[SCEV] Exploit A < B => (A+K) < (B+K) when possible
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index df1e3eb183a1c10e2760f1c03eae9176060ab3d4..a3763c354d2d3ae9bd419f62be9424665e616506 100644 (file)
@@ -898,8 +898,8 @@ private:
   SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
                const SCEV *Denominator)
       : SE(S), Denominator(Denominator) {
-    Zero = SE.getConstant(Denominator->getType(), 0);
-    One = SE.getConstant(Denominator->getType(), 1);
+    Zero = SE.getZero(Denominator->getType());
+    One = SE.getOne(Denominator->getType());
 
     // We generally do not know how to divide Expr by Denominator. We
     // initialize the division to a "cannot divide" state to simplify the rest
@@ -1743,8 +1743,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
         if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) &&
             C2.isPowerOf2()) {
           Start = getSignExtendExpr(Start, Ty);
-          const SCEV *NewAR = getAddRecExpr(getConstant(AR->getType(), 0), Step,
-                                            L, AR->getNoWrapFlags());
+          const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L,
+                                            AR->getNoWrapFlags());
           return getAddExpr(Start, getSignExtendExpr(NewAR, Ty));
         }
       }
@@ -2120,7 +2120,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
           Ops.push_back(getMulExpr(getConstant(I->first),
                                    getAddExpr(I->second)));
       if (Ops.empty())
-        return getConstant(Ty, 0);
+        return getZero(Ty);
       if (Ops.size() == 1)
         return Ops[0];
       return getAddExpr(Ops);
@@ -2148,7 +2148,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
             MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
             InnerMul = getMulExpr(MulOps);
           }
-          const SCEV *One = getConstant(Ty, 1);
+          const SCEV *One = getOne(Ty);
           const SCEV *AddOne = getAddExpr(One, InnerMul);
           const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV);
           if (Ops.size() == 2) return OuterMul;
@@ -2540,7 +2540,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
       SmallVector<const SCEV*, 7> AddRecOps;
       for (int x = 0, xe = AddRec->getNumOperands() +
              OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
-        const SCEV *Term = getConstant(Ty, 0);
+        const SCEV *Term = getZero(Ty);
         for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
           uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
           for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
@@ -2920,7 +2920,7 @@ ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr,
   // adds.
   SCEV::NoWrapFlags Wrap = InBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
 
-  const SCEV *TotalOffset = getConstant(IntPtrTy, 0);
+  const SCEV *TotalOffset = getZero(IntPtrTy);
   // The address space is unimportant. The first thing we do on CurTy is getting
   // its element type.
   Type *CurTy = PointerType::getUnqual(PointeeType);
@@ -3349,7 +3349,7 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
                                           SCEV::NoWrapFlags Flags) {
   // Fast path: X - X --> 0.
   if (LHS == RHS)
-    return getConstant(LHS->getType(), 0);
+    return getZero(LHS->getType());
 
   // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
   // makes it so that we cannot make much use of NUW.
@@ -4177,7 +4177,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
   else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
     return getConstant(CI);
   else if (isa<ConstantPointerNull>(V))
-    return getConstant(V->getType(), 0);
+    return getZero(V->getType());
   else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
     return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee());
   else
