Extend ScalarEvolution's multiple-exit support to compute exact
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 436b79dc07e4976b171fe14c9214d57fa9ac660d..d1f6679a43760480b7558b966d6b6bf60e09e089 100644 (file)
@@ -2813,7 +2813,6 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
   const SCEV* BECount = CouldNotCompute;
   const SCEV* MaxBECount = CouldNotCompute;
   bool CouldNotComputeBECount = false;
-  bool CouldNotComputeMaxBECount = false;
   for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
     BackedgeTakenInfo NewBTI =
       ComputeBackedgeTakenCountFromExit(L, ExitingBlocks[i]);
@@ -2826,25 +2825,13 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
     } else if (!CouldNotComputeBECount) {
       if (BECount == CouldNotCompute)
         BECount = NewBTI.Exact;
-      else {
-        // TODO: More analysis could be done here. For example, a
-        // loop with a short-circuiting && operator has an exact count
-        // of the min of both sides.
-        CouldNotComputeBECount = true;
-        BECount = CouldNotCompute;
-      }
-    }
-    if (NewBTI.Max == CouldNotCompute) {
-      // We couldn't compute an maximum value for this exit, so
-      // we won't be able to compute an maximum value for the loop.
-      CouldNotComputeMaxBECount = true;
-      MaxBECount = CouldNotCompute;
-    } else if (!CouldNotComputeMaxBECount) {
-      if (MaxBECount == CouldNotCompute)
-        MaxBECount = NewBTI.Max;
       else
-        MaxBECount = getUMaxFromMismatchedTypes(MaxBECount, NewBTI.Max);
+        BECount = getUMinFromMismatchedTypes(BECount, NewBTI.Exact);
     }
+    if (MaxBECount == CouldNotCompute)
+      MaxBECount = NewBTI.Max;
+    else if (NewBTI.Max != CouldNotCompute)
+      MaxBECount = getUMinFromMismatchedTypes(MaxBECount, NewBTI.Max);
   }
 
   return BackedgeTakenInfo(BECount, MaxBECount);
@@ -2925,9 +2912,7 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L,
                                                        Value *ExitCond,
                                                        BasicBlock *TBB,
                                                        BasicBlock *FBB) {
-  // Check if the controlling expression for this loop is an and or or. In
-  // such cases, an exact backedge-taken count may be infeasible, but a
-  // maximum count may still be feasible.
+  // Check if the controlling expression for this loop is an And or Or.
   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
     if (BO->getOpcode() == Instruction::And) {
       // Recurse on the operands of the and.
@@ -3899,88 +3884,111 @@ bool ScalarEvolution::isLoopGuardedByCond(const Loop *L,
         LoopEntryPredicate->isUnconditional())
       continue;
 
-    ICmpInst *ICI = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition());
-    if (!ICI) continue;
+    if (isNecessaryCond(LoopEntryPredicate->getCondition(), Pred, LHS, RHS,
+                        LoopEntryPredicate->getSuccessor(0) != PredecessorDest))
+      return true;
+  }
 
-    // Now that we found a conditional branch that dominates the loop, check to
-    // see if it is the comparison we are looking for.
-    Value *PreCondLHS = ICI->getOperand(0);
-    Value *PreCondRHS = ICI->getOperand(1);
-    ICmpInst::Predicate Cond;
-    if (LoopEntryPredicate->getSuccessor(0) == PredecessorDest)
-      Cond = ICI->getPredicate();
-    else
-      Cond = ICI->getInversePredicate();
+  return false;
+}
 
