Create a wrapper pass for BranchProbabilityInfo.
[oota-llvm.git] / lib / Transforms / Scalar / InductiveRangeCheckElimination.cpp
index 22ce7119cb70259d4c417dce50bf4019a1cea2ec..08fdcc38c045d8a48545756efe2f1f0930d82b42 100644 (file)
@@ -42,7 +42,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ADT/Optional.h"
-
 #include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/ScalarEvolutionExpander.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/ValueTracking.h"
-
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Function.h"
-#include "llvm/IR/Instructions.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/ValueHandle.h"
 #include "llvm/IR/Verifier.h"
-
+#include "llvm/Pass.h"
 #include "llvm/Support/Debug.h"
-
+#include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
 #include "llvm/Transforms/Utils/SimplifyIndVar.h"
 #include "llvm/Transforms/Utils/UnrollLoop.h"
-
-#include "llvm/Pass.h"
-
 #include <array>
 
 using namespace llvm;
@@ -82,6 +77,9 @@ static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden,
 static cl::opt<bool> PrintChangedLoops("irce-print-changed-loops", cl::Hidden,
                                        cl::init(false));
 
+static cl::opt<bool> PrintRangeChecks("irce-print-range-checks", cl::Hidden,
+                                      cl::init(false));
+
 static cl::opt<int> MaxExitProbReciprocal("irce-max-exit-prob-reciprocal",
                                           cl::Hidden, cl::init(10));
 