@@ -4529,7 +4529,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
         if (getTypeSizeInBits(LHS->getType()) <=
                 getTypeSizeInBits(U->getType()) &&
             isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
-          const SCEV *One = getConstant(U->getType(), 1);
+          const SCEV *One = getOne(U->getType());
           const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
           const SCEV *LA = getSCEV(U->getOperand(1));
           const SCEV *RA = getSCEV(U->getOperand(2));
@@ -4544,7 +4544,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
         if (getTypeSizeInBits(LHS->getType()) <=
                 getTypeSizeInBits(U->getType()) &&
             isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
-          const SCEV *One = getConstant(U->getType(), 1);
+          const SCEV *One = getOne(U->getType());
           const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
           const SCEV *LA = getSCEV(U->getOperand(1));
           const SCEV *RA = getSCEV(U->getOperand(2));
@@ -4641,8 +4641,7 @@ ScalarEvolution::getSmallConstantTripMultiple(Loop *L,
     return 1;
 
   // Get the trip count from the BE count by adding 1.
-  const SCEV *TCMul = getAddExpr(ExitCount,
-                                 getConstant(ExitCount->getType(), 1));
+  const SCEV *TCMul = getAddExpr(ExitCount, getOne(ExitCount->getType()));
   // FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt
   // to factor simple cases.
   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(TCMul))
@@ -5197,7 +5196,7 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L,
       return getCouldNotCompute();
     else
       // The backedge is never taken.
-      return getConstant(CI->getType(), 0);
+      return getZero(CI->getType());
   }
 
   // If it's not an integer or pointer comparison then compute it the hard way.
@@ -6372,7 +6371,7 @@ ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
   // already.  If so, the backedge will execute zero times.
   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
     if (!C->getValue()->isNullValue())
-      return getConstant(C->getType(), 0);
+      return getZero(C->getType());
     return getCouldNotCompute();  // Otherwise it will loop infinitely.
   }
 
@@ -7281,6 +7280,146 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
   return false;
 }
 
+// Return true if More == (Less + C), where C is a constant.
+static bool IsConstDiff(ScalarEvolution &SE, const SCEV *Less, const SCEV *More,
+                        APInt &C) {
+  // We avoid subtracting expressions here because this function is usually
+  // fairly deep in the call stack (i.e. is called many times).
+
+  auto SplitBinaryAdd = [](const SCEV *Expr, const SCEV *&L, const SCEV *&R) {
+    const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
+    if (!AE || AE->getNumOperands() != 2)
+      return false;
+
+    L = AE->getOperand(0);
+    R = AE->getOperand(1);
+    return true;
+  };
+
+  if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
+    const auto *LAR = cast<SCEVAddRecExpr>(Less);
+    const auto *MAR = cast<SCEVAddRecExpr>(More);
+
+    if (LAR->getLoop() != MAR->getLoop())
+      return false;
+
+    // We look at affine expressions only; not for correctness but to keep
+    // getStepRecurrence cheap.
+    if (!LAR->isAffine() || !MAR->isAffine())
+      return false;
+
+    if (LAR->getStepRecurrence(SE) != MAR->getStepRecurrence(SE))
+      return false;
+
+    Less = LAR->getStart();
+    More = MAR->getStart();
+
+    // fall through
+  }
+
+  if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
+    const auto &M = cast<SCEVConstant>(More)->getValue()->getValue();
+    const auto &L = cast<SCEVConstant>(Less)->getValue()->getValue();
+    C = M - L;
+    return true;
+  }
+
+  const SCEV *L, *R;
+  if (SplitBinaryAdd(Less, L, R))
+    if (const auto *LC = dyn_cast<SCEVConstant>(L))
+      if (R == More) {
+        C = -(LC->getValue()->getValue());
+        return true;
+      }
+
+  if (SplitBinaryAdd(More, L, R))
+    if (const auto *LC = dyn_cast<SCEVConstant>(L))
+      if (R == Less) {
+        C = LC->getValue()->getValue();
+        return true;
+      }
+
+  return false;
+}
+
+bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
+    ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
+    const SCEV *FoundLHS, const SCEV *FoundRHS) {
+  if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
+    return false;
+
+  const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
+  if (!AddRecLHS)
+    return false;
+
+  const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
+  if (!AddRecFoundLHS)
+    return false;
+
+  // We'd like to let SCEV reason about control dependencies, so we constrain
+  // both the inequalities to be about add recurrences on the same loop.  This
+  // way we can use isLoopEntryGuardedByCond later.
+
+  const Loop *L = AddRecFoundLHS->getLoop();
+  if (L != AddRecLHS->getLoop())
+    return false;
+
+  //  FoundLHS u< FoundRHS u< -C =>  (FoundLHS + C) u< (FoundRHS + C) ... (1)
+  //
+  //  FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
+  //                                                                  ... (2)
+  //
+  // Informal proof for (2), assuming (1) [*]:
+  //
+  // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
+  //
+  // Then
+  //
+  //       FoundLHS s< FoundRHS s< INT_MIN - C
+  // <=>  (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C   [ using (3) ]
+  // <=>  (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
+  // <=>  (FoundLHS + INT_MIN + C + INT_MIN) s<
+  //                        (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
+  // <=>  FoundLHS + C s< FoundRHS + C
+  //
+  // [*]: (1) can be proved by ruling out overflow.
+  //
+  // [**]: This can be proved by analyzing all the four possibilities:
+  //    (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
+  //    (A s>= 0, B s>= 0).
+  //
+  // Note:
+  // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
+  // will not sign underflow.  For instance, say FoundLHS = (i8 -128), FoundRHS
+  // = (i8 -127) and C = (i8 -100).  Then INT_MIN - C = (i8 -28), and FoundRHS
+  // s< (INT_MIN - C).  Lack of sign overflow / underflow in "FoundRHS + C" is
+  // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
+  // C)".
+
+  APInt LDiff, RDiff;
+  if (!IsConstDiff(*this, FoundLHS, LHS, LDiff) ||
+      !IsConstDiff(*this, FoundRHS, RHS, RDiff) ||
+      LDiff != RDiff)
+    return false;
+
+  if (LDiff == 0)
+    return true;
+
+  unsigned Width = cast<IntegerType>(RHS->getType())->getBitWidth();
+  APInt FoundRHSLimit;
+
+  if (Pred == CmpInst::ICMP_ULT) {
+    FoundRHSLimit = -RDiff;
+  } else {
+    assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
+    FoundRHSLimit = APInt::getSignedMinValue(Width) - RDiff;
+  }
+
+  // Try to prove (1) or (2), as needed.
+  return isLoopEntryGuardedByCond(L, Pred, FoundRHS,
+                                  getConstant(FoundRHSLimit));
+}
+
 /// isImpliedCondOperands - Test whether the condition described by Pred,
 /// LHS, and RHS is true whenever the condition described by Pred, FoundLHS,
 /// and FoundRHS is true.