-    if (Cond == Pred)
-      ; // An exact match.
-    else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE)
-      ; // The actual condition is beyond sufficient.
-    else
-      // Check a few special cases.
-      switch (Cond) {
-      case ICmpInst::ICMP_UGT:
-        if (Pred == ICmpInst::ICMP_ULT) {
-          std::swap(PreCondLHS, PreCondRHS);
-          Cond = ICmpInst::ICMP_ULT;
-          break;
-        }
-        continue;
-      case ICmpInst::ICMP_SGT:
-        if (Pred == ICmpInst::ICMP_SLT) {
-          std::swap(PreCondLHS, PreCondRHS);
-          Cond = ICmpInst::ICMP_SLT;
+/// isNecessaryCond - Test whether the given CondValue value is a condition
+/// which is at least as strict as the one described by Pred, LHS, and RHS.
+bool ScalarEvolution::isNecessaryCond(Value *CondValue,
+                                      ICmpInst::Predicate Pred,
+                                      const SCEV *LHS, const SCEV *RHS,
+                                      bool Inverse) {
+  // Recursivly handle And and Or conditions.
+  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(CondValue)) {
+    if (BO->getOpcode() == Instruction::And) {
+      if (!Inverse)
+        return isNecessaryCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
+               isNecessaryCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
+    } else if (BO->getOpcode() == Instruction::Or) {
+      if (Inverse)
+        return isNecessaryCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
+               isNecessaryCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
+    }
+  }
+
+  ICmpInst *ICI = dyn_cast<ICmpInst>(CondValue);
+  if (!ICI) return false;
+
+  // Now that we found a conditional branch that dominates the loop, check to
+  // see if it is the comparison we are looking for.
+  Value *PreCondLHS = ICI->getOperand(0);
+  Value *PreCondRHS = ICI->getOperand(1);
+  ICmpInst::Predicate Cond;
+  if (Inverse)
+    Cond = ICI->getInversePredicate();
+  else
+    Cond = ICI->getPredicate();
+
+  if (Cond == Pred)
+    ; // An exact match.
+  else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE)
+    ; // The actual condition is beyond sufficient.
+  else
+    // Check a few special cases.
+    switch (Cond) {
+    case ICmpInst::ICMP_UGT:
+      if (Pred == ICmpInst::ICMP_ULT) {
+        std::swap(PreCondLHS, PreCondRHS);
+        Cond = ICmpInst::ICMP_ULT;
+        break;
+      }
+      return false;
+    case ICmpInst::ICMP_SGT:
+      if (Pred == ICmpInst::ICMP_SLT) {
+        std::swap(PreCondLHS, PreCondRHS);
+        Cond = ICmpInst::ICMP_SLT;
+        break;
+      }
+      return false;
+    case ICmpInst::ICMP_NE:
+      // Expressions like (x >u 0) are often canonicalized to (x != 0),
+      // so check for this case by checking if the NE is comparing against
+      // a minimum or maximum constant.
+      if (!ICmpInst::isTrueWhenEqual(Pred))
+        if (ConstantInt *CI = dyn_cast<ConstantInt>(PreCondRHS)) {
+          const APInt &A = CI->getValue();
+          switch (Pred) {
+          case ICmpInst::ICMP_SLT:
+            if (A.isMaxSignedValue()) break;
+            return false;
+          case ICmpInst::ICMP_SGT:
+            if (A.isMinSignedValue()) break;
+            return false;
+          case ICmpInst::ICMP_ULT:
+            if (A.isMaxValue()) break;
+            return false;
+          case ICmpInst::ICMP_UGT:
+            if (A.isMinValue()) break;
+            return false;
+          default:
+            return false;
+          }
+          Cond = ICmpInst::ICMP_NE;
+          // NE is symmetric but the original comparison may not be. Swap
+          // the operands if necessary so that they match below.
+          if (isa<SCEVConstant>(LHS))
+            std::swap(PreCondLHS, PreCondRHS);
           break;
         }
-        continue;
-      case ICmpInst::ICMP_NE:
-        // Expressions like (x >u 0) are often canonicalized to (x != 0),
-        // so check for this case by checking if the NE is comparing against
-        // a minimum or maximum constant.
-        if (!ICmpInst::isTrueWhenEqual(Pred))
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(PreCondRHS)) {
-            const APInt &A = CI->getValue();
-            switch (Pred) {
-            case ICmpInst::ICMP_SLT:
-              if (A.isMaxSignedValue()) break;
-              continue;
-            case ICmpInst::ICMP_SGT:
-              if (A.isMinSignedValue()) break;
-              continue;
-            case ICmpInst::ICMP_ULT:
-              if (A.isMaxValue()) break;
-              continue;
-            case ICmpInst::ICMP_UGT:
-              if (A.isMinValue()) break;
-              continue;
-            default:
-              continue;
-            }
-            Cond = ICmpInst::ICMP_NE;
-            // NE is symmetric but the original comparison may not be. Swap
-            // the operands if necessary so that they match below.
-            if (isa<SCEVConstant>(LHS))
-              std::swap(PreCondLHS, PreCondRHS);
-            break;
-          }
-        continue;
-      default:
-        // We weren't able to reconcile the condition.
-        continue;
-      }
-
-    if (!PreCondLHS->getType()->isInteger()) continue;
+      return false;
+    default:
+      // We weren't able to reconcile the condition.
+      return false;
+    }
 
-    const SCEV* PreCondLHSSCEV = getSCEV(PreCondLHS);
-    const SCEV* PreCondRHSSCEV = getSCEV(PreCondRHS);
-    if ((HasSameValue(LHS, PreCondLHSSCEV) &&
-         HasSameValue(RHS, PreCondRHSSCEV)) ||
-        (HasSameValue(LHS, getNotSCEV(PreCondRHSSCEV)) &&
-         HasSameValue(RHS, getNotSCEV(PreCondLHSSCEV))))
-      return true;
-  }
+  if (!PreCondLHS->getType()->isInteger()) return false;
 
-  return false;
+  const SCEV *PreCondLHSSCEV = getSCEV(PreCondLHS);
+  const SCEV *PreCondRHSSCEV = getSCEV(PreCondRHS);
+  return (HasSameValue(LHS, PreCondLHSSCEV) &&
+          HasSameValue(RHS, PreCondRHSSCEV)) ||
+         (HasSameValue(LHS, getNotSCEV(PreCondRHSSCEV)) &&
+          HasSameValue(RHS, getNotSCEV(PreCondLHSSCEV)));
 }
 
 /// getBECount - Subtract the end and start values and divide by the step,