Introduce a range version of std::find, and use in SCEV
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 4652e449f4bfb0a164ff5e0d2a3ca8a24d18c725..d04028b15e2fa6adc58ce92c61c1636ea5a53128 100644 (file)
@@ -1957,11 +1957,11 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
       ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
 
   // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
-  auto IsKnownNonNegative =
-    std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1);
+  auto IsKnownNonNegative = [&](const SCEV *S) {
+    return SE->isKnownNonNegative(S);
+  };
 
-  if (SignOrUnsignWrap == SCEV::FlagNSW &&
-      std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative))
+  if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
     Flags =
         ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
 
@@ -2888,9 +2888,8 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
       // AddRecs require their operands be loop-invariant with respect to their
       // loops. Don't perform this transformation if it would break this
       // requirement.
-      bool AllInvariant =
-          std::all_of(Operands.begin(), Operands.end(),
-                      [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
+      bool AllInvariant = all_of(
+          Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
 
       if (AllInvariant) {
         // Create a recurrence for the outer loop with the same step size.
@@ -2901,9 +2900,9 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
           maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
 
         NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
-        AllInvariant = std::all_of(
-            NestedOperands.begin(), NestedOperands.end(),
-            [&](const SCEV *Op) { return isLoopInvariant(Op, NestedLoop); });
+        AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
+          return isLoopInvariant(Op, NestedLoop);
+        });
 
         if (AllInvariant) {
           // Ok, both add recurrences are valid after the transformation.
@@ -3626,6 +3625,7 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
   }
 }
 
+namespace {
 class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
 public:
   static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
@@ -3690,6 +3690,7 @@ private:
   const Loop *L;
   bool Valid;
 };
+} // end anonymous namespace
 
 const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
   const Loop *L = LI.getLoopFor(PN->getParent());
@@ -6124,22 +6125,22 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
 /// In the case that a relevant loop exit value cannot be computed, the
 /// original value V is returned.
 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
+  SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
+      ValuesAtScopes[V];
   // Check to see if we've folded this expression at this loop before.
-  SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = ValuesAtScopes[V];
-  for (unsigned u = 0; u < Values.size(); u++) {
-    if (Values[u].first == L)
-      return Values[u].second ? Values[u].second : V;
-  }
-  Values.push_back(std::make_pair(L, static_cast<const SCEV *>(nullptr)));
+  for (auto &LS : Values)
+    if (LS.first == L)
+      return LS.second ? LS.second : V;
+
+  Values.emplace_back(L, nullptr);
+
   // Otherwise compute it.
   const SCEV *C = computeSCEVAtScope(V, L);
-  SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values2 = ValuesAtScopes[V];
-  for (unsigned u = Values2.size(); u > 0; u--) {
-    if (Values2[u - 1].first == L) {
-      Values2[u - 1].second = C;
+  for (auto &LS : reverse(ValuesAtScopes[V]))
+    if (LS.first == L) {
+      LS.second = C;
       break;
     }
-  }
   return C;
 }
 
@@ -7043,16 +7044,14 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
       Pred = ICmpInst::ICMP_ULT;
       Changed = true;
     } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
-      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
-                       SCEV::FlagNUW);
+      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
       Pred = ICmpInst::ICMP_ULT;
       Changed = true;
     }
     break;
   case ICmpInst::ICMP_UGE:
     if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
-      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
-                       SCEV::FlagNUW);
+      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
       Pred = ICmpInst::ICMP_UGT;
       Changed = true;
     } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
@@ -7348,6 +7347,62 @@ ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
   return false;
 }
 
+bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
+                                                    const SCEV *LHS,
+                                                    const SCEV *RHS) {
+
+  // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.
+  // Return Y via OutY.
+  auto MatchBinaryAddToConst =
+      [this](const SCEV *Result, const SCEV *X, APInt &OutY,
+             SCEV::NoWrapFlags ExpectedFlags) {
+    const SCEV *NonConstOp, *ConstOp;
+    SCEV::NoWrapFlags FlagsPresent;
+
+    if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) ||
+        !isa<SCEVConstant>(ConstOp) || NonConstOp != X)
+      return false;
+
+    OutY = cast<SCEVConstant>(ConstOp)->getValue()->getValue();
+    return (FlagsPresent & ExpectedFlags) == ExpectedFlags;
+  };
+
+  APInt C;
+
+  switch (Pred) {
+  default:
+    break;
+
+  case ICmpInst::ICMP_SGE:
+    std::swap(LHS, RHS);
+  case ICmpInst::ICMP_SLE:
+    // X s<= (X + C)<nsw> if C >= 0
+    if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative())
+      return true;
+
+    // (X + C)<nsw> s<= X if C <= 0
+    if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
+        !C.isStrictlyPositive())
+      return true;
+    break;
+
+  case ICmpInst::ICMP_SGT:
+    std::swap(LHS, RHS);
+  case ICmpInst::ICMP_SLT:
+    // X s< (X + C)<nsw> if C > 0
+    if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) &&
+        C.isStrictlyPositive())
+      return true;
+
+    // (X + C)<nsw> s< X if C < 0
+    if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
+      return true;
+    break;
+  }
+
+  return false;
+}
+
 bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
                                                    const SCEV *LHS,
                                                    const SCEV *RHS) {
@@ -7909,8 +7964,7 @@ static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr,
   const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr);
   if (!MaxExpr) return false;
 
