Introduce a range version of std::find, and use in SCEV
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 6b32a7a9d2960d4ae820d8771a9fc89bb564f92d..d04028b15e2fa6adc58ce92c61c1636ea5a53128 100644 (file)
@@ -83,6 +83,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -123,12 +124,11 @@ VerifySCEV("verify-scev",
 // Implementation of the SCEV class.
 //
 
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD
 void SCEV::dump() const {
   print(dbgs());
   dbgs() << '\n';
 }
-#endif
 
 void SCEV::print(raw_ostream &OS) const {
   switch (static_cast<SCEVTypes>(getSCEVType())) {
@@ -1268,7 +1268,9 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
   // `Step`:
 
   // 1. NSW/NUW flags on the step increment.
-  const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags());
+  auto PreStartFlags =
+    ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW);
+  const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
   const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
       SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
 
@@ -1303,9 +1305,9 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
       ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
 
   if (OverflowLimit &&
-      SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) {
+      SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
     return PreStart;
-  }
+
   return nullptr;
 }
 
@@ -1376,19 +1378,17 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
   for (unsigned Delta : {-2, -1, 1, 2}) {
     const SCEV *PreStart = getConstant(StartAI - Delta);
 
+    FoldingSetNodeID ID;
+    ID.AddInteger(scAddRecExpr);
+    ID.AddPointer(PreStart);
+    ID.AddPointer(Step);
+    ID.AddPointer(L);
+    void *IP = nullptr;
+    const auto *PreAR =
+      static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
+
     // Give up if we don't already have the add recurrence we need because
     // actually constructing an add recurrence is relatively expensive.
-    const SCEVAddRecExpr *PreAR = [&]() {
-      FoldingSetNodeID ID;
-      ID.AddInteger(scAddRecExpr);
-      ID.AddPointer(PreStart);
-      ID.AddPointer(Step);
-      ID.AddPointer(L);
-      void *IP = nullptr;
-      return static_cast<SCEVAddRecExpr *>(
-          this->UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
-    }();
-
     if (PreAR && PreAR->getNoWrapFlags(WrapType)) {  // proves (2)
       const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
       ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
@@ -1559,6 +1559,18 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
       }
     }
 
+  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
+    // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
+    if (SA->getNoWrapFlags(SCEV::FlagNUW)) {
+      // If the addition does not unsign overflow then we can, by definition,
+      // commute the zero extension with the addition operation.
+      SmallVector<const SCEV *, 4> Ops;
+      for (const auto *Op : SA->operands())
+        Ops.push_back(getZeroExtendExpr(Op, Ty));
+      return getAddExpr(Ops, SCEV::FlagNUW);
+    }
+  }
+
   // The cast wasn't folded; create an explicit cast node.
   // Recompute the insert position, as it may have been invalidated.
   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
@@ -1631,6 +1643,16 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
         }
       }
     }
+
+    // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
+    if (SA->getNoWrapFlags(SCEV::FlagNSW)) {
+      // If the addition does not sign overflow then we can, by definition,
+      // commute the sign extension with the addition operation.
+      SmallVector<const SCEV *, 4> Ops;
+      for (const auto *Op : SA->operands())
+        Ops.push_back(getSignExtendExpr(Op, Ty));
+      return getAddExpr(Ops, SCEV::FlagNSW);
+    }
   }
   // If the input value is a chrec scev, and we can prove that the value
   // did not overflow the old, smaller, value, we can sign extend all of the
@@ -1921,8 +1943,9 @@ namespace {
 static SCEV::NoWrapFlags
 StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
                       const SmallVectorImpl<const SCEV *> &Ops,
-                      SCEV::NoWrapFlags OldFlags) {
+                      SCEV::NoWrapFlags Flags) {
   using namespace std::placeholders;
+  typedef OverflowingBinaryOperator OBO;
 
   bool CanAnalyze =
       Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
@@ -1931,18 +1954,42 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
 
   int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
   SCEV::NoWrapFlags SignOrUnsignWrap =
-      ScalarEvolution::maskFlags(OldFlags, SignOrUnsignMask);
+      ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
 
   // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
-  auto IsKnownNonNegative =
-    std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1);
+  auto IsKnownNonNegative = [&](const SCEV *S) {
+    return SE->isKnownNonNegative(S);
+  };
+
+  if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
+    Flags =
+        ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
+
+  SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
+
+  if (SignOrUnsignWrap != SignOrUnsignMask && Type == scAddExpr &&
+      Ops.size() == 2 && isa<SCEVConstant>(Ops[0])) {
 
-  if (SignOrUnsignWrap == SCEV::FlagNSW &&
-      std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative))
-    return ScalarEvolution::setFlags(OldFlags,
-                                     (SCEV::NoWrapFlags)SignOrUnsignMask);
+    // (A + C) --> (A + C)<nsw> if the addition does not sign overflow
+    // (A + C) --> (A + C)<nuw> if the addition does not unsign overflow
 
-  return OldFlags;
+    const APInt &C = cast<SCEVConstant>(Ops[0])->getValue()->getValue();
+    if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
+      auto NSWRegion =
+        ConstantRange::makeNoWrapRegion(Instruction::Add, C, OBO::NoSignedWrap);
+      if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
+        Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
+    }
+    if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
+      auto NUWRegion =
+        ConstantRange::makeNoWrapRegion(Instruction::Add, C,
+                                        OBO::NoUnsignedWrap);
+      if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
+        Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
+    }
+  }
+
+  return Flags;
 }
 
 /// getAddExpr - Get a canonical add expression, or something simpler if
@@ -2106,18 +2153,16 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
       // re-generate the operands list. Group the operands by constant scale,
       // to avoid multiplying by the same constant scale multiple times.
       std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
-      for (SmallVectorImpl<const SCEV *>::const_iterator I = NewOps.begin(),
-           E = NewOps.end(); I != E; ++I)
-        MulOpLists[M.find(*I)->second].push_back(*I);
+      for (const SCEV *NewOp : NewOps)
+        MulOpLists[M.find(NewOp)->second].push_back(NewOp);
       // Re-generate the operands list.
       Ops.clear();
       if (AccumulatedConstant != 0)
         Ops.push_back(getConstant(AccumulatedConstant));
-      for (std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare>::iterator
-           I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I)
-        if (I->first != 0)
-          Ops.push_back(getMulExpr(getConstant(I->first),
-                                   getAddExpr(I->second)));
+      for (auto &MulOp : MulOpLists)
+        if (MulOp.first != 0)
+          Ops.push_back(getMulExpr(getConstant(MulOp.first),
+                                   getAddExpr(MulOp.second)));
       if (Ops.empty())
         return getZero(Ty);
       if (Ops.size() == 1)
@@ -2258,8 +2303,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
                                                AddRec->op_end());
         for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
              ++OtherIdx)
-          if (const SCEVAddRecExpr *OtherAddRec =
-                dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
+          if (const auto *OtherAddRec = dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
             if (OtherAddRec->getLoop() == AddRecLoop) {
               for (unsigned i = 0, e = OtherAddRec->getNumOperands();
                    i != e; ++i) {
@@ -2844,9 +2888,8 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
       // AddRecs require their operands be loop-invariant with respect to their
       // loops. Don't perform this transformation if it would break this
       // requirement.
-      bool AllInvariant =
-          std::all_of(Operands.begin(), Operands.end(),
-                      [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
+      bool AllInvariant = all_of(
+          Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
 
       if (AllInvariant) {
         // Create a recurrence for the outer loop with the same step size.
@@ -2857,9 +2900,9 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
           maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
 
         NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
-        AllInvariant = std::all_of(
-            NestedOperands.begin(), NestedOperands.end(),
-            [&](const SCEV *Op) { return isLoopInvariant(Op, NestedLoop); });
+        AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
+          return isLoopInvariant(Op, NestedLoop);
+        });
 
         if (AllInvariant) {
           // Ok, both add recurrences are valid after the transformation.
@@ -3172,8 +3215,7 @@ const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
   // We can bypass creating a target-independent
   // constant expression and then folding it back into a ConstantInt.
   // This is just a compile-time optimization.
-  return getConstant(IntTy,
-                     F.getParent()->getDataLayout().getTypeAllocSize(AllocTy));
+  return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
 }
 
 const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
@@ -3183,9 +3225,7 @@ const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
   // constant expression and then folding it back into a ConstantInt.
   // This is just a compile-time optimization.
   return getConstant(
-      IntTy,
-      F.getParent()->getDataLayout().getStructLayout(STy)->getElementOffset(
-          FieldNo));
+      IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
 }
 
 const SCEV *ScalarEvolution::getUnknown(Value *V) {
@@ -3227,7 +3267,7 @@ bool ScalarEvolution::isSCEVable(Type *Ty) const {
 /// for which isSCEVable must return true.
 uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
   assert(isSCEVable(Ty) && "Type is not SCEVable!");
-  return F.getParent()->getDataLayout().getTypeSizeInBits(Ty);
+  return getDataLayout().getTypeSizeInBits(Ty);
 }
 
 /// getEffectiveSCEVType - Return a type with the same bitwidth as
@@ -3237,13 +3277,12 @@ uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
 Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
   assert(isSCEVable(Ty) && "Type is not SCEVable!");
 
-  if (Ty->isIntegerTy()) {
+  if (Ty->isIntegerTy())
     return Ty;
-  }
 
   // The only other support type is pointer.
   assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
-  return F.getParent()->getDataLayout().getIntPtrType(Ty);
+  return getDataLayout().getIntPtrType(Ty);
 }
 
 const SCEV *ScalarEvolution::getCouldNotCompute() {
@@ -3514,8 +3553,7 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
 
   if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
     return getPointerBase(Cast->getOperand());
-  }
-  else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
+  } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
     const SCEV *PtrOp = nullptr;
     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
          I != E; ++I) {
@@ -3587,6 +3625,73 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
   }
 }
 
+namespace {
+class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
+public:
+  static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
+                             ScalarEvolution &SE) {
+    SCEVInitRewriter Rewriter(L, SE);
+    const SCEV *Result = Rewriter.visit(Scev);
+    return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
+  }
+
+  SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
+      : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
+
+  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+    if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant))
+      Valid = false;
+    return Expr;
+  }
+
+  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+    // Only allow AddRecExprs for this loop.
+    if (Expr->getLoop() == L)
+      return Expr->getStart();
+    Valid = false;
+    return Expr;
+  }
+
+  bool isValid() { return Valid; }
+
+private:
+  const Loop *L;
+  bool Valid;
+};
+
+class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
+public:
+  static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
+                             ScalarEvolution &SE) {
+    SCEVShiftRewriter Rewriter(L, SE);
+    const SCEV *Result = Rewriter.visit(Scev);
+    return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
+  }
+
+  SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
+      : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
+
+  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+    // Only allow AddRecExprs for this loop.
+    if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant))
+      Valid = false;
+    return Expr;
+  }
+
+  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+    if (Expr->getLoop() == L && Expr->isAffine())
+      return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
+    Valid = false;
+    return Expr;
+  }
+  bool isValid() { return Valid; }
+
+private:
+  const Loop *L;
+  bool Valid;
+};
+} // end anonymous namespace
+
 const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
   const Loop *L = LI.getLoopFor(PN->getParent());
   if (!L || L->getHeader() != PN->getParent())
@@ -3699,30 +3804,28 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
           return PHISCEV;
         }
       }
-    } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(BEValue)) {
+    } else {
       // Otherwise, this could be a loop like this:
       //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
       // In this case, j = {1,+,1}  and BEValue is j.
       // Because the other in-value of i (0) fits the evolution of BEValue
       // i really is an addrec evolution.
-      if (AddRec->getLoop() == L && AddRec->isAffine()) {
+      //
+      // We can generalize this saying that i is the shifted value of BEValue
+      // by one iteration:
+      //   PHI(f(0), f({1,+,1})) --> f({0,+,1})
+      const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
+      const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this);
+      if (Shifted != getCouldNotCompute() &&
+          Start != getCouldNotCompute()) {
         const SCEV *StartVal = getSCEV(StartValueV);
-
-        // If StartVal = j.start - j.stride, we can use StartVal as the
-        // initial step of the addrec evolution.
-        if (StartVal ==
-            getMinusSCEV(AddRec->getOperand(0), AddRec->getOperand(1))) {
-          // FIXME: For constant StartVal, we should be able to infer
-          // no-wrap flags.
-          const SCEV *PHISCEV = getAddRecExpr(StartVal, AddRec->getOperand(1),
-                                              L, SCEV::FlagAnyWrap);
-
+        if (Start == StartVal) {
           // Okay, for the entire analysis of this edge we assumed the PHI
           // to be symbolic.  We now need to go back and purge all of the
           // entries for the scalars that use the symbolic expression.
           ForgetSymbolicName(PN, SymbolicName);
-          ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
-          return PHISCEV;
+          ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted;
+          return Shifted;
         }
       }
     }
@@ -3756,8 +3859,8 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
       switch (S->getSCEVType()) {
       case scConstant: case scTruncate: case scZeroExtend: case scSignExtend:
       case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr:
-      // These expressions are available if their operand(s) is/are.
-      return true;
+        // These expressions are available if their operand(s) is/are.
+        return true;
 
       case scAddRecExpr: {
         // We allow add recurrences that are on the loop BB is in, or some
@@ -3788,8 +3891,8 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
 
       case scUDivExpr:
       case scCouldNotCompute:
-      // We do not try to smart about these at all.
-      return setUnavailable();
+        // We do not try to smart about these at all.
+        return setUnavailable();
       }
       llvm_unreachable("switch should be fully covered!");
     }
@@ -3841,6 +3944,11 @@ const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
   if (PN->getNumIncomingValues() == 2) {
     const Loop *L = LI.getLoopFor(PN->getParent());
 
+    // We don't want to break LCSSA, even in a SCEV expression tree.
+    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
+      if (LI.getLoopFor(PN->getIncomingBlock(i)) != L)
+        return nullptr;
+
     // Try to match
     //
     //  br %cond, label %left, label %right
@@ -3880,8 +3988,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
   // PHI's incoming blocks are in a different loop, in which case doing so
   // risks breaking LCSSA form. Instcombine would normally zap these, but
   // it doesn't have DominatorTree information, so it may miss cases.
-  if (Value *V = SimplifyInstruction(PN, F.getParent()->getDataLayout(), &TLI,
-                                     &DT, &AC))
+  if (Value *V = SimplifyInstruction(PN, getDataLayout(), &TLI, &DT, &AC))
     if (LI.replacementPreservesLCSSAForm(PN, V))
       return getSCEV(V);
 
@@ -4076,8 +4183,8 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
     // For a SCEVUnknown, ask ValueTracking.
     unsigned BitWidth = getTypeSizeInBits(U->getType());
     APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
-    computeKnownBits(U->getValue(), Zeros, Ones, F.getParent()->getDataLayout(),
-                     0, &AC, nullptr, &DT);
+    computeKnownBits(U->getValue(), Zeros, Ones, getDataLayout(), 0, &AC,
+                     nullptr, &DT);
     return Zeros.countTrailingOnes();
   }
 
@@ -4088,26 +4195,9 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
 /// GetRangeFromMetadata - Helper method to assign a range to V from
 /// metadata present in the IR.
 static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
-  if (Instruction *I = dyn_cast<Instruction>(V)) {
-    if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) {
-      ConstantRange TotalRange(
-          cast<IntegerType>(I->getType())->getBitWidth(), false);
-
-      unsigned NumRanges = MD->getNumOperands() / 2;
-      assert(NumRanges >= 1);
-
-      for (unsigned i = 0; i < NumRanges; ++i) {
-        ConstantInt *Lower =
-            mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 0));
-        ConstantInt *Upper =
-            mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 1));
-        ConstantRange Range(Lower->getValue(), Upper->getValue());
-        TotalRange = TotalRange.unionWith(Range);
-      }
-
-      return TotalRange;
-    }
-  }
+  if (Instruction *I = dyn_cast<Instruction>(V))
+    if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
+      return getConstantRangeFromMetadata(*MD);
 
   return None;
 }
@@ -4307,7 +4397,7 @@ ScalarEvolution::getRange(const SCEV *S,
     // Split here to avoid paying the compile-time cost of calling both
     // computeKnownBits and ComputeNumSignBits.  This restriction can be lifted
     // if needed.
-    const DataLayout &DL = F.getParent()->getDataLayout();
+    const DataLayout &DL = getDataLayout();
     if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
       // For a SCEVUnknown, ask ValueTracking.
       APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
@@ -4515,8 +4605,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       unsigned TZ = A.countTrailingZeros();
       unsigned BitWidth = A.getBitWidth();
       APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
-      computeKnownBits(U->getOperand(0), KnownZero, KnownOne,
-                       F.getParent()->getDataLayout(), 0, &AC, nullptr, &DT);
+      computeKnownBits(U->getOperand(0), KnownZero, KnownOne, getDataLayout(),
+                       0, &AC, nullptr, &DT);
 
       APInt EffectiveMask =
           APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
@@ -5375,6 +5465,11 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
         return ItCnt;
     }
 
+  ExitLimit ShiftEL = computeShiftCompareExitLimit(
+      ExitCond->getOperand(0), ExitCond->getOperand(1), L, Cond);
+  if (ShiftEL.hasAnyInfo())
+    return ShiftEL;
+
   const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
   const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
 
@@ -5434,14 +5529,6 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
     break;
   }
   default:
-#if 0
-    dbgs() << "computeBackedgeTakenCount ";
-    if (ExitCond->getOperand(0)->getType()->isUnsigned())
-      dbgs() << "[unsigned] ";
-    dbgs() << *LHS << "   "
-         << Instruction::getOpcodeName(Instruction::ICmp)
-         << "   " << *RHS << "\n";
-#endif
     break;
   }
   return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
@@ -5554,11 +5641,6 @@ ScalarEvolution::computeLoadConstantCompareExitLimit(
     Result = ConstantExpr::getICmp(predicate, Result, RHS);
     if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
     if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
-#if 0
-      dbgs() << "\n***\n*** Computed loop count " << *ItCst
-             << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
-             << "***\n";
-#endif
       ++NumArrayLenItCounts;
       return getConstant(ItCst);   // Found terminating iteration!
     }
@@ -5566,6 +5648,149 @@ ScalarEvolution::computeLoadConstantCompareExitLimit(
   return getCouldNotCompute();
 }
 
+ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
+    Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
+  ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
+  if (!RHS)
+    return getCouldNotCompute();
+
+  const BasicBlock *Latch = L->getLoopLatch();
+  if (!Latch)
+    return getCouldNotCompute();
+
+  const BasicBlock *Predecessor = L->getLoopPredecessor();
+  if (!Predecessor)
+    return getCouldNotCompute();
+
+  // Return true if V is of the form "LHS `shift_op` <positive constant>".
+  // Return LHS in OutLHS and shift_opt in OutOpCode.
+  auto MatchPositiveShift =
+      [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
+
+    using namespace PatternMatch;
+
+    ConstantInt *ShiftAmt;
+    if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
+      OutOpCode = Instruction::LShr;
+    else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
+      OutOpCode = Instruction::AShr;
+    else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
+      OutOpCode = Instruction::Shl;
+    else
+      return false;
+
+    return ShiftAmt->getValue().isStrictlyPositive();
+  };
+
+  // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
+  //
+  // loop:
+  //   %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
+  //   %iv.shifted = lshr i32 %iv, <positive constant>
+  //
+  // Return true on a succesful match.  Return the corresponding PHI node (%iv
+  // above) in PNOut and the opcode of the shift operation in OpCodeOut.
+  auto MatchShiftRecurrence =
+      [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
+    Optional<Instruction::BinaryOps> PostShiftOpCode;
+
+    {
+      Instruction::BinaryOps OpC;
+      Value *V;
+
+      // If we encounter a shift instruction, "peel off" the shift operation,
+      // and remember that we did so.  Later when we inspect %iv's backedge
+      // value, we will make sure that the backedge value uses the same
+      // operation.
+      //
+      // Note: the peeled shift operation does not have to be the same
+      // instruction as the one feeding into the PHI's backedge value.  We only
+      // really care about it being the same *kind* of shift instruction --
+      // that's all that is required for our later inferences to hold.
+      if (MatchPositiveShift(LHS, V, OpC)) {
+        PostShiftOpCode = OpC;
+        LHS = V;
+      }
+    }
+
+    PNOut = dyn_cast<PHINode>(LHS);
+    if (!PNOut || PNOut->getParent() != L->getHeader())
+      return false;
+
+    Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
+    Value *OpLHS;
+
+    return
+        // The backedge value for the PHI node must be a shift by a positive
+        // amount
+        MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
+
+        // of the PHI node itself
+        OpLHS == PNOut &&
+
+        // and the kind of shift should be match the kind of shift we peeled
+        // off, if any.
+        (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut);
+  };
+
+  PHINode *PN;
+  Instruction::BinaryOps OpCode;
+  if (!MatchShiftRecurrence(LHS, PN, OpCode))
+    return getCouldNotCompute();
+
+  const DataLayout &DL = getDataLayout();
+
+  // The key rationale for this optimization is that for some kinds of shift
+  // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
+  // within a finite number of iterations.  If the condition guarding the
+  // backedge (in the sense that the backedge is taken if the condition is true)
+  // is false for the value the shift recurrence stabilizes to, then we know
+  // that the backedge is taken only a finite number of times.
+
+  ConstantInt *StableValue = nullptr;
+  switch (OpCode) {
+  default:
+    llvm_unreachable("Impossible case!");
+
+  case Instruction::AShr: {
+    // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
+    // bitwidth(K) iterations.
+    Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
+    bool KnownZero, KnownOne;
+    ComputeSignBit(FirstValue, KnownZero, KnownOne, DL, 0, nullptr,
+                   Predecessor->getTerminator(), &DT);
+    auto *Ty = cast<IntegerType>(RHS->getType());
+    if (KnownZero)
+      StableValue = ConstantInt::get(Ty, 0);
+    else if (KnownOne)
+      StableValue = ConstantInt::get(Ty, -1, true);
+    else
+      return getCouldNotCompute();
+
+    break;
+  }
+  case Instruction::LShr:
+  case Instruction::Shl:
+    // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
+    // stabilize to 0 in at most bitwidth(K) iterations.
+    StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
+    break;
+  }
+
+  auto *Result =
+      ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
+  assert(Result->getType()->isIntegerTy(1) &&
+         "Otherwise cannot be an operand to a branch instruction");
+
+  if (Result->isZeroValue()) {
+    unsigned BitWidth = getTypeSizeInBits(RHS->getType());
+    const SCEV *UpperBound =
+        getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
+    return ExitLimit(getCouldNotCompute(), UpperBound);
+  }
+
+  return getCouldNotCompute();
+}
 
 /// CanConstantFold - Return true if we can constant fold an instruction of the
 /// specified type, assuming that all operands were constants.
@@ -5646,9 +5871,8 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
   Instruction *I = dyn_cast<Instruction>(V);
   if (!I || !canConstantEvolve(I, L)) return nullptr;
 
-  if (PHINode *PN = dyn_cast<PHINode>(I)) {
+  if (PHINode *PN = dyn_cast<PHINode>(I))
     return PN;
-  }
 
   // Record non-constant instructions contained by the loop.
   DenseMap<Instruction *, PHINode *> PHIMap;
@@ -5705,6 +5929,30 @@ static Constant *EvaluateExpression(Value *V, const Loop *L,
                                   TLI);
 }
 
+
+// If every incoming value to PN except the one for BB is a specific Constant,
+// return that, else return nullptr.
+static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) {
+  Constant *IncomingVal = nullptr;
+
+  for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+    if (PN->getIncomingBlock(i) == BB)
+      continue;
+
+    auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
+    if (!CurrentVal)
+      return nullptr;
+
+    if (IncomingVal != CurrentVal) {
+      if (IncomingVal)
+        return nullptr;
+      IncomingVal = CurrentVal;
+    }
+  }
+
+  return IncomingVal;
+}
+
 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
 /// in the header of its containing loop, we know the loop executes a
 /// constant number of times, and the PHI node is just a recurrence
@@ -5730,25 +5978,10 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
   if (!Latch)
     return nullptr;
 