@@ -101,7 +99,7 @@ namespace {
 ///
 class InductiveRangeCheck {
   // Classifies a range check
-  enum RangeCheckKind {
+  enum RangeCheckKind : unsigned {
     // Range check of the form "0 <= I".
     RANGE_CHECK_LOWER = 1,
 
@@ -124,8 +122,9 @@ class InductiveRangeCheck {
   BranchInst *Branch;
   RangeCheckKind Kind;
 
-  static RangeCheckKind parseRangeCheckICmp(ICmpInst *ICI, ScalarEvolution &SE,
-                                            Value *&Index, Value *&Length);
+  static RangeCheckKind parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
+                                            ScalarEvolution &SE, Value *&Index,
+                                            Value *&Length);
 
   static InductiveRangeCheck::RangeCheckKind
   parseRangeCheck(Loop *L, ScalarEvolution &SE, Value *Condition,
@@ -216,7 +215,7 @@ public:
     AU.addRequiredID(LoopSimplifyID);
     AU.addRequiredID(LCSSAID);
     AU.addRequired<ScalarEvolution>();
-    AU.addRequired<BranchProbabilityInfo>();
+    AU.addRequired<BranchProbabilityInfoWrapperPass>();
   }
 
   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
@@ -257,8 +256,18 @@ const char *InductiveRangeCheck::rangeCheckKindToStr(
 /// RANGE_CHECK_UPPER.
 ///
 InductiveRangeCheck::RangeCheckKind
-InductiveRangeCheck::parseRangeCheckICmp(ICmpInst *ICI, ScalarEvolution &SE,
-                                         Value *&Index, Value *&Length) {
+InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
+                                         ScalarEvolution &SE, Value *&Index,
+                                         Value *&Length) {
+
+  auto IsNonNegativeAndNotLoopVarying = [&SE, L](Value *V) {
+    const SCEV *S = SE.getSCEV(V);
+    if (isa<SCEVCouldNotCompute>(S))
+      return false;
+
+    return SE.getLoopDisposition(S, L) == ScalarEvolution::LoopInvariant &&
+           SE.isKnownNonNegative(S);
+  };
 
   using namespace llvm::PatternMatch;
 
@@ -289,7 +298,7 @@ InductiveRangeCheck::parseRangeCheckICmp(ICmpInst *ICI, ScalarEvolution &SE,
       return RANGE_CHECK_LOWER;
     }
 
-    if (SE.isKnownNonNegative(SE.getSCEV(LHS))) {
+    if (IsNonNegativeAndNotLoopVarying(LHS)) {
       Index = RHS;
       Length = LHS;
       return RANGE_CHECK_UPPER;
@@ -300,7 +309,7 @@ InductiveRangeCheck::parseRangeCheckICmp(ICmpInst *ICI, ScalarEvolution &SE,
     std::swap(LHS, RHS);
   // fallthrough
   case ICmpInst::ICMP_UGT:
-    if (SE.isKnownNonNegative(SE.getSCEV(LHS))) {
+    if (IsNonNegativeAndNotLoopVarying(LHS)) {
       Index = RHS;
       Length = LHS;
       return RANGE_CHECK_BOTH;
@@ -330,8 +339,8 @@ InductiveRangeCheck::parseRangeCheck(Loop *L, ScalarEvolution &SE,
     if (!ICmpA || !ICmpB)
       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
 
-    auto RCKindA = parseRangeCheckICmp(ICmpA, SE, IndexA, LengthA);
-    auto RCKindB = parseRangeCheckICmp(ICmpB, SE, IndexB, LengthB);
+    auto RCKindA = parseRangeCheckICmp(L, ICmpA, SE, IndexA, LengthA);
+    auto RCKindB = parseRangeCheckICmp(L, ICmpB, SE, IndexB, LengthB);
 
     if (RCKindA == InductiveRangeCheck::RANGE_CHECK_UNKNOWN ||
         RCKindB == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
@@ -355,7 +364,7 @@ InductiveRangeCheck::parseRangeCheck(Loop *L, ScalarEvolution &SE,
   if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
     Value *IndexVal = nullptr;
 
-    auto RCKind = parseRangeCheckICmp(ICI, SE, IndexVal, Length);
+    auto RCKind = parseRangeCheckICmp(L, ICI, SE, IndexVal, Length);
 
     if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
@@ -394,8 +403,8 @@ InductiveRangeCheck::create(InductiveRangeCheck::AllocatorTy &A, BranchInst *BI,
     return nullptr;
 
   assert(IndexSCEV && "contract with SplitRangeCheckCondition!");
-  assert(!(RCKind & InductiveRangeCheck::RANGE_CHECK_UPPER) ||
-         Length && "contract with SplitRangeCheckCondition!");
+  assert((!(RCKind & InductiveRangeCheck::RANGE_CHECK_UPPER) || Length) &&
+         "contract with SplitRangeCheckCondition!");
 
   const SCEVAddRecExpr *IndexAddRec = dyn_cast<SCEVAddRecExpr>(IndexSCEV);
   bool IsAffineIndex =
@@ -698,30 +707,40 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP
     }
   }
 
-  auto IsInductionVar = [&SE](const SCEVAddRecExpr *AR, bool &IsIncreasing) {
-    if (!AR->isAffine())
-      return false;
+  auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
+    if (AR->getNoWrapFlags(SCEV::FlagNSW))
+      return true;
 
     IntegerType *Ty = cast<IntegerType>(AR->getType());
     IntegerType *WideTy =
         IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
 
-    // Currently we only work with induction variables that have been proved to
-    // not wrap.  This restriction can potentially be lifted in the future.
-
     const SCEVAddRecExpr *ExtendAfterOp =
         dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
-    if (!ExtendAfterOp)
-      return false;
+    if (ExtendAfterOp) {
+      const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
+      const SCEV *ExtendedStep =
+          SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
 
-    const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
-    const SCEV *ExtendedStep =
-        SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
+      bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
+                          ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
 
-    bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
-                        ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
+      if (NoSignedWrap)
+        return true;
+    }
+
+    // We may have proved this when computing the sign extension above.
+    return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
+  };
 
-    if (!NoSignedWrap)
+  auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing) {
+    if (!AR->isAffine())
+      return false;
+
+    // Currently we only work with induction variables that have been proved to
+    // not wrap.  This restriction can potentially be lifted in the future.
+
+    if (!HasNoSignedWrap(AR))
       return false;
 
     if (const SCEVConstant *StepExpr =
@@ -1381,7 +1400,8 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) {
   InductiveRangeCheck::AllocatorTy IRCAlloc;
   SmallVector<InductiveRangeCheck *, 16> RangeChecks;
   ScalarEvolution &SE = getAnalysis<ScalarEvolution>();
-  BranchProbabilityInfo &BPI = getAnalysis<BranchProbabilityInfo>();
+  BranchProbabilityInfo &BPI =
+      getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
 
   for (auto BBI : L->getBlocks())
     if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator()))
@@ -1392,12 +1412,18 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) {
   if (RangeChecks.empty())
     return false;
 
-  DEBUG(dbgs() << "irce: looking at loop "; L->print(dbgs());
-        dbgs() << "irce: loop has " << RangeChecks.size()
-               << " inductive range checks: \n";
-        for (InductiveRangeCheck *IRC : RangeChecks)
-          IRC->print(dbgs());
-    );
+  auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) {
+    OS << "irce: looking at loop "; L->print(OS);
+    OS << "irce: loop has " << RangeChecks.size()
+       << " inductive range checks: \n";
+    for (InductiveRangeCheck *IRC : RangeChecks)
+      IRC->print(OS);
+  };
+
+  DEBUG(PrintRecognizedRangeChecks(dbgs()));
+
+  if (PrintRangeChecks)
+    PrintRecognizedRangeChecks(errs());
 
   const char *FailureReason = nullptr;
   Optional<LoopStructure> MaybeLoopStructure =