Create a wrapper pass for BranchProbabilityInfo.
[oota-llvm.git] / lib / Transforms / Scalar / InductiveRangeCheckElimination.cpp
index 86a00b1590e5144de8424f96a3380324eb5370eb..08fdcc38c045d8a48545756efe2f1f0930d82b42 100644 (file)
@@ -42,7 +42,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ADT/Optional.h"
-
 #include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/ScalarEvolutionExpander.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/ValueTracking.h"
-
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Function.h"
-#include "llvm/IR/Instructions.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/ValueHandle.h"
 #include "llvm/IR/Verifier.h"
-
+#include "llvm/Pass.h"
 #include "llvm/Support/Debug.h"
-
+#include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
 #include "llvm/Transforms/Utils/SimplifyIndVar.h"
 #include "llvm/Transforms/Utils/UnrollLoop.h"
-
-#include "llvm/Pass.h"
-
 #include <array>
 
 using namespace llvm;
@@ -82,6 +77,12 @@ static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden,
 static cl::opt<bool> PrintChangedLoops("irce-print-changed-loops", cl::Hidden,
                                        cl::init(false));
 
+static cl::opt<bool> PrintRangeChecks("irce-print-range-checks", cl::Hidden,
+                                      cl::init(false));
+
+static cl::opt<int> MaxExitProbReciprocal("irce-max-exit-prob-reciprocal",
+                                          cl::Hidden, cl::init(10));
+
 #define DEBUG_TYPE "irce"
 
 namespace {
@@ -93,23 +94,41 @@ namespace {
 ///
 ///  and
 ///
-///  2. a condition that is provably true for some range of values taken by the
-///     containing loop's induction variable.
+///  2. a condition that is provably true for some contiguous range of values
+///     taken by the containing loop's induction variable.
 ///
-/// Currently all inductive range checks are branches conditional on an
-/// expression of the form
-///
-///   0 <= (Offset + Scale * I) < Length
-///
-/// where `I' is the canonical induction variable of a loop to which Offset and
-/// Scale are loop invariant, and Length is >= 0.  Currently the 'false' branch
-/// is considered cold, looking at profiling data to verify that is a TODO.
-
 class InductiveRangeCheck {
+  // Classifies a range check
+  enum RangeCheckKind : unsigned {
+    // Range check of the form "0 <= I".
+    RANGE_CHECK_LOWER = 1,
+
+    // Range check of the form "I < L" where L is known positive.
+    RANGE_CHECK_UPPER = 2,
+
+    // The logical and of the RANGE_CHECK_LOWER and RANGE_CHECK_UPPER
+    // conditions.
+    RANGE_CHECK_BOTH = RANGE_CHECK_LOWER | RANGE_CHECK_UPPER,
+
+    // Unrecognized range check condition.
+    RANGE_CHECK_UNKNOWN = (unsigned)-1
+  };
+
+  static const char *rangeCheckKindToStr(RangeCheckKind);
+
   const SCEV *Offset;
   const SCEV *Scale;
   Value *Length;
   BranchInst *Branch;
+  RangeCheckKind Kind;
+
+  static RangeCheckKind parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
+                                            ScalarEvolution &SE, Value *&Index,
+                                            Value *&Length);
+
+  static InductiveRangeCheck::RangeCheckKind
+  parseRangeCheck(Loop *L, ScalarEvolution &SE, Value *Condition,
+                  const SCEV *&Index, Value *&UpperLimit);
 
   InductiveRangeCheck() :
     Offset(nullptr), Scale(nullptr), Length(nullptr), Branch(nullptr) { }
@@ -121,14 +140,19 @@ public:
 
   void print(raw_ostream &OS) const {
     OS << "InductiveRangeCheck:\n";
+    OS << "  Kind: " << rangeCheckKindToStr(Kind) << "\n";
     OS << "  Offset: ";
     Offset->print(OS);
     OS << "  Scale: ";
     Scale->print(OS);
     OS << "  Length: ";
-    Length->print(OS);
-    OS << "  Branch: ";
+    if (Length)
+      Length->print(OS);
+    else
+      OS << "(null)";
+    OS << "\n  Branch: ";
     getBranch()->print(OS);
+    OS << "\n";
   }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -162,9 +186,11 @@ public:
   /// branch to take the hot successor (see (1) above).
   bool getPassingDirection() { return true; }
 
-  /// Computes a range for the induction variable in which the range check is
-  /// redundant and can be constant-folded away.
+  /// Computes a range for the induction variable (IndVar) in which the range
+  /// check is redundant and can be constant-folded away.  The induction
+  /// variable is not required to be the canonical {0,+,1} induction variable.
   Optional<Range> computeSafeIterationSpace(ScalarEvolution &SE,
+                                            const SCEVAddRecExpr *IndVar,
                                             IRBuilder<> &B) const;
 
   /// Create an inductive range check out of BI if possible, else return
@@ -189,7 +215,7 @@ public:
     AU.addRequiredID(LoopSimplifyID);
     AU.addRequiredID(LCSSAID);
     AU.addRequired<ScalarEvolution>();
-    AU.addRequired<BranchProbabilityInfo>();
+    AU.addRequired<BranchProbabilityInfoWrapperPass>();
   }
 
   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
@@ -201,160 +227,156 @@ char InductiveRangeCheckElimination::ID = 0;
 INITIALIZE_PASS(InductiveRangeCheckElimination, "irce",
                 "Inductive range check elimination", false, false)
 
-static bool IsLowerBoundCheck(Value *Check, Value *&IndexV) {
-  using namespace llvm::PatternMatch;
+const char *InductiveRangeCheck::rangeCheckKindToStr(
+    InductiveRangeCheck::RangeCheckKind RCK) {
+  switch (RCK) {
+  case InductiveRangeCheck::RANGE_CHECK_UNKNOWN:
+    return "RANGE_CHECK_UNKNOWN";
 
-  ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
-  Value *LHS = nullptr, *RHS = nullptr;
+  case InductiveRangeCheck::RANGE_CHECK_UPPER:
+    return "RANGE_CHECK_UPPER";
 
-  if (!match(Check, m_ICmp(Pred, m_Value(LHS), m_Value(RHS))))
-    return false;
+  case InductiveRangeCheck::RANGE_CHECK_LOWER:
+    return "RANGE_CHECK_LOWER";
+
+  case InductiveRangeCheck::RANGE_CHECK_BOTH:
+    return "RANGE_CHECK_BOTH";
+  }
+
+  llvm_unreachable("unknown range check type!");
+}
+
+/// Parse a single ICmp instruction, `ICI`, into a range check.  If `ICI`
+/// cannot
+/// be interpreted as a range check, return `RANGE_CHECK_UNKNOWN` and set
+/// `Index` and `Length` to `nullptr`.  Otherwise set `Index` to the value
+/// being
+/// range checked, and set `Length` to the upper limit `Index` is being range
+/// checked with if (and only if) the range check type is stronger or equal to
+/// RANGE_CHECK_UPPER.
+///
+InductiveRangeCheck::RangeCheckKind
+InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
+                                         ScalarEvolution &SE, Value *&Index,
+                                         Value *&Length) {
+
+  auto IsNonNegativeAndNotLoopVarying = [&SE, L](Value *V) {
+    const SCEV *S = SE.getSCEV(V);
+    if (isa<SCEVCouldNotCompute>(S))
+      return false;
+
+    return SE.getLoopDisposition(S, L) == ScalarEvolution::LoopInvariant &&
+           SE.isKnownNonNegative(S);
+  };
+
+  using namespace llvm::PatternMatch;
+
+  ICmpInst::Predicate Pred = ICI->getPredicate();
+  Value *LHS = ICI->getOperand(0);
+  Value *RHS = ICI->getOperand(1);
 
   switch (Pred) {
   default:
-    return false;
+    return RANGE_CHECK_UNKNOWN;
 
   case ICmpInst::ICMP_SLE:
     std::swap(LHS, RHS);
   // fallthrough
   case ICmpInst::ICMP_SGE:
-    if (!match(RHS, m_ConstantInt<0>()))
-      return false;
-    IndexV = LHS;
-    return true;
+    if (match(RHS, m_ConstantInt<0>())) {
+      Index = LHS;
+      return RANGE_CHECK_LOWER;
+    }
+    return RANGE_CHECK_UNKNOWN;
 
   case ICmpInst::ICMP_SLT:
     std::swap(LHS, RHS);
   // fallthrough
   case ICmpInst::ICMP_SGT:
-    if (!match(RHS, m_ConstantInt<-1>()))
-      return false;
-    IndexV = LHS;
-    return true;
-  }
-}
-
-static bool IsUpperBoundCheck(Value *Check, Value *Index, Value *&UpperLimit) {
-  using namespace llvm::PatternMatch;
-
-  ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
-  Value *LHS = nullptr, *RHS = nullptr;
-
-  if (!match(Check, m_ICmp(Pred, m_Value(LHS), m_Value(RHS))))
-    return false;
+    if (match(RHS, m_ConstantInt<-1>())) {
+      Index = LHS;
+      return RANGE_CHECK_LOWER;
+    }
 
-  switch (Pred) {
-  default:
-    return false;
+    if (IsNonNegativeAndNotLoopVarying(LHS)) {
+      Index = RHS;
+      Length = LHS;
+      return RANGE_CHECK_UPPER;
+    }
+    return RANGE_CHECK_UNKNOWN;
 
-  case ICmpInst::ICMP_SGT:
+  case ICmpInst::ICMP_ULT:
     std::swap(LHS, RHS);
   // fallthrough
-  case ICmpInst::ICMP_SLT:
-    if (LHS != Index)
-      return false;
-    UpperLimit = RHS;
-    return true;
-
   case ICmpInst::ICMP_UGT:
-    std::swap(LHS, RHS);
-  // fallthrough
-  case ICmpInst::ICMP_ULT:
-    if (LHS != Index)
-      return false;
-    UpperLimit = RHS;
-    return true;
+    if (IsNonNegativeAndNotLoopVarying(LHS)) {
+      Index = RHS;
+      Length = LHS;
+      return RANGE_CHECK_BOTH;
+    }
+    return RANGE_CHECK_UNKNOWN;
   }
+
+  llvm_unreachable("default clause returns!");
 }
 
-/// Split a condition into something semantically equivalent to (0 <= I <
-/// Limit), both comparisons signed and Len loop invariant on L and positive.
-/// On success, return true and set Index to I and UpperLimit to Limit.  Return
-/// false on failure (we may still write to UpperLimit and Index on failure).
-/// It does not try to interpret I as a loop index.
-///
-static bool SplitRangeCheckCondition(Loop *L, ScalarEvolution &SE,
+/// Parses an arbitrary condition into a range check.  `Length` is set only if
+/// the range check is recognized to be `RANGE_CHECK_UPPER` or stronger.
+InductiveRangeCheck::RangeCheckKind
+InductiveRangeCheck::parseRangeCheck(Loop *L, ScalarEvolution &SE,
                                      Value *Condition, const SCEV *&Index,
-                                     Value *&UpperLimit) {
-
-  // TODO: currently this catches some silly cases like comparing "%idx slt 1".
-  // Our transformations are still correct, but less likely to be profitable in
-  // those cases.  We have to come up with some heuristics that pick out the
-  // range checks that are more profitable to clone a loop for.  This function
-  // in general can be made more robust.
-
+                                     Value *&Length) {
   using namespace llvm::PatternMatch;
 
   Value *A = nullptr;
   Value *B = nullptr;
-  ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
-
-  // In these early checks we assume that the matched UpperLimit is positive.
-  // We'll verify that fact later, before returning true.
 
   if (match(Condition, m_And(m_Value(A), m_Value(B)))) {
-    Value *IndexV = nullptr;
-    Value *ExpectedUpperBoundCheck = nullptr;
+    Value *IndexA = nullptr, *IndexB = nullptr;
+    Value *LengthA = nullptr, *LengthB = nullptr;
+    ICmpInst *ICmpA = dyn_cast<ICmpInst>(A), *ICmpB = dyn_cast<ICmpInst>(B);
 
-    if (IsLowerBoundCheck(A, IndexV))
-      ExpectedUpperBoundCheck = B;
-    else if (IsLowerBoundCheck(B, IndexV))
-      ExpectedUpperBoundCheck = A;
-    else
-      return false;
+    if (!ICmpA || !ICmpB)
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 
-    if (!IsUpperBoundCheck(ExpectedUpperBoundCheck, IndexV, UpperLimit))
-      return false;
+    auto RCKindA = parseRangeCheckICmp(L, ICmpA, SE, IndexA, LengthA);
+    auto RCKindB = parseRangeCheckICmp(L, ICmpB, SE, IndexB, LengthB);
 
-    Index = SE.getSCEV(IndexV);
+    if (RCKindA == InductiveRangeCheck::RANGE_CHECK_UNKNOWN ||
+        RCKindB == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 
-    if (isa<SCEVCouldNotCompute>(Index))
-      return false;
+    if (IndexA != IndexB)
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 
-  } else if (match(Condition, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
-    switch (Pred) {
-    default:
-      return false;
+    if (LengthA != nullptr && LengthB != nullptr && LengthA != LengthB)
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 
-    case ICmpInst::ICMP_SGT:
-      std::swap(A, B);
-    // fall through
-    case ICmpInst::ICMP_SLT:
-      UpperLimit = B;
-      Index = SE.getSCEV(A);
-      if (isa<SCEVCouldNotCompute>(Index) || !SE.isKnownNonNegative(Index))
-        return false;
-      break;
+    Index = SE.getSCEV(IndexA);
+    if (isa<SCEVCouldNotCompute>(Index))
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 
-    case ICmpInst::ICMP_UGT:
-      std::swap(A, B);
-    // fall through
-    case ICmpInst::ICMP_ULT:
-      UpperLimit = B;
-      Index = SE.getSCEV(A);
-      if (isa<SCEVCouldNotCompute>(Index))
-        return false;
-      break;
-    }
-  } else {
-    return false;
+    Length = LengthA == nullptr ? LengthB : LengthA;
+
+    return (InductiveRangeCheck::RangeCheckKind)(RCKindA | RCKindB);
   }
 
-  const SCEV *UpperLimitSCEV = SE.getSCEV(UpperLimit);
-  if (isa<SCEVCouldNotCompute>(UpperLimitSCEV) ||
-      !SE.isKnownNonNegative(UpperLimitSCEV))
-    return false;
+  if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
+    Value *IndexVal = nullptr;
 
-  if (SE.getLoopDisposition(UpperLimitSCEV, L) !=
-      ScalarEvolution::LoopInvariant) {
-    DEBUG(dbgs() << " in function: " << L->getHeader()->getParent()->getName()
-                 << " ";
-          dbgs() << " UpperLimit is not loop invariant: "
-                 << UpperLimit->getName() << "\n";);
-    return false;
+    auto RCKind = parseRangeCheckICmp(L, ICI, SE, IndexVal, Length);
+
+    if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
+
+    Index = SE.getSCEV(IndexVal);
+    if (isa<SCEVCouldNotCompute>(Index))
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
+
+    return RCKind;
   }
 
-  return true;
+  return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 }
 
 
@@ -374,10 +396,15 @@ InductiveRangeCheck::create(InductiveRangeCheck::AllocatorTy &A, BranchInst *BI,
   Value *Length = nullptr;
   const SCEV *IndexSCEV = nullptr;
 
-  if (!SplitRangeCheckCondition(L, SE, BI->getCondition(), IndexSCEV, Length))
+  auto RCKind = InductiveRangeCheck::parseRangeCheck(L, SE, BI->getCondition(),
+                                                     IndexSCEV, Length);
+
+  if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
     return nullptr;
 
-  assert(IndexSCEV && Length && "contract with SplitRangeCheckCondition!");
+  assert(IndexSCEV && "contract with SplitRangeCheckCondition!");
+  assert((!(RCKind & InductiveRangeCheck::RANGE_CHECK_UPPER) || Length) &&
+         "contract with SplitRangeCheckCondition!");
 
   const SCEVAddRecExpr *IndexAddRec = dyn_cast<SCEVAddRecExpr>(IndexSCEV);
   bool IsAffineIndex =
@@ -391,11 +418,60 @@ InductiveRangeCheck::create(InductiveRangeCheck::AllocatorTy &A, BranchInst *BI,
   IRC->Offset = IndexAddRec->getStart();
   IRC->Scale = IndexAddRec->getStepRecurrence(SE);
   IRC->Branch = BI;
+  IRC->Kind = RCKind;
   return IRC;
 }
 
 namespace {
 
+// Keeps track of the structure of a loop.  This is similar to llvm::Loop,
+// except that it is more lightweight and can track the state of a loop through
+// changing and potentially invalid IR.  This structure also formalizes the
+// kinds of loops we can deal with -- ones that have a single latch that is also
+// an exiting block *and* have a canonical induction variable.
+struct LoopStructure {
+  const char *Tag;
+
+  BasicBlock *Header;
+  BasicBlock *Latch;
+
+  // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th
+  // successor is `LatchExit', the exit block of the loop.
+  BranchInst *LatchBr;
+  BasicBlock *LatchExit;
+  unsigned LatchBrExitIdx;
+
+  Value *IndVarNext;
+  Value *IndVarStart;
+  Value *LoopExitAt;
+  bool IndVarIncreasing;
+
+  LoopStructure()
+      : Tag(""), Header(nullptr), Latch(nullptr), LatchBr(nullptr),
+        LatchExit(nullptr), LatchBrExitIdx(-1), IndVarNext(nullptr),
+        IndVarStart(nullptr), LoopExitAt(nullptr), IndVarIncreasing(false) {}
+
+  template <typename M> LoopStructure map(M Map) const {
+    LoopStructure Result;
+    Result.Tag = Tag;
+    Result.Header = cast<BasicBlock>(Map(Header));
+    Result.Latch = cast<BasicBlock>(Map(Latch));
+    Result.LatchBr = cast<BranchInst>(Map(LatchBr));
+    Result.LatchExit = cast<BasicBlock>(Map(LatchExit));
+    Result.LatchBrExitIdx = LatchBrExitIdx;
+    Result.IndVarNext = Map(IndVarNext);
+    Result.IndVarStart = Map(IndVarStart);
+    Result.LoopExitAt = Map(LoopExitAt);
+    Result.IndVarIncreasing = IndVarIncreasing;
+    return Result;
+  }
+
+  static Optional<LoopStructure> parseLoopStructure(ScalarEvolution &,
+                                                    BranchProbabilityInfo &BPI,
+                                                    Loop &,
+                                                    const char *&);
+};
+
 /// This class is used to constrain loops to run within a given iteration space.
 /// The algorithm this class implements is given a Loop and a range [Begin,
 /// End).  The algorithm then tries to break out a "main loop" out of the loop
@@ -406,51 +482,6 @@ namespace {
 /// iterations in which the induction variable is >= End.
 ///
 class LoopConstrainer {
-
-  // Keeps track of the structure of a loop.  This is similar to llvm::Loop,
-  // except that it is more lightweight and can track the state of a loop
-  // through changing and potentially invalid IR.  This structure also
-  // formalizes the kinds of loops we can deal with -- ones that have a single
-  // latch that is also an exiting block *and* have a canonical induction
-  // variable.
-  struct LoopStructure {
-    const char *Tag;
-
-    BasicBlock *Header;
-    BasicBlock *Latch;
-
-    // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th
-    // successor is `LatchExit', the exit block of the loop.
-    BranchInst *LatchBr;
-    BasicBlock *LatchExit;
-    unsigned LatchBrExitIdx;
-
-    // The canonical induction variable.  It's value is `CIVStart` on the 0th
-    // itertion and `CIVNext` for all iterations after that.
-    PHINode *CIV;
-    Value *CIVStart;
-    Value *CIVNext;
-
-    LoopStructure() : Tag(""), Header(nullptr), Latch(nullptr),
-                      LatchBr(nullptr), LatchExit(nullptr),
-                      LatchBrExitIdx(-1), CIV(nullptr),
-                      CIVStart(nullptr), CIVNext(nullptr) { }
-
-    template <typename M> LoopStructure map(M Map) const {
-      LoopStructure Result;
-      Result.Tag = Tag;
-      Result.Header = cast<BasicBlock>(Map(Header));
-      Result.Latch = cast<BasicBlock>(Map(Latch));
-      Result.LatchBr = cast<BranchInst>(Map(LatchBr));
-      Result.LatchExit = cast<BasicBlock>(Map(LatchExit));
-      Result.LatchBrExitIdx = LatchBrExitIdx;
-      Result.CIV = cast<PHINode>(Map(CIV));
-      Result.CIVNext = Map(CIVNext);
-      Result.CIVStart = Map(CIVStart);
-      return Result;
-    }
-  };
-
   // The representation of a clone of the original loop we started out with.
   struct ClonedLoop {
     // The cloned blocks
@@ -469,17 +500,22 @@ class LoopConstrainer {
     BasicBlock *PseudoExit;
     BasicBlock *ExitSelector;
     std::vector<PHINode *> PHIValuesAtPseudoExit;
+    PHINode *IndVarEnd;
 
-    RewrittenRangeInfo() : PseudoExit(nullptr), ExitSelector(nullptr) { }
+    RewrittenRangeInfo()
+        : PseudoExit(nullptr), ExitSelector(nullptr), IndVarEnd(nullptr) {}
   };
 
   // Calculated subranges we restrict the iteration space of the main loop to.
   // See the implementation of `calculateSubRanges' for more details on how
-  // these fields are computed.  `ExitPreLoopAt' is `None' if we don't need a
-  // pre loop.  `ExitMainLoopAt' is `None' if we don't need a post loop.
+  // these fields are computed.  `LowLimit` is None if there is no restriction
+  // on low end of the restricted iteration space of the main loop.  `HighLimit`
+  // is None if there is no restriction on high end of the restricted iteration
+  // space of the main loop.
+
   struct SubRanges {
-    Optional<Value *> ExitPreLoopAt;
-    Optional<Value *> ExitMainLoopAt;
+    Optional<const SCEV *> LowLimit;
+    Optional<const SCEV *> HighLimit;
   };
 
   // A utility function that does a `replaceUsesOfWith' on the incoming block
@@ -488,19 +524,11 @@ class LoopConstrainer {
   static void replacePHIBlock(PHINode *PN, BasicBlock *Block,
                               BasicBlock *ReplaceBy);
 
-  // Try to "parse" `OriginalLoop' and populate the various out parameters.
-  // Returns true on success, false on failure.
-  //
-  bool recognizeLoop(LoopStructure &LoopStructureOut,
-                     const SCEV *&LatchCountOut, BasicBlock *&PreHeaderOut,
-                     const char *&FailureReasonOut) const;
-
   // Compute a safe set of limits for the main loop to run in -- effectively the
   // intersection of `Range' and the iteration space of the original loop.
-  // Return the header count (1 + the latch taken count) in `HeaderCount'.
   // Return None if unable to compute the set of subranges.
   //
-  Optional<SubRanges> calculateSubRanges(Value *&HeaderCount) const;
+  Optional<SubRanges> calculateSubRanges() const;
 
   // Clone `OriginalLoop' and return the result in CLResult.  The IR after
   // running `cloneLoop' is well formed except for the PHI nodes in CLResult --
@@ -539,16 +567,15 @@ class LoopConstrainer {
   // The loop denoted by `LS' has `OldPreheader' as its preheader.  This
   // function creates a new preheader for `LS' and returns it.
   //
-  BasicBlock *createPreheader(const LoopConstrainer::LoopStructure &LS,
-                              BasicBlock *OldPreheader, const char *Tag) const;
+  BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader,
+                              const char *Tag) const;
 
   // `ContinuationBlockAndPreheader' was the continuation block for some call to
   // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'.
   // This function rewrites the PHI nodes in `LS.Header' to start with the
   // correct value.
   void rewriteIncomingValuesForPHIs(
-      LoopConstrainer::LoopStructure &LS,
-      BasicBlock *ContinuationBlockAndPreheader,
+      LoopStructure &LS, BasicBlock *ContinuationBlockAndPreheader,
       const LoopConstrainer::RewrittenRangeInfo &RRI) const;
 
   // Even though we do not preserve any passes at this time, we at least need to
@@ -567,7 +594,6 @@ class LoopConstrainer {
   LoopInfo &OriginalLoopInfo;
   const SCEV *LatchTakenCount;
   BasicBlock *OriginalPreheader;
-  Value *OriginalHeaderCount;
 
   // The preheader of the main loop.  This may or may not be different from
   // `OriginalPreheader'.
@@ -581,12 +607,12 @@ class LoopConstrainer {
   LoopStructure MainLoopStructure;
 
 public:
-  LoopConstrainer(Loop &L, LoopInfo &LI, ScalarEvolution &SE,
-                  InductiveRangeCheck::Range R)
-    : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE),
-      OriginalLoop(L), OriginalLoopInfo(LI), LatchTakenCount(nullptr),
-      OriginalPreheader(nullptr), OriginalHeaderCount(nullptr),
-      MainLoopPreheader(nullptr), Range(R) { }
+  LoopConstrainer(Loop &L, LoopInfo &LI, const LoopStructure &LS,
+                  ScalarEvolution &SE, InductiveRangeCheck::Range R)
+      : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()),
+        SE(SE), OriginalLoop(L), OriginalLoopInfo(LI), LatchTakenCount(nullptr),
+        OriginalPreheader(nullptr), MainLoopPreheader(nullptr), Range(R),
+        MainLoopStructure(LS) {}
 
   // Entry point for the algorithm.  Returns true on success.
   bool run();
@@ -601,158 +627,288 @@ void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block,
       PN->setIncomingBlock(i, ReplaceBy);
 }
 
-bool LoopConstrainer::recognizeLoop(LoopStructure &LoopStructureOut,
-                                    const SCEV *&LatchCountOut,
-                                    BasicBlock *&PreheaderOut,
-                                    const char *&FailureReason) const {
-  using namespace llvm::PatternMatch;
+static bool CanBeSMax(ScalarEvolution &SE, const SCEV *S) {
+  APInt SMax =
+      APInt::getSignedMaxValue(cast<IntegerType>(S->getType())->getBitWidth());
+  return SE.getSignedRange(S).contains(SMax) &&
+         SE.getUnsignedRange(S).contains(SMax);
+}
 
-  assert(OriginalLoop.isLoopSimplifyForm() &&
-         "should follow from addRequired<>");
+static bool CanBeSMin(ScalarEvolution &SE, const SCEV *S) {
+  APInt SMin =
+      APInt::getSignedMinValue(cast<IntegerType>(S->getType())->getBitWidth());
+  return SE.getSignedRange(S).contains(SMin) &&
+         SE.getUnsignedRange(S).contains(SMin);
+}
 
-  BasicBlock *Latch = OriginalLoop.getLoopLatch();
-  if (!OriginalLoop.isLoopExiting(Latch)) {
-    FailureReason = "no loop latch";
-    return false;
-  }
+Optional<LoopStructure>
+LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BPI,
+                                  Loop &L, const char *&FailureReason) {
+  assert(L.isLoopSimplifyForm() && "should follow from addRequired<>");
 
-  PHINode *CIV = OriginalLoop.getCanonicalInductionVariable();
-  if (!CIV) {
-    FailureReason = "no CIV";
-    return false;
+  BasicBlock *Latch = L.getLoopLatch();
+  if (!L.isLoopExiting(Latch)) {
+    FailureReason = "no loop latch";
+    return None;
   }
 
-  BasicBlock *Header = OriginalLoop.getHeader();
-  BasicBlock *Preheader = OriginalLoop.getLoopPreheader();
+  BasicBlock *Header = L.getHeader();
+  BasicBlock *Preheader = L.getLoopPreheader();
   if (!Preheader) {
     FailureReason = "no preheader";
-    return false;
+    return None;
   }
 
-  Value *CIVNext = CIV->getIncomingValueForBlock(Latch);
-  Value *CIVStart = CIV->getIncomingValueForBlock(Preheader);
-
-  const SCEV *LatchCount = SE.getExitCount(&OriginalLoop, Latch);
-  if (isa<SCEVCouldNotCompute>(LatchCount)) {
-    FailureReason = "could not compute latch count";
-    return false;
+  BranchInst *LatchBr = dyn_cast<BranchInst>(&*Latch->rbegin());
+  if (!LatchBr || LatchBr->isUnconditional()) {
+    FailureReason = "latch terminator not conditional branch";
+    return None;
   }
 
-  // While SCEV does most of the analysis for us, we still have to
-  // modify the latch; and currently we can only deal with certain
-  // kinds of latches.  This can be made more sophisticated as needed.
+  unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
 
-  BranchInst *LatchBr = dyn_cast<BranchInst>(&*Latch->rbegin());
+  BranchProbability ExitProbability =
+    BPI.getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx);
 
-  if (!LatchBr || LatchBr->isUnconditional()) {
-    FailureReason = "latch terminator not conditional branch";
-    return false;
+  if (ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) {
+    FailureReason = "short running loop, not profitable";
+    return None;
   }
 
-  // Currently we only support a latch condition of the form:
-  //
-  //  %condition = icmp slt %civNext, %limit
-  //  br i1 %condition, label %header, label %exit
+  ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
+  if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
+    FailureReason = "latch terminator branch not conditional on integral icmp";
+    return None;
+  }
 
-  if (LatchBr->getSuccessor(0) != Header) {
-    FailureReason = "unknown latch form (header not first successor)";
-    return false;
+  const SCEV *LatchCount = SE.getExitCount(&L, Latch);
+  if (isa<SCEVCouldNotCompute>(LatchCount)) {
+    FailureReason = "could not compute latch count";
+    return None;
   }
 
-  Value *CIVComparedTo = nullptr;
-  ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
-  if (!(match(LatchBr->getCondition(),
-              m_ICmp(Pred, m_Specific(CIVNext), m_Value(CIVComparedTo))) &&
-        Pred == ICmpInst::ICMP_SLT)) {
-    FailureReason = "unknown latch form (not slt)";
-    return false;
+  ICmpInst::Predicate Pred = ICI->getPredicate();
+  Value *LeftValue = ICI->getOperand(0);
+  const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
+  IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
+
+  Value *RightValue = ICI->getOperand(1);
+  const SCEV *RightSCEV = SE.getSCEV(RightValue);
+
+  // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
+  if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
+    if (isa<SCEVAddRecExpr>(RightSCEV)) {
+      std::swap(LeftSCEV, RightSCEV);
+      std::swap(LeftValue, RightValue);
+      Pred = ICmpInst::getSwappedPredicate(Pred);
+    } else {
+      FailureReason = "no add recurrences in the icmp";
+      return None;
+    }
   }
 
-  // IndVarSimplify will sometimes leave behind (in SCEV's cache) backedge-taken
-  // counts that are narrower than the canonical induction variable.  These
-  // values are still accurate, and we could probably use them after sign/zero
-  // extension; but for now we just bail out of the transformation to keep
-  // things simple.
-  const SCEV *CIVComparedToSCEV = SE.getSCEV(CIVComparedTo);
-  if (isa<SCEVCouldNotCompute>(CIVComparedToSCEV) ||
-      CIVComparedToSCEV->getType() != LatchCount->getType()) {
-    FailureReason = "could not relate CIV to latch expression";
+  auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
+    if (AR->getNoWrapFlags(SCEV::FlagNSW))
+      return true;
+
+    IntegerType *Ty = cast<IntegerType>(AR->getType());
+    IntegerType *WideTy =
+        IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
+
+    const SCEVAddRecExpr *ExtendAfterOp =
+        dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
+    if (ExtendAfterOp) {
+      const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
+      const SCEV *ExtendedStep =
+          SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
+
+      bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
+                          ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
+
+      if (NoSignedWrap)
+        return true;
+    }
+
+    // We may have proved this when computing the sign extension above.
+    return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
+  };
+
+  auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing) {
+    if (!AR->isAffine())
+      return false;
+
+    // Currently we only work with induction variables that have been proved to
+    // not wrap.  This restriction can potentially be lifted in the future.
+
+    if (!HasNoSignedWrap(AR))
+      return false;
+
+    if (const SCEVConstant *StepExpr =
+            dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) {
+      ConstantInt *StepCI = StepExpr->getValue();
+      if (StepCI->isOne() || StepCI->isMinusOne()) {
+        IsIncreasing = StepCI->isOne();
+        return true;
+      }
+    }
+
     return false;
+  };
+
+  // `ICI` is interpreted as taking the backedge if the *next* value of the
+  // induction variable satisfies some constraint.
+
+  const SCEVAddRecExpr *IndVarNext = cast<SCEVAddRecExpr>(LeftSCEV);
+  bool IsIncreasing = false;
+  if (!IsInductionVar(IndVarNext, IsIncreasing)) {
+    FailureReason = "LHS in icmp not induction variable";
+    return None;
   }
 
-  const SCEV *ShouldBeOne = SE.getMinusSCEV(CIVComparedToSCEV, LatchCount);
-  const SCEVConstant *SCEVOne = dyn_cast<SCEVConstant>(ShouldBeOne);
-  if (!SCEVOne || SCEVOne->getValue()->getValue() != 1) {
-    FailureReason = "unexpected header count in latch";
-    return false;
+  ConstantInt *One = ConstantInt::get(IndVarTy, 1);
+  // TODO: generalize the predicates here to also match their unsigned variants.
+  if (IsIncreasing) {
+    bool FoundExpectedPred =
+        (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 1) ||
+        (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 0);
+
+    if (!FoundExpectedPred) {
+      FailureReason = "expected icmp slt semantically, found something else";
+      return None;
+    }
+
+    if (LatchBrExitIdx == 0) {
+      if (CanBeSMax(SE, RightSCEV)) {
+        // TODO: this restriction is easily removable -- we just have to
+        // remember that the icmp was an slt and not an sle.
+        FailureReason = "limit may overflow when coercing sle to slt";
+        return None;
+      }
+
+      IRBuilder<> B(&*Preheader->rbegin());
+      RightValue = B.CreateAdd(RightValue, One);
+    }
+
+  } else {
+    bool FoundExpectedPred =
+        (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 1) ||
+        (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 0);
+
+    if (!FoundExpectedPred) {
+      FailureReason = "expected icmp sgt semantically, found something else";
+      return None;
+    }
+
+    if (LatchBrExitIdx == 0) {
+      if (CanBeSMin(SE, RightSCEV)) {
+        // TODO: this restriction is easily removable -- we just have to
+        // remember that the icmp was an sgt and not an sge.
+        FailureReason = "limit may overflow when coercing sge to sgt";
+        return None;
+      }
+
+      IRBuilder<> B(&*Preheader->rbegin());
+      RightValue = B.CreateSub(RightValue, One);
+    }
   }
 
-  unsigned LatchBrExitIdx = 1;
+  const SCEV *StartNext = IndVarNext->getStart();
+  const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE));
+  const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
+
   BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
 
-  assert(SE.getLoopDisposition(LatchCount, &OriginalLoop) ==
+  assert(SE.getLoopDisposition(LatchCount, &L) ==
              ScalarEvolution::LoopInvariant &&
          "loop variant exit count doesn't make sense!");
 
-  assert(!OriginalLoop.contains(LatchExit) && "expected an exit block!");
+  assert(!L.contains(LatchExit) && "expected an exit block!");
+  const DataLayout &DL = Preheader->getModule()->getDataLayout();
+  Value *IndVarStartV =
+      SCEVExpander(SE, DL, "irce")
+          .expandCodeFor(IndVarStart, IndVarTy, &*Preheader->rbegin());
+  IndVarStartV->setName("indvar.start");
+
+  LoopStructure Result;
+
+  Result.Tag = "main";
+  Result.Header = Header;
+  Result.Latch = Latch;
+  Result.LatchBr = LatchBr;
+  Result.LatchExit = LatchExit;
+  Result.LatchBrExitIdx = LatchBrExitIdx;
+  Result.IndVarStart = IndVarStartV;
+  Result.IndVarNext = LeftValue;
+  Result.IndVarIncreasing = IsIncreasing;
+  Result.LoopExitAt = RightValue;
 
-  LoopStructureOut.Tag = "main";
-  LoopStructureOut.Header = Header;
-  LoopStructureOut.Latch = Latch;
-  LoopStructureOut.LatchBr = LatchBr;
-  LoopStructureOut.LatchExit = LatchExit;
-  LoopStructureOut.LatchBrExitIdx = LatchBrExitIdx;
-  LoopStructureOut.CIV = CIV;
-  LoopStructureOut.CIVNext = CIVNext;
-  LoopStructureOut.CIVStart = CIVStart;
-
-  LatchCountOut = LatchCount;
-  PreheaderOut = Preheader;
   FailureReason = nullptr;
 
-  return true;
+  return Result;
 }
 
 Optional<LoopConstrainer::SubRanges>
-LoopConstrainer::calculateSubRanges(Value *&HeaderCountOut) const {
+LoopConstrainer::calculateSubRanges() const {
   IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType());
 
   if (Range.getType() != Ty)
     return None;
 
-  SCEVExpander Expander(SE, "irce");
-  Instruction *InsertPt = OriginalPreheader->getTerminator();
-
   LoopConstrainer::SubRanges Result;
 
   // I think we can be more aggressive here and make this nuw / nsw if the
   // addition that feeds into the icmp for the latch's terminating branch is nuw
   // / nsw.  In any case, a wrapping 2's complement addition is safe.
   ConstantInt *One = ConstantInt::get(Ty, 1);
-  const SCEV *HeaderCountSCEV = SE.getAddExpr(LatchTakenCount, SE.getSCEV(One));
-  HeaderCountOut = Expander.expandCodeFor(HeaderCountSCEV, Ty, InsertPt);
+  const SCEV *Start = SE.getSCEV(MainLoopStructure.IndVarStart);
+  const SCEV *End = SE.getSCEV(MainLoopStructure.LoopExitAt);
+
+  bool Increasing = MainLoopStructure.IndVarIncreasing;
+
+  // We compute `Smallest` and `Greatest` such that [Smallest, Greatest) is the
+  // range of values the induction variable takes.
 
-  const SCEV *Zero = SE.getConstant(Ty, 0);
+  const SCEV *Smallest = nullptr, *Greatest = nullptr;
+
+  if (Increasing) {
+    Smallest = Start;
+    Greatest = End;
+  } else {
+    // These two computations may sign-overflow.  Here is why that is okay:
+    //
+    // We know that the induction variable does not sign-overflow on any
+    // iteration except the last one, and it starts at `Start` and ends at
+    // `End`, decrementing by one every time.
+    //
+    //  * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the
+    //    induction variable is decreasing we know that that the smallest value
+    //    the loop body is actually executed with is `INT_SMIN` == `Smallest`.
+    //
+    //  * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`.  In
+    //    that case, `Clamp` will always return `Smallest` and
+    //    [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`)
+    //    will be an empty range.  Returning an empty range is always safe.
+    //
+
+    Smallest = SE.getAddExpr(End, SE.getSCEV(One));
+    Greatest = SE.getAddExpr(Start, SE.getSCEV(One));
+  }
+
+  auto Clamp = [this, Smallest, Greatest](const SCEV *S) {
+    return SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S));
+  };
 
   // In some cases we can prove that we don't need a pre or post loop
 
   bool ProvablyNoPreloop =
-    SE.isKnownPredicate(ICmpInst::ICMP_SLE, Range.getBegin(), Zero);
-  if (!ProvablyNoPreloop) {
-    const SCEV *ExitPreLoopAtSCEV =
-      SE.getSMinExpr(HeaderCountSCEV, Range.getBegin());
-    Result.ExitPreLoopAt =
-      Expander.expandCodeFor(ExitPreLoopAtSCEV, Ty, InsertPt);
-  }
+      SE.isKnownPredicate(ICmpInst::ICMP_SLE, Range.getBegin(), Smallest);
+  if (!ProvablyNoPreloop)
+    Result.LowLimit = Clamp(Range.getBegin());
 
   bool ProvablyNoPostLoop =
-    SE.isKnownPredicate(ICmpInst::ICMP_SLE, HeaderCountSCEV, Range.getEnd());
-  if (!ProvablyNoPostLoop) {
-    const SCEV *ExitMainLoopAtSCEV =
-      SE.getSMinExpr(HeaderCountSCEV, Range.getEnd());
-    Result.ExitMainLoopAt =
-      Expander.expandCodeFor(ExitMainLoopAtSCEV, Ty, InsertPt);
-  }
+      SE.isKnownPredicate(ICmpInst::ICMP_SLE, Greatest, Range.getEnd());
+  if (!ProvablyNoPostLoop)
+    Result.HighLimit = Clamp(Range.getEnd());
 
   return Result;
 }
@@ -809,7 +965,7 @@ void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
 }
 
 LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
-    const LoopStructure &LS, BasicBlock *Preheader, Value *ExitLoopAt,
+    const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
     BasicBlock *ContinuationBlock) const {
 
   // We start with a loop with a single latch:
@@ -893,32 +1049,37 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
                                       BBInsertLocation);
 
   BranchInst *PreheaderJump = cast<BranchInst>(&*Preheader->rbegin());
+  bool Increasing = LS.IndVarIncreasing;
 
   IRBuilder<> B(PreheaderJump);
 
   // EnterLoopCond - is it okay to start executing this `LS'?
-  Value *EnterLoopCond = B.CreateICmpSLT(LS.CIVStart, ExitLoopAt);
+  Value *EnterLoopCond = Increasing
+                             ? B.CreateICmpSLT(LS.IndVarStart, ExitSubloopAt)
+                             : B.CreateICmpSGT(LS.IndVarStart, ExitSubloopAt);
+
   B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
   PreheaderJump->eraseFromParent();
 
-  assert(LS.LatchBrExitIdx == 1 && "generalize this as needed!");
-
+  LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
   B.SetInsertPoint(LS.LatchBr);
+  Value *TakeBackedgeLoopCond =
+      Increasing ? B.CreateICmpSLT(LS.IndVarNext, ExitSubloopAt)
+                 : B.CreateICmpSGT(LS.IndVarNext, ExitSubloopAt);
+  Value *CondForBranch = LS.LatchBrExitIdx == 1
+                             ? TakeBackedgeLoopCond
+                             : B.CreateNot(TakeBackedgeLoopCond);
 
-  // ContinueCond - is it okay to execute the next iteration in `LS'?
-  Value *ContinueCond = B.CreateICmpSLT(LS.CIVNext, ExitLoopAt);
-
-  LS.LatchBr->setCondition(ContinueCond);
-  assert(LS.LatchBr->getSuccessor(LS.LatchBrExitIdx) == LS.LatchExit &&
-         "invariant!");
-  LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
+  LS.LatchBr->setCondition(CondForBranch);
 
   B.SetInsertPoint(RRI.ExitSelector);
 
   // IterationsLeft - are there any more iterations left, given the original
   // upper bound on the induction variable?  If not, we branch to the "real"
   // exit.
-  Value *IterationsLeft = B.CreateICmpSLT(LS.CIVNext, OriginalHeaderCount);
+  Value *IterationsLeft = Increasing
+                              ? B.CreateICmpSLT(LS.IndVarNext, LS.LoopExitAt)
+                              : B.CreateICmpSGT(LS.IndVarNext, LS.LoopExitAt);
   B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
 
   BranchInst *BranchToContinuation =
@@ -942,6 +1103,11 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
     RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
   }
 
+  RRI.IndVarEnd = PHINode::Create(LS.IndVarNext->getType(), 2, "indvar.end",
+                                  BranchToContinuation);
+  RRI.IndVarEnd->addIncoming(LS.IndVarStart, Preheader);
+  RRI.IndVarEnd->addIncoming(LS.IndVarNext, RRI.ExitSelector);
+
   // The latch exit now has a branch from `RRI.ExitSelector' instead of
   // `LS.Latch'.  The PHI nodes need to be updated to reflect that.
   for (Instruction &I : *LS.LatchExit) {
@@ -955,7 +1121,7 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
 }
 
 void LoopConstrainer::rewriteIncomingValuesForPHIs(
-    LoopConstrainer::LoopStructure &LS, BasicBlock *ContinuationBlock,
+    LoopStructure &LS, BasicBlock *ContinuationBlock,
     const LoopConstrainer::RewrittenRangeInfo &RRI) const {
 
   unsigned PHIIndex = 0;
@@ -970,13 +1136,12 @@ void LoopConstrainer::rewriteIncomingValuesForPHIs(
         PN->setIncomingValue(i, RRI.PHIValuesAtPseudoExit[PHIIndex++]);
   }
 
-  LS.CIVStart = LS.CIV->getIncomingValueForBlock(ContinuationBlock);
+  LS.IndVarStart = RRI.IndVarEnd;
 }
 
-BasicBlock *
-LoopConstrainer::createPreheader(const LoopConstrainer::LoopStructure &LS,
-                                 BasicBlock *OldPreheader,
-                                 const char *Tag) const {
+BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
+                                             BasicBlock *OldPreheader,
+                                             const char *Tag) const {
 
   BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
   BranchInst::Create(LS.Header, Preheader);
@@ -1004,30 +1169,79 @@ void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
 
 bool LoopConstrainer::run() {
   BasicBlock *Preheader = nullptr;
-  const char *CouldNotProceedBecause = nullptr;
-  if (!recognizeLoop(MainLoopStructure, LatchTakenCount, Preheader,
-                     CouldNotProceedBecause)) {
-    DEBUG(dbgs() << "irce: could not recognize loop, " << CouldNotProceedBecause
-                 << "\n";);
-    return false;
-  }
+  LatchTakenCount = SE.getExitCount(&OriginalLoop, MainLoopStructure.Latch);
+  Preheader = OriginalLoop.getLoopPreheader();
+  assert(!isa<SCEVCouldNotCompute>(LatchTakenCount) && Preheader != nullptr &&
+         "preconditions!");
 
   OriginalPreheader = Preheader;
   MainLoopPreheader = Preheader;
 
-  Optional<SubRanges> MaybeSR = calculateSubRanges(OriginalHeaderCount);
+  Optional<SubRanges> MaybeSR = calculateSubRanges();
   if (!MaybeSR.hasValue()) {
     DEBUG(dbgs() << "irce: could not compute subranges\n");
     return false;
   }
+
   SubRanges SR = MaybeSR.getValue();
+  bool Increasing = MainLoopStructure.IndVarIncreasing;
+  IntegerType *IVTy =
+      cast<IntegerType>(MainLoopStructure.IndVarNext->getType());
+
+  SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce");
+  Instruction *InsertPt = OriginalPreheader->getTerminator();
 
   // It would have been better to make `PreLoop' and `PostLoop'
   // `Optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
   // constructor.
   ClonedLoop PreLoop, PostLoop;
-  bool NeedsPreLoop = SR.ExitPreLoopAt.hasValue();
-  bool NeedsPostLoop = SR.ExitMainLoopAt.hasValue();
+  bool NeedsPreLoop =
+      Increasing ? SR.LowLimit.hasValue() : SR.HighLimit.hasValue();
+  bool NeedsPostLoop =
+      Increasing ? SR.HighLimit.hasValue() : SR.LowLimit.hasValue();
+
+  Value *ExitPreLoopAt = nullptr;
+  Value *ExitMainLoopAt = nullptr;
+  const SCEVConstant *MinusOneS =
+      cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
+
+  if (NeedsPreLoop) {
+    const SCEV *ExitPreLoopAtSCEV = nullptr;
+
+    if (Increasing)
+      ExitPreLoopAtSCEV = *SR.LowLimit;
+    else {
+      if (CanBeSMin(SE, *SR.HighLimit)) {
+        DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
+                     << "preloop exit limit.  HighLimit = " << *(*SR.HighLimit)
+                     << "\n");
+        return false;
+      }
+      ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
+    }
+
+    ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
+    ExitPreLoopAt->setName("exit.preloop.at");
+  }
+
+  if (NeedsPostLoop) {
+    const SCEV *ExitMainLoopAtSCEV = nullptr;
+
+    if (Increasing)
+      ExitMainLoopAtSCEV = *SR.HighLimit;
+    else {
+      if (CanBeSMin(SE, *SR.LowLimit)) {
+        DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
+                     << "mainloop exit limit.  LowLimit = " << *(*SR.LowLimit)
+                     << "\n");
+        return false;
+      }
+      ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
+    }
+
+    ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
+    ExitMainLoopAt->setName("exit.mainloop.at");
+  }
 
   // We clone these ahead of time so that we don't have to deal with changing
   // and temporarily invalid IR as we transform the loops.
@@ -1044,9 +1258,8 @@ bool LoopConstrainer::run() {
 
     MainLoopPreheader =
         createPreheader(MainLoopStructure, Preheader, "mainloop");
-    PreLoopRRI =
-        changeIterationSpaceEnd(PreLoop.Structure, Preheader,
-                                SR.ExitPreLoopAt.getValue(), MainLoopPreheader);
+    PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
+                                         ExitPreLoopAt, MainLoopPreheader);
     rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
                                  PreLoopRRI);
   }
@@ -1058,8 +1271,7 @@ bool LoopConstrainer::run() {
     PostLoopPreheader =
         createPreheader(PostLoop.Structure, Preheader, "postloop");
     PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
-                                          SR.ExitMainLoopAt.getValue(),
-                                          PostLoopPreheader);
+                                          ExitMainLoopAt, PostLoopPreheader);
     rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
                                  PostLoopRRI);
   }
