Introduce a range version of std::find, and use in SCEV
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index c2db02fe85a49aaa5cfef7ed849a84e30f9dbdfb..d04028b15e2fa6adc58ce92c61c1636ea5a53128 100644 (file)
@@ -1957,11 +1957,11 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
       ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
 
   // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
-  auto IsKnownNonNegative =
-    std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1);
+  auto IsKnownNonNegative = [&](const SCEV *S) {
+    return SE->isKnownNonNegative(S);
+  };
 
-  if (SignOrUnsignWrap == SCEV::FlagNSW &&
-      std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative))
+  if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
     Flags =
         ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
 
@@ -2888,9 +2888,8 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
       // AddRecs require their operands be loop-invariant with respect to their
       // loops. Don't perform this transformation if it would break this
       // requirement.
-      bool AllInvariant =
-          std::all_of(Operands.begin(), Operands.end(),
-                      [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
+      bool AllInvariant = all_of(
+          Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
 
       if (AllInvariant) {
         // Create a recurrence for the outer loop with the same step size.
@@ -2901,9 +2900,9 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
           maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
 
         NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
-        AllInvariant = std::all_of(
-            NestedOperands.begin(), NestedOperands.end(),
-            [&](const SCEV *Op) { return isLoopInvariant(Op, NestedLoop); });
+        AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
+          return isLoopInvariant(Op, NestedLoop);
+        });
 
         if (AllInvariant) {
           // Ok, both add recurrences are valid after the transformation.
@@ -3626,6 +3625,7 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
   }
 }
 
+namespace {
 class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
 public:
   static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
@@ -3690,6 +3690,7 @@ private:
   const Loop *L;
   bool Valid;
 };
+} // end anonymous namespace
 
 const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
   const Loop *L = LI.getLoopFor(PN->getParent());
@@ -6124,22 +6125,22 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
 /// In the case that a relevant loop exit value cannot be computed, the
 /// original value V is returned.
 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
+  SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
+      ValuesAtScopes[V];
   // Check to see if we've folded this expression at this loop before.
-  SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = ValuesAtScopes[V];
-  for (unsigned u = 0; u < Values.size(); u++) {
-    if (Values[u].first == L)
-      return Values[u].second ? Values[u].second : V;
-  }
-  Values.push_back(std::make_pair(L, static_cast<const SCEV *>(nullptr)));
+  for (auto &LS : Values)
+    if (LS.first == L)
+      return LS.second ? LS.second : V;
+
+  Values.emplace_back(L, nullptr);
+
   // Otherwise compute it.
   const SCEV *C = computeSCEVAtScope(V, L);
-  SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values2 = ValuesAtScopes[V];
-  for (unsigned u = Values2.size(); u > 0; u--) {
-    if (Values2[u - 1].first == L) {
-      Values2[u - 1].second = C;
+  for (auto &LS : reverse(ValuesAtScopes[V]))
+    if (LS.first == L) {
+      LS.second = C;
       break;
     }
-  }
   return C;
 }
 
@@ -7043,16 +7044,14 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
       Pred = ICmpInst::ICMP_ULT;
       Changed = true;
     } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
-      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
-                       SCEV::FlagNUW);
+      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
       Pred = ICmpInst::ICMP_ULT;
       Changed = true;
     }
     break;
   case ICmpInst::ICMP_UGE:
     if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
-      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
-                       SCEV::FlagNUW);
+      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
       Pred = ICmpInst::ICMP_UGT;
       Changed = true;
     } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
@@ -7965,8 +7964,7 @@ static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr,
   const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr);
   if (!MaxExpr) return false;
 
-  auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate);
-  return It != MaxExpr->op_end();
+  return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end();
 }
 
 
@@ -8404,8 +8402,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
 
   // The only time we can solve this is when we have all constant indices.
   // Otherwise, we cannot determine the overflow conditions.
-  if (std::any_of(op_begin(), op_end(),
-                  [](const SCEV *Op) { return !isa<SCEVConstant>(Op);}))
+  if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
     return SE.getCouldNotCompute();
 
   // Okay at this point we know that all elements of the chrec are constants and
@@ -8458,15 +8455,13 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
                                              FlagAnyWrap);
 
     // Next, solve the constructed addrec
-    std::pair<const SCEV *,const SCEV *> Roots =
-      SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
+    auto Roots = SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
     const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
     const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
     if (R1) {
       // Pick the smallest positive root value.
-      if (ConstantInt *CB =
-          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
-                         R1->getValue(), R2->getValue()))) {
+      if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp(
+              ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) {
         if (!CB->getZExtValue())
           std::swap(R1, R2);   // R1 is the minimum root now.
 
@@ -9267,9 +9262,8 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
 
     // This recurrence is variant w.r.t. L if any of its operands
     // are variant.
-    for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
-         I != E; ++I)
-      if (!isLoopInvariant(*I, L))
+    for (auto *Op : AR->operands())
+      if (!isLoopInvariant(Op, L))
         return LoopVariant;
 
     // Otherwise it's loop-invariant.
@@ -9279,11 +9273,9 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
   case scMulExpr:
   case scUMaxExpr:
   case scSMaxExpr: {
-    const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
     bool HasVarying = false;
-    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
-         I != E; ++I) {
-      LoopDisposition D = getLoopDisposition(*I, L);
+    for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
+      LoopDisposition D = getLoopDisposition(Op, L);
       if (D == LoopVariant)
         return LoopVariant;
       if (D == LoopComputable)
@@ -9307,7 +9299,7 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
     // invariant if they are not contained in the specified loop.
     // Instructions are never considered invariant in the function body
     // (null loop) because they are defined within the "loop".
-    if (Instruction *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
+    if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
       return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
     return LoopInvariant;
   case scCouldNotCompute:
@@ -9614,6 +9606,7 @@ ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
   return Eq;
 }
 
+namespace {
 class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
 public:
   static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
@@ -9638,6 +9631,7 @@ public:
 private:
   SCEVUnionPredicate &P;
 };
+} // end anonymous namespace
 
 const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev,
                                                    SCEVUnionPredicate &Preds) {
@@ -9676,8 +9670,8 @@ SCEVUnionPredicate::SCEVUnionPredicate()
     : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
 
 bool SCEVUnionPredicate::isAlwaysTrue() const {
-  return std::all_of(Preds.begin(), Preds.end(),
-                     [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
+  return all_of(Preds,
+                [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
 }
 
 ArrayRef<const SCEVPredicate *>
@@ -9690,17 +9684,16 @@ SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
 
 bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
   if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N))
-    return std::all_of(
-        Set->Preds.begin(), Set->Preds.end(),
-        [this](const SCEVPredicate *I) { return this->implies(I); });
+    return all_of(Set->Preds,
+                  [this](const SCEVPredicate *I) { return this->implies(I); });
 
   auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
   if (ScevPredsIt == SCEVToPreds.end())
     return false;
   auto &SCEVPreds = ScevPredsIt->second;
 
-  return std::any_of(SCEVPreds.begin(), SCEVPreds.end(),
-                     [N](const SCEVPredicate *I) { return I->implies(N); });
+  return any_of(SCEVPreds,
+                [N](const SCEVPredicate *I) { return I->implies(N); });
 }
 
 const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }