Reapply commit 112699, speculatively reverted by echristo, since
[oota-llvm.git] / lib / Transforms / Scalar / JumpThreading.cpp
index c0b0cbebdea6405938e46e17df59b06e61ef5af0..104d5aecbdd32810122662e63586df86bd538db5 100644 (file)
@@ -24,6 +24,7 @@
 #include "llvm/Transforms/Utils/SSAUpdater.h"
 #include "llvm/Target/TargetData.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -77,6 +78,21 @@ namespace {
 #else
     SmallSet<AssertingVH<BasicBlock>, 16> LoopHeaders;
 #endif
+    DenseSet<std::pair<Value*, BasicBlock*> > RecursionSet;
+    
+    // RAII helper for updating the recursion stack.
+    struct RecursionSetRemover {
+      DenseSet<std::pair<Value*, BasicBlock*> > &TheSet;
+      std::pair<Value*, BasicBlock*> ThePair;
+      
+      RecursionSetRemover(DenseSet<std::pair<Value*, BasicBlock*> > &S,
+                          std::pair<Value*, BasicBlock*> P)
+        : TheSet(S), ThePair(P) { }
+      
+      ~RecursionSetRemover() {
+        TheSet.erase(ThePair);
+      }
+    };
   public:
     static char ID; // Pass identification
     JumpThreading() : FunctionPass(ID) {}
@@ -84,8 +100,10 @@ namespace {
     bool runOnFunction(Function &F);
     
     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
-      if (EnableLVI)
+      if (EnableLVI) {
         AU.addRequired<LazyValueInfo>();
+        AU.addPreserved<LazyValueInfo>();
+      }
     }
     
     void FindLoopHeaders(Function &F);
@@ -260,6 +278,17 @@ void JumpThreading::FindLoopHeaders(Function &F) {
     LoopHeaders.insert(const_cast<BasicBlock*>(Edges[i].second));
 }
 
+// Helper method for ComputeValueKnownInPredecessors.  If Value is a
+// ConstantInt, push it.  If it's an undef, push 0.  Otherwise, do nothing.
+static void PushConstantIntOrUndef(SmallVectorImpl<std::pair<ConstantInt*,
+                                                        BasicBlock*> > &Result,
+                              Constant *Value, BasicBlock* BB){
+  if (ConstantInt *FoldedCInt = dyn_cast<ConstantInt>(Value))
+    Result.push_back(std::make_pair(FoldedCInt, BB));
+  else if (isa<UndefValue>(Value))
+    Result.push_back(std::make_pair((ConstantInt*)0, BB));
+}
+
 /// ComputeValueKnownInPredecessors - Given a basic block BB and a value V, see
 /// if we can infer that the value is a known ConstantInt in any of our
 /// predecessors.  If so, return the known list of value and pred BB in the
@@ -269,12 +298,24 @@ void JumpThreading::FindLoopHeaders(Function &F) {
 ///
 bool JumpThreading::
 ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB,PredValueInfo &Result){
+  // This method walks up use-def chains recursively.  Because of this, we could
+  // get into an infinite loop going around loops in the use-def chain.  To
+  // prevent this, keep track of what (value, block) pairs we've already visited
+  // and terminate the search if we loop back to them
+  if (!RecursionSet.insert(std::make_pair(V, BB)).second)
+    return false;
+  
+  // An RAII help to remove this pair from the recursion set once the recursion
+  // stack pops back out again.
+  RecursionSetRemover remover(RecursionSet, std::make_pair(V, BB));
+  
   // If V is a constantint, then it is known in all predecessors.
   if (isa<ConstantInt>(V) || isa<UndefValue>(V)) {
     ConstantInt *CI = dyn_cast<ConstantInt>(V);
     
     for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI)
       Result.push_back(std::make_pair(CI, *PI));
+    
     return true;
   }
   
@@ -325,11 +366,12 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB,PredValueInfo &Result){
       } else if (LVI) {
         Constant *CI = LVI->getConstantOnEdge(InVal,
                                               PN->getIncomingBlock(i), BB);
-        ConstantInt *CInt = dyn_cast_or_null<ConstantInt>(CI);
-        if (CInt)
-          Result.push_back(std::make_pair(CInt, PN->getIncomingBlock(i)));
+        // LVI returns null is no value could be determined.
+        if (!CI) continue;
+        PushConstantIntOrUndef(Result, CI, PN->getIncomingBlock(i));
       }
     }
+    
     return !Result.empty();
   }
   
@@ -372,11 +414,8 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB,PredValueInfo &Result){
             Result.back().first = InterestingVal;
           }
         }
