[SCEV] Teach SCEV some axioms about non-wrapping arithmetic
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index d3574b97bf6c1edb0f3217b8e7066ca8f90e64e0..d30e982c8d5cb206aefed0ef7c68bd0e2cba48dd 100644 (file)
@@ -1303,9 +1303,9 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
       ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
 
   if (OverflowLimit &&
-      SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) {
+      SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
     return PreStart;
-  }
+
   return nullptr;
 }
 
@@ -1631,6 +1631,16 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
         }
       }
     }
+
+    // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
+    if (SA->getNoWrapFlags(SCEV::FlagNSW)) {
+      // If the addition does not sign overflow then we can, by definition,
+      // commute the sign extension with the addition operation.
+      SmallVector<const SCEV *, 4> Ops;
+      for (const auto *Op : SA->operands())
+        Ops.push_back(getSignExtendExpr(Op, Ty));
+      return getAddExpr(Ops, SCEV::FlagNSW);
+    }
   }
   // If the input value is a chrec scev, and we can prove that the value
   // did not overflow the old, smaller, value, we can sign extend all of the
@@ -1921,8 +1931,9 @@ namespace {
 static SCEV::NoWrapFlags
 StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
                       const SmallVectorImpl<const SCEV *> &Ops,
-                      SCEV::NoWrapFlags OldFlags) {
+                      SCEV::NoWrapFlags Flags) {
   using namespace std::placeholders;
+  typedef OverflowingBinaryOperator OBO;
 
   bool CanAnalyze =
       Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
@@ -1931,7 +1942,7 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
 
   int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
   SCEV::NoWrapFlags SignOrUnsignWrap =
-      ScalarEvolution::maskFlags(OldFlags, SignOrUnsignMask);
+      ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
 
   // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
   auto IsKnownNonNegative =
@@ -1939,10 +1950,34 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
 
   if (SignOrUnsignWrap == SCEV::FlagNSW &&
       std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative))
-    return ScalarEvolution::setFlags(OldFlags,
-                                     (SCEV::NoWrapFlags)SignOrUnsignMask);
+    Flags =
+        ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
+
+  SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
+
+  if (SignOrUnsignWrap != SignOrUnsignMask && Type == scAddExpr &&
+      Ops.size() == 2 && isa<SCEVConstant>(Ops[0])) {
+
+    // (A + C) --> (A + C)<nsw> if the addition does not sign overflow
+    // (A + C) --> (A + C)<nuw> if the addition does not unsign overflow
 
-  return OldFlags;
+    const APInt &C = cast<SCEVConstant>(Ops[0])->getValue()->getValue();
+    if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
+      auto NSWRegion =
+        ConstantRange::makeNoWrapRegion(Instruction::Add, C, OBO::NoSignedWrap);
+      if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
+        Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
+    }
+    if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
+      auto NUWRegion =
+        ConstantRange::makeNoWrapRegion(Instruction::Add, C,
+                                        OBO::NoUnsignedWrap);
+      if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
+        Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
+    }
+  }
+
+  return Flags;
 }
 
 /// getAddExpr - Get a canonical add expression, or something simpler if
@@ -2043,8 +2078,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
               break;
             }
             LargeMulOps.push_back(T->getOperand());
-          } else if (const SCEVConstant *C =
-                       dyn_cast<SCEVConstant>(M->getOperand(j))) {
+          } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
             LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
           } else {
             Ok = false;
@@ -2421,9 +2455,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
           }
           if (AnyFolded)
             return getAddExpr(NewOps);
-        }
-        else if (const SCEVAddRecExpr *
-                 AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
+        } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
           // Negation preserves a recurrence's no self-wrap property.
           SmallVector<const SCEV *, 4> Operands;
           for (SCEVAddRecExpr::op_iterator I = AddRec->op_begin(),
@@ -2748,8 +2780,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
   if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
     // If the mulexpr multiplies by a constant, then that constant must be the
     // first element of the mulexpr.
-    if (const SCEVConstant *LHSCst =
-            dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
+    if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
       if (LHSCst == RHSCst) {
         SmallVector<const SCEV *, 2> Operands;
         Operands.append(Mul->op_begin() + 1, Mul->op_end());
@@ -2848,12 +2879,10 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
       // AddRecs require their operands be loop-invariant with respect to their
       // loops. Don't perform this transformation if it would break this
       // requirement.
-      bool AllInvariant = true;
-      for (unsigned i = 0, e = Operands.size(); i != e; ++i)
-        if (!isLoopInvariant(Operands[i], L)) {
-          AllInvariant = false;
-          break;
-        }
+      bool AllInvariant =
+          std::all_of(Operands.begin(), Operands.end(),
+                      [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
+
       if (AllInvariant) {
         // Create a recurrence for the outer loop with the same step size.
         //
@@ -2863,12 +2892,10 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
           maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
 
         NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
-        AllInvariant = true;
-        for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i)
-          if (!isLoopInvariant(NestedOperands[i], NestedLoop)) {
-            AllInvariant = false;
-            break;
-          }
+        AllInvariant = std::all_of(
+            NestedOperands.begin(), NestedOperands.end(),
+            [&](const SCEV *Op) { return isLoopInvariant(Op, NestedLoop); });
+
         if (AllInvariant) {
           // Ok, both add recurrences are valid after the transformation.
           //
@@ -3245,9 +3272,8 @@ uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
 Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
   assert(isSCEVable(Ty) && "Type is not SCEVable!");
 
-  if (Ty->isIntegerTy()) {
+  if (Ty->isIntegerTy())
     return Ty;
-  }
 
   // The only other support type is pointer.
   assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
@@ -3522,8 +3548,7 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
 
   if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
     return getPointerBase(Cast->getOperand());
-  }
-  else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
+  } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
     const SCEV *PtrOp = nullptr;
     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
          I != E; ++I) {
@@ -3567,8 +3592,7 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
     if (!Visited.insert(I).second)
       continue;
 
-    ValueExprMapType::iterator It =
-      ValueExprMap.find_as(static_cast<Value *>(I));
+    auto It = ValueExprMap.find_as(static_cast<Value *>(I));
     if (It != ValueExprMap.end()) {
       const SCEV *Old = It->second;
 
@@ -3708,8 +3732,7 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
           return PHISCEV;
         }
       }
-    } else if (const SCEVAddRecExpr *AddRec =
-                   dyn_cast<SCEVAddRecExpr>(BEValue)) {
+    } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(BEValue)) {
       // Otherwise, this could be a loop like this:
       //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
       // In this case, j = {1,+,1}  and BEValue is j.
@@ -3798,8 +3821,8 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
 
       case scUDivExpr:
       case scCouldNotCompute:
-      // We do not try to smart about these at all.
-      return setUnavailable();
+        // We do not try to smart about these at all.
+        return setUnavailable();
       }
       llvm_unreachable("switch should be fully covered!");
     }