-  // Since the loop has one latch, the PHI node must have two entries.  One
-  // entry must be a constant (coming in from outside of the loop), and the
-  // second must be derived from the same PHI.
-
-  BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0)
-                             ? PN->getIncomingBlock(1)
-                             : PN->getIncomingBlock(0);
-
-  assert(PN->getNumIncomingValues() == 2 && "Follows from having one latch!");
-
-  // Note: not all PHI nodes in the same block have to have their incoming
-  // values in the same order, so we use the basic block to look up the incoming
-  // value, not an index.
-
   for (auto &I : *Header) {
     PHINode *PHI = dyn_cast<PHINode>(&I);
     if (!PHI) break;
-    auto *StartCST =
-        dyn_cast<Constant>(PHI->getIncomingValueForBlock(NonLatch));
+    auto *StartCST = getOtherIncomingValue(PHI, Latch);
     if (!StartCST) continue;
     CurrentIterVals[PHI] = StartCST;
   }
@@ -5763,7 +5996,7 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
 
   unsigned NumIterations = BEs.getZExtValue(); // must be in range
   unsigned IterationNum = 0;
-  const DataLayout &DL = F.getParent()->getDataLayout();
+  const DataLayout &DL = getDataLayout();
   for (; ; ++IterationNum) {
     if (IterationNum == NumIterations)
       return RetVal = CurrentIterVals[PN];  // Got exit value!
@@ -5827,21 +6060,11 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
   BasicBlock *Latch = L->getLoopLatch();
   assert(Latch && "Should follow from NumIncomingValues == 2!");
 
-  // NonLatch is the preheader, or something equivalent.
-  BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0)
-                             ? PN->getIncomingBlock(1)
-                             : PN->getIncomingBlock(0);
-
-  // Note: not all PHI nodes in the same block have to have their incoming
-  // values in the same order, so we use the basic block to look up the incoming
-  // value, not an index.
-
   for (auto &I : *Header) {
     PHINode *PHI = dyn_cast<PHINode>(&I);
     if (!PHI)
       break;
-    auto *StartCST =
-      dyn_cast<Constant>(PHI->getIncomingValueForBlock(NonLatch));
+    auto *StartCST = getOtherIncomingValue(PHI, Latch);
     if (!StartCST) continue;
     CurrentIterVals[PHI] = StartCST;
   }
@@ -5852,7 +6075,7 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
   // the loop symbolically to determine when the condition gets a value of
   // "ExitWhen".
   unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
-  const DataLayout &DL = F.getParent()->getDataLayout();
+  const DataLayout &DL = getDataLayout();
   for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
     auto *CondVal = dyn_cast_or_null<ConstantInt>(
         EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
@@ -5902,22 +6125,22 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
 /// In the case that a relevant loop exit value cannot be computed, the
 /// original value V is returned.
 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
+  SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
+      ValuesAtScopes[V];
   // Check to see if we've folded this expression at this loop before.
-  SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = ValuesAtScopes[V];
-  for (unsigned u = 0; u < Values.size(); u++) {
-    if (Values[u].first == L)
-      return Values[u].second ? Values[u].second : V;
-  }
-  Values.push_back(std::make_pair(L, static_cast<const SCEV *>(nullptr)));
+  for (auto &LS : Values)
+    if (LS.first == L)
+      return LS.second ? LS.second : V;
+
+  Values.emplace_back(L, nullptr);
+
   // Otherwise compute it.
   const SCEV *C = computeSCEVAtScope(V, L);
-  SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values2 = ValuesAtScopes[V];
-  for (unsigned u = Values2.size(); u > 0; u--) {
-    if (Values2[u - 1].first == L) {
-      Values2[u - 1].second = C;
+  for (auto &LS : reverse(ValuesAtScopes[V]))
+    if (LS.first == L) {
+      LS.second = C;
       break;
     }
-  }
   return C;
 }
 
@@ -6084,7 +6307,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
         // Check to see if getSCEVAtScope actually made an improvement.
         if (MadeImprovement) {
           Constant *C = nullptr;
-          const DataLayout &DL = F.getParent()->getDataLayout();
+          const DataLayout &DL = getDataLayout();
           if (const CmpInst *CI = dyn_cast<CmpInst>(I))
             C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
                                                 Operands[1], DL, &TLI);
@@ -6366,10 +6589,6 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) {
     const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
     const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
     if (R1 && R2) {
-#if 0
-      dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1
-             << "  sol#2: " << *R2 << "\n";
-#endif
       // Pick the smallest positive root value.
       if (ConstantInt *CB =
           dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT,
@@ -6825,16 +7044,14 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
       Pred = ICmpInst::ICMP_ULT;
       Changed = true;
     } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
-      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
-                       SCEV::FlagNUW);
+      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
       Pred = ICmpInst::ICMP_ULT;
       Changed = true;
     }
     break;
   case ICmpInst::ICMP_UGE:
     if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
-      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
-                       SCEV::FlagNUW);
+      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
       Pred = ICmpInst::ICMP_UGT;
       Changed = true;
     } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
@@ -7130,6 +7347,62 @@ ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
   return false;
 }
 
+bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
+                                                    const SCEV *LHS,
+                                                    const SCEV *RHS) {
+
+  // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.
+  // Return Y via OutY.
+  auto MatchBinaryAddToConst =
+      [this](const SCEV *Result, const SCEV *X, APInt &OutY,
+             SCEV::NoWrapFlags ExpectedFlags) {
+    const SCEV *NonConstOp, *ConstOp;
+    SCEV::NoWrapFlags FlagsPresent;
+
+    if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) ||
+        !isa<SCEVConstant>(ConstOp) || NonConstOp != X)
+      return false;
+
+    OutY = cast<SCEVConstant>(ConstOp)->getValue()->getValue();
+    return (FlagsPresent & ExpectedFlags) == ExpectedFlags;
+  };
+
+  APInt C;
+
+  switch (Pred) {
+  default:
+    break;
+
+  case ICmpInst::ICMP_SGE:
+    std::swap(LHS, RHS);
+  case ICmpInst::ICMP_SLE:
+    // X s<= (X + C)<nsw> if C >= 0
+    if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative())
+      return true;
+
+    // (X + C)<nsw> s<= X if C <= 0
+    if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
+        !C.isStrictlyPositive())
+      return true;
+    break;
+
+  case ICmpInst::ICMP_SGT:
+    std::swap(LHS, RHS);
+  case ICmpInst::ICMP_SLT:
+    // X s< (X + C)<nsw> if C > 0
+    if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) &&
+        C.isStrictlyPositive())
+      return true;
+
+    // (X + C)<nsw> s< X if C < 0
+    if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
+      return true;
+    break;
+  }
+
+  return false;
+}
+
 bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
                                                    const SCEV *LHS,
                                                    const SCEV *RHS) {
@@ -7147,12 +7420,9 @@ bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
   // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
   // interesting cases seen in practice.  We can consider "upgrading" L >= 0 to
   // use isKnownPredicate later if needed.
-  if (isKnownNonNegative(RHS) &&
-      isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
-      isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS))
-    return true;
-
-  return false;
+  return isKnownNonNegative(RHS) &&
+         isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
+         isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
 }
 
 /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
@@ -7305,6 +7575,7 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
   return false;
 }
 
+namespace {
 /// RAII wrapper to prevent recursive application of isImpliedCond.
 /// ScalarEvolution's PendingLoopPredicates set must be empty unless we are
 /// currently evaluating isImpliedCond.
@@ -7322,6 +7593,7 @@ struct MarkPendingLoopPredicate {
       LoopPreds.erase(Cond);
   }
 };
+} // end anonymous namespace
 
 /// isImpliedCond - Test whether the condition described by Pred, LHS,
 /// and RHS is true whenever the given Cond value evaluates to true.
@@ -7423,6 +7695,13 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
                                    RHS, LHS, FoundLHS, FoundRHS);
   }
 
+  // Unsigned comparison is the same as signed comparison when both the operands
+  // are non-negative.
+  if (CmpInst::isUnsigned(FoundPred) &&
+      CmpInst::getSignedPredicate(FoundPred) == Pred &&
+      isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS))
+    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
+
   // Check if we can make progress by sharpening ranges.
   if (FoundPred == ICmpInst::ICMP_NE &&
       (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
@@ -7685,8 +7964,7 @@ static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr,
   const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr);
   if (!MaxExpr) return false;
 
-  auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate);
-  return It != MaxExpr->op_end();
+  return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end();
 }
 
 
@@ -7779,8 +8057,9 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
   auto IsKnownPredicateFull =
       [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
     return isKnownPredicateWithRanges(Pred, LHS, RHS) ||
-        IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
-        IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS);
+           IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
+           IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
+           isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
   };
 
   switch (Pred) {
@@ -8123,8 +8402,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
 
   // The only time we can solve this is when we have all constant indices.
   // Otherwise, we cannot determine the overflow conditions.
-  if (std::any_of(op_begin(), op_end(),
-                  [](const SCEV *Op) { return !isa<SCEVConstant>(Op);}))
+  if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
     return SE.getCouldNotCompute();
 
   // Okay at this point we know that all elements of the chrec are constants and
@@ -8177,15 +8455,13 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
                                              FlagAnyWrap);
 
     // Next, solve the constructed addrec