@@ -1082,44 +1294,74 @@ bool LoopConstrainer::run() {
   return true;
 }
 
-/// Computes and returns a range of values for the induction variable in which
-/// the range check can be safely elided.  If it cannot compute such a range,
-/// returns None.
+/// Computes and returns a range of values for the induction variable (IndVar)
+/// in which the range check can be safely elided.  If it cannot compute such a
+/// range, returns None.
 Optional<InductiveRangeCheck::Range>
 InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE,
-                                               IRBuilder<> &B) const {
-
-  // Currently we support inequalities of the form:
+                                               const SCEVAddRecExpr *IndVar,
+                                               IRBuilder<> &) const {
+  // IndVar is of the form "A + B * I" (where "I" is the canonical induction
+  // variable, that may or may not exist as a real llvm::Value in the loop) and
+  // this inductive range check is a range check on the "C + D * I" ("C" is
+  // getOffset() and "D" is getScale()).  We rewrite the value being range
+  // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA".
+  // Currently we support this only for "B" = "D" = { 1 or -1 }, but the code
+  // can be generalized as needed.
   //
-  //   0 <= Offset + 1 * CIV < L given L >= 0
+  // The actual inequalities we solve are of the form
   //
-  // The inequality is satisfied by -Offset <= CIV < (L - Offset) [^1].  All
-  // additions and subtractions are twos-complement wrapping and comparisons are
-  // signed.
+  //   0 <= M + 1 * IndVar < L given L >= 0  (i.e. N == 1)
+  //
+  // The inequality is satisfied by -M <= IndVar < (L - M) [^1].  All additions
+  // and subtractions are twos-complement wrapping and comparisons are signed.
   //
   // Proof:
   //
-  //   If there exists CIV such that -Offset <= CIV < (L - Offset) then it
-  //   follows that -Offset <= (-Offset + L) [== Eq. 1].  Since L >= 0, if
-  //   (-Offset + L) sign-overflows then (-Offset + L) < (-Offset).  Hence by
-  //   [Eq. 1], (-Offset + L) could not have overflown.
+  //   If there exists IndVar such that -M <= IndVar < (L - M) then it follows
+  //   that -M <= (-M + L) [== Eq. 1].  Since L >= 0, if (-M + L) sign-overflows
+  //   then (-M + L) < (-M).  Hence by [Eq. 1], (-M + L) could not have
+  //   overflown.
   //
-  //   This means CIV = t + (-Offset) for t in [0, L).  Hence (CIV + Offset) =
-  //   t.  Hence 0 <= (CIV + Offset) < L
+  //   This means IndVar = t + (-M) for t in [0, L).  Hence (IndVar + M) = t.
+  //   Hence 0 <= (IndVar + M) < L
 
-  // [^1]: Note that the solution does _not_ apply if L < 0; consider values
-  // Offset = 127, CIV = 126 and L = -2 in an i8 world.
+  // [^1]: Note that the solution does _not_ apply if L < 0; consider values M =
+  // 127, IndVar = 126 and L = -2 in an i8 world.
 
-  const SCEVConstant *ScaleC = dyn_cast<SCEVConstant>(getScale());
-  if (!(ScaleC && ScaleC->getValue()->getValue() == 1)) {
-    DEBUG(dbgs() << "irce: could not compute safe iteration space for:\n";
-          print(dbgs()));
+  if (!IndVar->isAffine())
     return None;
-  }
 
-  const SCEV *Begin = SE.getNegativeSCEV(getOffset());
-  const SCEV *End = SE.getMinusSCEV(SE.getSCEV(getLength()), getOffset());
+  const SCEV *A = IndVar->getStart();
+  const SCEVConstant *B = dyn_cast<SCEVConstant>(IndVar->getStepRecurrence(SE));
+  if (!B)
+    return None;
+
+  const SCEV *C = getOffset();
+  const SCEVConstant *D = dyn_cast<SCEVConstant>(getScale());
+  if (D != B)
+    return None;
 
+  ConstantInt *ConstD = D->getValue();
+  if (!(ConstD->isMinusOne() || ConstD->isOne()))
+    return None;
+
+  const SCEV *M = SE.getMinusSCEV(C, A);
+
+  const SCEV *Begin = SE.getNegativeSCEV(M);
+  const SCEV *UpperLimit = nullptr;
+
+  // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L".
+  // We can potentially do much better here.
+  if (Value *V = getLength()) {
+    UpperLimit = SE.getSCEV(V);
+  } else {
+    assert(Kind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!");
+    unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth();
+    UpperLimit = SE.getConstant(APInt::getSignedMaxValue(BitWidth));
+  }
+
+  const SCEV *End = SE.getMinusSCEV(UpperLimit, M);
   return InductiveRangeCheck::Range(Begin, End);
 }
 
@@ -1158,7 +1400,8 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) {
   InductiveRangeCheck::AllocatorTy IRCAlloc;
   SmallVector<InductiveRangeCheck *, 16> RangeChecks;
   ScalarEvolution &SE = getAnalysis<ScalarEvolution>();
-  BranchProbabilityInfo &BPI = getAnalysis<BranchProbabilityInfo>();
+  BranchProbabilityInfo &BPI =
+      getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
 
   for (auto BBI : L->getBlocks())
     if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator()))
