[mips] Merge disassemblers into a single implementation.
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 44d0b11ac721615c1df1ca29897bfeb85b43b80f..9aefe8c33f7cbd7243b78b7638200c6e689aa61a 100644 (file)
@@ -68,6 +68,7 @@
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
@@ -87,7 +88,6 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
-#include "llvm/Target/TargetLibraryInfo.h"
 #include <algorithm>
 using namespace llvm;
 
@@ -117,9 +117,9 @@ VerifySCEV("verify-scev",
 INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution",
                 "Scalar Evolution Analysis", false, true)
 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_DEPENDENCY(LoopInfo)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
 INITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution",
                 "Scalar Evolution Analysis", false, true)
 char ScalarEvolution::ID = 0;
@@ -1364,7 +1364,24 @@ static const SCEV *getPreStartForSignExtend(const SCEVAddRecExpr *AR,
   const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
     SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
 
-  if (PreAR && PreAR->getNoWrapFlags(SCEV::FlagNSW))
+  // WARNING: FIXME: the optimization below assumes that a sign-overflowing nsw
+  // operation is undefined behavior.  This is strictly more aggressive than the
+  // interpretation of nsw in other parts of LLVM (for instance, they may
+  // unconditionally hoist nsw arithmetic through control flow).  This logic
+  // needs to be revisited once we have a consistent semantics for poison
+  // values.
+  //
+  // "{S,+,X} is <nsw>" and "{S,+,X} is evaluated at least once" implies "S+X
+  // does not sign-overflow" (we'd have undefined behavior if it did).  If
+  // `L->getExitingBlock() == L->getLoopLatch()` then `PreAR` (= {S,+,X}<nsw>)
+  // is evaluated every-time `AR` (= {S+X,+,X}) is evaluated, and hence within
+  // `AR` we are safe to assume that "S+X" will not sign-overflow.
+  //
+
+  BasicBlock *ExitingBlock = L->getExitingBlock();
+  BasicBlock *LatchBlock = L->getLoopLatch();
+  if (PreAR && PreAR->getNoWrapFlags(SCEV::FlagNSW) &&
+      ExitingBlock != nullptr && ExitingBlock == LatchBlock)
     return PreStart;
 
   // 2. Direct overflow check on the step operation's expression.
@@ -1533,8 +1550,16 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
                        getMulExpr(WideMaxBECount,
                                   getZeroExtendExpr(Step, WideTy)));
           if (SAdd == OperandExtendedAdd) {
-            // Cache knowledge of AR NSW, which is propagated to this AddRec.
-            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
+            // If AR wraps around then
+            //
+            //    abs(Step) * MaxBECount > unsigned-max(AR->getType())
+            // => SAdd != OperandExtendedAdd
+            //
+            // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
+            // (SAdd == OperandExtendedAdd => AR is NW)
+
+            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
+
             // Return the expression with the addrec on the outside.
             return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
                                  getZeroExtendExpr(Step, Ty),
@@ -1737,6 +1762,36 @@ namespace {
   };
 }
 
+// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
+// `OldFlags' as can't-wrap behavior.  Infer a more aggressive set of
+// can't-overflow flags for the operation if possible.
+static SCEV::NoWrapFlags
+StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
+                      const SmallVectorImpl<const SCEV *> &Ops,
+                      SCEV::NoWrapFlags OldFlags) {
+  using namespace std::placeholders;
+
+  bool CanAnalyze =
+      Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
+  (void)CanAnalyze;
+  assert(CanAnalyze && "don't call from other places!");
+
+  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
+  SCEV::NoWrapFlags SignOrUnsignWrap =
+      ScalarEvolution::maskFlags(OldFlags, SignOrUnsignMask);
+
+  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
+  auto IsKnownNonNegative =
+    std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1);
+
+  if (SignOrUnsignWrap == SCEV::FlagNSW &&
+      std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative))
+    return ScalarEvolution::setFlags(OldFlags,
+                                     (SCEV::NoWrapFlags)SignOrUnsignMask);
+
+  return OldFlags;
+}
+
 /// getAddExpr - Get a canonical add expression, or something simpler if
 /// possible.
 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
@@ -1752,20 +1807,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
            "SCEVAddExpr operand types don't match!");
 #endif
 
-  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
-  // And vice-versa.
-  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
-  SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
-  if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
-    bool All = true;
-    for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
-         E = Ops.end(); I != E; ++I)
-      if (!isKnownNonNegative(*I)) {
-        All = false;
-        break;
-      }
-    if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
-  }
+  Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags);
 
   // Sort by complexity, this groups all similar expression types together.
   GroupByComplexity(Ops, LI);
@@ -2174,20 +2216,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
            "SCEVMulExpr operand types don't match!");
 #endif
 
-  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
-  // And vice-versa.
-  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
-  SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
-  if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
-    bool All = true;
-    for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
-         E = Ops.end(); I != E; ++I)
-      if (!isKnownNonNegative(*I)) {
-        All = false;
-        break;
-      }
-    if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
-  }
+  Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags);
 
   // Sort by complexity, this groups all similar expression types together.
   GroupByComplexity(Ops, LI);
@@ -2653,20 +2682,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
   // meaningful BE count at this point (and if we don't, we'd be stuck
   // with a SCEVCouldNotCompute as the cached BE count).
 
-  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
-  // And vice-versa.
-  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
-  SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
-  if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
-    bool All = true;
-    for (SmallVectorImpl<const SCEV *>::const_iterator I = Operands.begin(),
-         E = Operands.end(); I != E; ++I)
-      if (!isKnownNonNegative(*I)) {
-        All = false;
-        break;
-      }
-    if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
-  }
+  Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
 
   // Canonicalize nested AddRecs in by nesting them in order of loop depth.
   if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
@@ -3163,8 +3179,9 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
   if (LHS == RHS)
     return getConstant(LHS->getType(), 0);
 
-  // X - Y --> X + -Y
-  return getAddExpr(LHS, getNegativeSCEV(RHS), Flags);
+  // X - Y --> X + -Y.
+  // X -(nsw || nuw) Y --> X + -Y.
+  return getAddExpr(LHS, getNegativeSCEV(RHS));
 }
 
 /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
@@ -3470,12 +3487,10 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
                   if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
                     Flags = setFlags(Flags, SCEV::FlagNUW);
                 }
-              } else if (const SubOperator *OBO =
-                           dyn_cast<SubOperator>(BEValueV)) {
-                if (OBO->hasNoUnsignedWrap())
-                  Flags = setFlags(Flags, SCEV::FlagNUW);
-                if (OBO->hasNoSignedWrap())
-                  Flags = setFlags(Flags, SCEV::FlagNSW);
+
+                // We cannot transfer nuw and nsw flags from subtraction
+                // operations -- sub nuw X, Y is not the same as add nuw X, -Y
+                // for instance.
               }
 
               const SCEV *StartVal = getSCEV(StartValueV);
@@ -4290,9 +4305,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       case ICmpInst::ICMP_SGE:
         // a >s b ? a+x : b+x  ->  smax(a, b)+x
         // a >s b ? b+x : a+x  ->  smin(a, b)+x
-        if (LHS->getType() == U->getType()) {
-          const SCEV *LS = getSCEV(LHS);
-          const SCEV *RS = getSCEV(RHS);
+        if (getTypeSizeInBits(LHS->getType()) <=
+            getTypeSizeInBits(U->getType())) {
+          const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType());
+          const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType());
           const SCEV *LA = getSCEV(U->getOperand(1));
           const SCEV *RA = getSCEV(U->getOperand(2));
           const SCEV *LDiff = getMinusSCEV(LA, LS);
@@ -4313,9 +4329,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       case ICmpInst::ICMP_UGE:
         // a >u b ? a+x : b+x  ->  umax(a, b)+x
         // a >u b ? b+x : a+x  ->  umin(a, b)+x
-        if (LHS->getType() == U->getType()) {
-          const SCEV *LS = getSCEV(LHS);
-          const SCEV *RS = getSCEV(RHS);
+        if (getTypeSizeInBits(LHS->getType()) <=
+            getTypeSizeInBits(U->getType())) {
+          const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
+          const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType());
           const SCEV *LA = getSCEV(U->getOperand(1));
           const SCEV *RA = getSCEV(U->getOperand(2));
           const SCEV *LDiff = getMinusSCEV(LA, LS);
@@ -4330,11 +4347,11 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
         break;
       case ICmpInst::ICMP_NE:
         // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x
-        if (LHS->getType() == U->getType() &&
-            isa<ConstantInt>(RHS) &&
-            cast<ConstantInt>(RHS)->isZero()) {
-          const SCEV *One = getConstant(LHS->getType(), 1);
-          const SCEV *LS = getSCEV(LHS);
+        if (getTypeSizeInBits(LHS->getType()) <=
+                getTypeSizeInBits(U->getType()) &&
+            isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+          const SCEV *One = getConstant(U->getType(), 1);
+          const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
           const SCEV *LA = getSCEV(U->getOperand(1));
           const SCEV *RA = getSCEV(U->getOperand(2));
           const SCEV *LDiff = getMinusSCEV(LA, LS);
@@ -4345,11 +4362,11 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
         break;
       case ICmpInst::ICMP_EQ:
         // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x
-        if (LHS->getType() == U->getType() &&
-            isa<ConstantInt>(RHS) &&
-            cast<ConstantInt>(RHS)->isZero()) {
-          const SCEV *One = getConstant(LHS->getType(), 1);
-          const SCEV *LS = getSCEV(LHS);
+        if (getTypeSizeInBits(LHS->getType()) <=
+                getTypeSizeInBits(U->getType()) &&
+            isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+          const SCEV *One = getConstant(U->getType(), 1);
+          const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
           const SCEV *LA = getSCEV(U->getOperand(1));
           const SCEV *RA = getSCEV(U->getOperand(2));
           const SCEV *LDiff = getMinusSCEV(LA, One);
@@ -7021,8 +7038,8 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
   return false;
 }
 
-// Verify if an linear IV with positive stride can overflow when in a 
-// less-than comparison, knowing the invariant term of the comparison, the 
+// Verify if an linear IV with positive stride can overflow when in a
+// less-than comparison, knowing the invariant term of the comparison, the
 // stride and the knowledge of NSW/NUW flags on the recurrence.
 bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
                                          bool IsSigned, bool NoWrap) {
@@ -7050,7 +7067,7 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
   return (MaxValue - MaxStrideMinusOne).ult(MaxRHS);
 }
 
-// Verify if an linear IV with negative stride can overflow when in a 
+// Verify if an linear IV with negative stride can overflow when in a
 // greater-than comparison, knowing the invariant term of the comparison,
 // the stride and the knowledge of NSW/NUW flags on the recurrence.
 bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
@@ -7081,7 +7098,7 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
 
 // Compute the backedge taken count knowing the interval difference, the
 // stride and presence of the equality in the comparison.
-const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, 
+const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step,
                                             bool Equality) {
   const SCEV *One = getConstant(Step->getType(), 1);
   Delta = Equality ? getAddExpr(Delta, Step)
@@ -7121,7 +7138,7 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
 
   // Avoid proven overflow cases: this will ensure that the backedge taken count
   // will not generate any unsigned overflow. Relaxed no-overflow conditions
-  // exploit NoWrapFlags, allowing to optimize in presence of undefined 
+  // exploit NoWrapFlags, allowing to optimize in presence of undefined
   // behaviors like the case of C language.
   if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap))
     return getCouldNotCompute();
@@ -7201,7 +7218,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
 
   // Avoid proven overflow cases: this will ensure that the backedge taken count
   // will not generate any unsigned overflow. Relaxed no-overflow conditions
-  // exploit NoWrapFlags, allowing to optimize in presence of undefined 
+  // exploit NoWrapFlags, allowing to optimize in presence of undefined
   // behaviors like the case of C language.
   if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap))
     return getCouldNotCompute();
