X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=3e3032b48fe6f614ed90eff544016a45191b93fb;hb=e115abc777851b4615de8c4232698a691b9e7b2c;hp=6ccf8aac3b893c4ba3b8ec344cf980ee7eb1bad0;hpb=89e11b110afa9af4bae5392791022b466ca96a5e;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 6ccf8aac3b8..3e3032b48fe 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -1,4 +1,4 @@ -//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===// +//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===// // // The LLVM Compiler Infrastructure // @@ -58,15 +58,17 @@ // //===----------------------------------------------------------------------===// -#define DEBUG_TYPE "scalar-evolution" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" @@ -79,16 +81,18 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Target/TargetLibraryInfo.h" #include using namespace llvm; +#define DEBUG_TYPE "scalar-evolution" + STATISTIC(NumArrayLenItCounts, "Number of trip counts computed with array length"); STATISTIC(NumTripCountsComputed, @@ -112,9 +116,10 @@ VerifySCEV("verify-scev", INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution", "Scalar Evolution Analysis", false, true) -INITIALIZE_PASS_DEPENDENCY(LoopInfo) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution", "Scalar Evolution Analysis", false, true) char ScalarEvolution::ID = 0; @@ -182,7 +187,7 @@ void SCEV::print(raw_ostream &OS) const { case scUMaxExpr: case scSMaxExpr: { const SCEVNAryExpr *NAry = cast(this); - const char *OpStr = 0; + const char *OpStr = nullptr; switch (NAry->getSCEVType()) { case scAddExpr: OpStr = " + "; break; case scMulExpr: OpStr = " * "; break; @@ -312,7 +317,7 @@ const SCEV *ScalarEvolution::getConstant(ConstantInt *V) { FoldingSetNodeID ID; ID.AddInteger(scConstant); ID.AddPointer(V); - void *IP = 0; + void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V); UniqueSCEVs.InsertNode(S, IP); @@ -365,7 +370,7 @@ void SCEVUnknown::deleted() { SE->UniqueSCEVs.RemoveNode(this); // Release the value. - setValPtr(0); + setValPtr(nullptr); } void SCEVUnknown::allUsesReplacedWith(Value *New) { @@ -670,7 +675,269 @@ static void GroupByComplexity(SmallVectorImpl &Ops, } } +namespace { +struct FindSCEVSize { + int Size; + FindSCEVSize() : Size(0) {} + + bool follow(const SCEV *S) { + ++Size; + // Keep looking at all operands of S. + return true; + } + bool isDone() const { + return false; + } +}; +} + +// Returns the size of the SCEV S. +static inline int sizeOfSCEV(const SCEV *S) { + FindSCEVSize F; + SCEVTraversal ST(F); + ST.visitAll(S); + return F.Size; +} + +namespace { + +struct SCEVDivision : public SCEVVisitor { +public: + // Computes the Quotient and Remainder of the division of Numerator by + // Denominator. + static void divide(ScalarEvolution &SE, const SCEV *Numerator, + const SCEV *Denominator, const SCEV **Quotient, + const SCEV **Remainder) { + assert(Numerator && Denominator && "Uninitialized SCEV"); + + SCEVDivision D(SE, Numerator, Denominator); + + // Check for the trivial case here to avoid having to check for it in the + // rest of the code. + if (Numerator == Denominator) { + *Quotient = D.One; + *Remainder = D.Zero; + return; + } + + if (Numerator->isZero()) { + *Quotient = D.Zero; + *Remainder = D.Zero; + return; + } + + // A simple case when N/1. The quotient is N. + if (Denominator->isOne()) { + *Quotient = Numerator; + *Remainder = D.Zero; + return; + } + + // Split the Denominator when it is a product. + if (const SCEVMulExpr *T = dyn_cast(Denominator)) { + const SCEV *Q, *R; + *Quotient = Numerator; + for (const SCEV *Op : T->operands()) { + divide(SE, *Quotient, Op, &Q, &R); + *Quotient = Q; + + // Bail out when the Numerator is not divisible by one of the terms of + // the Denominator. + if (!R->isZero()) { + *Quotient = D.Zero; + *Remainder = Numerator; + return; + } + } + *Remainder = D.Zero; + return; + } + + D.visit(Numerator); + *Quotient = D.Quotient; + *Remainder = D.Remainder; + } + + // Except in the trivial case described above, we do not know how to divide + // Expr by Denominator for the following functions with empty implementation. + void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} + void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} + void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} + void visitUDivExpr(const SCEVUDivExpr *Numerator) {} + void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} + void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} + void visitUnknown(const SCEVUnknown *Numerator) {} + void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} + + void visitConstant(const SCEVConstant *Numerator) { + if (const SCEVConstant *D = dyn_cast(Denominator)) { + APInt NumeratorVal = Numerator->getValue()->getValue(); + APInt DenominatorVal = D->getValue()->getValue(); + uint32_t NumeratorBW = NumeratorVal.getBitWidth(); + uint32_t DenominatorBW = DenominatorVal.getBitWidth(); + + if (NumeratorBW > DenominatorBW) + DenominatorVal = DenominatorVal.sext(NumeratorBW); + else if (NumeratorBW < DenominatorBW) + NumeratorVal = NumeratorVal.sext(DenominatorBW); + + APInt QuotientVal(NumeratorVal.getBitWidth(), 0); + APInt RemainderVal(NumeratorVal.getBitWidth(), 0); + APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); + Quotient = SE.getConstant(QuotientVal); + Remainder = SE.getConstant(RemainderVal); + return; + } + } + + void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { + const SCEV *StartQ, *StartR, *StepQ, *StepR; + assert(Numerator->isAffine() && "Numerator should be affine"); + divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); + divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); + // Bail out if the types do not match. + Type *Ty = Denominator->getType(); + if (Ty != StartQ->getType() || Ty != StartR->getType() || + Ty != StepQ->getType() || Ty != StepR->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } + Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), + Numerator->getNoWrapFlags()); + Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), + Numerator->getNoWrapFlags()); + } + + void visitAddExpr(const SCEVAddExpr *Numerator) { + SmallVector Qs, Rs; + Type *Ty = Denominator->getType(); + + for (const SCEV *Op : Numerator->operands()) { + const SCEV *Q, *R; + divide(SE, Op, Denominator, &Q, &R); + + // Bail out if types do not match. + if (Ty != Q->getType() || Ty != R->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + Qs.push_back(Q); + Rs.push_back(R); + } + + if (Qs.size() == 1) { + Quotient = Qs[0]; + Remainder = Rs[0]; + return; + } + + Quotient = SE.getAddExpr(Qs); + Remainder = SE.getAddExpr(Rs); + } + + void visitMulExpr(const SCEVMulExpr *Numerator) { + SmallVector Qs; + Type *Ty = Denominator->getType(); + + bool FoundDenominatorTerm = false; + for (const SCEV *Op : Numerator->operands()) { + // Bail out if types do not match. + if (Ty != Op->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } + if (FoundDenominatorTerm) { + Qs.push_back(Op); + continue; + } + + // Check whether Denominator divides one of the product operands. + const SCEV *Q, *R; + divide(SE, Op, Denominator, &Q, &R); + if (!R->isZero()) { + Qs.push_back(Op); + continue; + } + + // Bail out if types do not match. + if (Ty != Q->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + FoundDenominatorTerm = true; + Qs.push_back(Q); + } + + if (FoundDenominatorTerm) { + Remainder = Zero; + if (Qs.size() == 1) + Quotient = Qs[0]; + else + Quotient = SE.getMulExpr(Qs); + return; + } + + if (!isa(Denominator)) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + // The Remainder is obtained by replacing Denominator by 0 in Numerator. + ValueToValueMap RewriteMap; + RewriteMap[cast(Denominator)->getValue()] = + cast(Zero)->getValue(); + Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + + if (Remainder->isZero()) { + // The Quotient is obtained by replacing Denominator by 1 in Numerator. + RewriteMap[cast(Denominator)->getValue()] = + cast(One)->getValue(); + Quotient = + SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + return; + } + + // Quotient is (Numerator - Remainder) divided by Denominator. + const SCEV *Q, *R; + const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); + if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) { + // This SCEV does not seem to simplify: fail the division here. + Quotient = Zero; + Remainder = Numerator; + return; + } + divide(SE, Diff, Denominator, &Q, &R); + assert(R == Zero && + "(Numerator - Remainder) should evenly divide Denominator"); + Quotient = Q; + } + +private: + SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, + const SCEV *Denominator) + : SE(S), Denominator(Denominator) { + Zero = SE.getConstant(Denominator->getType(), 0); + One = SE.getConstant(Denominator->getType(), 1); + + // By default, we don't know how to divide Expr by Denominator. + // Providing the default here simplifies the rest of the code. + Quotient = Zero; + Remainder = Numerator; + } + + ScalarEvolution &SE; + const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; +}; + +} //===----------------------------------------------------------------------===// // Simple SCEV method implementations @@ -829,7 +1096,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, ID.AddInteger(scTruncate); ID.AddPointer(Op); ID.AddPointer(Ty); - void *IP = 0; + void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; // Fold if the operand is constant. @@ -850,13 +1117,14 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, return getTruncateOrZeroExtend(SZ->getOperand(), Ty); // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can - // eliminate all the truncates. + // eliminate all the truncates, or we replace other casts with truncates. if (const SCEVAddExpr *SA = dyn_cast(Op)) { SmallVector Operands; bool hasTrunc = false; for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) { const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty); - hasTrunc = isa(S); + if (!isa(SA->getOperand(i))) + hasTrunc = isa(S); Operands.push_back(S); } if (!hasTrunc) @@ -865,13 +1133,14 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, } // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can - // eliminate all the truncates. + // eliminate all the truncates, or we replace other casts with truncates. if (const SCEVMulExpr *SM = dyn_cast(Op)) { SmallVector Operands; bool hasTrunc = false; for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) { const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty); - hasTrunc = isa(S); + if (!isa(SM->getOperand(i))) + hasTrunc = isa(S); Operands.push_back(S); } if (!hasTrunc) @@ -896,6 +1165,262 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, return S; } +// Get the limit of a recurrence such that incrementing by Step cannot cause +// signed overflow as long as the value of the recurrence within the +// loop does not exceed this limit before incrementing. +static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); + if (SE->isKnownPositive(Step)) { + *Pred = ICmpInst::ICMP_SLT; + return SE->getConstant(APInt::getSignedMinValue(BitWidth) - + SE->getSignedRange(Step).getSignedMax()); + } + if (SE->isKnownNegative(Step)) { + *Pred = ICmpInst::ICMP_SGT; + return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - + SE->getSignedRange(Step).getSignedMin()); + } + return nullptr; +} + +// Get the limit of a recurrence such that incrementing by Step cannot cause +// unsigned overflow as long as the value of the recurrence within the loop does +// not exceed this limit before incrementing. +static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); + *Pred = ICmpInst::ICMP_ULT; + + return SE->getConstant(APInt::getMinValue(BitWidth) - + SE->getUnsignedRange(Step).getUnsignedMax()); +} + +namespace { + +struct ExtendOpTraitsBase { + typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *); +}; + +// Used to make code generic over signed and unsigned overflow. +template struct ExtendOpTraits { + // Members present: + // + // static const SCEV::NoWrapFlags WrapType; + // + // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr; + // + // static const SCEV *getOverflowLimitForStep(const SCEV *Step, + // ICmpInst::Predicate *Pred, + // ScalarEvolution *SE); +}; + +template <> +struct ExtendOpTraits : public ExtendOpTraitsBase { + static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW; + + static const GetExtendExprTy GetExtendExpr; + + static const SCEV *getOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + return getSignedOverflowLimitForStep(Step, Pred, SE); + } +}; + +const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< + SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr; + +template <> +struct ExtendOpTraits : public ExtendOpTraitsBase { + static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW; + + static const GetExtendExprTy GetExtendExpr; + + static const SCEV *getOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + return getUnsignedOverflowLimitForStep(Step, Pred, SE); + } +}; + +const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< + SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; +} + +// The recurrence AR has been shown to have no signed/unsigned wrap or something +// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as +// easily prove NSW/NUW for its preincrement or postincrement sibling. This +// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step + +// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the +// expression "Step + sext/zext(PreIncAR)" is congruent with +// "sext/zext(PostIncAR)" +template +static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, + ScalarEvolution *SE) { + auto WrapType = ExtendOpTraits::WrapType; + auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; + + const Loop *L = AR->getLoop(); + const SCEV *Start = AR->getStart(); + const SCEV *Step = AR->getStepRecurrence(*SE); + + // Check for a simple looking step prior to loop entry. + const SCEVAddExpr *SA = dyn_cast(Start); + if (!SA) + return nullptr; + + // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV + // subtraction is expensive. For this purpose, perform a quick and dirty + // difference, by checking for Step in the operand list. + SmallVector DiffOps; + for (const SCEV *Op : SA->operands()) + if (Op != Step) + DiffOps.push_back(Op); + + if (DiffOps.size() == SA->getNumOperands()) + return nullptr; + + // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` + + // `Step`: + + // 1. NSW/NUW flags on the step increment. + const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags()); + const SCEVAddRecExpr *PreAR = dyn_cast( + SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); + + // "{S,+,X} is /" and "the backedge is taken at least once" implies + // "S+X does not sign/unsign-overflow". + // + + const SCEV *BECount = SE->getBackedgeTakenCount(L); + if (PreAR && PreAR->getNoWrapFlags(WrapType) && + !isa(BECount) && SE->isKnownPositive(BECount)) + return PreStart; + + // 2. Direct overflow check on the step operation's expression. + unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); + Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); + const SCEV *OperandExtendedStart = + SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy), + (SE->*GetExtendExpr)(Step, WideTy)); + if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) { + if (PreAR && AR->getNoWrapFlags(WrapType)) { + // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW + // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then + // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact. + const_cast(PreAR)->setNoWrapFlags(WrapType); + } + return PreStart; + } + + // 3. Loop precondition. + ICmpInst::Predicate Pred; + const SCEV *OverflowLimit = + ExtendOpTraits::getOverflowLimitForStep(Step, &Pred, SE); + + if (OverflowLimit && + SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) { + return PreStart; + } + return nullptr; +} + +// Get the normalized zero or sign extended expression for this AddRec's Start. +template +static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, + ScalarEvolution *SE) { + auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; + + const SCEV *PreStart = getPreStartForExtend(AR, Ty, SE); + if (!PreStart) + return (SE->*GetExtendExpr)(AR->getStart(), Ty); + + return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty), + (SE->*GetExtendExpr)(PreStart, Ty)); +} + +// Try to prove away overflow by looking at "nearby" add recurrences. A +// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it +// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`. +// +// Formally: +// +// {S,+,X} == {S-T,+,X} + T +// => Ext({S,+,X}) == Ext({S-T,+,X} + T) +// +// If ({S-T,+,X} + T) does not overflow ... (1) +// +// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T) +// +// If {S-T,+,X} does not overflow ... (2) +// +// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T) +// == {Ext(S-T)+Ext(T),+,Ext(X)} +// +// If (S-T)+T does not overflow ... (3) +// +// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)} +// == {Ext(S),+,Ext(X)} == LHS +// +// Thus, if (1), (2) and (3) are true for some T, then +// Ext({S,+,X}) == {Ext(S),+,Ext(X)} +// +// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T) +// does not overflow" restricted to the 0th iteration. Therefore we only need +// to check for (1) and (2). +// +// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T +// is `Delta` (defined below). +// +template +bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, + const SCEV *Step, + const Loop *L) { + auto WrapType = ExtendOpTraits::WrapType; + + // We restrict `Start` to a constant to prevent SCEV from spending too much + // time here. It is correct (but more expensive) to continue with a + // non-constant `Start` and do a general SCEV subtraction to compute + // `PreStart` below. + // + const SCEVConstant *StartC = dyn_cast(Start); + if (!StartC) + return false; + + APInt StartAI = StartC->getValue()->getValue(); + + for (unsigned Delta : {-2, -1, 1, 2}) { + const SCEV *PreStart = getConstant(StartAI - Delta); + + // Give up if we don't already have the add recurrence we need because + // actually constructing an add recurrence is relatively expensive. + const SCEVAddRecExpr *PreAR = [&]() { + FoldingSetNodeID ID; + ID.AddInteger(scAddRecExpr); + ID.AddPointer(PreStart); + ID.AddPointer(Step); + ID.AddPointer(L); + void *IP = nullptr; + return static_cast( + this->UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + }(); + + if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2) + const SCEV *DeltaS = getConstant(StartC->getType(), Delta); + ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; + const SCEV *Limit = ExtendOpTraits::getOverflowLimitForStep( + DeltaS, &Pred, this); + if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1) + return true; + } + } + + return false; +} + const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && @@ -919,7 +1444,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, ID.AddInteger(scZeroExtend); ID.AddPointer(Op); ID.AddPointer(Ty); - void *IP = 0; + void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; // zext(trunc(x)) --> zext(x) or x or trunc(x) @@ -949,9 +1474,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. if (AR->getNoWrapFlags(SCEV::FlagNUW)) - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getZeroExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -988,9 +1513,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Cache knowledge of AR NUW, which is propagated to this AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getZeroExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as signed. // This covers loops that count down. @@ -1003,9 +1528,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Negative step causes unsigned wrap, but it still can't self-wrap. const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getSignExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } @@ -1023,9 +1548,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Cache knowledge of AR NUW, which is propagated to this AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getZeroExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } else if (isKnownNegative(Step)) { const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - @@ -1038,12 +1563,19 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Negative step causes unsigned wrap, but it still can't self-wrap. const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getSignExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } } + + if (proveNoWrapByVaryingStart(Start, Step, L)) { + const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + } } // The cast wasn't folded; create an explicit cast node. @@ -1055,105 +1587,6 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, return S; } -// Get the limit of a recurrence such that incrementing by Step cannot cause -// signed overflow as long as the value of the recurrence within the loop does -// not exceed this limit before incrementing. -static const SCEV *getOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { - unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); - if (SE->isKnownPositive(Step)) { - *Pred = ICmpInst::ICMP_SLT; - return SE->getConstant(APInt::getSignedMinValue(BitWidth) - - SE->getSignedRange(Step).getSignedMax()); - } - if (SE->isKnownNegative(Step)) { - *Pred = ICmpInst::ICMP_SGT; - return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - - SE->getSignedRange(Step).getSignedMin()); - } - return 0; -} - -// The recurrence AR has been shown to have no signed wrap. Typically, if we can -// prove NSW for AR, then we can just as easily prove NSW for its preincrement -// or postincrement sibling. This allows normalizing a sign extended AddRec as -// such: {sext(Step + Start),+,Step} => {(Step + sext(Start),+,Step} As a -// result, the expression "Step + sext(PreIncAR)" is congruent with -// "sext(PostIncAR)" -static const SCEV *getPreStartForSignExtend(const SCEVAddRecExpr *AR, - Type *Ty, - ScalarEvolution *SE) { - const Loop *L = AR->getLoop(); - const SCEV *Start = AR->getStart(); - const SCEV *Step = AR->getStepRecurrence(*SE); - - // Check for a simple looking step prior to loop entry. - const SCEVAddExpr *SA = dyn_cast(Start); - if (!SA) - return 0; - - // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV - // subtraction is expensive. For this purpose, perform a quick and dirty - // difference, by checking for Step in the operand list. - SmallVector DiffOps; - for (SCEVAddExpr::op_iterator I = SA->op_begin(), E = SA->op_end(); - I != E; ++I) { - if (*I != Step) - DiffOps.push_back(*I); - } - if (DiffOps.size() == SA->getNumOperands()) - return 0; - - // This is a postinc AR. Check for overflow on the preinc recurrence using the - // same three conditions that getSignExtendedExpr checks. - - // 1. NSW flags on the step increment. - const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags()); - const SCEVAddRecExpr *PreAR = dyn_cast( - SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); - - if (PreAR && PreAR->getNoWrapFlags(SCEV::FlagNSW)) - return PreStart; - - // 2. Direct overflow check on the step operation's expression. - unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); - Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); - const SCEV *OperandExtendedStart = - SE->getAddExpr(SE->getSignExtendExpr(PreStart, WideTy), - SE->getSignExtendExpr(Step, WideTy)); - if (SE->getSignExtendExpr(Start, WideTy) == OperandExtendedStart) { - // Cache knowledge of PreAR NSW. - if (PreAR) - const_cast(PreAR)->setNoWrapFlags(SCEV::FlagNSW); - // FIXME: this optimization needs a unit test - DEBUG(dbgs() << "SCEV: untested prestart overflow check\n"); - return PreStart; - } - - // 3. Loop precondition. - ICmpInst::Predicate Pred; - const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, SE); - - if (OverflowLimit && - SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) { - return PreStart; - } - return 0; -} - -// Get the normalized sign-extended expression for this AddRec's Start. -static const SCEV *getSignExtendAddRecStart(const SCEVAddRecExpr *AR, - Type *Ty, - ScalarEvolution *SE) { - const SCEV *PreStart = getPreStartForSignExtend(AR, Ty, SE); - if (!PreStart) - return SE->getSignExtendExpr(AR->getStart(), Ty); - - return SE->getAddExpr(SE->getSignExtendExpr(AR->getStepRecurrence(*SE), Ty), - SE->getSignExtendExpr(PreStart, Ty)); -} - const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && @@ -1181,7 +1614,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, ID.AddInteger(scSignExtend); ID.AddPointer(Op); ID.AddPointer(Ty); - void *IP = 0; + void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; // If the input value is provably positive, build a zext instead. @@ -1201,6 +1634,23 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, return getTruncateOrSignExtend(X, Ty); } + // sext(C1 + (C2 * x)) --> C1 + sext(C2 * x) if C1 < C2 + if (auto SA = dyn_cast(Op)) { + if (SA->getNumOperands() == 2) { + auto SC1 = dyn_cast(SA->getOperand(0)); + auto SMul = dyn_cast(SA->getOperand(1)); + if (SMul && SC1) { + if (auto SC2 = dyn_cast(SMul->getOperand(0))) { + const APInt &C1 = SC1->getValue()->getValue(); + const APInt &C2 = SC2->getValue()->getValue(); + if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && + C2.ugt(C1) && C2.isPowerOf2()) + return getAddExpr(getSignExtendExpr(SC1, Ty), + getSignExtendExpr(SMul, Ty)); + } + } + } + } // If the input value is a chrec scev, and we can prove that the value // did not overflow the old, smaller, value, we can sign extend all of the // operands (often constants). This allows analysis of something like @@ -1215,9 +1665,9 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. if (AR->getNoWrapFlags(SCEV::FlagNSW)) - return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), - L, SCEV::FlagNSW); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1254,9 +1704,9 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // Cache knowledge of AR NSW, which is propagated to this AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as unsigned. // This covers loops that count up with an unsigned step. @@ -1265,12 +1715,20 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, getMulExpr(WideMaxBECount, getZeroExtendExpr(Step, WideTy))); if (SAdd == OperandExtendedAdd) { - // Cache knowledge of AR NSW, which is propagated to this AddRec. - const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); + // If AR wraps around then + // + // abs(Step) * MaxBECount > unsigned-max(AR->getType()) + // => SAdd != OperandExtendedAdd + // + // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=> + // (SAdd == OperandExtendedAdd => AR is NW) + + const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); + // Return the expression with the addrec on the outside. - return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), - getZeroExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } @@ -1279,7 +1737,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // with the start value and the backedge is guarded by a comparison // with the post-inc value, the addrec is safe. ICmpInst::Predicate Pred; - const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, this); + const SCEV *OverflowLimit = + getSignedOverflowLimitForStep(Step, &Pred, this); if (OverflowLimit && (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) && @@ -1287,11 +1746,34 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, OverflowLimit)))) { // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); - return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } + // If Start and Step are constants, check if we can apply this + // transformation: + // sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2 + auto SC1 = dyn_cast(Start); + auto SC2 = dyn_cast(Step); + if (SC1 && SC2) { + const APInt &C1 = SC1->getValue()->getValue(); + const APInt &C2 = SC2->getValue()->getValue(); + if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && + C2.isPowerOf2()) { + Start = getSignExtendExpr(Start, Ty); + const SCEV *NewAR = getAddRecExpr(getConstant(AR->getType(), 0), Step, + L, AR->getNoWrapFlags()); + return getAddExpr(Start, getSignExtendExpr(NewAR, Ty)); + } + } + + if (proveNoWrapByVaryingStart(Start, Step, L)) { + const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + } } // The cast wasn't folded; create an explicit cast node. @@ -1340,9 +1822,8 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, // Force the cast to be folded into the operands of an addrec. if (const SCEVAddRecExpr *AR = dyn_cast(Op)) { SmallVector Ops; - for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end(); - I != E; ++I) - Ops.push_back(getAnyExtendExpr(*I, Ty)); + for (const SCEV *Op : AR->operands()) + Ops.push_back(getAnyExtendExpr(Op, Ty)); return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW); } @@ -1454,6 +1935,36 @@ namespace { }; } +// We're trying to construct a SCEV of type `Type' with `Ops' as operands and +// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of +// can't-overflow flags for the operation if possible. +static SCEV::NoWrapFlags +StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, + const SmallVectorImpl &Ops, + SCEV::NoWrapFlags OldFlags) { + using namespace std::placeholders; + + bool CanAnalyze = + Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; + (void)CanAnalyze; + assert(CanAnalyze && "don't call from other places!"); + + int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; + SCEV::NoWrapFlags SignOrUnsignWrap = + ScalarEvolution::maskFlags(OldFlags, SignOrUnsignMask); + + // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. + auto IsKnownNonNegative = + std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1); + + if (SignOrUnsignWrap == SCEV::FlagNSW && + std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative)) + return ScalarEvolution::setFlags(OldFlags, + (SCEV::NoWrapFlags)SignOrUnsignMask); + + return OldFlags; +} + /// getAddExpr - Get a canonical add expression, or something simpler if /// possible. const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, @@ -1469,20 +1980,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, "SCEVAddExpr operand types don't match!"); #endif - // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - // And vice-versa. - int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; - SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask); - if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) { - bool All = true; - for (SmallVectorImpl::const_iterator I = Ops.begin(), - E = Ops.end(); I != E; ++I) - if (!isKnownNonNegative(*I)) { - All = false; - break; - } - if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); - } + Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, LI); @@ -1811,7 +2309,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, ID.AddInteger(scAddExpr); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); - void *IP = 0; + void *IP = nullptr; SCEVAddExpr *S = static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { @@ -1857,6 +2355,24 @@ static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { return r; } +/// Determine if any of the operands in this SCEV are a constant or if +/// any of the add or multiply expressions in this SCEV contain a constant. +static bool containsConstantSomewhere(const SCEV *StartExpr) { + SmallVector Ops; + Ops.push_back(StartExpr); + while (!Ops.empty()) { + const SCEV *CurrentExpr = Ops.pop_back_val(); + if (isa(*CurrentExpr)) + return true; + + if (isa(*CurrentExpr) || isa(*CurrentExpr)) { + const auto *CurrentNAry = cast(CurrentExpr); + Ops.append(CurrentNAry->op_begin(), CurrentNAry->op_end()); + } + } + return false; +} + /// getMulExpr - Get a canonical multiply expression, or something simpler if /// possible. const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, @@ -1872,20 +2388,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, "SCEVMulExpr operand types don't match!"); #endif - // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - // And vice-versa. - int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; - SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask); - if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) { - bool All = true; - for (SmallVectorImpl::const_iterator I = Ops.begin(), - E = Ops.end(); I != E; ++I) - if (!isKnownNonNegative(*I)) { - All = false; - break; - } - if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); - } + Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, LI); @@ -1896,11 +2399,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // C1*(C2+V) -> C1*C2 + C1*V if (Ops.size() == 2) - if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) - if (Add->getNumOperands() == 2 && - isa(Add->getOperand(0))) - return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), - getMulExpr(LHSC, Add->getOperand(1))); + if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) + // If any of Add's ops are Adds or Muls with a constant, + // apply this transformation as well. + if (Add->getNumOperands() == 2) + if (containsConstantSomewhere(Add)) + return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), + getMulExpr(LHSC, Add->getOperand(1))); ++Idx; while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { @@ -2029,71 +2534,66 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // Okay, if there weren't any loop invariants to be folded, check to see if // there are multiple AddRec's with the same loop induction variable being // multiplied together. If so, we can fold them. + + // {A1,+,A2,+,...,+,An} * {B1,+,B2,+,...,+,Bn} + // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ + // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z + // ]]],+,...up to x=2n}. + // Note that the arguments to choose() are always integers with values + // known at compile time, never SCEV objects. + // + // The implementation avoids pointless extra computations when the two + // addrec's are of different length (mathematically, it's equivalent to + // an infinite stream of zeros on the right). + bool OpsModified = false; for (unsigned OtherIdx = Idx+1; - OtherIdx < Ops.size() && isa(Ops[OtherIdx]); + OtherIdx != Ops.size() && isa(Ops[OtherIdx]); ++OtherIdx) { - if (AddRecLoop != cast(Ops[OtherIdx])->getLoop()) + const SCEVAddRecExpr *OtherAddRec = + dyn_cast(Ops[OtherIdx]); + if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) continue; - // {A1,+,A2,+,...,+,An} * {B1,+,B2,+,...,+,Bn} - // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ - // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z - // ]]],+,...up to x=2n}. - // Note that the arguments to choose() are always integers with values - // known at compile time, never SCEV objects. - // - // The implementation avoids pointless extra computations when the two - // addrec's are of different length (mathematically, it's equivalent to - // an infinite stream of zeros on the right). - bool OpsModified = false; - for (; OtherIdx != Ops.size() && isa(Ops[OtherIdx]); - ++OtherIdx) { - const SCEVAddRecExpr *OtherAddRec = - dyn_cast(Ops[OtherIdx]); - if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) - continue; - - bool Overflow = false; - Type *Ty = AddRec->getType(); - bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; - SmallVector AddRecOps; - for (int x = 0, xe = AddRec->getNumOperands() + - OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { - const SCEV *Term = getConstant(Ty, 0); - for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { - uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); - for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), - ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); - z < ze && !Overflow; ++z) { - uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); - uint64_t Coeff; - if (LargerThan64Bits) - Coeff = umul_ov(Coeff1, Coeff2, Overflow); - else - Coeff = Coeff1*Coeff2; - const SCEV *CoeffTerm = getConstant(Ty, Coeff); - const SCEV *Term1 = AddRec->getOperand(y-z); - const SCEV *Term2 = OtherAddRec->getOperand(z); - Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2)); - } + bool Overflow = false; + Type *Ty = AddRec->getType(); + bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; + SmallVector AddRecOps; + for (int x = 0, xe = AddRec->getNumOperands() + + OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { + const SCEV *Term = getConstant(Ty, 0); + for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { + uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); + for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), + ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); + z < ze && !Overflow; ++z) { + uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); + uint64_t Coeff; + if (LargerThan64Bits) + Coeff = umul_ov(Coeff1, Coeff2, Overflow); + else + Coeff = Coeff1*Coeff2; + const SCEV *CoeffTerm = getConstant(Ty, Coeff); + const SCEV *Term1 = AddRec->getOperand(y-z); + const SCEV *Term2 = OtherAddRec->getOperand(z); + Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2)); } - AddRecOps.push_back(Term); - } - if (!Overflow) { - const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), - SCEV::FlagAnyWrap); - if (Ops.size() == 2) return NewAddRec; - Ops[Idx] = NewAddRec; - Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; - OpsModified = true; - AddRec = dyn_cast(NewAddRec); - if (!AddRec) - break; } + AddRecOps.push_back(Term); + } + if (!Overflow) { + const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), + SCEV::FlagAnyWrap); + if (Ops.size() == 2) return NewAddRec; + Ops[Idx] = NewAddRec; + Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; + OpsModified = true; + AddRec = dyn_cast(NewAddRec); + if (!AddRec) + break; } - if (OpsModified) - return getMulExpr(Ops); } + if (OpsModified) + return getMulExpr(Ops); // Otherwise couldn't fold anything into this recurrence. Move onto the // next one. @@ -2105,7 +2605,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, ID.AddInteger(scMulExpr); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); - void *IP = 0; + void *IP = nullptr; SCEVMulExpr *S = static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { @@ -2230,7 +2730,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, ID.AddInteger(scUDivExpr); ID.AddPointer(LHS); ID.AddPointer(RHS); - void *IP = 0; + void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), LHS, RHS); @@ -2354,20 +2854,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, // meaningful BE count at this point (and if we don't, we'd be stuck // with a SCEVCouldNotCompute as the cached BE count). - // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - // And vice-versa. - int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; - SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask); - if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) { - bool All = true; - for (SmallVectorImpl::const_iterator I = Operands.begin(), - E = Operands.end(); I != E; ++I) - if (!isKnownNonNegative(*I)) { - All = false; - break; - } - if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); - } + Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); // Canonicalize nested AddRecs in by nesting them in order of loop depth. if (const SCEVAddRecExpr *NestedAR = dyn_cast(Operands[0])) { @@ -2425,7 +2912,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, for (unsigned i = 0, e = Operands.size(); i != e; ++i) ID.AddPointer(Operands[i]); ID.AddPointer(L); - void *IP = 0; + void *IP = nullptr; SCEVAddRecExpr *S = static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { @@ -2439,6 +2926,57 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, return S; } +const SCEV * +ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr, + const SmallVectorImpl &IndexExprs, + bool InBounds) { + // getSCEV(Base)->getType() has the same address space as Base->getType() + // because SCEV::getType() preserves the address space. + Type *IntPtrTy = getEffectiveSCEVType(BaseExpr->getType()); + // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP + // instruction to its SCEV, because the Instruction may be guarded by control + // flow and the no-overflow bits may not be valid for the expression in any + // context. This can be fixed similarly to how these flags are handled for + // adds. + SCEV::NoWrapFlags Wrap = InBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap; + + const SCEV *TotalOffset = getConstant(IntPtrTy, 0); + // The address space is unimportant. The first thing we do on CurTy is getting + // its element type. + Type *CurTy = PointerType::getUnqual(PointeeType); + for (const SCEV *IndexExpr : IndexExprs) { + // Compute the (potentially symbolic) offset in bytes for this index. + if (StructType *STy = dyn_cast(CurTy)) { + // For a struct, add the member offset. + ConstantInt *Index = cast(IndexExpr)->getValue(); + unsigned FieldNo = Index->getZExtValue(); + const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo); + + // Add the field offset to the running total offset. + TotalOffset = getAddExpr(TotalOffset, FieldOffset); + + // Update CurTy to the type of the field at Index. + CurTy = STy->getTypeAtIndex(Index); + } else { + // Update CurTy to its element type. + CurTy = cast(CurTy)->getElementType(); + // For an array, add the element offset, explicitly scaled. + const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, CurTy); + // Getelementptr indices are signed. + IndexExpr = getTruncateOrSignExtend(IndexExpr, IntPtrTy); + + // Multiply the index by the element size to compute the element offset. + const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, Wrap); + + // Add the element offset to the running total offset. + TotalOffset = getAddExpr(TotalOffset, LocalOffset); + } + } + + // Add the total offset from all the GEP indices to the base. + return getAddExpr(BaseExpr, TotalOffset, Wrap); +} + const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) { SmallVector Ops; @@ -2533,7 +3071,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { ID.AddInteger(scSMaxExpr); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); - void *IP = 0; + void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; const SCEV **O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); @@ -2637,7 +3175,7 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { ID.AddInteger(scUMaxExpr); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); - void *IP = 0; + void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; const SCEV **O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); @@ -2660,39 +3198,23 @@ const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, } const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { - // If we have DataLayout, we can bypass creating a target-independent + // We can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. - if (DL) - return getConstant(IntTy, DL->getTypeAllocSize(AllocTy)); - - Constant *C = ConstantExpr::getSizeOf(AllocTy); - if (ConstantExpr *CE = dyn_cast(C)) - if (Constant *Folded = ConstantFoldConstantExpression(CE, DL, TLI)) - C = Folded; - Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy)); - assert(Ty == IntTy && "Effective SCEV type doesn't match"); - return getTruncateOrZeroExtend(getSCEV(C), Ty); + return getConstant(IntTy, + F->getParent()->getDataLayout().getTypeAllocSize(AllocTy)); } const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo) { - // If we have DataLayout, we can bypass creating a target-independent + // We can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. - if (DL) { - return getConstant(IntTy, - DL->getStructLayout(STy)->getElementOffset(FieldNo)); - } - - Constant *C = ConstantExpr::getOffsetOf(STy, FieldNo); - if (ConstantExpr *CE = dyn_cast(C)) - if (Constant *Folded = ConstantFoldConstantExpression(CE, DL, TLI)) - C = Folded; - - Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy)); - return getTruncateOrZeroExtend(getSCEV(C), Ty); + return getConstant( + IntTy, + F->getParent()->getDataLayout().getStructLayout(STy)->getElementOffset( + FieldNo)); } const SCEV *ScalarEvolution::getUnknown(Value *V) { @@ -2704,7 +3226,7 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) { FoldingSetNodeID ID; ID.AddInteger(scUnknown); ID.AddPointer(V); - void *IP = 0; + void *IP = nullptr; if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) { assert(cast(S)->getValue() == V && "Stale SCEVUnknown in uniquing map!"); @@ -2734,19 +3256,7 @@ bool ScalarEvolution::isSCEVable(Type *Ty) const { /// for which isSCEVable must return true. uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); - - // If we have a DataLayout, use it! - if (DL) - return DL->getTypeSizeInBits(Ty); - - // Integer types have fixed sizes. - if (Ty->isIntegerTy()) - return Ty->getPrimitiveSizeInBits(); - - // The only other support type is pointer. Without DataLayout, conservatively - // assume pointers are 64-bit. - assert(Ty->isPointerTy() && "isSCEVable permitted a non-SCEVable type!"); - return 64; + return F->getParent()->getDataLayout().getTypeSizeInBits(Ty); } /// getEffectiveSCEVType - Return a type with the same bitwidth as @@ -2762,12 +3272,7 @@ Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { // The only other support type is pointer. assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!"); - - if (DL) - return DL->getIntPtrType(Ty); - - // Without DataLayout, conservatively assume pointers are 64-bit. - return Type::getInt64Ty(getContext()); + return F->getParent()->getDataLayout().getIntPtrType(Ty); } const SCEV *ScalarEvolution::getCouldNotCompute() { @@ -2811,22 +3316,25 @@ bool ScalarEvolution::checkValidity(const SCEV *S) const { const SCEV *ScalarEvolution::getSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); + const SCEV *S = getExistingSCEV(V); + if (S == nullptr) { + S = createSCEV(V); + ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S)); + } + return S; +} + +const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { + assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); + ValueExprMapType::iterator I = ValueExprMap.find_as(V); if (I != ValueExprMap.end()) { const SCEV *S = I->second; if (checkValidity(S)) return S; - else - ValueExprMap.erase(I); + ValueExprMap.erase(I); } - const SCEV *S = createSCEV(V); - - // The process of creating a SCEV for V may have caused other SCEVs - // to have been created, so it's necessary to insert the new entry - // from scratch, rather than trying to remember the insert position - // above. - ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S)); - return S; + return nullptr; } /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V @@ -2864,8 +3372,9 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, if (LHS == RHS) return getConstant(LHS->getType(), 0); - // X - Y --> X + -Y - return getAddExpr(LHS, getNegativeSCEV(RHS), Flags); + // X - Y --> X + -Y. + // X -(nsw || nuw) Y --> X + -Y. + return getAddExpr(LHS, getNegativeSCEV(RHS)); } /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the @@ -3010,7 +3519,7 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { return getPointerBase(Cast->getOperand()); } else if (const SCEVNAryExpr *NAry = dyn_cast(V)) { - const SCEV *PtrOp = 0; + const SCEV *PtrOp = nullptr; for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); I != E; ++I) { if ((*I)->getType()->isPointerTy()) { @@ -3050,7 +3559,8 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { Visited.insert(PN); while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); @@ -3090,20 +3600,20 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // The loop may have multiple entrances or multiple exits; we can analyze // this phi as an addrec if it has a unique entry value and a unique // backedge value. - Value *BEValueV = 0, *StartValueV = 0; + Value *BEValueV = nullptr, *StartValueV = nullptr; for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { Value *V = PN->getIncomingValue(i); if (L->contains(PN->getIncomingBlock(i))) { if (!BEValueV) { BEValueV = V; } else if (BEValueV != V) { - BEValueV = 0; + BEValueV = nullptr; break; } } else if (!StartValueV) { StartValueV = V; } else if (StartValueV != V) { - StartValueV = 0; + StartValueV = nullptr; break; } } @@ -3152,10 +3662,12 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // If the increment doesn't overflow, then neither the addrec nor // the post-increment will overflow. if (const AddOperator *OBO = dyn_cast(BEValueV)) { - if (OBO->hasNoUnsignedWrap()) - Flags = setFlags(Flags, SCEV::FlagNUW); - if (OBO->hasNoSignedWrap()) - Flags = setFlags(Flags, SCEV::FlagNSW); + if (OBO->getOperand(0) == PN) { + if (OBO->hasNoUnsignedWrap()) + Flags = setFlags(Flags, SCEV::FlagNUW); + if (OBO->hasNoSignedWrap()) + Flags = setFlags(Flags, SCEV::FlagNSW); + } } else if (GEPOperator *GEP = dyn_cast(BEValueV)) { // If the increment is an inbounds GEP, then we know the address // space cannot be wrapped around. We cannot make any guarantee @@ -3163,19 +3675,17 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // unsigned but we may have a negative index from the base // pointer. We can guarantee that no unsigned wrap occurs if the // indices form a positive value. - if (GEP->isInBounds()) { + if (GEP->isInBounds() && GEP->getOperand(0) == PN) { Flags = setFlags(Flags, SCEV::FlagNW); const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) Flags = setFlags(Flags, SCEV::FlagNUW); } - } else if (const SubOperator *OBO = - dyn_cast(BEValueV)) { - if (OBO->hasNoUnsignedWrap()) - Flags = setFlags(Flags, SCEV::FlagNUW); - if (OBO->hasNoSignedWrap()) - Flags = setFlags(Flags, SCEV::FlagNSW); + + // We cannot transfer nuw and nsw flags from subtraction + // operations -- sub nuw X, Y is not the same as add nuw X, -Y + // for instance. } const SCEV *StartVal = getSCEV(StartValueV); @@ -3231,7 +3741,8 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // PHI's incoming blocks are in a different loop, in which case doing so // risks breaking LCSSA form. Instcombine would normally zap these, but // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = SimplifyInstruction(PN, DL, TLI, DT)) + if (Value *V = + SimplifyInstruction(PN, F->getParent()->getDataLayout(), TLI, DT, AC)) if (LI->replacementPreservesLCSSAForm(PN, V)) return getSCEV(V); @@ -3243,52 +3754,16 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { /// operations. This allows them to be analyzed by regular SCEV code. /// const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { - Type *IntPtrTy = getEffectiveSCEVType(GEP->getType()); Value *Base = GEP->getOperand(0); // Don't attempt to analyze GEPs over unsized objects. if (!Base->getType()->getPointerElementType()->isSized()) return getUnknown(GEP); - // Don't blindly transfer the inbounds flag from the GEP instruction to the - // Add expression, because the Instruction may be guarded by control flow - // and the no-overflow bits may not be valid for the expression in any - // context. - SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW : SCEV::FlagAnyWrap; - - const SCEV *TotalOffset = getConstant(IntPtrTy, 0); - gep_type_iterator GTI = gep_type_begin(GEP); - for (GetElementPtrInst::op_iterator I = std::next(GEP->op_begin()), - E = GEP->op_end(); - I != E; ++I) { - Value *Index = *I; - // Compute the (potentially symbolic) offset in bytes for this index. - if (StructType *STy = dyn_cast(*GTI++)) { - // For a struct, add the member offset. - unsigned FieldNo = cast(Index)->getZExtValue(); - const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo); - - // Add the field offset to the running total offset. - TotalOffset = getAddExpr(TotalOffset, FieldOffset); - } else { - // For an array, add the element offset, explicitly scaled. - const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, *GTI); - const SCEV *IndexS = getSCEV(Index); - // Getelementptr indices are signed. - IndexS = getTruncateOrSignExtend(IndexS, IntPtrTy); - - // Multiply the index by the element size to compute the element offset. - const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize, Wrap); - - // Add the element offset to the running total offset. - TotalOffset = getAddExpr(TotalOffset, LocalOffset); - } - } - - // Get the SCEV for the GEP base. - const SCEV *BaseS = getSCEV(Base); - - // Add the total offset from all the GEP indices to the base. - return getAddExpr(BaseS, TotalOffset, Wrap); + SmallVector IndexExprs; + for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) + IndexExprs.push_back(getSCEV(*Index)); + return getGEPExpr(GEP->getSourceElementType(), getSCEV(Base), IndexExprs, + GEP->isInBounds()); } /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is @@ -3363,7 +3838,8 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { // For a SCEVUnknown, ask ValueTracking. unsigned BitWidth = getTypeSizeInBits(U->getType()); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - ComputeMaskedBits(U->getValue(), Zeros, Ones); + computeKnownBits(U->getValue(), Zeros, Ones, + F->getParent()->getDataLayout(), 0, AC, nullptr, DT); return Zeros.countTrailingOnes(); } @@ -3371,79 +3847,120 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { return 0; } -/// getUnsignedRange - Determine the unsigned range for a particular SCEV. +/// GetRangeFromMetadata - Helper method to assign a range to V from +/// metadata present in the IR. +static Optional GetRangeFromMetadata(Value *V) { + if (Instruction *I = dyn_cast(V)) { + if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) { + ConstantRange TotalRange( + cast(I->getType())->getBitWidth(), false); + + unsigned NumRanges = MD->getNumOperands() / 2; + assert(NumRanges >= 1); + + for (unsigned i = 0; i < NumRanges; ++i) { + ConstantInt *Lower = + mdconst::extract(MD->getOperand(2 * i + 0)); + ConstantInt *Upper = + mdconst::extract(MD->getOperand(2 * i + 1)); + ConstantRange Range(Lower->getValue(), Upper->getValue()); + TotalRange = TotalRange.unionWith(Range); + } + + return TotalRange; + } + } + + return None; +} + +/// getRange - Determine the range for a particular SCEV. If SignHint is +/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges +/// with a "cleaner" unsigned (resp. signed) representation. /// ConstantRange -ScalarEvolution::getUnsignedRange(const SCEV *S) { +ScalarEvolution::getRange(const SCEV *S, + ScalarEvolution::RangeSignHint SignHint) { + DenseMap &Cache = + SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges + : SignedRanges; + // See if we've computed this range already. - DenseMap::iterator I = UnsignedRanges.find(S); - if (I != UnsignedRanges.end()) + DenseMap::iterator I = Cache.find(S); + if (I != Cache.end()) return I->second; if (const SCEVConstant *C = dyn_cast(S)) - return setUnsignedRange(C, ConstantRange(C->getValue()->getValue())); + return setRange(C, SignHint, ConstantRange(C->getValue()->getValue())); unsigned BitWidth = getTypeSizeInBits(S->getType()); ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); - // If the value has known zeros, the maximum unsigned value will have those - // known zeros as well. + // If the value has known zeros, the maximum value will have those known zeros + // as well. uint32_t TZ = GetMinTrailingZeros(S); - if (TZ != 0) - ConservativeResult = - ConstantRange(APInt::getMinValue(BitWidth), - APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1); + if (TZ != 0) { + if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) + ConservativeResult = + ConstantRange(APInt::getMinValue(BitWidth), + APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1); + else + ConservativeResult = ConstantRange( + APInt::getSignedMinValue(BitWidth), + APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); + } if (const SCEVAddExpr *Add = dyn_cast(S)) { - ConstantRange X = getUnsignedRange(Add->getOperand(0)); + ConstantRange X = getRange(Add->getOperand(0), SignHint); for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) - X = X.add(getUnsignedRange(Add->getOperand(i))); - return setUnsignedRange(Add, ConservativeResult.intersectWith(X)); + X = X.add(getRange(Add->getOperand(i), SignHint)); + return setRange(Add, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVMulExpr *Mul = dyn_cast(S)) { - ConstantRange X = getUnsignedRange(Mul->getOperand(0)); + ConstantRange X = getRange(Mul->getOperand(0), SignHint); for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) - X = X.multiply(getUnsignedRange(Mul->getOperand(i))); - return setUnsignedRange(Mul, ConservativeResult.intersectWith(X)); + X = X.multiply(getRange(Mul->getOperand(i), SignHint)); + return setRange(Mul, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVSMaxExpr *SMax = dyn_cast(S)) { - ConstantRange X = getUnsignedRange(SMax->getOperand(0)); + ConstantRange X = getRange(SMax->getOperand(0), SignHint); for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) - X = X.smax(getUnsignedRange(SMax->getOperand(i))); - return setUnsignedRange(SMax, ConservativeResult.intersectWith(X)); + X = X.smax(getRange(SMax->getOperand(i), SignHint)); + return setRange(SMax, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVUMaxExpr *UMax = dyn_cast(S)) { - ConstantRange X = getUnsignedRange(UMax->getOperand(0)); + ConstantRange X = getRange(UMax->getOperand(0), SignHint); for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) - X = X.umax(getUnsignedRange(UMax->getOperand(i))); - return setUnsignedRange(UMax, ConservativeResult.intersectWith(X)); + X = X.umax(getRange(UMax->getOperand(i), SignHint)); + return setRange(UMax, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { - ConstantRange X = getUnsignedRange(UDiv->getLHS()); - ConstantRange Y = getUnsignedRange(UDiv->getRHS()); - return setUnsignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y))); + ConstantRange X = getRange(UDiv->getLHS(), SignHint); + ConstantRange Y = getRange(UDiv->getRHS(), SignHint); + return setRange(UDiv, SignHint, + ConservativeResult.intersectWith(X.udiv(Y))); } if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) { - ConstantRange X = getUnsignedRange(ZExt->getOperand()); - return setUnsignedRange(ZExt, - ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); + ConstantRange X = getRange(ZExt->getOperand(), SignHint); + return setRange(ZExt, SignHint, + ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); } if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) { - ConstantRange X = getUnsignedRange(SExt->getOperand()); - return setUnsignedRange(SExt, - ConservativeResult.intersectWith(X.signExtend(BitWidth))); + ConstantRange X = getRange(SExt->getOperand(), SignHint); + return setRange(SExt, SignHint, + ConservativeResult.intersectWith(X.signExtend(BitWidth))); } if (const SCEVTruncateExpr *Trunc = dyn_cast(S)) { - ConstantRange X = getUnsignedRange(Trunc->getOperand()); - return setUnsignedRange(Trunc, - ConservativeResult.intersectWith(X.truncate(BitWidth))); + ConstantRange X = getRange(Trunc->getOperand(), SignHint); + return setRange(Trunc, SignHint, + ConservativeResult.intersectWith(X.truncate(BitWidth))); } if (const SCEVAddRecExpr *AddRec = dyn_cast(S)) { @@ -3456,138 +3973,6 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) { ConservativeResult.intersectWith( ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0))); - // TODO: non-affine addrec - if (AddRec->isAffine()) { - Type *Ty = AddRec->getType(); - const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); - if (!isa(MaxBECount) && - getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { - MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); - - const SCEV *Start = AddRec->getStart(); - const SCEV *Step = AddRec->getStepRecurrence(*this); - - ConstantRange StartRange = getUnsignedRange(Start); - ConstantRange StepRange = getSignedRange(Step); - ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); - ConstantRange EndRange = - StartRange.add(MaxBECountRange.multiply(StepRange)); - - // Check for overflow. This must be done with ConstantRange arithmetic - // because we could be called from within the ScalarEvolution overflow - // checking code. - ConstantRange ExtStartRange = StartRange.zextOrTrunc(BitWidth*2+1); - ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1); - ConstantRange ExtMaxBECountRange = - MaxBECountRange.zextOrTrunc(BitWidth*2+1); - ConstantRange ExtEndRange = EndRange.zextOrTrunc(BitWidth*2+1); - if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) != - ExtEndRange) - return setUnsignedRange(AddRec, ConservativeResult); - - APInt Min = APIntOps::umin(StartRange.getUnsignedMin(), - EndRange.getUnsignedMin()); - APInt Max = APIntOps::umax(StartRange.getUnsignedMax(), - EndRange.getUnsignedMax()); - if (Min.isMinValue() && Max.isMaxValue()) - return setUnsignedRange(AddRec, ConservativeResult); - return setUnsignedRange(AddRec, - ConservativeResult.intersectWith(ConstantRange(Min, Max+1))); - } - } - - return setUnsignedRange(AddRec, ConservativeResult); - } - - if (const SCEVUnknown *U = dyn_cast(S)) { - // For a SCEVUnknown, ask ValueTracking. - APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - ComputeMaskedBits(U->getValue(), Zeros, Ones, DL); - if (Ones == ~Zeros + 1) - return setUnsignedRange(U, ConservativeResult); - return setUnsignedRange(U, - ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1))); - } - - return setUnsignedRange(S, ConservativeResult); -} - -/// getSignedRange - Determine the signed range for a particular SCEV. -/// -ConstantRange -ScalarEvolution::getSignedRange(const SCEV *S) { - // See if we've computed this range already. - DenseMap::iterator I = SignedRanges.find(S); - if (I != SignedRanges.end()) - return I->second; - - if (const SCEVConstant *C = dyn_cast(S)) - return setSignedRange(C, ConstantRange(C->getValue()->getValue())); - - unsigned BitWidth = getTypeSizeInBits(S->getType()); - ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); - - // If the value has known zeros, the maximum signed value will have those - // known zeros as well. - uint32_t TZ = GetMinTrailingZeros(S); - if (TZ != 0) - ConservativeResult = - ConstantRange(APInt::getSignedMinValue(BitWidth), - APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); - - if (const SCEVAddExpr *Add = dyn_cast(S)) { - ConstantRange X = getSignedRange(Add->getOperand(0)); - for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) - X = X.add(getSignedRange(Add->getOperand(i))); - return setSignedRange(Add, ConservativeResult.intersectWith(X)); - } - - if (const SCEVMulExpr *Mul = dyn_cast(S)) { - ConstantRange X = getSignedRange(Mul->getOperand(0)); - for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) - X = X.multiply(getSignedRange(Mul->getOperand(i))); - return setSignedRange(Mul, ConservativeResult.intersectWith(X)); - } - - if (const SCEVSMaxExpr *SMax = dyn_cast(S)) { - ConstantRange X = getSignedRange(SMax->getOperand(0)); - for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) - X = X.smax(getSignedRange(SMax->getOperand(i))); - return setSignedRange(SMax, ConservativeResult.intersectWith(X)); - } - - if (const SCEVUMaxExpr *UMax = dyn_cast(S)) { - ConstantRange X = getSignedRange(UMax->getOperand(0)); - for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) - X = X.umax(getSignedRange(UMax->getOperand(i))); - return setSignedRange(UMax, ConservativeResult.intersectWith(X)); - } - - if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { - ConstantRange X = getSignedRange(UDiv->getLHS()); - ConstantRange Y = getSignedRange(UDiv->getRHS()); - return setSignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y))); - } - - if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) { - ConstantRange X = getSignedRange(ZExt->getOperand()); - return setSignedRange(ZExt, - ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); - } - - if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) { - ConstantRange X = getSignedRange(SExt->getOperand()); - return setSignedRange(SExt, - ConservativeResult.intersectWith(X.signExtend(BitWidth))); - } - - if (const SCEVTruncateExpr *Trunc = dyn_cast(S)) { - ConstantRange X = getSignedRange(Trunc->getOperand()); - return setSignedRange(Trunc, - ConservativeResult.intersectWith(X.truncate(BitWidth))); - } - - if (const SCEVAddRecExpr *AddRec = dyn_cast(S)) { // If there's no signed wrap, and all the operands have the same sign or // zero, the value won't ever change sign. if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) { @@ -3613,60 +3998,158 @@ ScalarEvolution::getSignedRange(const SCEV *S) { const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa(MaxBECount) && getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { + + // Check for overflow. This must be done with ConstantRange arithmetic + // because we could be called from within the ScalarEvolution overflow + // checking code. + MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); + ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); + ConstantRange ZExtMaxBECountRange = + MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1); const SCEV *Start = AddRec->getStart(); const SCEV *Step = AddRec->getStepRecurrence(*this); + ConstantRange StepSRange = getSignedRange(Step); + ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1); + + ConstantRange StartURange = getUnsignedRange(Start); + ConstantRange EndURange = + StartURange.add(MaxBECountRange.multiply(StepSRange)); + + // Check for unsigned overflow. + ConstantRange ZExtStartURange = + StartURange.zextOrTrunc(BitWidth * 2 + 1); + ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1); + if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == + ZExtEndURange) { + APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), + EndURange.getUnsignedMin()); + APInt Max = APIntOps::umax(StartURange.getUnsignedMax(), + EndURange.getUnsignedMax()); + bool IsFullRange = Min.isMinValue() && Max.isMaxValue(); + if (!IsFullRange) + ConservativeResult = + ConservativeResult.intersectWith(ConstantRange(Min, Max + 1)); + } - ConstantRange StartRange = getSignedRange(Start); - ConstantRange StepRange = getSignedRange(Step); - ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); - ConstantRange EndRange = - StartRange.add(MaxBECountRange.multiply(StepRange)); - - // Check for overflow. This must be done with ConstantRange arithmetic - // because we could be called from within the ScalarEvolution overflow - // checking code. - ConstantRange ExtStartRange = StartRange.sextOrTrunc(BitWidth*2+1); - ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1); - ConstantRange ExtMaxBECountRange = - MaxBECountRange.zextOrTrunc(BitWidth*2+1); - ConstantRange ExtEndRange = EndRange.sextOrTrunc(BitWidth*2+1); - if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) != - ExtEndRange) - return setSignedRange(AddRec, ConservativeResult); - - APInt Min = APIntOps::smin(StartRange.getSignedMin(), - EndRange.getSignedMin()); - APInt Max = APIntOps::smax(StartRange.getSignedMax(), - EndRange.getSignedMax()); - if (Min.isMinSignedValue() && Max.isMaxSignedValue()) - return setSignedRange(AddRec, ConservativeResult); - return setSignedRange(AddRec, - ConservativeResult.intersectWith(ConstantRange(Min, Max+1))); + ConstantRange StartSRange = getSignedRange(Start); + ConstantRange EndSRange = + StartSRange.add(MaxBECountRange.multiply(StepSRange)); + + // Check for signed overflow. This must be done with ConstantRange + // arithmetic because we could be called from within the ScalarEvolution + // overflow checking code. + ConstantRange SExtStartSRange = + StartSRange.sextOrTrunc(BitWidth * 2 + 1); + ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1); + if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == + SExtEndSRange) { + APInt Min = APIntOps::smin(StartSRange.getSignedMin(), + EndSRange.getSignedMin()); + APInt Max = APIntOps::smax(StartSRange.getSignedMax(), + EndSRange.getSignedMax()); + bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue(); + if (!IsFullRange) + ConservativeResult = + ConservativeResult.intersectWith(ConstantRange(Min, Max + 1)); + } } } - return setSignedRange(AddRec, ConservativeResult); + return setRange(AddRec, SignHint, ConservativeResult); } if (const SCEVUnknown *U = dyn_cast(S)) { - // For a SCEVUnknown, ask ValueTracking. - if (!U->getValue()->getType()->isIntegerTy() && !DL) - return setSignedRange(U, ConservativeResult); - unsigned NS = ComputeNumSignBits(U->getValue(), DL); - if (NS <= 1) - return setSignedRange(U, ConservativeResult); - return setSignedRange(U, ConservativeResult.intersectWith( - ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), - APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1))); - } + // Check if the IR explicitly contains !range metadata. + Optional MDRange = GetRangeFromMetadata(U->getValue()); + if (MDRange.hasValue()) + ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); + + // Split here to avoid paying the compile-time cost of calling both + // computeKnownBits and ComputeNumSignBits. This restriction can be lifted + // if needed. + const DataLayout &DL = F->getParent()->getDataLayout(); + if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { + // For a SCEVUnknown, ask ValueTracking. + APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); + computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT); + if (Ones != ~Zeros + 1) + ConservativeResult = + ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)); + } else { + assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED && + "generalize as needed!"); + unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AC, nullptr, DT); + if (NS > 1) + ConservativeResult = ConservativeResult.intersectWith( + ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), + APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1)); + } - return setSignedRange(S, ConservativeResult); + return setRange(U, SignHint, ConservativeResult); + } + + return setRange(S, SignHint, ConservativeResult); +} + +SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { + const BinaryOperator *BinOp = cast(V); + + // Return early if there are no flags to propagate to the SCEV. + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + if (BinOp->hasNoUnsignedWrap()) + Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); + if (BinOp->hasNoSignedWrap()) + Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); + if (Flags == SCEV::FlagAnyWrap) { + return SCEV::FlagAnyWrap; + } + + // Here we check that BinOp is in the header of the innermost loop + // containing BinOp, since we only deal with instructions in the loop + // header. The actual loop we need to check later will come from an add + // recurrence, but getting that requires computing the SCEV of the operands, + // which can be expensive. This check we can do cheaply to rule out some + // cases early. + Loop *innermostContainingLoop = LI->getLoopFor(BinOp->getParent()); + if (innermostContainingLoop == nullptr || + innermostContainingLoop->getHeader() != BinOp->getParent()) + return SCEV::FlagAnyWrap; + + // Only proceed if we can prove that BinOp does not yield poison. + if (!isKnownNotFullPoison(BinOp)) return SCEV::FlagAnyWrap; + + // At this point we know that if V is executed, then it does not wrap + // according to at least one of NSW or NUW. If V is not executed, then we do + // not know if the calculation that V represents would wrap. Multiple + // instructions can map to the same SCEV. If we apply NSW or NUW from V to + // the SCEV, we must guarantee no wrapping for that SCEV also when it is + // derived from other instructions that map to the same SCEV. We cannot make + // that guarantee for cases where V is not executed. So we need to find the + // loop that V is considered in relation to and prove that V is executed for + // every iteration of that loop. That implies that the value that V + // calculates does not wrap anywhere in the loop, so then we can apply the + // flags to the SCEV. + // + // We check isLoopInvariant to disambiguate in case we are adding two + // recurrences from different loops, so that we know which loop to prove + // that V is executed in. + for (int OpIndex = 0; OpIndex < 2; ++OpIndex) { + const SCEV *Op = getSCEV(BinOp->getOperand(OpIndex)); + if (auto *AddRec = dyn_cast(Op)) { + const int OtherOpIndex = 1 - OpIndex; + const SCEV *OtherOp = getSCEV(BinOp->getOperand(OtherOpIndex)); + if (isLoopInvariant(OtherOp, AddRec->getLoop()) && + isGuaranteedToExecuteForEveryIteration(BinOp, AddRec->getLoop())) + return Flags; + } + } + return SCEV::FlagAnyWrap; } -/// createSCEV - We know that there is no SCEV for the specified value. -/// Analyze the expression. +/// createSCEV - We know that there is no SCEV for the specified value. Analyze +/// the expression. /// const SCEV *ScalarEvolution::createSCEV(Value *V) { if (!isSCEVable(V->getType())) @@ -3703,29 +4186,52 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Instead, gather up all the operands and make a single getAddExpr call. // LLVM IR canonical form means we need only traverse the left operands. // - // Don't apply this instruction's NSW or NUW flags to the new - // expression. The instruction may be guarded by control flow that the - // no-wrap behavior depends on. Non-control-equivalent instructions can be - // mapped to the same SCEV expression, and it would be incorrect to transfer - // NSW/NUW semantics to those operations. + // FIXME: Expand this handling of NSW and NUW to other instructions, like + // sub and mul. SmallVector AddOps; - AddOps.push_back(getSCEV(U->getOperand(1))); - for (Value *Op = U->getOperand(0); ; Op = U->getOperand(0)) { - unsigned Opcode = Op->getValueID() - Value::InstructionVal; - if (Opcode != Instruction::Add && Opcode != Instruction::Sub) + for (Value *Op = U;; Op = U->getOperand(0)) { + U = dyn_cast(Op); + unsigned Opcode = U ? U->getOpcode() : 0; + if (!U || (Opcode != Instruction::Add && Opcode != Instruction::Sub)) { + assert(Op != V && "V should be an add"); + AddOps.push_back(getSCEV(Op)); break; - U = cast(Op); + } + + if (auto *OpSCEV = getExistingSCEV(Op)) { + AddOps.push_back(OpSCEV); + break; + } + + // If a NUW or NSW flag can be applied to the SCEV for this + // addition, then compute the SCEV for this addition by itself + // with a separate call to getAddExpr. We need to do that + // instead of pushing the operands of the addition onto AddOps, + // since the flags are only known to apply to this particular + // addition - they may not apply to other additions that can be + // formed with operands from AddOps. + // + // FIXME: Expand this to sub instructions. + if (Opcode == Instruction::Add && isa(U)) { + SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U); + if (Flags != SCEV::FlagAnyWrap) { + AddOps.push_back(getAddExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1)), Flags)); + break; + } + } + const SCEV *Op1 = getSCEV(U->getOperand(1)); if (Opcode == Instruction::Sub) AddOps.push_back(getNegativeSCEV(Op1)); else AddOps.push_back(Op1); } - AddOps.push_back(getSCEV(U->getOperand(0))); return getAddExpr(AddOps); } + case Instruction::Mul: { - // Don't transfer NSW/NUW for the same reason as AddExpr. + // FIXME: Transfer NSW/NUW as in AddExpr. SmallVector MulOps; MulOps.push_back(getSCEV(U->getOperand(1))); for (Value *Op = U->getOperand(0); @@ -3755,13 +4261,14 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Instcombine's ShrinkDemandedConstant may strip bits out of // constants, obscuring what would otherwise be a low-bits mask. - // Use ComputeMaskedBits to compute what ShrinkDemandedConstant + // Use computeKnownBits to compute what ShrinkDemandedConstant // knew about to reconstruct a low-bits mask value. unsigned LZ = A.countLeadingZeros(); unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - ComputeMaskedBits(U->getOperand(0), KnownZero, KnownOne, DL); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, + F->getParent()->getDataLayout(), 0, AC, nullptr, DT); APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); @@ -3952,9 +4459,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case ICmpInst::ICMP_SGE: // a >s b ? a+x : b+x -> smax(a, b)+x // a >s b ? b+x : a+x -> smin(a, b)+x - if (LHS->getType() == U->getType()) { - const SCEV *LS = getSCEV(LHS); - const SCEV *RS = getSCEV(RHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType())) { + const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType()); + const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -3975,9 +4483,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case ICmpInst::ICMP_UGE: // a >u b ? a+x : b+x -> umax(a, b)+x // a >u b ? b+x : a+x -> umin(a, b)+x - if (LHS->getType() == U->getType()) { - const SCEV *LS = getSCEV(LHS); - const SCEV *RS = getSCEV(RHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType())) { + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); + const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -3992,11 +4501,11 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { break; case ICmpInst::ICMP_NE: // n != 0 ? n+x : 1+x -> umax(n, 1)+x - if (LHS->getType() == U->getType() && - isa(RHS) && - cast(RHS)->isZero()) { - const SCEV *One = getConstant(LHS->getType(), 1); - const SCEV *LS = getSCEV(LHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType()) && + isa(RHS) && cast(RHS)->isZero()) { + const SCEV *One = getConstant(U->getType(), 1); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -4007,11 +4516,11 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { break; case ICmpInst::ICMP_EQ: // n == 0 ? 1+x : n+x -> umax(n, 1)+x - if (LHS->getType() == U->getType() && - isa(RHS) && - cast(RHS)->isZero()) { - const SCEV *One = getConstant(LHS->getType(), 1); - const SCEV *LS = getSCEV(LHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType()) && + isa(RHS) && cast(RHS)->isZero()) { + const SCEV *One = getConstant(U->getType(), 1); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, One); @@ -4038,6 +4547,14 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Iteration Count Computation Code // +unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { + if (BasicBlock *ExitingBB = L->getExitingBlock()) + return getSmallConstantTripCount(L, ExitingBB); + + // No trip count information for multiple exits. + return 0; +} + /// getSmallConstantTripCount - Returns the maximum trip count of this loop as a /// normal unsigned value. Returns 0 if the trip count is unknown or not /// constant. Will also return 0 if the maximum trip count is very large (>= @@ -4048,19 +4565,13 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { /// before taking the branch. For loops with multiple exits, it may not be the /// number times that the loop header executes because the loop may exit /// prematurely via another branch. -/// -/// FIXME: We conservatively call getBackedgeTakenCount(L) instead of -/// getExitCount(L, ExitingBlock) to compute a safe trip count considering all -/// loop exits. getExitCount() may return an exact count for this branch -/// assuming no-signed-wrap. The number of well-defined iterations may actually -/// be higher than this trip count if this exit test is skipped and the loop -/// exits via a different branch. Ideally, getExitCount() would know whether it -/// depends on a NSW assumption, and we would only fall back to a conservative -/// trip count in that case. -unsigned ScalarEvolution:: -getSmallConstantTripCount(Loop *L, BasicBlock * /*ExitingBlock*/) { +unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, + BasicBlock *ExitingBlock) { + assert(ExitingBlock && "Must pass a non-null exiting block!"); + assert(L->isLoopExiting(ExitingBlock) && + "Exiting block must actually branch out of the loop!"); const SCEVConstant *ExitCount = - dyn_cast(getBackedgeTakenCount(L)); + dyn_cast(getExitCount(L, ExitingBlock)); if (!ExitCount) return 0; @@ -4074,6 +4585,14 @@ getSmallConstantTripCount(Loop *L, BasicBlock * /*ExitingBlock*/) { return ((unsigned)ExitConst->getZExtValue()) + 1; } +unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { + if (BasicBlock *ExitingBB = L->getExitingBlock()) + return getSmallConstantTripMultiple(L, ExitingBB); + + // No trip multiple information for multiple exits. + return 0; +} + /// getSmallConstantTripMultiple - Returns the largest constant divisor of the /// trip count of this loop as a normal unsigned value, if possible. This /// means that the actual trip count is always a multiple of the returned @@ -4086,9 +4605,13 @@ getSmallConstantTripCount(Loop *L, BasicBlock * /*ExitingBlock*/) { /// /// As explained in the comments for getSmallConstantTripCount, this assumes /// that control exits the loop via ExitingBlock. -unsigned ScalarEvolution:: -getSmallConstantTripMultiple(Loop *L, BasicBlock * /*ExitingBlock*/) { - const SCEV *ExitCount = getBackedgeTakenCount(L); +unsigned +ScalarEvolution::getSmallConstantTripMultiple(Loop *L, + BasicBlock *ExitingBlock) { + assert(ExitingBlock && "Must pass a non-null exiting block!"); + assert(L->isLoopExiting(ExitingBlock) && + "Exiting block must actually branch out of the loop!"); + const SCEV *ExitCount = getExitCount(L, ExitingBlock); if (ExitCount == getCouldNotCompute()) return 1; @@ -4198,7 +4721,8 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { SmallPtrSet Visited; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); @@ -4250,7 +4774,8 @@ void ScalarEvolution::forgetLoop(const Loop *L) { SmallPtrSet Visited; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); @@ -4284,7 +4809,8 @@ void ScalarEvolution::forgetValue(Value *V) { SmallPtrSet Visited; while (!Worklist.empty()) { I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); @@ -4300,12 +4826,12 @@ void ScalarEvolution::forgetValue(Value *V) { } /// getExact - Get the exact loop backedge taken count considering all loop -/// exits. A computable result can only be return for loops with a single exit. -/// Returning the minimum taken count among all exits is incorrect because one -/// of the loop's exit limit's may have been skipped. HowFarToZero assumes that -/// the limit of each loop test is never skipped. This is a valid assumption as -/// long as the loop exits via that test. For precise results, it is the -/// caller's responsibility to specify the relevant loop exit using +/// exits. A computable result can only be returned for loops with a single +/// exit. Returning the minimum taken count among all exits is incorrect +/// because one of the loop's exit limit's may have been skipped. HowFarToZero +/// assumes that the limit of each loop test is never skipped. This is a valid +/// assumption as long as the loop exits via that test. For precise results, it +/// is the caller's responsibility to specify the relevant loop exit using /// getExact(ExitingBlock, SE). const SCEV * ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const { @@ -4316,9 +4842,9 @@ ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const { if (!ExitNotTaken.ExitingBlock) return SE->getCouldNotCompute(); assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info"); - const SCEV *BECount = 0; + const SCEV *BECount = nullptr; for (const ExitNotTakenInfo *ENT = &ExitNotTaken; - ENT != 0; ENT = ENT->getNextExit()) { + ENT != nullptr; ENT = ENT->getNextExit()) { assert(ENT->ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); @@ -4336,7 +4862,7 @@ const SCEV * ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock, ScalarEvolution *SE) const { for (const ExitNotTakenInfo *ENT = &ExitNotTaken; - ENT != 0; ENT = ENT->getNextExit()) { + ENT != nullptr; ENT = ENT->getNextExit()) { if (ENT->ExitingBlock == ExitingBlock) return ENT->ExactNotTaken; @@ -4359,7 +4885,7 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, return false; for (const ExitNotTakenInfo *ENT = &ExitNotTaken; - ENT != 0; ENT = ENT->getNextExit()) { + ENT != nullptr; ENT = ENT->getNextExit()) { if (ENT->ExactNotTaken != SE->getCouldNotCompute() && SE->hasOperand(ENT->ExactNotTaken, S)) { @@ -4398,8 +4924,8 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( /// clear - Invalidate this result and free the ExitNotTakenInfo array. void ScalarEvolution::BackedgeTakenInfo::clear() { - ExitNotTaken.ExitingBlock = 0; - ExitNotTaken.ExactNotTaken = 0; + ExitNotTaken.ExitingBlock = nullptr; + ExitNotTaken.ExactNotTaken = nullptr; delete[] ExitNotTaken.getNextExit(); } @@ -4410,38 +4936,55 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { SmallVector ExitingBlocks; L->getExitingBlocks(ExitingBlocks); - // Examine all exits and pick the most conservative values. - const SCEV *MaxBECount = getCouldNotCompute(); + SmallVector, 4> ExitCounts; bool CouldComputeBECount = true; BasicBlock *Latch = L->getLoopLatch(); // may be NULL. - const SCEV *LatchMaxCount = 0; - SmallVector, 4> ExitCounts; + const SCEV *MustExitMaxBECount = nullptr; + const SCEV *MayExitMaxBECount = nullptr; + + // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts + // and compute maxBECount. for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { - ExitLimit EL = ComputeExitLimit(L, ExitingBlocks[i]); + BasicBlock *ExitBB = ExitingBlocks[i]; + ExitLimit EL = ComputeExitLimit(L, ExitBB); + + // 1. For each exit that can be computed, add an entry to ExitCounts. + // CouldComputeBECount is true only if all exits can be computed. if (EL.Exact == getCouldNotCompute()) // We couldn't compute an exact value for this exit, so // we won't be able to compute an exact value for the loop. CouldComputeBECount = false; else - ExitCounts.push_back(std::make_pair(ExitingBlocks[i], EL.Exact)); - - if (MaxBECount == getCouldNotCompute()) - MaxBECount = EL.Max; - else if (EL.Max != getCouldNotCompute()) { - // We cannot take the "min" MaxBECount, because non-unit stride loops may - // skip some loop tests. Taking the max over the exits is sufficiently - // conservative. TODO: We could do better taking into consideration - // non-latch exits that dominate the latch. - if (EL.MustExit && ExitingBlocks[i] == Latch) - LatchMaxCount = EL.Max; - else - MaxBECount = getUMaxFromMismatchedTypes(MaxBECount, EL.Max); + ExitCounts.push_back(std::make_pair(ExitBB, EL.Exact)); + + // 2. Derive the loop's MaxBECount from each exit's max number of + // non-exiting iterations. Partition the loop exits into two kinds: + // LoopMustExits and LoopMayExits. + // + // If the exit dominates the loop latch, it is a LoopMustExit otherwise it + // is a LoopMayExit. If any computable LoopMustExit is found, then + // MaxBECount is the minimum EL.Max of computable LoopMustExits. Otherwise, + // MaxBECount is conservatively the maximum EL.Max, where CouldNotCompute is + // considered greater than any computable EL.Max. + if (EL.Max != getCouldNotCompute() && Latch && + DT->dominates(ExitBB, Latch)) { + if (!MustExitMaxBECount) + MustExitMaxBECount = EL.Max; + else { + MustExitMaxBECount = + getUMinFromMismatchedTypes(MustExitMaxBECount, EL.Max); + } + } else if (MayExitMaxBECount != getCouldNotCompute()) { + if (!MayExitMaxBECount || EL.Max == getCouldNotCompute()) + MayExitMaxBECount = EL.Max; + else { + MayExitMaxBECount = + getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.Max); + } } } - // Be more precise in the easy case of a loop latch that must exit. - if (LatchMaxCount) { - MaxBECount = getUMinFromMismatchedTypes(MaxBECount, LatchMaxCount); - } + const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount : + (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount); } @@ -4454,7 +4997,7 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { // exit at this block and remember the exit block and whether all other targets // lead to the loop header. bool MustExecuteLoopHeader = true; - BasicBlock *Exit = 0; + BasicBlock *Exit = nullptr; for (succ_iterator SI = succ_begin(ExitingBlock), SE = succ_end(ExitingBlock); SI != SE; ++SI) if (!L->contains(*SI)) { @@ -4491,8 +5034,7 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { if (!Pred) return getCouldNotCompute(); TerminatorInst *PredTerm = Pred->getTerminator(); - for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) { - BasicBlock *PredSucc = PredTerm->getSuccessor(i); + for (const BasicBlock *PredSucc : PredTerm->successors()) { if (PredSucc == BB) continue; // If the predecessor has a successor that isn't BB and isn't @@ -4510,18 +5052,19 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { return getCouldNotCompute(); } + bool IsOnlyExit = (L->getExitingBlock() != nullptr); TerminatorInst *Term = ExitingBlock->getTerminator(); if (BranchInst *BI = dyn_cast(Term)) { assert(BI->isConditional() && "If unconditional, it can't be in loop!"); // Proceed to the next level to examine the exit condition expression. return ComputeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1), - /*IsSubExpr=*/false); + /*ControlsExit=*/IsOnlyExit); } if (SwitchInst *SI = dyn_cast(Term)) return ComputeExitLimitFromSingleExitSwitch(L, SI, Exit, - /*IsSubExpr=*/false); + /*ControlsExit=*/IsOnlyExit); return getCouldNotCompute(); } @@ -4530,28 +5073,27 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { /// backedge of the specified loop will execute if its exit condition /// were a conditional branch of ExitCond, TBB, and FBB. /// -/// @param IsSubExpr is true if ExitCond does not directly control the exit -/// branch. In this case, we cannot assume that the loop only exits when the -/// condition is true and cannot infer that failing to meet the condition prior -/// to integer wraparound results in undefined behavior. +/// @param ControlsExit is true if ExitCond directly controls the exit +/// branch. In this case, we can assume that the loop exits only if the +/// condition is true and can infer that failing to meet the condition prior to +/// integer wraparound results in undefined behavior. ScalarEvolution::ExitLimit ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool IsSubExpr) { + bool ControlsExit) { // Check if the controlling expression for this loop is an And or Or. if (BinaryOperator *BO = dyn_cast(ExitCond)) { if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. bool EitherMayExit = L->contains(TBB); ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); - bool MustExit = false; if (EitherMayExit) { // Both conditions must be true for the loop to continue executing. // Choose the less conservative count. @@ -4566,7 +5108,6 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; else MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); - MustExit = EL0.MustExit || EL1.MustExit; } else { // Both conditions must be true at the same time for the loop to exit. // For now, be conservative. @@ -4575,21 +5116,19 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; if (EL0.Exact == EL1.Exact) BECount = EL0.Exact; - MustExit = EL0.MustExit && EL1.MustExit; } - return ExitLimit(BECount, MaxBECount, MustExit); + return ExitLimit(BECount, MaxBECount); } if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. bool EitherMayExit = L->contains(FBB); ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); - bool MustExit = false; if (EitherMayExit) { // Both conditions must be false for the loop to continue executing. // Choose the less conservative count. @@ -4604,7 +5143,6 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; else MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); - MustExit = EL0.MustExit || EL1.MustExit; } else { // Both conditions must be false at the same time for the loop to exit. // For now, be conservative. @@ -4613,17 +5151,16 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; if (EL0.Exact == EL1.Exact) BECount = EL0.Exact; - MustExit = EL0.MustExit && EL1.MustExit; } - return ExitLimit(BECount, MaxBECount, MustExit); + return ExitLimit(BECount, MaxBECount); } } // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. if (ICmpInst *ExitCondICmp = dyn_cast(ExitCond)) - return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, IsSubExpr); + return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); // Check for a constant condition. These are normally stripped out by // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to @@ -4650,7 +5187,7 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool IsSubExpr) { + bool ControlsExit) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Cond; @@ -4702,7 +5239,7 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, switch (Cond) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, IsSubExpr); + ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); if (EL.hasAnyInfo()) return EL; break; } @@ -4715,14 +5252,14 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = Cond == ICmpInst::ICMP_SLT; - ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, IsSubExpr); + ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, ControlsExit); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: { // while (X > Y) bool IsSigned = Cond == ICmpInst::ICMP_SGT; - ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, IsSubExpr); + ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit); if (EL.hasAnyInfo()) return EL; break; } @@ -4744,7 +5281,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock, - bool IsSubExpr) { + bool ControlsExit) { assert(!L->contains(ExitingBlock) && "Not an exiting block!"); // Give up if the exit is the default dest of a switch. @@ -4757,7 +5294,7 @@ ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L, const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); // while (X != Y) --> while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, IsSubExpr); + ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); if (EL.hasAnyInfo()) return EL; @@ -4800,7 +5337,7 @@ ScalarEvolution::ComputeLoadConstantCompareExitLimit( return getCouldNotCompute(); // Okay, we allow one non-constant index into the GEP instruction. - Value *VarIdx = 0; + Value *VarIdx = nullptr; std::vector Indexes; unsigned VarIdxNum = 0; for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i) @@ -4810,7 +5347,7 @@ ScalarEvolution::ComputeLoadConstantCompareExitLimit( if (VarIdx) return getCouldNotCompute(); // Multiple non-constant idx's. VarIdx = GEP->getOperand(i); VarIdxNum = i-2; - Indexes.push_back(0); + Indexes.push_back(nullptr); } // Loop-invariant loads may be a byproduct of loop optimization. Skip them. @@ -4841,7 +5378,7 @@ ScalarEvolution::ComputeLoadConstantCompareExitLimit( Constant *Result = ConstantFoldLoadThroughGEPIndices(GV->getInitializer(), Indexes); - if (Result == 0) break; // Cannot compute! + if (!Result) break; // Cannot compute! // Evaluate the condition for this iteration. Result = ConstantExpr::getICmp(predicate, Result, RHS); @@ -4881,12 +5418,9 @@ static bool canConstantEvolve(Instruction *I, const Loop *L) { if (!L->contains(I)) return false; if (isa(I)) { - if (L->getHeader() == I->getParent()) - return true; - else - // We don't currently keep track of the control flow needed to evaluate - // PHIs, so we cannot handle PHIs inside of loops. - return false; + // We don't currently keep track of the control flow needed to evaluate + // PHIs, so we cannot handle PHIs inside of loops. + return L->getHeader() == I->getParent(); } // If we won't be able to constant fold this expression even if the operands @@ -4902,14 +5436,14 @@ getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, // Otherwise, we can evaluate this instruction if all of its operands are // constant or derived from a PHI node themselves. - PHINode *PHI = 0; + PHINode *PHI = nullptr; for (Instruction::op_iterator OpI = UseInst->op_begin(), OpE = UseInst->op_end(); OpI != OpE; ++OpI) { if (isa(*OpI)) continue; Instruction *OpInst = dyn_cast(*OpI); - if (!OpInst || !canConstantEvolve(OpInst, L)) return 0; + if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr; PHINode *P = dyn_cast(OpInst); if (!P) @@ -4923,8 +5457,10 @@ getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap); PHIMap[OpInst] = P; } - if (P == 0) return 0; // Not evolving from PHI - if (PHI && PHI != P) return 0; // Evolving from multiple different PHIs. + if (!P) + return nullptr; // Not evolving from PHI + if (PHI && PHI != P) + return nullptr; // Evolving from multiple different PHIs. PHI = P; } // This is a expression evolving from a constant PHI! @@ -4938,7 +5474,7 @@ getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, /// constraints, return null. static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { Instruction *I = dyn_cast(V); - if (I == 0 || !canConstantEvolve(I, L)) return 0; + if (!I || !canConstantEvolve(I, L)) return nullptr; if (PHINode *PN = dyn_cast(I)) { return PN; @@ -4955,23 +5491,23 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { /// reason, return null. static Constant *EvaluateExpression(Value *V, const Loop *L, DenseMap &Vals, - const DataLayout *DL, + const DataLayout &DL, const TargetLibraryInfo *TLI) { // Convenient constant check, but redundant for recursive calls. if (Constant *C = dyn_cast(V)) return C; Instruction *I = dyn_cast(V); - if (!I) return 0; + if (!I) return nullptr; if (Constant *C = Vals.lookup(I)) return C; // An instruction inside the loop depends on a value outside the loop that we // weren't given a mapping for, or a value such as a call inside the loop. - if (!canConstantEvolve(I, L)) return 0; + if (!canConstantEvolve(I, L)) return nullptr; // An unmapped PHI can be due to a branch or another loop inside this loop, // or due to this not being the initial iteration through a loop where we // couldn't compute the evolution of this particular PHI last time. - if (isa(I)) return 0; + if (isa(I)) return nullptr; std::vector Operands(I->getNumOperands()); @@ -4979,12 +5515,12 @@ static Constant *EvaluateExpression(Value *V, const Loop *L, Instruction *Operand = dyn_cast(I->getOperand(i)); if (!Operand) { Operands[i] = dyn_cast(I->getOperand(i)); - if (!Operands[i]) return 0; + if (!Operands[i]) return nullptr; continue; } Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI); Vals[Operand] = C; - if (!C) return 0; + if (!C) return nullptr; Operands[i] = C; } @@ -5013,7 +5549,7 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, return I->second; if (BEs.ugt(MaxBruteForceIterations)) - return ConstantEvolutionLoopExitValue[PN] = 0; // Not going to evaluate it. + return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it. Constant *&RetVal = ConstantEvolutionLoopExitValue[PN]; @@ -5025,25 +5561,26 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, // entry must be a constant (coming in from outside of the loop), and the // second must be derived from the same PHI. bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1)); - PHINode *PHI = 0; + PHINode *PHI = nullptr; for (BasicBlock::iterator I = Header->begin(); (PHI = dyn_cast(I)); ++I) { Constant *StartCST = dyn_cast(PHI->getIncomingValue(!SecondIsBackedge)); - if (StartCST == 0) continue; + if (!StartCST) continue; CurrentIterVals[PHI] = StartCST; } if (!CurrentIterVals.count(PN)) - return RetVal = 0; + return RetVal = nullptr; Value *BEValue = PN->getIncomingValue(SecondIsBackedge); // Execute the loop symbolically to determine the exit value. if (BEs.getActiveBits() >= 32) - return RetVal = 0; // More than 2^32-1 iterations?? Not doing it! + return RetVal = nullptr; // More than 2^32-1 iterations?? Not doing it! unsigned NumIterations = BEs.getZExtValue(); // must be in range unsigned IterationNum = 0; + const DataLayout &DL = F->getParent()->getDataLayout(); for (; ; ++IterationNum) { if (IterationNum == NumIterations) return RetVal = CurrentIterVals[PN]; // Got exit value! @@ -5051,10 +5588,10 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, // Compute the value of the PHIs for the next iteration. // EvaluateExpression adds non-phi values to the CurrentIterVals map. DenseMap NextIterVals; - Constant *NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, - TLI); - if (NextPHI == 0) - return 0; // Couldn't evaluate! + Constant *NextPHI = + EvaluateExpression(BEValue, L, CurrentIterVals, DL, TLI); + if (!NextPHI) + return nullptr; // Couldn't evaluate! NextIterVals[PN] = NextPHI; bool StoppedEvolving = NextPHI == CurrentIterVals[PN]; @@ -5101,7 +5638,7 @@ const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) { PHINode *PN = getConstantEvolvingPHI(Cond, L); - if (PN == 0) return getCouldNotCompute(); + if (!PN) return getCouldNotCompute(); // If the loop is canonicalized, the PHI will have exactly two entries. // That's the only form we support here. @@ -5114,12 +5651,12 @@ const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L, // One entry must be a constant (coming in from outside of the loop), and the // second must be derived from the same PHI. bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1)); - PHINode *PHI = 0; + PHINode *PHI = nullptr; for (BasicBlock::iterator I = Header->begin(); (PHI = dyn_cast(I)); ++I) { Constant *StartCST = dyn_cast(PHI->getIncomingValue(!SecondIsBackedge)); - if (StartCST == 0) continue; + if (!StartCST) continue; CurrentIterVals[PHI] = StartCST; } if (!CurrentIterVals.count(PN)) @@ -5128,12 +5665,11 @@ const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L, // Okay, we find a PHI node that defines the trip count of this loop. Execute // the loop symbolically to determine when the condition gets a value of // "ExitWhen". - unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis. + const DataLayout &DL = F->getParent()->getDataLayout(); for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){ - ConstantInt *CondVal = - dyn_cast_or_null(EvaluateExpression(Cond, L, CurrentIterVals, - DL, TLI)); + ConstantInt *CondVal = dyn_cast_or_null( + EvaluateExpression(Cond, L, CurrentIterVals, DL, TLI)); // Couldn't symbolically evaluate. if (!CondVal) return getCouldNotCompute(); @@ -5189,7 +5725,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { if (Values[u].first == L) return Values[u].second ? Values[u].second : V; } - Values.push_back(std::make_pair(L, static_cast(0))); + Values.push_back(std::make_pair(L, static_cast(nullptr))); // Otherwise compute it. const SCEV *C = computeSCEVAtScope(V, L); SmallVector, 2> &Values2 = ValuesAtScopes[V]; @@ -5243,7 +5779,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { } for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) { Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i)); - if (!C2) return 0; + if (!C2) return nullptr; // First pointer! if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) { @@ -5258,13 +5794,13 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { // Don't bother trying to sum two pointers. We probably can't // statically compute a load that results from it anyway. if (C2->getType()->isPointerTy()) - return 0; + return nullptr; if (PointerType *PTy = dyn_cast(C->getType())) { if (PTy->getElementType()->isStructTy()) C2 = ConstantExpr::getIntegerCast( C2, Type::getInt32Ty(C->getContext()), true); - C = ConstantExpr::getGetElementPtr(C, C2); + C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2); } else C = ConstantExpr::getAdd(C, C2); } @@ -5276,10 +5812,10 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { const SCEVMulExpr *SM = cast(V); if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) { // Don't bother with pointers at all. - if (C->getType()->isPointerTy()) return 0; + if (C->getType()->isPointerTy()) return nullptr; for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) { Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i)); - if (!C2 || C2->getType()->isPointerTy()) return 0; + if (!C2 || C2->getType()->isPointerTy()) return nullptr; C = ConstantExpr::getMul(C, C2); } return C; @@ -5298,7 +5834,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { case scUMaxExpr: break; // TODO: smax, umax. } - return 0; + return nullptr; } const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { @@ -5365,17 +5901,17 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // Check to see if getSCEVAtScope actually made an improvement. if (MadeImprovement) { - Constant *C = 0; + Constant *C = nullptr; + const DataLayout &DL = F->getParent()->getDataLayout(); if (const CmpInst *CI = dyn_cast(I)) - C = ConstantFoldCompareInstOperands(CI->getPredicate(), - Operands[0], Operands[1], DL, - TLI); + C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], + Operands[1], DL, TLI); else if (const LoadInst *LI = dyn_cast(I)) { if (!LI->isVolatile()) C = ConstantFoldLoadFromConstPtr(Operands[0], DL); } else - C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), - Operands, DL, TLI); + C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, + DL, TLI); if (!C) return V; return getSCEV(C); } @@ -5628,7 +6164,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { /// effectively V != 0. We know and take advantage of the fact that this /// expression only being used in a comparison by zero context. ScalarEvolution::ExitLimit -ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr) { +ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // If the value is a constant if (const SCEVConstant *C = dyn_cast(V)) { // If the value is already zero, the branch will execute zero times. @@ -5657,7 +6193,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr) { dyn_cast(ConstantExpr::getICmp(CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { - if (CB->getZExtValue() == false) + if (!CB->getZExtValue()) std::swap(R1, R2); // R1 is the minimum root now. // We can only use this value if the chrec ends up with an exact zero @@ -5697,7 +6233,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr) { // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step. // We have not yet seen any such cases. const SCEVConstant *StepC = dyn_cast(Step); - if (StepC == 0 || StepC->getValue()->equalsInt(0)) + if (!StepC || StepC->getValue()->equalsInt(0)) return getCouldNotCompute(); // For positive steps (counting up until unsigned overflow): @@ -5722,38 +6258,34 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr) { else MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() : -CR.getUnsignedMin()); - return ExitLimit(Distance, MaxBECount, /*MustExit=*/true); - } - - // If the recurrence is known not to wraparound, unsigned divide computes the - // back edge count. (Ideally we would have an "isexact" bit for udiv). We know - // that the value will either become zero (and thus the loop terminates), that - // the loop will terminate through some other exit condition first, or that - // the loop has undefined behavior. This means we can't "miss" the exit - // value, even with nonunit stride, and exit later via the same branch. Note - // that we can skip this exit if loop later exits via a different - // branch. Hence MustExit=false. - // - // This is only valid for expressions that directly compute the loop exit. It - // is invalid for subexpressions in which the loop may exit through this - // branch even if this subexpression is false. In that case, the trip count - // computed by this udiv could be smaller than the number of well-defined - // iterations. - if (!IsSubExpr && AddRec->getNoWrapFlags(SCEV::FlagNW)) { + return ExitLimit(Distance, MaxBECount); + } + + // As a special case, handle the instance where Step is a positive power of + // two. In this case, determining whether Step divides Distance evenly can be + // done by counting and comparing the number of trailing zeros of Step and + // Distance. + if (!CountDown) { + const APInt &StepV = StepC->getValue()->getValue(); + // StepV.isPowerOf2() returns true if StepV is an positive power of two. It + // also returns true if StepV is maximally negative (eg, INT_MIN), but that + // case is not handled as this code is guarded by !CountDown. + if (StepV.isPowerOf2() && + GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) + return getUDivExactExpr(Distance, Step); + } + + // If the condition controls loop exit (the loop exits only if the expression + // is true) and the addition is no-wrap we can use unsigned divide to + // compute the backedge count. In this case, the step may not divide the + // distance, but we don't care because if the condition is "missed" the loop + // will have undefined behavior due to wrapping. + if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) { const SCEV *Exact = - getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - return ExitLimit(Exact, Exact, /*MustExit=*/false); + getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); + return ExitLimit(Exact, Exact); } - // If Step is a power of two that evenly divides Start we know that the loop - // will always terminate. Start may not be a constant so we just have the - // number of trailing zeros available. This is safe even in presence of - // overflow as the recurrence will overflow to exactly 0. - const APInt &StepV = StepC->getValue()->getValue(); - if (StepV.isPowerOf2() && - GetMinTrailingZeros(getNegativeSCEV(Start)) >= StepV.countTrailingZeros()) - return getUDivExactExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - // Then, try to solve the above equation provided that Start is constant. if (const SCEVConstant *StartC = dyn_cast(Start)) return SolveLinEquationWithOverflow(StepC->getValue()->getValue(), @@ -6129,28 +6661,167 @@ bool ScalarEvolution::isKnownNonZero(const SCEV *S) { return isKnownNegative(S) || isKnownPositive(S); } -bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { - // Canonicalize the inputs first. - (void)SimplifyICmpOperands(Pred, LHS, RHS); +bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // Canonicalize the inputs first. + (void)SimplifyICmpOperands(Pred, LHS, RHS); + + // If LHS or RHS is an addrec, check to see if the condition is true in + // every iteration of the loop. + // If LHS and RHS are both addrec, both conditions must be true in + // every iteration of the loop. + const SCEVAddRecExpr *LAR = dyn_cast(LHS); + const SCEVAddRecExpr *RAR = dyn_cast(RHS); + bool LeftGuarded = false; + bool RightGuarded = false; + if (LAR) { + const Loop *L = LAR->getLoop(); + if (isLoopEntryGuardedByCond(L, Pred, LAR->getStart(), RHS) && + isLoopBackedgeGuardedByCond(L, Pred, LAR->getPostIncExpr(*this), RHS)) { + if (!RAR) return true; + LeftGuarded = true; + } + } + if (RAR) { + const Loop *L = RAR->getLoop(); + if (isLoopEntryGuardedByCond(L, Pred, LHS, RAR->getStart()) && + isLoopBackedgeGuardedByCond(L, Pred, LHS, RAR->getPostIncExpr(*this))) { + if (!LAR) return true; + RightGuarded = true; + } + } + if (LeftGuarded && RightGuarded) + return true; + + // Otherwise see what can be done with known constant ranges. + return isKnownPredicateWithRanges(Pred, LHS, RHS); +} + +bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS, + ICmpInst::Predicate Pred, + bool &Increasing) { + bool Result = isMonotonicPredicateImpl(LHS, Pred, Increasing); + +#ifndef NDEBUG + // Verify an invariant: inverting the predicate should turn a monotonically + // increasing change to a monotonically decreasing one, and vice versa. + bool IncreasingSwapped; + bool ResultSwapped = isMonotonicPredicateImpl( + LHS, ICmpInst::getSwappedPredicate(Pred), IncreasingSwapped); + + assert(Result == ResultSwapped && "should be able to analyze both!"); + if (ResultSwapped) + assert(Increasing == !IncreasingSwapped && + "monotonicity should flip as we flip the predicate"); +#endif + + return Result; +} + +bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS, + ICmpInst::Predicate Pred, + bool &Increasing) { + + // A zero step value for LHS means the induction variable is essentially a + // loop invariant value. We don't really depend on the predicate actually + // flipping from false to true (for increasing predicates, and the other way + // around for decreasing predicates), all we care about is that *if* the + // predicate changes then it only changes from false to true. + // + // A zero step value in itself is not very useful, but there may be places + // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be + // as general as possible. + + switch (Pred) { + default: + return false; // Conservative answer + + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + if (!LHS->getNoWrapFlags(SCEV::FlagNUW)) + return false; + + Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE; + return true; + + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: { + if (!LHS->getNoWrapFlags(SCEV::FlagNSW)) + return false; + + const SCEV *Step = LHS->getStepRecurrence(*this); + + if (isKnownNonNegative(Step)) { + Increasing = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE; + return true; + } + + if (isKnownNonPositive(Step)) { + Increasing = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE; + return true; + } + + return false; + } + + } + + llvm_unreachable("switch has default clause!"); +} + +bool ScalarEvolution::isLoopInvariantPredicate( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, + ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS, + const SCEV *&InvariantRHS) { + + // If there is a loop-invariant, force it into the RHS, otherwise bail out. + if (!isLoopInvariant(RHS, L)) { + if (!isLoopInvariant(LHS, L)) + return false; + + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + const SCEVAddRecExpr *ArLHS = dyn_cast(LHS); + if (!ArLHS || ArLHS->getLoop() != L) + return false; + + bool Increasing; + if (!isMonotonicPredicate(ArLHS, Pred, Increasing)) + return false; + + // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to + // true as the loop iterates, and the backedge is control dependent on + // "ArLHS `Pred` RHS" == true then we can reason as follows: + // + // * if the predicate was false in the first iteration then the predicate + // is never evaluated again, since the loop exits without taking the + // backedge. + // * if the predicate was true in the first iteration then it will + // continue to be true for all future iterations since it is + // monotonically increasing. + // + // For both the above possibilities, we can replace the loop varying + // predicate with its value on the first iteration of the loop (which is + // loop invariant). + // + // A similar reasoning applies for a monotonically decreasing predicate, by + // replacing true with false and false with true in the above two bullets. + + auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred); - // If LHS or RHS is an addrec, check to see if the condition is true in - // every iteration of the loop. - if (const SCEVAddRecExpr *AR = dyn_cast(LHS)) - if (isLoopEntryGuardedByCond( - AR->getLoop(), Pred, AR->getStart(), RHS) && - isLoopBackedgeGuardedByCond( - AR->getLoop(), Pred, AR->getPostIncExpr(*this), RHS)) - return true; - if (const SCEVAddRecExpr *AR = dyn_cast(RHS)) - if (isLoopEntryGuardedByCond( - AR->getLoop(), Pred, LHS, AR->getStart()) && - isLoopBackedgeGuardedByCond( - AR->getLoop(), Pred, LHS, AR->getPostIncExpr(*this))) - return true; + if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS)) + return false; - // Otherwise see what can be done with known constant ranges. - return isKnownPredicateWithRanges(Pred, LHS, RHS); + InvariantPred = Pred; + InvariantLHS = ArLHS->getStart(); + InvariantRHS = RHS; + return true; } bool @@ -6238,19 +6909,92 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return true; + if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + BasicBlock *Latch = L->getLoopLatch(); if (!Latch) return false; BranchInst *LoopContinuePredicate = dyn_cast(Latch->getTerminator()); - if (!LoopContinuePredicate || - LoopContinuePredicate->isUnconditional()) + if (LoopContinuePredicate && LoopContinuePredicate->isConditional() && + isImpliedCond(Pred, LHS, RHS, + LoopContinuePredicate->getCondition(), + LoopContinuePredicate->getSuccessor(0) != L->getHeader())) + return true; + + // Check conditions due to any @llvm.assume intrinsics. + for (auto &AssumeVH : AC->assumptions()) { + if (!AssumeVH) + continue; + auto *CI = cast(AssumeVH); + if (!DT->dominates(CI, Latch->getTerminator())) + continue; + + if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) + return true; + } + + struct ClearWalkingBEDominatingCondsOnExit { + ScalarEvolution &SE; + + explicit ClearWalkingBEDominatingCondsOnExit(ScalarEvolution &SE) + : SE(SE){} + + ~ClearWalkingBEDominatingCondsOnExit() { + SE.WalkingBEDominatingConds = false; + } + }; + + // We don't want more than one activation of the following loop on the stack + // -- that can lead to O(n!) time complexity. + if (WalkingBEDominatingConds) + return false; + + WalkingBEDominatingConds = true; + ClearWalkingBEDominatingCondsOnExit ClearOnExit(*this); + + // If the loop is not reachable from the entry block, we risk running into an + // infinite loop as we walk up into the dom tree. These loops do not matter + // anyway, so we just return a conservative answer when we see them. + if (!DT->isReachableFromEntry(L->getHeader())) return false; - return isImpliedCond(Pred, LHS, RHS, - LoopContinuePredicate->getCondition(), - LoopContinuePredicate->getSuccessor(0) != L->getHeader()); + for (DomTreeNode *DTN = (*DT)[Latch], *HeaderDTN = (*DT)[L->getHeader()]; + DTN != HeaderDTN; + DTN = DTN->getIDom()) { + + assert(DTN && "should reach the loop header before reaching the root!"); + + BasicBlock *BB = DTN->getBlock(); + BasicBlock *PBB = BB->getSinglePredecessor(); + if (!PBB) + continue; + + BranchInst *ContinuePredicate = dyn_cast(PBB->getTerminator()); + if (!ContinuePredicate || !ContinuePredicate->isConditional()) + continue; + + Value *Condition = ContinuePredicate->getCondition(); + + // If we have an edge `E` within the loop body that dominates the only + // latch, the condition guarding `E` also guards the backedge. This + // reasoning works only for loops with a single latch. + + BasicBlockEdge DominatingEdge(PBB, BB); + if (DominatingEdge.isSingleEdge()) { + // We're constructively (and conservatively) enumerating edges within the + // loop body that dominate the latch. The dominator tree better agree + // with us on this: + assert(DT->dominates(DominatingEdge, Latch) && "should be!"); + + if (isImpliedCond(Pred, LHS, RHS, Condition, + BB != ContinuePredicate->getSuccessor(0))) + return true; + } + } + + return false; } /// isLoopEntryGuardedByCond - Test whether entry to the loop is protected @@ -6264,6 +7008,8 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return false; + if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + // Starting at the loop predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors // leading to the original header. @@ -6284,6 +7030,18 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, return true; } + // Check conditions due to any @llvm.assume intrinsics. + for (auto &AssumeVH : AC->assumptions()) { + if (!AssumeVH) + continue; + auto *CI = cast(AssumeVH); + if (!DT->dominates(CI, L->getHeader())) + continue; + + if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) + return true; + } + return false; } @@ -6331,15 +7089,6 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, ICmpInst *ICI = dyn_cast(FoundCondValue); if (!ICI) return false; - // Bail if the ICmp's operands' types are wider than the needed type - // before attempting to call getSCEV on them. This avoids infinite - // recursion, since the analysis of widening casts can require loop - // exit condition information for overflow checking, which would - // lead back here. - if (getTypeSizeInBits(LHS->getType()) < - getTypeSizeInBits(ICI->getOperand(0)->getType())) - return false; - // Now that we found a conditional branch that dominates the loop or controls // the loop latch. Check to see if it is the comparison we are looking for. ICmpInst::Predicate FoundPred; @@ -6351,9 +7100,17 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); - // Balance the types. The case where FoundLHS' type is wider than - // LHS' type is checked for above. - if (getTypeSizeInBits(LHS->getType()) > + // Balance the types. + if (getTypeSizeInBits(LHS->getType()) < + getTypeSizeInBits(FoundLHS->getType())) { + if (CmpInst::isSigned(Pred)) { + LHS = getSignExtendExpr(LHS, FoundLHS->getType()); + RHS = getSignExtendExpr(RHS, FoundLHS->getType()); + } else { + LHS = getZeroExtendExpr(LHS, FoundLHS->getType()); + RHS = getZeroExtendExpr(RHS, FoundLHS->getType()); + } + } else if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(FoundLHS->getType())) { if (CmpInst::isSigned(FoundPred)) { FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType()); @@ -6398,6 +7155,66 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, RHS, LHS, FoundLHS, FoundRHS); } + // Check if we can make progress by sharpening ranges. + if (FoundPred == ICmpInst::ICMP_NE && + (isa(FoundLHS) || isa(FoundRHS))) { + + const SCEVConstant *C = nullptr; + const SCEV *V = nullptr; + + if (isa(FoundLHS)) { + C = cast(FoundLHS); + V = FoundRHS; + } else { + C = cast(FoundRHS); + V = FoundLHS; + } + + // The guarding predicate tells us that C != V. If the known range + // of V is [C, t), we can sharpen the range to [C + 1, t). The + // range we consider has to correspond to same signedness as the + // predicate we're interested in folding. + + APInt Min = ICmpInst::isSigned(Pred) ? + getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin(); + + if (Min == C->getValue()->getValue()) { + // Given (V >= Min && V != Min) we conclude V >= (Min + 1). + // This is true even if (Min + 1) wraps around -- in case of + // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)). + + APInt SharperMin = Min + 1; + + switch (Pred) { + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_UGE: + // We know V `Pred` SharperMin. If this implies LHS `Pred` + // RHS, we're done. + if (isImpliedCondOperands(Pred, LHS, RHS, V, + getConstant(SharperMin))) + return true; + + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_UGT: + // We know from the range information that (V `Pred` Min || + // V == Min). We know from the guarding condition that !(V + // == Min). This gives us + // + // V `Pred` Min || V == Min && !(V == Min) + // => V `Pred` Min + // + // If V `Pred` Min implies LHS `Pred` RHS, we're done. + + if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) + return true; + + default: + // No change + break; + } + } + } + // Check whether the actual condition is beyond sufficient. if (FoundPred == ICmpInst::ICMP_EQ) if (ICmpInst::isTrueWhenEqual(Pred)) @@ -6419,6 +7236,9 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { + if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + return isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS) || // ~x < ~y --> x > y @@ -6427,6 +7247,85 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, getNotSCEV(FoundLHS)); } + +/// If Expr computes ~A, return A else return nullptr +static const SCEV *MatchNotExpr(const SCEV *Expr) { + const SCEVAddExpr *Add = dyn_cast(Expr); + if (!Add || Add->getNumOperands() != 2) return nullptr; + + const SCEVConstant *AddLHS = dyn_cast(Add->getOperand(0)); + if (!(AddLHS && AddLHS->getValue()->getValue().isAllOnesValue())) + return nullptr; + + const SCEVMulExpr *AddRHS = dyn_cast(Add->getOperand(1)); + if (!AddRHS || AddRHS->getNumOperands() != 2) return nullptr; + + const SCEVConstant *MulLHS = dyn_cast(AddRHS->getOperand(0)); + if (!(MulLHS && MulLHS->getValue()->getValue().isAllOnesValue())) + return nullptr; + + return AddRHS->getOperand(1); +} + + +/// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values? +template +static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr, + const SCEV *Candidate) { + const MaxExprType *MaxExpr = dyn_cast(MaybeMaxExpr); + if (!MaxExpr) return false; + + auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate); + return It != MaxExpr->op_end(); +} + + +/// Is MaybeMinExpr an SMin or UMin of Candidate and some other values? +template +static bool IsMinConsistingOf(ScalarEvolution &SE, + const SCEV *MaybeMinExpr, + const SCEV *Candidate) { + const SCEV *MaybeMaxExpr = MatchNotExpr(MaybeMinExpr); + if (!MaybeMaxExpr) + return false; + + return IsMaxConsistingOf(MaybeMaxExpr, SE.getNotSCEV(Candidate)); +} + + +/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max +/// expression? +static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + switch (Pred) { + default: + return false; + + case ICmpInst::ICMP_SGE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_SLE: + return + // min(A, ...) <= A + IsMinConsistingOf(SE, LHS, RHS) || + // A <= max(A, ...) + IsMaxConsistingOf(RHS, LHS); + + case ICmpInst::ICMP_UGE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_ULE: + return + // min(A, ...) <= A + IsMinConsistingOf(SE, LHS, RHS) || + // A <= max(A, ...) + IsMaxConsistingOf(RHS, LHS); + } + + llvm_unreachable("covered switch fell through?!"); +} + /// isImpliedCondOperandsHelper - Test whether the condition described by /// Pred, LHS, and RHS is true whenever the condition described by Pred, /// FoundLHS, and FoundRHS is true. @@ -6435,6 +7334,12 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { + auto IsKnownPredicateFull = + [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { + return isKnownPredicateWithRanges(Pred, LHS, RHS) || + IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS); + }; + switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: @@ -6444,26 +7349,26 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, break; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - if (isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, LHS, FoundLHS) && - isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, RHS, FoundRHS)) + if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) && + IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - if (isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, LHS, FoundLHS) && - isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, RHS, FoundRHS)) + if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) && + IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - if (isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, LHS, FoundLHS) && - isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, RHS, FoundRHS)) + if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) && + IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: - if (isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, LHS, FoundLHS) && - isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, RHS, FoundRHS)) + if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) && + IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS)) return true; break; } @@ -6471,8 +7376,49 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, return false; } -// Verify if an linear IV with positive stride can overflow when in a -// less-than comparison, knowing the invariant term of the comparison, the +/// isImpliedCondOperandsViaRanges - helper function for isImpliedCondOperands. +/// Tries to get cases like "X `sgt` 0 => X - 1 `sgt` -1". +bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, + const SCEV *LHS, + const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + if (!isa(RHS) || !isa(FoundRHS)) + // The restriction on `FoundRHS` be lifted easily -- it exists only to + // reduce the compile time impact of this optimization. + return false; + + const SCEVAddExpr *AddLHS = dyn_cast(LHS); + if (!AddLHS || AddLHS->getOperand(1) != FoundLHS || + !isa(AddLHS->getOperand(0))) + return false; + + APInt ConstFoundRHS = cast(FoundRHS)->getValue()->getValue(); + + // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the + // antecedent "`FoundLHS` `Pred` `FoundRHS`". + ConstantRange FoundLHSRange = + ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS); + + // Since `LHS` is `FoundLHS` + `AddLHS->getOperand(0)`, we can compute a range + // for `LHS`: + APInt Addend = + cast(AddLHS->getOperand(0))->getValue()->getValue(); + ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(Addend)); + + // We can also compute the range of values for `LHS` that satisfy the + // consequent, "`LHS` `Pred` `RHS`": + APInt ConstRHS = cast(RHS)->getValue()->getValue(); + ConstantRange SatisfyingLHSRange = + ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS); + + // The antecedent implies the consequent if every value of `LHS` that + // satisfies the antecedent also satisfies the consequent. + return SatisfyingLHSRange.contains(LHSRange); +} + +// Verify if an linear IV with positive stride can overflow when in a +// less-than comparison, knowing the invariant term of the comparison, the // stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { @@ -6500,7 +7446,7 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, return (MaxValue - MaxStrideMinusOne).ult(MaxRHS); } -// Verify if an linear IV with negative stride can overflow when in a +// Verify if an linear IV with negative stride can overflow when in a // greater-than comparison, knowing the invariant term of the comparison, // the stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, @@ -6531,7 +7477,7 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, // Compute the backedge taken count knowing the interval difference, the // stride and presence of the equality in the comparison. -const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, +const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, bool Equality) { const SCEV *One = getConstant(Step->getType(), 1); Delta = Equality ? getAddExpr(Delta, Step) @@ -6543,13 +7489,13 @@ const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, /// specified less-than comparison will execute. If not computable, return /// CouldNotCompute. /// -/// @param IsSubExpr is true when the LHS < RHS condition does not directly -/// control the branch. In this case, we can only compute an iteration count for -/// a subexpression that cannot overflow before evaluating true. +/// @param ControlsExit is true when the LHS < RHS condition directly controls +/// the branch (loops exits only if condition is true). In this case, we can use +/// NoWrapFlags to skip overflow checks. ScalarEvolution::ExitLimit ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool IsSubExpr) { + bool ControlsExit) { // We handle only IV < Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); @@ -6560,7 +7506,7 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, if (!IV || IV->getLoop() != L || !IV->isAffine()) return getCouldNotCompute(); - bool NoWrap = !IsSubExpr && + bool NoWrap = ControlsExit && IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); const SCEV *Stride = IV->getStepRecurrence(*this); @@ -6571,7 +7517,7 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // Avoid proven overflow cases: this will ensure that the backedge taken count // will not generate any unsigned overflow. Relaxed no-overflow conditions - // exploit NoWrapFlags, allowing to optimize in presence of undefined + // exploit NoWrapFlags, allowing to optimize in presence of undefined // behaviors like the case of C language. if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) return getCouldNotCompute(); @@ -6580,9 +7526,19 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, : ICmpInst::ICMP_ULT; const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) - End = IsSigned ? getSMaxExpr(RHS, Start) - : getUMaxExpr(RHS, Start); + if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) { + const SCEV *Diff = getMinusSCEV(RHS, Start); + // If we have NoWrap set, then we can assume that the increment won't + // overflow, in which case if RHS - Start is a constant, we don't need to + // do a max operation since we can just figure it out statically + if (NoWrap && isa(Diff)) { + APInt D = dyn_cast(Diff)->getValue()->getValue(); + if (D.isNegative()) + End = Start; + } else + End = IsSigned ? getSMaxExpr(RHS, Start) + : getUMaxExpr(RHS, Start); + } const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); @@ -6613,13 +7569,13 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, if (isa(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, /*MustExit=*/true); + return ExitLimit(BECount, MaxBECount); } ScalarEvolution::ExitLimit ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool IsSubExpr) { + bool ControlsExit) { // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); @@ -6630,7 +7586,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (!IV || IV->getLoop() != L || !IV->isAffine()) return getCouldNotCompute(); - bool NoWrap = !IsSubExpr && + bool NoWrap = ControlsExit && IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); @@ -6641,7 +7597,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, // Avoid proven overflow cases: this will ensure that the backedge taken count // will not generate any unsigned overflow. Relaxed no-overflow conditions - // exploit NoWrapFlags, allowing to optimize in presence of undefined + // exploit NoWrapFlags, allowing to optimize in presence of undefined // behaviors like the case of C language. if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap)) return getCouldNotCompute(); @@ -6651,9 +7607,19 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) - End = IsSigned ? getSMinExpr(RHS, Start) - : getUMinExpr(RHS, Start); + if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { + const SCEV *Diff = getMinusSCEV(RHS, Start); + // If we have NoWrap set, then we can assume that the increment won't + // overflow, in which case if RHS - Start is a constant, we don't need to + // do a max operation since we can just figure it out statically + if (NoWrap && isa(Diff)) { + APInt D = dyn_cast(Diff)->getValue()->getValue(); + if (!D.isNegative()) + End = Start; + } else + End = IsSigned ? getSMinExpr(RHS, Start) + : getUMinExpr(RHS, Start); + } const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); @@ -6679,13 +7645,13 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (isa(BECount)) MaxBECount = BECount; else - MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), + MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), getConstant(MinStride), false); if (isa(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, /*MustExit=*/true); + return ExitLimit(BECount, MaxBECount); } /// getNumIterationsInRange - Return the number of iterations of this loop that @@ -6779,7 +7745,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, if (ConstantInt *CB = dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { - if (CB->getZExtValue() == false) + if (!CB->getZExtValue()) std::swap(R1, R2); // R1 is the minimum root now. // Make sure the root is not off by one. The returned iteration should @@ -6814,387 +7780,358 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, return SE.getCouldNotCompute(); } -static const APInt srem(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); +namespace { +struct FindUndefs { + bool Found; + FindUndefs() : Found(false) {} - if (ABW > BBW) - B = B.sext(ABW); - else if (ABW < BBW) - A = A.sext(BBW); + bool follow(const SCEV *S) { + if (const SCEVUnknown *C = dyn_cast(S)) { + if (isa(C->getValue())) + Found = true; + } else if (const SCEVConstant *C = dyn_cast(S)) { + if (isa(C->getValue())) + Found = true; + } - return APIntOps::srem(A, B); + // Keep looking if we haven't found it yet. + return !Found; + } + bool isDone() const { + // Stop recursion if we have found an undef. + return Found; + } +}; } -static const APInt sdiv(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.sext(ABW); - else if (ABW < BBW) - A = A.sext(BBW); +// Return true when S contains at least an undef value. +static inline bool +containsUndefs(const SCEV *S) { + FindUndefs F; + SCEVTraversal ST(F); + ST.visitAll(S); - return APIntOps::sdiv(A, B); + return F.Found; } namespace { -struct SCEVGCD : public SCEVVisitor { -public: - // Pattern match Step into Start. When Step is a multiply expression, find - // the largest subexpression of Step that appears in Start. When Start is an - // add expression, try to match Step in the subexpressions of Start, non - // matching subexpressions are returned under Remainder. - static const SCEV *findGCD(ScalarEvolution &SE, const SCEV *Start, - const SCEV *Step, const SCEV **Remainder) { - assert(Remainder && "Remainder should not be NULL"); - SCEVGCD R(SE, Step, SE.getConstant(Step->getType(), 0)); - const SCEV *Res = R.visit(Start); - *Remainder = R.Remainder; - return Res; - } +// Collect all steps of SCEV expressions. +struct SCEVCollectStrides { + ScalarEvolution &SE; + SmallVectorImpl &Strides; - SCEVGCD(ScalarEvolution &S, const SCEV *G, const SCEV *R) - : SE(S), GCD(G), Remainder(R) { - Zero = SE.getConstant(GCD->getType(), 0); - One = SE.getConstant(GCD->getType(), 1); - } + SCEVCollectStrides(ScalarEvolution &SE, SmallVectorImpl &S) + : SE(SE), Strides(S) {} - const SCEV *visitConstant(const SCEVConstant *Constant) { - if (GCD == Constant || Constant == Zero) - return GCD; + bool follow(const SCEV *S) { + if (const SCEVAddRecExpr *AR = dyn_cast(S)) + Strides.push_back(AR->getStepRecurrence(SE)); + return true; + } + bool isDone() const { return false; } +}; - if (const SCEVConstant *CGCD = dyn_cast(GCD)) { - const SCEV *Res = SE.getConstant(gcd(Constant, CGCD)); - if (Res != One) - return Res; +// Collect all SCEVUnknown and SCEVMulExpr expressions. +struct SCEVCollectTerms { + SmallVectorImpl &Terms; - Remainder = SE.getConstant(srem(Constant, CGCD)); - Constant = cast(SE.getMinusSCEV(Constant, Remainder)); - Res = SE.getConstant(gcd(Constant, CGCD)); - return Res; - } + SCEVCollectTerms(SmallVectorImpl &T) + : Terms(T) {} - // When GCD is not a constant, it could be that the GCD is an Add, Mul, - // AddRec, etc., in which case we want to find out how many times the - // Constant divides the GCD: we then return that as the new GCD. - const SCEV *Rem = Zero; - const SCEV *Res = findGCD(SE, GCD, Constant, &Rem); + bool follow(const SCEV *S) { + if (isa(S) || isa(S)) { + if (!containsUndefs(S)) + Terms.push_back(S); - if (Res == One || Rem != Zero) { - Remainder = Constant; - return One; + // Stop recursion: once we collected a term, do not walk its operands. + return false; } - assert(isa(Res) && "Res should be a constant"); - Remainder = SE.getConstant(srem(Constant, cast(Res))); - return Res; + // Keep looking. + return true; } + bool isDone() const { return false; } +}; +} - const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { - if (GCD != Expr) - Remainder = Expr; - return GCD; - } +/// Find parametric terms in this SCEVAddRecExpr. +void ScalarEvolution::collectParametricTerms(const SCEV *Expr, + SmallVectorImpl &Terms) { + SmallVector Strides; + SCEVCollectStrides StrideCollector(*this, Strides); + visitAll(Expr, StrideCollector); - const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { - if (GCD != Expr) - Remainder = Expr; - return GCD; - } + DEBUG({ + dbgs() << "Strides:\n"; + for (const SCEV *S : Strides) + dbgs() << *S << "\n"; + }); - const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { - if (GCD != Expr) - Remainder = Expr; - return GCD; + for (const SCEV *S : Strides) { + SCEVCollectTerms TermCollector(Terms); + visitAll(S, TermCollector); } - const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { - if (GCD == Expr) - return GCD; + DEBUG({ + dbgs() << "Terms:\n"; + for (const SCEV *T : Terms) + dbgs() << *T << "\n"; + }); +} + +static bool findArrayDimensionsRec(ScalarEvolution &SE, + SmallVectorImpl &Terms, + SmallVectorImpl &Sizes) { + int Last = Terms.size() - 1; + const SCEV *Step = Terms[Last]; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { - const SCEV *Rem = Zero; - const SCEV *Res = findGCD(SE, Expr->getOperand(e - 1 - i), GCD, &Rem); + // End of recursion. + if (Last == 0) { + if (const SCEVMulExpr *M = dyn_cast(Step)) { + SmallVector Qs; + for (const SCEV *Op : M->operands()) + if (!isa(Op)) + Qs.push_back(Op); - // FIXME: There may be ambiguous situations: for instance, - // GCD(-4 + (3 * %m), 2 * %m) where 2 divides -4 and %m divides (3 * %m). - // The order in which the AddExpr is traversed computes a different GCD - // and Remainder. - if (Res != One) - GCD = Res; - if (Rem != Zero) - Remainder = SE.getAddExpr(Remainder, Rem); + Step = SE.getMulExpr(Qs); } - return GCD; + Sizes.push_back(Step); + return true; } - const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { - if (GCD == Expr) - return GCD; + for (const SCEV *&Term : Terms) { + // Normalize the terms before the next call to findArrayDimensionsRec. + const SCEV *Q, *R; + SCEVDivision::divide(SE, Term, Step, &Q, &R); - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { - if (Expr->getOperand(i) == GCD) - return GCD; - } + // Bail out when GCD does not evenly divide one of the terms. + if (!R->isZero()) + return false; - // If we have not returned yet, it means that GCD is not part of Expr. - const SCEV *PartialGCD = One; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { - const SCEV *Rem = Zero; - const SCEV *Res = findGCD(SE, Expr->getOperand(i), GCD, &Rem); - if (Rem != Zero) - // GCD does not divide Expr->getOperand(i). - continue; + Term = Q; + } - if (Res == GCD) - return GCD; - PartialGCD = SE.getMulExpr(PartialGCD, Res); - if (PartialGCD == GCD) - return GCD; - } - - if (PartialGCD != One) - return PartialGCD; - - // Failed to find a PartialGCD: set the Remainder to the full expression, - // and return the GCD. - Remainder = Expr; - const SCEVMulExpr *Mul = dyn_cast(GCD); - if (!Mul) - return GCD; - - // When the GCD is a multiply expression, try to decompose it: - // this occurs when Step does not divide the Start expression - // as in: {(-4 + (3 * %m)),+,(2 * %m)} - for (int i = 0, e = Mul->getNumOperands(); i < e; ++i) { - const SCEV *Rem = Zero; - const SCEV *Res = findGCD(SE, Expr, Mul->getOperand(i), &Rem); - if (Rem == Zero) { - Remainder = Rem; - return Res; - } - } + // Remove all SCEVConstants. + Terms.erase(std::remove_if(Terms.begin(), Terms.end(), [](const SCEV *E) { + return isa(E); + }), + Terms.end()); - return GCD; - } + if (Terms.size() > 0) + if (!findArrayDimensionsRec(SE, Terms, Sizes)) + return false; - const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { - if (GCD != Expr) - Remainder = Expr; - return GCD; - } + Sizes.push_back(Step); + return true; +} - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { - if (GCD == Expr) - return GCD; +namespace { +struct FindParameter { + bool FoundParameter; + FindParameter() : FoundParameter(false) {} - if (!Expr->isAffine()) { - Remainder = Expr; - return GCD; + bool follow(const SCEV *S) { + if (isa(S)) { + FoundParameter = true; + // Stop recursion: we found a parameter. + return false; } + // Keep looking. + return true; + } + bool isDone() const { + // Stop recursion if we have found a parameter. + return FoundParameter; + } +}; +} - const SCEV *Rem = Zero; - const SCEV *Res = findGCD(SE, Expr->getOperand(0), GCD, &Rem); - if (Res == One || Res->isAllOnesValue()) { - Remainder = Expr; - return GCD; - } +// Returns true when S contains at least a SCEVUnknown parameter. +static inline bool +containsParameters(const SCEV *S) { + FindParameter F; + SCEVTraversal ST(F); + ST.visitAll(S); - if (Rem != Zero) - Remainder = SE.getAddExpr(Remainder, Rem); + return F.FoundParameter; +} - Rem = Zero; - Res = findGCD(SE, Expr->getOperand(1), Res, &Rem); - if (Rem != Zero || Res == One || Res->isAllOnesValue()) { - Remainder = Expr; - return GCD; - } +// Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter. +static inline bool +containsParameters(SmallVectorImpl &Terms) { + for (const SCEV *T : Terms) + if (containsParameters(T)) + return true; + return false; +} - return Res; - } +// Return the number of product terms in S. +static inline int numberOfTerms(const SCEV *S) { + if (const SCEVMulExpr *Expr = dyn_cast(S)) + return Expr->getNumOperands(); + return 1; +} - const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { - if (GCD != Expr) - Remainder = Expr; - return GCD; - } +static const SCEV *removeConstantFactors(ScalarEvolution &SE, const SCEV *T) { + if (isa(T)) + return nullptr; - const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { - if (GCD != Expr) - Remainder = Expr; - return GCD; - } + if (isa(T)) + return T; - const SCEV *visitUnknown(const SCEVUnknown *Expr) { - if (GCD != Expr) - Remainder = Expr; - return GCD; - } + if (const SCEVMulExpr *M = dyn_cast(T)) { + SmallVector Factors; + for (const SCEV *Op : M->operands()) + if (!isa(Op)) + Factors.push_back(Op); - const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { - return One; + return SE.getMulExpr(Factors); } -private: - ScalarEvolution &SE; - const SCEV *GCD, *Remainder, *Zero, *One; -}; + return T; +} -struct SCEVDivision : public SCEVVisitor { -public: - // Remove from Start all multiples of Step. - static const SCEV *divide(ScalarEvolution &SE, const SCEV *Start, - const SCEV *Step) { - SCEVDivision D(SE, Step); - const SCEV *Rem = D.Zero; - (void)Rem; - // The division is guaranteed to succeed: Step should divide Start with no - // remainder. - assert(Step == SCEVGCD::findGCD(SE, Start, Step, &Rem) && Rem == D.Zero && - "Step should divide Start with no remainder."); - return D.visit(Start); - } +/// Return the size of an element read or written by Inst. +const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { + Type *Ty; + if (StoreInst *Store = dyn_cast(Inst)) + Ty = Store->getValueOperand()->getType(); + else if (LoadInst *Load = dyn_cast(Inst)) + Ty = Load->getType(); + else + return nullptr; - SCEVDivision(ScalarEvolution &S, const SCEV *G) : SE(S), GCD(G) { - Zero = SE.getConstant(GCD->getType(), 0); - One = SE.getConstant(GCD->getType(), 1); - } + Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty)); + return getSizeOfExpr(ETy, Ty); +} - const SCEV *visitConstant(const SCEVConstant *Constant) { - if (GCD == Constant) - return One; +/// Second step of delinearization: compute the array dimensions Sizes from the +/// set of Terms extracted from the memory access function of this SCEVAddRec. +void ScalarEvolution::findArrayDimensions(SmallVectorImpl &Terms, + SmallVectorImpl &Sizes, + const SCEV *ElementSize) const { - if (const SCEVConstant *CGCD = dyn_cast(GCD)) - return SE.getConstant(sdiv(Constant, CGCD)); - return Constant; - } + if (Terms.size() < 1 || !ElementSize) + return; - const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { - if (GCD == Expr) - return One; - return Expr; - } + // Early return when Terms do not contain parameters: we do not delinearize + // non parametric SCEVs. + if (!containsParameters(Terms)) + return; - const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { - if (GCD == Expr) - return One; - return Expr; - } + DEBUG({ + dbgs() << "Terms:\n"; + for (const SCEV *T : Terms) + dbgs() << *T << "\n"; + }); - const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { - if (GCD == Expr) - return One; - return Expr; - } + // Remove duplicates. + std::sort(Terms.begin(), Terms.end()); + Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end()); - const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { - if (GCD == Expr) - return One; + // Put larger terms first. + std::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) { + return numberOfTerms(LHS) > numberOfTerms(RHS); + }); - SmallVector Operands; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(divide(SE, Expr->getOperand(i), GCD)); + ScalarEvolution &SE = *const_cast(this); - if (Operands.size() == 1) - return Operands[0]; - return SE.getAddExpr(Operands); + // Divide all terms by the element size. + for (const SCEV *&Term : Terms) { + const SCEV *Q, *R; + SCEVDivision::divide(SE, Term, ElementSize, &Q, &R); + Term = Q; } - const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { - if (GCD == Expr) - return One; - - bool FoundGCDTerm = false; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - if (Expr->getOperand(i) == GCD) - FoundGCDTerm = true; + SmallVector NewTerms; - SmallVector Operands; - if (FoundGCDTerm) { - FoundGCDTerm = false; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { - if (FoundGCDTerm) - Operands.push_back(Expr->getOperand(i)); - else if (Expr->getOperand(i) == GCD) - FoundGCDTerm = true; - else - Operands.push_back(Expr->getOperand(i)); - } - } else { - const SCEV *PartialGCD = One; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { - if (PartialGCD == GCD) { - Operands.push_back(Expr->getOperand(i)); - continue; - } + // Remove constant factors. + for (const SCEV *T : Terms) + if (const SCEV *NewT = removeConstantFactors(SE, T)) + NewTerms.push_back(NewT); - const SCEV *Rem = Zero; - const SCEV *Res = SCEVGCD::findGCD(SE, Expr->getOperand(i), GCD, &Rem); - if (Rem == Zero) { - PartialGCD = SE.getMulExpr(PartialGCD, Res); - Operands.push_back(divide(SE, Expr->getOperand(i), GCD)); - } else { - Operands.push_back(Expr->getOperand(i)); - } - } - } + DEBUG({ + dbgs() << "Terms after sorting:\n"; + for (const SCEV *T : NewTerms) + dbgs() << *T << "\n"; + }); - if (Operands.size() == 1) - return Operands[0]; - return SE.getMulExpr(Operands); + if (NewTerms.empty() || + !findArrayDimensionsRec(SE, NewTerms, Sizes)) { + Sizes.clear(); + return; } - const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { - if (GCD == Expr) - return One; - return Expr; - } + // The last element to be pushed into Sizes is the size of an element. + Sizes.push_back(ElementSize); + + DEBUG({ + dbgs() << "Sizes:\n"; + for (const SCEV *S : Sizes) + dbgs() << *S << "\n"; + }); +} - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { - if (GCD == Expr) - return One; +/// Third step of delinearization: compute the access functions for the +/// Subscripts based on the dimensions in Sizes. +void ScalarEvolution::computeAccessFunctions( + const SCEV *Expr, SmallVectorImpl &Subscripts, + SmallVectorImpl &Sizes) { - assert(Expr->isAffine() && "Expr should be affine"); + // Early exit in case this SCEV is not an affine multivariate function. + if (Sizes.empty()) + return; - const SCEV *Start = divide(SE, Expr->getStart(), GCD); - const SCEV *Step = divide(SE, Expr->getStepRecurrence(SE), GCD); + if (auto AR = dyn_cast(Expr)) + if (!AR->isAffine()) + return; - return SE.getAddRecExpr(Start, Step, Expr->getLoop(), - Expr->getNoWrapFlags()); - } + const SCEV *Res = Expr; + int Last = Sizes.size() - 1; + for (int i = Last; i >= 0; i--) { + const SCEV *Q, *R; + SCEVDivision::divide(*this, Res, Sizes[i], &Q, &R); + + DEBUG({ + dbgs() << "Res: " << *Res << "\n"; + dbgs() << "Sizes[i]: " << *Sizes[i] << "\n"; + dbgs() << "Res divided by Sizes[i]:\n"; + dbgs() << "Quotient: " << *Q << "\n"; + dbgs() << "Remainder: " << *R << "\n"; + }); + + Res = Q; + + // Do not record the last subscript corresponding to the size of elements in + // the array. + if (i == Last) { + + // Bail out if the remainder is too complex. + if (isa(R)) { + Subscripts.clear(); + Sizes.clear(); + return; + } - const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { - if (GCD == Expr) - return One; - return Expr; - } + continue; + } - const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { - if (GCD == Expr) - return One; - return Expr; + // Record the access function for the current subscript. + Subscripts.push_back(R); } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { - if (GCD == Expr) - return One; - return Expr; - } + // Also push in last position the remainder of the last division: it will be + // the access function of the innermost dimension. + Subscripts.push_back(Res); - const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { - return Expr; - } + std::reverse(Subscripts.begin(), Subscripts.end()); -private: - ScalarEvolution &SE; - const SCEV *GCD, *Zero, *One; -}; + DEBUG({ + dbgs() << "Subscripts:\n"; + for (const SCEV *S : Subscripts) + dbgs() << *S << "\n"; + }); } /// Splits the SCEV into two vectors of SCEVs representing the subscripts and @@ -7246,84 +8183,40 @@ private: /// asking for the SCEV of the memory access with respect to all enclosing /// loops, calling SCEV->delinearize on that and printing the results. -const SCEV * -SCEVAddRecExpr::delinearize(ScalarEvolution &SE, - SmallVectorImpl &Subscripts, - SmallVectorImpl &Sizes) const { - // Early exit in case this SCEV is not an affine multivariate function. - if (!this->isAffine()) - return this; - - const SCEV *Start = this->getStart(); - const SCEV *Step = this->getStepRecurrence(SE); - - // Build the SCEV representation of the canonical induction variable in the - // loop of this SCEV. - const SCEV *Zero = SE.getConstant(this->getType(), 0); - const SCEV *One = SE.getConstant(this->getType(), 1); - const SCEV *IV = - SE.getAddRecExpr(Zero, One, this->getLoop(), this->getNoWrapFlags()); - - DEBUG(dbgs() << "(delinearize: " << *this << "\n"); - - // When the stride of this SCEV is 1, do not compute the GCD: the size of this - // subscript is 1, and this same SCEV for the access function. - const SCEV *Remainder = Zero; - const SCEV *GCD = One; - - // Find the GCD and Remainder of the Start and Step coefficients of this SCEV. - if (Step != One && !Step->isAllOnesValue()) - GCD = SCEVGCD::findGCD(SE, Start, Step, &Remainder); - - DEBUG(dbgs() << "GCD: " << *GCD << "\n"); - DEBUG(dbgs() << "Remainder: " << *Remainder << "\n"); - - const SCEV *Quotient = Start; - if (GCD != One && !GCD->isAllOnesValue()) - // As findGCD computed Remainder, GCD divides "Start - Remainder." The - // Quotient is then this SCEV without Remainder, scaled down by the GCD. The - // Quotient is what will be used in the next subscript delinearization. - Quotient = SCEVDivision::divide(SE, SE.getMinusSCEV(Start, Remainder), GCD); - - DEBUG(dbgs() << "Quotient: " << *Quotient << "\n"); - - const SCEV *Rem = Quotient; - if (const SCEVAddRecExpr *AR = dyn_cast(Quotient)) - // Recursively call delinearize on the Quotient until there are no more - // multiples that can be recognized. - Rem = AR->delinearize(SE, Subscripts, Sizes); - - // Scale up the canonical induction variable IV by whatever remains from the - // Step after division by the GCD: the GCD is the size of all the sub-array. - if (Step != One && !Step->isAllOnesValue() && GCD != One && - !GCD->isAllOnesValue() && Step != GCD) { - Step = SCEVDivision::divide(SE, Step, GCD); - IV = SE.getMulExpr(IV, Step); - } - // The access function in the current subscript is computed as the canonical - // induction variable IV (potentially scaled up by the step) and offset by - // Rem, the offset of delinearization in the sub-array. - const SCEV *Index = SE.getAddExpr(IV, Rem); - - // Record the access function and the size of the current subscript. - Subscripts.push_back(Index); - Sizes.push_back(GCD); +void ScalarEvolution::delinearize(const SCEV *Expr, + SmallVectorImpl &Subscripts, + SmallVectorImpl &Sizes, + const SCEV *ElementSize) { + // First step: collect parametric terms. + SmallVector Terms; + collectParametricTerms(Expr, Terms); -#ifndef NDEBUG - int Size = Sizes.size(); - DEBUG(dbgs() << "succeeded to delinearize " << *this << "\n"); - DEBUG(dbgs() << "ArrayDecl[UnknownSize]"); - for (int i = 0; i < Size - 1; i++) - DEBUG(dbgs() << "[" << *Sizes[i] << "]"); - DEBUG(dbgs() << " with elements of " << *Sizes[Size - 1] << " bytes.\n"); - - DEBUG(dbgs() << "ArrayRef"); - for (int i = 0; i < Size; i++) - DEBUG(dbgs() << "[" << *Subscripts[i] << "]"); - DEBUG(dbgs() << "\n)\n"); -#endif + if (Terms.empty()) + return; + + // Second step: find subscript sizes. + findArrayDimensions(Terms, Sizes, ElementSize); + + if (Sizes.empty()) + return; + + // Third step: compute the access functions for each subscript. + computeAccessFunctions(Expr, Subscripts, Sizes); + + if (Subscripts.empty()) + return; + + DEBUG({ + dbgs() << "succeeded to delinearize " << *Expr << "\n"; + dbgs() << "ArrayDecl[UnknownSize]"; + for (const SCEV *S : Sizes) + dbgs() << "[" << *S << "]"; - return Remainder; + dbgs() << "\nArrayRef"; + for (const SCEV *S : Subscripts) + dbgs() << "[" << *S << "]"; + dbgs() << "\n"; + }); } //===----------------------------------------------------------------------===// @@ -7353,7 +8246,7 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { // that until everything else is done. if (U == Old) continue; - if (!Visited.insert(U)) + if (!Visited.insert(U).second) continue; if (PHINode *PN = dyn_cast(U)) SE->ConstantEvolutionLoopExitValue.erase(PN); @@ -7375,16 +8268,16 @@ ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) //===----------------------------------------------------------------------===// ScalarEvolution::ScalarEvolution() - : FunctionPass(ID), ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64), FirstUnknown(0) { + : FunctionPass(ID), WalkingBEDominatingConds(false), ValuesAtScopes(64), + LoopDispositions(64), BlockDispositions(64), FirstUnknown(nullptr) { initializeScalarEvolutionPass(*PassRegistry::getPassRegistry()); } bool ScalarEvolution::runOnFunction(Function &F) { this->F = &F; - LI = &getAnalysis(); - DataLayoutPass *DLP = getAnalysisIfAvailable(); - DL = DLP ? &DLP->getDataLayout() : 0; - TLI = &getAnalysis(); + AC = &getAnalysis().getAssumptionCache(F); + LI = &getAnalysis().getLoopInfo(); + TLI = &getAnalysis().getTLI(); DT = &getAnalysis().getDomTree(); return false; } @@ -7394,7 +8287,7 @@ void ScalarEvolution::releaseMemory() { // destructors, so that they release their references to their values. for (SCEVUnknown *U = FirstUnknown; U; U = U->Next) U->~SCEVUnknown(); - FirstUnknown = 0; + FirstUnknown = nullptr; ValueExprMap.clear(); @@ -7407,6 +8300,7 @@ void ScalarEvolution::releaseMemory() { } assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); + assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!"); BackedgeTakenCounts.clear(); ConstantEvolutionLoopExitValue.clear(); @@ -7421,9 +8315,10 @@ void ScalarEvolution::releaseMemory() { void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); - AU.addRequiredTransitive(); + AU.addRequiredTransitive(); + AU.addRequiredTransitive(); AU.addRequiredTransitive(); - AU.addRequired(); + AU.addRequiredTransitive(); } bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { @@ -7483,6 +8378,12 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const { OS << " --> "; const SCEV *SV = SE.getSCEV(&*I); SV->print(OS); + if (!isa(SV)) { + OS << " U: "; + SE.getUnsignedRange(SV).print(OS); + OS << " S: "; + SE.getSignedRange(SV).print(OS); + } const Loop *L = LI->getLoopFor((*I).getParent()); @@ -7490,6 +8391,12 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const { if (AtUse != SV) { OS << " --> "; AtUse->print(OS); + if (!isa(AtUse)) { + OS << " U: "; + SE.getUnsignedRange(AtUse).print(OS); + OS << " S: "; + SE.getSignedRange(AtUse).print(OS); + } } if (L) { @@ -7514,17 +8421,17 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const { ScalarEvolution::LoopDisposition ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { - SmallVector, 2> &Values = LoopDispositions[S]; - for (unsigned u = 0; u < Values.size(); u++) { - if (Values[u].first == L) - return Values[u].second; + auto &Values = LoopDispositions[S]; + for (auto &V : Values) { + if (V.getPointer() == L) + return V.getInt(); } - Values.push_back(std::make_pair(L, LoopVariant)); + Values.emplace_back(L, LoopVariant); LoopDisposition D = computeLoopDisposition(S, L); - SmallVector, 2> &Values2 = LoopDispositions[S]; - for (unsigned u = Values2.size(); u > 0; u--) { - if (Values2[u - 1].first == L) { - Values2[u - 1].second = D; + auto &Values2 = LoopDispositions[S]; + for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { + if (V.getPointer() == L) { + V.setInt(D); break; } } @@ -7620,17 +8527,17 @@ bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) { ScalarEvolution::BlockDisposition ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { - SmallVector, 2> &Values = BlockDispositions[S]; - for (unsigned u = 0; u < Values.size(); u++) { - if (Values[u].first == BB) - return Values[u].second; + auto &Values = BlockDispositions[S]; + for (auto &V : Values) { + if (V.getPointer() == BB) + return V.getInt(); } - Values.push_back(std::make_pair(BB, DoesNotDominateBlock)); + Values.emplace_back(BB, DoesNotDominateBlock); BlockDisposition D = computeBlockDisposition(S, BB); - SmallVector, 2> &Values2 = BlockDispositions[S]; - for (unsigned u = Values2.size(); u > 0; u--) { - if (Values2[u - 1].first == BB) { - Values2[u - 1].second = D; + auto &Values2 = BlockDispositions[S]; + for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { + if (V.getPointer() == BB) { + V.setInt(D); break; } }