91c1cd4d47b30af8b068252a9aa0f862a1a0cdd6
[oota-llvm.git] / lib / Transforms / Scalar / GVNPRE.cpp
1 //===- GVNPRE.cpp - Eliminate redundant values and expressions ------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by the Owen Anderson and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass performs a hybrid of global value numbering and partial redundancy
11 // elimination, known as GVN-PRE.  It performs partial redundancy elimination on
12 // values, rather than lexical expressions, allowing a more comprehensive view 
13 // the optimization.  It replaces redundant values with uses of earlier 
14 // occurences of the same value.  While this is beneficial in that it eliminates
15 // unneeded computation, it also increases register pressure by creating large
16 // live ranges, and should be used with caution on platforms that are very 
17 // sensitive to register pressure.
18 //
19 //===----------------------------------------------------------------------===//
20
21 #define DEBUG_TYPE "gvnpre"
22 #include "llvm/Value.h"
23 #include "llvm/Transforms/Scalar.h"
24 #include "llvm/Instructions.h"
25 #include "llvm/Function.h"
26 #include "llvm/Analysis/Dominators.h"
27 #include "llvm/Analysis/PostDominators.h"
28 #include "llvm/ADT/DepthFirstIterator.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/Support/CFG.h"
31 #include "llvm/Support/Compiler.h"
32 #include "llvm/Support/Debug.h"
33 #include <algorithm>
34 #include <deque>
35 #include <map>
36 #include <vector>
37 #include <set>
38 using namespace llvm;
39
40 //===----------------------------------------------------------------------===//
41 //                         ValueTable Class
42 //===----------------------------------------------------------------------===//
43
44 /// This class holds the mapping between values and value numbers.
45
46 namespace {
47   class VISIBILITY_HIDDEN ValueTable {
48     public:
49       struct Expression {
50         enum ExpressionOpcode { ADD, SUB, MUL, UDIV, SDIV, FDIV, UREM, SREM, 
51                               FREM, SHL, LSHR, ASHR, AND, OR, XOR, ICMPEQ, 
52                               ICMPNE, ICMPUGT, ICMPUGE, ICMPULT, ICMPULE, 
53                               ICMPSGT, ICMPSGE, ICMPSLT, ICMPSLE, FCMPOEQ, 
54                               FCMPOGT, FCMPOGE, FCMPOLT, FCMPOLE, FCMPONE, 
55                               FCMPORD, FCMPUNO, FCMPUEQ, FCMPUGT, FCMPUGE, 
56                               FCMPULT, FCMPULE, FCMPUNE };
57     
58         ExpressionOpcode opcode;
59         uint32_t leftVN;
60         uint32_t rightVN;
61       
62         bool operator< (const Expression& other) const {
63           if (opcode < other.opcode)
64             return true;
65           else if (opcode > other.opcode)
66             return false;
67           else if (leftVN < other.leftVN)
68             return true;
69           else if (leftVN > other.leftVN)
70             return false;
71           else if (rightVN < other.rightVN)
72             return true;
73           else if (rightVN > other.rightVN)
74             return false;
75           else
76             return false;
77         }
78       };
79     
80     private:
81       std::map<Value*, uint32_t> valueNumbering;
82       std::map<Expression, uint32_t> expressionNumbering;
83   
84       std::set<Expression> maximalExpressions;
85       std::set<Value*> maximalValues;
86   
87       uint32_t nextValueNumber;
88     
89       Expression::ExpressionOpcode getOpcode(BinaryOperator* BO);
90       Expression::ExpressionOpcode getOpcode(CmpInst* C);
91     public:
92       ValueTable() { nextValueNumber = 1; }
93       uint32_t lookup_or_add(Value* V);
94       uint32_t lookup(Value* V);
95       void add(Value* V, uint32_t num);
96       void clear();
97       std::set<Expression>& getMaximalExpressions() {
98         return maximalExpressions;
99       
100       }
101       std::set<Value*>& getMaximalValues() { return maximalValues; }
102       Expression create_expression(BinaryOperator* BO);
103       Expression create_expression(CmpInst* C);
104   };
105 }
106
107 ValueTable::Expression::ExpressionOpcode 
108                                      ValueTable::getOpcode(BinaryOperator* BO) {
109   switch(BO->getOpcode()) {
110     case Instruction::Add:
111       return Expression::ADD;
112     case Instruction::Sub:
113       return Expression::SUB;
114     case Instruction::Mul:
115       return Expression::MUL;
116     case Instruction::UDiv:
117       return Expression::UDIV;
118     case Instruction::SDiv:
119       return Expression::SDIV;
120     case Instruction::FDiv:
121       return Expression::FDIV;
122     case Instruction::URem:
123       return Expression::UREM;
124     case Instruction::SRem:
125       return Expression::SREM;
126     case Instruction::FRem:
127       return Expression::FREM;
128     case Instruction::Shl:
129       return Expression::SHL;
130     case Instruction::LShr:
131       return Expression::LSHR;
132     case Instruction::AShr:
133       return Expression::ASHR;
134     case Instruction::And:
135       return Expression::AND;
136     case Instruction::Or:
137       return Expression::OR;
138     case Instruction::Xor:
139       return Expression::XOR;
140     
141     // THIS SHOULD NEVER HAPPEN
142     default:
143       assert(0 && "Binary operator with unknown opcode?");
144       return Expression::ADD;
145   }
146 }
147
148 ValueTable::Expression::ExpressionOpcode ValueTable::getOpcode(CmpInst* C) {
149   if (C->getOpcode() == Instruction::ICmp) {
150     switch (C->getPredicate()) {
151       case ICmpInst::ICMP_EQ:
152         return Expression::ICMPEQ;
153       case ICmpInst::ICMP_NE:
154         return Expression::ICMPNE;
155       case ICmpInst::ICMP_UGT:
156         return Expression::ICMPUGT;
157       case ICmpInst::ICMP_UGE:
158         return Expression::ICMPUGE;
159       case ICmpInst::ICMP_ULT:
160         return Expression::ICMPULT;
161       case ICmpInst::ICMP_ULE:
162         return Expression::ICMPULE;
163       case ICmpInst::ICMP_SGT:
164         return Expression::ICMPSGT;
165       case ICmpInst::ICMP_SGE:
166         return Expression::ICMPSGE;
167       case ICmpInst::ICMP_SLT:
168         return Expression::ICMPSLT;
169       case ICmpInst::ICMP_SLE:
170         return Expression::ICMPSLE;
171       
172       // THIS SHOULD NEVER HAPPEN
173       default:
174         assert(0 && "Comparison with unknown predicate?");
175         return Expression::ICMPEQ;
176     }
177   } else {
178     switch (C->getPredicate()) {
179       case FCmpInst::FCMP_OEQ:
180         return Expression::FCMPOEQ;
181       case FCmpInst::FCMP_OGT:
182         return Expression::FCMPOGT;
183       case FCmpInst::FCMP_OGE:
184         return Expression::FCMPOGE;
185       case FCmpInst::FCMP_OLT:
186         return Expression::FCMPOLT;
187       case FCmpInst::FCMP_OLE:
188         return Expression::FCMPOLE;
189       case FCmpInst::FCMP_ONE:
190         return Expression::FCMPONE;
191       case FCmpInst::FCMP_ORD:
192         return Expression::FCMPORD;
193       case FCmpInst::FCMP_UNO:
194         return Expression::FCMPUNO;
195       case FCmpInst::FCMP_UEQ:
196         return Expression::FCMPUEQ;
197       case FCmpInst::FCMP_UGT:
198         return Expression::FCMPUGT;
199       case FCmpInst::FCMP_UGE:
200         return Expression::FCMPUGE;
201       case FCmpInst::FCMP_ULT:
202         return Expression::FCMPULT;
203       case FCmpInst::FCMP_ULE:
204         return Expression::FCMPULE;
205       case FCmpInst::FCMP_UNE:
206         return Expression::FCMPUNE;
207       
208       // THIS SHOULD NEVER HAPPEN
209       default:
210         assert(0 && "Comparison with unknown predicate?");
211         return Expression::FCMPOEQ;
212     }
213   }
214 }
215
216 uint32_t ValueTable::lookup_or_add(Value* V) {
217   maximalValues.insert(V);
218
219   std::map<Value*, uint32_t>::iterator VI = valueNumbering.find(V);
220   if (VI != valueNumbering.end())
221     return VI->second;
222   
223   
224   if (BinaryOperator* BO = dyn_cast<BinaryOperator>(V)) {
225     Expression e = create_expression(BO);
226     
227     std::map<Expression, uint32_t>::iterator EI = expressionNumbering.find(e);
228     if (EI != expressionNumbering.end()) {
229       valueNumbering.insert(std::make_pair(V, EI->second));
230       return EI->second;
231     } else {
232       expressionNumbering.insert(std::make_pair(e, nextValueNumber));
233       valueNumbering.insert(std::make_pair(V, nextValueNumber));
234       
235       return nextValueNumber++;
236     }
237   } else if (CmpInst* C = dyn_cast<CmpInst>(V)) {
238     Expression e = create_expression(C);
239     
240     std::map<Expression, uint32_t>::iterator EI = expressionNumbering.find(e);
241     if (EI != expressionNumbering.end()) {
242       valueNumbering.insert(std::make_pair(V, EI->second));
243       return EI->second;
244     } else {
245       expressionNumbering.insert(std::make_pair(e, nextValueNumber));
246       valueNumbering.insert(std::make_pair(V, nextValueNumber));
247       
248       return nextValueNumber++;
249     }
250   } else {
251     valueNumbering.insert(std::make_pair(V, nextValueNumber));
252     return nextValueNumber++;
253   }
254 }
255
256 uint32_t ValueTable::lookup(Value* V) {
257   std::map<Value*, uint32_t>::iterator VI = valueNumbering.find(V);
258   if (VI != valueNumbering.end())
259     return VI->second;
260   else
261     assert(0 && "Value not numbered?");
262   
263   return 0;
264 }
265
266 void ValueTable::add(Value* V, uint32_t num) {
267   std::map<Value*, uint32_t>::iterator VI = valueNumbering.find(V);
268   if (VI != valueNumbering.end())
269     valueNumbering.erase(VI);
270   valueNumbering.insert(std::make_pair(V, num));
271 }
272
273 ValueTable::Expression ValueTable::create_expression(BinaryOperator* BO) {
274   Expression e;
275     
276   e.leftVN = lookup_or_add(BO->getOperand(0));
277   e.rightVN = lookup_or_add(BO->getOperand(1));
278   e.opcode = getOpcode(BO);
279   
280   maximalExpressions.insert(e);
281   
282   return e;
283 }
284
285 ValueTable::Expression ValueTable::create_expression(CmpInst* C) {
286   Expression e;
287     
288   e.leftVN = lookup_or_add(C->getOperand(0));
289   e.rightVN = lookup_or_add(C->getOperand(1));
290   e.opcode = getOpcode(C);
291   
292   maximalExpressions.insert(e);
293   
294   return e;
295 }
296
297 void ValueTable::clear() {
298   valueNumbering.clear();
299   expressionNumbering.clear();
300   nextValueNumber = 1;
301 }
302
303 namespace {
304
305   class VISIBILITY_HIDDEN GVNPRE : public FunctionPass {
306     bool runOnFunction(Function &F);
307   public:
308     static char ID; // Pass identification, replacement for typeid
309     GVNPRE() : FunctionPass((intptr_t)&ID) { nextValueNumber = 1; }
310
311   private:
312     uint32_t nextValueNumber;
313     ValueTable VN;
314     std::vector<Instruction*> createdExpressions;
315     
316     std::map<BasicBlock*, std::set<Value*> > availableOut;
317     std::map<BasicBlock*, std::set<Value*> > anticipatedIn;
318     std::map<User*, bool> invokeDep;
319     
320     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
321       AU.setPreservesCFG();
322       AU.addRequired<DominatorTree>();
323       AU.addRequired<PostDominatorTree>();
324     }
325   
326     // Helper fuctions
327     // FIXME: eliminate or document these better
328     void dump(const std::set<Value*>& s) const;
329     void dump_unique(const std::set<Value*>& s) const;
330     void clean(std::set<Value*>& set);
331     Value* find_leader(std::set<Value*>& vals,
332                        uint32_t v);
333     Value* phi_translate(Value* V, BasicBlock* pred, BasicBlock* succ);
334     void phi_translate_set(std::set<Value*>& anticIn, BasicBlock* pred,
335                            BasicBlock* succ, std::set<Value*>& out);
336     
337     void topo_sort(std::set<Value*>& set,
338                    std::vector<Value*>& vec);
339     
340     // For a given block, calculate the generated expressions, temporaries,
341     // and the AVAIL_OUT set
342     void cleanup();
343     void elimination();
344     
345     void val_insert(std::set<Value*>& s, Value* v);
346     void val_replace(std::set<Value*>& s, Value* v);
347     bool dependsOnInvoke(Value* V);
348   
349   };
350   
351   char GVNPRE::ID = 0;
352   
353 }
354
355 FunctionPass *llvm::createGVNPREPass() { return new GVNPRE(); }
356
357 RegisterPass<GVNPRE> X("gvnpre",
358                        "Global Value Numbering/Partial Redundancy Elimination");
359
360
361 STATISTIC(NumInsertedVals, "Number of values inserted");
362 STATISTIC(NumInsertedPhis, "Number of PHI nodes inserted");
363 STATISTIC(NumEliminated, "Number of redundant instructions eliminated");
364
365 Value* GVNPRE::find_leader(std::set<Value*>& vals, uint32_t v) {
366   for (std::set<Value*>::iterator I = vals.begin(), E = vals.end();
367        I != E; ++I)
368     if (v == VN.lookup(*I))
369       return *I;
370   
371   return 0;
372 }
373
374 void GVNPRE::val_insert(std::set<Value*>& s, Value* v) {
375   uint32_t num = VN.lookup(v);
376   Value* leader = find_leader(s, num);
377   if (leader == 0)
378     s.insert(v);
379 }
380
381 void GVNPRE::val_replace(std::set<Value*>& s, Value* v) {
382   uint32_t num = VN.lookup(v);
383   Value* leader = find_leader(s, num);
384   while (leader != 0) {
385     s.erase(leader);
386     leader = find_leader(s, num);
387   }
388   s.insert(v);
389 }
390
391 Value* GVNPRE::phi_translate(Value* V, BasicBlock* pred, BasicBlock* succ) {
392   if (V == 0)
393     return 0;
394   
395   if (BinaryOperator* BO = dyn_cast<BinaryOperator>(V)) {
396     Value* newOp1 = 0;
397     if (isa<Instruction>(BO->getOperand(0)))
398       newOp1 = phi_translate(find_leader(anticipatedIn[succ],         
399                                          VN.lookup(BO->getOperand(0))),
400                              pred, succ);
401     else
402       newOp1 = BO->getOperand(0);
403     
404     if (newOp1 == 0)
405       return 0;
406     
407     Value* newOp2 = 0;
408     if (isa<Instruction>(BO->getOperand(1)))
409       newOp2 = phi_translate(find_leader(anticipatedIn[succ],         
410                                          VN.lookup(BO->getOperand(1))),
411                              pred, succ);
412     else
413       newOp2 = BO->getOperand(1);
414     
415     if (newOp2 == 0)
416       return 0;
417     
418     if (newOp1 != BO->getOperand(0) || newOp2 != BO->getOperand(1)) {
419       Instruction* newVal = BinaryOperator::create(BO->getOpcode(),
420                                              newOp1, newOp2,
421                                              BO->getName()+".gvnpre");
422       
423       uint32_t v = VN.lookup_or_add(newVal);
424       
425       Value* leader = find_leader(availableOut[pred], v);
426       if (leader == 0) {
427         createdExpressions.push_back(newVal);
428         return newVal;
429       } else {
430         delete newVal;
431         return leader;
432       }
433     }
434   } else if (PHINode* P = dyn_cast<PHINode>(V)) {
435     if (P->getParent() == succ)
436       return P->getIncomingValueForBlock(pred);
437   } else if (CmpInst* C = dyn_cast<CmpInst>(V)) {
438     Value* newOp1 = 0;
439     if (isa<Instruction>(C->getOperand(0)))
440       newOp1 = phi_translate(find_leader(anticipatedIn[succ],         
441                                          VN.lookup(C->getOperand(0))),
442                              pred, succ);
443     else
444       newOp1 = C->getOperand(0);
445     
446     if (newOp1 == 0)
447       return 0;
448     
449     Value* newOp2 = 0;
450     if (isa<Instruction>(C->getOperand(1)))
451       newOp2 = phi_translate(find_leader(anticipatedIn[succ],         
452                                          VN.lookup(C->getOperand(1))),
453                              pred, succ);
454     else
455       newOp2 = C->getOperand(1);
456       
457     if (newOp2 == 0)
458       return 0;
459     
460     if (newOp1 != C->getOperand(0) || newOp2 != C->getOperand(1)) {
461       Instruction* newVal = CmpInst::create(C->getOpcode(),
462                                             C->getPredicate(),
463                                              newOp1, newOp2,
464                                              C->getName()+".gvnpre");
465       
466       uint32_t v = VN.lookup_or_add(newVal);
467         
468       Value* leader = find_leader(availableOut[pred], v);
469       if (leader == 0) {
470         createdExpressions.push_back(newVal);
471         return newVal;
472       } else {
473         delete newVal;
474         return leader;
475       }
476     }
477   }
478   
479   return V;
480 }
481
482 void GVNPRE::phi_translate_set(std::set<Value*>& anticIn,
483                               BasicBlock* pred, BasicBlock* succ,
484                               std::set<Value*>& out) {
485   for (std::set<Value*>::iterator I = anticIn.begin(),
486        E = anticIn.end(); I != E; ++I) {
487     Value* V = phi_translate(*I, pred, succ);
488     if (V != 0)
489       out.insert(V);
490   }
491 }
492
493 bool GVNPRE::dependsOnInvoke(Value* V) {
494   if (!isa<User>(V))
495     return false;
496     
497   User* U = cast<User>(V);
498   std::map<User*, bool>::iterator I = invokeDep.find(U);
499   if (I != invokeDep.end())
500     return I->second;
501   
502   std::vector<Value*> worklist(U->op_begin(), U->op_end());
503   std::set<Value*> visited;
504   
505   while (!worklist.empty()) {
506     Value* current = worklist.back();
507     worklist.pop_back();
508     visited.insert(current);
509     
510     if (!isa<User>(current))
511       continue;
512     else if (isa<InvokeInst>(current))
513       return true;
514     
515     User* curr = cast<User>(current);
516     std::map<User*, bool>::iterator CI = invokeDep.find(curr);
517     if (CI != invokeDep.end()) {
518       if (CI->second)
519         return true;
520     } else {
521       for (unsigned i = 0; i < curr->getNumOperands(); ++i)
522         if (visited.find(curr->getOperand(i)) == visited.end())
523           worklist.push_back(curr->getOperand(i));
524     }
525   }
526   
527   return false;
528 }
529
530 // Remove all expressions whose operands are not themselves in the set
531 void GVNPRE::clean(std::set<Value*>& set) {
532   std::vector<Value*> worklist;
533   topo_sort(set, worklist);
534   
535   for (unsigned i = 0; i < worklist.size(); ++i) {
536     Value* v = worklist[i];
537     
538     if (BinaryOperator* BO = dyn_cast<BinaryOperator>(v)) {   
539       bool lhsValid = !isa<Instruction>(BO->getOperand(0));
540       if (!lhsValid)
541         for (std::set<Value*>::iterator I = set.begin(), E = set.end();
542              I != E; ++I)
543           if (VN.lookup(*I) == VN.lookup(BO->getOperand(0))) {
544             lhsValid = true;
545             break;
546           }
547           
548       // Check for dependency on invoke insts
549       // NOTE: This check is expensive, so don't do it if we
550       // don't have to
551       if (lhsValid)
552         lhsValid = !dependsOnInvoke(BO->getOperand(0));
553     
554       bool rhsValid = !isa<Instruction>(BO->getOperand(1));
555       if (!rhsValid)
556         for (std::set<Value*>::iterator I = set.begin(), E = set.end();
557              I != E; ++I)
558           if (VN.lookup(*I) == VN.lookup(BO->getOperand(1))) {
559             rhsValid = true;
560             break;
561           }
562       
563       // Check for dependency on invoke insts
564       // NOTE: This check is expensive, so don't do it if we
565       // don't have to
566       if (rhsValid)
567         rhsValid = !dependsOnInvoke(BO->getOperand(1));
568       
569       if (!lhsValid || !rhsValid)
570         set.erase(BO);
571     } else if (CmpInst* C = dyn_cast<CmpInst>(v)) {
572       bool lhsValid = !isa<Instruction>(C->getOperand(0));
573       if (!lhsValid)
574         for (std::set<Value*>::iterator I = set.begin(), E = set.end();
575              I != E; ++I)
576           if (VN.lookup(*I) == VN.lookup(C->getOperand(0))) {
577             lhsValid = true;
578             break;
579           }
580       lhsValid &= !dependsOnInvoke(C->getOperand(0));
581       
582       bool rhsValid = !isa<Instruction>(C->getOperand(1));
583       if (!rhsValid)
584       for (std::set<Value*>::iterator I = set.begin(), E = set.end();
585            I != E; ++I)
586         if (VN.lookup(*I) == VN.lookup(C->getOperand(1))) {
587           rhsValid = true;
588           break;
589         }
590       rhsValid &= !dependsOnInvoke(C->getOperand(1));
591     
592       if (!lhsValid || !rhsValid)
593         set.erase(C);
594     }
595   }
596 }
597
598 void GVNPRE::topo_sort(std::set<Value*>& set,
599                        std::vector<Value*>& vec) {
600   std::set<Value*> toErase;
601   for (std::set<Value*>::iterator I = set.begin(), E = set.end();
602        I != E; ++I) {
603     if (BinaryOperator* BO = dyn_cast<BinaryOperator>(*I))
604       for (std::set<Value*>::iterator SI = set.begin(); SI != E; ++SI) {
605         if (VN.lookup(BO->getOperand(0)) == VN.lookup(*SI) ||
606             VN.lookup(BO->getOperand(1)) == VN.lookup(*SI)) {
607           toErase.insert(*SI);
608         }
609       }
610     else if (CmpInst* C = dyn_cast<CmpInst>(*I))
611       for (std::set<Value*>::iterator SI = set.begin(); SI != E; ++SI) {
612         if (VN.lookup(C->getOperand(0)) == VN.lookup(*SI) ||
613             VN.lookup(C->getOperand(1)) == VN.lookup(*SI)) {
614           toErase.insert(*SI);
615         }
616       }
617   }
618   
619   std::vector<Value*> Q;
620   for (std::set<Value*>::iterator I = set.begin(), E = set.end();
621        I != E; ++I) {
622     if (toErase.find(*I) == toErase.end())
623       Q.push_back(*I);
624   }
625   
626   std::set<Value*> visited;
627   while (!Q.empty()) {
628     Value* e = Q.back();
629   
630     if (BinaryOperator* BO = dyn_cast<BinaryOperator>(e)) {
631       Value* l = find_leader(set, VN.lookup(BO->getOperand(0)));
632       Value* r = find_leader(set, VN.lookup(BO->getOperand(1)));
633       
634       if (l != 0 && isa<Instruction>(l) &&
635           visited.find(l) == visited.end())
636         Q.push_back(l);
637       else if (r != 0 && isa<Instruction>(r) &&
638                visited.find(r) == visited.end())
639         Q.push_back(r);
640       else {
641         vec.push_back(e);
642         visited.insert(e);
643         Q.pop_back();
644       }
645     } else if (CmpInst* C = dyn_cast<CmpInst>(e)) {
646       Value* l = find_leader(set, VN.lookup(C->getOperand(0)));
647       Value* r = find_leader(set, VN.lookup(C->getOperand(1)));
648       
649       if (l != 0 && isa<Instruction>(l) &&
650           visited.find(l) == visited.end())
651         Q.push_back(l);
652       else if (r != 0 && isa<Instruction>(r) &&
653                visited.find(r) == visited.end())
654         Q.push_back(r);
655       else {
656         vec.push_back(e);
657         visited.insert(e);
658         Q.pop_back();
659       }
660     } else {
661       visited.insert(e);
662       vec.push_back(e);
663       Q.pop_back();
664     }
665   }
666 }
667
668
669 void GVNPRE::dump(const std::set<Value*>& s) const {
670   DOUT << "{ ";
671   for (std::set<Value*>::iterator I = s.begin(), E = s.end();
672        I != E; ++I) {
673     DEBUG((*I)->dump());
674   }
675   DOUT << "}\n\n";
676 }
677
678 void GVNPRE::dump_unique(const std::set<Value*>& s) const {
679   DOUT << "{ ";
680   for (std::set<Value*>::iterator I = s.begin(), E = s.end();
681        I != E; ++I) {
682     DEBUG((*I)->dump());
683   }
684   DOUT << "}\n\n";
685 }
686
687 void GVNPRE::elimination() {
688   DOUT << "\n\nPhase 3: Elimination\n\n";
689   
690   std::vector<std::pair<Instruction*, Value*> > replace;
691   std::vector<Instruction*> erase;
692   
693   DominatorTree& DT = getAnalysis<DominatorTree>();
694   
695   for (df_iterator<DomTreeNode*> DI = df_begin(DT.getRootNode()),
696          E = df_end(DT.getRootNode()); DI != E; ++DI) {
697     BasicBlock* BB = DI->getBlock();
698     
699     DOUT << "Block: " << BB->getName() << "\n";
700     dump_unique(availableOut[BB]);
701     DOUT << "\n\n";
702     
703     for (BasicBlock::iterator BI = BB->begin(), BE = BB->end();
704          BI != BE; ++BI) {
705
706       if (isa<BinaryOperator>(BI) || isa<CmpInst>(BI)) {
707          Value *leader = find_leader(availableOut[BB], VN.lookup(BI));
708   
709         if (leader != 0)
710           if (Instruction* Instr = dyn_cast<Instruction>(leader))
711             if (Instr->getParent() != 0 && Instr != BI) {
712               replace.push_back(std::make_pair(BI, leader));
713               erase.push_back(BI);
714               ++NumEliminated;
715             }
716       }
717     }
718   }
719   
720   while (!replace.empty()) {
721     std::pair<Instruction*, Value*> rep = replace.back();
722     replace.pop_back();
723     rep.first->replaceAllUsesWith(rep.second);
724   }
725     
726   for (std::vector<Instruction*>::iterator I = erase.begin(), E = erase.end();
727        I != E; ++I)
728      (*I)->eraseFromParent();
729 }
730
731
732 void GVNPRE::cleanup() {
733   while (!createdExpressions.empty()) {
734     Instruction* I = createdExpressions.back();
735     createdExpressions.pop_back();
736     
737     delete I;
738   }
739 }
740
741 bool GVNPRE::runOnFunction(Function &F) {
742   VN.clear();
743   createdExpressions.clear();
744   availableOut.clear();
745   anticipatedIn.clear();
746   invokeDep.clear();
747
748   std::map<BasicBlock*, std::set<Value*> > generatedExpressions;
749   std::map<BasicBlock*, std::set<PHINode*> > generatedPhis;
750   std::map<BasicBlock*, std::set<Value*> > generatedTemporaries;
751   
752   
753   DominatorTree &DT = getAnalysis<DominatorTree>();   
754   
755   // Phase 1: BuildSets
756   
757   // Phase 1, Part 1: calculate AVAIL_OUT
758   
759   // Top-down walk of the dominator tree
760   for (df_iterator<DomTreeNode*> DI = df_begin(DT.getRootNode()),
761          E = df_end(DT.getRootNode()); DI != E; ++DI) {
762     
763     // Get the sets to update for this block
764     std::set<Value*>& currExps = generatedExpressions[DI->getBlock()];
765     std::set<PHINode*>& currPhis = generatedPhis[DI->getBlock()];
766     std::set<Value*>& currTemps = generatedTemporaries[DI->getBlock()];
767     std::set<Value*>& currAvail = availableOut[DI->getBlock()];     
768     
769     BasicBlock* BB = DI->getBlock();
770   
771     // A block inherits AVAIL_OUT from its dominator
772     if (DI->getIDom() != 0)
773     currAvail.insert(availableOut[DI->getIDom()->getBlock()].begin(),
774                      availableOut[DI->getIDom()->getBlock()].end());
775     
776     
777     for (BasicBlock::iterator BI = BB->begin(), BE = BB->end();
778          BI != BE; ++BI) {
779        
780       // Handle PHI nodes...
781       if (PHINode* p = dyn_cast<PHINode>(BI)) {
782         VN.lookup_or_add(p);
783         currPhis.insert(p);
784     
785       // Handle binary ops...
786       } else if (BinaryOperator* BO = dyn_cast<BinaryOperator>(BI)) {
787         Value* leftValue = BO->getOperand(0);
788         Value* rightValue = BO->getOperand(1);
789       
790         VN.lookup_or_add(BO);
791       
792         if (isa<Instruction>(leftValue))
793           val_insert(currExps, leftValue);
794         if (isa<Instruction>(rightValue))
795           val_insert(currExps, rightValue);
796         val_insert(currExps, BO);
797       
798       // Handle cmp ops...
799       } else if (CmpInst* C = dyn_cast<CmpInst>(BI)) {
800         Value* leftValue = C->getOperand(0);
801         Value* rightValue = C->getOperand(1);
802       
803         VN.lookup_or_add(C);
804       
805         if (isa<Instruction>(leftValue))
806           val_insert(currExps, leftValue);
807         if (isa<Instruction>(rightValue))
808           val_insert(currExps, rightValue);
809         val_insert(currExps, C);
810       
811       // Handle unsupported ops
812       } else if (!BI->isTerminator()){
813         VN.lookup_or_add(BI);
814         currTemps.insert(BI);
815       }
816     
817       if (!BI->isTerminator())
818         val_insert(currAvail, BI);
819     }
820   }
821   
822   DOUT << "Maximal Set: ";
823   dump_unique(VN.getMaximalValues());
824   DOUT << "\n";
825   
826   // If function has no exit blocks, only perform GVN
827   PostDominatorTree &PDT = getAnalysis<PostDominatorTree>();
828   if (PDT[&F.getEntryBlock()] == 0) {
829     elimination();
830     cleanup();
831     
832     return true;
833   }
834   
835   
836   // Phase 1, Part 2: calculate ANTIC_IN
837   
838   std::set<BasicBlock*> visited;
839   
840   bool changed = true;
841   unsigned iterations = 0;
842   while (changed) {
843     changed = false;
844     std::set<Value*> anticOut;
845     
846     // Top-down walk of the postdominator tree
847     for (df_iterator<DomTreeNode*> PDI = 
848          df_begin(PDT.getRootNode()), E = df_end(PDT.getRootNode());
849          PDI != E; ++PDI) {
850       BasicBlock* BB = PDI->getBlock();
851       if (BB == 0)
852         continue;
853       
854       DOUT << "Block: " << BB->getName() << "\n";
855       DOUT << "TMP_GEN: ";
856       dump(generatedTemporaries[BB]);
857       DOUT << "\n";
858     
859       DOUT << "EXP_GEN: ";
860       dump_unique(generatedExpressions[BB]);
861       visited.insert(BB);
862       
863       std::set<Value*>& anticIn = anticipatedIn[BB];
864       std::set<Value*> old (anticIn.begin(), anticIn.end());
865       
866       if (BB->getTerminator()->getNumSuccessors() == 1) {
867          if (visited.find(BB->getTerminator()->getSuccessor(0)) == 
868              visited.end())
869            phi_translate_set(VN.getMaximalValues(), BB,    
870                              BB->getTerminator()->getSuccessor(0),
871                              anticOut);
872          else
873           phi_translate_set(anticipatedIn[BB->getTerminator()->getSuccessor(0)],
874                             BB,  BB->getTerminator()->getSuccessor(0), 
875                             anticOut);
876       } else if (BB->getTerminator()->getNumSuccessors() > 1) {
877         BasicBlock* first = BB->getTerminator()->getSuccessor(0);
878         anticOut.insert(anticipatedIn[first].begin(),
879                         anticipatedIn[first].end());
880         for (unsigned i = 1; i < BB->getTerminator()->getNumSuccessors(); ++i) {
881           BasicBlock* currSucc = BB->getTerminator()->getSuccessor(i);
882           std::set<Value*>& succAnticIn = anticipatedIn[currSucc];
883           
884           std::set<Value*> temp;
885           std::insert_iterator<std::set<Value*> >  temp_ins(temp, 
886                                                             temp.begin());
887           std::set_intersection(anticOut.begin(), anticOut.end(),
888                                 succAnticIn.begin(), succAnticIn.end(),
889                                 temp_ins);
890           
891           anticOut.clear();
892           anticOut.insert(temp.begin(), temp.end());
893         }
894       }
895       
896       DOUT << "ANTIC_OUT: ";
897       dump_unique(anticOut);
898       DOUT << "\n";
899       
900       std::set<Value*> S;
901       std::insert_iterator<std::set<Value*> >  s_ins(S, S.begin());
902       std::set_difference(anticOut.begin(), anticOut.end(),
903                      generatedTemporaries[BB].begin(),
904                      generatedTemporaries[BB].end(),
905                      s_ins);
906       
907       anticIn.clear();
908       std::insert_iterator<std::set<Value*> >  ai_ins(anticIn, anticIn.begin());
909       std::set_difference(generatedExpressions[BB].begin(),
910                      generatedExpressions[BB].end(),
911                      generatedTemporaries[BB].begin(),
912                      generatedTemporaries[BB].end(),
913                      ai_ins);
914       
915       for (std::set<Value*>::iterator I = S.begin(), E = S.end();
916            I != E; ++I) {
917         if (find_leader(anticIn, VN.lookup(*I)) == 0)
918           val_insert(anticIn, *I);
919       }
920       
921       clean(anticIn);
922       
923       DOUT << "ANTIC_IN: ";
924       dump_unique(anticIn);
925       DOUT << "\n";
926       
927       if (old.size() != anticIn.size())
928         changed = true;
929       
930       anticOut.clear();
931     }
932     
933     iterations++;
934   }
935   
936   DOUT << "Iterations: " << iterations << "\n";
937   
938   for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) {
939     DOUT << "Name: " << I->getName().c_str() << "\n";
940     
941     DOUT << "TMP_GEN: ";
942     dump(generatedTemporaries[I]);
943     DOUT << "\n";
944     
945     DOUT << "EXP_GEN: ";
946     dump_unique(generatedExpressions[I]);
947     DOUT << "\n";
948     
949     DOUT << "ANTIC_IN: ";
950     dump_unique(anticipatedIn[I]);
951     DOUT << "\n";
952     
953     DOUT << "AVAIL_OUT: ";
954     dump_unique(availableOut[I]);
955     DOUT << "\n";
956   }
957   
958   // Phase 2: Insert
959   DOUT<< "\nPhase 2: Insertion\n";
960   
961   std::map<BasicBlock*, std::set<Value*> > new_sets;
962   unsigned i_iterations = 0;
963   bool new_stuff = true;
964   while (new_stuff) {
965     new_stuff = false;
966     DOUT << "Iteration: " << i_iterations << "\n\n";
967     for (df_iterator<DomTreeNode*> DI = df_begin(DT.getRootNode()),
968          E = df_end(DT.getRootNode()); DI != E; ++DI) {
969       BasicBlock* BB = DI->getBlock();
970       
971       if (BB == 0)
972         continue;
973       
974       std::set<Value*>& new_set = new_sets[BB];
975       std::set<Value*>& availOut = availableOut[BB];
976       std::set<Value*>& anticIn = anticipatedIn[BB];
977       
978       new_set.clear();
979       
980       // Replace leaders with leaders inherited from dominator
981       if (DI->getIDom() != 0) {
982         std::set<Value*>& dom_set = new_sets[DI->getIDom()->getBlock()];
983         for (std::set<Value*>::iterator I = dom_set.begin(),
984              E = dom_set.end(); I != E; ++I) {
985           new_set.insert(*I);
986           val_replace(availOut, *I);
987         }
988       }
989       
990       // If there is more than one predecessor...
991       if (pred_begin(BB) != pred_end(BB) && ++pred_begin(BB) != pred_end(BB)) {
992         std::vector<Value*> workList;
993         topo_sort(anticIn, workList);
994         
995         DOUT << "Merge Block: " << BB->getName() << "\n";
996         DOUT << "ANTIC_IN: ";
997         dump_unique(anticIn);
998         DOUT << "\n";
999         
1000         for (unsigned i = 0; i < workList.size(); ++i) {
1001           Value* e = workList[i];
1002           
1003           if (isa<BinaryOperator>(e) || isa<CmpInst>(e)) {
1004             if (find_leader(availableOut[DI->getIDom()->getBlock()], VN.lookup(e)) != 0)
1005               continue;
1006             
1007             std::map<BasicBlock*, Value*> avail;
1008             bool by_some = false;
1009             int num_avail = 0;
1010             
1011             for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB); PI != PE;
1012                  ++PI) {
1013               Value *e2 = phi_translate(e, *PI, BB);
1014               Value *e3 = find_leader(availableOut[*PI], VN.lookup(e2));
1015               
1016               if (e3 == 0) {
1017                 std::map<BasicBlock*, Value*>::iterator av = avail.find(*PI);
1018                 if (av != avail.end())
1019                   avail.erase(av);
1020                 avail.insert(std::make_pair(*PI, e2));
1021               } else {
1022                 std::map<BasicBlock*, Value*>::iterator av = avail.find(*PI);
1023                 if (av != avail.end())
1024                   avail.erase(av);
1025                 avail.insert(std::make_pair(*PI, e3));
1026                 
1027                 by_some = true;
1028                 num_avail++;
1029               }
1030             }
1031             
1032             if (by_some &&
1033                 num_avail < std::distance(pred_begin(BB), pred_end(BB))) {
1034               DOUT << "Processing Value: ";
1035               DEBUG(e->dump());
1036               DOUT << "\n\n";
1037             
1038               for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB);
1039                    PI != PE; ++PI) {
1040                 Value* e2 = avail[*PI];
1041                 if (!find_leader(availableOut[*PI], VN.lookup(e2))) {
1042                   User* U = cast<User>(e2);
1043                 
1044                   Value* s1 = 0;
1045                   if (isa<BinaryOperator>(U->getOperand(0)) ||
1046                       isa<CmpInst>(U->getOperand(0)))
1047                     s1 = find_leader(availableOut[*PI], VN.lookup(U->getOperand(0)));
1048                   else
1049                     s1 = U->getOperand(0);
1050                   
1051                   Value* s2 = 0;
1052                   if (isa<BinaryOperator>(U->getOperand(1)) ||
1053                       isa<CmpInst>(U->getOperand(1)))
1054                     s2 = find_leader(availableOut[*PI], VN.lookup(U->getOperand(1)));
1055                   else
1056                     s2 = U->getOperand(1);
1057                   
1058                   Value* newVal = 0;
1059                   if (BinaryOperator* BO = dyn_cast<BinaryOperator>(U))
1060                     newVal = BinaryOperator::create(BO->getOpcode(),
1061                                              s1, s2,
1062                                              BO->getName()+".gvnpre",
1063                                              (*PI)->getTerminator());
1064                   else if (CmpInst* C = dyn_cast<CmpInst>(U))
1065                     newVal = CmpInst::create(C->getOpcode(),
1066                                              C->getPredicate(),
1067                                              s1, s2,
1068                                              C->getName()+".gvnpre",
1069                                              (*PI)->getTerminator());
1070                   
1071                   VN.add(newVal, VN.lookup(U));
1072                   
1073                   std::set<Value*>& predAvail = availableOut[*PI];
1074                   val_replace(predAvail, newVal);
1075                   
1076                   DOUT << "Creating value: " << std::hex << newVal << std::dec << "\n";
1077                   
1078                   std::map<BasicBlock*, Value*>::iterator av = avail.find(*PI);
1079                   if (av != avail.end())
1080                     avail.erase(av);
1081                   avail.insert(std::make_pair(*PI, newVal));
1082                   
1083                   ++NumInsertedVals;
1084                 }
1085               }
1086               
1087               PHINode* p = 0;
1088               
1089               for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB);
1090                    PI != PE; ++PI) {
1091                 if (p == 0)
1092                   p = new PHINode(avail[*PI]->getType(), "gvnpre-join", 
1093                                   BB->begin());
1094                 
1095                 p->addIncoming(avail[*PI], *PI);
1096               }
1097               
1098               VN.add(p, VN.lookup(e));
1099               DOUT << "Creating value: " << std::hex << p << std::dec << "\n";
1100               
1101               val_replace(availOut, p);
1102               availOut.insert(p);
1103               
1104               new_stuff = true;
1105               
1106               DOUT << "Preds After Processing: ";
1107               for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB);
1108                    PI != PE; ++PI)
1109                 DEBUG((*PI)->dump());
1110               DOUT << "\n\n";
1111               
1112               DOUT << "Merge Block After Processing: ";
1113               DEBUG(BB->dump());
1114               DOUT << "\n\n";
1115               
1116               new_set.insert(p);
1117               
1118               ++NumInsertedPhis;
1119             }
1120           }
1121         }
1122       }
1123     }
1124     i_iterations++;
1125   }
1126   
1127   // Phase 3: Eliminate
1128   elimination();
1129   
1130   // Phase 4: Cleanup
1131   cleanup();
1132   
1133   return true;
1134 }