[IRCE] Support half-range checks.
authorSanjoy Das <sanjoy@playingwithpointers.com>
Tue, 17 Mar 2015 00:42:13 +0000 (00:42 +0000)
committerSanjoy Das <sanjoy@playingwithpointers.com>
Tue, 17 Mar 2015 00:42:13 +0000 (00:42 +0000)
This change to IRCE gets it to recognize "half" range checks.  Half
range checks are range checks that only either check if the index is
`slt` some positive integer ("length") or if the index is `sge` `0`.

The range solver does not try to be clever / aggressive about solving
half-range checks -- it transforms "I < L" to "0 <= I < L" and "0 <= I"
to "0 <= I < INT_SMAX".  This is safe, but not always optimal.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@232444 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp
test/Transforms/IRCE/only-lower-check.ll [new file with mode: 0644]
test/Transforms/IRCE/only-upper-check.ll [new file with mode: 0644]

index 61f3cc5025774d0b34b65a071d8c7cff137d9f31..0429aeeb49e83986edb00a5215c13b71ac39779f 100644 (file)
@@ -96,23 +96,40 @@ namespace {
 ///
 ///  and
 ///
-///  2. a condition that is provably true for some range of values taken by the
-///     containing loop's induction variable.
+///  2. a condition that is provably true for some contiguous range of values
+///     taken by the containing loop's induction variable.
 ///
-/// Currently all inductive range checks are branches conditional on an
-/// expression of the form
-///
-///   0 <= (Offset + Scale * I) < Length
-///
-/// where `I' is the canonical induction variable of a loop to which Offset and
-/// Scale are loop invariant, and Length is >= 0.  Currently the 'false' branch
-/// is considered cold, looking at profiling data to verify that is a TODO.
-
 class InductiveRangeCheck {
+  // Classifies a range check
+  enum RangeCheckKind {
+    // Range check of the form "0 <= I".
+    RANGE_CHECK_LOWER = 1,
+
+    // Range check of the form "I < L" where L is known positive.
+    RANGE_CHECK_UPPER = 2,
+
+    // The logical and of the RANGE_CHECK_LOWER and RANGE_CHECK_UPPER
+    // conditions.
+    RANGE_CHECK_BOTH = RANGE_CHECK_LOWER | RANGE_CHECK_UPPER,
+
+    // Unrecognized range check condition.
+    RANGE_CHECK_UNKNOWN = (unsigned)-1
+  };
+
+  static const char *rangeCheckKindToStr(RangeCheckKind);
+
   const SCEV *Offset;
   const SCEV *Scale;
   Value *Length;
   BranchInst *Branch;
+  RangeCheckKind Kind;
+
+  static RangeCheckKind parseRangeCheckICmp(ICmpInst *ICI, ScalarEvolution &SE,
+                                            Value *&Index, Value *&Length);
+
+  static InductiveRangeCheck::RangeCheckKind
+  parseRangeCheck(Loop *L, ScalarEvolution &SE, Value *Condition,
+                  const SCEV *&Index, Value *&UpperLimit);
 
   InductiveRangeCheck() :
     Offset(nullptr), Scale(nullptr), Length(nullptr), Branch(nullptr) { }
@@ -124,13 +141,17 @@ public:
 
   void print(raw_ostream &OS) const {
     OS << "InductiveRangeCheck:\n";
+    OS << "  Kind: " << rangeCheckKindToStr(Kind) << "\n";
     OS << "  Offset: ";
     Offset->print(OS);
     OS << "  Scale: ";
     Scale->print(OS);
     OS << "  Length: ";
-    Length->print(OS);
-    OS << "  Branch: ";
+    if (Length)
+      Length->print(OS);
+    else
+      OS << "(null)";
+    OS << "\n  Branch: ";
     getBranch()->print(OS);
     OS << "\n";
   }
@@ -207,160 +228,146 @@ char InductiveRangeCheckElimination::ID = 0;
 INITIALIZE_PASS(InductiveRangeCheckElimination, "irce",
                 "Inductive range check elimination", false, false)
 
-static bool IsLowerBoundCheck(Value *Check, Value *&IndexV) {
-  using namespace llvm::PatternMatch;
+const char *InductiveRangeCheck::rangeCheckKindToStr(
+    InductiveRangeCheck::RangeCheckKind RCK) {
+  switch (RCK) {
+  case InductiveRangeCheck::RANGE_CHECK_UNKNOWN:
+    return "RANGE_CHECK_UNKNOWN";
 
-  ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
-  Value *LHS = nullptr, *RHS = nullptr;
+  case InductiveRangeCheck::RANGE_CHECK_UPPER:
+    return "RANGE_CHECK_UPPER";
 
-  if (!match(Check, m_ICmp(Pred, m_Value(LHS), m_Value(RHS))))
-    return false;
+  case InductiveRangeCheck::RANGE_CHECK_LOWER:
+    return "RANGE_CHECK_LOWER";
+
+  case InductiveRangeCheck::RANGE_CHECK_BOTH:
+    return "RANGE_CHECK_BOTH";
+  }
+
+  llvm_unreachable("unknown range check type!");
+}
+
+/// Parse a single ICmp instruction, `ICI`, into a range check.  If `ICI`
+/// cannot
+/// be interpreted as a range check, return `RANGE_CHECK_UNKNOWN` and set
+/// `Index` and `Length` to `nullptr`.  Otherwise set `Index` to the value
+/// being
+/// range checked, and set `Length` to the upper limit `Index` is being range
+/// checked with if (and only if) the range check type is stronger or equal to
+/// RANGE_CHECK_UPPER.
+///
+InductiveRangeCheck::RangeCheckKind
+InductiveRangeCheck::parseRangeCheckICmp(ICmpInst *ICI, ScalarEvolution &SE,
+                                         Value *&Index, Value *&Length) {
+
+  using namespace llvm::PatternMatch;
+
+  ICmpInst::Predicate Pred = ICI->getPredicate();
+  Value *LHS = ICI->getOperand(0);
+  Value *RHS = ICI->getOperand(1);
 
   switch (Pred) {
   default:
-    return false;
+    return RANGE_CHECK_UNKNOWN;
 
   case ICmpInst::ICMP_SLE:
     std::swap(LHS, RHS);
   // fallthrough
   case ICmpInst::ICMP_SGE:
-    if (!match(RHS, m_ConstantInt<0>()))
-      return false;
-    IndexV = LHS;
-    return true;
+    if (match(RHS, m_ConstantInt<0>())) {
+      Index = LHS;
+      return RANGE_CHECK_LOWER;
+    }
+    return RANGE_CHECK_UNKNOWN;
 
   case ICmpInst::ICMP_SLT:
     std::swap(LHS, RHS);
   // fallthrough
   case ICmpInst::ICMP_SGT:
-    if (!match(RHS, m_ConstantInt<-1>()))
-      return false;
-    IndexV = LHS;
-    return true;
-  }
-}
-
-static bool IsUpperBoundCheck(Value *Check, Value *Index, Value *&UpperLimit) {
-  using namespace llvm::PatternMatch;
-
-  ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
-  Value *LHS = nullptr, *RHS = nullptr;
-
-  if (!match(Check, m_ICmp(Pred, m_Value(LHS), m_Value(RHS))))
-    return false;
+    if (match(RHS, m_ConstantInt<-1>())) {
+      Index = LHS;
+      return RANGE_CHECK_LOWER;
+    }
 
-  switch (Pred) {
-  default:
-    return false;
+    if (SE.isKnownNonNegative(SE.getSCEV(LHS))) {
+      Index = RHS;
+      Length = LHS;
+      return RANGE_CHECK_UPPER;
+    }
+    return RANGE_CHECK_UNKNOWN;
 
-  case ICmpInst::ICMP_SGT:
+  case ICmpInst::ICMP_ULT:
     std::swap(LHS, RHS);
   // fallthrough
-  case ICmpInst::ICMP_SLT:
-    if (LHS != Index)
-      return false;
-    UpperLimit = RHS;
-    return true;
-
   case ICmpInst::ICMP_UGT:
-    std::swap(LHS, RHS);
-  // fallthrough
-  case ICmpInst::ICMP_ULT:
-    if (LHS != Index)
-      return false;
-    UpperLimit = RHS;
-    return true;
+    if (SE.isKnownNonNegative(SE.getSCEV(LHS))) {
+      Index = RHS;
+      Length = LHS;
+      return RANGE_CHECK_BOTH;
+    }
+    return RANGE_CHECK_UNKNOWN;
   }
+
+  llvm_unreachable("default clause returns!");
 }
 
-/// Split a condition into something semantically equivalent to (0 <= I <
-/// Limit), both comparisons signed and Len loop invariant on L and positive.
-/// On success, return true and set Index to I and UpperLimit to Limit.  Return
-/// false on failure (we may still write to UpperLimit and Index on failure).
-/// It does not try to interpret I as a loop index.
-///
-static bool SplitRangeCheckCondition(Loop *L, ScalarEvolution &SE,
+/// Parses an arbitrary condition into a range check.  `Length` is set only if
+/// the range check is recognized to be `RANGE_CHECK_UPPER` or stronger.
+InductiveRangeCheck::RangeCheckKind
+InductiveRangeCheck::parseRangeCheck(Loop *L, ScalarEvolution &SE,
                                      Value *Condition, const SCEV *&Index,
-                                     Value *&UpperLimit) {
-
-  // TODO: currently this catches some silly cases like comparing "%idx slt 1".
-  // Our transformations are still correct, but less likely to be profitable in
-  // those cases.  We have to come up with some heuristics that pick out the
-  // range checks that are more profitable to clone a loop for.  This function
-  // in general can be made more robust.
-
+                                     Value *&Length) {
   using namespace llvm::PatternMatch;
 
   Value *A = nullptr;
   Value *B = nullptr;
-  ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
-
-  // In these early checks we assume that the matched UpperLimit is positive.
-  // We'll verify that fact later, before returning true.
 
   if (match(Condition, m_And(m_Value(A), m_Value(B)))) {
-    Value *IndexV = nullptr;
-    Value *ExpectedUpperBoundCheck = nullptr;
+    Value *IndexA = nullptr, *IndexB = nullptr;
+    Value *LengthA = nullptr, *LengthB = nullptr;
+    ICmpInst *ICmpA = dyn_cast<ICmpInst>(A), *ICmpB = dyn_cast<ICmpInst>(B);
 
-    if (IsLowerBoundCheck(A, IndexV))
-      ExpectedUpperBoundCheck = B;
-    else if (IsLowerBoundCheck(B, IndexV))
-      ExpectedUpperBoundCheck = A;
-    else
-      return false;
+    if (!ICmpA || !ICmpB)
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 
-    if (!IsUpperBoundCheck(ExpectedUpperBoundCheck, IndexV, UpperLimit))
-      return false;
+    auto RCKindA = parseRangeCheckICmp(ICmpA, SE, IndexA, LengthA);
+    auto RCKindB = parseRangeCheckICmp(ICmpB, SE, IndexB, LengthB);
 
-    Index = SE.getSCEV(IndexV);
+    if (RCKindA == InductiveRangeCheck::RANGE_CHECK_UNKNOWN ||
+        RCKindB == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 
-    if (isa<SCEVCouldNotCompute>(Index))
-      return false;
+    if (IndexA != IndexB)
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 
-  } else if (match(Condition, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
-    switch (Pred) {
-    default:
-      return false;
+    if (LengthA != nullptr && LengthB != nullptr && LengthA != LengthB)
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 
-    case ICmpInst::ICMP_SGT:
-      std::swap(A, B);
-    // fall through
-    case ICmpInst::ICMP_SLT:
-      UpperLimit = B;
-      Index = SE.getSCEV(A);
-      if (isa<SCEVCouldNotCompute>(Index) || !SE.isKnownNonNegative(Index))
-        return false;
-      break;
+    Index = SE.getSCEV(IndexA);
+    if (isa<SCEVCouldNotCompute>(Index))
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 
-    case ICmpInst::ICMP_UGT:
-      std::swap(A, B);
-    // fall through
-    case ICmpInst::ICMP_ULT:
-      UpperLimit = B;
-      Index = SE.getSCEV(A);
-      if (isa<SCEVCouldNotCompute>(Index))
-        return false;
-      break;
-    }
-  } else {
-    return false;
+    Length = LengthA == nullptr ? LengthB : LengthA;
+
+    return (InductiveRangeCheck::RangeCheckKind)(RCKindA | RCKindB);
   }
 
-  const SCEV *UpperLimitSCEV = SE.getSCEV(UpperLimit);
-  if (isa<SCEVCouldNotCompute>(UpperLimitSCEV) ||
-      !SE.isKnownNonNegative(UpperLimitSCEV))
-    return false;
+  if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
+    Value *IndexVal = nullptr;
 
-  if (SE.getLoopDisposition(UpperLimitSCEV, L) !=
-      ScalarEvolution::LoopInvariant) {
-    DEBUG(dbgs() << " in function: " << L->getHeader()->getParent()->getName()
-                 << " ";
-          dbgs() << " UpperLimit is not loop invariant: "
-                 << UpperLimit->getName() << "\n";);
-    return false;
+    auto RCKind = parseRangeCheckICmp(ICI, SE, IndexVal, Length);
+
+    if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
+
+    Index = SE.getSCEV(IndexVal);
+    if (isa<SCEVCouldNotCompute>(Index))
+      return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
+
+    return RCKind;
   }
 
-  return true;
+  return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 }
 
 
@@ -380,10 +387,15 @@ InductiveRangeCheck::create(InductiveRangeCheck::AllocatorTy &A, BranchInst *BI,
   Value *Length = nullptr;
   const SCEV *IndexSCEV = nullptr;
 
-  if (!SplitRangeCheckCondition(L, SE, BI->getCondition(), IndexSCEV, Length))
+  auto RCKind = InductiveRangeCheck::parseRangeCheck(L, SE, BI->getCondition(),
+                                                     IndexSCEV, Length);
+
+  if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
     return nullptr;
 
-  assert(IndexSCEV && Length && "contract with SplitRangeCheckCondition!");
+  assert(IndexSCEV && "contract with SplitRangeCheckCondition!");
+  assert(!(RCKind & InductiveRangeCheck::RANGE_CHECK_UPPER) ||
+         Length && "contract with SplitRangeCheckCondition!");
 
   const SCEVAddRecExpr *IndexAddRec = dyn_cast<SCEVAddRecExpr>(IndexSCEV);
   bool IsAffineIndex =
@@ -397,6 +409,7 @@ InductiveRangeCheck::create(InductiveRangeCheck::AllocatorTy &A, BranchInst *BI,
   IRC->Offset = IndexAddRec->getStart();
   IRC->Scale = IndexAddRec->getStepRecurrence(SE);
   IRC->Branch = BI;
+  IRC->Kind = RCKind;
   return IRC;
 }
 
@@ -1294,8 +1307,19 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE,
   const SCEV *M = SE.getMinusSCEV(C, A);
 
   const SCEV *Begin = SE.getNegativeSCEV(M);
-  const SCEV *End = SE.getMinusSCEV(SE.getSCEV(getLength()), M);
+  const SCEV *UpperLimit = nullptr;
+
+  // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L".
+  // We can potentially do much better here.
+  if (Value *V = getLength()) {
+    UpperLimit = SE.getSCEV(V);
+  } else {
+    assert(Kind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!");
+    unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth();
+    UpperLimit = SE.getConstant(APInt::getSignedMaxValue(BitWidth));
+  }
 
+  const SCEV *End = SE.getMinusSCEV(UpperLimit, M);
   return InductiveRangeCheck::Range(Begin, End);
 }
 
diff --git a/test/Transforms/IRCE/only-lower-check.ll b/test/Transforms/IRCE/only-lower-check.ll
new file mode 100644 (file)
index 0000000..4cbf5da
--- /dev/null
@@ -0,0 +1,37 @@
+; RUN: opt -debug-only=irce -irce < %s 2>&1 | FileCheck %s
+
+; CHECK: irce: loop has 1 inductive range checks:
+; CHECK-NEXT: InductiveRangeCheck:
+; CHECK-NEXT:   Kind: RANGE_CHECK_LOWER
+; CHECK-NEXT:   Offset: (-1 + %n)  Scale: -1  Length: (null)
+; CHECK-NEXT:   Branch:   br i1 %abc, label %in.bounds, label %out.of.bounds
+; CHECK-NEXT: irce: in function only_lower_check: constrained Loop at depth 1 containing: %loop<header><exiting>,%in.bounds<latch><exiting>
+
+define void @only_lower_check(i32 *%arr, i32 *%a_len_ptr, i32 %n) {
+ entry:
+  %len = load i32, i32* %a_len_ptr, !range !0
+  %first.itr.check = icmp sgt i32 %n, 0
+  %start = sub i32 %n, 1
+  br i1 %first.itr.check, label %loop, label %exit
+
+ loop:
+  %idx = phi i32 [ %start, %entry ] , [ %idx.dec, %in.bounds ]
+  %idx.dec = sub i32 %idx, 1
+  %abc = icmp sge i32 %idx, 0
+  br i1 %abc, label %in.bounds, label %out.of.bounds, !prof !1
+
+ in.bounds:
+  %addr = getelementptr i32, i32* %arr, i32 %idx
+  store i32 0, i32* %addr
+  %next = icmp sgt i32 %idx.dec, -1
+  br i1 %next, label %loop, label %exit
+
+ out.of.bounds:
+  ret void
+
+ exit:
+  ret void
+}
+
+!0 = !{i32 0, i32 2147483647}
+!1 = !{!"branch_weights", i32 64, i32 4}
diff --git a/test/Transforms/IRCE/only-upper-check.ll b/test/Transforms/IRCE/only-upper-check.ll
new file mode 100644 (file)
index 0000000..eb18408
--- /dev/null
@@ -0,0 +1,37 @@
+; RUN: opt -irce -debug-only=irce %s -S 2>&1 | FileCheck %s
+
+; CHECK: irce: loop has 1 inductive range checks:
+; CHECK-NEXT:InductiveRangeCheck:
+; CHECK-NEXT:  Kind: RANGE_CHECK_UPPER
+; CHECK-NEXT:  Offset: %offset  Scale: 1  Length:   %len = load i32, i32* %a_len_ptr, !range !0
+; CHECK-NEXT:  Branch:   br i1 %abc, label %in.bounds, label %out.of.bounds, !prof !1
+; CHECK-NEXT: irce: in function incrementing: constrained Loop at depth 1 containing: %loop<header><exiting>,%in.bounds<latch><exiting>
+
+define void @incrementing(i32 *%arr, i32 *%a_len_ptr, i32 %n, i32 %offset) {
+ entry:
+  %len = load i32, i32* %a_len_ptr, !range !0
+  %first.itr.check = icmp sgt i32 %n, 0
+  br i1 %first.itr.check, label %loop, label %exit
+
+ loop:
+  %idx = phi i32 [ 0, %entry ] , [ %idx.next, %in.bounds ]
+  %idx.next = add i32 %idx, 1
+  %array.idx = add i32 %idx, %offset
+  %abc = icmp slt i32 %array.idx, %len
+  br i1 %abc, label %in.bounds, label %out.of.bounds, !prof !1
+
+ in.bounds:
+  %addr = getelementptr i32, i32* %arr, i32 %array.idx
+  store i32 0, i32* %addr
+  %next = icmp slt i32 %idx.next, %n
+  br i1 %next, label %loop, label %exit
+
+ out.of.bounds:
+  ret void
+
+ exit:
+  ret void
+}
+
+!0 = !{i32 0, i32 2147483647}
+!1 = !{!"branch_weights", i32 64, i32 4}