-      return !Result.empty();
-    
-    // Try to process a few other binary operator patterns.
-    } else if (isa<BinaryOperator>(I)) {
       
+      return !Result.empty();
     }
     
     // Handle the NOT form of XOR.
@@ -392,23 +431,27 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB,PredValueInfo &Result){
         if (Result[i].first)
           Result[i].first =
             cast<ConstantInt>(ConstantExpr::getNot(Result[i].first));
+      
       return true;
     }
   
   // Try to simplify some other binary operator values.
   } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) {
-    // AND or OR of a value with itself is that value.
-    ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1));
-    if (CI && (BO->getOpcode() == Instruction::And ||
-         BO->getOpcode() == Instruction::Or)) {
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
       SmallVector<std::pair<ConstantInt*, BasicBlock*>, 8> LHSVals;
       ComputeValueKnownInPredecessors(BO->getOperand(0), BB, LHSVals);
-      for (unsigned i = 0, e = LHSVals.size(); i != e; ++i) 
-        if (LHSVals[i].first == CI)
-        Result.push_back(std::make_pair(CI, LHSVals[i].second));
-      
-      return !Result.empty();
+    
+      // Try to use constant folding to simplify the binary operator.
+      for (unsigned i = 0, e = LHSVals.size(); i != e; ++i) {
+        Constant *V = LHSVals[i].first ? LHSVals[i].first :
+                                 cast<Constant>(UndefValue::get(BO->getType()));
+        Constant *Folded = ConstantExpr::get(BO->getOpcode(), V, CI);
+        
+        PushConstantIntOrUndef(Result, Folded, LHSVals[i].second);
+      }
     }
+      
+    return !Result.empty();
   }
   
   // Handle compare with phi operand, where the PHI is defined in this block.
@@ -435,10 +478,8 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB,PredValueInfo &Result){
           Res = ConstantInt::get(Type::getInt1Ty(LHS->getContext()), ResT);
         }
         
-        if (isa<UndefValue>(Res))
-          Result.push_back(std::make_pair((ConstantInt*)0, PredBB));
-        else if (ConstantInt *CI = dyn_cast<ConstantInt>(Res))
-          Result.push_back(std::make_pair(CI, PredBB));
+        if (Constant *ConstRes = dyn_cast<Constant>(Res))
+          PushConstantIntOrUndef(Result, ConstRes, PredBB);
       }
       
       return !Result.empty();