@@ -1169,12 +1412,33 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) {
   if (RangeChecks.empty())
     return false;
 
-  DEBUG(dbgs() << "irce: looking at loop "; L->print(dbgs());
-        dbgs() << "irce: loop has " << RangeChecks.size()
-               << " inductive range checks: \n";
-        for (InductiveRangeCheck *IRC : RangeChecks)
-          IRC->print(dbgs());
-    );
+  auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) {
+    OS << "irce: looking at loop "; L->print(OS);
+    OS << "irce: loop has " << RangeChecks.size()
+       << " inductive range checks: \n";
+    for (InductiveRangeCheck *IRC : RangeChecks)
+      IRC->print(OS);
+  };
+
+  DEBUG(PrintRecognizedRangeChecks(dbgs()));
+
+  if (PrintRangeChecks)
+    PrintRecognizedRangeChecks(errs());
+
+  const char *FailureReason = nullptr;
+  Optional<LoopStructure> MaybeLoopStructure =
+      LoopStructure::parseLoopStructure(SE, BPI, *L, FailureReason);
+  if (!MaybeLoopStructure.hasValue()) {
+    DEBUG(dbgs() << "irce: could not parse loop structure: " << FailureReason
+                 << "\n";);
+    return false;
+  }
+  LoopStructure LS = MaybeLoopStructure.getValue();
+  bool Increasing = LS.IndVarIncreasing;
+  const SCEV *MinusOne =
+      SE.getConstant(LS.IndVarNext->getType(), Increasing ? -1 : 1, true);
+  const SCEVAddRecExpr *IndVar =
+      cast<SCEVAddRecExpr>(SE.getAddExpr(SE.getSCEV(LS.IndVarNext), MinusOne));
 
   Optional<InductiveRangeCheck::Range> SafeIterRange;
   Instruction *ExprInsertPt = Preheader->getTerminator();
@@ -1183,7 +1447,7 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) {
 
   IRBuilder<> B(ExprInsertPt);
   for (InductiveRangeCheck *IRC : RangeChecks) {
-    auto Result = IRC->computeSafeIterationSpace(SE, B);
+    auto Result = IRC->computeSafeIterationSpace(SE, IndVar, B);
     if (Result.hasValue()) {
       auto MaybeSafeIterRange =
         IntersectRange(SE, SafeIterRange, Result.getValue(), B);
@@ -1197,8 +1461,8 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) {
   if (!SafeIterRange.hasValue())
     return false;
 
-  LoopConstrainer LC(*L, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), SE,
-                     SafeIterRange.getValue());
+  LoopConstrainer LC(*L, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), LS,
+                     SE, SafeIterRange.getValue());
   bool Changed = LC.run();
 
   if (Changed) {