-  auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate);
-  return It != MaxExpr->op_end();
+  return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end();
 }
 
 
@@ -8004,7 +8058,8 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
       [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
     return isKnownPredicateWithRanges(Pred, LHS, RHS) ||
            IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
-           IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS);
+           IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
+           isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
   };
 
   switch (Pred) {
@@ -8347,8 +8402,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
 
   // The only time we can solve this is when we have all constant indices.
   // Otherwise, we cannot determine the overflow conditions.
-  if (std::any_of(op_begin(), op_end(),
-                  [](const SCEV *Op) { return !isa<SCEVConstant>(Op);}))
+  if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
     return SE.getCouldNotCompute();
 
   // Okay at this point we know that all elements of the chrec are constants and
@@ -8401,15 +8455,13 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
                                              FlagAnyWrap);
 
     // Next, solve the constructed addrec
-    std::pair<const SCEV *,const SCEV *> Roots =
-      SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
+    auto Roots = SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
     const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
     const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
     if (R1) {
       // Pick the smallest positive root value.
-      if (ConstantInt *CB =
-          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
-                         R1->getValue(), R2->getValue()))) {
+      if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp(
+              ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) {
         if (!CB->getZExtValue())
           std::swap(R1, R2);   // R1 is the minimum root now.
 
@@ -9210,9 +9262,8 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
 
     // This recurrence is variant w.r.t. L if any of its operands
     // are variant.
-    for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
-         I != E; ++I)
-      if (!isLoopInvariant(*I, L))
+    for (auto *Op : AR->operands())
+      if (!isLoopInvariant(Op, L))
         return LoopVariant;
 
     // Otherwise it's loop-invariant.
@@ -9222,11 +9273,9 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
   case scMulExpr:
   case scUMaxExpr:
   case scSMaxExpr: {
-    const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
     bool HasVarying = false;
-    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
-         I != E; ++I) {
-      LoopDisposition D = getLoopDisposition(*I, L);
+    for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
+      LoopDisposition D = getLoopDisposition(Op, L);
       if (D == LoopVariant)
         return LoopVariant;
       if (D == LoopComputable)
@@ -9250,7 +9299,7 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
     // invariant if they are not contained in the specified loop.
     // Instructions are never considered invariant in the function body
     // (null loop) because they are defined within the "loop".
-    if (Instruction *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
+    if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
       return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
     return LoopInvariant;
   case scCouldNotCompute:
@@ -9557,6 +9606,7 @@ ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
   return Eq;
 }
 
+namespace {
 class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
 public:
   static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
@@ -9581,6 +9631,7 @@ public:
 private:
   SCEVUnionPredicate &P;
 };
+} // end anonymous namespace
 
 const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev,
                                                    SCEVUnionPredicate &Preds) {
@@ -9619,8 +9670,8 @@ SCEVUnionPredicate::SCEVUnionPredicate()
     : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
 
 bool SCEVUnionPredicate::isAlwaysTrue() const {
-  return std::all_of(Preds.begin(), Preds.end(),
-                     [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
+  return all_of(Preds,
+                [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
 }
 
 ArrayRef<const SCEVPredicate *>
@@ -9633,17 +9684,16 @@ SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
 
 bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
   if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N))
-    return std::all_of(
-        Set->Preds.begin(), Set->Preds.end(),
-        [this](const SCEVPredicate *I) { return this->implies(I); });
+    return all_of(Set->Preds,
+                  [this](const SCEVPredicate *I) { return this->implies(I); });
 
   auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
   if (ScevPredsIt == SCEVToPreds.end())
     return false;
   auto &SCEVPreds = ScevPredsIt->second;
 
-  return std::any_of(SCEVPreds.begin(), SCEVPreds.end(),
-                     [N](const SCEVPredicate *I) { return I->implies(N); });
+  return any_of(SCEVPreds,
+                [N](const SCEVPredicate *I) { return I->implies(N); });
 }
 
 const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }