Remember to clear the maximal sets between functions.
[oota-llvm.git] / lib / Transforms / Scalar / GVNPRE.cpp
1 //===- GVNPRE.cpp - Eliminate redundant values and expressions ------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by the Owen Anderson and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass performs a hybrid of global value numbering and partial redundancy
11 // elimination, known as GVN-PRE.  It performs partial redundancy elimination on
12 // values, rather than lexical expressions, allowing a more comprehensive view 
13 // the optimization.  It replaces redundant values with uses of earlier 
14 // occurences of the same value.  While this is beneficial in that it eliminates
15 // unneeded computation, it also increases register pressure by creating large
16 // live ranges, and should be used with caution on platforms that are very 
17 // sensitive to register pressure.
18 //
19 //===----------------------------------------------------------------------===//
20
21 #define DEBUG_TYPE "gvnpre"
22 #include "llvm/Value.h"
23 #include "llvm/Transforms/Scalar.h"
24 #include "llvm/Instructions.h"
25 #include "llvm/Function.h"
26 #include "llvm/Analysis/Dominators.h"
27 #include "llvm/Analysis/PostDominators.h"
28 #include "llvm/ADT/DepthFirstIterator.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/Support/CFG.h"
31 #include "llvm/Support/Compiler.h"
32 #include "llvm/Support/Debug.h"
33 #include <algorithm>
34 #include <deque>
35 #include <map>
36 #include <vector>
37 #include <set>
38 using namespace llvm;
39
40 //===----------------------------------------------------------------------===//
41 //                         ValueTable Class
42 //===----------------------------------------------------------------------===//
43
44 /// This class holds the mapping between values and value numbers.
45
46 namespace {
47   class VISIBILITY_HIDDEN ValueTable {
48     public:
49       struct Expression {
50         enum ExpressionOpcode { ADD, SUB, MUL, UDIV, SDIV, FDIV, UREM, SREM, 
51                               FREM, SHL, LSHR, ASHR, AND, OR, XOR, ICMPEQ, 
52                               ICMPNE, ICMPUGT, ICMPUGE, ICMPULT, ICMPULE, 
53                               ICMPSGT, ICMPSGE, ICMPSLT, ICMPSLE, FCMPOEQ, 
54                               FCMPOGT, FCMPOGE, FCMPOLT, FCMPOLE, FCMPONE, 
55                               FCMPORD, FCMPUNO, FCMPUEQ, FCMPUGT, FCMPUGE, 
56                               FCMPULT, FCMPULE, FCMPUNE };
57     
58         ExpressionOpcode opcode;
59         uint32_t leftVN;
60         uint32_t rightVN;
61       
62         bool operator< (const Expression& other) const {
63           if (opcode < other.opcode)
64             return true;
65           else if (opcode > other.opcode)
66             return false;
67           else if (leftVN < other.leftVN)
68             return true;
69           else if (leftVN > other.leftVN)
70             return false;
71           else if (rightVN < other.rightVN)
72             return true;
73           else if (rightVN > other.rightVN)
74             return false;
75           else
76             return false;
77         }
78       };
79     
80     private:
81       std::map<Value*, uint32_t> valueNumbering;
82       std::map<Expression, uint32_t> expressionNumbering;
83   
84       std::set<Expression> maximalExpressions;
85       std::set<Value*> maximalValues;
86   
87       uint32_t nextValueNumber;
88     
89       Expression::ExpressionOpcode getOpcode(BinaryOperator* BO);
90       Expression::ExpressionOpcode getOpcode(CmpInst* C);
91     public:
92       ValueTable() { nextValueNumber = 1; }
93       uint32_t lookup_or_add(Value* V);
94       uint32_t lookup(Value* V);
95       void add(Value* V, uint32_t num);
96       void clear();
97       std::set<Expression>& getMaximalExpressions() {
98         return maximalExpressions;
99       
100       }
101       std::set<Value*>& getMaximalValues() { return maximalValues; }
102       Expression create_expression(BinaryOperator* BO);
103       Expression create_expression(CmpInst* C);
104   };
105 }
106
107 ValueTable::Expression::ExpressionOpcode 
108                                      ValueTable::getOpcode(BinaryOperator* BO) {
109   switch(BO->getOpcode()) {
110     case Instruction::Add:
111       return Expression::ADD;
112     case Instruction::Sub:
113       return Expression::SUB;
114     case Instruction::Mul:
115       return Expression::MUL;
116     case Instruction::UDiv:
117       return Expression::UDIV;
118     case Instruction::SDiv:
119       return Expression::SDIV;
120     case Instruction::FDiv:
121       return Expression::FDIV;
122     case Instruction::URem:
123       return Expression::UREM;
124     case Instruction::SRem:
125       return Expression::SREM;
126     case Instruction::FRem:
127       return Expression::FREM;
128     case Instruction::Shl:
129       return Expression::SHL;
130     case Instruction::LShr:
131       return Expression::LSHR;
132     case Instruction::AShr:
133       return Expression::ASHR;
134     case Instruction::And:
135       return Expression::AND;
136     case Instruction::Or:
137       return Expression::OR;
138     case Instruction::Xor:
139       return Expression::XOR;
140     
141     // THIS SHOULD NEVER HAPPEN
142     default:
143       assert(0 && "Binary operator with unknown opcode?");
144       return Expression::ADD;
145   }
146 }
147
148 ValueTable::Expression::ExpressionOpcode ValueTable::getOpcode(CmpInst* C) {
149   if (C->getOpcode() == Instruction::ICmp) {
150     switch (C->getPredicate()) {
151       case ICmpInst::ICMP_EQ:
152         return Expression::ICMPEQ;
153       case ICmpInst::ICMP_NE:
154         return Expression::ICMPNE;
155       case ICmpInst::ICMP_UGT:
156         return Expression::ICMPUGT;
157       case ICmpInst::ICMP_UGE:
158         return Expression::ICMPUGE;
159       case ICmpInst::ICMP_ULT:
160         return Expression::ICMPULT;
161       case ICmpInst::ICMP_ULE:
162         return Expression::ICMPULE;
163       case ICmpInst::ICMP_SGT:
164         return Expression::ICMPSGT;
165       case ICmpInst::ICMP_SGE:
166         return Expression::ICMPSGE;
167       case ICmpInst::ICMP_SLT:
168         return Expression::ICMPSLT;
169       case ICmpInst::ICMP_SLE:
170         return Expression::ICMPSLE;
171       
172       // THIS SHOULD NEVER HAPPEN
173       default:
174         assert(0 && "Comparison with unknown predicate?");
175         return Expression::ICMPEQ;
176     }
177   } else {
178     switch (C->getPredicate()) {
179       case FCmpInst::FCMP_OEQ:
180         return Expression::FCMPOEQ;
181       case FCmpInst::FCMP_OGT:
182         return Expression::FCMPOGT;
183       case FCmpInst::FCMP_OGE:
184         return Expression::FCMPOGE;
185       case FCmpInst::FCMP_OLT:
186         return Expression::FCMPOLT;
187       case FCmpInst::FCMP_OLE:
188         return Expression::FCMPOLE;
189       case FCmpInst::FCMP_ONE:
190         return Expression::FCMPONE;
191       case FCmpInst::FCMP_ORD:
192         return Expression::FCMPORD;
193       case FCmpInst::FCMP_UNO:
194         return Expression::FCMPUNO;
195       case FCmpInst::FCMP_UEQ:
196         return Expression::FCMPUEQ;
197       case FCmpInst::FCMP_UGT:
198         return Expression::FCMPUGT;
199       case FCmpInst::FCMP_UGE:
200         return Expression::FCMPUGE;
201       case FCmpInst::FCMP_ULT:
202         return Expression::FCMPULT;
203       case FCmpInst::FCMP_ULE:
204         return Expression::FCMPULE;
205       case FCmpInst::FCMP_UNE:
206         return Expression::FCMPUNE;
207       
208       // THIS SHOULD NEVER HAPPEN
209       default:
210         assert(0 && "Comparison with unknown predicate?");
211         return Expression::FCMPOEQ;
212     }
213   }
214 }
215
216 uint32_t ValueTable::lookup_or_add(Value* V) {
217   maximalValues.insert(V);
218
219   std::map<Value*, uint32_t>::iterator VI = valueNumbering.find(V);
220   if (VI != valueNumbering.end())
221     return VI->second;
222   
223   
224   if (BinaryOperator* BO = dyn_cast<BinaryOperator>(V)) {
225     Expression e = create_expression(BO);
226     
227     std::map<Expression, uint32_t>::iterator EI = expressionNumbering.find(e);
228     if (EI != expressionNumbering.end()) {
229       valueNumbering.insert(std::make_pair(V, EI->second));
230       return EI->second;
231     } else {
232       expressionNumbering.insert(std::make_pair(e, nextValueNumber));
233       valueNumbering.insert(std::make_pair(V, nextValueNumber));
234       
235       return nextValueNumber++;
236     }
237   } else if (CmpInst* C = dyn_cast<CmpInst>(V)) {
238     Expression e = create_expression(C);
239     
240     std::map<Expression, uint32_t>::iterator EI = expressionNumbering.find(e);
241     if (EI != expressionNumbering.end()) {
242       valueNumbering.insert(std::make_pair(V, EI->second));
243       return EI->second;
244     } else {
245       expressionNumbering.insert(std::make_pair(e, nextValueNumber));
246       valueNumbering.insert(std::make_pair(V, nextValueNumber));
247       
248       return nextValueNumber++;
249     }
250   } else {
251     valueNumbering.insert(std::make_pair(V, nextValueNumber));
252     return nextValueNumber++;
253   }
254 }
255
256 uint32_t ValueTable::lookup(Value* V) {
257   std::map<Value*, uint32_t>::iterator VI = valueNumbering.find(V);
258   if (VI != valueNumbering.end())
259     return VI->second;
260   else
261     assert(0 && "Value not numbered?");
262   
263   return 0;
264 }
265
266 void ValueTable::add(Value* V, uint32_t num) {
267   std::map<Value*, uint32_t>::iterator VI = valueNumbering.find(V);
268   if (VI != valueNumbering.end())
269     valueNumbering.erase(VI);
270   valueNumbering.insert(std::make_pair(V, num));
271 }
272
273 ValueTable::Expression ValueTable::create_expression(BinaryOperator* BO) {
274   Expression e;
275     
276   e.leftVN = lookup_or_add(BO->getOperand(0));
277   e.rightVN = lookup_or_add(BO->getOperand(1));
278   e.opcode = getOpcode(BO);
279   
280   maximalExpressions.insert(e);
281   
282   return e;
283 }
284
285 ValueTable::Expression ValueTable::create_expression(CmpInst* C) {
286   Expression e;
287     
288   e.leftVN = lookup_or_add(C->getOperand(0));
289   e.rightVN = lookup_or_add(C->getOperand(1));
290   e.opcode = getOpcode(C);
291   
292   maximalExpressions.insert(e);
293   
294   return e;
295 }
296
297 void ValueTable::clear() {
298   valueNumbering.clear();
299   expressionNumbering.clear();
300   maximalExpressions.clear();
301   maximalValues.clear();
302   nextValueNumber = 1;
303 }
304
305 namespace {
306
307   class VISIBILITY_HIDDEN GVNPRE : public FunctionPass {
308     bool runOnFunction(Function &F);
309   public:
310     static char ID; // Pass identification, replacement for typeid
311     GVNPRE() : FunctionPass((intptr_t)&ID) { }
312
313   private:
314     ValueTable VN;
315     std::vector<Instruction*> createdExpressions;
316     
317     std::map<BasicBlock*, std::set<Value*> > availableOut;
318     std::map<BasicBlock*, std::set<Value*> > anticipatedIn;
319     std::map<User*, bool> invokeDep;
320     
321     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
322       AU.setPreservesCFG();
323       AU.addRequired<DominatorTree>();
324       AU.addRequired<PostDominatorTree>();
325     }
326   
327     // Helper fuctions
328     // FIXME: eliminate or document these better
329     void dump(const std::set<Value*>& s) const;
330     void dump_unique(const std::set<Value*>& s) const;
331     void clean(std::set<Value*>& set);
332     Value* find_leader(std::set<Value*>& vals,
333                        uint32_t v);
334     Value* phi_translate(Value* V, BasicBlock* pred, BasicBlock* succ);
335     void phi_translate_set(std::set<Value*>& anticIn, BasicBlock* pred,
336                            BasicBlock* succ, std::set<Value*>& out);
337     
338     void topo_sort(std::set<Value*>& set,
339                    std::vector<Value*>& vec);
340     
341     // For a given block, calculate the generated expressions, temporaries,
342     // and the AVAIL_OUT set
343     void cleanup();
344     void elimination();
345     
346     void val_insert(std::set<Value*>& s, Value* v);
347     void val_replace(std::set<Value*>& s, Value* v);
348     bool dependsOnInvoke(Value* V);
349   
350   };
351   
352   char GVNPRE::ID = 0;
353   
354 }
355
356 FunctionPass *llvm::createGVNPREPass() { return new GVNPRE(); }
357
358 RegisterPass<GVNPRE> X("gvnpre",
359                        "Global Value Numbering/Partial Redundancy Elimination");
360
361
362 STATISTIC(NumInsertedVals, "Number of values inserted");
363 STATISTIC(NumInsertedPhis, "Number of PHI nodes inserted");
364 STATISTIC(NumEliminated, "Number of redundant instructions eliminated");
365
366 Value* GVNPRE::find_leader(std::set<Value*>& vals, uint32_t v) {
367   for (std::set<Value*>::iterator I = vals.begin(), E = vals.end();
368        I != E; ++I)
369     if (v == VN.lookup(*I))
370       return *I;
371   
372   return 0;
373 }
374
375 void GVNPRE::val_insert(std::set<Value*>& s, Value* v) {
376   uint32_t num = VN.lookup(v);
377   Value* leader = find_leader(s, num);
378   if (leader == 0)
379     s.insert(v);
380 }
381
382 void GVNPRE::val_replace(std::set<Value*>& s, Value* v) {
383   uint32_t num = VN.lookup(v);
384   Value* leader = find_leader(s, num);
385   while (leader != 0) {
386     s.erase(leader);
387     leader = find_leader(s, num);
388   }
389   s.insert(v);
390 }
391
392 Value* GVNPRE::phi_translate(Value* V, BasicBlock* pred, BasicBlock* succ) {
393   if (V == 0)
394     return 0;
395   
396   if (BinaryOperator* BO = dyn_cast<BinaryOperator>(V)) {
397     Value* newOp1 = 0;
398     if (isa<Instruction>(BO->getOperand(0)))
399       newOp1 = phi_translate(find_leader(anticipatedIn[succ],         
400                                          VN.lookup(BO->getOperand(0))),
401                              pred, succ);
402     else
403       newOp1 = BO->getOperand(0);
404     
405     if (newOp1 == 0)
406       return 0;
407     
408     Value* newOp2 = 0;
409     if (isa<Instruction>(BO->getOperand(1)))
410       newOp2 = phi_translate(find_leader(anticipatedIn[succ],         
411                                          VN.lookup(BO->getOperand(1))),
412                              pred, succ);
413     else
414       newOp2 = BO->getOperand(1);
415     
416     if (newOp2 == 0)
417       return 0;
418     
419     if (newOp1 != BO->getOperand(0) || newOp2 != BO->getOperand(1)) {
420       Instruction* newVal = BinaryOperator::create(BO->getOpcode(),
421                                              newOp1, newOp2,
422                                              BO->getName()+".gvnpre");
423       
424       uint32_t v = VN.lookup_or_add(newVal);
425       
426       Value* leader = find_leader(availableOut[pred], v);
427       if (leader == 0) {
428         createdExpressions.push_back(newVal);
429         return newVal;
430       } else {
431         delete newVal;
432         return leader;
433       }
434     }
435   } else if (PHINode* P = dyn_cast<PHINode>(V)) {
436     if (P->getParent() == succ)
437       return P->getIncomingValueForBlock(pred);
438   } else if (CmpInst* C = dyn_cast<CmpInst>(V)) {
439     Value* newOp1 = 0;
440     if (isa<Instruction>(C->getOperand(0)))
441       newOp1 = phi_translate(find_leader(anticipatedIn[succ],         
442                                          VN.lookup(C->getOperand(0))),
443                              pred, succ);
444     else
445       newOp1 = C->getOperand(0);
446     
447     if (newOp1 == 0)
448       return 0;
449     
450     Value* newOp2 = 0;
451     if (isa<Instruction>(C->getOperand(1)))
452       newOp2 = phi_translate(find_leader(anticipatedIn[succ],         
453                                          VN.lookup(C->getOperand(1))),
454                              pred, succ);
455     else
456       newOp2 = C->getOperand(1);
457       
458     if (newOp2 == 0)
459       return 0;
460     
461     if (newOp1 != C->getOperand(0) || newOp2 != C->getOperand(1)) {
462       Instruction* newVal = CmpInst::create(C->getOpcode(),
463                                             C->getPredicate(),
464                                              newOp1, newOp2,
465                                              C->getName()+".gvnpre");
466       
467       uint32_t v = VN.lookup_or_add(newVal);
468         
469       Value* leader = find_leader(availableOut[pred], v);
470       if (leader == 0) {
471         createdExpressions.push_back(newVal);
472         return newVal;
473       } else {
474         delete newVal;
475         return leader;
476       }
477     }
478   }
479   
480   return V;
481 }
482
483 void GVNPRE::phi_translate_set(std::set<Value*>& anticIn,
484                               BasicBlock* pred, BasicBlock* succ,
485                               std::set<Value*>& out) {
486   for (std::set<Value*>::iterator I = anticIn.begin(),
487        E = anticIn.end(); I != E; ++I) {
488     Value* V = phi_translate(*I, pred, succ);
489     if (V != 0)
490       out.insert(V);
491   }
492 }
493
494 bool GVNPRE::dependsOnInvoke(Value* V) {
495   if (!isa<User>(V))
496     return false;
497     
498   User* U = cast<User>(V);
499   std::map<User*, bool>::iterator I = invokeDep.find(U);
500   if (I != invokeDep.end())
501     return I->second;
502   
503   std::vector<Value*> worklist(U->op_begin(), U->op_end());
504   std::set<Value*> visited;
505   
506   while (!worklist.empty()) {
507     Value* current = worklist.back();
508     worklist.pop_back();
509     visited.insert(current);
510     
511     if (!isa<User>(current))
512       continue;
513     else if (isa<InvokeInst>(current))
514       return true;
515     
516     User* curr = cast<User>(current);
517     std::map<User*, bool>::iterator CI = invokeDep.find(curr);
518     if (CI != invokeDep.end()) {
519       if (CI->second)
520         return true;
521     } else {
522       for (unsigned i = 0; i < curr->getNumOperands(); ++i)
523         if (visited.find(curr->getOperand(i)) == visited.end())
524           worklist.push_back(curr->getOperand(i));
525     }
526   }
527   
528   return false;
529 }
530
531 // Remove all expressions whose operands are not themselves in the set
532 void GVNPRE::clean(std::set<Value*>& set) {
533   std::vector<Value*> worklist;
534   topo_sort(set, worklist);
535   
536   for (unsigned i = 0; i < worklist.size(); ++i) {
537     Value* v = worklist[i];
538     
539     if (BinaryOperator* BO = dyn_cast<BinaryOperator>(v)) {   
540       bool lhsValid = !isa<Instruction>(BO->getOperand(0));
541       if (!lhsValid)
542         for (std::set<Value*>::iterator I = set.begin(), E = set.end();
543              I != E; ++I)
544           if (VN.lookup(*I) == VN.lookup(BO->getOperand(0))) {
545             lhsValid = true;
546             break;
547           }
548           
549       // Check for dependency on invoke insts
550       // NOTE: This check is expensive, so don't do it if we
551       // don't have to
552       if (lhsValid)
553         lhsValid = !dependsOnInvoke(BO->getOperand(0));
554     
555       bool rhsValid = !isa<Instruction>(BO->getOperand(1));
556       if (!rhsValid)
557         for (std::set<Value*>::iterator I = set.begin(), E = set.end();
558              I != E; ++I)
559           if (VN.lookup(*I) == VN.lookup(BO->getOperand(1))) {
560             rhsValid = true;
561             break;
562           }
563       
564       // Check for dependency on invoke insts
565       // NOTE: This check is expensive, so don't do it if we
566       // don't have to
567       if (rhsValid)
568         rhsValid = !dependsOnInvoke(BO->getOperand(1));
569       
570       if (!lhsValid || !rhsValid)
571         set.erase(BO);
572     } else if (CmpInst* C = dyn_cast<CmpInst>(v)) {
573       bool lhsValid = !isa<Instruction>(C->getOperand(0));
574       if (!lhsValid)
575         for (std::set<Value*>::iterator I = set.begin(), E = set.end();
576              I != E; ++I)
577           if (VN.lookup(*I) == VN.lookup(C->getOperand(0))) {
578             lhsValid = true;
579             break;
580           }
581       lhsValid &= !dependsOnInvoke(C->getOperand(0));
582       
583       bool rhsValid = !isa<Instruction>(C->getOperand(1));
584       if (!rhsValid)
585       for (std::set<Value*>::iterator I = set.begin(), E = set.end();
586            I != E; ++I)
587         if (VN.lookup(*I) == VN.lookup(C->getOperand(1))) {
588           rhsValid = true;
589           break;
590         }
591       rhsValid &= !dependsOnInvoke(C->getOperand(1));
592     
593       if (!lhsValid || !rhsValid)
594         set.erase(C);
595     }
596   }
597 }
598
599 void GVNPRE::topo_sort(std::set<Value*>& set,
600                        std::vector<Value*>& vec) {
601   std::set<Value*> toErase;
602   for (std::set<Value*>::iterator I = set.begin(), E = set.end();
603        I != E; ++I) {
604     if (BinaryOperator* BO = dyn_cast<BinaryOperator>(*I))
605       for (std::set<Value*>::iterator SI = set.begin(); SI != E; ++SI) {
606         if (VN.lookup(BO->getOperand(0)) == VN.lookup(*SI) ||
607             VN.lookup(BO->getOperand(1)) == VN.lookup(*SI)) {
608           toErase.insert(*SI);
609         }
610       }
611     else if (CmpInst* C = dyn_cast<CmpInst>(*I))
612       for (std::set<Value*>::iterator SI = set.begin(); SI != E; ++SI) {
613         if (VN.lookup(C->getOperand(0)) == VN.lookup(*SI) ||
614             VN.lookup(C->getOperand(1)) == VN.lookup(*SI)) {
615           toErase.insert(*SI);
616         }
617       }
618   }
619   
620   std::vector<Value*> Q;
621   for (std::set<Value*>::iterator I = set.begin(), E = set.end();
622        I != E; ++I) {
623     if (toErase.find(*I) == toErase.end())
624       Q.push_back(*I);
625   }
626   
627   std::set<Value*> visited;
628   while (!Q.empty()) {
629     Value* e = Q.back();
630   
631     if (BinaryOperator* BO = dyn_cast<BinaryOperator>(e)) {
632       Value* l = find_leader(set, VN.lookup(BO->getOperand(0)));
633       Value* r = find_leader(set, VN.lookup(BO->getOperand(1)));
634       
635       if (l != 0 && isa<Instruction>(l) &&
636           visited.find(l) == visited.end())
637         Q.push_back(l);
638       else if (r != 0 && isa<Instruction>(r) &&
639                visited.find(r) == visited.end())
640         Q.push_back(r);
641       else {
642         vec.push_back(e);
643         visited.insert(e);
644         Q.pop_back();
645       }
646     } else if (CmpInst* C = dyn_cast<CmpInst>(e)) {
647       Value* l = find_leader(set, VN.lookup(C->getOperand(0)));
648       Value* r = find_leader(set, VN.lookup(C->getOperand(1)));
649       
650       if (l != 0 && isa<Instruction>(l) &&
651           visited.find(l) == visited.end())
652         Q.push_back(l);
653       else if (r != 0 && isa<Instruction>(r) &&
654                visited.find(r) == visited.end())
655         Q.push_back(r);
656       else {
657         vec.push_back(e);
658         visited.insert(e);
659         Q.pop_back();
660       }
661     } else {
662       visited.insert(e);
663       vec.push_back(e);
664       Q.pop_back();
665     }
666   }
667 }
668
669
670 void GVNPRE::dump(const std::set<Value*>& s) const {
671   DOUT << "{ ";
672   for (std::set<Value*>::iterator I = s.begin(), E = s.end();
673        I != E; ++I) {
674     DEBUG((*I)->dump());
675   }
676   DOUT << "}\n\n";
677 }
678
679 void GVNPRE::dump_unique(const std::set<Value*>& s) const {
680   DOUT << "{ ";
681   for (std::set<Value*>::iterator I = s.begin(), E = s.end();
682        I != E; ++I) {
683     DEBUG((*I)->dump());
684   }
685   DOUT << "}\n\n";
686 }
687
688 void GVNPRE::elimination() {
689   DOUT << "\n\nPhase 3: Elimination\n\n";
690   
691   std::vector<std::pair<Instruction*, Value*> > replace;
692   std::vector<Instruction*> erase;
693   
694   DominatorTree& DT = getAnalysis<DominatorTree>();
695   
696   for (df_iterator<DomTreeNode*> DI = df_begin(DT.getRootNode()),
697          E = df_end(DT.getRootNode()); DI != E; ++DI) {
698     BasicBlock* BB = DI->getBlock();
699     
700     DOUT << "Block: " << BB->getName() << "\n";
701     dump_unique(availableOut[BB]);
702     DOUT << "\n\n";
703     
704     for (BasicBlock::iterator BI = BB->begin(), BE = BB->end();
705          BI != BE; ++BI) {
706
707       if (isa<BinaryOperator>(BI) || isa<CmpInst>(BI)) {
708          Value *leader = find_leader(availableOut[BB], VN.lookup(BI));
709   
710         if (leader != 0)
711           if (Instruction* Instr = dyn_cast<Instruction>(leader))
712             if (Instr->getParent() != 0 && Instr != BI) {
713               replace.push_back(std::make_pair(BI, leader));
714               erase.push_back(BI);
715               ++NumEliminated;
716             }
717       }
718     }
719   }
720   
721   while (!replace.empty()) {
722     std::pair<Instruction*, Value*> rep = replace.back();
723     replace.pop_back();
724     rep.first->replaceAllUsesWith(rep.second);
725   }
726     
727   for (std::vector<Instruction*>::iterator I = erase.begin(), E = erase.end();
728        I != E; ++I)
729      (*I)->eraseFromParent();
730 }
731
732
733 void GVNPRE::cleanup() {
734   while (!createdExpressions.empty()) {
735     Instruction* I = createdExpressions.back();
736     createdExpressions.pop_back();
737     
738     delete I;
739   }
740 }
741
742 bool GVNPRE::runOnFunction(Function &F) {
743   VN.clear();
744   createdExpressions.clear();
745   availableOut.clear();
746   anticipatedIn.clear();
747   invokeDep.clear();
748
749   std::map<BasicBlock*, std::set<Value*> > generatedExpressions;
750   std::map<BasicBlock*, std::set<PHINode*> > generatedPhis;
751   std::map<BasicBlock*, std::set<Value*> > generatedTemporaries;
752   
753   
754   DominatorTree &DT = getAnalysis<DominatorTree>();   
755   
756   // Phase 1: BuildSets
757   
758   // Phase 1, Part 1: calculate AVAIL_OUT
759   
760   // Top-down walk of the dominator tree
761   for (df_iterator<DomTreeNode*> DI = df_begin(DT.getRootNode()),
762          E = df_end(DT.getRootNode()); DI != E; ++DI) {
763     
764     // Get the sets to update for this block
765     std::set<Value*>& currExps = generatedExpressions[DI->getBlock()];
766     std::set<PHINode*>& currPhis = generatedPhis[DI->getBlock()];
767     std::set<Value*>& currTemps = generatedTemporaries[DI->getBlock()];
768     std::set<Value*>& currAvail = availableOut[DI->getBlock()];     
769     
770     BasicBlock* BB = DI->getBlock();
771   
772     // A block inherits AVAIL_OUT from its dominator
773     if (DI->getIDom() != 0)
774     currAvail.insert(availableOut[DI->getIDom()->getBlock()].begin(),
775                      availableOut[DI->getIDom()->getBlock()].end());
776     
777     
778     for (BasicBlock::iterator BI = BB->begin(), BE = BB->end();
779          BI != BE; ++BI) {
780        
781       // Handle PHI nodes...
782       if (PHINode* p = dyn_cast<PHINode>(BI)) {
783         VN.lookup_or_add(p);
784         currPhis.insert(p);
785     
786       // Handle binary ops...
787       } else if (BinaryOperator* BO = dyn_cast<BinaryOperator>(BI)) {
788         Value* leftValue = BO->getOperand(0);
789         Value* rightValue = BO->getOperand(1);
790       
791         VN.lookup_or_add(BO);
792       
793         if (isa<Instruction>(leftValue))
794           val_insert(currExps, leftValue);
795         if (isa<Instruction>(rightValue))
796           val_insert(currExps, rightValue);
797         val_insert(currExps, BO);
798       
799       // Handle cmp ops...
800       } else if (CmpInst* C = dyn_cast<CmpInst>(BI)) {
801         Value* leftValue = C->getOperand(0);
802         Value* rightValue = C->getOperand(1);
803       
804         VN.lookup_or_add(C);
805       
806         if (isa<Instruction>(leftValue))
807           val_insert(currExps, leftValue);
808         if (isa<Instruction>(rightValue))
809           val_insert(currExps, rightValue);
810         val_insert(currExps, C);
811       
812       // Handle unsupported ops
813       } else if (!BI->isTerminator()){
814         VN.lookup_or_add(BI);
815         currTemps.insert(BI);
816       }
817     
818       if (!BI->isTerminator())
819         val_insert(currAvail, BI);
820     }
821   }
822   
823   DOUT << "Maximal Set: ";
824   dump_unique(VN.getMaximalValues());
825   DOUT << "\n";
826   
827   // If function has no exit blocks, only perform GVN
828   PostDominatorTree &PDT = getAnalysis<PostDominatorTree>();
829   if (PDT[&F.getEntryBlock()] == 0) {
830     elimination();
831     cleanup();
832     
833     return true;
834   }
835   
836   
837   // Phase 1, Part 2: calculate ANTIC_IN
838   
839   std::set<BasicBlock*> visited;
840   
841   bool changed = true;
842   unsigned iterations = 0;
843   while (changed) {
844     changed = false;
845     std::set<Value*> anticOut;
846     
847     // Top-down walk of the postdominator tree
848     for (df_iterator<DomTreeNode*> PDI = 
849          df_begin(PDT.getRootNode()), E = df_end(PDT.getRootNode());
850          PDI != E; ++PDI) {
851       BasicBlock* BB = PDI->getBlock();
852       if (BB == 0)
853         continue;
854       
855       DOUT << "Block: " << BB->getName() << "\n";
856       DOUT << "TMP_GEN: ";
857       dump(generatedTemporaries[BB]);
858       DOUT << "\n";
859     
860       DOUT << "EXP_GEN: ";
861       dump_unique(generatedExpressions[BB]);
862       visited.insert(BB);
863       
864       std::set<Value*>& anticIn = anticipatedIn[BB];
865       std::set<Value*> old (anticIn.begin(), anticIn.end());
866       
867       if (BB->getTerminator()->getNumSuccessors() == 1) {
868          if (visited.find(BB->getTerminator()->getSuccessor(0)) == 
869              visited.end())
870            phi_translate_set(VN.getMaximalValues(), BB,    
871                              BB->getTerminator()->getSuccessor(0),
872                              anticOut);
873          else
874           phi_translate_set(anticipatedIn[BB->getTerminator()->getSuccessor(0)],
875                             BB,  BB->getTerminator()->getSuccessor(0), 
876                             anticOut);
877       } else if (BB->getTerminator()->getNumSuccessors() > 1) {
878         BasicBlock* first = BB->getTerminator()->getSuccessor(0);
879         anticOut.insert(anticipatedIn[first].begin(),
880                         anticipatedIn[first].end());
881         for (unsigned i = 1; i < BB->getTerminator()->getNumSuccessors(); ++i) {
882           BasicBlock* currSucc = BB->getTerminator()->getSuccessor(i);
883           std::set<Value*>& succAnticIn = anticipatedIn[currSucc];
884           
885           std::set<Value*> temp;
886           std::insert_iterator<std::set<Value*> >  temp_ins(temp, 
887                                                             temp.begin());
888           std::set_intersection(anticOut.begin(), anticOut.end(),
889                                 succAnticIn.begin(), succAnticIn.end(),
890                                 temp_ins);
891           
892           anticOut.clear();
893           anticOut.insert(temp.begin(), temp.end());
894         }
895       }
896       
897       DOUT << "ANTIC_OUT: ";
898       dump_unique(anticOut);
899       DOUT << "\n";
900       
901       std::set<Value*> S;
902       std::insert_iterator<std::set<Value*> >  s_ins(S, S.begin());
903       std::set_difference(anticOut.begin(), anticOut.end(),
904                      generatedTemporaries[BB].begin(),
905                      generatedTemporaries[BB].end(),
906                      s_ins);
907       
908       anticIn.clear();
909       std::insert_iterator<std::set<Value*> >  ai_ins(anticIn, anticIn.begin());
910       std::set_difference(generatedExpressions[BB].begin(),
911                      generatedExpressions[BB].end(),
912                      generatedTemporaries[BB].begin(),
913                      generatedTemporaries[BB].end(),
914                      ai_ins);
915       
916       for (std::set<Value*>::iterator I = S.begin(), E = S.end();
917            I != E; ++I) {
918         if (find_leader(anticIn, VN.lookup(*I)) == 0)
919           val_insert(anticIn, *I);
920       }
921       
922       clean(anticIn);
923       
924       DOUT << "ANTIC_IN: ";
925       dump_unique(anticIn);
926       DOUT << "\n";
927       
928       if (old.size() != anticIn.size())
929         changed = true;
930       
931       anticOut.clear();
932     }
933     
934     iterations++;
935   }
936   
937   DOUT << "Iterations: " << iterations << "\n";
938   
939   for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) {
940     DOUT << "Name: " << I->getName().c_str() << "\n";
941     
942     DOUT << "TMP_GEN: ";
943     dump(generatedTemporaries[I]);
944     DOUT << "\n";
945     
946     DOUT << "EXP_GEN: ";
947     dump_unique(generatedExpressions[I]);
948     DOUT << "\n";
949     
950     DOUT << "ANTIC_IN: ";
951     dump_unique(anticipatedIn[I]);
952     DOUT << "\n";
953     
954     DOUT << "AVAIL_OUT: ";
955     dump_unique(availableOut[I]);
956     DOUT << "\n";
957   }
958   
959   // Phase 2: Insert
960   DOUT<< "\nPhase 2: Insertion\n";
961   
962   std::map<BasicBlock*, std::set<Value*> > new_sets;
963   unsigned i_iterations = 0;
964   bool new_stuff = true;
965   while (new_stuff) {
966     new_stuff = false;
967     DOUT << "Iteration: " << i_iterations << "\n\n";
968     for (df_iterator<DomTreeNode*> DI = df_begin(DT.getRootNode()),
969          E = df_end(DT.getRootNode()); DI != E; ++DI) {
970       BasicBlock* BB = DI->getBlock();
971       
972       if (BB == 0)
973         continue;
974       
975       std::set<Value*>& new_set = new_sets[BB];
976       std::set<Value*>& availOut = availableOut[BB];
977       std::set<Value*>& anticIn = anticipatedIn[BB];
978       
979       new_set.clear();
980       
981       // Replace leaders with leaders inherited from dominator
982       if (DI->getIDom() != 0) {
983         std::set<Value*>& dom_set = new_sets[DI->getIDom()->getBlock()];
984         for (std::set<Value*>::iterator I = dom_set.begin(),
985              E = dom_set.end(); I != E; ++I) {
986           new_set.insert(*I);
987           val_replace(availOut, *I);
988         }
989       }
990       
991       // If there is more than one predecessor...
992       if (pred_begin(BB) != pred_end(BB) && ++pred_begin(BB) != pred_end(BB)) {
993         std::vector<Value*> workList;
994         topo_sort(anticIn, workList);
995         
996         DOUT << "Merge Block: " << BB->getName() << "\n";
997         DOUT << "ANTIC_IN: ";
998         dump_unique(anticIn);
999         DOUT << "\n";
1000         
1001         for (unsigned i = 0; i < workList.size(); ++i) {
1002           Value* e = workList[i];
1003           
1004           if (isa<BinaryOperator>(e) || isa<CmpInst>(e)) {
1005             if (find_leader(availableOut[DI->getIDom()->getBlock()], VN.lookup(e)) != 0)
1006               continue;
1007             
1008             std::map<BasicBlock*, Value*> avail;
1009             bool by_some = false;
1010             int num_avail = 0;
1011             
1012             for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB); PI != PE;
1013                  ++PI) {
1014               Value *e2 = phi_translate(e, *PI, BB);
1015               Value *e3 = find_leader(availableOut[*PI], VN.lookup(e2));
1016               
1017               if (e3 == 0) {
1018                 std::map<BasicBlock*, Value*>::iterator av = avail.find(*PI);
1019                 if (av != avail.end())
1020                   avail.erase(av);
1021                 avail.insert(std::make_pair(*PI, e2));
1022               } else {
1023                 std::map<BasicBlock*, Value*>::iterator av = avail.find(*PI);
1024                 if (av != avail.end())
1025                   avail.erase(av);
1026                 avail.insert(std::make_pair(*PI, e3));
1027                 
1028                 by_some = true;
1029                 num_avail++;
1030               }
1031             }
1032             
1033             if (by_some &&
1034                 num_avail < std::distance(pred_begin(BB), pred_end(BB))) {
1035               DOUT << "Processing Value: ";
1036               DEBUG(e->dump());
1037               DOUT << "\n\n";
1038             
1039               for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB);
1040                    PI != PE; ++PI) {
1041                 Value* e2 = avail[*PI];
1042                 if (!find_leader(availableOut[*PI], VN.lookup(e2))) {
1043                   User* U = cast<User>(e2);
1044                 
1045                   Value* s1 = 0;
1046                   if (isa<BinaryOperator>(U->getOperand(0)) ||
1047                       isa<CmpInst>(U->getOperand(0)))
1048                     s1 = find_leader(availableOut[*PI], VN.lookup(U->getOperand(0)));
1049                   else
1050                     s1 = U->getOperand(0);
1051                   
1052                   Value* s2 = 0;
1053                   if (isa<BinaryOperator>(U->getOperand(1)) ||
1054                       isa<CmpInst>(U->getOperand(1)))
1055                     s2 = find_leader(availableOut[*PI], VN.lookup(U->getOperand(1)));
1056                   else
1057                     s2 = U->getOperand(1);
1058                   
1059                   Value* newVal = 0;
1060                   if (BinaryOperator* BO = dyn_cast<BinaryOperator>(U))
1061                     newVal = BinaryOperator::create(BO->getOpcode(),
1062                                              s1, s2,
1063                                              BO->getName()+".gvnpre",
1064                                              (*PI)->getTerminator());
1065                   else if (CmpInst* C = dyn_cast<CmpInst>(U))
1066                     newVal = CmpInst::create(C->getOpcode(),
1067                                              C->getPredicate(),
1068                                              s1, s2,
1069                                              C->getName()+".gvnpre",
1070                                              (*PI)->getTerminator());
1071                   
1072                   VN.add(newVal, VN.lookup(U));
1073                   
1074                   std::set<Value*>& predAvail = availableOut[*PI];
1075                   val_replace(predAvail, newVal);
1076                   
1077                   DOUT << "Creating value: " << std::hex << newVal << std::dec << "\n";
1078                   
1079                   std::map<BasicBlock*, Value*>::iterator av = avail.find(*PI);
1080                   if (av != avail.end())
1081                     avail.erase(av);
1082                   avail.insert(std::make_pair(*PI, newVal));
1083                   
1084                   ++NumInsertedVals;
1085                 }
1086               }
1087               
1088               PHINode* p = 0;
1089               
1090               for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB);
1091                    PI != PE; ++PI) {
1092                 if (p == 0)
1093                   p = new PHINode(avail[*PI]->getType(), "gvnpre-join", 
1094                                   BB->begin());
1095                 
1096                 p->addIncoming(avail[*PI], *PI);
1097               }
1098               
1099               VN.add(p, VN.lookup(e));
1100               DOUT << "Creating value: " << std::hex << p << std::dec << "\n";
1101               
1102               val_replace(availOut, p);
1103               availOut.insert(p);
1104               
1105               new_stuff = true;
1106               
1107               DOUT << "Preds After Processing: ";
1108               for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB);
1109                    PI != PE; ++PI)
1110                 DEBUG((*PI)->dump());
1111               DOUT << "\n\n";
1112               
1113               DOUT << "Merge Block After Processing: ";
1114               DEBUG(BB->dump());
1115               DOUT << "\n\n";
1116               
1117               new_set.insert(p);
1118               
1119               ++NumInsertedPhis;
1120             }
1121           }
1122         }
1123       }
1124     }
1125     i_iterations++;
1126   }
1127   
1128   // Phase 3: Eliminate
1129   elimination();
1130   
1131   // Phase 4: Cleanup
1132   cleanup();
1133   
1134   return true;
1135 }