X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=4652e449f4bfb0a164ff5e0d2a3ca8a24d18c725;hb=59970e6f528230b1d97a95facc5c4a281dcc2535;hp=3ccbd14438d1223b21672af799cc8b959a43e720;hpb=cd5029d00108df094b1ed225fb81f4eeece580a9;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 3ccbd14438d..4652e449f4b 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -83,11 +83,13 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Support/SaveAndRestore.h" #include using namespace llvm; @@ -114,16 +116,6 @@ static cl::opt VerifySCEV("verify-scev", cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); -INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution", - "Scalar Evolution Analysis", false, true) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution", - "Scalar Evolution Analysis", false, true) -char ScalarEvolution::ID = 0; - //===----------------------------------------------------------------------===// // SCEV class definitions //===----------------------------------------------------------------------===// @@ -132,12 +124,11 @@ char ScalarEvolution::ID = 0; // Implementation of the SCEV class. // -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void SCEV::dump() const { print(dbgs()); dbgs() << '\n'; } -#endif void SCEV::print(raw_ostream &OS) const { switch (static_cast(getSCEVType())) { @@ -726,6 +717,13 @@ public: return; } + // A simple case when N/1. The quotient is N. + if (Denominator->isOne()) { + *Quotient = Numerator; + *Remainder = D.Zero; + return; + } + // Split the Denominator when it is a product. if (const SCEVMulExpr *T = dyn_cast(Denominator)) { const SCEV *Q, *R; @@ -785,9 +783,15 @@ public: void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { const SCEV *StartQ, *StartR, *StepQ, *StepR; - assert(Numerator->isAffine() && "Numerator should be affine"); + if (!Numerator->isAffine()) + return cannotDivide(Numerator); divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); + // Bail out if the types do not match. + Type *Ty = Denominator->getType(); + if (Ty != StartQ->getType() || Ty != StartR->getType() || + Ty != StepQ->getType() || Ty != StepR->getType()) + return cannotDivide(Numerator); Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), Numerator->getNoWrapFlags()); Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), @@ -803,11 +807,8 @@ public: divide(SE, Op, Denominator, &Q, &R); // Bail out if types do not match. - if (Ty != Q->getType() || Ty != R->getType()) { - Quotient = Zero; - Remainder = Numerator; - return; - } + if (Ty != Q->getType() || Ty != R->getType()) + return cannotDivide(Numerator); Qs.push_back(Q); Rs.push_back(R); @@ -830,11 +831,8 @@ public: bool FoundDenominatorTerm = false; for (const SCEV *Op : Numerator->operands()) { // Bail out if types do not match. - if (Ty != Op->getType()) { - Quotient = Zero; - Remainder = Numerator; - return; - } + if (Ty != Op->getType()) + return cannotDivide(Numerator); if (FoundDenominatorTerm) { Qs.push_back(Op); @@ -850,11 +848,8 @@ public: } // Bail out if types do not match. - if (Ty != Q->getType()) { - Quotient = Zero; - Remainder = Numerator; - return; - } + if (Ty != Q->getType()) + return cannotDivide(Numerator); FoundDenominatorTerm = true; Qs.push_back(Q); @@ -869,11 +864,8 @@ public: return; } - if (!isa(Denominator)) { - Quotient = Zero; - Remainder = Numerator; - return; - } + if (!isa(Denominator)) + return cannotDivide(Numerator); // The Remainder is obtained by replacing Denominator by 0 in Numerator. ValueToValueMap RewriteMap; @@ -893,15 +885,12 @@ public: // Quotient is (Numerator - Remainder) divided by Denominator. const SCEV *Q, *R; const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); - if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) { - // This SCEV does not seem to simplify: fail the division here. - Quotient = Zero; - Remainder = Numerator; - return; - } + // This SCEV does not seem to simplify: fail the division here. + if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) + return cannotDivide(Numerator); divide(SE, Diff, Denominator, &Q, &R); - assert(R == Zero && - "(Numerator - Remainder) should evenly divide Denominator"); + if (R != Zero) + return cannotDivide(Numerator); Quotient = Q; } @@ -909,11 +898,18 @@ private: SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, const SCEV *Denominator) : SE(S), Denominator(Denominator) { - Zero = SE.getConstant(Denominator->getType(), 0); - One = SE.getConstant(Denominator->getType(), 1); + Zero = SE.getZero(Denominator->getType()); + One = SE.getOne(Denominator->getType()); + + // We generally do not know how to divide Expr by Denominator. We + // initialize the division to a "cannot divide" state to simplify the rest + // of the code. + cannotDivide(Numerator); + } - // By default, we don't know how to divide Expr by Denominator. - // Providing the default here simplifies the rest of the code. + // Convenience function for giving up on the division. We set the quotient to + // be equal to zero and the remainder to be equal to the numerator. + void cannotDivide(const SCEV *Numerator) { Quotient = Zero; Remainder = Numerator; } @@ -1102,13 +1098,14 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, return getTruncateOrZeroExtend(SZ->getOperand(), Ty); // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can - // eliminate all the truncates. + // eliminate all the truncates, or we replace other casts with truncates. if (const SCEVAddExpr *SA = dyn_cast(Op)) { SmallVector Operands; bool hasTrunc = false; for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) { const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty); - hasTrunc = isa(S); + if (!isa(SA->getOperand(i))) + hasTrunc = isa(S); Operands.push_back(S); } if (!hasTrunc) @@ -1117,13 +1114,14 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, } // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can - // eliminate all the truncates. + // eliminate all the truncates, or we replace other casts with truncates. if (const SCEVMulExpr *SM = dyn_cast(Op)) { SmallVector Operands; bool hasTrunc = false; for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) { const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty); - hasTrunc = isa(S); + if (!isa(SM->getOperand(i))) + hasTrunc = isa(S); Operands.push_back(S); } if (!hasTrunc) @@ -1134,8 +1132,8 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, // If the input value is a chrec scev, truncate the chrec's operands. if (const SCEVAddRecExpr *AddRec = dyn_cast(Op)) { SmallVector Operands; - for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) - Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty)); + for (const SCEV *Op : AddRec->operands()) + Operands.push_back(getTruncateExpr(Op, Ty)); return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap); } @@ -1270,7 +1268,9 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // `Step`: // 1. NSW/NUW flags on the step increment. - const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags()); + auto PreStartFlags = + ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW); + const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags); const SCEVAddRecExpr *PreAR = dyn_cast( SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); @@ -1305,9 +1305,9 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ExtendOpTraits::getOverflowLimitForStep(Step, &Pred, SE); if (OverflowLimit && - SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) { + SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) return PreStart; - } + return nullptr; } @@ -1378,19 +1378,17 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, for (unsigned Delta : {-2, -1, 1, 2}) { const SCEV *PreStart = getConstant(StartAI - Delta); + FoldingSetNodeID ID; + ID.AddInteger(scAddRecExpr); + ID.AddPointer(PreStart); + ID.AddPointer(Step); + ID.AddPointer(L); + void *IP = nullptr; + const auto *PreAR = + static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + // Give up if we don't already have the add recurrence we need because // actually constructing an add recurrence is relatively expensive. - const SCEVAddRecExpr *PreAR = [&]() { - FoldingSetNodeID ID; - ID.AddInteger(scAddRecExpr); - ID.AddPointer(PreStart); - ID.AddPointer(Step); - ID.AddPointer(L); - void *IP = nullptr; - return static_cast( - this->UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); - }(); - if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2) const SCEV *DeltaS = getConstant(StartC->getType(), Delta); ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; @@ -1561,6 +1559,18 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, } } + if (auto *SA = dyn_cast(Op)) { + // zext((A + B + ...)) --> (zext(A) + zext(B) + ...) + if (SA->getNoWrapFlags(SCEV::FlagNUW)) { + // If the addition does not unsign overflow then we can, by definition, + // commute the zero extension with the addition operation. + SmallVector Ops; + for (const auto *Op : SA->operands()) + Ops.push_back(getZeroExtendExpr(Op, Ty)); + return getAddExpr(Ops, SCEV::FlagNUW); + } + } + // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; @@ -1618,12 +1628,12 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, } // sext(C1 + (C2 * x)) --> C1 + sext(C2 * x) if C1 < C2 - if (auto SA = dyn_cast(Op)) { + if (auto *SA = dyn_cast(Op)) { if (SA->getNumOperands() == 2) { - auto SC1 = dyn_cast(SA->getOperand(0)); - auto SMul = dyn_cast(SA->getOperand(1)); + auto *SC1 = dyn_cast(SA->getOperand(0)); + auto *SMul = dyn_cast(SA->getOperand(1)); if (SMul && SC1) { - if (auto SC2 = dyn_cast(SMul->getOperand(0))) { + if (auto *SC2 = dyn_cast(SMul->getOperand(0))) { const APInt &C1 = SC1->getValue()->getValue(); const APInt &C2 = SC2->getValue()->getValue(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && @@ -1633,6 +1643,16 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, } } } + + // sext((A + B + ...)) --> (sext(A) + sext(B) + ...) + if (SA->getNoWrapFlags(SCEV::FlagNSW)) { + // If the addition does not sign overflow then we can, by definition, + // commute the sign extension with the addition operation. + SmallVector Ops; + for (const auto *Op : SA->operands()) + Ops.push_back(getSignExtendExpr(Op, Ty)); + return getAddExpr(Ops, SCEV::FlagNSW); + } } // If the input value is a chrec scev, and we can prove that the value // did not overflow the old, smaller, value, we can sign extend all of the @@ -1737,16 +1757,16 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // If Start and Step are constants, check if we can apply this // transformation: // sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2 - auto SC1 = dyn_cast(Start); - auto SC2 = dyn_cast(Step); + auto *SC1 = dyn_cast(Start); + auto *SC2 = dyn_cast(Step); if (SC1 && SC2) { const APInt &C1 = SC1->getValue()->getValue(); const APInt &C2 = SC2->getValue()->getValue(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) { Start = getSignExtendExpr(Start, Ty); - const SCEV *NewAR = getAddRecExpr(getConstant(AR->getType(), 0), Step, - L, AR->getNoWrapFlags()); + const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L, + AR->getNoWrapFlags()); return getAddExpr(Start, getSignExtendExpr(NewAR, Ty)); } } @@ -1881,8 +1901,7 @@ CollectAddOperandsWithScales(DenseMap &M, // the map. SmallVector MulOps(Mul->op_begin()+1, Mul->op_end()); const SCEV *Key = SE.getMulExpr(MulOps); - std::pair::iterator, bool> Pair = - M.insert(std::make_pair(Key, NewScale)); + auto Pair = M.insert(std::make_pair(Key, NewScale)); if (Pair.second) { NewOps.push_back(Pair.first->first); } else { @@ -1924,8 +1943,9 @@ namespace { static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const SmallVectorImpl &Ops, - SCEV::NoWrapFlags OldFlags) { + SCEV::NoWrapFlags Flags) { using namespace std::placeholders; + typedef OverflowingBinaryOperator OBO; bool CanAnalyze = Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; @@ -1934,7 +1954,7 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; SCEV::NoWrapFlags SignOrUnsignWrap = - ScalarEvolution::maskFlags(OldFlags, SignOrUnsignMask); + ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. auto IsKnownNonNegative = @@ -1942,10 +1962,34 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, if (SignOrUnsignWrap == SCEV::FlagNSW && std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative)) - return ScalarEvolution::setFlags(OldFlags, - (SCEV::NoWrapFlags)SignOrUnsignMask); + Flags = + ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); + + SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); + + if (SignOrUnsignWrap != SignOrUnsignMask && Type == scAddExpr && + Ops.size() == 2 && isa(Ops[0])) { + + // (A + C) --> (A + C) if the addition does not sign overflow + // (A + C) --> (A + C) if the addition does not unsign overflow + + const APInt &C = cast(Ops[0])->getValue()->getValue(); + if (!(SignOrUnsignWrap & SCEV::FlagNSW)) { + auto NSWRegion = + ConstantRange::makeNoWrapRegion(Instruction::Add, C, OBO::NoSignedWrap); + if (NSWRegion.contains(SE->getSignedRange(Ops[1]))) + Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); + } + if (!(SignOrUnsignWrap & SCEV::FlagNUW)) { + auto NUWRegion = + ConstantRange::makeNoWrapRegion(Instruction::Add, C, + OBO::NoUnsignedWrap); + if (NUWRegion.contains(SE->getUnsignedRange(Ops[1]))) + Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); + } + } - return OldFlags; + return Flags; } /// getAddExpr - Get a canonical add expression, or something simpler if @@ -1963,10 +2007,10 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, "SCEVAddExpr operand types don't match!"); #endif - Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); - // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops, LI); + GroupByComplexity(Ops, &LI); + + Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); // If there are any constants, fold them together. unsigned Idx = 0; @@ -2046,8 +2090,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, break; } LargeMulOps.push_back(T->getOperand()); - } else if (const SCEVConstant *C = - dyn_cast(M->getOperand(j))) { + } else if (const auto *C = dyn_cast(M->getOperand(j))) { LargeMulOps.push_back(getAnyExtendExpr(C, SrcType)); } else { Ok = false; @@ -2110,20 +2153,18 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // re-generate the operands list. Group the operands by constant scale, // to avoid multiplying by the same constant scale multiple times. std::map, APIntCompare> MulOpLists; - for (SmallVectorImpl::const_iterator I = NewOps.begin(), - E = NewOps.end(); I != E; ++I) - MulOpLists[M.find(*I)->second].push_back(*I); + for (const SCEV *NewOp : NewOps) + MulOpLists[M.find(NewOp)->second].push_back(NewOp); // Re-generate the operands list. Ops.clear(); if (AccumulatedConstant != 0) Ops.push_back(getConstant(AccumulatedConstant)); - for (std::map, APIntCompare>::iterator - I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I) - if (I->first != 0) - Ops.push_back(getMulExpr(getConstant(I->first), - getAddExpr(I->second))); + for (auto &MulOp : MulOpLists) + if (MulOp.first != 0) + Ops.push_back(getMulExpr(getConstant(MulOp.first), + getAddExpr(MulOp.second))); if (Ops.empty()) - return getConstant(Ty, 0); + return getZero(Ty); if (Ops.size() == 1) return Ops[0]; return getAddExpr(Ops); @@ -2151,7 +2192,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); InnerMul = getMulExpr(MulOps); } - const SCEV *One = getConstant(Ty, 1); + const SCEV *One = getOne(Ty); const SCEV *AddOne = getAddExpr(One, InnerMul); const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV); if (Ops.size() == 2) return OuterMul; @@ -2262,8 +2303,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, AddRec->op_end()); for (; OtherIdx != Ops.size() && isa(Ops[OtherIdx]); ++OtherIdx) - if (const SCEVAddRecExpr *OtherAddRec = - dyn_cast(Ops[OtherIdx])) + if (const auto *OtherAddRec = dyn_cast(Ops[OtherIdx])) if (OtherAddRec->getLoop() == AddRecLoop) { for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) { @@ -2371,10 +2411,10 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, "SCEVMulExpr operand types don't match!"); #endif - Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); - // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops, LI); + GroupByComplexity(Ops, &LI); + + Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); // If there are any constants, fold them together. unsigned Idx = 0; @@ -2424,9 +2464,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, } if (AnyFolded) return getAddExpr(NewOps); - } - else if (const SCEVAddRecExpr * - AddRec = dyn_cast(Ops[1])) { + } else if (const auto *AddRec = dyn_cast(Ops[1])) { // Negation preserves a recurrence's no self-wrap property. SmallVector Operands; for (SCEVAddRecExpr::op_iterator I = AddRec->op_begin(), @@ -2543,7 +2581,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, SmallVector AddRecOps; for (int x = 0, xe = AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { - const SCEV *Term = getConstant(Ty, 0); + const SCEV *Term = getZero(Ty); for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), @@ -2641,10 +2679,9 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, getZeroExtendExpr(Step, ExtTy), AR->getLoop(), SCEV::FlagAnyWrap)) { SmallVector Operands; - for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i) - Operands.push_back(getUDivExpr(AR->getOperand(i), RHS)); - return getAddRecExpr(Operands, AR->getLoop(), - SCEV::FlagNW); + for (const SCEV *Op : AR->operands()) + Operands.push_back(getUDivExpr(Op, RHS)); + return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW); } /// Get a canonical UDivExpr for a recurrence. /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0. @@ -2665,8 +2702,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // (A*B)/C --> A*(B/C) if safe and B/C can be folded. if (const SCEVMulExpr *M = dyn_cast(LHS)) { SmallVector Operands; - for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) - Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy)); + for (const SCEV *Op : M->operands()) + Operands.push_back(getZeroExtendExpr(Op, ExtTy)); if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) // Find an operand that's safely divisible. for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) { @@ -2683,8 +2720,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. if (const SCEVAddExpr *A = dyn_cast(LHS)) { SmallVector Operands; - for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) - Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy)); + for (const SCEV *Op : A->operands()) + Operands.push_back(getZeroExtendExpr(Op, ExtTy)); if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { Operands.clear(); for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { @@ -2752,8 +2789,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, if (const SCEVConstant *RHSCst = dyn_cast(RHS)) { // If the mulexpr multiplies by a constant, then that constant must be the // first element of the mulexpr. - if (const SCEVConstant *LHSCst = - dyn_cast(Mul->getOperand(0))) { + if (const auto *LHSCst = dyn_cast(Mul->getOperand(0))) { if (LHSCst == RHSCst) { SmallVector Operands; Operands.append(Mul->op_begin() + 1, Mul->op_end()); @@ -2842,22 +2878,20 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, // Canonicalize nested AddRecs in by nesting them in order of loop depth. if (const SCEVAddRecExpr *NestedAR = dyn_cast(Operands[0])) { const Loop *NestedLoop = NestedAR->getLoop(); - if (L->contains(NestedLoop) ? - (L->getLoopDepth() < NestedLoop->getLoopDepth()) : - (!NestedLoop->contains(L) && - DT->dominates(L->getHeader(), NestedLoop->getHeader()))) { + if (L->contains(NestedLoop) + ? (L->getLoopDepth() < NestedLoop->getLoopDepth()) + : (!NestedLoop->contains(L) && + DT.dominates(L->getHeader(), NestedLoop->getHeader()))) { SmallVector NestedOperands(NestedAR->op_begin(), NestedAR->op_end()); Operands[0] = NestedAR->getStart(); // AddRecs require their operands be loop-invariant with respect to their // loops. Don't perform this transformation if it would break this // requirement. - bool AllInvariant = true; - for (unsigned i = 0, e = Operands.size(); i != e; ++i) - if (!isLoopInvariant(Operands[i], L)) { - AllInvariant = false; - break; - } + bool AllInvariant = + std::all_of(Operands.begin(), Operands.end(), + [&](const SCEV *Op) { return isLoopInvariant(Op, L); }); + if (AllInvariant) { // Create a recurrence for the outer loop with the same step size. // @@ -2867,12 +2901,10 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags()); NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags); - AllInvariant = true; - for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i) - if (!isLoopInvariant(NestedOperands[i], NestedLoop)) { - AllInvariant = false; - break; - } + AllInvariant = std::all_of( + NestedOperands.begin(), NestedOperands.end(), + [&](const SCEV *Op) { return isLoopInvariant(Op, NestedLoop); }); + if (AllInvariant) { // Ok, both add recurrences are valid after the transformation. // @@ -2909,6 +2941,57 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, return S; } +const SCEV * +ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr, + const SmallVectorImpl &IndexExprs, + bool InBounds) { + // getSCEV(Base)->getType() has the same address space as Base->getType() + // because SCEV::getType() preserves the address space. + Type *IntPtrTy = getEffectiveSCEVType(BaseExpr->getType()); + // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP + // instruction to its SCEV, because the Instruction may be guarded by control + // flow and the no-overflow bits may not be valid for the expression in any + // context. This can be fixed similarly to how these flags are handled for + // adds. + SCEV::NoWrapFlags Wrap = InBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap; + + const SCEV *TotalOffset = getZero(IntPtrTy); + // The address space is unimportant. The first thing we do on CurTy is getting + // its element type. + Type *CurTy = PointerType::getUnqual(PointeeType); + for (const SCEV *IndexExpr : IndexExprs) { + // Compute the (potentially symbolic) offset in bytes for this index. + if (StructType *STy = dyn_cast(CurTy)) { + // For a struct, add the member offset. + ConstantInt *Index = cast(IndexExpr)->getValue(); + unsigned FieldNo = Index->getZExtValue(); + const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo); + + // Add the field offset to the running total offset. + TotalOffset = getAddExpr(TotalOffset, FieldOffset); + + // Update CurTy to the type of the field at Index. + CurTy = STy->getTypeAtIndex(Index); + } else { + // Update CurTy to its element type. + CurTy = cast(CurTy)->getElementType(); + // For an array, add the element offset, explicitly scaled. + const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, CurTy); + // Getelementptr indices are signed. + IndexExpr = getTruncateOrSignExtend(IndexExpr, IntPtrTy); + + // Multiply the index by the element size to compute the element offset. + const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, Wrap); + + // Add the element offset to the running total offset. + TotalOffset = getAddExpr(TotalOffset, LocalOffset); + } + } + + // Add the total offset from all the GEP indices to the base. + return getAddExpr(BaseExpr, TotalOffset, Wrap); +} + const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) { SmallVector Ops; @@ -2929,7 +3012,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { #endif // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops, LI); + GroupByComplexity(Ops, &LI); // If there are any constants, fold them together. unsigned Idx = 0; @@ -3033,7 +3116,7 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { #endif // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops, LI); + GroupByComplexity(Ops, &LI); // If there are any constants, fold them together. unsigned Idx = 0; @@ -3130,39 +3213,20 @@ const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, } const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { - // If we have DataLayout, we can bypass creating a target-independent + // We can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. - if (DL) - return getConstant(IntTy, DL->getTypeAllocSize(AllocTy)); - - Constant *C = ConstantExpr::getSizeOf(AllocTy); - if (ConstantExpr *CE = dyn_cast(C)) - if (Constant *Folded = ConstantFoldConstantExpression(CE, DL, TLI)) - C = Folded; - Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy)); - assert(Ty == IntTy && "Effective SCEV type doesn't match"); - return getTruncateOrZeroExtend(getSCEV(C), Ty); + return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); } const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo) { - // If we have DataLayout, we can bypass creating a target-independent + // We can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. - if (DL) { - return getConstant(IntTy, - DL->getStructLayout(STy)->getElementOffset(FieldNo)); - } - - Constant *C = ConstantExpr::getOffsetOf(STy, FieldNo); - if (ConstantExpr *CE = dyn_cast(C)) - if (Constant *Folded = ConstantFoldConstantExpression(CE, DL, TLI)) - C = Folded; - - Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy)); - return getTruncateOrZeroExtend(getSCEV(C), Ty); + return getConstant( + IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo)); } const SCEV *ScalarEvolution::getUnknown(Value *V) { @@ -3204,19 +3268,7 @@ bool ScalarEvolution::isSCEVable(Type *Ty) const { /// for which isSCEVable must return true. uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); - - // If we have a DataLayout, use it! - if (DL) - return DL->getTypeSizeInBits(Ty); - - // Integer types have fixed sizes. - if (Ty->isIntegerTy()) - return Ty->getPrimitiveSizeInBits(); - - // The only other support type is pointer. Without DataLayout, conservatively - // assume pointers are 64-bit. - assert(Ty->isPointerTy() && "isSCEVable permitted a non-SCEVable type!"); - return 64; + return getDataLayout().getTypeSizeInBits(Ty); } /// getEffectiveSCEVType - Return a type with the same bitwidth as @@ -3226,22 +3278,16 @@ uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); - if (Ty->isIntegerTy()) { + if (Ty->isIntegerTy()) return Ty; - } // The only other support type is pointer. assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!"); - - if (DL) - return DL->getIntPtrType(Ty); - - // Without DataLayout, conservatively assume pointers are 64-bit. - return Type::getInt64Ty(getContext()); + return getDataLayout().getIntPtrType(Ty); } const SCEV *ScalarEvolution::getCouldNotCompute() { - return &CouldNotCompute; + return CouldNotCompute.get(); } namespace { @@ -3281,35 +3327,39 @@ bool ScalarEvolution::checkValidity(const SCEV *S) const { const SCEV *ScalarEvolution::getSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); + const SCEV *S = getExistingSCEV(V); + if (S == nullptr) { + S = createSCEV(V); + ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S)); + } + return S; +} + +const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { + assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); + ValueExprMapType::iterator I = ValueExprMap.find_as(V); if (I != ValueExprMap.end()) { const SCEV *S = I->second; if (checkValidity(S)) return S; - else - ValueExprMap.erase(I); + ValueExprMap.erase(I); } - const SCEV *S = createSCEV(V); - - // The process of creating a SCEV for V may have caused other SCEVs - // to have been created, so it's necessary to insert the new entry - // from scratch, rather than trying to remember the insert position - // above. - ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S)); - return S; + return nullptr; } /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V /// -const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) { +const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, + SCEV::NoWrapFlags Flags) { if (const SCEVConstant *VC = dyn_cast(V)) return getConstant( cast(ConstantExpr::getNeg(VC->getValue()))); Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); - return getMulExpr(V, - getConstant(cast(Constant::getAllOnesValue(Ty)))); + return getMulExpr( + V, getConstant(cast(Constant::getAllOnesValue(Ty))), Flags); } /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V @@ -3328,15 +3378,40 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { /// getMinusSCEV - Return LHS-RHS. Minus is represented in SCEV as A+B*-1. const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags) { - assert(!maskFlags(Flags, SCEV::FlagNUW) && "subtraction does not have NUW"); - // Fast path: X - X --> 0. if (LHS == RHS) - return getConstant(LHS->getType(), 0); + return getZero(LHS->getType()); + + // We represent LHS - RHS as LHS + (-1)*RHS. This transformation + // makes it so that we cannot make much use of NUW. + auto AddFlags = SCEV::FlagAnyWrap; + const bool RHSIsNotMinSigned = + !getSignedRange(RHS).getSignedMin().isMinSignedValue(); + if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) { + // Let M be the minimum representable signed value. Then (-1)*RHS + // signed-wraps if and only if RHS is M. That can happen even for + // a NSW subtraction because e.g. (-1)*M signed-wraps even though + // -1 - M does not. So to transfer NSW from LHS - RHS to LHS + + // (-1)*RHS, we need to prove that RHS != M. + // + // If LHS is non-negative and we know that LHS - RHS does not + // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap + // either by proving that RHS > M or that LHS >= 0. + if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) { + AddFlags = SCEV::FlagNSW; + } + } + + // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS - + // RHS is NSW and LHS >= 0. + // + // The difficulty here is that the NSW flag may have been proven + // relative to a loop that is to be found in a recurrence in LHS and + // not in RHS. Applying NSW to (-1)*M may then let the NSW have a + // larger scope than intended. + auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap; - // X - Y --> X + -Y. - // X -(nsw || nuw) Y --> X + -Y. - return getAddExpr(LHS, getNegativeSCEV(RHS)); + return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags); } /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the @@ -3479,8 +3554,7 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { if (const SCEVCastExpr *Cast = dyn_cast(V)) { return getPointerBase(Cast->getOperand()); - } - else if (const SCEVNAryExpr *NAry = dyn_cast(V)) { + } else if (const SCEVNAryExpr *NAry = dyn_cast(V)) { const SCEV *PtrOp = nullptr; for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); I != E; ++I) { @@ -3524,8 +3598,7 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { if (!Visited.insert(I).second) continue; - ValueExprMapType::iterator It = - ValueExprMap.find_as(static_cast(I)); + auto It = ValueExprMap.find_as(static_cast(I)); if (It != ValueExprMap.end()) { const SCEV *Old = It->second; @@ -3553,267 +3626,543 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { } } -/// createNodeForPHI - PHI nodes have two cases. Either the PHI node exists in -/// a loop header, making it a potential recurrence, or it doesn't. -/// -const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { - if (const Loop *L = LI->getLoopFor(PN->getParent())) - if (L->getHeader() == PN->getParent()) { - // The loop may have multiple entrances or multiple exits; we can analyze - // this phi as an addrec if it has a unique entry value and a unique - // backedge value. - Value *BEValueV = nullptr, *StartValueV = nullptr; - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { - Value *V = PN->getIncomingValue(i); - if (L->contains(PN->getIncomingBlock(i))) { - if (!BEValueV) { - BEValueV = V; - } else if (BEValueV != V) { - BEValueV = nullptr; - break; - } - } else if (!StartValueV) { - StartValueV = V; - } else if (StartValueV != V) { - StartValueV = nullptr; - break; - } - } - if (BEValueV && StartValueV) { - // While we are analyzing this PHI node, handle its value symbolically. - const SCEV *SymbolicName = getUnknown(PN); - assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && - "PHI node already processed?"); - ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName)); - - // Using this symbolic name for the PHI, analyze the value coming around - // the back-edge. - const SCEV *BEValue = getSCEV(BEValueV); - - // NOTE: If BEValue is loop invariant, we know that the PHI node just - // has a special value for the first iteration of the loop. - - // If the value coming around the backedge is an add with the symbolic - // value we just inserted, then we found a simple induction variable! - if (const SCEVAddExpr *Add = dyn_cast(BEValue)) { - // If there is a single occurrence of the symbolic value, replace it - // with a recurrence. - unsigned FoundIndex = Add->getNumOperands(); - for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) - if (Add->getOperand(i) == SymbolicName) - if (FoundIndex == e) { - FoundIndex = i; - break; - } +class SCEVInitRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE) { + SCEVInitRewriter Rewriter(L, SE); + const SCEV *Result = Rewriter.visit(Scev); + return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); + } - if (FoundIndex != Add->getNumOperands()) { - // Create an add with everything but the specified operand. - SmallVector Ops; - for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) - if (i != FoundIndex) - Ops.push_back(Add->getOperand(i)); - const SCEV *Accum = getAddExpr(Ops); - - // This is not a valid addrec if the step amount is varying each - // loop iteration, but is not itself an addrec in this loop. - if (isLoopInvariant(Accum, L) || - (isa(Accum) && - cast(Accum)->getLoop() == L)) { - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; - - // If the increment doesn't overflow, then neither the addrec nor - // the post-increment will overflow. - if (const AddOperator *OBO = dyn_cast(BEValueV)) { - if (OBO->hasNoUnsignedWrap()) - Flags = setFlags(Flags, SCEV::FlagNUW); - if (OBO->hasNoSignedWrap()) - Flags = setFlags(Flags, SCEV::FlagNSW); - } else if (GEPOperator *GEP = dyn_cast(BEValueV)) { - // If the increment is an inbounds GEP, then we know the address - // space cannot be wrapped around. We cannot make any guarantee - // about signed or unsigned overflow because pointers are - // unsigned but we may have a negative index from the base - // pointer. We can guarantee that no unsigned wrap occurs if the - // indices form a positive value. - if (GEP->isInBounds()) { - Flags = setFlags(Flags, SCEV::FlagNW); - - const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); - if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) - Flags = setFlags(Flags, SCEV::FlagNUW); - } + SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L), Valid(true) {} - // We cannot transfer nuw and nsw flags from subtraction - // operations -- sub nuw X, Y is not the same as add nuw X, -Y - // for instance. - } + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) + Valid = false; + return Expr; + } - const SCEV *StartVal = getSCEV(StartValueV); - const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); - - // Since the no-wrap flags are on the increment, they apply to the - // post-incremented value as well. - if (isLoopInvariant(Accum, L)) - (void)getAddRecExpr(getAddExpr(StartVal, Accum), - Accum, L, Flags); - - // Okay, for the entire analysis of this edge we assumed the PHI - // to be symbolic. We now need to go back and purge all of the - // entries for the scalars that use the symbolic expression. - ForgetSymbolicName(PN, SymbolicName); - ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; - return PHISCEV; - } + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + // Only allow AddRecExprs for this loop. + if (Expr->getLoop() == L) + return Expr->getStart(); + Valid = false; + return Expr; + } + + bool isValid() { return Valid; } + +private: + const Loop *L; + bool Valid; +}; + +class SCEVShiftRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE) { + SCEVShiftRewriter Rewriter(L, SE); + const SCEV *Result = Rewriter.visit(Scev); + return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); + } + + SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L), Valid(true) {} + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + // Only allow AddRecExprs for this loop. + if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) + Valid = false; + return Expr; + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + if (Expr->getLoop() == L && Expr->isAffine()) + return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE)); + Valid = false; + return Expr; + } + bool isValid() { return Valid; } + +private: + const Loop *L; + bool Valid; +}; + +const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { + const Loop *L = LI.getLoopFor(PN->getParent()); + if (!L || L->getHeader() != PN->getParent()) + return nullptr; + + // The loop may have multiple entrances or multiple exits; we can analyze + // this phi as an addrec if it has a unique entry value and a unique + // backedge value. + Value *BEValueV = nullptr, *StartValueV = nullptr; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + Value *V = PN->getIncomingValue(i); + if (L->contains(PN->getIncomingBlock(i))) { + if (!BEValueV) { + BEValueV = V; + } else if (BEValueV != V) { + BEValueV = nullptr; + break; + } + } else if (!StartValueV) { + StartValueV = V; + } else if (StartValueV != V) { + StartValueV = nullptr; + break; + } + } + if (BEValueV && StartValueV) { + // While we are analyzing this PHI node, handle its value symbolically. + const SCEV *SymbolicName = getUnknown(PN); + assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && + "PHI node already processed?"); + ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName)); + + // Using this symbolic name for the PHI, analyze the value coming around + // the back-edge. + const SCEV *BEValue = getSCEV(BEValueV); + + // NOTE: If BEValue is loop invariant, we know that the PHI node just + // has a special value for the first iteration of the loop. + + // If the value coming around the backedge is an add with the symbolic + // value we just inserted, then we found a simple induction variable! + if (const SCEVAddExpr *Add = dyn_cast(BEValue)) { + // If there is a single occurrence of the symbolic value, replace it + // with a recurrence. + unsigned FoundIndex = Add->getNumOperands(); + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if (Add->getOperand(i) == SymbolicName) + if (FoundIndex == e) { + FoundIndex = i; + break; } - } else if (const SCEVAddRecExpr *AddRec = - dyn_cast(BEValue)) { - // Otherwise, this could be a loop like this: - // i = 0; for (j = 1; ..; ++j) { .... i = j; } - // In this case, j = {1,+,1} and BEValue is j. - // Because the other in-value of i (0) fits the evolution of BEValue - // i really is an addrec evolution. - if (AddRec->getLoop() == L && AddRec->isAffine()) { - const SCEV *StartVal = getSCEV(StartValueV); - - // If StartVal = j.start - j.stride, we can use StartVal as the - // initial step of the addrec evolution. - if (StartVal == getMinusSCEV(AddRec->getOperand(0), - AddRec->getOperand(1))) { - // FIXME: For constant StartVal, we should be able to infer - // no-wrap flags. - const SCEV *PHISCEV = - getAddRecExpr(StartVal, AddRec->getOperand(1), L, - SCEV::FlagAnyWrap); - - // Okay, for the entire analysis of this edge we assumed the PHI - // to be symbolic. We now need to go back and purge all of the - // entries for the scalars that use the symbolic expression. - ForgetSymbolicName(PN, SymbolicName); - ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; - return PHISCEV; + + if (FoundIndex != Add->getNumOperands()) { + // Create an add with everything but the specified operand. + SmallVector Ops; + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if (i != FoundIndex) + Ops.push_back(Add->getOperand(i)); + const SCEV *Accum = getAddExpr(Ops); + + // This is not a valid addrec if the step amount is varying each + // loop iteration, but is not itself an addrec in this loop. + if (isLoopInvariant(Accum, L) || + (isa(Accum) && + cast(Accum)->getLoop() == L)) { + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + + // If the increment doesn't overflow, then neither the addrec nor + // the post-increment will overflow. + if (const AddOperator *OBO = dyn_cast(BEValueV)) { + if (OBO->getOperand(0) == PN) { + if (OBO->hasNoUnsignedWrap()) + Flags = setFlags(Flags, SCEV::FlagNUW); + if (OBO->hasNoSignedWrap()) + Flags = setFlags(Flags, SCEV::FlagNSW); } + } else if (GEPOperator *GEP = dyn_cast(BEValueV)) { + // If the increment is an inbounds GEP, then we know the address + // space cannot be wrapped around. We cannot make any guarantee + // about signed or unsigned overflow because pointers are + // unsigned but we may have a negative index from the base + // pointer. We can guarantee that no unsigned wrap occurs if the + // indices form a positive value. + if (GEP->isInBounds() && GEP->getOperand(0) == PN) { + Flags = setFlags(Flags, SCEV::FlagNW); + + const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); + if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) + Flags = setFlags(Flags, SCEV::FlagNUW); + } + + // We cannot transfer nuw and nsw flags from subtraction + // operations -- sub nuw X, Y is not the same as add nuw X, -Y + // for instance. } + + const SCEV *StartVal = getSCEV(StartValueV); + const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); + + // Since the no-wrap flags are on the increment, they apply to the + // post-incremented value as well. + if (isLoopInvariant(Accum, L)) + (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + + // Okay, for the entire analysis of this edge we assumed the PHI + // to be symbolic. We now need to go back and purge all of the + // entries for the scalars that use the symbolic expression. + ForgetSymbolicName(PN, SymbolicName); + ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; + return PHISCEV; + } + } + } else { + // Otherwise, this could be a loop like this: + // i = 0; for (j = 1; ..; ++j) { .... i = j; } + // In this case, j = {1,+,1} and BEValue is j. + // Because the other in-value of i (0) fits the evolution of BEValue + // i really is an addrec evolution. + // + // We can generalize this saying that i is the shifted value of BEValue + // by one iteration: + // PHI(f(0), f({1,+,1})) --> f({0,+,1}) + const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); + const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this); + if (Shifted != getCouldNotCompute() && + Start != getCouldNotCompute()) { + const SCEV *StartVal = getSCEV(StartValueV); + if (Start == StartVal) { + // Okay, for the entire analysis of this edge we assumed the PHI + // to be symbolic. We now need to go back and purge all of the + // entries for the scalars that use the symbolic expression. + ForgetSymbolicName(PN, SymbolicName); + ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; + return Shifted; } } } + } - // If the PHI has a single incoming value, follow that value, unless the - // PHI's incoming blocks are in a different loop, in which case doing so - // risks breaking LCSSA form. Instcombine would normally zap these, but - // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = SimplifyInstruction(PN, DL, TLI, DT, AC)) - if (LI->replacementPreservesLCSSAForm(PN, V)) - return getSCEV(V); - - // If it's not a loop phi, we can't handle it yet. - return getUnknown(PN); + return nullptr; } -/// createNodeForGEP - Expand GEP instructions into add and multiply -/// operations. This allows them to be analyzed by regular SCEV code. -/// -const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { - Type *IntPtrTy = getEffectiveSCEVType(GEP->getType()); - Value *Base = GEP->getOperand(0); - // Don't attempt to analyze GEPs over unsized objects. - if (!Base->getType()->getPointerElementType()->isSized()) - return getUnknown(GEP); +// Checks if the SCEV S is available at BB. S is considered available at BB +// if S can be materialized at BB without introducing a fault. +static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S, + BasicBlock *BB) { + struct CheckAvailable { + bool TraversalDone = false; + bool Available = true; - // Don't blindly transfer the inbounds flag from the GEP instruction to the - // Add expression, because the Instruction may be guarded by control flow - // and the no-overflow bits may not be valid for the expression in any - // context. - SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW : SCEV::FlagAnyWrap; - - const SCEV *TotalOffset = getConstant(IntPtrTy, 0); - gep_type_iterator GTI = gep_type_begin(GEP); - for (GetElementPtrInst::op_iterator I = std::next(GEP->op_begin()), - E = GEP->op_end(); - I != E; ++I) { - Value *Index = *I; - // Compute the (potentially symbolic) offset in bytes for this index. - if (StructType *STy = dyn_cast(*GTI++)) { - // For a struct, add the member offset. - unsigned FieldNo = cast(Index)->getZExtValue(); - const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo); + const Loop *L = nullptr; // The loop BB is in (can be nullptr) + BasicBlock *BB = nullptr; + DominatorTree &DT; - // Add the field offset to the running total offset. - TotalOffset = getAddExpr(TotalOffset, FieldOffset); - } else { - // For an array, add the element offset, explicitly scaled. - const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, *GTI); - const SCEV *IndexS = getSCEV(Index); - // Getelementptr indices are signed. - IndexS = getTruncateOrSignExtend(IndexS, IntPtrTy); + CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT) + : L(L), BB(BB), DT(DT) {} - // Multiply the index by the element size to compute the element offset. - const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize, Wrap); + bool setUnavailable() { + TraversalDone = true; + Available = false; + return false; + } - // Add the element offset to the running total offset. - TotalOffset = getAddExpr(TotalOffset, LocalOffset); + bool follow(const SCEV *S) { + switch (S->getSCEVType()) { + case scConstant: case scTruncate: case scZeroExtend: case scSignExtend: + case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr: + // These expressions are available if their operand(s) is/are. + return true; + + case scAddRecExpr: { + // We allow add recurrences that are on the loop BB is in, or some + // outer loop. This guarantees availability because the value of the + // add recurrence at BB is simply the "current" value of the induction + // variable. We can relax this in the future; for instance an add + // recurrence on a sibling dominating loop is also available at BB. + const auto *ARLoop = cast(S)->getLoop(); + if (L && (ARLoop == L || ARLoop->contains(L))) + return true; + + return setUnavailable(); + } + + case scUnknown: { + // For SCEVUnknown, we check for simple dominance. + const auto *SU = cast(S); + Value *V = SU->getValue(); + + if (isa(V)) + return false; + + if (isa(V) && DT.dominates(cast(V), BB)) + return false; + + return setUnavailable(); + } + + case scUDivExpr: + case scCouldNotCompute: + // We do not try to smart about these at all. + return setUnavailable(); + } + llvm_unreachable("switch should be fully covered!"); } - } - // Get the SCEV for the GEP base. - const SCEV *BaseS = getSCEV(Base); + bool isDone() { return TraversalDone; } + }; + + CheckAvailable CA(L, BB, DT); + SCEVTraversal ST(CA); - // Add the total offset from all the GEP indices to the base. - return getAddExpr(BaseS, TotalOffset, Wrap); + ST.visitAll(S); + return CA.Available; } -/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is -/// guaranteed to end in (at every loop iteration). It is, at the same time, -/// the minimum number of times S is divisible by 2. For example, given {4,+,8} -/// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S. -uint32_t -ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { - if (const SCEVConstant *C = dyn_cast(S)) - return C->getValue()->getValue().countTrailingZeros(); +// Try to match a control flow sequence that branches out at BI and merges back +// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful +// match. +static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, + Value *&C, Value *&LHS, Value *&RHS) { + C = BI->getCondition(); - if (const SCEVTruncateExpr *T = dyn_cast(S)) - return std::min(GetMinTrailingZeros(T->getOperand()), - (uint32_t)getTypeSizeInBits(T->getType())); + BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0)); + BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1)); - if (const SCEVZeroExtendExpr *E = dyn_cast(S)) { - uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); - return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? - getTypeSizeInBits(E->getType()) : OpRes; - } + if (!LeftEdge.isSingleEdge()) + return false; - if (const SCEVSignExtendExpr *E = dyn_cast(S)) { - uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); - return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? - getTypeSizeInBits(E->getType()) : OpRes; - } + assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()"); - if (const SCEVAddExpr *A = dyn_cast(S)) { - // The result is the min of all operands results. - uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); - for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); - return MinOpRes; - } + Use &LeftUse = Merge->getOperandUse(0); + Use &RightUse = Merge->getOperandUse(1); - if (const SCEVMulExpr *M = dyn_cast(S)) { - // The result is the sum of all operands results. - uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0)); - uint32_t BitWidth = getTypeSizeInBits(M->getType()); - for (unsigned i = 1, e = M->getNumOperands(); - SumOpRes != BitWidth && i != e; ++i) - SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), - BitWidth); - return SumOpRes; + if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) { + LHS = LeftUse; + RHS = RightUse; + return true; } - if (const SCEVAddRecExpr *A = dyn_cast(S)) { - // The result is the min of all operands results. - uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); - for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); - return MinOpRes; + if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) { + LHS = RightUse; + RHS = LeftUse; + return true; } - if (const SCEVSMaxExpr *M = dyn_cast(S)) { + return false; +} + +const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { + if (PN->getNumIncomingValues() == 2) { + const Loop *L = LI.getLoopFor(PN->getParent()); + + // We don't want to break LCSSA, even in a SCEV expression tree. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (LI.getLoopFor(PN->getIncomingBlock(i)) != L) + return nullptr; + + // Try to match + // + // br %cond, label %left, label %right + // left: + // br label %merge + // right: + // br label %merge + // merge: + // V = phi [ %x, %left ], [ %y, %right ] + // + // as "select %cond, %x, %y" + + BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock(); + assert(IDom && "At least the entry block should dominate PN"); + + auto *BI = dyn_cast(IDom->getTerminator()); + Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr; + + if (BI && BI->isConditional() && + BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) && + IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) && + IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent())) + return createNodeForSelectOrPHI(PN, Cond, LHS, RHS); + } + + return nullptr; +} + +const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { + if (const SCEV *S = createAddRecFromPHI(PN)) + return S; + + if (const SCEV *S = createNodeFromSelectLikePHI(PN)) + return S; + + // If the PHI has a single incoming value, follow that value, unless the + // PHI's incoming blocks are in a different loop, in which case doing so + // risks breaking LCSSA form. Instcombine would normally zap these, but + // it doesn't have DominatorTree information, so it may miss cases. + if (Value *V = SimplifyInstruction(PN, getDataLayout(), &TLI, &DT, &AC)) + if (LI.replacementPreservesLCSSAForm(PN, V)) + return getSCEV(V); + + // If it's not a loop phi, we can't handle it yet. + return getUnknown(PN); +} + +const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I, + Value *Cond, + Value *TrueVal, + Value *FalseVal) { + // Handle "constant" branch or select. This can occur for instance when a + // loop pass transforms an inner loop and moves on to process the outer loop. + if (auto *CI = dyn_cast(Cond)) + return getSCEV(CI->isOne() ? TrueVal : FalseVal); + + // Try to match some simple smax or umax patterns. + auto *ICI = dyn_cast(Cond); + if (!ICI) + return getUnknown(I); + + Value *LHS = ICI->getOperand(0); + Value *RHS = ICI->getOperand(1); + + switch (ICI->getPredicate()) { + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + // a >s b ? a+x : b+x -> smax(a, b)+x + // a >s b ? b+x : a+x -> smin(a, b)+x + if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) { + const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType()); + const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType()); + const SCEV *LA = getSCEV(TrueVal); + const SCEV *RA = getSCEV(FalseVal); + const SCEV *LDiff = getMinusSCEV(LA, LS); + const SCEV *RDiff = getMinusSCEV(RA, RS); + if (LDiff == RDiff) + return getAddExpr(getSMaxExpr(LS, RS), LDiff); + LDiff = getMinusSCEV(LA, RS); + RDiff = getMinusSCEV(RA, LS); + if (LDiff == RDiff) + return getAddExpr(getSMinExpr(LS, RS), LDiff); + } + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + // a >u b ? a+x : b+x -> umax(a, b)+x + // a >u b ? b+x : a+x -> umin(a, b)+x + if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) { + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); + const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType()); + const SCEV *LA = getSCEV(TrueVal); + const SCEV *RA = getSCEV(FalseVal); + const SCEV *LDiff = getMinusSCEV(LA, LS); + const SCEV *RDiff = getMinusSCEV(RA, RS); + if (LDiff == RDiff) + return getAddExpr(getUMaxExpr(LS, RS), LDiff); + LDiff = getMinusSCEV(LA, RS); + RDiff = getMinusSCEV(RA, LS); + if (LDiff == RDiff) + return getAddExpr(getUMinExpr(LS, RS), LDiff); + } + break; + case ICmpInst::ICMP_NE: + // n != 0 ? n+x : 1+x -> umax(n, 1)+x + if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) && + isa(RHS) && cast(RHS)->isZero()) { + const SCEV *One = getOne(I->getType()); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); + const SCEV *LA = getSCEV(TrueVal); + const SCEV *RA = getSCEV(FalseVal); + const SCEV *LDiff = getMinusSCEV(LA, LS); + const SCEV *RDiff = getMinusSCEV(RA, One); + if (LDiff == RDiff) + return getAddExpr(getUMaxExpr(One, LS), LDiff); + } + break; + case ICmpInst::ICMP_EQ: + // n == 0 ? 1+x : n+x -> umax(n, 1)+x + if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) && + isa(RHS) && cast(RHS)->isZero()) { + const SCEV *One = getOne(I->getType()); + const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); + const SCEV *LA = getSCEV(TrueVal); + const SCEV *RA = getSCEV(FalseVal); + const SCEV *LDiff = getMinusSCEV(LA, One); + const SCEV *RDiff = getMinusSCEV(RA, LS); + if (LDiff == RDiff) + return getAddExpr(getUMaxExpr(One, LS), LDiff); + } + break; + default: + break; + } + + return getUnknown(I); +} + +/// createNodeForGEP - Expand GEP instructions into add and multiply +/// operations. This allows them to be analyzed by regular SCEV code. +/// +const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { + Value *Base = GEP->getOperand(0); + // Don't attempt to analyze GEPs over unsized objects. + if (!Base->getType()->getPointerElementType()->isSized()) + return getUnknown(GEP); + + SmallVector IndexExprs; + for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) + IndexExprs.push_back(getSCEV(*Index)); + return getGEPExpr(GEP->getSourceElementType(), getSCEV(Base), IndexExprs, + GEP->isInBounds()); +} + +/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is +/// guaranteed to end in (at every loop iteration). It is, at the same time, +/// the minimum number of times S is divisible by 2. For example, given {4,+,8} +/// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S. +uint32_t +ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { + if (const SCEVConstant *C = dyn_cast(S)) + return C->getValue()->getValue().countTrailingZeros(); + + if (const SCEVTruncateExpr *T = dyn_cast(S)) + return std::min(GetMinTrailingZeros(T->getOperand()), + (uint32_t)getTypeSizeInBits(T->getType())); + + if (const SCEVZeroExtendExpr *E = dyn_cast(S)) { + uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? + getTypeSizeInBits(E->getType()) : OpRes; + } + + if (const SCEVSignExtendExpr *E = dyn_cast(S)) { + uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? + getTypeSizeInBits(E->getType()) : OpRes; + } + + if (const SCEVAddExpr *A = dyn_cast(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); + for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); + return MinOpRes; + } + + if (const SCEVMulExpr *M = dyn_cast(S)) { + // The result is the sum of all operands results. + uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0)); + uint32_t BitWidth = getTypeSizeInBits(M->getType()); + for (unsigned i = 1, e = M->getNumOperands(); + SumOpRes != BitWidth && i != e; ++i) + SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), + BitWidth); + return SumOpRes; + } + + if (const SCEVAddRecExpr *A = dyn_cast(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); + for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); + return MinOpRes; + } + + if (const SCEVSMaxExpr *M = dyn_cast(S)) { // The result is the min of all operands results. uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) @@ -3833,7 +4182,8 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { // For a SCEVUnknown, ask ValueTracking. unsigned BitWidth = getTypeSizeInBits(U->getType()); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT); + computeKnownBits(U->getValue(), Zeros, Ones, getDataLayout(), 0, &AC, + nullptr, &DT); return Zeros.countTrailingOnes(); } @@ -3844,26 +4194,9 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { /// GetRangeFromMetadata - Helper method to assign a range to V from /// metadata present in the IR. static Optional GetRangeFromMetadata(Value *V) { - if (Instruction *I = dyn_cast(V)) { - if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) { - ConstantRange TotalRange( - cast(I->getType())->getBitWidth(), false); - - unsigned NumRanges = MD->getNumOperands() / 2; - assert(NumRanges >= 1); - - for (unsigned i = 0; i < NumRanges; ++i) { - ConstantInt *Lower = - mdconst::extract(MD->getOperand(2 * i + 0)); - ConstantInt *Upper = - mdconst::extract(MD->getOperand(2 * i + 1)); - ConstantRange Range(Lower->getValue(), Upper->getValue()); - TotalRange = TotalRange.unionWith(Range); - } - - return TotalRange; - } - } + if (Instruction *I = dyn_cast(V)) + if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) + return getConstantRangeFromMetadata(*MD); return None; } @@ -4063,24 +4396,22 @@ ScalarEvolution::getRange(const SCEV *S, // Split here to avoid paying the compile-time cost of calling both // computeKnownBits and ComputeNumSignBits. This restriction can be lifted // if needed. - + const DataLayout &DL = getDataLayout(); if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { // For a SCEVUnknown, ask ValueTracking. APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT); + computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, &AC, nullptr, &DT); if (Ones != ~Zeros + 1) ConservativeResult = ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)); } else { assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED && "generalize as needed!"); - if (U->getValue()->getType()->isIntegerTy() || DL) { - unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AC, nullptr, DT); - if (NS > 1) - ConservativeResult = ConservativeResult.intersectWith(ConstantRange( - APInt::getSignedMinValue(BitWidth).ashr(NS - 1), - APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1)); - } + unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT); + if (NS > 1) + ConservativeResult = ConservativeResult.intersectWith( + ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), + APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1)); } return setRange(U, SignHint, ConservativeResult); @@ -4089,8 +4420,64 @@ ScalarEvolution::getRange(const SCEV *S, return setRange(S, SignHint, ConservativeResult); } -/// createSCEV - We know that there is no SCEV for the specified value. -/// Analyze the expression. +SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { + if (isa(V)) return SCEV::FlagAnyWrap; + const BinaryOperator *BinOp = cast(V); + + // Return early if there are no flags to propagate to the SCEV. + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + if (BinOp->hasNoUnsignedWrap()) + Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); + if (BinOp->hasNoSignedWrap()) + Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); + if (Flags == SCEV::FlagAnyWrap) { + return SCEV::FlagAnyWrap; + } + + // Here we check that BinOp is in the header of the innermost loop + // containing BinOp, since we only deal with instructions in the loop + // header. The actual loop we need to check later will come from an add + // recurrence, but getting that requires computing the SCEV of the operands, + // which can be expensive. This check we can do cheaply to rule out some + // cases early. + Loop *innermostContainingLoop = LI.getLoopFor(BinOp->getParent()); + if (innermostContainingLoop == nullptr || + innermostContainingLoop->getHeader() != BinOp->getParent()) + return SCEV::FlagAnyWrap; + + // Only proceed if we can prove that BinOp does not yield poison. + if (!isKnownNotFullPoison(BinOp)) return SCEV::FlagAnyWrap; + + // At this point we know that if V is executed, then it does not wrap + // according to at least one of NSW or NUW. If V is not executed, then we do + // not know if the calculation that V represents would wrap. Multiple + // instructions can map to the same SCEV. If we apply NSW or NUW from V to + // the SCEV, we must guarantee no wrapping for that SCEV also when it is + // derived from other instructions that map to the same SCEV. We cannot make + // that guarantee for cases where V is not executed. So we need to find the + // loop that V is considered in relation to and prove that V is executed for + // every iteration of that loop. That implies that the value that V + // calculates does not wrap anywhere in the loop, so then we can apply the + // flags to the SCEV. + // + // We check isLoopInvariant to disambiguate in case we are adding two + // recurrences from different loops, so that we know which loop to prove + // that V is executed in. + for (int OpIndex = 0; OpIndex < 2; ++OpIndex) { + const SCEV *Op = getSCEV(BinOp->getOperand(OpIndex)); + if (auto *AddRec = dyn_cast(Op)) { + const int OtherOpIndex = 1 - OpIndex; + const SCEV *OtherOp = getSCEV(BinOp->getOperand(OtherOpIndex)); + if (isLoopInvariant(OtherOp, AddRec->getLoop()) && + isGuaranteedToExecuteForEveryIteration(BinOp, AddRec->getLoop())) + return Flags; + } + } + return SCEV::FlagAnyWrap; +} + +/// createSCEV - We know that there is no SCEV for the specified value. Analyze +/// the expression. /// const SCEV *ScalarEvolution::createSCEV(Value *V) { if (!isSCEVable(V->getType())) @@ -4104,14 +4491,14 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // reachable. Such instructions don't matter, and they aren't required // to obey basic rules for definitions dominating uses which this // analysis depends on. - if (!DT->isReachableFromEntry(I->getParent())) + if (!DT.isReachableFromEntry(I->getParent())) return getUnknown(V); } else if (ConstantExpr *CE = dyn_cast(V)) Opcode = CE->getOpcode(); else if (ConstantInt *CI = dyn_cast(V)) return getConstant(CI); else if (isa(V)) - return getConstant(V->getType(), 0); + return getZero(V->getType()); else if (GlobalAlias *GA = dyn_cast(V)) return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee()); else @@ -4126,47 +4513,79 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // because it leads to N-1 getAddExpr calls for N ultimate operands. // Instead, gather up all the operands and make a single getAddExpr call. // LLVM IR canonical form means we need only traverse the left operands. - // - // Don't apply this instruction's NSW or NUW flags to the new - // expression. The instruction may be guarded by control flow that the - // no-wrap behavior depends on. Non-control-equivalent instructions can be - // mapped to the same SCEV expression, and it would be incorrect to transfer - // NSW/NUW semantics to those operations. SmallVector AddOps; - AddOps.push_back(getSCEV(U->getOperand(1))); - for (Value *Op = U->getOperand(0); ; Op = U->getOperand(0)) { - unsigned Opcode = Op->getValueID() - Value::InstructionVal; - if (Opcode != Instruction::Add && Opcode != Instruction::Sub) + for (Value *Op = U;; Op = U->getOperand(0)) { + U = dyn_cast(Op); + unsigned Opcode = U ? U->getOpcode() : 0; + if (!U || (Opcode != Instruction::Add && Opcode != Instruction::Sub)) { + assert(Op != V && "V should be an add"); + AddOps.push_back(getSCEV(Op)); + break; + } + + if (auto *OpSCEV = getExistingSCEV(U)) { + AddOps.push_back(OpSCEV); + break; + } + + // If a NUW or NSW flag can be applied to the SCEV for this + // addition, then compute the SCEV for this addition by itself + // with a separate call to getAddExpr. We need to do that + // instead of pushing the operands of the addition onto AddOps, + // since the flags are only known to apply to this particular + // addition - they may not apply to other additions that can be + // formed with operands from AddOps. + const SCEV *RHS = getSCEV(U->getOperand(1)); + SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U); + if (Flags != SCEV::FlagAnyWrap) { + const SCEV *LHS = getSCEV(U->getOperand(0)); + if (Opcode == Instruction::Sub) + AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); + else + AddOps.push_back(getAddExpr(LHS, RHS, Flags)); break; - U = cast(Op); - const SCEV *Op1 = getSCEV(U->getOperand(1)); + } + if (Opcode == Instruction::Sub) - AddOps.push_back(getNegativeSCEV(Op1)); + AddOps.push_back(getNegativeSCEV(RHS)); else - AddOps.push_back(Op1); + AddOps.push_back(RHS); } - AddOps.push_back(getSCEV(U->getOperand(0))); return getAddExpr(AddOps); } + case Instruction::Mul: { - // Don't transfer NSW/NUW for the same reason as AddExpr. SmallVector MulOps; - MulOps.push_back(getSCEV(U->getOperand(1))); - for (Value *Op = U->getOperand(0); - Op->getValueID() == Instruction::Mul + Value::InstructionVal; - Op = U->getOperand(0)) { - U = cast(Op); + for (Value *Op = U;; Op = U->getOperand(0)) { + U = dyn_cast(Op); + if (!U || U->getOpcode() != Instruction::Mul) { + assert(Op != V && "V should be a mul"); + MulOps.push_back(getSCEV(Op)); + break; + } + + if (auto *OpSCEV = getExistingSCEV(U)) { + MulOps.push_back(OpSCEV); + break; + } + + SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U); + if (Flags != SCEV::FlagAnyWrap) { + MulOps.push_back(getMulExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1)), Flags)); + break; + } + MulOps.push_back(getSCEV(U->getOperand(1))); } - MulOps.push_back(getSCEV(U->getOperand(0))); return getMulExpr(MulOps); } case Instruction::UDiv: return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1))); case Instruction::Sub: - return getMinusSCEV(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1))); + return getMinusSCEV(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)), + getNoWrapFlagsFromUB(U)); case Instruction::And: // For an expression like x&255 that merely masks off the high bits, // use zext(trunc(x)) as the SCEV expression. @@ -4185,8 +4604,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL, 0, AC, - nullptr, DT); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, getDataLayout(), + 0, &AC, nullptr, &DT); APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); @@ -4286,9 +4705,18 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { if (SA->getValue().uge(BitWidth)) break; + // It is currently not resolved how to interpret NSW for left + // shift by BitWidth - 1, so we avoid applying flags in that + // case. Remove this check (or this comment) once the situation + // is resolved. See + // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html + // and http://reviews.llvm.org/D8890 . + auto Flags = SCEV::FlagAnyWrap; + if (SA->getValue().ult(BitWidth - 1)) Flags = getNoWrapFlagsFromUB(U); + Constant *X = ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); - return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X)); + return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X), Flags); } break; @@ -4363,94 +4791,13 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { return createNodeForPHI(cast(U)); case Instruction::Select: - // This could be a smax or umax that was lowered earlier. - // Try to recover it. - if (ICmpInst *ICI = dyn_cast(U->getOperand(0))) { - Value *LHS = ICI->getOperand(0); - Value *RHS = ICI->getOperand(1); - switch (ICI->getPredicate()) { - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - std::swap(LHS, RHS); - // fall through - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - // a >s b ? a+x : b+x -> smax(a, b)+x - // a >s b ? b+x : a+x -> smin(a, b)+x - if (getTypeSizeInBits(LHS->getType()) <= - getTypeSizeInBits(U->getType())) { - const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType()); - const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType()); - const SCEV *LA = getSCEV(U->getOperand(1)); - const SCEV *RA = getSCEV(U->getOperand(2)); - const SCEV *LDiff = getMinusSCEV(LA, LS); - const SCEV *RDiff = getMinusSCEV(RA, RS); - if (LDiff == RDiff) - return getAddExpr(getSMaxExpr(LS, RS), LDiff); - LDiff = getMinusSCEV(LA, RS); - RDiff = getMinusSCEV(RA, LS); - if (LDiff == RDiff) - return getAddExpr(getSMinExpr(LS, RS), LDiff); - } - break; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: - std::swap(LHS, RHS); - // fall through - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: - // a >u b ? a+x : b+x -> umax(a, b)+x - // a >u b ? b+x : a+x -> umin(a, b)+x - if (getTypeSizeInBits(LHS->getType()) <= - getTypeSizeInBits(U->getType())) { - const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); - const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType()); - const SCEV *LA = getSCEV(U->getOperand(1)); - const SCEV *RA = getSCEV(U->getOperand(2)); - const SCEV *LDiff = getMinusSCEV(LA, LS); - const SCEV *RDiff = getMinusSCEV(RA, RS); - if (LDiff == RDiff) - return getAddExpr(getUMaxExpr(LS, RS), LDiff); - LDiff = getMinusSCEV(LA, RS); - RDiff = getMinusSCEV(RA, LS); - if (LDiff == RDiff) - return getAddExpr(getUMinExpr(LS, RS), LDiff); - } - break; - case ICmpInst::ICMP_NE: - // n != 0 ? n+x : 1+x -> umax(n, 1)+x - if (getTypeSizeInBits(LHS->getType()) <= - getTypeSizeInBits(U->getType()) && - isa(RHS) && cast(RHS)->isZero()) { - const SCEV *One = getConstant(U->getType(), 1); - const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); - const SCEV *LA = getSCEV(U->getOperand(1)); - const SCEV *RA = getSCEV(U->getOperand(2)); - const SCEV *LDiff = getMinusSCEV(LA, LS); - const SCEV *RDiff = getMinusSCEV(RA, One); - if (LDiff == RDiff) - return getAddExpr(getUMaxExpr(One, LS), LDiff); - } - break; - case ICmpInst::ICMP_EQ: - // n == 0 ? 1+x : n+x -> umax(n, 1)+x - if (getTypeSizeInBits(LHS->getType()) <= - getTypeSizeInBits(U->getType()) && - isa(RHS) && cast(RHS)->isZero()) { - const SCEV *One = getConstant(U->getType(), 1); - const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType()); - const SCEV *LA = getSCEV(U->getOperand(1)); - const SCEV *RA = getSCEV(U->getOperand(2)); - const SCEV *LDiff = getMinusSCEV(LA, One); - const SCEV *RDiff = getMinusSCEV(RA, LS); - if (LDiff == RDiff) - return getAddExpr(getUMaxExpr(One, LS), LDiff); - } - break; - default: - break; - } - } + // U can also be a select constant expr, which let fall through. Since + // createNodeForSelect only works for a condition that is an `ICmpInst`, and + // constant expressions cannot have instructions as operands, we'd have + // returned getUnknown for a select constant expressions anyway. + if (isa(U)) + return createNodeForSelectOrPHI(cast(U), U->getOperand(0), + U->getOperand(1), U->getOperand(2)); default: // We cannot analyze this expression. break; @@ -4534,8 +4881,7 @@ ScalarEvolution::getSmallConstantTripMultiple(Loop *L, return 1; // Get the trip count from the BE count by adding 1. - const SCEV *TCMul = getAddExpr(ExitCount, - getConstant(ExitCount->getType(), 1)); + const SCEV *TCMul = getAddExpr(ExitCount, getOne(ExitCount->getType())); // FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt // to factor simple cases. if (const SCEVMulExpr *Mul = dyn_cast(TCMul)) @@ -4610,10 +4956,10 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { if (!Pair.second) return Pair.first->second; - // ComputeBackedgeTakenCount may allocate memory for its result. Inserting it + // computeBackedgeTakenCount may allocate memory for its result. Inserting it // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result // must be cleared in this scope. - BackedgeTakenInfo Result = ComputeBackedgeTakenCount(L); + BackedgeTakenInfo Result = computeBackedgeTakenCount(L); if (Result.getExact(this) != getCouldNotCompute()) { assert(isLoopInvariant(Result.getExact(this), L) && @@ -4666,7 +5012,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { } // Re-lookup the insert position, since the call to - // ComputeBackedgeTakenCount above could result in a + // computeBackedgeTakenCount above could result in a // recusive call to getBackedgeTakenInfo (on a different // loop), which would invalidate the iterator computed // earlier. @@ -4744,12 +5090,12 @@ void ScalarEvolution::forgetValue(Value *V) { } /// getExact - Get the exact loop backedge taken count considering all loop -/// exits. A computable result can only be return for loops with a single exit. -/// Returning the minimum taken count among all exits is incorrect because one -/// of the loop's exit limit's may have been skipped. HowFarToZero assumes that -/// the limit of each loop test is never skipped. This is a valid assumption as -/// long as the loop exits via that test. For precise results, it is the -/// caller's responsibility to specify the relevant loop exit using +/// exits. A computable result can only be returned for loops with a single +/// exit. Returning the minimum taken count among all exits is incorrect +/// because one of the loop's exit limit's may have been skipped. HowFarToZero +/// assumes that the limit of each loop test is never skipped. This is a valid +/// assumption as long as the loop exits via that test. For precise results, it +/// is the caller's responsibility to specify the relevant loop exit using /// getExact(ExitingBlock, SE). const SCEV * ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const { @@ -4847,10 +5193,10 @@ void ScalarEvolution::BackedgeTakenInfo::clear() { delete[] ExitNotTaken.getNextExit(); } -/// ComputeBackedgeTakenCount - Compute the number of times the backedge +/// computeBackedgeTakenCount - Compute the number of times the backedge /// of the specified loop will execute. ScalarEvolution::BackedgeTakenInfo -ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { +ScalarEvolution::computeBackedgeTakenCount(const Loop *L) { SmallVector ExitingBlocks; L->getExitingBlocks(ExitingBlocks); @@ -4864,7 +5210,7 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { // and compute maxBECount. for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { BasicBlock *ExitBB = ExitingBlocks[i]; - ExitLimit EL = ComputeExitLimit(L, ExitBB); + ExitLimit EL = computeExitLimit(L, ExitBB); // 1. For each exit that can be computed, add an entry to ExitCounts. // CouldComputeBECount is true only if all exits can be computed. @@ -4885,7 +5231,7 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { // MaxBECount is conservatively the maximum EL.Max, where CouldNotCompute is // considered greater than any computable EL.Max. if (EL.Max != getCouldNotCompute() && Latch && - DT->dominates(ExitBB, Latch)) { + DT.dominates(ExitBB, Latch)) { if (!MustExitMaxBECount) MustExitMaxBECount = EL.Max; else { @@ -4906,13 +5252,11 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount); } -/// ComputeExitLimit - Compute the number of times the backedge of the specified -/// loop will execute if it exits via the specified block. ScalarEvolution::ExitLimit -ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { +ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { - // Okay, we've chosen an exiting block. See what condition causes us to - // exit at this block and remember the exit block and whether all other targets + // Okay, we've chosen an exiting block. See what condition causes us to exit + // at this block and remember the exit block and whether all other targets // lead to the loop header. bool MustExecuteLoopHeader = true; BasicBlock *Exit = nullptr; @@ -4952,8 +5296,7 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { if (!Pred) return getCouldNotCompute(); TerminatorInst *PredTerm = Pred->getTerminator(); - for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) { - BasicBlock *PredSucc = PredTerm->getSuccessor(i); + for (const BasicBlock *PredSucc : PredTerm->successors()) { if (PredSucc == BB) continue; // If the predecessor has a successor that isn't BB and isn't @@ -4976,19 +5319,19 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { if (BranchInst *BI = dyn_cast(Term)) { assert(BI->isConditional() && "If unconditional, it can't be in loop!"); // Proceed to the next level to examine the exit condition expression. - return ComputeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0), + return computeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1), /*ControlsExit=*/IsOnlyExit); } if (SwitchInst *SI = dyn_cast(Term)) - return ComputeExitLimitFromSingleExitSwitch(L, SI, Exit, + return computeExitLimitFromSingleExitSwitch(L, SI, Exit, /*ControlsExit=*/IsOnlyExit); return getCouldNotCompute(); } -/// ComputeExitLimitFromCond - Compute the number of times the +/// computeExitLimitFromCond - Compute the number of times the /// backedge of the specified loop will execute if its exit condition /// were a conditional branch of ExitCond, TBB, and FBB. /// @@ -4997,7 +5340,7 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { /// condition is true and can infer that failing to meet the condition prior to /// integer wraparound results in undefined behavior. ScalarEvolution::ExitLimit -ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, +ScalarEvolution::computeExitLimitFromCond(const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB, @@ -5007,9 +5350,9 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. bool EitherMayExit = L->contains(TBB); - ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, + ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, ControlsExit && !EitherMayExit); - ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, + ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, ControlsExit && !EitherMayExit); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); @@ -5042,9 +5385,9 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. bool EitherMayExit = L->contains(FBB); - ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, + ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, ControlsExit && !EitherMayExit); - ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, + ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, ControlsExit && !EitherMayExit); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); @@ -5079,7 +5422,7 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. if (ICmpInst *ExitCondICmp = dyn_cast(ExitCond)) - return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); + return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); // Check for a constant condition. These are normally stripped out by // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to @@ -5091,18 +5434,15 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, return getCouldNotCompute(); else // The backedge is never taken. - return getConstant(CI->getType(), 0); + return getZero(CI->getType()); } // If it's not an integer or pointer comparison then compute it the hard way. - return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); + return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); } -/// ComputeExitLimitFromICmp - Compute the number of times the -/// backedge of the specified loop will execute if its exit condition -/// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB. ScalarEvolution::ExitLimit -ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, +ScalarEvolution::computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, BasicBlock *TBB, BasicBlock *FBB, @@ -5119,11 +5459,16 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, if (LoadInst *LI = dyn_cast(ExitCond->getOperand(0))) if (Constant *RHS = dyn_cast(ExitCond->getOperand(1))) { ExitLimit ItCnt = - ComputeLoadConstantCompareExitLimit(LI, RHS, L, Cond); + computeLoadConstantCompareExitLimit(LI, RHS, L, Cond); if (ItCnt.hasAnyInfo()) return ItCnt; } + ExitLimit ShiftEL = computeShiftCompareExitLimit( + ExitCond->getOperand(0), ExitCond->getOperand(1), L, Cond); + if (ShiftEL.hasAnyInfo()) + return ShiftEL; + const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); @@ -5183,21 +5528,13 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, break; } default: -#if 0 - dbgs() << "ComputeBackedgeTakenCount "; - if (ExitCond->getOperand(0)->getType()->isUnsigned()) - dbgs() << "[unsigned] "; - dbgs() << *LHS << " " - << Instruction::getOpcodeName(Instruction::ICmp) - << " " << *RHS << "\n"; -#endif break; } - return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); + return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); } ScalarEvolution::ExitLimit -ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L, +ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock, bool ControlsExit) { @@ -5230,11 +5567,11 @@ EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, return cast(Val)->getValue(); } -/// ComputeLoadConstantCompareExitLimit - Given an exit condition of +/// computeLoadConstantCompareExitLimit - Given an exit condition of /// 'icmp op load X, cst', try to see if we can compute the backedge /// execution count. ScalarEvolution::ExitLimit -ScalarEvolution::ComputeLoadConstantCompareExitLimit( +ScalarEvolution::computeLoadConstantCompareExitLimit( LoadInst *LI, Constant *RHS, const Loop *L, @@ -5303,11 +5640,6 @@ ScalarEvolution::ComputeLoadConstantCompareExitLimit( Result = ConstantExpr::getICmp(predicate, Result, RHS); if (!isa(Result)) break; // Couldn't decide for sure if (cast(Result)->getValue().isMinValue()) { -#if 0 - dbgs() << "\n***\n*** Computed loop count " << *ItCst - << "\n*** From global " << *GV << "*** BB: " << *L->getHeader() - << "***\n"; -#endif ++NumArrayLenItCounts; return getConstant(ItCst); // Found terminating iteration! } @@ -5315,6 +5647,149 @@ ScalarEvolution::ComputeLoadConstantCompareExitLimit( return getCouldNotCompute(); } +ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( + Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) { + ConstantInt *RHS = dyn_cast(RHSV); + if (!RHS) + return getCouldNotCompute(); + + const BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return getCouldNotCompute(); + + const BasicBlock *Predecessor = L->getLoopPredecessor(); + if (!Predecessor) + return getCouldNotCompute(); + + // Return true if V is of the form "LHS `shift_op` ". + // Return LHS in OutLHS and shift_opt in OutOpCode. + auto MatchPositiveShift = + [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) { + + using namespace PatternMatch; + + ConstantInt *ShiftAmt; + if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) + OutOpCode = Instruction::LShr; + else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) + OutOpCode = Instruction::AShr; + else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) + OutOpCode = Instruction::Shl; + else + return false; + + return ShiftAmt->getValue().isStrictlyPositive(); + }; + + // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in + // + // loop: + // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ] + // %iv.shifted = lshr i32 %iv, + // + // Return true on a succesful match. Return the corresponding PHI node (%iv + // above) in PNOut and the opcode of the shift operation in OpCodeOut. + auto MatchShiftRecurrence = + [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) { + Optional PostShiftOpCode; + + { + Instruction::BinaryOps OpC; + Value *V; + + // If we encounter a shift instruction, "peel off" the shift operation, + // and remember that we did so. Later when we inspect %iv's backedge + // value, we will make sure that the backedge value uses the same + // operation. + // + // Note: the peeled shift operation does not have to be the same + // instruction as the one feeding into the PHI's backedge value. We only + // really care about it being the same *kind* of shift instruction -- + // that's all that is required for our later inferences to hold. + if (MatchPositiveShift(LHS, V, OpC)) { + PostShiftOpCode = OpC; + LHS = V; + } + } + + PNOut = dyn_cast(LHS); + if (!PNOut || PNOut->getParent() != L->getHeader()) + return false; + + Value *BEValue = PNOut->getIncomingValueForBlock(Latch); + Value *OpLHS; + + return + // The backedge value for the PHI node must be a shift by a positive + // amount + MatchPositiveShift(BEValue, OpLHS, OpCodeOut) && + + // of the PHI node itself + OpLHS == PNOut && + + // and the kind of shift should be match the kind of shift we peeled + // off, if any. + (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut); + }; + + PHINode *PN; + Instruction::BinaryOps OpCode; + if (!MatchShiftRecurrence(LHS, PN, OpCode)) + return getCouldNotCompute(); + + const DataLayout &DL = getDataLayout(); + + // The key rationale for this optimization is that for some kinds of shift + // recurrences, the value of the recurrence "stabilizes" to either 0 or -1 + // within a finite number of iterations. If the condition guarding the + // backedge (in the sense that the backedge is taken if the condition is true) + // is false for the value the shift recurrence stabilizes to, then we know + // that the backedge is taken only a finite number of times. + + ConstantInt *StableValue = nullptr; + switch (OpCode) { + default: + llvm_unreachable("Impossible case!"); + + case Instruction::AShr: { + // {K,ashr,} stabilizes to signum(K) in at most + // bitwidth(K) iterations. + Value *FirstValue = PN->getIncomingValueForBlock(Predecessor); + bool KnownZero, KnownOne; + ComputeSignBit(FirstValue, KnownZero, KnownOne, DL, 0, nullptr, + Predecessor->getTerminator(), &DT); + auto *Ty = cast(RHS->getType()); + if (KnownZero) + StableValue = ConstantInt::get(Ty, 0); + else if (KnownOne) + StableValue = ConstantInt::get(Ty, -1, true); + else + return getCouldNotCompute(); + + break; + } + case Instruction::LShr: + case Instruction::Shl: + // Both {K,lshr,} and {K,shl,} + // stabilize to 0 in at most bitwidth(K) iterations. + StableValue = ConstantInt::get(cast(RHS->getType()), 0); + break; + } + + auto *Result = + ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI); + assert(Result->getType()->isIntegerTy(1) && + "Otherwise cannot be an operand to a branch instruction"); + + if (Result->isZeroValue()) { + unsigned BitWidth = getTypeSizeInBits(RHS->getType()); + const SCEV *UpperBound = + getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); + return ExitLimit(getCouldNotCompute(), UpperBound); + } + + return getCouldNotCompute(); +} /// CanConstantFold - Return true if we can constant fold an instruction of the /// specified type, assuming that all operands were constants. @@ -5337,12 +5812,9 @@ static bool canConstantEvolve(Instruction *I, const Loop *L) { if (!L->contains(I)) return false; if (isa(I)) { - if (L->getHeader() == I->getParent()) - return true; - else - // We don't currently keep track of the control flow needed to evaluate - // PHIs, so we cannot handle PHIs inside of loops. - return false; + // We don't currently keep track of the control flow needed to evaluate + // PHIs, so we cannot handle PHIs inside of loops. + return L->getHeader() == I->getParent(); } // If we won't be able to constant fold this expression even if the operands @@ -5398,9 +5870,8 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { Instruction *I = dyn_cast(V); if (!I || !canConstantEvolve(I, L)) return nullptr; - if (PHINode *PN = dyn_cast(I)) { + if (PHINode *PN = dyn_cast(I)) return PN; - } // Record non-constant instructions contained by the loop. DenseMap PHIMap; @@ -5413,7 +5884,7 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { /// reason, return null. static Constant *EvaluateExpression(Value *V, const Loop *L, DenseMap &Vals, - const DataLayout *DL, + const DataLayout &DL, const TargetLibraryInfo *TLI) { // Convenient constant check, but redundant for recursive calls. if (Constant *C = dyn_cast(V)) return C; @@ -5457,6 +5928,30 @@ static Constant *EvaluateExpression(Value *V, const Loop *L, TLI); } + +// If every incoming value to PN except the one for BB is a specific Constant, +// return that, else return nullptr. +static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) { + Constant *IncomingVal = nullptr; + + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + if (PN->getIncomingBlock(i) == BB) + continue; + + auto *CurrentVal = dyn_cast(PN->getIncomingValue(i)); + if (!CurrentVal) + return nullptr; + + if (IncomingVal != CurrentVal) { + if (IncomingVal) + return nullptr; + IncomingVal = CurrentVal; + } + } + + return IncomingVal; +} + /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a /// constant number of times, and the PHI node is just a recurrence @@ -5465,8 +5960,7 @@ Constant * ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, const APInt &BEs, const Loop *L) { - DenseMap::const_iterator I = - ConstantEvolutionLoopExitValue.find(PN); + auto I = ConstantEvolutionLoopExitValue.find(PN); if (I != ConstantEvolutionLoopExitValue.end()) return I->second; @@ -5479,22 +5973,21 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, BasicBlock *Header = L->getHeader(); assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); - // Since the loop is canonicalized, the PHI node must have two entries. One - // entry must be a constant (coming in from outside of the loop), and the - // second must be derived from the same PHI. - bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1)); - PHINode *PHI = nullptr; - for (BasicBlock::iterator I = Header->begin(); - (PHI = dyn_cast(I)); ++I) { - Constant *StartCST = - dyn_cast(PHI->getIncomingValue(!SecondIsBackedge)); + BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return nullptr; + + for (auto &I : *Header) { + PHINode *PHI = dyn_cast(&I); + if (!PHI) break; + auto *StartCST = getOtherIncomingValue(PHI, Latch); if (!StartCST) continue; CurrentIterVals[PHI] = StartCST; } if (!CurrentIterVals.count(PN)) return RetVal = nullptr; - Value *BEValue = PN->getIncomingValue(SecondIsBackedge); + Value *BEValue = PN->getIncomingValueForBlock(Latch); // Execute the loop symbolically to determine the exit value. if (BEs.getActiveBits() >= 32) @@ -5502,6 +5995,7 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, unsigned NumIterations = BEs.getZExtValue(); // must be in range unsigned IterationNum = 0; + const DataLayout &DL = getDataLayout(); for (; ; ++IterationNum) { if (IterationNum == NumIterations) return RetVal = CurrentIterVals[PN]; // Got exit value! @@ -5509,8 +6003,8 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, // Compute the value of the PHIs for the next iteration. // EvaluateExpression adds non-phi values to the CurrentIterVals map. DenseMap NextIterVals; - Constant *NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, - TLI); + Constant *NextPHI = + EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); if (!NextPHI) return nullptr; // Couldn't evaluate! NextIterVals[PN] = NextPHI; @@ -5521,23 +6015,21 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, // cease to be able to evaluate one of them or if they stop evolving, // because that doesn't necessarily prevent us from computing PN. SmallVector, 8> PHIsToCompute; - for (DenseMap::const_iterator - I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){ - PHINode *PHI = dyn_cast(I->first); + for (const auto &I : CurrentIterVals) { + PHINode *PHI = dyn_cast(I.first); if (!PHI || PHI == PN || PHI->getParent() != Header) continue; - PHIsToCompute.push_back(std::make_pair(PHI, I->second)); + PHIsToCompute.emplace_back(PHI, I.second); } // We use two distinct loops because EvaluateExpression may invalidate any // iterators into CurrentIterVals. - for (SmallVectorImpl >::const_iterator - I = PHIsToCompute.begin(), E = PHIsToCompute.end(); I != E; ++I) { - PHINode *PHI = I->first; + for (const auto &I : PHIsToCompute) { + PHINode *PHI = I.first; Constant *&NextPHI = NextIterVals[PHI]; if (!NextPHI) { // Not already computed. - Value *BEValue = PHI->getIncomingValue(SecondIsBackedge); - NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, TLI); + Value *BEValue = PHI->getIncomingValueForBlock(Latch); + NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); } - if (NextPHI != I->second) + if (NextPHI != I.second) StoppedEvolving = false; } @@ -5550,12 +6042,7 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, } } -/// ComputeExitCountExhaustively - If the loop is known to execute a -/// constant number of times (the condition evolves only from constants), -/// try to evaluate a few iterations of the loop until we get the exit -/// condition gets a value of ExitWhen (true or false). If we cannot -/// evaluate the trip count of the loop, return getCouldNotCompute(). -const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L, +const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) { PHINode *PN = getConstantEvolvingPHI(Cond, L); @@ -5569,14 +6056,14 @@ const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L, BasicBlock *Header = L->getHeader(); assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); - // One entry must be a constant (coming in from outside of the loop), and the - // second must be derived from the same PHI. - bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1)); - PHINode *PHI = nullptr; - for (BasicBlock::iterator I = Header->begin(); - (PHI = dyn_cast(I)); ++I) { - Constant *StartCST = - dyn_cast(PHI->getIncomingValue(!SecondIsBackedge)); + BasicBlock *Latch = L->getLoopLatch(); + assert(Latch && "Should follow from NumIncomingValues == 2!"); + + for (auto &I : *Header) { + PHINode *PHI = dyn_cast(&I); + if (!PHI) + break; + auto *StartCST = getOtherIncomingValue(PHI, Latch); if (!StartCST) continue; CurrentIterVals[PHI] = StartCST; } @@ -5586,12 +6073,11 @@ const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L, // Okay, we find a PHI node that defines the trip count of this loop. Execute // the loop symbolically to determine when the condition gets a value of // "ExitWhen". - unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis. + const DataLayout &DL = getDataLayout(); for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){ - ConstantInt *CondVal = - dyn_cast_or_null(EvaluateExpression(Cond, L, CurrentIterVals, - DL, TLI)); + auto *CondVal = dyn_cast_or_null( + EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI)); // Couldn't symbolically evaluate. if (!CondVal) return getCouldNotCompute(); @@ -5608,20 +6094,17 @@ const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L, // calling EvaluateExpression on them because that may invalidate iterators // into CurrentIterVals. SmallVector PHIsToCompute; - for (DenseMap::const_iterator - I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){ - PHINode *PHI = dyn_cast(I->first); + for (const auto &I : CurrentIterVals) { + PHINode *PHI = dyn_cast(I.first); if (!PHI || PHI->getParent() != Header) continue; PHIsToCompute.push_back(PHI); } - for (SmallVectorImpl::const_iterator I = PHIsToCompute.begin(), - E = PHIsToCompute.end(); I != E; ++I) { - PHINode *PHI = *I; + for (PHINode *PHI : PHIsToCompute) { Constant *&NextPHI = NextIterVals[PHI]; if (NextPHI) continue; // Already computed! - Value *BEValue = PHI->getIncomingValue(SecondIsBackedge); - NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, TLI); + Value *BEValue = PHI->getIncomingValueForBlock(Latch); + NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); } CurrentIterVals.swap(NextIterVals); } @@ -5722,7 +6205,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { if (PTy->getElementType()->isStructTy()) C2 = ConstantExpr::getIntegerCast( C2, Type::getInt32Ty(C->getContext()), true); - C = ConstantExpr::getGetElementPtr(C, C2); + C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2); } else C = ConstantExpr::getAdd(C, C2); } @@ -5766,7 +6249,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // exit value from the loop without using SCEVs. if (const SCEVUnknown *SU = dyn_cast(V)) { if (Instruction *I = dyn_cast(SU->getValue())) { - const Loop *LI = (*this->LI)[I->getParent()]; + const Loop *LI = this->LI[I->getParent()]; if (LI && LI->getParentLoop() == L) // Looking for loop exit value. if (PHINode *PN = dyn_cast(I)) if (PN->getParent() == LI->getHeader()) { @@ -5794,8 +6277,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { if (CanConstantFold(I)) { SmallVector Operands; bool MadeImprovement = false; - for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { - Value *Op = I->getOperand(i); + for (Value *Op : I->operands()) { if (Constant *C = dyn_cast(Op)) { Operands.push_back(C); continue; @@ -5824,16 +6306,16 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // Check to see if getSCEVAtScope actually made an improvement. if (MadeImprovement) { Constant *C = nullptr; + const DataLayout &DL = getDataLayout(); if (const CmpInst *CI = dyn_cast(I)) - C = ConstantFoldCompareInstOperands(CI->getPredicate(), - Operands[0], Operands[1], DL, - TLI); + C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], + Operands[1], DL, &TLI); else if (const LoadInst *LI = dyn_cast(I)) { if (!LI->isVolatile()) C = ConstantFoldLoadFromConstPtr(Operands[0], DL); } else - C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), - Operands, DL, TLI); + C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, + DL, &TLI); if (!C) return V; return getSCEV(C); } @@ -6106,10 +6588,6 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { const SCEVConstant *R1 = dyn_cast(Roots.first); const SCEVConstant *R2 = dyn_cast(Roots.second); if (R1 && R2) { -#if 0 - dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1 - << " sol#2: " << *R2 << "\n"; -#endif // Pick the smallest positive root value. if (ConstantInt *CB = dyn_cast(ConstantExpr::getICmp(CmpInst::ICMP_ULT, @@ -6193,16 +6671,56 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // also returns true if StepV is maximally negative (eg, INT_MIN), but that // case is not handled as this code is guarded by !CountDown. if (StepV.isPowerOf2() && - GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) - return getUDivExactExpr(Distance, Step); - } - - // If the condition controls loop exit (the loop exits only if the expression - // is true) and the addition is no-wrap we can use unsigned divide to - // compute the backedge count. In this case, the step may not divide the - // distance, but we don't care because if the condition is "missed" the loop - // will have undefined behavior due to wrapping. - if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) { + GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) { + // Here we've constrained the equation to be of the form + // + // 2^(N + k) * Distance' = (StepV == 2^N) * X (mod 2^W) ... (0) + // + // where we're operating on a W bit wide integer domain and k is + // non-negative. The smallest unsigned solution for X is the trip count. + // + // (0) is equivalent to: + // + // 2^(N + k) * Distance' - 2^N * X = L * 2^W + // <=> 2^N(2^k * Distance' - X) = L * 2^(W - N) * 2^N + // <=> 2^k * Distance' - X = L * 2^(W - N) + // <=> 2^k * Distance' = L * 2^(W - N) + X ... (1) + // + // The smallest X satisfying (1) is unsigned remainder of dividing the LHS + // by 2^(W - N). + // + // <=> X = 2^k * Distance' URem 2^(W - N) ... (2) + // + // E.g. say we're solving + // + // 2 * Val = 2 * X (in i8) ... (3) + // + // then from (2), we get X = Val URem i8 128 (k = 0 in this case). + // + // Note: It is tempting to solve (3) by setting X = Val, but Val is not + // necessarily the smallest unsigned value of X that satisfies (3). + // E.g. if Val is i8 -127 then the smallest value of X that satisfies (3) + // is i8 1, not i8 -127 + + const auto *ModuloResult = getUDivExactExpr(Distance, Step); + + // Since SCEV does not have a URem node, we construct one using a truncate + // and a zero extend. + + unsigned NarrowWidth = StepV.getBitWidth() - StepV.countTrailingZeros(); + auto *NarrowTy = IntegerType::get(getContext(), NarrowWidth); + auto *WideTy = Distance->getType(); + + return getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); + } + } + + // If the condition controls loop exit (the loop exits only if the expression + // is true) and the addition is no-wrap we can use unsigned divide to + // compute the backedge count. In this case, the step may not divide the + // distance, but we don't care because if the condition is "missed" the loop + // will have undefined behavior due to wrapping. + if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) { const SCEV *Exact = getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); return ExitLimit(Exact, Exact); @@ -6229,7 +6747,7 @@ ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { // already. If so, the backedge will execute zero times. if (const SCEVConstant *C = dyn_cast(V)) { if (!C->getValue()->isNullValue()) - return getConstant(C->getType(), 0); + return getZero(C->getType()); return getCouldNotCompute(); // Otherwise it will loop infinitely. } @@ -6254,7 +6772,7 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { // A loop's header is defined to be a block that dominates the loop. // If the header has a unique predecessor outside the loop, it must be // a block that has exactly one successor that can reach the loop. - if (Loop *L = LI->getLoopFor(BB)) + if (Loop *L = LI.getLoopFor(BB)) return std::make_pair(L->getLoopPredecessor(), L->getHeader()); return std::pair(); @@ -6270,13 +6788,20 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) { // Quick check to see if they are the same SCEV. if (A == B) return true; + auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) { + // Not all instructions that are "identical" compute the same value. For + // instance, two distinct alloca instructions allocating the same type are + // identical and do not read memory; but compute distinct values. + return A->isIdenticalTo(B) && (isa(A) || isa(A)); + }; + // Otherwise, if they're both SCEVUnknown, it's possible that they hold // two different instructions with the same value. Check for this case. if (const SCEVUnknown *AU = dyn_cast(A)) if (const SCEVUnknown *BU = dyn_cast(B)) if (const Instruction *AI = dyn_cast(AU->getValue())) if (const Instruction *BI = dyn_cast(BU->getValue())) - if (AI->isIdenticalTo(BI) && !AI->mayReadFromMemory()) + if (ComputesEqualValues(AI, BI)) return true; // Otherwise assume they may have a different value. @@ -6615,10 +7140,140 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, if (LeftGuarded && RightGuarded) return true; + if (isKnownPredicateViaSplitting(Pred, LHS, RHS)) + return true; + // Otherwise see what can be done with known constant ranges. return isKnownPredicateWithRanges(Pred, LHS, RHS); } +bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS, + ICmpInst::Predicate Pred, + bool &Increasing) { + bool Result = isMonotonicPredicateImpl(LHS, Pred, Increasing); + +#ifndef NDEBUG + // Verify an invariant: inverting the predicate should turn a monotonically + // increasing change to a monotonically decreasing one, and vice versa. + bool IncreasingSwapped; + bool ResultSwapped = isMonotonicPredicateImpl( + LHS, ICmpInst::getSwappedPredicate(Pred), IncreasingSwapped); + + assert(Result == ResultSwapped && "should be able to analyze both!"); + if (ResultSwapped) + assert(Increasing == !IncreasingSwapped && + "monotonicity should flip as we flip the predicate"); +#endif + + return Result; +} + +bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS, + ICmpInst::Predicate Pred, + bool &Increasing) { + + // A zero step value for LHS means the induction variable is essentially a + // loop invariant value. We don't really depend on the predicate actually + // flipping from false to true (for increasing predicates, and the other way + // around for decreasing predicates), all we care about is that *if* the + // predicate changes then it only changes from false to true. + // + // A zero step value in itself is not very useful, but there may be places + // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be + // as general as possible. + + switch (Pred) { + default: + return false; // Conservative answer + + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + if (!LHS->getNoWrapFlags(SCEV::FlagNUW)) + return false; + + Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE; + return true; + + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: { + if (!LHS->getNoWrapFlags(SCEV::FlagNSW)) + return false; + + const SCEV *Step = LHS->getStepRecurrence(*this); + + if (isKnownNonNegative(Step)) { + Increasing = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE; + return true; + } + + if (isKnownNonPositive(Step)) { + Increasing = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE; + return true; + } + + return false; + } + + } + + llvm_unreachable("switch has default clause!"); +} + +bool ScalarEvolution::isLoopInvariantPredicate( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, + ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS, + const SCEV *&InvariantRHS) { + + // If there is a loop-invariant, force it into the RHS, otherwise bail out. + if (!isLoopInvariant(RHS, L)) { + if (!isLoopInvariant(LHS, L)) + return false; + + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + const SCEVAddRecExpr *ArLHS = dyn_cast(LHS); + if (!ArLHS || ArLHS->getLoop() != L) + return false; + + bool Increasing; + if (!isMonotonicPredicate(ArLHS, Pred, Increasing)) + return false; + + // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to + // true as the loop iterates, and the backedge is control dependent on + // "ArLHS `Pred` RHS" == true then we can reason as follows: + // + // * if the predicate was false in the first iteration then the predicate + // is never evaluated again, since the loop exits without taking the + // backedge. + // * if the predicate was true in the first iteration then it will + // continue to be true for all future iterations since it is + // monotonically increasing. + // + // For both the above possibilities, we can replace the loop varying + // predicate with its value on the first iteration of the loop (which is + // loop invariant). + // + // A similar reasoning applies for a monotonically decreasing predicate, by + // replacing true with false and false with true in the above two bullets. + + auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred); + + if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS)) + return false; + + InvariantPred = Pred; + InvariantLHS = ArLHS->getStart(); + InvariantRHS = RHS; + return true; +} + bool ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { @@ -6693,6 +7348,28 @@ ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred, return false; } +bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, + const SCEV *LHS, + const SCEV *RHS) { + if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate) + return false; + + // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on + // the stack can result in exponential time complexity. + SaveAndRestore Restore(ProvingSplitPredicate, true); + + // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L + // + // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use + // isKnownPredicate. isKnownPredicate is more powerful, but also more + // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the + // interesting cases seen in practice. We can consider "upgrading" L >= 0 to + // use isKnownPredicate later if needed. + return isKnownNonNegative(RHS) && + isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) && + isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS); +} + /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is /// protected by a conditional between LHS and RHS. This is used to /// to eliminate casts. @@ -6718,18 +7395,80 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, LoopContinuePredicate->getSuccessor(0) != L->getHeader())) return true; + // We don't want more than one activation of the following loops on the stack + // -- that can lead to O(n!) time complexity. + if (WalkingBEDominatingConds) + return false; + + SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true); + + // See if we can exploit a trip count to prove the predicate. + const auto &BETakenInfo = getBackedgeTakenInfo(L); + const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this); + if (LatchBECount != getCouldNotCompute()) { + // We know that Latch branches back to the loop header exactly + // LatchBECount times. This means the backdege condition at Latch is + // equivalent to "{0,+,1} u< LatchBECount". + Type *Ty = LatchBECount->getType(); + auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW); + const SCEV *LoopCounter = + getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags); + if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter, + LatchBECount)) + return true; + } + // Check conditions due to any @llvm.assume intrinsics. - for (auto &AssumeVH : AC->assumptions()) { + for (auto &AssumeVH : AC.assumptions()) { if (!AssumeVH) continue; auto *CI = cast(AssumeVH); - if (!DT->dominates(CI, Latch->getTerminator())) + if (!DT.dominates(CI, Latch->getTerminator())) continue; if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) return true; } + // If the loop is not reachable from the entry block, we risk running into an + // infinite loop as we walk up into the dom tree. These loops do not matter + // anyway, so we just return a conservative answer when we see them. + if (!DT.isReachableFromEntry(L->getHeader())) + return false; + + for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()]; + DTN != HeaderDTN; DTN = DTN->getIDom()) { + + assert(DTN && "should reach the loop header before reaching the root!"); + + BasicBlock *BB = DTN->getBlock(); + BasicBlock *PBB = BB->getSinglePredecessor(); + if (!PBB) + continue; + + BranchInst *ContinuePredicate = dyn_cast(PBB->getTerminator()); + if (!ContinuePredicate || !ContinuePredicate->isConditional()) + continue; + + Value *Condition = ContinuePredicate->getCondition(); + + // If we have an edge `E` within the loop body that dominates the only + // latch, the condition guarding `E` also guards the backedge. This + // reasoning works only for loops with a single latch. + + BasicBlockEdge DominatingEdge(PBB, BB); + if (DominatingEdge.isSingleEdge()) { + // We're constructively (and conservatively) enumerating edges within the + // loop body that dominate the latch. The dominator tree better agree + // with us on this: + assert(DT.dominates(DominatingEdge, Latch) && "should be!"); + + if (isImpliedCond(Pred, LHS, RHS, Condition, + BB != ContinuePredicate->getSuccessor(0))) + return true; + } + } + return false; } @@ -6767,11 +7506,11 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, } // Check conditions due to any @llvm.assume intrinsics. - for (auto &AssumeVH : AC->assumptions()) { + for (auto &AssumeVH : AC.assumptions()) { if (!AssumeVH) continue; auto *CI = cast(AssumeVH); - if (!DT->dominates(CI, L->getHeader())) + if (!DT.dominates(CI, L->getHeader())) continue; if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) @@ -6781,6 +7520,7 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, return false; } +namespace { /// RAII wrapper to prevent recursive application of isImpliedCond. /// ScalarEvolution's PendingLoopPredicates set must be empty unless we are /// currently evaluating isImpliedCond. @@ -6798,6 +7538,7 @@ struct MarkPendingLoopPredicate { LoopPreds.erase(Cond); } }; +} // end anonymous namespace /// isImpliedCond - Test whether the condition described by Pred, LHS, /// and RHS is true whenever the given Cond value evaluates to true. @@ -6825,15 +7566,6 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, ICmpInst *ICI = dyn_cast(FoundCondValue); if (!ICI) return false; - // Bail if the ICmp's operands' types are wider than the needed type - // before attempting to call getSCEV on them. This avoids infinite - // recursion, since the analysis of widening casts can require loop - // exit condition information for overflow checking, which would - // lead back here. - if (getTypeSizeInBits(LHS->getType()) < - getTypeSizeInBits(ICI->getOperand(0)->getType())) - return false; - // Now that we found a conditional branch that dominates the loop or controls // the loop latch. Check to see if it is the comparison we are looking for. ICmpInst::Predicate FoundPred; @@ -6845,9 +7577,25 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); - // Balance the types. The case where FoundLHS' type is wider than - // LHS' type is checked for above. - if (getTypeSizeInBits(LHS->getType()) > + return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS); +} + +bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, + const SCEV *RHS, + ICmpInst::Predicate FoundPred, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + // Balance the types. + if (getTypeSizeInBits(LHS->getType()) < + getTypeSizeInBits(FoundLHS->getType())) { + if (CmpInst::isSigned(Pred)) { + LHS = getSignExtendExpr(LHS, FoundLHS->getType()); + RHS = getSignExtendExpr(RHS, FoundLHS->getType()); + } else { + LHS = getZeroExtendExpr(LHS, FoundLHS->getType()); + RHS = getZeroExtendExpr(RHS, FoundLHS->getType()); + } + } else if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(FoundLHS->getType())) { if (CmpInst::isSigned(FoundPred)) { FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType()); @@ -6892,6 +7640,13 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, RHS, LHS, FoundLHS, FoundRHS); } + // Unsigned comparison is the same as signed comparison when both the operands + // are non-negative. + if (CmpInst::isUnsigned(FoundPred) && + CmpInst::getSignedPredicate(FoundPred) == Pred && + isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); + // Check if we can make progress by sharpening ranges. if (FoundPred == ICmpInst::ICMP_NE && (isa(FoundLHS) || isa(FoundRHS))) { @@ -6966,6 +7721,149 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, return false; } +bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, + const SCEV *&L, const SCEV *&R, + SCEV::NoWrapFlags &Flags) { + const auto *AE = dyn_cast(Expr); + if (!AE || AE->getNumOperands() != 2) + return false; + + L = AE->getOperand(0); + R = AE->getOperand(1); + Flags = AE->getNoWrapFlags(); + return true; +} + +bool ScalarEvolution::computeConstantDifference(const SCEV *Less, + const SCEV *More, + APInt &C) { + // We avoid subtracting expressions here because this function is usually + // fairly deep in the call stack (i.e. is called many times). + + if (isa(Less) && isa(More)) { + const auto *LAR = cast(Less); + const auto *MAR = cast(More); + + if (LAR->getLoop() != MAR->getLoop()) + return false; + + // We look at affine expressions only; not for correctness but to keep + // getStepRecurrence cheap. + if (!LAR->isAffine() || !MAR->isAffine()) + return false; + + if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this)) + return false; + + Less = LAR->getStart(); + More = MAR->getStart(); + + // fall through + } + + if (isa(Less) && isa(More)) { + const auto &M = cast(More)->getValue()->getValue(); + const auto &L = cast(Less)->getValue()->getValue(); + C = M - L; + return true; + } + + const SCEV *L, *R; + SCEV::NoWrapFlags Flags; + if (splitBinaryAdd(Less, L, R, Flags)) + if (const auto *LC = dyn_cast(L)) + if (R == More) { + C = -(LC->getValue()->getValue()); + return true; + } + + if (splitBinaryAdd(More, L, R, Flags)) + if (const auto *LC = dyn_cast(L)) + if (R == Less) { + C = LC->getValue()->getValue(); + return true; + } + + return false; +} + +bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, const SCEV *FoundRHS) { + if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT) + return false; + + const auto *AddRecLHS = dyn_cast(LHS); + if (!AddRecLHS) + return false; + + const auto *AddRecFoundLHS = dyn_cast(FoundLHS); + if (!AddRecFoundLHS) + return false; + + // We'd like to let SCEV reason about control dependencies, so we constrain + // both the inequalities to be about add recurrences on the same loop. This + // way we can use isLoopEntryGuardedByCond later. + + const Loop *L = AddRecFoundLHS->getLoop(); + if (L != AddRecLHS->getLoop()) + return false; + + // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1) + // + // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C) + // ... (2) + // + // Informal proof for (2), assuming (1) [*]: + // + // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**] + // + // Then + // + // FoundLHS s< FoundRHS s< INT_MIN - C + // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ] + // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ] + // <=> (FoundLHS + INT_MIN + C + INT_MIN) s< + // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ] + // <=> FoundLHS + C s< FoundRHS + C + // + // [*]: (1) can be proved by ruling out overflow. + // + // [**]: This can be proved by analyzing all the four possibilities: + // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and + // (A s>= 0, B s>= 0). + // + // Note: + // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C" + // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS + // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS + // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is + // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS + + // C)". + + APInt LDiff, RDiff; + if (!computeConstantDifference(FoundLHS, LHS, LDiff) || + !computeConstantDifference(FoundRHS, RHS, RDiff) || + LDiff != RDiff) + return false; + + if (LDiff == 0) + return true; + + APInt FoundRHSLimit; + + if (Pred == CmpInst::ICMP_ULT) { + FoundRHSLimit = -RDiff; + } else { + assert(Pred == CmpInst::ICMP_SLT && "Checked above!"); + FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - RDiff; + } + + // Try to prove (1) or (2), as needed. + return isLoopEntryGuardedByCond(L, Pred, FoundRHS, + getConstant(FoundRHSLimit)); +} + /// isImpliedCondOperands - Test whether the condition described by Pred, /// LHS, and RHS is true whenever the condition described by Pred, FoundLHS, /// and FoundRHS is true. @@ -6973,6 +7871,12 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { + if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + + if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + return isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS) || // ~x < ~y --> x > y @@ -6985,17 +7889,13 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, /// If Expr computes ~A, return A else return nullptr static const SCEV *MatchNotExpr(const SCEV *Expr) { const SCEVAddExpr *Add = dyn_cast(Expr); - if (!Add || Add->getNumOperands() != 2) return nullptr; - - const SCEVConstant *AddLHS = dyn_cast(Add->getOperand(0)); - if (!(AddLHS && AddLHS->getValue()->getValue().isAllOnesValue())) + if (!Add || Add->getNumOperands() != 2 || + !Add->getOperand(0)->isAllOnesValue()) return nullptr; const SCEVMulExpr *AddRHS = dyn_cast(Add->getOperand(1)); - if (!AddRHS || AddRHS->getNumOperands() != 2) return nullptr; - - const SCEVConstant *MulLHS = dyn_cast(AddRHS->getOperand(0)); - if (!(MulLHS && MulLHS->getValue()->getValue().isAllOnesValue())) + if (!AddRHS || AddRHS->getNumOperands() != 2 || + !AddRHS->getOperand(0)->isAllOnesValue()) return nullptr; return AddRHS->getOperand(1); @@ -7026,6 +7926,38 @@ static bool IsMinConsistingOf(ScalarEvolution &SE, return IsMaxConsistingOf(MaybeMaxExpr, SE.getNotSCEV(Candidate)); } +static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + + // If both sides are affine addrecs for the same loop, with equal + // steps, and we know the recurrences don't wrap, then we only + // need to check the predicate on the starting values. + + if (!ICmpInst::isRelational(Pred)) + return false; + + const SCEVAddRecExpr *LAR = dyn_cast(LHS); + if (!LAR) + return false; + const SCEVAddRecExpr *RAR = dyn_cast(RHS); + if (!RAR) + return false; + if (LAR->getLoop() != RAR->getLoop()) + return false; + if (!LAR->isAffine() || !RAR->isAffine()) + return false; + + if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE)) + return false; + + SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ? + SCEV::FlagNSW : SCEV::FlagNUW; + if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW)) + return false; + + return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart()); +} /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max /// expression? @@ -7071,7 +8003,8 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, auto IsKnownPredicateFull = [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { return isKnownPredicateWithRanges(Pred, LHS, RHS) || - IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS); + IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || + IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS); }; switch (Pred) { @@ -7110,6 +8043,47 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, return false; } +/// isImpliedCondOperandsViaRanges - helper function for isImpliedCondOperands. +/// Tries to get cases like "X `sgt` 0 => X - 1 `sgt` -1". +bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, + const SCEV *LHS, + const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + if (!isa(RHS) || !isa(FoundRHS)) + // The restriction on `FoundRHS` be lifted easily -- it exists only to + // reduce the compile time impact of this optimization. + return false; + + const SCEVAddExpr *AddLHS = dyn_cast(LHS); + if (!AddLHS || AddLHS->getOperand(1) != FoundLHS || + !isa(AddLHS->getOperand(0))) + return false; + + APInt ConstFoundRHS = cast(FoundRHS)->getValue()->getValue(); + + // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the + // antecedent "`FoundLHS` `Pred` `FoundRHS`". + ConstantRange FoundLHSRange = + ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS); + + // Since `LHS` is `FoundLHS` + `AddLHS->getOperand(0)`, we can compute a range + // for `LHS`: + APInt Addend = + cast(AddLHS->getOperand(0))->getValue()->getValue(); + ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(Addend)); + + // We can also compute the range of values for `LHS` that satisfy the + // consequent, "`LHS` `Pred` `RHS`": + APInt ConstRHS = cast(RHS)->getValue()->getValue(); + ConstantRange SatisfyingLHSRange = + ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS); + + // The antecedent implies the consequent if every value of `LHS` that + // satisfies the antecedent also satisfies the consequent. + return SatisfyingLHSRange.contains(LHSRange); +} + // Verify if an linear IV with positive stride can overflow when in a // less-than comparison, knowing the invariant term of the comparison, the // stride and the knowledge of NSW/NUW flags on the recurrence. @@ -7118,7 +8092,7 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, if (NoWrap) return false; unsigned BitWidth = getTypeSizeInBits(RHS->getType()); - const SCEV *One = getConstant(Stride->getType(), 1); + const SCEV *One = getOne(Stride->getType()); if (IsSigned) { APInt MaxRHS = getSignedRange(RHS).getSignedMax(); @@ -7147,7 +8121,7 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, if (NoWrap) return false; unsigned BitWidth = getTypeSizeInBits(RHS->getType()); - const SCEV *One = getConstant(Stride->getType(), 1); + const SCEV *One = getOne(Stride->getType()); if (IsSigned) { APInt MinRHS = getSignedRange(RHS).getSignedMin(); @@ -7172,7 +8146,7 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, // stride and presence of the equality in the comparison. const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, bool Equality) { - const SCEV *One = getConstant(Step->getType(), 1); + const SCEV *One = getOne(Step->getType()); Delta = Equality ? getAddExpr(Delta, Step) : getAddExpr(Delta, getMinusSCEV(Step, One)); return getUDivExpr(Delta, Step); @@ -7361,11 +8335,10 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, if (const SCEVConstant *SC = dyn_cast(getStart())) if (!SC->getValue()->isZero()) { SmallVector Operands(op_begin(), op_end()); - Operands[0] = SE.getConstant(SC->getType(), 0); + Operands[0] = SE.getZero(SC->getType()); const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(), getNoWrapFlags(FlagNW)); - if (const SCEVAddRecExpr *ShiftedAddRec = - dyn_cast(Shifted)) + if (const auto *ShiftedAddRec = dyn_cast(Shifted)) return ShiftedAddRec->getNumIterationsInRange( Range.subtract(SC->getValue()->getValue()), SE); // This is strange and shouldn't happen. @@ -7374,10 +8347,9 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // The only time we can solve this is when we have all constant indices. // Otherwise, we cannot determine the overflow conditions. - for (unsigned i = 0, e = getNumOperands(); i != e; ++i) - if (!isa(getOperand(i))) - return SE.getCouldNotCompute(); - + if (std::any_of(op_begin(), op_end(), + [](const SCEV *Op) { return !isa(Op);})) + return SE.getCouldNotCompute(); // Okay at this point we know that all elements of the chrec are constants and // that the start element is zero. @@ -7386,7 +8358,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // iteration exits. unsigned BitWidth = SE.getTypeSizeInBits(getType()); if (!Range.contains(APInt(BitWidth, 0))) - return SE.getConstant(getType(), 0); + return SE.getZero(getType()); if (isAffine()) { // If this is an affine expression then we have this situation: @@ -7545,14 +8517,89 @@ struct SCEVCollectTerms { } bool isDone() const { return false; } }; + +// Check if a SCEV contains an AddRecExpr. +struct SCEVHasAddRec { + bool &ContainsAddRec; + + SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) { + ContainsAddRec = false; + } + + bool follow(const SCEV *S) { + if (isa(S)) { + ContainsAddRec = true; + + // Stop recursion: once we collected a term, do not walk its operands. + return false; + } + + // Keep looking. + return true; + } + bool isDone() const { return false; } +}; + +// Find factors that are multiplied with an expression that (possibly as a +// subexpression) contains an AddRecExpr. In the expression: +// +// 8 * (100 + %p * %q * (%a + {0, +, 1}_loop)) +// +// "%p * %q" are factors multiplied by the expression "(%a + {0, +, 1}_loop)" +// that contains the AddRec {0, +, 1}_loop. %p * %q are likely to be array size +// parameters as they form a product with an induction variable. +// +// This collector expects all array size parameters to be in the same MulExpr. +// It might be necessary to later add support for collecting parameters that are +// spread over different nested MulExpr. +struct SCEVCollectAddRecMultiplies { + SmallVectorImpl &Terms; + ScalarEvolution &SE; + + SCEVCollectAddRecMultiplies(SmallVectorImpl &T, ScalarEvolution &SE) + : Terms(T), SE(SE) {} + + bool follow(const SCEV *S) { + if (auto *Mul = dyn_cast(S)) { + bool HasAddRec = false; + SmallVector Operands; + for (auto Op : Mul->operands()) { + if (isa(Op)) { + Operands.push_back(Op); + } else { + bool ContainsAddRec; + SCEVHasAddRec ContiansAddRec(ContainsAddRec); + visitAll(Op, ContiansAddRec); + HasAddRec |= ContainsAddRec; + } + } + if (Operands.size() == 0) + return true; + + if (!HasAddRec) + return false; + + Terms.push_back(SE.getMulExpr(Operands)); + // Stop recursion: once we collected a term, do not walk its operands. + return false; + } + + // Keep looking. + return true; + } + bool isDone() const { return false; } +}; } -/// Find parametric terms in this SCEVAddRecExpr. -void SCEVAddRecExpr::collectParametricTerms( - ScalarEvolution &SE, SmallVectorImpl &Terms) const { +/// Find parametric terms in this SCEVAddRecExpr. We first for parameters in +/// two places: +/// 1) The strides of AddRec expressions. +/// 2) Unknowns that are multiplied with AddRec expressions. +void ScalarEvolution::collectParametricTerms(const SCEV *Expr, + SmallVectorImpl &Terms) { SmallVector Strides; - SCEVCollectStrides StrideCollector(SE, Strides); - visitAll(this, StrideCollector); + SCEVCollectStrides StrideCollector(*this, Strides); + visitAll(Expr, StrideCollector); DEBUG({ dbgs() << "Strides:\n"; @@ -7570,6 +8617,9 @@ void SCEVAddRecExpr::collectParametricTerms( for (const SCEV *T : Terms) dbgs() << *T << "\n"; }); + + SCEVCollectAddRecMultiplies MulCollector(Terms, *this); + visitAll(Expr, MulCollector); } static bool findArrayDimensionsRec(ScalarEvolution &SE, @@ -7730,11 +8780,13 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl &Terms, ScalarEvolution &SE = *const_cast(this); - // Divide all terms by the element size. + // Try to divide all terms by the element size. If term is not divisible by + // element size, proceed with the original term. for (const SCEV *&Term : Terms) { const SCEV *Q, *R; SCEVDivision::divide(SE, Term, ElementSize, &Q, &R); - Term = Q; + if (!Q->isZero()) + Term = Q; } SmallVector NewTerms; @@ -7768,19 +8820,23 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl &Terms, /// Third step of delinearization: compute the access functions for the /// Subscripts based on the dimensions in Sizes. -void SCEVAddRecExpr::computeAccessFunctions( - ScalarEvolution &SE, SmallVectorImpl &Subscripts, - SmallVectorImpl &Sizes) const { +void ScalarEvolution::computeAccessFunctions( + const SCEV *Expr, SmallVectorImpl &Subscripts, + SmallVectorImpl &Sizes) { // Early exit in case this SCEV is not an affine multivariate function. - if (Sizes.empty() || !this->isAffine()) + if (Sizes.empty()) return; - const SCEV *Res = this; + if (auto *AR = dyn_cast(Expr)) + if (!AR->isAffine()) + return; + + const SCEV *Res = Expr; int Last = Sizes.size() - 1; for (int i = Last; i >= 0; i--) { const SCEV *Q, *R; - SCEVDivision::divide(SE, Res, Sizes[i], &Q, &R); + SCEVDivision::divide(*this, Res, Sizes[i], &Q, &R); DEBUG({ dbgs() << "Res: " << *Res << "\n"; @@ -7872,31 +8928,31 @@ void SCEVAddRecExpr::computeAccessFunctions( /// asking for the SCEV of the memory access with respect to all enclosing /// loops, calling SCEV->delinearize on that and printing the results. -void SCEVAddRecExpr::delinearize(ScalarEvolution &SE, +void ScalarEvolution::delinearize(const SCEV *Expr, SmallVectorImpl &Subscripts, SmallVectorImpl &Sizes, - const SCEV *ElementSize) const { + const SCEV *ElementSize) { // First step: collect parametric terms. SmallVector Terms; - collectParametricTerms(SE, Terms); + collectParametricTerms(Expr, Terms); if (Terms.empty()) return; // Second step: find subscript sizes. - SE.findArrayDimensions(Terms, Sizes, ElementSize); + findArrayDimensions(Terms, Sizes, ElementSize); if (Sizes.empty()) return; // Third step: compute the access functions for each subscript. - computeAccessFunctions(SE, Subscripts, Sizes); + computeAccessFunctions(Expr, Subscripts, Sizes); if (Subscripts.empty()) return; DEBUG({ - dbgs() << "succeeded to delinearize " << *this << "\n"; + dbgs() << "succeeded to delinearize " << *Expr << "\n"; dbgs() << "ArrayDecl[UnknownSize]"; for (const SCEV *S : Sizes) dbgs() << "[" << *S << "]"; @@ -7956,58 +9012,55 @@ ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) // ScalarEvolution Class Implementation //===----------------------------------------------------------------------===// -ScalarEvolution::ScalarEvolution() - : FunctionPass(ID), ValuesAtScopes(64), LoopDispositions(64), - BlockDispositions(64), FirstUnknown(nullptr) { - initializeScalarEvolutionPass(*PassRegistry::getPassRegistry()); -} - -bool ScalarEvolution::runOnFunction(Function &F) { - this->F = &F; - AC = &getAnalysis().getAssumptionCache(F); - LI = &getAnalysis().getLoopInfo(); - DL = &F.getParent()->getDataLayout(); - TLI = &getAnalysis().getTLI(); - DT = &getAnalysis().getDomTree(); - return false; -} - -void ScalarEvolution::releaseMemory() { +ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, + AssumptionCache &AC, DominatorTree &DT, + LoopInfo &LI) + : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), + CouldNotCompute(new SCEVCouldNotCompute()), + WalkingBEDominatingConds(false), ProvingSplitPredicate(false), + ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64), + FirstUnknown(nullptr) {} + +ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) + : F(Arg.F), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), LI(Arg.LI), + CouldNotCompute(std::move(Arg.CouldNotCompute)), + ValueExprMap(std::move(Arg.ValueExprMap)), + WalkingBEDominatingConds(false), ProvingSplitPredicate(false), + BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), + ConstantEvolutionLoopExitValue( + std::move(Arg.ConstantEvolutionLoopExitValue)), + ValuesAtScopes(std::move(Arg.ValuesAtScopes)), + LoopDispositions(std::move(Arg.LoopDispositions)), + BlockDispositions(std::move(Arg.BlockDispositions)), + UnsignedRanges(std::move(Arg.UnsignedRanges)), + SignedRanges(std::move(Arg.SignedRanges)), + UniqueSCEVs(std::move(Arg.UniqueSCEVs)), + UniquePreds(std::move(Arg.UniquePreds)), + SCEVAllocator(std::move(Arg.SCEVAllocator)), + FirstUnknown(Arg.FirstUnknown) { + Arg.FirstUnknown = nullptr; +} + +ScalarEvolution::~ScalarEvolution() { // Iterate through all the SCEVUnknown instances and call their // destructors, so that they release their references to their values. - for (SCEVUnknown *U = FirstUnknown; U; U = U->Next) - U->~SCEVUnknown(); + for (SCEVUnknown *U = FirstUnknown; U;) { + SCEVUnknown *Tmp = U; + U = U->Next; + Tmp->~SCEVUnknown(); + } FirstUnknown = nullptr; ValueExprMap.clear(); // Free any extra memory created for ExitNotTakenInfo in the unlikely event // that a loop had multiple computable exits. - for (DenseMap::iterator I = - BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end(); - I != E; ++I) { - I->second.clear(); - } + for (auto &BTCI : BackedgeTakenCounts) + BTCI.second.clear(); assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); - - BackedgeTakenCounts.clear(); - ConstantEvolutionLoopExitValue.clear(); - ValuesAtScopes.clear(); - LoopDispositions.clear(); - BlockDispositions.clear(); - UnsignedRanges.clear(); - SignedRanges.clear(); - UniqueSCEVs.clear(); - SCEVAllocator.Reset(); -} - -void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const { - AU.setPreservesAll(); - AU.addRequired(); - AU.addRequiredTransitive(); - AU.addRequiredTransitive(); - AU.addRequired(); + assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!"); + assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!"); } bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { @@ -8049,7 +9102,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, OS << "\n"; } -void ScalarEvolution::print(raw_ostream &OS, const Module *) const { +void ScalarEvolution::print(raw_ostream &OS) const { // ScalarEvolution's implementation of the print method is to print // out SCEV values of all instructions that are interesting. Doing // this potentially causes it to create new SCEV objects though, @@ -8059,13 +9112,13 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const { ScalarEvolution &SE = *const_cast(this); OS << "Classifying expressions for: "; - F->printAsOperand(OS, /*PrintType=*/false); + F.printAsOperand(OS, /*PrintType=*/false); OS << "\n"; - for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) - if (isSCEVable(I->getType()) && !isa(*I)) { - OS << *I << '\n'; + for (Instruction &I : instructions(F)) + if (isSCEVable(I.getType()) && !isa(I)) { + OS << I << '\n'; OS << " --> "; - const SCEV *SV = SE.getSCEV(&*I); + const SCEV *SV = SE.getSCEV(&I); SV->print(OS); if (!isa(SV)) { OS << " U: "; @@ -8074,7 +9127,7 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const { SE.getSignedRange(SV).print(OS); } - const Loop *L = LI->getLoopFor((*I).getParent()); + const Loop *L = LI.getLoopFor(I.getParent()); const SCEV *AtUse = SE.getSCEVAtScope(SV, L); if (AtUse != SV) { @@ -8102,9 +9155,9 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const { } OS << "Determining loop execution counts for: "; - F->printAsOperand(OS, /*PrintType=*/false); + F.printAsOperand(OS, /*PrintType=*/false); OS << "\n"; - for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) + for (LoopInfo::iterator I = LI.begin(), E = LI.end(); I != E; ++I) PrintLoopInfo(OS, &SE, *I); } @@ -8248,7 +9301,7 @@ ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { // produces the addrec's value is a PHI, and a PHI effectively properly // dominates its entire containing block. const SCEVAddRecExpr *AR = cast(S); - if (!DT->dominates(AR->getLoop()->getHeader(), BB)) + if (!DT.dominates(AR->getLoop()->getHeader(), BB)) return DoesNotDominateBlock; } // FALL THROUGH into SCEVNAryExpr handling. @@ -8285,7 +9338,7 @@ ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { dyn_cast(cast(S)->getValue())) { if (I->getParent() == BB) return DominatesBlock; - if (DT->properlyDominates(I->getParent(), BB)) + if (DT.properlyDominates(I->getParent(), BB)) return ProperlyDominatesBlock; return DoesNotDominateBlock; } @@ -8379,24 +9432,21 @@ getLoopBackedgeTakenCounts(Loop *L, VerifyMap &Map, ScalarEvolution &SE) { } } -void ScalarEvolution::verifyAnalysis() const { - if (!VerifySCEV) - return; - +void ScalarEvolution::verify() const { ScalarEvolution &SE = *const_cast(this); // Gather stringified backedge taken counts for all loops using SCEV's caches. // FIXME: It would be much better to store actual values instead of strings, // but SCEV pointers will change if we drop the caches. VerifyMap BackedgeDumpsOld, BackedgeDumpsNew; - for (LoopInfo::reverse_iterator I = LI->rbegin(), E = LI->rend(); I != E; ++I) + for (LoopInfo::reverse_iterator I = LI.rbegin(), E = LI.rend(); I != E; ++I) getLoopBackedgeTakenCounts(*I, BackedgeDumpsOld, SE); - // Gather stringified backedge taken counts for all loops without using - // SCEV's caches. - SE.releaseMemory(); - for (LoopInfo::reverse_iterator I = LI->rbegin(), E = LI->rend(); I != E; ++I) - getLoopBackedgeTakenCounts(*I, BackedgeDumpsNew, SE); + // Gather stringified backedge taken counts for all loops using a fresh + // ScalarEvolution object. + ScalarEvolution SE2(F, TLI, AC, DT, LI); + for (LoopInfo::reverse_iterator I = LI.rbegin(), E = LI.rend(); I != E; ++I) + getLoopBackedgeTakenCounts(*I, BackedgeDumpsNew, SE2); // Now compare whether they're the same with and without caches. This allows // verifying that no pass changed the cache. @@ -8429,3 +9479,194 @@ void ScalarEvolution::verifyAnalysis() const { // TODO: Verify more things. } + +char ScalarEvolutionAnalysis::PassID; + +ScalarEvolution ScalarEvolutionAnalysis::run(Function &F, + AnalysisManager *AM) { + return ScalarEvolution(F, AM->getResult(F), + AM->getResult(F), + AM->getResult(F), + AM->getResult(F)); +} + +PreservedAnalyses +ScalarEvolutionPrinterPass::run(Function &F, AnalysisManager *AM) { + AM->getResult(F).print(OS); + return PreservedAnalyses::all(); +} + +INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution", + "Scalar Evolution Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution", + "Scalar Evolution Analysis", false, true) +char ScalarEvolutionWrapperPass::ID = 0; + +ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) { + initializeScalarEvolutionWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) { + SE.reset(new ScalarEvolution( + F, getAnalysis().getTLI(), + getAnalysis().getAssumptionCache(F), + getAnalysis().getDomTree(), + getAnalysis().getLoopInfo())); + return false; +} + +void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); } + +void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const { + SE->print(OS); +} + +void ScalarEvolutionWrapperPass::verifyAnalysis() const { + if (!VerifySCEV) + return; + + SE->verify(); +} + +void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive(); + AU.addRequiredTransitive(); + AU.addRequiredTransitive(); + AU.addRequiredTransitive(); +} + +const SCEVPredicate * +ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS, + const SCEVConstant *RHS) { + FoldingSetNodeID ID; + // Unique this node based on the arguments + ID.AddInteger(SCEVPredicate::P_Equal); + ID.AddPointer(LHS); + ID.AddPointer(RHS); + void *IP = nullptr; + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) + return S; + SCEVEqualPredicate *Eq = new (SCEVAllocator) + SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS); + UniquePreds.InsertNode(Eq, IP); + return Eq; +} + +class SCEVPredicateRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, + SCEVUnionPredicate &A) { + SCEVPredicateRewriter Rewriter(SE, A); + return Rewriter.visit(Scev); + } + + SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P) + : SCEVRewriteVisitor(SE), P(P) {} + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + auto ExprPreds = P.getPredicatesForExpr(Expr); + for (auto *Pred : ExprPreds) + if (const auto *IPred = dyn_cast(Pred)) + if (IPred->getLHS() == Expr) + return IPred->getRHS(); + + return Expr; + } + +private: + SCEVUnionPredicate &P; +}; + +const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, + SCEVUnionPredicate &Preds) { + return SCEVPredicateRewriter::rewrite(Scev, *this, Preds); +} + +/// SCEV predicates +SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, + SCEVPredicateKind Kind) + : FastID(ID), Kind(Kind) {} + +SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, + const SCEVUnknown *LHS, + const SCEVConstant *RHS) + : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {} + +bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast(N); + + if (!Op) + return false; + + return Op->LHS == LHS && Op->RHS == RHS; +} + +bool SCEVEqualPredicate::isAlwaysTrue() const { return false; } + +const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; } + +void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; +} + +/// Union predicates don't get cached so create a dummy set ID for it. +SCEVUnionPredicate::SCEVUnionPredicate() + : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} + +bool SCEVUnionPredicate::isAlwaysTrue() const { + return std::all_of(Preds.begin(), Preds.end(), + [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); +} + +ArrayRef +SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) { + auto I = SCEVToPreds.find(Expr); + if (I == SCEVToPreds.end()) + return ArrayRef(); + return I->second; +} + +bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { + if (const auto *Set = dyn_cast(N)) + return std::all_of( + Set->Preds.begin(), Set->Preds.end(), + [this](const SCEVPredicate *I) { return this->implies(I); }); + + auto ScevPredsIt = SCEVToPreds.find(N->getExpr()); + if (ScevPredsIt == SCEVToPreds.end()) + return false; + auto &SCEVPreds = ScevPredsIt->second; + + return std::any_of(SCEVPreds.begin(), SCEVPreds.end(), + [N](const SCEVPredicate *I) { return I->implies(N); }); +} + +const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; } + +void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { + for (auto Pred : Preds) + Pred->print(OS, Depth); +} + +void SCEVUnionPredicate::add(const SCEVPredicate *N) { + if (const auto *Set = dyn_cast(N)) { + for (auto Pred : Set->Preds) + add(Pred); + return; + } + + if (implies(N)) + return; + + const SCEV *Key = N->getExpr(); + assert(Key && "Only SCEVUnionPredicate doesn't have an " + " associated expression!"); + + SCEVToPreds[Key].push_back(N); + Preds.push_back(N); +}