@@ -5656,9 +5679,8 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
   Instruction *I = dyn_cast<Instruction>(V);
   if (!I || !canConstantEvolve(I, L)) return nullptr;
 
-  if (PHINode *PN = dyn_cast<PHINode>(I)) {
+  if (PHINode *PN = dyn_cast<PHINode>(I))
     return PN;
-  }
 
   // Record non-constant instructions contained by the loop.
   DenseMap<Instruction *, PHINode *> PHIMap;
@@ -7140,6 +7162,60 @@ ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
   return false;
 }
 
+bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
+                                                    const SCEV *LHS,
+                                                    const SCEV *RHS) {
+
+  // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.
+  // Return Y via OutY.
+  auto MatchBinaryAddToConst =
+      [this](const SCEV *Result, const SCEV *X, APInt &OutY,
+             SCEV::NoWrapFlags ExpectedFlags) {
+    const SCEV *NonConstOp, *ConstOp;
+    SCEV::NoWrapFlags FlagsPresent;
+
+    if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) ||
+        !isa<SCEVConstant>(ConstOp) || NonConstOp != X)
+      return false;
+
+    OutY = cast<SCEVConstant>(ConstOp)->getValue()->getValue();
+    return (FlagsPresent & ExpectedFlags) != 0;
+  };
+
+  APInt C;
+
+  switch (Pred) {
+  default:
+    break;
+
+  case ICmpInst::ICMP_SGE:
+    std::swap(LHS, RHS);
+  case ICmpInst::ICMP_SLE:
+    // X s<= (X + C)<nsw> if C >= 0
+    if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative())
+      return true;
+
+    // (X + C)<nsw> s<= X if C <= 0
+    if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
+        !C.isStrictlyPositive())
+      return true;
+
+  case ICmpInst::ICMP_SGT:
+    std::swap(LHS, RHS);
+  case ICmpInst::ICMP_SLT:
+    // X s< (X + C)<nsw> if C > 0
+    if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) &&
+        C.isStrictlyPositive())
+      return true;
+
+    // (X + C)<nsw> s< X if C < 0
+    if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
+      return true;
+  }
+
+  return false;
+}
+
 bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
                                                    const SCEV *LHS,
                                                    const SCEV *RHS) {
@@ -7789,8 +7865,9 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
   auto IsKnownPredicateFull =
       [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
     return isKnownPredicateWithRanges(Pred, LHS, RHS) ||
-        IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
-        IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS);
+           IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
+           IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
+           isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
   };
 
   switch (Pred) {
@@ -8124,8 +8201,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
       Operands[0] = SE.getZero(SC->getType());
       const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
                                              getNoWrapFlags(FlagNW));
-      if (const SCEVAddRecExpr *ShiftedAddRec =
-            dyn_cast<SCEVAddRecExpr>(Shifted))
+      if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
         return ShiftedAddRec->getNumIterationsInRange(
                            Range.subtract(SC->getValue()->getValue()), SE);
       // This is strange and shouldn't happen.
@@ -8134,10 +8210,9 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
 
   // The only time we can solve this is when we have all constant indices.
   // Otherwise, we cannot determine the overflow conditions.
-  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
-    if (!isa<SCEVConstant>(getOperand(i)))
-      return SE.getCouldNotCompute();
-
+  if (std::any_of(op_begin(), op_end(),
+                  [](const SCEV *Op) { return !isa<SCEVConstant>(Op);}))
+    return SE.getCouldNotCompute();
 
   // Okay at this point we know that all elements of the chrec are constants and
   // that the start element is zero.