Move dumpPassStructure out of line.
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index db9792e3b3e4535ca44eaea239359e7414efcf05..4e0dba7e04b3acde61083563d9baa1b987202d5d 100644 (file)
@@ -66,6 +66,7 @@
 #include "llvm/GlobalVariable.h"
 #include "llvm/Instructions.h"
 #include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/Dominators.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Assembly/Writer.h"
 #include "llvm/Transforms/Scalar.h"
@@ -83,9 +84,6 @@
 #include <cmath>
 using namespace llvm;
 
-STATISTIC(NumBruteForceEvaluations,
-          "Number of brute force evaluations needed to "
-          "calculate high-order polynomial exit values");
 STATISTIC(NumArrayLenItCounts,
           "Number of trip counts computed with array length");
 STATISTIC(NumTripCountsComputed,
@@ -115,6 +113,7 @@ char ScalarEvolution::ID = 0;
 SCEV::~SCEV() {}
 void SCEV::dump() const {
   print(cerr);
+  cerr << '\n';
 }
 
 uint32_t SCEV::getBitWidth() const {
@@ -207,6 +206,10 @@ SCEVTruncateExpr::~SCEVTruncateExpr() {
   SCEVTruncates->erase(std::make_pair(Op, Ty));
 }
 
+bool SCEVTruncateExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
+  return Op->dominates(BB, DT);
+}
+
 void SCEVTruncateExpr::print(std::ostream &OS) const {
   OS << "(truncate " << *Op << " to " << *Ty << ")";
 }
@@ -229,6 +232,10 @@ SCEVZeroExtendExpr::~SCEVZeroExtendExpr() {
   SCEVZeroExtends->erase(std::make_pair(Op, Ty));
 }
 
+bool SCEVZeroExtendExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
+  return Op->dominates(BB, DT);
+}
+
 void SCEVZeroExtendExpr::print(std::ostream &OS) const {
   OS << "(zeroextend " << *Op << " to " << *Ty << ")";
 }
@@ -251,6 +258,10 @@ SCEVSignExtendExpr::~SCEVSignExtendExpr() {
   SCEVSignExtends->erase(std::make_pair(Op, Ty));
 }
 
+bool SCEVSignExtendExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
+  return Op->dominates(BB, DT);
+}
+
 void SCEVSignExtendExpr::print(std::ostream &OS) const {
   OS << "(signextend " << *Op << " to " << *Ty << ")";
 }
@@ -308,6 +319,14 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
   return this;
 }
 
+bool SCEVCommutativeExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
+  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
+    if (!getOperand(i)->dominates(BB, DT))
+      return false;
+  }
+  return true;
+}
+
 
 // SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular
 // input.  Don't use a SCEVHandle here, or else the object will never be
@@ -319,6 +338,10 @@ SCEVUDivExpr::~SCEVUDivExpr() {
   SCEVUDivs->erase(std::make_pair(LHS, RHS));
 }
 
+bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
+  return LHS->dominates(BB, DT) && RHS->dominates(BB, DT);
+}
+
 void SCEVUDivExpr::print(std::ostream &OS) const {
   OS << "(" << *LHS << " /u " << *RHS << ")";
 }
@@ -339,6 +362,15 @@ SCEVAddRecExpr::~SCEVAddRecExpr() {
                                                            Operands.end())));
 }
 
+bool SCEVAddRecExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
+  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
+    if (!getOperand(i)->dominates(BB, DT))
+      return false;
+  }
+  return true;
+}
+
+
 SCEVHandle SCEVAddRecExpr::
 replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
                                   const SCEVHandle &Conc,
@@ -393,6 +425,12 @@ bool SCEVUnknown::isLoopInvariant(const Loop *L) const {
   return true;
 }
 
+bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const {
+  if (Instruction *I = dyn_cast<Instruction>(getValue()))
+    return DT->dominates(I->getParent(), BB);
+  return true;
+}
+
 const Type *SCEVUnknown::getType() const {
   return V->getType();
 }
@@ -587,17 +625,7 @@ static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K,
   }
 
   // We need at least W + T bits for the multiplication step
-  // FIXME: A temporary hack; we round up the bitwidths
-  // to the nearest power of 2 to be nice to the code generator.
-  unsigned CalculationBits = 1U << Log2_32_Ceil(W + T);
-  // FIXME: Temporary hack to avoid generating integers that are too wide.
-  // Although, it's not completely clear how to determine how much
-  // widening is safe; for example, on X86, we can't really widen
-  // beyond 64 because we need to be able to do multiplication
-  // that's CalculationBits wide, but on X86-64, we can safely widen up to
-  // 128 bits.
-  if (CalculationBits > 64)
-    return new SCEVCouldNotCompute();
+  unsigned CalculationBits = W + T;
 
   // Calcuate 2^T, at width T+W.
   APInt DivFactor = APInt(CalculationBits, 1).shl(T);
@@ -644,11 +672,12 @@ SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It,
     // The computation is correct in the face of overflow provided that the
     // multiplication is performed _after_ the evaluation of the binomial
     // coefficient.
-    SCEVHandle Val =
-      SE.getMulExpr(getOperand(i),
-                    BinomialCoefficient(It, i, SE,
-                                        cast<IntegerType>(getType())));
-    Result = SE.getAddExpr(Result, Val);
+    SCEVHandle Coeff = BinomialCoefficient(It, i, SE,
+                                           cast<IntegerType>(getType()));
+    if (isa<SCEVCouldNotCompute>(Coeff))
+      return Coeff;
+
+    Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
   }
   return Result;
 }
@@ -1405,6 +1434,7 @@ namespace {
     void setSCEV(Value *V, const SCEVHandle &H) {
       bool isNew = Scalars.insert(std::make_pair(V, H)).second;
       assert(isNew && "This entry already existed!");
+      isNew = false;
     }
 
 
@@ -1414,6 +1444,11 @@ namespace {
     SCEVHandle getSCEVAtScope(SCEV *V, const Loop *L);
 
 
+    /// isLoopGuardedByCond - Test whether entry to the loop is protected by
+    /// a conditional between LHS and RHS.
+    bool isLoopGuardedByCond(const Loop *L, ICmpInst::Predicate Pred,
+                             SCEV *LHS, SCEV *RHS);
+
     /// hasLoopInvariantIterationCount - Return true if the specified loop has
     /// an analyzable loop-invariant iteration count.
     bool hasLoopInvariantIterationCount(const Loop *L);
@@ -1480,9 +1515,11 @@ namespace {
     SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L,
                                 bool isSigned);
 
-    /// executesAtLeastOnce - Test whether entry to the loop is protected by
-    /// a conditional between LHS and RHS.
-    bool executesAtLeastOnce(const Loop *L, bool isSigned, SCEV *LHS, SCEV *RHS);
+    /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
+    /// (which may not be an immediate predecessor) which has exactly one
+    /// successor from which BB is reachable, or null if no such block is
+    /// found.
+    BasicBlock* getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB);
 
     /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
     /// in the header of its containing loop, we know the loop executes a
@@ -1975,8 +2012,8 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) {
 
   // At this point, we would like to compute how many iterations of the 
   // loop the predicate will return true for these inputs.
-  if (isa<SCEVConstant>(LHS) && !isa<SCEVConstant>(RHS)) {
-    // If there is a constant, force it into the RHS.
+  if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) {
+    // If there is a loop-invariant, force it into the RHS.
     std::swap(LHS, RHS);
     Cond = ICmpInst::getSwappedPredicate(Cond);
   }
@@ -2597,6 +2634,11 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
     // The divisions must be performed as signed divisions.
     APInt NegB(-B);
     APInt TwoA( A << 1 );
+    if (TwoA.isMinValue()) {
+      SCEV *CNC = new SCEVCouldNotCompute();
+      return std::make_pair(CNC, CNC);
+    }
+
     ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA));
     ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA));
 
@@ -2703,22 +2745,42 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) {
   return UnknownValue;
 }
 
-/// executesAtLeastOnce - Test whether entry to the loop is protected by
+/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
+/// (which may not be an immediate predecessor) which has exactly one
+/// successor from which BB is reachable, or null if no such block is
+/// found.
+///
+BasicBlock *
+ScalarEvolutionsImpl::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
+  // If the block has a unique predecessor, the predecessor must have
+  // no other successors from which BB is reachable.
+  if (BasicBlock *Pred = BB->getSinglePredecessor())
+    return Pred;
+
+  // A loop's header is defined to be a block that dominates the loop.
+  // If the loop has a preheader, it must be a block that has exactly
+  // one successor that can reach BB. This is slightly more strict
+  // than necessary, but works if critical edges are split.
+  if (Loop *L = LI.getLoopFor(BB))
+    return L->getLoopPreheader();
+
+  return 0;
+}
+
+/// isLoopGuardedByCond - Test whether entry to the loop is protected by
 /// a conditional between LHS and RHS.
-bool ScalarEvolutionsImpl::executesAtLeastOnce(const Loop *L, bool isSigned,
+bool ScalarEvolutionsImpl::isLoopGuardedByCond(const Loop *L,
+                                               ICmpInst::Predicate Pred,
                                                SCEV *LHS, SCEV *RHS) {
   BasicBlock *Preheader = L->getLoopPreheader();
   BasicBlock *PreheaderDest = L->getHeader();
 
   // Starting at the preheader, climb up the predecessor chain, as long as
-  // there are unique predecessors, looking for a conditional branch that
-  // protects the loop.
-  // 
-  // This is a conservative apporoximation of a climb of the
-  // control-dependence predecessors.
-
-  for (; Preheader; PreheaderDest = Preheader,
-                    Preheader = Preheader->getSinglePredecessor()) {
+  // there are predecessors that can be found that have unique successors
+  // leading to the original header.
+  for (; Preheader;
+       PreheaderDest = Preheader,
+       Preheader = getPredecessorWithUniqueSuccessorForBB(Preheader)) {
 
     BranchInst *LoopEntryPredicate =
       dyn_cast<BranchInst>(Preheader->getTerminator());
@@ -2739,26 +2801,62 @@ bool ScalarEvolutionsImpl::executesAtLeastOnce(const Loop *L, bool isSigned,
     else
       Cond = ICI->getInversePredicate();
 
-    switch (Cond) {
-    case ICmpInst::ICMP_UGT:
-      if (isSigned) continue;
-      std::swap(PreCondLHS, PreCondRHS);
-      Cond = ICmpInst::ICMP_ULT;
-      break;
-    case ICmpInst::ICMP_SGT:
-      if (!isSigned) continue;
-      std::swap(PreCondLHS, PreCondRHS);
-      Cond = ICmpInst::ICMP_SLT;
-      break;
-    case ICmpInst::ICMP_ULT:
-      if (isSigned) continue;
-      break;
-    case ICmpInst::ICMP_SLT:
-      if (!isSigned) continue;
-      break;
-    default:
-      continue;
-    }
+    if (Cond == Pred)
+      ; // An exact match.
+    else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE)
+      ; // The actual condition is beyond sufficient.
+    else
+      // Check a few special cases.
+      switch (Cond) {
+      case ICmpInst::ICMP_UGT:
+        if (Pred == ICmpInst::ICMP_ULT) {
+          std::swap(PreCondLHS, PreCondRHS);
+          Cond = ICmpInst::ICMP_ULT;
+          break;
+        }
+        continue;
+      case ICmpInst::ICMP_SGT:
+        if (Pred == ICmpInst::ICMP_SLT) {
+          std::swap(PreCondLHS, PreCondRHS);
+          Cond = ICmpInst::ICMP_SLT;
+          break;
+        }
+        continue;
+      case ICmpInst::ICMP_NE:
+        // Expressions like (x >u 0) are often canonicalized to (x != 0),
+        // so check for this case by checking if the NE is comparing against
+        // a minimum or maximum constant.
+        if (!ICmpInst::isTrueWhenEqual(Pred))
+          if (ConstantInt *CI = dyn_cast<ConstantInt>(PreCondRHS)) {
+            const APInt &A = CI->getValue();
+            switch (Pred) {
+            case ICmpInst::ICMP_SLT:
+              if (A.isMaxSignedValue()) break;
+              continue;
+            case ICmpInst::ICMP_SGT:
+              if (A.isMinSignedValue()) break;
+              continue;
+            case ICmpInst::ICMP_ULT:
+              if (A.isMaxValue()) break;
+              continue;
+            case ICmpInst::ICMP_UGT:
+              if (A.isMinValue()) break;
+              continue;
+            default:
+              continue;
+            }
+            Cond = ICmpInst::ICMP_NE;
+            // NE is symmetric but the original comparison may not be. Swap
+            // the operands if necessary so that they match below.
+            if (isa<SCEVConstant>(LHS))
+              std::swap(PreCondLHS, PreCondRHS);
+            break;
+          }
+        continue;
+      default:
+        // We weren't able to reconcile the condition.
+        continue;
+      }
 
     if (!PreCondLHS->getType()->isInteger()) continue;
 
@@ -2799,7 +2897,8 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) {
     // First, we get the value of the LHS in the first iteration: n
     SCEVHandle Start = AddRec->getOperand(0);
 
-    if (executesAtLeastOnce(L, isSigned,
+    if (isLoopGuardedByCond(L,
+                            isSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT,
                             SE.getMinusSCEV(AddRec->getOperand(0), One), RHS)) {
       // Since we know that the condition is true in order to enter the loop,
       // we know that it will run exactly m-n times.
@@ -2935,27 +3034,6 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
     }
   }
 
-  // Fallback, if this is a general polynomial, figure out the progression
-  // through brute force: evaluate until we find an iteration that fails the
-  // test.  This is likely to be slow, but getting an accurate trip count is
-  // incredibly important, we will be able to simplify the exit test a lot, and
-  // we are almost guaranteed to get a trip count in this case.
-  ConstantInt *TestVal = ConstantInt::get(getType(), 0);
-  ConstantInt *EndVal  = TestVal;  // Stop when we wrap around.
-  do {
-    ++NumBruteForceEvaluations;
-    SCEVHandle Val = evaluateAtIteration(SE.getConstant(TestVal), SE);
-    if (!isa<SCEVConstant>(Val))  // This shouldn't happen.
-      return new SCEVCouldNotCompute();
-
-    // Check to see if we found the value!
-    if (!Range.contains(cast<SCEVConstant>(Val)->getValue()->getValue()))
-      return SE.getConstant(TestVal);
-
-    // Increment to test the next index.
-    TestVal = ConstantInt::get(TestVal->getValue()+1);
-  } while (TestVal != EndVal);
-
   return new SCEVCouldNotCompute();
 }
 
@@ -2998,6 +3076,13 @@ void ScalarEvolution::setSCEV(Value *V, const SCEVHandle &H) {
 }
 
 
+bool ScalarEvolution::isLoopGuardedByCond(const Loop *L,
+                                          ICmpInst::Predicate Pred,
+                                          SCEV *LHS, SCEV *RHS) {
+  return ((ScalarEvolutionsImpl*)Impl)->isLoopGuardedByCond(L, Pred,
+                                                            LHS, RHS);
+}
+
 SCEVHandle ScalarEvolution::getIterationCount(const Loop *L) const {
   return ((ScalarEvolutionsImpl*)Impl)->getIterationCount(L);
 }