-    std::pair<const SCEV *,const SCEV *> Roots =
-      SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
+    auto Roots = SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
     const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
     const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
     if (R1) {
       // Pick the smallest positive root value.
-      if (ConstantInt *CB =
-          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
-                         R1->getValue(), R2->getValue()))) {
+      if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp(
+              ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) {
         if (!CB->getZExtValue())
           std::swap(R1, R2);   // R1 is the minimum root now.
 
@@ -8811,6 +9087,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
       UnsignedRanges(std::move(Arg.UnsignedRanges)),
       SignedRanges(std::move(Arg.SignedRanges)),
       UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
+      UniquePreds(std::move(Arg.UniquePreds)),
       SCEVAllocator(std::move(Arg.SCEVAllocator)),
       FirstUnknown(Arg.FirstUnknown) {
   Arg.FirstUnknown = nullptr;
@@ -8985,9 +9262,8 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
 
     // This recurrence is variant w.r.t. L if any of its operands
     // are variant.
-    for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
-         I != E; ++I)
-      if (!isLoopInvariant(*I, L))
+    for (auto *Op : AR->operands())
+      if (!isLoopInvariant(Op, L))
         return LoopVariant;
 
     // Otherwise it's loop-invariant.
@@ -8997,11 +9273,9 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
   case scMulExpr:
   case scUMaxExpr:
   case scSMaxExpr: {
-    const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
     bool HasVarying = false;
-    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
-         I != E; ++I) {
-      LoopDisposition D = getLoopDisposition(*I, L);
+    for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
+      LoopDisposition D = getLoopDisposition(Op, L);
       if (D == LoopVariant)
         return LoopVariant;
       if (D == LoopComputable)
@@ -9025,7 +9299,7 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
     // invariant if they are not contained in the specified loop.
     // Instructions are never considered invariant in the function body
     // (null loop) because they are defined within the "loop".
-    if (Instruction *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
+    if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
       return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
     return LoopInvariant;
   case scCouldNotCompute:
@@ -9314,3 +9588,135 @@ void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.addRequiredTransitive<DominatorTreeWrapperPass>();
   AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
 }
+
+const SCEVPredicate *
+ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
+                                   const SCEVConstant *RHS) {
+  FoldingSetNodeID ID;
+  // Unique this node based on the arguments
+  ID.AddInteger(SCEVPredicate::P_Equal);
+  ID.AddPointer(LHS);
+  ID.AddPointer(RHS);
+  void *IP = nullptr;
+  if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
+    return S;
+  SCEVEqualPredicate *Eq = new (SCEVAllocator)
+      SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
+  UniquePreds.InsertNode(Eq, IP);
+  return Eq;
+}
+
+namespace {
+class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
+public:
+  static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
+                             SCEVUnionPredicate &A) {
+    SCEVPredicateRewriter Rewriter(SE, A);
+    return Rewriter.visit(Scev);
+  }
+
+  SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P)
+      : SCEVRewriteVisitor(SE), P(P) {}
+
+  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+    auto ExprPreds = P.getPredicatesForExpr(Expr);
+    for (auto *Pred : ExprPreds)
+      if (const auto *IPred = dyn_cast<const SCEVEqualPredicate>(Pred))
+        if (IPred->getLHS() == Expr)
+          return IPred->getRHS();
+
+    return Expr;
+  }
+
+private:
+  SCEVUnionPredicate &P;
+};
+} // end anonymous namespace
+
+const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev,
+                                                   SCEVUnionPredicate &Preds) {
+  return SCEVPredicateRewriter::rewrite(Scev, *this, Preds);
+}
+
+/// SCEV predicates
+SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
+                             SCEVPredicateKind Kind)
+    : FastID(ID), Kind(Kind) {}
+
+SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
+                                       const SCEVUnknown *LHS,
+                                       const SCEVConstant *RHS)
+    : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {}
+
+bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
+  const auto *Op = dyn_cast<const SCEVEqualPredicate>(N);
+
+  if (!Op)
+    return false;
+
+  return Op->LHS == LHS && Op->RHS == RHS;
+}
+
+bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
+
+const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
+
+void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
+  OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
+}
+
+/// Union predicates don't get cached so create a dummy set ID for it.
+SCEVUnionPredicate::SCEVUnionPredicate()
+    : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
+
+bool SCEVUnionPredicate::isAlwaysTrue() const {
+  return all_of(Preds,
+                [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
+}
+
+ArrayRef<const SCEVPredicate *>
+SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
+  auto I = SCEVToPreds.find(Expr);
+  if (I == SCEVToPreds.end())
+    return ArrayRef<const SCEVPredicate *>();
+  return I->second;
+}
+
+bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
+  if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N))
+    return all_of(Set->Preds,
+                  [this](const SCEVPredicate *I) { return this->implies(I); });
+
+  auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
+  if (ScevPredsIt == SCEVToPreds.end())
+    return false;
+  auto &SCEVPreds = ScevPredsIt->second;
+
+  return any_of(SCEVPreds,
+                [N](const SCEVPredicate *I) { return I->implies(N); });
+}
+
+const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }
+
+void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
+  for (auto Pred : Preds)
+    Pred->print(OS, Depth);
+}
+
+void SCEVUnionPredicate::add(const SCEVPredicate *N) {
+  if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) {
+    for (auto Pred : Set->Preds)
+      add(Pred);
+    return;
+  }
+
+  if (implies(N))
+    return;
+
+  const SCEV *Key = N->getExpr();
+  assert(Key && "Only SCEVUnionPredicate doesn't have an "
+                " associated expression!");
+
+  SCEVToPreds[Key].push_back(N);
+  Preds.push_back(N);
+}