Create a wrapper pass for BranchProbabilityInfo.
[oota-llvm.git] / lib / Transforms / Scalar / InductiveRangeCheckElimination.cpp
index aa3a029fc882ffcbd046eb3f4c174bb75ce5c16c..08fdcc38c045d8a48545756efe2f1f0930d82b42 100644 (file)
@@ -122,8 +122,9 @@ class InductiveRangeCheck {
   BranchInst *Branch;
   RangeCheckKind Kind;
 
-  static RangeCheckKind parseRangeCheckICmp(ICmpInst *ICI, ScalarEvolution &SE,
-                                            Value *&Index, Value *&Length);
+  static RangeCheckKind parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
+                                            ScalarEvolution &SE, Value *&Index,
+                                            Value *&Length);
 
   static InductiveRangeCheck::RangeCheckKind
   parseRangeCheck(Loop *L, ScalarEvolution &SE, Value *Condition,
@@ -214,7 +215,7 @@ public:
     AU.addRequiredID(LoopSimplifyID);
     AU.addRequiredID(LCSSAID);
     AU.addRequired<ScalarEvolution>();
-    AU.addRequired<BranchProbabilityInfo>();
+    AU.addRequired<BranchProbabilityInfoWrapperPass>();
   }
 
   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
@@ -255,8 +256,18 @@ const char *InductiveRangeCheck::rangeCheckKindToStr(
 /// RANGE_CHECK_UPPER.
 ///
 InductiveRangeCheck::RangeCheckKind
-InductiveRangeCheck::parseRangeCheckICmp(ICmpInst *ICI, ScalarEvolution &SE,
-                                         Value *&Index, Value *&Length) {
+InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
+                                         ScalarEvolution &SE, Value *&Index,
+                                         Value *&Length) {
+
+  auto IsNonNegativeAndNotLoopVarying = [&SE, L](Value *V) {
+    const SCEV *S = SE.getSCEV(V);
+    if (isa<SCEVCouldNotCompute>(S))
+      return false;
+
+    return SE.getLoopDisposition(S, L) == ScalarEvolution::LoopInvariant &&
+           SE.isKnownNonNegative(S);
+  };
 
   using namespace llvm::PatternMatch;
 
@@ -287,7 +298,7 @@ InductiveRangeCheck::parseRangeCheckICmp(ICmpInst *ICI, ScalarEvolution &SE,
       return RANGE_CHECK_LOWER;
     }
 
-    if (SE.isKnownNonNegative(SE.getSCEV(LHS))) {
+    if (IsNonNegativeAndNotLoopVarying(LHS)) {
       Index = RHS;
       Length = LHS;
       return RANGE_CHECK_UPPER;
@@ -298,7 +309,7 @@ InductiveRangeCheck::parseRangeCheckICmp(ICmpInst *ICI, ScalarEvolution &SE,
     std::swap(LHS, RHS);
   // fallthrough
   case ICmpInst::ICMP_UGT:
-    if (SE.isKnownNonNegative(SE.getSCEV(LHS))) {
+    if (IsNonNegativeAndNotLoopVarying(LHS)) {
       Index = RHS;
       Length = LHS;
       return RANGE_CHECK_BOTH;
@@ -328,8 +339,8 @@ InductiveRangeCheck::parseRangeCheck(Loop *L, ScalarEvolution &SE,
     if (!ICmpA || !ICmpB)
       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 
-    auto RCKindA = parseRangeCheckICmp(ICmpA, SE, IndexA, LengthA);
-    auto RCKindB = parseRangeCheckICmp(ICmpB, SE, IndexB, LengthB);
+    auto RCKindA = parseRangeCheckICmp(L, ICmpA, SE, IndexA, LengthA);
+    auto RCKindB = parseRangeCheckICmp(L, ICmpB, SE, IndexB, LengthB);
 
     if (RCKindA == InductiveRangeCheck::RANGE_CHECK_UNKNOWN ||
         RCKindB == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
@@ -353,7 +364,7 @@ InductiveRangeCheck::parseRangeCheck(Loop *L, ScalarEvolution &SE,
   if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
     Value *IndexVal = nullptr;
 
-    auto RCKind = parseRangeCheckICmp(ICI, SE, IndexVal, Length);
+    auto RCKind = parseRangeCheckICmp(L, ICI, SE, IndexVal, Length);
 
     if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
@@ -696,30 +707,40 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP
     }
   }
 
-  auto IsInductionVar = [&SE](const SCEVAddRecExpr *AR, bool &IsIncreasing) {
-    if (!AR->isAffine())
-      return false;
+  auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
+    if (AR->getNoWrapFlags(SCEV::FlagNSW))
+      return true;
 
     IntegerType *Ty = cast<IntegerType>(AR->getType());
     IntegerType *WideTy =
         IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
 
-    // Currently we only work with induction variables that have been proved to
-    // not wrap.  This restriction can potentially be lifted in the future.
-
     const SCEVAddRecExpr *ExtendAfterOp =
         dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
-    if (!ExtendAfterOp)
-      return false;
+    if (ExtendAfterOp) {
+      const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
+      const SCEV *ExtendedStep =
+          SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
 
-    const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
-    const SCEV *ExtendedStep =
-        SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
+      bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
+                          ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
 
-    bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
-                        ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
+      if (NoSignedWrap)
+        return true;
+    }
+
+    // We may have proved this when computing the sign extension above.
+    return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
+  };
+
+  auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing) {
+    if (!AR->isAffine())
+      return false;
+
+    // Currently we only work with induction variables that have been proved to
+    // not wrap.  This restriction can potentially be lifted in the future.
 
-    if (!NoSignedWrap)
+    if (!HasNoSignedWrap(AR))
       return false;
 
     if (const SCEVConstant *StepExpr =
@@ -1379,7 +1400,8 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) {
   InductiveRangeCheck::AllocatorTy IRCAlloc;
   SmallVector<InductiveRangeCheck *, 16> RangeChecks;
   ScalarEvolution &SE = getAnalysis<ScalarEvolution>();
-  BranchProbabilityInfo &BPI = getAnalysis<BranchProbabilityInfo>();
+  BranchProbabilityInfo &BPI =
+      getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
 
   for (auto BBI : L->getBlocks())
     if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator()))