X-Git-Url: http://plrg.eecs.uci.edu/git/?p=oota-llvm.git;a=blobdiff_plain;f=lib%2FTransforms%2FScalar%2FInductiveRangeCheckElimination.cpp;h=08fdcc38c045d8a48545756efe2f1f0930d82b42;hp=809e9ee99c12ecffca5cddd4e4a9c4c8aea90cf9;hb=8770f7af5f46c0d34a79cf0beeeef80b1a2ab690;hpb=e003f1ac8cb8e921b50eae9a997dfc9258cc998f diff --git a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 809e9ee99c1..08fdcc38c04 100644 --- a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -42,7 +42,6 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/Optional.h" - #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" @@ -51,27 +50,23 @@ #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ValueTracking.h" - #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/Instructions.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" #include "llvm/IR/Verifier.h" - +#include "llvm/Pass.h" #include "llvm/Support/Debug.h" - +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SimplifyIndVar.h" #include "llvm/Transforms/Utils/UnrollLoop.h" - -#include "llvm/Pass.h" - #include using namespace llvm; @@ -82,6 +77,12 @@ static cl::opt LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden, static cl::opt PrintChangedLoops("irce-print-changed-loops", cl::Hidden, cl::init(false)); +static cl::opt PrintRangeChecks("irce-print-range-checks", cl::Hidden, + cl::init(false)); + +static cl::opt MaxExitProbReciprocal("irce-max-exit-prob-reciprocal", + cl::Hidden, cl::init(10)); + #define DEBUG_TYPE "irce" namespace { @@ -93,23 +94,41 @@ namespace { /// /// and /// -/// 2. a condition that is provably true for some range of values taken by the -/// containing loop's induction variable. +/// 2. a condition that is provably true for some contiguous range of values +/// taken by the containing loop's induction variable. /// -/// Currently all inductive range checks are branches conditional on an -/// expression of the form -/// -/// 0 <= (Offset + Scale * I) < Length -/// -/// where `I' is the canonical induction variable of a loop to which Offset and -/// Scale are loop invariant, and Length is >= 0. Currently the 'false' branch -/// is considered cold, looking at profiling data to verify that is a TODO. - class InductiveRangeCheck { + // Classifies a range check + enum RangeCheckKind : unsigned { + // Range check of the form "0 <= I". + RANGE_CHECK_LOWER = 1, + + // Range check of the form "I < L" where L is known positive. + RANGE_CHECK_UPPER = 2, + + // The logical and of the RANGE_CHECK_LOWER and RANGE_CHECK_UPPER + // conditions. + RANGE_CHECK_BOTH = RANGE_CHECK_LOWER | RANGE_CHECK_UPPER, + + // Unrecognized range check condition. + RANGE_CHECK_UNKNOWN = (unsigned)-1 + }; + + static const char *rangeCheckKindToStr(RangeCheckKind); + const SCEV *Offset; const SCEV *Scale; Value *Length; BranchInst *Branch; + RangeCheckKind Kind; + + static RangeCheckKind parseRangeCheckICmp(Loop *L, ICmpInst *ICI, + ScalarEvolution &SE, Value *&Index, + Value *&Length); + + static InductiveRangeCheck::RangeCheckKind + parseRangeCheck(Loop *L, ScalarEvolution &SE, Value *Condition, + const SCEV *&Index, Value *&UpperLimit); InductiveRangeCheck() : Offset(nullptr), Scale(nullptr), Length(nullptr), Branch(nullptr) { } @@ -121,14 +140,19 @@ public: void print(raw_ostream &OS) const { OS << "InductiveRangeCheck:\n"; + OS << " Kind: " << rangeCheckKindToStr(Kind) << "\n"; OS << " Offset: "; Offset->print(OS); OS << " Scale: "; Scale->print(OS); OS << " Length: "; - Length->print(OS); - OS << " Branch: "; + if (Length) + Length->print(OS); + else + OS << "(null)"; + OS << "\n Branch: "; getBranch()->print(OS); + OS << "\n"; } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -143,17 +167,17 @@ public: /// R.getEnd() sle R.getBegin(), then R denotes the empty range. class Range { - Value *Begin; - Value *End; + const SCEV *Begin; + const SCEV *End; public: - Range(Value *Begin, Value *End) : Begin(Begin), End(End) { + Range(const SCEV *Begin, const SCEV *End) : Begin(Begin), End(End) { assert(Begin->getType() == End->getType() && "ill-typed range!"); } Type *getType() const { return Begin->getType(); } - Value *getBegin() const { return Begin; } - Value *getEnd() const { return End; } + const SCEV *getBegin() const { return Begin; } + const SCEV *getEnd() const { return End; } }; typedef SpecificBumpPtrAllocator AllocatorTy; @@ -162,9 +186,11 @@ public: /// branch to take the hot successor (see (1) above). bool getPassingDirection() { return true; } - /// Computes a range for the induction variable in which the range check is - /// redundant and can be constant-folded away. + /// Computes a range for the induction variable (IndVar) in which the range + /// check is redundant and can be constant-folded away. The induction + /// variable is not required to be the canonical {0,+,1} induction variable. Optional computeSafeIterationSpace(ScalarEvolution &SE, + const SCEVAddRecExpr *IndVar, IRBuilder<> &B) const; /// Create an inductive range check out of BI if possible, else return @@ -189,7 +215,7 @@ public: AU.addRequiredID(LoopSimplifyID); AU.addRequiredID(LCSSAID); AU.addRequired(); - AU.addRequired(); + AU.addRequired(); } bool runOnLoop(Loop *L, LPPassManager &LPM) override; @@ -201,160 +227,156 @@ char InductiveRangeCheckElimination::ID = 0; INITIALIZE_PASS(InductiveRangeCheckElimination, "irce", "Inductive range check elimination", false, false) -static bool IsLowerBoundCheck(Value *Check, Value *&IndexV) { - using namespace llvm::PatternMatch; +const char *InductiveRangeCheck::rangeCheckKindToStr( + InductiveRangeCheck::RangeCheckKind RCK) { + switch (RCK) { + case InductiveRangeCheck::RANGE_CHECK_UNKNOWN: + return "RANGE_CHECK_UNKNOWN"; - ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; - Value *LHS = nullptr, *RHS = nullptr; + case InductiveRangeCheck::RANGE_CHECK_UPPER: + return "RANGE_CHECK_UPPER"; - if (!match(Check, m_ICmp(Pred, m_Value(LHS), m_Value(RHS)))) - return false; + case InductiveRangeCheck::RANGE_CHECK_LOWER: + return "RANGE_CHECK_LOWER"; + + case InductiveRangeCheck::RANGE_CHECK_BOTH: + return "RANGE_CHECK_BOTH"; + } + + llvm_unreachable("unknown range check type!"); +} + +/// Parse a single ICmp instruction, `ICI`, into a range check. If `ICI` +/// cannot +/// be interpreted as a range check, return `RANGE_CHECK_UNKNOWN` and set +/// `Index` and `Length` to `nullptr`. Otherwise set `Index` to the value +/// being +/// range checked, and set `Length` to the upper limit `Index` is being range +/// checked with if (and only if) the range check type is stronger or equal to +/// RANGE_CHECK_UPPER. +/// +InductiveRangeCheck::RangeCheckKind +InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, + ScalarEvolution &SE, Value *&Index, + Value *&Length) { + + auto IsNonNegativeAndNotLoopVarying = [&SE, L](Value *V) { + const SCEV *S = SE.getSCEV(V); + if (isa(S)) + return false; + + return SE.getLoopDisposition(S, L) == ScalarEvolution::LoopInvariant && + SE.isKnownNonNegative(S); + }; + + using namespace llvm::PatternMatch; + + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *LHS = ICI->getOperand(0); + Value *RHS = ICI->getOperand(1); switch (Pred) { default: - return false; + return RANGE_CHECK_UNKNOWN; case ICmpInst::ICMP_SLE: std::swap(LHS, RHS); // fallthrough case ICmpInst::ICMP_SGE: - if (!match(RHS, m_ConstantInt<0>())) - return false; - IndexV = LHS; - return true; + if (match(RHS, m_ConstantInt<0>())) { + Index = LHS; + return RANGE_CHECK_LOWER; + } + return RANGE_CHECK_UNKNOWN; case ICmpInst::ICMP_SLT: std::swap(LHS, RHS); // fallthrough case ICmpInst::ICMP_SGT: - if (!match(RHS, m_ConstantInt<-1>())) - return false; - IndexV = LHS; - return true; - } -} - -static bool IsUpperBoundCheck(Value *Check, Value *Index, Value *&UpperLimit) { - using namespace llvm::PatternMatch; - - ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; - Value *LHS = nullptr, *RHS = nullptr; - - if (!match(Check, m_ICmp(Pred, m_Value(LHS), m_Value(RHS)))) - return false; + if (match(RHS, m_ConstantInt<-1>())) { + Index = LHS; + return RANGE_CHECK_LOWER; + } - switch (Pred) { - default: - return false; + if (IsNonNegativeAndNotLoopVarying(LHS)) { + Index = RHS; + Length = LHS; + return RANGE_CHECK_UPPER; + } + return RANGE_CHECK_UNKNOWN; - case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_ULT: std::swap(LHS, RHS); // fallthrough - case ICmpInst::ICMP_SLT: - if (LHS != Index) - return false; - UpperLimit = RHS; - return true; - case ICmpInst::ICMP_UGT: - std::swap(LHS, RHS); - // fallthrough - case ICmpInst::ICMP_ULT: - if (LHS != Index) - return false; - UpperLimit = RHS; - return true; + if (IsNonNegativeAndNotLoopVarying(LHS)) { + Index = RHS; + Length = LHS; + return RANGE_CHECK_BOTH; + } + return RANGE_CHECK_UNKNOWN; } + + llvm_unreachable("default clause returns!"); } -/// Split a condition into something semantically equivalent to (0 <= I < -/// Limit), both comparisons signed and Len loop invariant on L and positive. -/// On success, return true and set Index to I and UpperLimit to Limit. Return -/// false on failure (we may still write to UpperLimit and Index on failure). -/// It does not try to interpret I as a loop index. -/// -static bool SplitRangeCheckCondition(Loop *L, ScalarEvolution &SE, +/// Parses an arbitrary condition into a range check. `Length` is set only if +/// the range check is recognized to be `RANGE_CHECK_UPPER` or stronger. +InductiveRangeCheck::RangeCheckKind +InductiveRangeCheck::parseRangeCheck(Loop *L, ScalarEvolution &SE, Value *Condition, const SCEV *&Index, - Value *&UpperLimit) { - - // TODO: currently this catches some silly cases like comparing "%idx slt 1". - // Our transformations are still correct, but less likely to be profitable in - // those cases. We have to come up with some heuristics that pick out the - // range checks that are more profitable to clone a loop for. This function - // in general can be made more robust. - + Value *&Length) { using namespace llvm::PatternMatch; Value *A = nullptr; Value *B = nullptr; - ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; - - // In these early checks we assume that the matched UpperLimit is positive. - // We'll verify that fact later, before returning true. if (match(Condition, m_And(m_Value(A), m_Value(B)))) { - Value *IndexV = nullptr; - Value *ExpectedUpperBoundCheck = nullptr; + Value *IndexA = nullptr, *IndexB = nullptr; + Value *LengthA = nullptr, *LengthB = nullptr; + ICmpInst *ICmpA = dyn_cast(A), *ICmpB = dyn_cast(B); - if (IsLowerBoundCheck(A, IndexV)) - ExpectedUpperBoundCheck = B; - else if (IsLowerBoundCheck(B, IndexV)) - ExpectedUpperBoundCheck = A; - else - return false; + if (!ICmpA || !ICmpB) + return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; - if (!IsUpperBoundCheck(ExpectedUpperBoundCheck, IndexV, UpperLimit)) - return false; + auto RCKindA = parseRangeCheckICmp(L, ICmpA, SE, IndexA, LengthA); + auto RCKindB = parseRangeCheckICmp(L, ICmpB, SE, IndexB, LengthB); - Index = SE.getSCEV(IndexV); + if (RCKindA == InductiveRangeCheck::RANGE_CHECK_UNKNOWN || + RCKindB == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) + return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; - if (isa(Index)) - return false; + if (IndexA != IndexB) + return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; - } else if (match(Condition, m_ICmp(Pred, m_Value(A), m_Value(B)))) { - switch (Pred) { - default: - return false; + if (LengthA != nullptr && LengthB != nullptr && LengthA != LengthB) + return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; - case ICmpInst::ICMP_SGT: - std::swap(A, B); - // fall through - case ICmpInst::ICMP_SLT: - UpperLimit = B; - Index = SE.getSCEV(A); - if (isa(Index) || !SE.isKnownNonNegative(Index)) - return false; - break; + Index = SE.getSCEV(IndexA); + if (isa(Index)) + return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; - case ICmpInst::ICMP_UGT: - std::swap(A, B); - // fall through - case ICmpInst::ICMP_ULT: - UpperLimit = B; - Index = SE.getSCEV(A); - if (isa(Index)) - return false; - break; - } - } else { - return false; + Length = LengthA == nullptr ? LengthB : LengthA; + + return (InductiveRangeCheck::RangeCheckKind)(RCKindA | RCKindB); } - const SCEV *UpperLimitSCEV = SE.getSCEV(UpperLimit); - if (isa(UpperLimitSCEV) || - !SE.isKnownNonNegative(UpperLimitSCEV)) - return false; + if (ICmpInst *ICI = dyn_cast(Condition)) { + Value *IndexVal = nullptr; - if (SE.getLoopDisposition(UpperLimitSCEV, L) != - ScalarEvolution::LoopInvariant) { - DEBUG(dbgs() << " in function: " << L->getHeader()->getParent()->getName() - << " "; - dbgs() << " UpperLimit is not loop invariant: " - << UpperLimit->getName() << "\n";); - return false; + auto RCKind = parseRangeCheckICmp(L, ICI, SE, IndexVal, Length); + + if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) + return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; + + Index = SE.getSCEV(IndexVal); + if (isa(Index)) + return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; + + return RCKind; } - return true; + return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; } @@ -374,10 +396,15 @@ InductiveRangeCheck::create(InductiveRangeCheck::AllocatorTy &A, BranchInst *BI, Value *Length = nullptr; const SCEV *IndexSCEV = nullptr; - if (!SplitRangeCheckCondition(L, SE, BI->getCondition(), IndexSCEV, Length)) + auto RCKind = InductiveRangeCheck::parseRangeCheck(L, SE, BI->getCondition(), + IndexSCEV, Length); + + if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) return nullptr; - assert(IndexSCEV && Length && "contract with SplitRangeCheckCondition!"); + assert(IndexSCEV && "contract with SplitRangeCheckCondition!"); + assert((!(RCKind & InductiveRangeCheck::RANGE_CHECK_UPPER) || Length) && + "contract with SplitRangeCheckCondition!"); const SCEVAddRecExpr *IndexAddRec = dyn_cast(IndexSCEV); bool IsAffineIndex = @@ -391,25 +418,59 @@ InductiveRangeCheck::create(InductiveRangeCheck::AllocatorTy &A, BranchInst *BI, IRC->Offset = IndexAddRec->getStart(); IRC->Scale = IndexAddRec->getStepRecurrence(SE); IRC->Branch = BI; + IRC->Kind = RCKind; return IRC; } -static Value *MaybeSimplify(Value *V) { - if (Instruction *I = dyn_cast(V)) - if (Value *Simplified = SimplifyInstruction(I)) - return Simplified; - return V; -} - -static Value *ConstructSMinOf(Value *X, Value *Y, IRBuilder<> &B) { - return MaybeSimplify(B.CreateSelect(B.CreateICmpSLT(X, Y), X, Y)); -} +namespace { -static Value *ConstructSMaxOf(Value *X, Value *Y, IRBuilder<> &B) { - return MaybeSimplify(B.CreateSelect(B.CreateICmpSGT(X, Y), X, Y)); -} +// Keeps track of the structure of a loop. This is similar to llvm::Loop, +// except that it is more lightweight and can track the state of a loop through +// changing and potentially invalid IR. This structure also formalizes the +// kinds of loops we can deal with -- ones that have a single latch that is also +// an exiting block *and* have a canonical induction variable. +struct LoopStructure { + const char *Tag; + + BasicBlock *Header; + BasicBlock *Latch; + + // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th + // successor is `LatchExit', the exit block of the loop. + BranchInst *LatchBr; + BasicBlock *LatchExit; + unsigned LatchBrExitIdx; + + Value *IndVarNext; + Value *IndVarStart; + Value *LoopExitAt; + bool IndVarIncreasing; + + LoopStructure() + : Tag(""), Header(nullptr), Latch(nullptr), LatchBr(nullptr), + LatchExit(nullptr), LatchBrExitIdx(-1), IndVarNext(nullptr), + IndVarStart(nullptr), LoopExitAt(nullptr), IndVarIncreasing(false) {} + + template LoopStructure map(M Map) const { + LoopStructure Result; + Result.Tag = Tag; + Result.Header = cast(Map(Header)); + Result.Latch = cast(Map(Latch)); + Result.LatchBr = cast(Map(LatchBr)); + Result.LatchExit = cast(Map(LatchExit)); + Result.LatchBrExitIdx = LatchBrExitIdx; + Result.IndVarNext = Map(IndVarNext); + Result.IndVarStart = Map(IndVarStart); + Result.LoopExitAt = Map(LoopExitAt); + Result.IndVarIncreasing = IndVarIncreasing; + return Result; + } -namespace { + static Optional parseLoopStructure(ScalarEvolution &, + BranchProbabilityInfo &BPI, + Loop &, + const char *&); +}; /// This class is used to constrain loops to run within a given iteration space. /// The algorithm this class implements is given a Loop and a range [Begin, @@ -421,51 +482,6 @@ namespace { /// iterations in which the induction variable is >= End. /// class LoopConstrainer { - - // Keeps track of the structure of a loop. This is similar to llvm::Loop, - // except that it is more lightweight and can track the state of a loop - // through changing and potentially invalid IR. This structure also - // formalizes the kinds of loops we can deal with -- ones that have a single - // latch that is also an exiting block *and* have a canonical induction - // variable. - struct LoopStructure { - const char *Tag; - - BasicBlock *Header; - BasicBlock *Latch; - - // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th - // successor is `LatchExit', the exit block of the loop. - BranchInst *LatchBr; - BasicBlock *LatchExit; - unsigned LatchBrExitIdx; - - // The canonical induction variable. It's value is `CIVStart` on the 0th - // itertion and `CIVNext` for all iterations after that. - PHINode *CIV; - Value *CIVStart; - Value *CIVNext; - - LoopStructure() : Tag(""), Header(nullptr), Latch(nullptr), - LatchBr(nullptr), LatchExit(nullptr), - LatchBrExitIdx(-1), CIV(nullptr), - CIVStart(nullptr), CIVNext(nullptr) { } - - template LoopStructure map(M Map) const { - LoopStructure Result; - Result.Tag = Tag; - Result.Header = cast(Map(Header)); - Result.Latch = cast(Map(Latch)); - Result.LatchBr = cast(Map(LatchBr)); - Result.LatchExit = cast(Map(LatchExit)); - Result.LatchBrExitIdx = LatchBrExitIdx; - Result.CIV = cast(Map(CIV)); - Result.CIVNext = Map(CIVNext); - Result.CIVStart = Map(CIVStart); - return Result; - } - }; - // The representation of a clone of the original loop we started out with. struct ClonedLoop { // The cloned blocks @@ -484,17 +500,22 @@ class LoopConstrainer { BasicBlock *PseudoExit; BasicBlock *ExitSelector; std::vector PHIValuesAtPseudoExit; + PHINode *IndVarEnd; - RewrittenRangeInfo() : PseudoExit(nullptr), ExitSelector(nullptr) { } + RewrittenRangeInfo() + : PseudoExit(nullptr), ExitSelector(nullptr), IndVarEnd(nullptr) {} }; // Calculated subranges we restrict the iteration space of the main loop to. // See the implementation of `calculateSubRanges' for more details on how - // these fields are computed. `ExitPreLoopAt' is `None' if we don't need a - // pre loop. `ExitMainLoopAt' is `None' if we don't need a post loop. + // these fields are computed. `LowLimit` is None if there is no restriction + // on low end of the restricted iteration space of the main loop. `HighLimit` + // is None if there is no restriction on high end of the restricted iteration + // space of the main loop. + struct SubRanges { - Optional ExitPreLoopAt; - Optional ExitMainLoopAt; + Optional LowLimit; + Optional HighLimit; }; // A utility function that does a `replaceUsesOfWith' on the incoming block @@ -503,19 +524,11 @@ class LoopConstrainer { static void replacePHIBlock(PHINode *PN, BasicBlock *Block, BasicBlock *ReplaceBy); - // Try to "parse" `OriginalLoop' and populate the various out parameters. - // Returns true on success, false on failure. - // - bool recognizeLoop(LoopStructure &LoopStructureOut, - const SCEV *&LatchCountOut, BasicBlock *&PreHeaderOut, - const char *&FailureReasonOut) const; - // Compute a safe set of limits for the main loop to run in -- effectively the // intersection of `Range' and the iteration space of the original loop. - // Return the header count (1 + the latch taken count) in `HeaderCount'. // Return None if unable to compute the set of subranges. // - Optional calculateSubRanges(Value *&HeaderCount) const; + Optional calculateSubRanges() const; // Clone `OriginalLoop' and return the result in CLResult. The IR after // running `cloneLoop' is well formed except for the PHI nodes in CLResult -- @@ -554,16 +567,15 @@ class LoopConstrainer { // The loop denoted by `LS' has `OldPreheader' as its preheader. This // function creates a new preheader for `LS' and returns it. // - BasicBlock *createPreheader(const LoopConstrainer::LoopStructure &LS, - BasicBlock *OldPreheader, const char *Tag) const; + BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader, + const char *Tag) const; // `ContinuationBlockAndPreheader' was the continuation block for some call to // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'. // This function rewrites the PHI nodes in `LS.Header' to start with the // correct value. void rewriteIncomingValuesForPHIs( - LoopConstrainer::LoopStructure &LS, - BasicBlock *ContinuationBlockAndPreheader, + LoopStructure &LS, BasicBlock *ContinuationBlockAndPreheader, const LoopConstrainer::RewrittenRangeInfo &RRI) const; // Even though we do not preserve any passes at this time, we at least need to @@ -582,7 +594,6 @@ class LoopConstrainer { LoopInfo &OriginalLoopInfo; const SCEV *LatchTakenCount; BasicBlock *OriginalPreheader; - Value *OriginalHeaderCount; // The preheader of the main loop. This may or may not be different from // `OriginalPreheader'. @@ -596,12 +607,12 @@ class LoopConstrainer { LoopStructure MainLoopStructure; public: - LoopConstrainer(Loop &L, LoopInfo &LI, ScalarEvolution &SE, - InductiveRangeCheck::Range R) - : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE), - OriginalLoop(L), OriginalLoopInfo(LI), LatchTakenCount(nullptr), - OriginalPreheader(nullptr), OriginalHeaderCount(nullptr), - MainLoopPreheader(nullptr), Range(R) { } + LoopConstrainer(Loop &L, LoopInfo &LI, const LoopStructure &LS, + ScalarEvolution &SE, InductiveRangeCheck::Range R) + : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), + SE(SE), OriginalLoop(L), OriginalLoopInfo(LI), LatchTakenCount(nullptr), + OriginalPreheader(nullptr), MainLoopPreheader(nullptr), Range(R), + MainLoopStructure(LS) {} // Entry point for the algorithm. Returns true on success. bool run(); @@ -616,157 +627,288 @@ void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, PN->setIncomingBlock(i, ReplaceBy); } -bool LoopConstrainer::recognizeLoop(LoopStructure &LoopStructureOut, - const SCEV *&LatchCountOut, - BasicBlock *&PreheaderOut, - const char *&FailureReason) const { - using namespace llvm::PatternMatch; +static bool CanBeSMax(ScalarEvolution &SE, const SCEV *S) { + APInt SMax = + APInt::getSignedMaxValue(cast(S->getType())->getBitWidth()); + return SE.getSignedRange(S).contains(SMax) && + SE.getUnsignedRange(S).contains(SMax); +} - assert(OriginalLoop.isLoopSimplifyForm() && - "should follow from addRequired<>"); +static bool CanBeSMin(ScalarEvolution &SE, const SCEV *S) { + APInt SMin = + APInt::getSignedMinValue(cast(S->getType())->getBitWidth()); + return SE.getSignedRange(S).contains(SMin) && + SE.getUnsignedRange(S).contains(SMin); +} - BasicBlock *Latch = OriginalLoop.getLoopLatch(); - if (!OriginalLoop.isLoopExiting(Latch)) { - FailureReason = "no loop latch"; - return false; - } +Optional +LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BPI, + Loop &L, const char *&FailureReason) { + assert(L.isLoopSimplifyForm() && "should follow from addRequired<>"); - PHINode *CIV = OriginalLoop.getCanonicalInductionVariable(); - if (!CIV) { - FailureReason = "no CIV"; - return false; + BasicBlock *Latch = L.getLoopLatch(); + if (!L.isLoopExiting(Latch)) { + FailureReason = "no loop latch"; + return None; } - BasicBlock *Header = OriginalLoop.getHeader(); - BasicBlock *Preheader = OriginalLoop.getLoopPreheader(); + BasicBlock *Header = L.getHeader(); + BasicBlock *Preheader = L.getLoopPreheader(); if (!Preheader) { FailureReason = "no preheader"; - return false; + return None; } - Value *CIVNext = CIV->getIncomingValueForBlock(Latch); - Value *CIVStart = CIV->getIncomingValueForBlock(Preheader); - - const SCEV *LatchCount = SE.getExitCount(&OriginalLoop, Latch); - if (isa(LatchCount)) { - FailureReason = "could not compute latch count"; - return false; + BranchInst *LatchBr = dyn_cast(&*Latch->rbegin()); + if (!LatchBr || LatchBr->isUnconditional()) { + FailureReason = "latch terminator not conditional branch"; + return None; } - // While SCEV does most of the analysis for us, we still have to - // modify the latch; and currently we can only deal with certain - // kinds of latches. This can be made more sophisticated as needed. + unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; - BranchInst *LatchBr = dyn_cast(&*Latch->rbegin()); + BranchProbability ExitProbability = + BPI.getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx); - if (!LatchBr || LatchBr->isUnconditional()) { - FailureReason = "latch terminator not conditional branch"; - return false; + if (ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) { + FailureReason = "short running loop, not profitable"; + return None; } - // Currently we only support a latch condition of the form: - // - // %condition = icmp slt %civNext, %limit - // br i1 %condition, label %header, label %exit + ICmpInst *ICI = dyn_cast(LatchBr->getCondition()); + if (!ICI || !isa(ICI->getOperand(0)->getType())) { + FailureReason = "latch terminator branch not conditional on integral icmp"; + return None; + } - if (LatchBr->getSuccessor(0) != Header) { - FailureReason = "unknown latch form (header not first successor)"; - return false; + const SCEV *LatchCount = SE.getExitCount(&L, Latch); + if (isa(LatchCount)) { + FailureReason = "could not compute latch count"; + return None; } - Value *CIVComparedTo = nullptr; - ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; - if (!(match(LatchBr->getCondition(), - m_ICmp(Pred, m_Specific(CIVNext), m_Value(CIVComparedTo))) && - Pred == ICmpInst::ICMP_SLT)) { - FailureReason = "unknown latch form (not slt)"; - return false; + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *LeftValue = ICI->getOperand(0); + const SCEV *LeftSCEV = SE.getSCEV(LeftValue); + IntegerType *IndVarTy = cast(LeftValue->getType()); + + Value *RightValue = ICI->getOperand(1); + const SCEV *RightSCEV = SE.getSCEV(RightValue); + + // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence. + if (!isa(LeftSCEV)) { + if (isa(RightSCEV)) { + std::swap(LeftSCEV, RightSCEV); + std::swap(LeftValue, RightValue); + Pred = ICmpInst::getSwappedPredicate(Pred); + } else { + FailureReason = "no add recurrences in the icmp"; + return None; + } } - // IndVarSimplify will sometimes leave behind (in SCEV's cache) backedge-taken - // counts that are narrower than the canonical induction variable. These - // values are still accurate, and we could probably use them after sign/zero - // extension; but for now we just bail out of the transformation to keep - // things simple. - const SCEV *CIVComparedToSCEV = SE.getSCEV(CIVComparedTo); - if (isa(CIVComparedToSCEV) || - CIVComparedToSCEV->getType() != LatchCount->getType()) { - FailureReason = "could not relate CIV to latch expression"; + auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) { + if (AR->getNoWrapFlags(SCEV::FlagNSW)) + return true; + + IntegerType *Ty = cast(AR->getType()); + IntegerType *WideTy = + IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); + + const SCEVAddRecExpr *ExtendAfterOp = + dyn_cast(SE.getSignExtendExpr(AR, WideTy)); + if (ExtendAfterOp) { + const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy); + const SCEV *ExtendedStep = + SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy); + + bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart && + ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep; + + if (NoSignedWrap) + return true; + } + + // We may have proved this when computing the sign extension above. + return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; + }; + + auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing) { + if (!AR->isAffine()) + return false; + + // Currently we only work with induction variables that have been proved to + // not wrap. This restriction can potentially be lifted in the future. + + if (!HasNoSignedWrap(AR)) + return false; + + if (const SCEVConstant *StepExpr = + dyn_cast(AR->getStepRecurrence(SE))) { + ConstantInt *StepCI = StepExpr->getValue(); + if (StepCI->isOne() || StepCI->isMinusOne()) { + IsIncreasing = StepCI->isOne(); + return true; + } + } + return false; + }; + + // `ICI` is interpreted as taking the backedge if the *next* value of the + // induction variable satisfies some constraint. + + const SCEVAddRecExpr *IndVarNext = cast(LeftSCEV); + bool IsIncreasing = false; + if (!IsInductionVar(IndVarNext, IsIncreasing)) { + FailureReason = "LHS in icmp not induction variable"; + return None; } - const SCEV *ShouldBeOne = SE.getMinusSCEV(CIVComparedToSCEV, LatchCount); - const SCEVConstant *SCEVOne = dyn_cast(ShouldBeOne); - if (!SCEVOne || SCEVOne->getValue()->getValue() != 1) { - FailureReason = "unexpected header count in latch"; - return false; + ConstantInt *One = ConstantInt::get(IndVarTy, 1); + // TODO: generalize the predicates here to also match their unsigned variants. + if (IsIncreasing) { + bool FoundExpectedPred = + (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 1) || + (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 0); + + if (!FoundExpectedPred) { + FailureReason = "expected icmp slt semantically, found something else"; + return None; + } + + if (LatchBrExitIdx == 0) { + if (CanBeSMax(SE, RightSCEV)) { + // TODO: this restriction is easily removable -- we just have to + // remember that the icmp was an slt and not an sle. + FailureReason = "limit may overflow when coercing sle to slt"; + return None; + } + + IRBuilder<> B(&*Preheader->rbegin()); + RightValue = B.CreateAdd(RightValue, One); + } + + } else { + bool FoundExpectedPred = + (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 1) || + (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 0); + + if (!FoundExpectedPred) { + FailureReason = "expected icmp sgt semantically, found something else"; + return None; + } + + if (LatchBrExitIdx == 0) { + if (CanBeSMin(SE, RightSCEV)) { + // TODO: this restriction is easily removable -- we just have to + // remember that the icmp was an sgt and not an sge. + FailureReason = "limit may overflow when coercing sge to sgt"; + return None; + } + + IRBuilder<> B(&*Preheader->rbegin()); + RightValue = B.CreateSub(RightValue, One); + } } - unsigned LatchBrExitIdx = 1; + const SCEV *StartNext = IndVarNext->getStart(); + const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE)); + const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); + BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); - assert(SE.getLoopDisposition(LatchCount, &OriginalLoop) == + assert(SE.getLoopDisposition(LatchCount, &L) == ScalarEvolution::LoopInvariant && "loop variant exit count doesn't make sense!"); - assert(!OriginalLoop.contains(LatchExit) && "expected an exit block!"); - - LoopStructureOut.Tag = "main"; - LoopStructureOut.Header = Header; - LoopStructureOut.Latch = Latch; - LoopStructureOut.LatchBr = LatchBr; - LoopStructureOut.LatchExit = LatchExit; - LoopStructureOut.LatchBrExitIdx = LatchBrExitIdx; - LoopStructureOut.CIV = CIV; - LoopStructureOut.CIVNext = CIVNext; - LoopStructureOut.CIVStart = CIVStart; + assert(!L.contains(LatchExit) && "expected an exit block!"); + const DataLayout &DL = Preheader->getModule()->getDataLayout(); + Value *IndVarStartV = + SCEVExpander(SE, DL, "irce") + .expandCodeFor(IndVarStart, IndVarTy, &*Preheader->rbegin()); + IndVarStartV->setName("indvar.start"); + + LoopStructure Result; + + Result.Tag = "main"; + Result.Header = Header; + Result.Latch = Latch; + Result.LatchBr = LatchBr; + Result.LatchExit = LatchExit; + Result.LatchBrExitIdx = LatchBrExitIdx; + Result.IndVarStart = IndVarStartV; + Result.IndVarNext = LeftValue; + Result.IndVarIncreasing = IsIncreasing; + Result.LoopExitAt = RightValue; - LatchCountOut = LatchCount; - PreheaderOut = Preheader; FailureReason = nullptr; - return true; + return Result; } Optional -LoopConstrainer::calculateSubRanges(Value *&HeaderCountOut) const { +LoopConstrainer::calculateSubRanges() const { IntegerType *Ty = cast(LatchTakenCount->getType()); if (Range.getType() != Ty) return None; - SCEVExpander Expander(SE, "irce"); - Instruction *InsertPt = OriginalPreheader->getTerminator(); - - Value *LatchCountV = - MaybeSimplify(Expander.expandCodeFor(LatchTakenCount, Ty, InsertPt)); - - IRBuilder<> B(InsertPt); - LoopConstrainer::SubRanges Result; // I think we can be more aggressive here and make this nuw / nsw if the // addition that feeds into the icmp for the latch's terminating branch is nuw // / nsw. In any case, a wrapping 2's complement addition is safe. ConstantInt *One = ConstantInt::get(Ty, 1); - HeaderCountOut = MaybeSimplify(B.CreateAdd(LatchCountV, One, "header.count")); + const SCEV *Start = SE.getSCEV(MainLoopStructure.IndVarStart); + const SCEV *End = SE.getSCEV(MainLoopStructure.LoopExitAt); + + bool Increasing = MainLoopStructure.IndVarIncreasing; - const SCEV *RangeBegin = SE.getSCEV(Range.getBegin()); - const SCEV *RangeEnd = SE.getSCEV(Range.getEnd()); - const SCEV *HeaderCountSCEV = SE.getSCEV(HeaderCountOut); - const SCEV *Zero = SE.getConstant(Ty, 0); + // We compute `Smallest` and `Greatest` such that [Smallest, Greatest) is the + // range of values the induction variable takes. + + const SCEV *Smallest = nullptr, *Greatest = nullptr; + + if (Increasing) { + Smallest = Start; + Greatest = End; + } else { + // These two computations may sign-overflow. Here is why that is okay: + // + // We know that the induction variable does not sign-overflow on any + // iteration except the last one, and it starts at `Start` and ends at + // `End`, decrementing by one every time. + // + // * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the + // induction variable is decreasing we know that that the smallest value + // the loop body is actually executed with is `INT_SMIN` == `Smallest`. + // + // * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`. In + // that case, `Clamp` will always return `Smallest` and + // [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`) + // will be an empty range. Returning an empty range is always safe. + // + + Smallest = SE.getAddExpr(End, SE.getSCEV(One)); + Greatest = SE.getAddExpr(Start, SE.getSCEV(One)); + } + + auto Clamp = [this, Smallest, Greatest](const SCEV *S) { + return SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S)); + }; // In some cases we can prove that we don't need a pre or post loop bool ProvablyNoPreloop = - SE.isKnownPredicate(ICmpInst::ICMP_SLE, RangeBegin, Zero); + SE.isKnownPredicate(ICmpInst::ICMP_SLE, Range.getBegin(), Smallest); if (!ProvablyNoPreloop) - Result.ExitPreLoopAt = ConstructSMinOf(HeaderCountOut, Range.getBegin(), B); + Result.LowLimit = Clamp(Range.getBegin()); bool ProvablyNoPostLoop = - SE.isKnownPredicate(ICmpInst::ICMP_SLE, HeaderCountSCEV, RangeEnd); + SE.isKnownPredicate(ICmpInst::ICMP_SLE, Greatest, Range.getEnd()); if (!ProvablyNoPostLoop) - Result.ExitMainLoopAt = ConstructSMinOf(HeaderCountOut, Range.getEnd(), B); + Result.HighLimit = Clamp(Range.getEnd()); return Result; } @@ -823,7 +965,7 @@ void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, } LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( - const LoopStructure &LS, BasicBlock *Preheader, Value *ExitLoopAt, + const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt, BasicBlock *ContinuationBlock) const { // We start with a loop with a single latch: @@ -907,32 +1049,37 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( BBInsertLocation); BranchInst *PreheaderJump = cast(&*Preheader->rbegin()); + bool Increasing = LS.IndVarIncreasing; IRBuilder<> B(PreheaderJump); // EnterLoopCond - is it okay to start executing this `LS'? - Value *EnterLoopCond = B.CreateICmpSLT(LS.CIVStart, ExitLoopAt); + Value *EnterLoopCond = Increasing + ? B.CreateICmpSLT(LS.IndVarStart, ExitSubloopAt) + : B.CreateICmpSGT(LS.IndVarStart, ExitSubloopAt); + B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); PreheaderJump->eraseFromParent(); - assert(LS.LatchBrExitIdx == 1 && "generalize this as needed!"); - + LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); B.SetInsertPoint(LS.LatchBr); + Value *TakeBackedgeLoopCond = + Increasing ? B.CreateICmpSLT(LS.IndVarNext, ExitSubloopAt) + : B.CreateICmpSGT(LS.IndVarNext, ExitSubloopAt); + Value *CondForBranch = LS.LatchBrExitIdx == 1 + ? TakeBackedgeLoopCond + : B.CreateNot(TakeBackedgeLoopCond); - // ContinueCond - is it okay to execute the next iteration in `LS'? - Value *ContinueCond = B.CreateICmpSLT(LS.CIVNext, ExitLoopAt); - - LS.LatchBr->setCondition(ContinueCond); - assert(LS.LatchBr->getSuccessor(LS.LatchBrExitIdx) == LS.LatchExit && - "invariant!"); - LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); + LS.LatchBr->setCondition(CondForBranch); B.SetInsertPoint(RRI.ExitSelector); // IterationsLeft - are there any more iterations left, given the original // upper bound on the induction variable? If not, we branch to the "real" // exit. - Value *IterationsLeft = B.CreateICmpSLT(LS.CIVNext, OriginalHeaderCount); + Value *IterationsLeft = Increasing + ? B.CreateICmpSLT(LS.IndVarNext, LS.LoopExitAt) + : B.CreateICmpSGT(LS.IndVarNext, LS.LoopExitAt); B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); BranchInst *BranchToContinuation = @@ -956,6 +1103,11 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( RRI.PHIValuesAtPseudoExit.push_back(NewPHI); } + RRI.IndVarEnd = PHINode::Create(LS.IndVarNext->getType(), 2, "indvar.end", + BranchToContinuation); + RRI.IndVarEnd->addIncoming(LS.IndVarStart, Preheader); + RRI.IndVarEnd->addIncoming(LS.IndVarNext, RRI.ExitSelector); + // The latch exit now has a branch from `RRI.ExitSelector' instead of // `LS.Latch'. The PHI nodes need to be updated to reflect that. for (Instruction &I : *LS.LatchExit) { @@ -969,7 +1121,7 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( } void LoopConstrainer::rewriteIncomingValuesForPHIs( - LoopConstrainer::LoopStructure &LS, BasicBlock *ContinuationBlock, + LoopStructure &LS, BasicBlock *ContinuationBlock, const LoopConstrainer::RewrittenRangeInfo &RRI) const { unsigned PHIIndex = 0; @@ -984,13 +1136,12 @@ void LoopConstrainer::rewriteIncomingValuesForPHIs( PN->setIncomingValue(i, RRI.PHIValuesAtPseudoExit[PHIIndex++]); } - LS.CIVStart = LS.CIV->getIncomingValueForBlock(ContinuationBlock); + LS.IndVarStart = RRI.IndVarEnd; } -BasicBlock * -LoopConstrainer::createPreheader(const LoopConstrainer::LoopStructure &LS, - BasicBlock *OldPreheader, - const char *Tag) const { +BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, + BasicBlock *OldPreheader, + const char *Tag) const { BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); BranchInst::Create(LS.Header, Preheader); @@ -1018,30 +1169,79 @@ void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef BBs) { bool LoopConstrainer::run() { BasicBlock *Preheader = nullptr; - const char *CouldNotProceedBecause = nullptr; - if (!recognizeLoop(MainLoopStructure, LatchTakenCount, Preheader, - CouldNotProceedBecause)) { - DEBUG(dbgs() << "irce: could not recognize loop, " << CouldNotProceedBecause - << "\n";); - return false; - } + LatchTakenCount = SE.getExitCount(&OriginalLoop, MainLoopStructure.Latch); + Preheader = OriginalLoop.getLoopPreheader(); + assert(!isa(LatchTakenCount) && Preheader != nullptr && + "preconditions!"); OriginalPreheader = Preheader; MainLoopPreheader = Preheader; - Optional MaybeSR = calculateSubRanges(OriginalHeaderCount); + Optional MaybeSR = calculateSubRanges(); if (!MaybeSR.hasValue()) { DEBUG(dbgs() << "irce: could not compute subranges\n"); return false; } + SubRanges SR = MaybeSR.getValue(); + bool Increasing = MainLoopStructure.IndVarIncreasing; + IntegerType *IVTy = + cast(MainLoopStructure.IndVarNext->getType()); + + SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce"); + Instruction *InsertPt = OriginalPreheader->getTerminator(); // It would have been better to make `PreLoop' and `PostLoop' // `Optional's, but `ValueToValueMapTy' does not have a copy // constructor. ClonedLoop PreLoop, PostLoop; - bool NeedsPreLoop = SR.ExitPreLoopAt.hasValue(); - bool NeedsPostLoop = SR.ExitMainLoopAt.hasValue(); + bool NeedsPreLoop = + Increasing ? SR.LowLimit.hasValue() : SR.HighLimit.hasValue(); + bool NeedsPostLoop = + Increasing ? SR.HighLimit.hasValue() : SR.LowLimit.hasValue(); + + Value *ExitPreLoopAt = nullptr; + Value *ExitMainLoopAt = nullptr; + const SCEVConstant *MinusOneS = + cast(SE.getConstant(IVTy, -1, true /* isSigned */)); + + if (NeedsPreLoop) { + const SCEV *ExitPreLoopAtSCEV = nullptr; + + if (Increasing) + ExitPreLoopAtSCEV = *SR.LowLimit; + else { + if (CanBeSMin(SE, *SR.HighLimit)) { + DEBUG(dbgs() << "irce: could not prove no-overflow when computing " + << "preloop exit limit. HighLimit = " << *(*SR.HighLimit) + << "\n"); + return false; + } + ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); + } + + ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt); + ExitPreLoopAt->setName("exit.preloop.at"); + } + + if (NeedsPostLoop) { + const SCEV *ExitMainLoopAtSCEV = nullptr; + + if (Increasing) + ExitMainLoopAtSCEV = *SR.HighLimit; + else { + if (CanBeSMin(SE, *SR.LowLimit)) { + DEBUG(dbgs() << "irce: could not prove no-overflow when computing " + << "mainloop exit limit. LowLimit = " << *(*SR.LowLimit) + << "\n"); + return false; + } + ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); + } + + ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt); + ExitMainLoopAt->setName("exit.mainloop.at"); + } // We clone these ahead of time so that we don't have to deal with changing // and temporarily invalid IR as we transform the loops. @@ -1058,9 +1258,8 @@ bool LoopConstrainer::run() { MainLoopPreheader = createPreheader(MainLoopStructure, Preheader, "mainloop"); - PreLoopRRI = - changeIterationSpaceEnd(PreLoop.Structure, Preheader, - SR.ExitPreLoopAt.getValue(), MainLoopPreheader); + PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader, + ExitPreLoopAt, MainLoopPreheader); rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader, PreLoopRRI); } @@ -1072,8 +1271,7 @@ bool LoopConstrainer::run() { PostLoopPreheader = createPreheader(PostLoop.Structure, Preheader, "postloop"); PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader, - SR.ExitMainLoopAt.getValue(), - PostLoopPreheader); + ExitMainLoopAt, PostLoopPreheader); rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader, PostLoopRRI); } @@ -1096,53 +1294,80 @@ bool LoopConstrainer::run() { return true; } -/// Computes and returns a range of values for the induction variable in which -/// the range check can be safely elided. If it cannot compute such a range, -/// returns None. +/// Computes and returns a range of values for the induction variable (IndVar) +/// in which the range check can be safely elided. If it cannot compute such a +/// range, returns None. Optional InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, - IRBuilder<> &B) const { - - // Currently we support inequalities of the form: + const SCEVAddRecExpr *IndVar, + IRBuilder<> &) const { + // IndVar is of the form "A + B * I" (where "I" is the canonical induction + // variable, that may or may not exist as a real llvm::Value in the loop) and + // this inductive range check is a range check on the "C + D * I" ("C" is + // getOffset() and "D" is getScale()). We rewrite the value being range + // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA". + // Currently we support this only for "B" = "D" = { 1 or -1 }, but the code + // can be generalized as needed. // - // 0 <= Offset + 1 * CIV < L given L >= 0 + // The actual inequalities we solve are of the form // - // The inequality is satisfied by -Offset <= CIV < (L - Offset) [^1]. All - // additions and subtractions are twos-complement wrapping and comparisons are - // signed. + // 0 <= M + 1 * IndVar < L given L >= 0 (i.e. N == 1) + // + // The inequality is satisfied by -M <= IndVar < (L - M) [^1]. All additions + // and subtractions are twos-complement wrapping and comparisons are signed. // // Proof: // - // If there exists CIV such that -Offset <= CIV < (L - Offset) then it - // follows that -Offset <= (-Offset + L) [== Eq. 1]. Since L >= 0, if - // (-Offset + L) sign-overflows then (-Offset + L) < (-Offset). Hence by - // [Eq. 1], (-Offset + L) could not have overflown. + // If there exists IndVar such that -M <= IndVar < (L - M) then it follows + // that -M <= (-M + L) [== Eq. 1]. Since L >= 0, if (-M + L) sign-overflows + // then (-M + L) < (-M). Hence by [Eq. 1], (-M + L) could not have + // overflown. // - // This means CIV = t + (-Offset) for t in [0, L). Hence (CIV + Offset) = - // t. Hence 0 <= (CIV + Offset) < L + // This means IndVar = t + (-M) for t in [0, L). Hence (IndVar + M) = t. + // Hence 0 <= (IndVar + M) < L - // [^1]: Note that the solution does _not_ apply if L < 0; consider values - // Offset = 127, CIV = 126 and L = -2 in an i8 world. + // [^1]: Note that the solution does _not_ apply if L < 0; consider values M = + // 127, IndVar = 126 and L = -2 in an i8 world. - const SCEVConstant *ScaleC = dyn_cast(getScale()); - if (!(ScaleC && ScaleC->getValue()->getValue() == 1)) { - DEBUG(dbgs() << "irce: could not compute safe iteration space for:\n"; - print(dbgs())); + if (!IndVar->isAffine()) return None; - } - Value *OffsetV = SCEVExpander(SE, "safe.itr.space").expandCodeFor( - getOffset(), getOffset()->getType(), B.GetInsertPoint()); - OffsetV = MaybeSimplify(OffsetV); + const SCEV *A = IndVar->getStart(); + const SCEVConstant *B = dyn_cast(IndVar->getStepRecurrence(SE)); + if (!B) + return None; - Value *Begin = MaybeSimplify(B.CreateNeg(OffsetV)); - Value *End = MaybeSimplify(B.CreateSub(getLength(), OffsetV)); + const SCEV *C = getOffset(); + const SCEVConstant *D = dyn_cast(getScale()); + if (D != B) + return None; + ConstantInt *ConstD = D->getValue(); + if (!(ConstD->isMinusOne() || ConstD->isOne())) + return None; + + const SCEV *M = SE.getMinusSCEV(C, A); + + const SCEV *Begin = SE.getNegativeSCEV(M); + const SCEV *UpperLimit = nullptr; + + // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". + // We can potentially do much better here. + if (Value *V = getLength()) { + UpperLimit = SE.getSCEV(V); + } else { + assert(Kind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!"); + unsigned BitWidth = cast(IndVar->getType())->getBitWidth(); + UpperLimit = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); + } + + const SCEV *End = SE.getMinusSCEV(UpperLimit, M); return InductiveRangeCheck::Range(Begin, End); } static Optional -IntersectRange(const Optional &R1, +IntersectRange(ScalarEvolution &SE, + const Optional &R1, const InductiveRangeCheck::Range &R2, IRBuilder<> &B) { if (!R1.hasValue()) return R2; @@ -1153,9 +1378,10 @@ IntersectRange(const Optional &R1, if (R1Value.getType() != R2.getType()) return None; - Value *NewMin = ConstructSMaxOf(R1Value.getBegin(), R2.getBegin(), B); - Value *NewMax = ConstructSMinOf(R1Value.getEnd(), R2.getEnd(), B); - return InductiveRangeCheck::Range(NewMin, NewMax); + const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin()); + const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd()); + + return InductiveRangeCheck::Range(NewBegin, NewEnd); } bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { @@ -1174,7 +1400,8 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { InductiveRangeCheck::AllocatorTy IRCAlloc; SmallVector RangeChecks; ScalarEvolution &SE = getAnalysis(); - BranchProbabilityInfo &BPI = getAnalysis(); + BranchProbabilityInfo &BPI = + getAnalysis().getBPI(); for (auto BBI : L->getBlocks()) if (BranchInst *TBI = dyn_cast(BBI->getTerminator())) @@ -1185,12 +1412,33 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { if (RangeChecks.empty()) return false; - DEBUG(dbgs() << "irce: looking at loop "; L->print(dbgs()); - dbgs() << "irce: loop has " << RangeChecks.size() - << " inductive range checks: \n"; - for (InductiveRangeCheck *IRC : RangeChecks) - IRC->print(dbgs()); - ); + auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) { + OS << "irce: looking at loop "; L->print(OS); + OS << "irce: loop has " << RangeChecks.size() + << " inductive range checks: \n"; + for (InductiveRangeCheck *IRC : RangeChecks) + IRC->print(OS); + }; + + DEBUG(PrintRecognizedRangeChecks(dbgs())); + + if (PrintRangeChecks) + PrintRecognizedRangeChecks(errs()); + + const char *FailureReason = nullptr; + Optional MaybeLoopStructure = + LoopStructure::parseLoopStructure(SE, BPI, *L, FailureReason); + if (!MaybeLoopStructure.hasValue()) { + DEBUG(dbgs() << "irce: could not parse loop structure: " << FailureReason + << "\n";); + return false; + } + LoopStructure LS = MaybeLoopStructure.getValue(); + bool Increasing = LS.IndVarIncreasing; + const SCEV *MinusOne = + SE.getConstant(LS.IndVarNext->getType(), Increasing ? -1 : 1, true); + const SCEVAddRecExpr *IndVar = + cast(SE.getAddExpr(SE.getSCEV(LS.IndVarNext), MinusOne)); Optional SafeIterRange; Instruction *ExprInsertPt = Preheader->getTerminator(); @@ -1199,10 +1447,10 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { IRBuilder<> B(ExprInsertPt); for (InductiveRangeCheck *IRC : RangeChecks) { - auto Result = IRC->computeSafeIterationSpace(SE, B); + auto Result = IRC->computeSafeIterationSpace(SE, IndVar, B); if (Result.hasValue()) { auto MaybeSafeIterRange = - IntersectRange(SafeIterRange, Result.getValue(), B); + IntersectRange(SE, SafeIterRange, Result.getValue(), B); if (MaybeSafeIterRange.hasValue()) { RangeChecksToEliminate.push_back(IRC); SafeIterRange = MaybeSafeIterRange.getValue(); @@ -1213,8 +1461,8 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { if (!SafeIterRange.hasValue()) return false; - LoopConstrainer LC(*L, getAnalysis().getLoopInfo(), SE, - SafeIterRange.getValue()); + LoopConstrainer LC(*L, getAnalysis().getLoopInfo(), LS, + SE, SafeIterRange.getValue()); bool Changed = LC.run(); if (Changed) {