Allow min/max detection to see through casts.
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 26b468013c6f67bcf38f58a6d1764173612113cd..9d99b8f77c99afc0109df833048d0f086885af06 100644 (file)
@@ -726,6 +726,13 @@ public:
       return;
     }
 
+    // A simple case when N/1. The quotient is N.
+    if (Denominator->isOne()) {
+      *Quotient = Numerator;
+      *Remainder = D.Zero;
+      return;
+    }
+
     // Split the Denominator when it is a product.
     if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) {
       const SCEV *Q, *R;
@@ -788,6 +795,14 @@ public:
     assert(Numerator->isAffine() && "Numerator should be affine");
     divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
     divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
+    // Bail out if the types do not match.
+    Type *Ty = Denominator->getType();
+    if (Ty != StartQ->getType() || Ty != StartR->getType() ||
+        Ty != StepQ->getType() || Ty != StepR->getType()) {
+      Quotient = Zero;
+      Remainder = Numerator;
+      return;
+    }
     Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
                                 Numerator->getNoWrapFlags());
     Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
@@ -5690,7 +5705,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
             if (PTy->getElementType()->isStructTy())
               C2 = ConstantExpr::getIntegerCast(
                   C2, Type::getInt32Ty(C->getContext()), true);
-            C = ConstantExpr::getGetElementPtr(C, C2);
+            C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2);
           } else
             C = ConstantExpr::getAdd(C, C2);
         }
@@ -6686,6 +6701,37 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
                     LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
     return true;
 
+  // Check conditions due to any @llvm.assume intrinsics.
+  for (auto &AssumeVH : AC->assumptions()) {
+    if (!AssumeVH)
+      continue;
+    auto *CI = cast<CallInst>(AssumeVH);
+    if (!DT->dominates(CI, Latch->getTerminator()))
+      continue;
+
+    if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
+      return true;
+  }
+
+  struct ClearWalkingBEDominatingCondsOnExit {
+    ScalarEvolution &SE;
+
+    explicit ClearWalkingBEDominatingCondsOnExit(ScalarEvolution &SE)
+        : SE(SE){};
+
+    ~ClearWalkingBEDominatingCondsOnExit() {
+      SE.WalkingBEDominatingConds = false;
+    }
+  };
+
+  // We don't want more than one activation of the following loop on the stack
+  // -- that can lead to O(n!) time complexity.
+  if (WalkingBEDominatingConds)
+    return false;
+
+  WalkingBEDominatingConds = true;
+  ClearWalkingBEDominatingCondsOnExit ClearOnExit(*this);
+
   // If the loop is not reachable from the entry block, we risk running into an
   // infinite loop as we walk up into the dom tree.  These loops do not matter
   // anyway, so we just return a conservative answer when we see them.
@@ -6726,18 +6772,6 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
     }
   }
 
-  // Check conditions due to any @llvm.assume intrinsics.
-  for (auto &AssumeVH : AC->assumptions()) {
-    if (!AssumeVH)
-      continue;
-    auto *CI = cast<CallInst>(AssumeVH);
-    if (!DT->dominates(CI, Latch->getTerminator()))
-      continue;
-
-    if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
-      return true;
-  }
-
   return false;
 }
 
@@ -8008,8 +8042,8 @@ ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
 //===----------------------------------------------------------------------===//
 
 ScalarEvolution::ScalarEvolution()
-  : FunctionPass(ID), ValuesAtScopes(64), LoopDispositions(64),
-    BlockDispositions(64), FirstUnknown(nullptr) {
+    : FunctionPass(ID), WalkingBEDominatingConds(false), ValuesAtScopes(64),
+      LoopDispositions(64), BlockDispositions(64), FirstUnknown(nullptr) {
   initializeScalarEvolutionPass(*PassRegistry::getPassRegistry());
 }
 
@@ -8040,6 +8074,7 @@ void ScalarEvolution::releaseMemory() {
   }
 
   assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
+  assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
 
   BackedgeTakenCounts.clear();
   ConstantEvolutionLoopExitValue.clear();