@@ -7249,7 +7266,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
   if (isa<SCEVConstant>(BECount))
     MaxBECount = BECount;
   else
-    MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), 
+    MaxBECount = computeBECount(getConstant(MaxStart - MinEnd),
                                 getConstant(MinStride), false);
 
   if (isa<SCEVCouldNotCompute>(MaxBECount))
@@ -7876,10 +7893,10 @@ ScalarEvolution::ScalarEvolution()
 bool ScalarEvolution::runOnFunction(Function &F) {
   this->F = &F;
   AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
-  LI = &getAnalysis<LoopInfo>();
+  LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
   DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
   DL = DLP ? &DLP->getDataLayout() : nullptr;
-  TLI = &getAnalysis<TargetLibraryInfo>();
+  TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
   DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
   return false;
 }
@@ -7917,9 +7934,9 @@ void ScalarEvolution::releaseMemory() {
 void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesAll();
   AU.addRequired<AssumptionCacheTracker>();
-  AU.addRequiredTransitive<LoopInfo>();
+  AU.addRequiredTransitive<LoopInfoWrapperPass>();
   AU.addRequiredTransitive<DominatorTreeWrapperPass>();
-  AU.addRequired<TargetLibraryInfo>();
+  AU.addRequired<TargetLibraryInfoWrapperPass>();
 }
 
 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
@@ -8010,17 +8027,17 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const {
 
 ScalarEvolution::LoopDisposition
 ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
-  SmallVector<std::pair<const Loop *, LoopDisposition>, 2> &Values = LoopDispositions[S];
-  for (unsigned u = 0; u < Values.size(); u++) {
-    if (Values[u].first == L)
-      return Values[u].second;
+  auto &Values = LoopDispositions[S];
+  for (auto &V : Values) {
+    if (V.getPointer() == L)
+      return V.getInt();
   }
-  Values.push_back(std::make_pair(L, LoopVariant));
+  Values.emplace_back(L, LoopVariant);
   LoopDisposition D = computeLoopDisposition(S, L);
-  SmallVector<std::pair<const Loop *, LoopDisposition>, 2> &Values2 = LoopDispositions[S];
-  for (unsigned u = Values2.size(); u > 0; u--) {
-    if (Values2[u - 1].first == L) {
-      Values2[u - 1].second = D;
+  auto &Values2 = LoopDispositions[S];
+  for (auto &V : make_range(Values2.rbegin(), Values2.rend())) {
+    if (V.getPointer() == L) {
+      V.setInt(D);
       break;
     }
   }
@@ -8116,17 +8133,17 @@ bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
 
 ScalarEvolution::BlockDisposition
 ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
-  SmallVector<std::pair<const BasicBlock *, BlockDisposition>, 2> &Values = BlockDispositions[S];
-  for (unsigned u = 0; u < Values.size(); u++) {
-    if (Values[u].first == BB)
-      return Values[u].second;
+  auto &Values = BlockDispositions[S];
+  for (auto &V : Values) {
+    if (V.getPointer() == BB)
+      return V.getInt();
   }
-  Values.push_back(std::make_pair(BB, DoesNotDominateBlock));
+  Values.emplace_back(BB, DoesNotDominateBlock);
   BlockDisposition D = computeBlockDisposition(S, BB);
-  SmallVector<std::pair<const BasicBlock *, BlockDisposition>, 2> &Values2 = BlockDispositions[S];
-  for (unsigned u = Values2.size(); u > 0; u--) {
-    if (Values2[u - 1].first == BB) {
-      Values2[u - 1].second = D;
+  auto &Values2 = BlockDispositions[S];
+  for (auto &V : make_range(Values2.rbegin(), Values2.rend())) {
+    if (V.getPointer() == BB) {
+      V.setInt(D);
       break;
     }
   }