Update SetVector to rely on the underlying set's insert to return a pair<iterator...
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 7324344c3e0e71ec421ce8d4061fd2985bf628bb..68549efad501f655f84b05c220674425338cd450 100644 (file)
@@ -703,6 +703,34 @@ static const APInt sdiv(const SCEVConstant *C1, const SCEVConstant *C2) {
   return APIntOps::sdiv(A, B);
 }
 
+static const APInt urem(const SCEVConstant *C1, const SCEVConstant *C2) {
+  APInt A = C1->getValue()->getValue();
+  APInt B = C2->getValue()->getValue();
+  uint32_t ABW = A.getBitWidth();
+  uint32_t BBW = B.getBitWidth();
+
+  if (ABW > BBW)
+    B = B.zext(ABW);
+  else if (ABW < BBW)
+    A = A.zext(BBW);
+
+  return APIntOps::urem(A, B);
+}
+
+static const APInt udiv(const SCEVConstant *C1, const SCEVConstant *C2) {
+  APInt A = C1->getValue()->getValue();
+  APInt B = C2->getValue()->getValue();
+  uint32_t ABW = A.getBitWidth();
+  uint32_t BBW = B.getBitWidth();
+
+  if (ABW > BBW)
+    B = B.zext(ABW);
+  else if (ABW < BBW)
+    A = A.zext(BBW);
+
+  return APIntOps::udiv(A, B);
+}
+
 namespace {
 struct FindSCEVSize {
   int Size;
@@ -729,7 +757,8 @@ static inline int sizeOfSCEV(const SCEV *S) {
 
 namespace {
 
-struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> {
+template <typename Derived>
+struct SCEVDivision : public SCEVVisitor<Derived, void> {
 public:
   // Computes the Quotient and Remainder of the division of Numerator by
   // Denominator.
@@ -738,7 +767,7 @@ public:
                      const SCEV **Remainder) {
     assert(Numerator && Denominator && "Uninitialized SCEV");
 
-    SCEVDivision D(SE, Numerator, Denominator);
+    Derived D(SE, Numerator, Denominator);
 
     // Check for the trivial case here to avoid having to check for it in the
     // rest of the code.
@@ -779,17 +808,6 @@ public:
     *Remainder = D.Remainder;
   }
 
-  SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, const SCEV *Denominator)
-      : SE(S), Denominator(Denominator) {
-    Zero = SE.getConstant(Denominator->getType(), 0);
-    One = SE.getConstant(Denominator->getType(), 1);
-
-    // By default, we don't know how to divide Expr by Denominator.
-    // Providing the default here simplifies the rest of the code.
-    Quotient = Zero;
-    Remainder = Numerator;
-  }
-
   // Except in the trivial case described above, we do not know how to divide
   // Expr by Denominator for the following functions with empty implementation.
   void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {}
@@ -801,14 +819,6 @@ public:
   void visitUnknown(const SCEVUnknown *Numerator) {}
   void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {}
 
-  void visitConstant(const SCEVConstant *Numerator) {
-    if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
-      Quotient = SE.getConstant(sdiv(Numerator, D));
-      Remainder = SE.getConstant(srem(Numerator, D));
-      return;
-    }
-  }
-
   void visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
     const SCEV *StartQ, *StartR, *StepQ, *StepR;
     assert(Numerator->isAffine() && "Numerator should be affine");
@@ -932,12 +942,54 @@ public:
   }
 
 private:
+  SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
+               const SCEV *Denominator)
+      : SE(S), Denominator(Denominator) {
+    Zero = SE.getConstant(Denominator->getType(), 0);
+    One = SE.getConstant(Denominator->getType(), 1);
+
+    // By default, we don't know how to divide Expr by Denominator.
+    // Providing the default here simplifies the rest of the code.
+    Quotient = Zero;
+    Remainder = Numerator;
+  }
+
   ScalarEvolution &SE;
   const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One;
+
+  friend struct SCEVSDivision;
+  friend struct SCEVUDivision;
+};
+
+struct SCEVSDivision : public SCEVDivision<SCEVSDivision> {
+  SCEVSDivision(ScalarEvolution &S, const SCEV *Numerator,
+                const SCEV *Denominator)
+      : SCEVDivision(S, Numerator, Denominator) {}
+
+  void visitConstant(const SCEVConstant *Numerator) {
+    if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
+      Quotient = SE.getConstant(sdiv(Numerator, D));
+      Remainder = SE.getConstant(srem(Numerator, D));
+      return;
+    }
+  }
 };
-}
 
+struct SCEVUDivision : public SCEVDivision<SCEVUDivision> {
+  SCEVUDivision(ScalarEvolution &S, const SCEV *Numerator,
+                const SCEV *Denominator)
+      : SCEVDivision(S, Numerator, Denominator) {}
+
+  void visitConstant(const SCEVConstant *Numerator) {
+    if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
+      Quotient = SE.getConstant(udiv(Numerator, D));
+      Remainder = SE.getConstant(urem(Numerator, D));
+      return;
+    }
+  }
+};
 
+}
 
 //===----------------------------------------------------------------------===//
 //                      Simple SCEV method implementations
@@ -3343,7 +3395,8 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
   Visited.insert(PN);
   while (!Worklist.empty()) {
     Instruction *I = Worklist.pop_back_val();
-    if (!Visited.insert(I)) continue;
+    if (!Visited.insert(I).second)
+      continue;
 
     ValueExprMapType::iterator It =
       ValueExprMap.find_as(static_cast<Value *>(I));
@@ -4541,7 +4594,8 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
     SmallPtrSet<Instruction *, 8> Visited;
     while (!Worklist.empty()) {
       Instruction *I = Worklist.pop_back_val();
-      if (!Visited.insert(I)) continue;
+      if (!Visited.insert(I).second)
+        continue;
 
       ValueExprMapType::iterator It =
         ValueExprMap.find_as(static_cast<Value *>(I));
@@ -4593,7 +4647,8 @@ void ScalarEvolution::forgetLoop(const Loop *L) {
   SmallPtrSet<Instruction *, 8> Visited;
   while (!Worklist.empty()) {
     Instruction *I = Worklist.pop_back_val();
-    if (!Visited.insert(I)) continue;
+    if (!Visited.insert(I).second)
+      continue;
 
     ValueExprMapType::iterator It =
       ValueExprMap.find_as(static_cast<Value *>(I));
@@ -4627,7 +4682,8 @@ void ScalarEvolution::forgetValue(Value *V) {
   SmallPtrSet<Instruction *, 8> Visited;
   while (!Worklist.empty()) {
     I = Worklist.pop_back_val();
-    if (!Visited.insert(I)) continue;
+    if (!Visited.insert(I).second)
+      continue;
 
     ValueExprMapType::iterator It =
       ValueExprMap.find_as(static_cast<Value *>(I));
@@ -6086,7 +6142,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) {
   // backedge count.
   const SCEV *Q, *R;
   ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
-  SCEVDivision::divide(SE, Distance, Step, &Q, &R);
+  SCEVUDivision::divide(SE, Distance, Step, &Q, &R);
   if (R->isZero()) {
     const SCEV *Exact =
         getUDivExactExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
@@ -6782,6 +6838,66 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
                                    RHS, LHS, FoundLHS, FoundRHS);
   }
 
+  // Check if we can make progress by sharpening ranges.
+  if (FoundPred == ICmpInst::ICMP_NE &&
+      (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
+
+    const SCEVConstant *C = nullptr;
+    const SCEV *V = nullptr;
+
+    if (isa<SCEVConstant>(FoundLHS)) {
+      C = cast<SCEVConstant>(FoundLHS);
+      V = FoundRHS;
+    } else {
+      C = cast<SCEVConstant>(FoundRHS);
+      V = FoundLHS;
+    }
+
+    // The guarding predicate tells us that C != V. If the known range
+    // of V is [C, t), we can sharpen the range to [C + 1, t).  The
+    // range we consider has to correspond to same signedness as the
+    // predicate we're interested in folding.
+
+    APInt Min = ICmpInst::isSigned(Pred) ?
+        getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin();
+
+    if (Min == C->getValue()->getValue()) {
+      // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
+      // This is true even if (Min + 1) wraps around -- in case of
+      // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
+
+      APInt SharperMin = Min + 1;
+
+      switch (Pred) {
+        case ICmpInst::ICMP_SGE:
+        case ICmpInst::ICMP_UGE:
+          // We know V `Pred` SharperMin.  If this implies LHS `Pred`
+          // RHS, we're done.
+          if (isImpliedCondOperands(Pred, LHS, RHS, V,
+                                    getConstant(SharperMin)))
+            return true;
+
+        case ICmpInst::ICMP_SGT:
+        case ICmpInst::ICMP_UGT:
+          // We know from the range information that (V `Pred` Min ||
+          // V == Min).  We know from the guarding condition that !(V
+          // == Min).  This gives us
+          //
+          //       V `Pred` Min || V == Min && !(V == Min)
+          //   =>  V `Pred` Min
+          //
+          // If V `Pred` Min implies LHS `Pred` RHS, we're done.
+
+          if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min)))
+            return true;
+
+        default:
+          // No change
+          break;
+      }
+    }
+  }
+
   // Check whether the actual condition is beyond sufficient.
   if (FoundPred == ICmpInst::ICMP_EQ)
     if (ICmpInst::isTrueWhenEqual(Pred))
@@ -7341,7 +7457,7 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE,
   for (const SCEV *&Term : Terms) {
     // Normalize the terms before the next call to findArrayDimensionsRec.
     const SCEV *Q, *R;
-    SCEVDivision::divide(SE, Term, Step, &Q, &R);
+    SCEVSDivision::divide(SE, Term, Step, &Q, &R);
 
     // Bail out when GCD does not evenly divide one of the terms.
     if (!R->isZero())
@@ -7478,7 +7594,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
   // Divide all terms by the element size.
   for (const SCEV *&Term : Terms) {
     const SCEV *Q, *R;
-    SCEVDivision::divide(SE, Term, ElementSize, &Q, &R);
+    SCEVSDivision::divide(SE, Term, ElementSize, &Q, &R);
     Term = Q;
   }
 
@@ -7525,7 +7641,7 @@ void SCEVAddRecExpr::computeAccessFunctions(
   int Last = Sizes.size() - 1;
   for (int i = Last; i >= 0; i--) {
     const SCEV *Q, *R;
-    SCEVDivision::divide(SE, Res, Sizes[i], &Q, &R);
+    SCEVSDivision::divide(SE, Res, Sizes[i], &Q, &R);
 
     DEBUG({
         dbgs() << "Res: " << *Res << "\n";
@@ -7680,7 +7796,7 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
     // that until everything else is done.
     if (U == Old)
       continue;
-    if (!Visited.insert(U))
+    if (!Visited.insert(U).second)
       continue;
     if (PHINode *PN = dyn_cast<PHINode>(U))
       SE->ConstantEvolutionLoopExitValue.erase(PN);