@@ -450,7 +491,7 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB,PredValueInfo &Result){
     if (LVI && isa<Constant>(Cmp->getOperand(1)) &&
         Cmp->getType()->isIntegerTy()) {
       if (!isa<Instruction>(Cmp->getOperand(0)) ||
-           cast<Instruction>(Cmp->getOperand(0))->getParent() != BB) {
+          cast<Instruction>(Cmp->getOperand(0))->getParent() != BB) {
         Constant *RHSCst = cast<Constant>(Cmp->getOperand(1));
 
         for (pred_iterator PI = pred_begin(BB), E = pred_end(BB);PI != E; ++PI){
@@ -470,22 +511,18 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB,PredValueInfo &Result){
         return !Result.empty();
       }
       
-      // Try to find a constant value for the LHS of an equality comparison,
+      // Try to find a constant value for the LHS of a comparison,
       // and evaluate it statically if we can.
-      if (Cmp->getPredicate() == CmpInst::ICMP_EQ || 
-          Cmp->getPredicate() == CmpInst::ICMP_NE) {
+      if (Constant *CmpConst = dyn_cast<Constant>(Cmp->getOperand(1))) {
         SmallVector<std::pair<ConstantInt*, BasicBlock*>, 8> LHSVals;
         ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals);
         
-        ConstantInt *True = ConstantInt::getTrue(I->getContext());
-        ConstantInt *False = ConstantInt::getFalse(I->getContext());
-        if (Cmp->getPredicate() == CmpInst::ICMP_NE) std::swap(True, False);
-        
         for (unsigned i = 0, e = LHSVals.size(); i != e; ++i) {
-          if (LHSVals[i].first == Cmp->getOperand(1))
-            Result.push_back(std::make_pair(True, LHSVals[i].second));
-          else 
-            Result.push_back(std::make_pair(False, LHSVals[i].second));
+          Constant *V = LHSVals[i].first ? LHSVals[i].first :
+                           cast<Constant>(UndefValue::get(CmpConst->getType()));
+          Constant *Folded = ConstantExpr::getCompare(Cmp->getPredicate(),
+                                                      V, CmpConst);
+          PushConstantIntOrUndef(Result, Folded, LHSVals[i].second);
         }
         
         return !Result.empty();
@@ -675,37 +712,36 @@ bool JumpThreading::ProcessBlock(BasicBlock *BB) {
     // the branch based on that.
     BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator());
     Constant *CondConst = dyn_cast<Constant>(CondCmp->getOperand(1));
-    if (LVI && CondBr && CondConst && CondBr->isConditional() &&
+    pred_iterator PI = pred_begin(BB), PE = pred_end(BB);
+    if (LVI && CondBr && CondConst && CondBr->isConditional() && PI != PE &&
         (!isa<Instruction>(CondCmp->getOperand(0)) ||
          cast<Instruction>(CondCmp->getOperand(0))->getParent() != BB)) {
       // For predecessor edge, determine if the comparison is true or false
       // on that edge.  If they're all true or all false, we can simplify the
       // branch.
       // FIXME: We could handle mixed true/false by duplicating code.
-      unsigned Trues = 0, Falses = 0, predcount = 0;
-      for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB);PI != PE; ++PI){
-        ++predcount;
-        LazyValueInfo::Tristate Ret =
-          LVI->getPredicateOnEdge(CondCmp->getPredicate(), 
-                                  CondCmp->getOperand(0), CondConst, *PI, BB);
-        if (Ret == LazyValueInfo::True)
-          ++Trues;
-        else if (Ret == LazyValueInfo::False)
-          ++Falses;
-      }
-      
-      // If we can determine the branch direction statically, convert
-      // the conditional branch to an unconditional one.
-      if (Trues && Trues == predcount) {
-        RemovePredecessorAndSimplify(CondBr->getSuccessor(1), BB, TD);
-        BranchInst::Create(CondBr->getSuccessor(0), CondBr);
-        CondBr->eraseFromParent();
-        return true;
-      } else if (Falses && Falses == predcount) {
-        RemovePredecessorAndSimplify(CondBr->getSuccessor(0), BB, TD);
-        BranchInst::Create(CondBr->getSuccessor(1), CondBr);
-        CondBr->eraseFromParent();
-        return true;
+      LazyValueInfo::Tristate Baseline =      
+        LVI->getPredicateOnEdge(CondCmp->getPredicate(), CondCmp->getOperand(0),
+                                CondConst, *PI, BB);
+      if (Baseline != LazyValueInfo::Unknown) {
+        // Check that all remaining incoming values match the first one.
+        while (++PI != PE) {
+          LazyValueInfo::Tristate Ret = LVI->getPredicateOnEdge(
+                                          CondCmp->getPredicate(),
+                                          CondCmp->getOperand(0),
+                                          CondConst, *PI, BB);
+          if (Ret != Baseline) break;
+        }
+        
+        // If we terminated early, then one of the values didn't match.
+        if (PI == PE) {
+          unsigned ToRemove = Baseline == LazyValueInfo::True ? 1 : 0;
+          unsigned ToKeep = Baseline == LazyValueInfo::True ? 0 : 1;
+          RemovePredecessorAndSimplify(CondBr->getSuccessor(ToRemove), BB, TD);
+          BranchInst::Create(CondBr->getSuccessor(ToKeep), CondBr);
+          CondBr->eraseFromParent();
+          return true;
+        }
       }
     }
   }
@@ -1125,6 +1161,7 @@ bool JumpThreading::ProcessThreadableEdges(Value *Cond, BasicBlock *BB) {
   SmallVector<std::pair<ConstantInt*, BasicBlock*>, 8> PredValues;
   if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues))
     return false;
+  
   assert(!PredValues.empty() &&
          "ComputeValueKnownInPredecessors returned true with no values");
 
@@ -1491,7 +1528,7 @@ bool JumpThreading::ThreadEdge(BasicBlock *BB,
     // We found a use of I outside of BB.  Rename all uses of I that are outside
     // its block to be uses of the appropriate PHI node etc.  See ValuesInBlocks
     // with the two values we know.
-    SSAUpdate.Initialize(I);
+    SSAUpdate.Initialize(I->getType(), I->getName());
     SSAUpdate.AddAvailableValue(BB, I);
     SSAUpdate.AddAvailableValue(NewBB, ValueMapping[I]);
     
@@ -1646,7 +1683,7 @@ bool JumpThreading::DuplicateCondBranchOnPHIIntoPred(BasicBlock *BB,
     // We found a use of I outside of BB.  Rename all uses of I that are outside
     // its block to be uses of the appropriate PHI node etc.  See ValuesInBlocks
     // with the two values we know.
-    SSAUpdate.Initialize(I);
+    SSAUpdate.Initialize(I->getType(), I->getName());
     SSAUpdate.AddAvailableValue(BB, I);
     SSAUpdate.AddAvailableValue(PredBB, ValueMapping[I]);