X-Git-Url: http://plrg.eecs.uci.edu/git/?p=oota-llvm.git;a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=19e3633fcc5cc9fde107506af0a6a283e6b4f520;hp=7324344c3e0e71ec421ce8d4061fd2985bf628bb;hb=d12ce78ca9ad070c87625ab014dc3caf257dae7d;hpb=8cff277de2400451755c835348f036733978eac6 diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 7324344c3e0..19e3633fcc5 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -63,11 +63,12 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AssumptionTracker.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" @@ -87,7 +88,6 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Target/TargetLibraryInfo.h" #include using namespace llvm; @@ -116,10 +116,10 @@ VerifySCEV("verify-scev", INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution", "Scalar Evolution Analysis", false, true) -INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) -INITIALIZE_PASS_DEPENDENCY(LoopInfo) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution", "Scalar Evolution Analysis", false, true) char ScalarEvolution::ID = 0; @@ -675,34 +675,6 @@ static void GroupByComplexity(SmallVectorImpl &Ops, } } -static const APInt srem(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.sext(ABW); - else if (ABW < BBW) - A = A.sext(BBW); - - return APIntOps::srem(A, B); -} - -static const APInt sdiv(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.sext(ABW); - else if (ABW < BBW) - A = A.sext(BBW); - - return APIntOps::sdiv(A, B); -} - namespace { struct FindSCEVSize { int Size; @@ -779,17 +751,6 @@ public: *Remainder = D.Remainder; } - SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, const SCEV *Denominator) - : SE(S), Denominator(Denominator) { - Zero = SE.getConstant(Denominator->getType(), 0); - One = SE.getConstant(Denominator->getType(), 1); - - // By default, we don't know how to divide Expr by Denominator. - // Providing the default here simplifies the rest of the code. - Quotient = Zero; - Remainder = Numerator; - } - // Except in the trivial case described above, we do not know how to divide // Expr by Denominator for the following functions with empty implementation. void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} @@ -803,8 +764,21 @@ public: void visitConstant(const SCEVConstant *Numerator) { if (const SCEVConstant *D = dyn_cast(Denominator)) { - Quotient = SE.getConstant(sdiv(Numerator, D)); - Remainder = SE.getConstant(srem(Numerator, D)); + APInt NumeratorVal = Numerator->getValue()->getValue(); + APInt DenominatorVal = D->getValue()->getValue(); + uint32_t NumeratorBW = NumeratorVal.getBitWidth(); + uint32_t DenominatorBW = DenominatorVal.getBitWidth(); + + if (NumeratorBW > DenominatorBW) + DenominatorVal = DenominatorVal.sext(NumeratorBW); + else if (NumeratorBW < DenominatorBW) + NumeratorVal = NumeratorVal.sext(DenominatorBW); + + APInt QuotientVal(NumeratorVal.getBitWidth(), 0); + APInt RemainderVal(NumeratorVal.getBitWidth(), 0); + APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); + Quotient = SE.getConstant(QuotientVal); + Remainder = SE.getConstant(RemainderVal); return; } } @@ -932,12 +906,23 @@ public: } private: + SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, + const SCEV *Denominator) + : SE(S), Denominator(Denominator) { + Zero = SE.getConstant(Denominator->getType(), 0); + One = SE.getConstant(Denominator->getType(), 1); + + // By default, we don't know how to divide Expr by Denominator. + // Providing the default here simplifies the rest of the code. + Quotient = Zero; + Remainder = Numerator; + } + ScalarEvolution &SE; const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; }; -} - +} //===----------------------------------------------------------------------===// // Simple SCEV method implementations @@ -1163,6 +1148,262 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, return S; } +// Get the limit of a recurrence such that incrementing by Step cannot cause +// signed overflow as long as the value of the recurrence within the +// loop does not exceed this limit before incrementing. +static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); + if (SE->isKnownPositive(Step)) { + *Pred = ICmpInst::ICMP_SLT; + return SE->getConstant(APInt::getSignedMinValue(BitWidth) - + SE->getSignedRange(Step).getSignedMax()); + } + if (SE->isKnownNegative(Step)) { + *Pred = ICmpInst::ICMP_SGT; + return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - + SE->getSignedRange(Step).getSignedMin()); + } + return nullptr; +} + +// Get the limit of a recurrence such that incrementing by Step cannot cause +// unsigned overflow as long as the value of the recurrence within the loop does +// not exceed this limit before incrementing. +static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); + *Pred = ICmpInst::ICMP_ULT; + + return SE->getConstant(APInt::getMinValue(BitWidth) - + SE->getUnsignedRange(Step).getUnsignedMax()); +} + +namespace { + +struct ExtendOpTraitsBase { + typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *); +}; + +// Used to make code generic over signed and unsigned overflow. +template struct ExtendOpTraits { + // Members present: + // + // static const SCEV::NoWrapFlags WrapType; + // + // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr; + // + // static const SCEV *getOverflowLimitForStep(const SCEV *Step, + // ICmpInst::Predicate *Pred, + // ScalarEvolution *SE); +}; + +template <> +struct ExtendOpTraits : public ExtendOpTraitsBase { + static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW; + + static const GetExtendExprTy GetExtendExpr; + + static const SCEV *getOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + return getSignedOverflowLimitForStep(Step, Pred, SE); + } +}; + +const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< + SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr; + +template <> +struct ExtendOpTraits : public ExtendOpTraitsBase { + static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW; + + static const GetExtendExprTy GetExtendExpr; + + static const SCEV *getOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + return getUnsignedOverflowLimitForStep(Step, Pred, SE); + } +}; + +const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< + SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; +} + +// The recurrence AR has been shown to have no signed/unsigned wrap or something +// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as +// easily prove NSW/NUW for its preincrement or postincrement sibling. This +// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step + +// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the +// expression "Step + sext/zext(PreIncAR)" is congruent with +// "sext/zext(PostIncAR)" +template +static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, + ScalarEvolution *SE) { + auto WrapType = ExtendOpTraits::WrapType; + auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; + + const Loop *L = AR->getLoop(); + const SCEV *Start = AR->getStart(); + const SCEV *Step = AR->getStepRecurrence(*SE); + + // Check for a simple looking step prior to loop entry. + const SCEVAddExpr *SA = dyn_cast(Start); + if (!SA) + return nullptr; + + // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV + // subtraction is expensive. For this purpose, perform a quick and dirty + // difference, by checking for Step in the operand list. + SmallVector DiffOps; + for (const SCEV *Op : SA->operands()) + if (Op != Step) + DiffOps.push_back(Op); + + if (DiffOps.size() == SA->getNumOperands()) + return nullptr; + + // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` + + // `Step`: + + // 1. NSW/NUW flags on the step increment. + const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags()); + const SCEVAddRecExpr *PreAR = dyn_cast( + SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); + + // "{S,+,X} is /" and "the backedge is taken at least once" implies + // "S+X does not sign/unsign-overflow". + // + + const SCEV *BECount = SE->getBackedgeTakenCount(L); + if (PreAR && PreAR->getNoWrapFlags(WrapType) && + !isa(BECount) && SE->isKnownPositive(BECount)) + return PreStart; + + // 2. Direct overflow check on the step operation's expression. + unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); + Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); + const SCEV *OperandExtendedStart = + SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy), + (SE->*GetExtendExpr)(Step, WideTy)); + if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) { + if (PreAR && AR->getNoWrapFlags(WrapType)) { + // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW + // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then + // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact. + const_cast(PreAR)->setNoWrapFlags(WrapType); + } + return PreStart; + } + + // 3. Loop precondition. + ICmpInst::Predicate Pred; + const SCEV *OverflowLimit = + ExtendOpTraits::getOverflowLimitForStep(Step, &Pred, SE); + + if (OverflowLimit && + SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) { + return PreStart; + } + return nullptr; +} + +// Get the normalized zero or sign extended expression for this AddRec's Start. +template +static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, + ScalarEvolution *SE) { + auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; + + const SCEV *PreStart = getPreStartForExtend(AR, Ty, SE); + if (!PreStart) + return (SE->*GetExtendExpr)(AR->getStart(), Ty); + + return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty), + (SE->*GetExtendExpr)(PreStart, Ty)); +} + +// Try to prove away overflow by looking at "nearby" add recurrences. A +// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it +// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`. +// +// Formally: +// +// {S,+,X} == {S-T,+,X} + T +// => Ext({S,+,X}) == Ext({S-T,+,X} + T) +// +// If ({S-T,+,X} + T) does not overflow ... (1) +// +// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T) +// +// If {S-T,+,X} does not overflow ... (2) +// +// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T) +// == {Ext(S-T)+Ext(T),+,Ext(X)} +// +// If (S-T)+T does not overflow ... (3) +// +// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)} +// == {Ext(S),+,Ext(X)} == LHS +// +// Thus, if (1), (2) and (3) are true for some T, then +// Ext({S,+,X}) == {Ext(S),+,Ext(X)} +// +// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T) +// does not overflow" restricted to the 0th iteration. Therefore we only need +// to check for (1) and (2). +// +// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T +// is `Delta` (defined below). +// +template +bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, + const SCEV *Step, + const Loop *L) { + auto WrapType = ExtendOpTraits::WrapType; + + // We restrict `Start` to a constant to prevent SCEV from spending too much + // time here. It is correct (but more expensive) to continue with a + // non-constant `Start` and do a general SCEV subtraction to compute + // `PreStart` below. + // + const SCEVConstant *StartC = dyn_cast(Start); + if (!StartC) + return false; + + APInt StartAI = StartC->getValue()->getValue(); + + for (unsigned Delta : {-2, -1, 1, 2}) { + const SCEV *PreStart = getConstant(StartAI - Delta); + + // Give up if we don't already have the add recurrence we need because + // actually constructing an add recurrence is relatively expensive. + const SCEVAddRecExpr *PreAR = [&]() { + FoldingSetNodeID ID; + ID.AddInteger(scAddRecExpr); + ID.AddPointer(PreStart); + ID.AddPointer(Step); + ID.AddPointer(L); + void *IP = nullptr; + return static_cast( + this->UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + }(); + + if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2) + const SCEV *DeltaS = getConstant(StartC->getType(), Delta); + ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; + const SCEV *Limit = ExtendOpTraits::getOverflowLimitForStep( + DeltaS, &Pred, this); + if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1) + return true; + } + } + + return false; +} + const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && @@ -1216,9 +1457,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. if (AR->getNoWrapFlags(SCEV::FlagNUW)) - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getZeroExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1255,9 +1496,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Cache knowledge of AR NUW, which is propagated to this AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getZeroExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as signed. // This covers loops that count down. @@ -1270,9 +1511,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Negative step causes unsigned wrap, but it still can't self-wrap. const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getSignExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } @@ -1290,9 +1531,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Cache knowledge of AR NUW, which is propagated to this AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getZeroExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } else if (isKnownNegative(Step)) { const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - @@ -1305,12 +1546,19 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Negative step causes unsigned wrap, but it still can't self-wrap. const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getZeroExtendExpr(Start, Ty), - getSignExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } } + + if (proveNoWrapByVaryingStart(Start, Step, L)) { + const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + } } // The cast wasn't folded; create an explicit cast node. @@ -1322,104 +1570,6 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, return S; } -// Get the limit of a recurrence such that incrementing by Step cannot cause -// signed overflow as long as the value of the recurrence within the loop does -// not exceed this limit before incrementing. -static const SCEV *getOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { - unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); - if (SE->isKnownPositive(Step)) { - *Pred = ICmpInst::ICMP_SLT; - return SE->getConstant(APInt::getSignedMinValue(BitWidth) - - SE->getSignedRange(Step).getSignedMax()); - } - if (SE->isKnownNegative(Step)) { - *Pred = ICmpInst::ICMP_SGT; - return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - - SE->getSignedRange(Step).getSignedMin()); - } - return nullptr; -} - -// The recurrence AR has been shown to have no signed wrap. Typically, if we can -// prove NSW for AR, then we can just as easily prove NSW for its preincrement -// or postincrement sibling. This allows normalizing a sign extended AddRec as -// such: {sext(Step + Start),+,Step} => {(Step + sext(Start),+,Step} As a -// result, the expression "Step + sext(PreIncAR)" is congruent with -// "sext(PostIncAR)" -static const SCEV *getPreStartForSignExtend(const SCEVAddRecExpr *AR, - Type *Ty, - ScalarEvolution *SE) { - const Loop *L = AR->getLoop(); - const SCEV *Start = AR->getStart(); - const SCEV *Step = AR->getStepRecurrence(*SE); - - // Check for a simple looking step prior to loop entry. - const SCEVAddExpr *SA = dyn_cast(Start); - if (!SA) - return nullptr; - - // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV - // subtraction is expensive. For this purpose, perform a quick and dirty - // difference, by checking for Step in the operand list. - SmallVector DiffOps; - for (const SCEV *Op : SA->operands()) - if (Op != Step) - DiffOps.push_back(Op); - - if (DiffOps.size() == SA->getNumOperands()) - return nullptr; - - // This is a postinc AR. Check for overflow on the preinc recurrence using the - // same three conditions that getSignExtendedExpr checks. - - // 1. NSW flags on the step increment. - const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags()); - const SCEVAddRecExpr *PreAR = dyn_cast( - SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); - - if (PreAR && PreAR->getNoWrapFlags(SCEV::FlagNSW)) - return PreStart; - - // 2. Direct overflow check on the step operation's expression. - unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); - Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); - const SCEV *OperandExtendedStart = - SE->getAddExpr(SE->getSignExtendExpr(PreStart, WideTy), - SE->getSignExtendExpr(Step, WideTy)); - if (SE->getSignExtendExpr(Start, WideTy) == OperandExtendedStart) { - // Cache knowledge of PreAR NSW. - if (PreAR) - const_cast(PreAR)->setNoWrapFlags(SCEV::FlagNSW); - // FIXME: this optimization needs a unit test - DEBUG(dbgs() << "SCEV: untested prestart overflow check\n"); - return PreStart; - } - - // 3. Loop precondition. - ICmpInst::Predicate Pred; - const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, SE); - - if (OverflowLimit && - SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) { - return PreStart; - } - return nullptr; -} - -// Get the normalized sign-extended expression for this AddRec's Start. -static const SCEV *getSignExtendAddRecStart(const SCEVAddRecExpr *AR, - Type *Ty, - ScalarEvolution *SE) { - const SCEV *PreStart = getPreStartForSignExtend(AR, Ty, SE); - if (!PreStart) - return SE->getSignExtendExpr(AR->getStart(), Ty); - - return SE->getAddExpr(SE->getSignExtendExpr(AR->getStepRecurrence(*SE), Ty), - SE->getSignExtendExpr(PreStart, Ty)); -} - const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && @@ -1498,9 +1648,9 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. if (AR->getNoWrapFlags(SCEV::FlagNSW)) - return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), - L, SCEV::FlagNSW); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1537,9 +1687,9 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // Cache knowledge of AR NSW, which is propagated to this AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as unsigned. // This covers loops that count up with an unsigned step. @@ -1548,12 +1698,20 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, getMulExpr(WideMaxBECount, getZeroExtendExpr(Step, WideTy))); if (SAdd == OperandExtendedAdd) { - // Cache knowledge of AR NSW, which is propagated to this AddRec. - const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); + // If AR wraps around then + // + // abs(Step) * MaxBECount > unsigned-max(AR->getType()) + // => SAdd != OperandExtendedAdd + // + // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=> + // (SAdd == OperandExtendedAdd => AR is NW) + + const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); + // Return the expression with the addrec on the outside. - return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), - getZeroExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } @@ -1562,7 +1720,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // with the start value and the backedge is guarded by a comparison // with the post-inc value, the addrec is safe. ICmpInst::Predicate Pred; - const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, this); + const SCEV *OverflowLimit = + getSignedOverflowLimitForStep(Step, &Pred, this); if (OverflowLimit && (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) && @@ -1570,9 +1729,9 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, OverflowLimit)))) { // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); - return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), - getSignExtendExpr(Step, Ty), - L, AR->getNoWrapFlags()); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } // If Start and Step are constants, check if we can apply this @@ -1591,6 +1750,13 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, return getAddExpr(Start, getSignExtendExpr(NewAR, Ty)); } } + + if (proveNoWrapByVaryingStart(Start, Step, L)) { + const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); + return getAddRecExpr( + getExtendAddRecStart(AR, Ty, this), + getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + } } // The cast wasn't folded; create an explicit cast node. @@ -1752,6 +1918,36 @@ namespace { }; } +// We're trying to construct a SCEV of type `Type' with `Ops' as operands and +// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of +// can't-overflow flags for the operation if possible. +static SCEV::NoWrapFlags +StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, + const SmallVectorImpl &Ops, + SCEV::NoWrapFlags OldFlags) { + using namespace std::placeholders; + + bool CanAnalyze = + Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; + (void)CanAnalyze; + assert(CanAnalyze && "don't call from other places!"); + + int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; + SCEV::NoWrapFlags SignOrUnsignWrap = + ScalarEvolution::maskFlags(OldFlags, SignOrUnsignMask); + + // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. + auto IsKnownNonNegative = + std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1); + + if (SignOrUnsignWrap == SCEV::FlagNSW && + std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative)) + return ScalarEvolution::setFlags(OldFlags, + (SCEV::NoWrapFlags)SignOrUnsignMask); + + return OldFlags; +} + /// getAddExpr - Get a canonical add expression, or something simpler if /// possible. const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, @@ -1767,20 +1963,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, "SCEVAddExpr operand types don't match!"); #endif - // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - // And vice-versa. - int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; - SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask); - if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) { - bool All = true; - for (SmallVectorImpl::const_iterator I = Ops.begin(), - E = Ops.end(); I != E; ++I) - if (!isKnownNonNegative(*I)) { - All = false; - break; - } - if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); - } + Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, LI); @@ -2155,6 +2338,24 @@ static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { return r; } +/// Determine if any of the operands in this SCEV are a constant or if +/// any of the add or multiply expressions in this SCEV contain a constant. +static bool containsConstantSomewhere(const SCEV *StartExpr) { + SmallVector Ops; + Ops.push_back(StartExpr); + while (!Ops.empty()) { + const SCEV *CurrentExpr = Ops.pop_back_val(); + if (isa(*CurrentExpr)) + return true; + + if (isa(*CurrentExpr) || isa(*CurrentExpr)) { + const auto *CurrentNAry = cast(CurrentExpr); + Ops.append(CurrentNAry->op_begin(), CurrentNAry->op_end()); + } + } + return false; +} + /// getMulExpr - Get a canonical multiply expression, or something simpler if /// possible. const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, @@ -2170,20 +2371,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, "SCEVMulExpr operand types don't match!"); #endif - // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - // And vice-versa. - int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; - SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask); - if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) { - bool All = true; - for (SmallVectorImpl::const_iterator I = Ops.begin(), - E = Ops.end(); I != E; ++I) - if (!isKnownNonNegative(*I)) { - All = false; - break; - } - if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); - } + Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, LI); @@ -2194,11 +2382,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // C1*(C2+V) -> C1*C2 + C1*V if (Ops.size() == 2) - if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) - if (Add->getNumOperands() == 2 && - isa(Add->getOperand(0))) - return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), - getMulExpr(LHSC, Add->getOperand(1))); + if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) + // If any of Add's ops are Adds or Muls with a constant, + // apply this transformation as well. + if (Add->getNumOperands() == 2) + if (containsConstantSomewhere(Add)) + return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), + getMulExpr(LHSC, Add->getOperand(1))); ++Idx; while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { @@ -2647,20 +2837,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, // meaningful BE count at this point (and if we don't, we'd be stuck // with a SCEVCouldNotCompute as the cached BE count). - // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - // And vice-versa. - int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; - SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask); - if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) { - bool All = true; - for (SmallVectorImpl::const_iterator I = Operands.begin(), - E = Operands.end(); I != E; ++I) - if (!isKnownNonNegative(*I)) { - All = false; - break; - } - if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); - } + Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); // Canonicalize nested AddRecs in by nesting them in order of loop depth. if (const SCEVAddRecExpr *NestedAR = dyn_cast(Operands[0])) { @@ -3157,8 +3334,9 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, if (LHS == RHS) return getConstant(LHS->getType(), 0); - // X - Y --> X + -Y - return getAddExpr(LHS, getNegativeSCEV(RHS), Flags); + // X - Y --> X + -Y. + // X -(nsw || nuw) Y --> X + -Y. + return getAddExpr(LHS, getNegativeSCEV(RHS)); } /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the @@ -3343,7 +3521,8 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { Visited.insert(PN); while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); @@ -3463,12 +3642,10 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) Flags = setFlags(Flags, SCEV::FlagNUW); } - } else if (const SubOperator *OBO = - dyn_cast(BEValueV)) { - if (OBO->hasNoUnsignedWrap()) - Flags = setFlags(Flags, SCEV::FlagNUW); - if (OBO->hasNoSignedWrap()) - Flags = setFlags(Flags, SCEV::FlagNSW); + + // We cannot transfer nuw and nsw flags from subtraction + // operations -- sub nuw X, Y is not the same as add nuw X, -Y + // for instance. } const SCEV *StartVal = getSCEV(StartValueV); @@ -3524,7 +3701,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // PHI's incoming blocks are in a different loop, in which case doing so // risks breaking LCSSA form. Instcombine would normally zap these, but // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = SimplifyInstruction(PN, DL, TLI, DT, AT)) + if (Value *V = SimplifyInstruction(PN, DL, TLI, DT, AC)) if (LI->replacementPreservesLCSSAForm(PN, V)) return getSCEV(V); @@ -3656,7 +3833,7 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { // For a SCEVUnknown, ask ValueTracking. unsigned BitWidth = getTypeSizeInBits(U->getType()); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AT, nullptr, DT); + computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT); return Zeros.countTrailingOnes(); } @@ -3676,8 +3853,10 @@ static Optional GetRangeFromMetadata(Value *V) { assert(NumRanges >= 1); for (unsigned i = 0; i < NumRanges; ++i) { - ConstantInt *Lower = cast(MD->getOperand(2*i + 0)); - ConstantInt *Upper = cast(MD->getOperand(2*i + 1)); + ConstantInt *Lower = + mdconst::extract(MD->getOperand(2 * i + 0)); + ConstantInt *Upper = + mdconst::extract(MD->getOperand(2 * i + 1)); ConstantRange Range(Lower->getValue(), Upper->getValue()); TotalRange = TotalRange.unionWith(Range); } @@ -3825,7 +4004,7 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) { // For a SCEVUnknown, ask ValueTracking. APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AT, nullptr, DT); + computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT); if (Ones == ~Zeros + 1) return setUnsignedRange(U, ConservativeResult); return setUnsignedRange(U, @@ -3982,7 +4161,7 @@ ScalarEvolution::getSignedRange(const SCEV *S) { // For a SCEVUnknown, ask ValueTracking. if (!U->getValue()->getType()->isIntegerTy() && !DL) return setSignedRange(U, ConservativeResult); - unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AT, nullptr, DT); + unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AC, nullptr, DT); if (NS <= 1) return setSignedRange(U, ConservativeResult); return setSignedRange(U, ConservativeResult.intersectWith( @@ -4089,8 +4268,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL, - 0, AT, nullptr, DT); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL, 0, AC, + nullptr, DT); APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); @@ -4281,9 +4460,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case ICmpInst::ICMP_SGE: // a >s b ? a+x : b+x -> smax(a, b)+x // a >s b ? b+x : a+x -> smin(a, b)+x - if (LHS->getType() == U->getType()) { - const SCEV *LS = getSCEV(LHS); - const SCEV *RS = getSCEV(RHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType())) { + const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType()); + const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -4304,9 +4484,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case ICmpInst::ICMP_UGE: // a >u b ? a+x : b+x -> umax(a, b)+x // a >u b ? b+x : a+x -> umin(a, b)+x - if (LHS->getType() == U->getType()) { - const SCEV *LS = getSCEV(LHS); - const SCEV *RS = getSCEV(RHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType())) { + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); + const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -4321,11 +4502,11 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { break; case ICmpInst::ICMP_NE: // n != 0 ? n+x : 1+x -> umax(n, 1)+x - if (LHS->getType() == U->getType() && - isa(RHS) && - cast(RHS)->isZero()) { - const SCEV *One = getConstant(LHS->getType(), 1); - const SCEV *LS = getSCEV(LHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType()) && + isa(RHS) && cast(RHS)->isZero()) { + const SCEV *One = getConstant(U->getType(), 1); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, LS); @@ -4336,11 +4517,11 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { break; case ICmpInst::ICMP_EQ: // n == 0 ? 1+x : n+x -> umax(n, 1)+x - if (LHS->getType() == U->getType() && - isa(RHS) && - cast(RHS)->isZero()) { - const SCEV *One = getConstant(LHS->getType(), 1); - const SCEV *LS = getSCEV(LHS); + if (getTypeSizeInBits(LHS->getType()) <= + getTypeSizeInBits(U->getType()) && + isa(RHS) && cast(RHS)->isZero()) { + const SCEV *One = getConstant(U->getType(), 1); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); const SCEV *LA = getSCEV(U->getOperand(1)); const SCEV *RA = getSCEV(U->getOperand(2)); const SCEV *LDiff = getMinusSCEV(LA, One); @@ -4541,7 +4722,8 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { SmallPtrSet Visited; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); @@ -4593,7 +4775,8 @@ void ScalarEvolution::forgetLoop(const Loop *L) { SmallPtrSet Visited; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); @@ -4627,7 +4810,8 @@ void ScalarEvolution::forgetValue(Value *V) { SmallPtrSet Visited; while (!Worklist.empty()) { I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); @@ -6082,15 +6266,18 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { return ExitLimit(Distance, MaxBECount); } - // If the step exactly divides the distance then unsigned divide computes the - // backedge count. - const SCEV *Q, *R; - ScalarEvolution &SE = *const_cast(this); - SCEVDivision::divide(SE, Distance, Step, &Q, &R); - if (R->isZero()) { - const SCEV *Exact = - getUDivExactExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - return ExitLimit(Exact, Exact); + // As a special case, handle the instance where Step is a positive power of + // two. In this case, determining whether Step divides Distance evenly can be + // done by counting and comparing the number of trailing zeros of Step and + // Distance. + if (!CountDown) { + const APInt &StepV = StepC->getValue()->getValue(); + // StepV.isPowerOf2() returns true if StepV is an positive power of two. It + // also returns true if StepV is maximally negative (eg, INT_MIN), but that + // case is not handled as this code is guarded by !CountDown. + if (StepV.isPowerOf2() && + GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) + return getUDivExactExpr(Distance, Step); } // If the condition controls loop exit (the loop exits only if the expression @@ -6615,7 +6802,10 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, return true; // Check conditions due to any @llvm.assume intrinsics. - for (auto &CI : AT->assumptions(F)) { + for (auto &AssumeVH : AC->assumptions()) { + if (!AssumeVH) + continue; + auto *CI = cast(AssumeVH); if (!DT->dominates(CI, Latch->getTerminator())) continue; @@ -6660,7 +6850,10 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, } // Check conditions due to any @llvm.assume intrinsics. - for (auto &CI : AT->assumptions(F)) { + for (auto &AssumeVH : AC->assumptions()) { + if (!AssumeVH) + continue; + auto *CI = cast(AssumeVH); if (!DT->dominates(CI, L->getHeader())) continue; @@ -6782,6 +6975,66 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, RHS, LHS, FoundLHS, FoundRHS); } + // Check if we can make progress by sharpening ranges. + if (FoundPred == ICmpInst::ICMP_NE && + (isa(FoundLHS) || isa(FoundRHS))) { + + const SCEVConstant *C = nullptr; + const SCEV *V = nullptr; + + if (isa(FoundLHS)) { + C = cast(FoundLHS); + V = FoundRHS; + } else { + C = cast(FoundRHS); + V = FoundLHS; + } + + // The guarding predicate tells us that C != V. If the known range + // of V is [C, t), we can sharpen the range to [C + 1, t). The + // range we consider has to correspond to same signedness as the + // predicate we're interested in folding. + + APInt Min = ICmpInst::isSigned(Pred) ? + getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin(); + + if (Min == C->getValue()->getValue()) { + // Given (V >= Min && V != Min) we conclude V >= (Min + 1). + // This is true even if (Min + 1) wraps around -- in case of + // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)). + + APInt SharperMin = Min + 1; + + switch (Pred) { + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_UGE: + // We know V `Pred` SharperMin. If this implies LHS `Pred` + // RHS, we're done. + if (isImpliedCondOperands(Pred, LHS, RHS, V, + getConstant(SharperMin))) + return true; + + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_UGT: + // We know from the range information that (V `Pred` Min || + // V == Min). We know from the guarding condition that !(V + // == Min). This gives us + // + // V `Pred` Min || V == Min && !(V == Min) + // => V `Pred` Min + // + // If V `Pred` Min implies LHS `Pred` RHS, we're done. + + if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) + return true; + + default: + // No change + break; + } + } + } + // Check whether the actual condition is beyond sufficient. if (FoundPred == ICmpInst::ICMP_EQ) if (ICmpInst::isTrueWhenEqual(Pred)) @@ -6811,6 +7064,85 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, getNotSCEV(FoundLHS)); } + +/// If Expr computes ~A, return A else return nullptr +static const SCEV *MatchNotExpr(const SCEV *Expr) { + const SCEVAddExpr *Add = dyn_cast(Expr); + if (!Add || Add->getNumOperands() != 2) return nullptr; + + const SCEVConstant *AddLHS = dyn_cast(Add->getOperand(0)); + if (!(AddLHS && AddLHS->getValue()->getValue().isAllOnesValue())) + return nullptr; + + const SCEVMulExpr *AddRHS = dyn_cast(Add->getOperand(1)); + if (!AddRHS || AddRHS->getNumOperands() != 2) return nullptr; + + const SCEVConstant *MulLHS = dyn_cast(AddRHS->getOperand(0)); + if (!(MulLHS && MulLHS->getValue()->getValue().isAllOnesValue())) + return nullptr; + + return AddRHS->getOperand(1); +} + + +/// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values? +template +static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr, + const SCEV *Candidate) { + const MaxExprType *MaxExpr = dyn_cast(MaybeMaxExpr); + if (!MaxExpr) return false; + + auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate); + return It != MaxExpr->op_end(); +} + + +/// Is MaybeMinExpr an SMin or UMin of Candidate and some other values? +template +static bool IsMinConsistingOf(ScalarEvolution &SE, + const SCEV *MaybeMinExpr, + const SCEV *Candidate) { + const SCEV *MaybeMaxExpr = MatchNotExpr(MaybeMinExpr); + if (!MaybeMaxExpr) + return false; + + return IsMaxConsistingOf(MaybeMaxExpr, SE.getNotSCEV(Candidate)); +} + + +/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max +/// expression? +static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + switch (Pred) { + default: + return false; + + case ICmpInst::ICMP_SGE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_SLE: + return + // min(A, ...) <= A + IsMinConsistingOf(SE, LHS, RHS) || + // A <= max(A, ...) + IsMaxConsistingOf(RHS, LHS); + + case ICmpInst::ICMP_UGE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_ULE: + return + // min(A, ...) <= A + IsMinConsistingOf(SE, LHS, RHS) || + // A <= max(A, ...) + IsMaxConsistingOf(RHS, LHS); + } + + llvm_unreachable("covered switch fell through?!"); +} + /// isImpliedCondOperandsHelper - Test whether the condition described by /// Pred, LHS, and RHS is true whenever the condition described by Pred, /// FoundLHS, and FoundRHS is true. @@ -6819,6 +7151,12 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { + auto IsKnownPredicateFull = + [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { + return isKnownPredicateWithRanges(Pred, LHS, RHS) || + IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS); + }; + switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: @@ -6828,26 +7166,26 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, break; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - if (isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, LHS, FoundLHS) && - isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, RHS, FoundRHS)) + if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) && + IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - if (isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, LHS, FoundLHS) && - isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, RHS, FoundRHS)) + if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) && + IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - if (isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, LHS, FoundLHS) && - isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, RHS, FoundRHS)) + if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) && + IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: - if (isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, LHS, FoundLHS) && - isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, RHS, FoundRHS)) + if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) && + IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS)) return true; break; } @@ -6855,8 +7193,8 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, return false; } -// Verify if an linear IV with positive stride can overflow when in a -// less-than comparison, knowing the invariant term of the comparison, the +// Verify if an linear IV with positive stride can overflow when in a +// less-than comparison, knowing the invariant term of the comparison, the // stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { @@ -6884,7 +7222,7 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, return (MaxValue - MaxStrideMinusOne).ult(MaxRHS); } -// Verify if an linear IV with negative stride can overflow when in a +// Verify if an linear IV with negative stride can overflow when in a // greater-than comparison, knowing the invariant term of the comparison, // the stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, @@ -6915,7 +7253,7 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, // Compute the backedge taken count knowing the interval difference, the // stride and presence of the equality in the comparison. -const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, +const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, bool Equality) { const SCEV *One = getConstant(Step->getType(), 1); Delta = Equality ? getAddExpr(Delta, Step) @@ -6955,7 +7293,7 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // Avoid proven overflow cases: this will ensure that the backedge taken count // will not generate any unsigned overflow. Relaxed no-overflow conditions - // exploit NoWrapFlags, allowing to optimize in presence of undefined + // exploit NoWrapFlags, allowing to optimize in presence of undefined // behaviors like the case of C language. if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) return getCouldNotCompute(); @@ -7035,7 +7373,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, // Avoid proven overflow cases: this will ensure that the backedge taken count // will not generate any unsigned overflow. Relaxed no-overflow conditions - // exploit NoWrapFlags, allowing to optimize in presence of undefined + // exploit NoWrapFlags, allowing to optimize in presence of undefined // behaviors like the case of C language. if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap)) return getCouldNotCompute(); @@ -7083,7 +7421,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (isa(BECount)) MaxBECount = BECount; else - MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), + MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), getConstant(MinStride), false); if (isa(MaxBECount)) @@ -7680,7 +8018,7 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { // that until everything else is done. if (U == Old) continue; - if (!Visited.insert(U)) + if (!Visited.insert(U).second) continue; if (PHINode *PN = dyn_cast(U)) SE->ConstantEvolutionLoopExitValue.erase(PN); @@ -7709,11 +8047,10 @@ ScalarEvolution::ScalarEvolution() bool ScalarEvolution::runOnFunction(Function &F) { this->F = &F; - AT = &getAnalysis(); - LI = &getAnalysis(); - DataLayoutPass *DLP = getAnalysisIfAvailable(); - DL = DLP ? &DLP->getDataLayout() : nullptr; - TLI = &getAnalysis(); + AC = &getAnalysis().getAssumptionCache(F); + LI = &getAnalysis().getLoopInfo(); + DL = &F.getParent()->getDataLayout(); + TLI = &getAnalysis().getTLI(); DT = &getAnalysis().getDomTree(); return false; } @@ -7750,10 +8087,10 @@ void ScalarEvolution::releaseMemory() { void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); - AU.addRequired(); - AU.addRequiredTransitive(); + AU.addRequired(); + AU.addRequiredTransitive(); AU.addRequiredTransitive(); - AU.addRequired(); + AU.addRequired(); } bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { @@ -7844,17 +8181,17 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const { ScalarEvolution::LoopDisposition ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { - SmallVector, 2> &Values = LoopDispositions[S]; - for (unsigned u = 0; u < Values.size(); u++) { - if (Values[u].first == L) - return Values[u].second; + auto &Values = LoopDispositions[S]; + for (auto &V : Values) { + if (V.getPointer() == L) + return V.getInt(); } - Values.push_back(std::make_pair(L, LoopVariant)); + Values.emplace_back(L, LoopVariant); LoopDisposition D = computeLoopDisposition(S, L); - SmallVector, 2> &Values2 = LoopDispositions[S]; - for (unsigned u = Values2.size(); u > 0; u--) { - if (Values2[u - 1].first == L) { - Values2[u - 1].second = D; + auto &Values2 = LoopDispositions[S]; + for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { + if (V.getPointer() == L) { + V.setInt(D); break; } } @@ -7950,17 +8287,17 @@ bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) { ScalarEvolution::BlockDisposition ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { - SmallVector, 2> &Values = BlockDispositions[S]; - for (unsigned u = 0; u < Values.size(); u++) { - if (Values[u].first == BB) - return Values[u].second; + auto &Values = BlockDispositions[S]; + for (auto &V : Values) { + if (V.getPointer() == BB) + return V.getInt(); } - Values.push_back(std::make_pair(BB, DoesNotDominateBlock)); + Values.emplace_back(BB, DoesNotDominateBlock); BlockDisposition D = computeBlockDisposition(S, BB); - SmallVector, 2> &Values2 = BlockDispositions[S]; - for (unsigned u = Values2.size(); u > 0; u--) { - if (Values2[u - 1].first == BB) { - Values2[u - 1].second = D; + auto &Values2 = BlockDispositions[S]; + for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { + if (V.getPointer() == BB) { + V.setInt(D); break; } }