@@ -7291,6 +7430,9 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
   if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
     return true;
 
+  if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
+    return true;
+
   return isImpliedCondOperandsHelper(Pred, LHS, RHS,
                                      FoundLHS, FoundRHS) ||
          // ~x < ~y --> x > y
@@ -7510,7 +7652,7 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
   if (NoWrap) return false;
 
   unsigned BitWidth = getTypeSizeInBits(RHS->getType());
-  const SCEV *One = getConstant(Stride->getType(), 1);
+  const SCEV *One = getOne(Stride->getType());
 
   if (IsSigned) {
     APInt MaxRHS = getSignedRange(RHS).getSignedMax();
@@ -7539,7 +7681,7 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
   if (NoWrap) return false;
 
   unsigned BitWidth = getTypeSizeInBits(RHS->getType());
-  const SCEV *One = getConstant(Stride->getType(), 1);
+  const SCEV *One = getOne(Stride->getType());
 
   if (IsSigned) {
     APInt MinRHS = getSignedRange(RHS).getSignedMin();
@@ -7564,7 +7706,7 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
 // stride and presence of the equality in the comparison.
 const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step,
                                             bool Equality) {
-  const SCEV *One = getConstant(Step->getType(), 1);
+  const SCEV *One = getOne(Step->getType());
   Delta = Equality ? getAddExpr(Delta, Step)
                    : getAddExpr(Delta, getMinusSCEV(Step, One));
   return getUDivExpr(Delta, Step);
@@ -7753,7 +7895,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
     if (!SC->getValue()->isZero()) {
       SmallVector<const SCEV *, 4> Operands(op_begin(), op_end());
-      Operands[0] = SE.getConstant(SC->getType(), 0);
+      Operands[0] = SE.getZero(SC->getType());
       const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
                                              getNoWrapFlags(FlagNW));
       if (const SCEVAddRecExpr *ShiftedAddRec =
@@ -7778,7 +7920,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
   // iteration exits.
   unsigned BitWidth = SE.getTypeSizeInBits(getType());
   if (!Range.contains(APInt(BitWidth, 0)))
-    return SE.getConstant(getType(), 0);
+    return SE.getZero(getType());
 
   if (isAffine()) {
     // If this is an affine expression then we have this situation: