Step #2 to improve trip count analysis for loops like this:
authorChris Lattner <sabre@nondot.org>
Sun, 9 Jan 2011 22:26:35 +0000 (22:26 +0000)
committerChris Lattner <sabre@nondot.org>
Sun, 9 Jan 2011 22:26:35 +0000 (22:26 +0000)
void f(int* begin, int* end) { std::fill(begin, end, 0); }

which turns into a != exit expression where one pointer is
strided and (thanks to step #1) known to not overflow, and
the other is loop invariant.

The observation here is that, though the IV is strided by
4 in this case, that the IV *has* to become equal to the
end value.  It cannot "miss" the end value by stepping over
it, because if it did, the strided IV expression would
eventually wrap around.

Handle this by turning A != B into "A-B != 0" where the A-B
part is known to be NUW.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@123131 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
lib/Analysis/ScalarEvolution.cpp

index d9104067711f86c58daa0598115f9793d13d3607..4706c09e7916c0ee31a9aa77f3c70f968aa80c4a 100644 (file)
@@ -539,7 +539,8 @@ namespace llvm {
 
     /// getMinusSCEV - Return LHS-RHS.
     ///
-    const SCEV *getMinusSCEV(const SCEV *LHS, const SCEV *RHS);
+    const SCEV *getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
+                             bool HasNUW = false, bool HasNSW = false);
 
     /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion
     /// of the input value to the specified type.  If the type must be
index b675e286bad87fe3f023c80dc16d332612ec114d..d3d65847cbca195c4f2cc5380e68166e1155504d 100644 (file)
@@ -2448,22 +2448,21 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
 
 /// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
 ///
-const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS,
-                                          const SCEV *RHS) {
+const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
+                                          bool HasNUW, bool HasNSW) {
   // Fast path: X - X --> 0.
   if (LHS == RHS)
     return getConstant(LHS->getType(), 0);
 
   // X - Y --> X + -Y
-  return getAddExpr(LHS, getNegativeSCEV(RHS));
+  return getAddExpr(LHS, getNegativeSCEV(RHS), HasNUW, HasNSW);
 }
 
 /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
 /// input value to the specified type.  If the type must be extended, it is zero
 /// extended.
 const SCEV *
-ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V,
-                                         const Type *Ty) {
+ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, const Type *Ty) {
   const Type *SrcTy = V->getType();
   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
@@ -3942,6 +3941,105 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L,
   return ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
 }
 
+static const SCEVAddRecExpr *
+isSimpleUnwrappingAddRec(const SCEV *S, const Loop *L) {
+  const SCEVAddRecExpr *SA = dyn_cast<SCEVAddRecExpr>(S);
+  
+  // The SCEV must be an addrec of this loop.
+  if (!SA || SA->getLoop() != L || !SA->isAffine())
+    return 0;
+  
+  // The SCEV must be known to not wrap in some way to be interesting.
+  if (!SA->hasNoUnsignedWrap() && !SA->hasNoSignedWrap())
+    return 0;
+
+  // The stride must be a constant so that we know if it is striding up or down.
+  if (!isa<SCEVConstant>(SA->getOperand(1)))
+    return 0;
+  return SA;
+}
+
+/// getMinusSCEVForExitTest - When considering an exit test for a loop with a
+/// "x != y" exit test, we turn this into a computation that evaluates x-y != 0,
+/// and this function returns the expression to use for x-y.  We know and take
+/// advantage of the fact that this subtraction is only being used in a
+/// comparison by zero context.
+///
+static const SCEV *getMinusSCEVForExitTest(const SCEV *LHS, const SCEV *RHS,
+                                           const Loop *L, ScalarEvolution &SE) {
+  // If either LHS or RHS is an AddRec SCEV (of this loop) that is known to not
+  // wrap (either NSW or NUW), then we know that the value will either become
+  // the other one (and thus the loop terminates), that the loop will terminate
+  // through some other exit condition first, or that the loop has undefined
+  // behavior.  This information is useful when the addrec has a stride that is
+  // != 1 or -1, because it means we can't "miss" the exit value.
+  //
+  // In any of these three cases, it is safe to turn the exit condition into a
+  // "counting down" AddRec (to zero) by subtracting the two inputs as normal,
+  // but since we know that the "end cannot be missed" we can force the
+  // resulting AddRec to be a NUW addrec.  Since it is counting down, this means
+  // that the AddRec *cannot* pass zero.
+
+  // See if LHS and RHS are addrec's we can handle.
+  const SCEVAddRecExpr *LHSA = isSimpleUnwrappingAddRec(LHS, L);
+  const SCEVAddRecExpr *RHSA = isSimpleUnwrappingAddRec(RHS, L);
+  
+  // If neither addrec is interesting, just return a minus.
+  if (RHSA == 0 && LHSA == 0)
+    return SE.getMinusSCEV(LHS, RHS);
+  
+  // If only one of LHS and RHS are an AddRec of this loop, make sure it is LHS.
+  if (RHSA && LHSA == 0) {
+    // Safe because a-b === b-a for comparisons against zero.
+    std::swap(LHS, RHS);
+    std::swap(LHSA, RHSA);
+  }
+  
+  // Handle the case when only one is advancing in a non-overflowing way.
+  if (RHSA == 0) {
+    // If RHS is loop varying, then we can't predict when LHS will cross it.
+    if (!SE.isLoopInvariant(RHS, L))
+      return SE.getMinusSCEV(LHS, RHS);
+    
+    // If LHS has a positive stride, then we compute RHS-LHS, because the loop
+    // is counting up until it crosses RHS (which must be larger than LHS).  If
+    // it is negative, we compute LHS-RHS because we're counting down to RHS.
+    const ConstantInt *Stride =
+      cast<SCEVConstant>(LHSA->getOperand(1))->getValue();
+    if (Stride->getValue().isNegative())
+      std::swap(LHS, RHS);
+
+    return SE.getMinusSCEV(RHS, LHS, true /*HasNUW*/);
+  }
+  
+  // If both LHS and RHS are interesting, we have something like:
+  //  a+i*4 != b+i*8.
+  const ConstantInt *LHSStride =
+    cast<SCEVConstant>(LHSA->getOperand(1))->getValue();
+  const ConstantInt *RHSStride =
+    cast<SCEVConstant>(RHSA->getOperand(1))->getValue();
+  
+  // If the strides are equal, then this is just a (complex) loop invariant
+  // comparison of a/b.
+  if (LHSStride == RHSStride)
+    return SE.getMinusSCEV(LHSA->getStart(), RHSA->getStart());
+  
+  // If the signs of the strides differ, then the negative stride is counting
+  // down to the positive stride.
+  if (LHSStride->getValue().isNegative() != RHSStride->getValue().isNegative()){
+    if (RHSStride->getValue().isNegative())
+      std::swap(LHS, RHS);
+  } else {
+    // If LHS's stride is smaller than RHS's stride, then "b" must be less than
+    // "a" and "b" is RHS is counting up (catching up) to LHS.  This is true
+    // whether the strides are positive or negative.
+    if (RHSStride->getValue().slt(LHSStride->getValue()))
+      std::swap(LHS, RHS);
+  }
+    
+  return SE.getMinusSCEV(LHS, RHS, true /*HasNUW*/);
+}
+
 /// ComputeBackedgeTakenCountFromExitCondICmp - Compute the number of times the
 /// backedge of the specified loop will execute if its exit condition
 /// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB.
@@ -4001,7 +4099,8 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L,
   switch (Cond) {
   case ICmpInst::ICMP_NE: {                     // while (X != Y)
     // Convert to: while (X-Y != 0)
-    BackedgeTakenInfo BTI = HowFarToZero(getMinusSCEV(LHS, RHS), L);
+    BackedgeTakenInfo BTI = HowFarToZero(getMinusSCEVForExitTest(LHS, RHS, L,
+                                                                 *this), L);
     if (BTI.hasAnyInfo()) return BTI;
     break;
   }