X-Git-Url: http://plrg.eecs.uci.edu/git/?p=oota-llvm.git;a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=177bd23030a7bec57f22db21cafa2d7e5c8462e2;hp=06dbde58c1084488ae087b9efbc0633394c0c11e;hb=dad20b2ae2544708d6a33abdb9bddd0a329f50e0;hpb=facdfc67815a9e8a67170568dd2dd1439006cd27 diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 06dbde58c10..177bd23030a 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -1,4 +1,4 @@ -//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===// +//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===// // // The LLVM Compiler Infrastructure // @@ -59,9 +59,11 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionTracker.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" @@ -78,6 +80,7 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -113,6 +116,7 @@ VerifySCEV("verify-scev", INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution", "Scalar Evolution Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionTracker) INITIALIZE_PASS_DEPENDENCY(LoopInfo) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) @@ -671,7 +675,321 @@ static void GroupByComplexity(SmallVectorImpl &Ops, } } +static const APInt srem(const SCEVConstant *C1, const SCEVConstant *C2) { + APInt A = C1->getValue()->getValue(); + APInt B = C2->getValue()->getValue(); + uint32_t ABW = A.getBitWidth(); + uint32_t BBW = B.getBitWidth(); + + if (ABW > BBW) + B = B.sext(ABW); + else if (ABW < BBW) + A = A.sext(BBW); + + return APIntOps::srem(A, B); +} + +static const APInt sdiv(const SCEVConstant *C1, const SCEVConstant *C2) { + APInt A = C1->getValue()->getValue(); + APInt B = C2->getValue()->getValue(); + uint32_t ABW = A.getBitWidth(); + uint32_t BBW = B.getBitWidth(); + + if (ABW > BBW) + B = B.sext(ABW); + else if (ABW < BBW) + A = A.sext(BBW); + + return APIntOps::sdiv(A, B); +} + +static const APInt urem(const SCEVConstant *C1, const SCEVConstant *C2) { + APInt A = C1->getValue()->getValue(); + APInt B = C2->getValue()->getValue(); + uint32_t ABW = A.getBitWidth(); + uint32_t BBW = B.getBitWidth(); + + if (ABW > BBW) + B = B.zext(ABW); + else if (ABW < BBW) + A = A.zext(BBW); + + return APIntOps::urem(A, B); +} + +static const APInt udiv(const SCEVConstant *C1, const SCEVConstant *C2) { + APInt A = C1->getValue()->getValue(); + APInt B = C2->getValue()->getValue(); + uint32_t ABW = A.getBitWidth(); + uint32_t BBW = B.getBitWidth(); + + if (ABW > BBW) + B = B.zext(ABW); + else if (ABW < BBW) + A = A.zext(BBW); + + return APIntOps::udiv(A, B); +} + +namespace { +struct FindSCEVSize { + int Size; + FindSCEVSize() : Size(0) {} + + bool follow(const SCEV *S) { + ++Size; + // Keep looking at all operands of S. + return true; + } + bool isDone() const { + return false; + } +}; +} + +// Returns the size of the SCEV S. +static inline int sizeOfSCEV(const SCEV *S) { + FindSCEVSize F; + SCEVTraversal ST(F); + ST.visitAll(S); + return F.Size; +} + +namespace { + +template +struct SCEVDivision : public SCEVVisitor { +public: + // Computes the Quotient and Remainder of the division of Numerator by + // Denominator. + static void divide(ScalarEvolution &SE, const SCEV *Numerator, + const SCEV *Denominator, const SCEV **Quotient, + const SCEV **Remainder) { + assert(Numerator && Denominator && "Uninitialized SCEV"); + + Derived D(SE, Numerator, Denominator); + + // Check for the trivial case here to avoid having to check for it in the + // rest of the code. + if (Numerator == Denominator) { + *Quotient = D.One; + *Remainder = D.Zero; + return; + } + + if (Numerator->isZero()) { + *Quotient = D.Zero; + *Remainder = D.Zero; + return; + } + + // Split the Denominator when it is a product. + if (const SCEVMulExpr *T = dyn_cast(Denominator)) { + const SCEV *Q, *R; + *Quotient = Numerator; + for (const SCEV *Op : T->operands()) { + divide(SE, *Quotient, Op, &Q, &R); + *Quotient = Q; + + // Bail out when the Numerator is not divisible by one of the terms of + // the Denominator. + if (!R->isZero()) { + *Quotient = D.Zero; + *Remainder = Numerator; + return; + } + } + *Remainder = D.Zero; + return; + } + + D.visit(Numerator); + *Quotient = D.Quotient; + *Remainder = D.Remainder; + } + + // Except in the trivial case described above, we do not know how to divide + // Expr by Denominator for the following functions with empty implementation. + void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} + void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} + void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} + void visitUDivExpr(const SCEVUDivExpr *Numerator) {} + void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} + void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} + void visitUnknown(const SCEVUnknown *Numerator) {} + void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} + + void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { + const SCEV *StartQ, *StartR, *StepQ, *StepR; + assert(Numerator->isAffine() && "Numerator should be affine"); + divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); + divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); + Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), + Numerator->getNoWrapFlags()); + Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), + Numerator->getNoWrapFlags()); + } + + void visitAddExpr(const SCEVAddExpr *Numerator) { + SmallVector Qs, Rs; + Type *Ty = Denominator->getType(); + + for (const SCEV *Op : Numerator->operands()) { + const SCEV *Q, *R; + divide(SE, Op, Denominator, &Q, &R); + + // Bail out if types do not match. + if (Ty != Q->getType() || Ty != R->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + Qs.push_back(Q); + Rs.push_back(R); + } + + if (Qs.size() == 1) { + Quotient = Qs[0]; + Remainder = Rs[0]; + return; + } + + Quotient = SE.getAddExpr(Qs); + Remainder = SE.getAddExpr(Rs); + } + + void visitMulExpr(const SCEVMulExpr *Numerator) { + SmallVector Qs; + Type *Ty = Denominator->getType(); + + bool FoundDenominatorTerm = false; + for (const SCEV *Op : Numerator->operands()) { + // Bail out if types do not match. + if (Ty != Op->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + if (FoundDenominatorTerm) { + Qs.push_back(Op); + continue; + } + + // Check whether Denominator divides one of the product operands. + const SCEV *Q, *R; + divide(SE, Op, Denominator, &Q, &R); + if (!R->isZero()) { + Qs.push_back(Op); + continue; + } + + // Bail out if types do not match. + if (Ty != Q->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + FoundDenominatorTerm = true; + Qs.push_back(Q); + } + + if (FoundDenominatorTerm) { + Remainder = Zero; + if (Qs.size() == 1) + Quotient = Qs[0]; + else + Quotient = SE.getMulExpr(Qs); + return; + } + + if (!isa(Denominator)) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + // The Remainder is obtained by replacing Denominator by 0 in Numerator. + ValueToValueMap RewriteMap; + RewriteMap[cast(Denominator)->getValue()] = + cast(Zero)->getValue(); + Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + + if (Remainder->isZero()) { + // The Quotient is obtained by replacing Denominator by 1 in Numerator. + RewriteMap[cast(Denominator)->getValue()] = + cast(One)->getValue(); + Quotient = + SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + return; + } + + // Quotient is (Numerator - Remainder) divided by Denominator. + const SCEV *Q, *R; + const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); + if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) { + // This SCEV does not seem to simplify: fail the division here. + Quotient = Zero; + Remainder = Numerator; + return; + } + divide(SE, Diff, Denominator, &Q, &R); + assert(R == Zero && + "(Numerator - Remainder) should evenly divide Denominator"); + Quotient = Q; + } + +private: + SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, + const SCEV *Denominator) + : SE(S), Denominator(Denominator) { + Zero = SE.getConstant(Denominator->getType(), 0); + One = SE.getConstant(Denominator->getType(), 1); + + // By default, we don't know how to divide Expr by Denominator. + // Providing the default here simplifies the rest of the code. + Quotient = Zero; + Remainder = Numerator; + } + + ScalarEvolution &SE; + const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; + + friend struct SCEVSDivision; + friend struct SCEVUDivision; +}; + +struct SCEVSDivision : public SCEVDivision { + SCEVSDivision(ScalarEvolution &S, const SCEV *Numerator, + const SCEV *Denominator) + : SCEVDivision(S, Numerator, Denominator) {} + + void visitConstant(const SCEVConstant *Numerator) { + if (const SCEVConstant *D = dyn_cast(Denominator)) { + Quotient = SE.getConstant(sdiv(Numerator, D)); + Remainder = SE.getConstant(srem(Numerator, D)); + return; + } + } +}; +struct SCEVUDivision : public SCEVDivision { + SCEVUDivision(ScalarEvolution &S, const SCEV *Numerator, + const SCEV *Denominator) + : SCEVDivision(S, Numerator, Denominator) {} + + void visitConstant(const SCEVConstant *Numerator) { + if (const SCEVConstant *D = dyn_cast(Denominator)) { + Quotient = SE.getConstant(udiv(Numerator, D)); + Remainder = SE.getConstant(urem(Numerator, D)); + return; + } + } +}; + +} //===----------------------------------------------------------------------===// // Simple SCEV method implementations @@ -1889,6 +2207,25 @@ static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { return r; } +/// Determine if any of the operands in this SCEV are a constant or if +/// any of the add or multiply expressions in this SCEV contain a constant. +static bool containsConstantSomewhere(const SCEV *StartExpr) { + SmallVector Ops; + Ops.push_back(StartExpr); + while (!Ops.empty()) { + const SCEV *CurrentExpr = Ops.pop_back_val(); + if (isa(*CurrentExpr)) + return true; + + if (isa(*CurrentExpr) || isa(*CurrentExpr)) { + const auto *CurrentNAry = cast(CurrentExpr); + for (const SCEV *Operand : CurrentNAry->operands()) + Ops.push_back(Operand); + } + } + return false; +} + /// getMulExpr - Get a canonical multiply expression, or something simpler if /// possible. const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, @@ -1928,11 +2265,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // C1*(C2+V) -> C1*C2 + C1*V if (Ops.size() == 2) - if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) - if (Add->getNumOperands() == 2 && - isa(Add->getOperand(0))) - return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), - getMulExpr(LHSC, Add->getOperand(1))); + if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) + // If any of Add's ops are Adds or Muls with a constant, + // apply this transformation as well. + if (Add->getNumOperands() == 2) + if (containsConstantSomewhere(Add)) + return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), + getMulExpr(LHSC, Add->getOperand(1))); ++Idx; while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { @@ -2061,71 +2400,66 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // Okay, if there weren't any loop invariants to be folded, check to see if // there are multiple AddRec's with the same loop induction variable being // multiplied together. If so, we can fold them. + + // {A1,+,A2,+,...,+,An} * {B1,+,B2,+,...,+,Bn} + // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ + // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z + // ]]],+,...up to x=2n}. + // Note that the arguments to choose() are always integers with values + // known at compile time, never SCEV objects. + // + // The implementation avoids pointless extra computations when the two + // addrec's are of different length (mathematically, it's equivalent to + // an infinite stream of zeros on the right). + bool OpsModified = false; for (unsigned OtherIdx = Idx+1; - OtherIdx < Ops.size() && isa(Ops[OtherIdx]); + OtherIdx != Ops.size() && isa(Ops[OtherIdx]); ++OtherIdx) { - if (AddRecLoop != cast(Ops[OtherIdx])->getLoop()) + const SCEVAddRecExpr *OtherAddRec = + dyn_cast(Ops[OtherIdx]); + if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) continue; - // {A1,+,A2,+,...,+,An} * {B1,+,B2,+,...,+,Bn} - // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ - // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z - // ]]],+,...up to x=2n}. - // Note that the arguments to choose() are always integers with values - // known at compile time, never SCEV objects. - // - // The implementation avoids pointless extra computations when the two - // addrec's are of different length (mathematically, it's equivalent to - // an infinite stream of zeros on the right). - bool OpsModified = false; - for (; OtherIdx != Ops.size() && isa(Ops[OtherIdx]); - ++OtherIdx) { - const SCEVAddRecExpr *OtherAddRec = - dyn_cast(Ops[OtherIdx]); - if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) - continue; - - bool Overflow = false; - Type *Ty = AddRec->getType(); - bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; - SmallVector AddRecOps; - for (int x = 0, xe = AddRec->getNumOperands() + - OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { - const SCEV *Term = getConstant(Ty, 0); - for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { - uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); - for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), - ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); - z < ze && !Overflow; ++z) { - uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); - uint64_t Coeff; - if (LargerThan64Bits) - Coeff = umul_ov(Coeff1, Coeff2, Overflow); - else - Coeff = Coeff1*Coeff2; - const SCEV *CoeffTerm = getConstant(Ty, Coeff); - const SCEV *Term1 = AddRec->getOperand(y-z); - const SCEV *Term2 = OtherAddRec->getOperand(z); - Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2)); - } + bool Overflow = false; + Type *Ty = AddRec->getType(); + bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; + SmallVector AddRecOps; + for (int x = 0, xe = AddRec->getNumOperands() + + OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { + const SCEV *Term = getConstant(Ty, 0); + for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { + uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); + for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), + ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); + z < ze && !Overflow; ++z) { + uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); + uint64_t Coeff; + if (LargerThan64Bits) + Coeff = umul_ov(Coeff1, Coeff2, Overflow); + else + Coeff = Coeff1*Coeff2; + const SCEV *CoeffTerm = getConstant(Ty, Coeff); + const SCEV *Term1 = AddRec->getOperand(y-z); + const SCEV *Term2 = OtherAddRec->getOperand(z); + Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2)); } - AddRecOps.push_back(Term); - } - if (!Overflow) { - const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), - SCEV::FlagAnyWrap); - if (Ops.size() == 2) return NewAddRec; - Ops[Idx] = NewAddRec; - Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; - OpsModified = true; - AddRec = dyn_cast(NewAddRec); - if (!AddRec) - break; } + AddRecOps.push_back(Term); + } + if (!Overflow) { + const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), + SCEV::FlagAnyWrap); + if (Ops.size() == 2) return NewAddRec; + Ops[Idx] = NewAddRec; + Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; + OpsModified = true; + AddRec = dyn_cast(NewAddRec); + if (!AddRec) + break; } - if (OpsModified) - return getMulExpr(Ops); } + if (OpsModified) + return getMulExpr(Ops); // Otherwise couldn't fold anything into this recurrence. Move onto the // next one. @@ -3082,7 +3416,8 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { Visited.insert(PN); while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); @@ -3263,7 +3598,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // PHI's incoming blocks are in a different loop, in which case doing so // risks breaking LCSSA form. Instcombine would normally zap these, but // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = SimplifyInstruction(PN, DL, TLI, DT)) + if (Value *V = SimplifyInstruction(PN, DL, TLI, DT, AT)) if (LI->replacementPreservesLCSSAForm(PN, V)) return getSCEV(V); @@ -3395,7 +3730,7 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { // For a SCEVUnknown, ask ValueTracking. unsigned BitWidth = getTypeSizeInBits(U->getType()); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones); + computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AT, nullptr, DT); return Zeros.countTrailingOnes(); } @@ -3403,6 +3738,33 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { return 0; } +/// GetRangeFromMetadata - Helper method to assign a range to V from +/// metadata present in the IR. +static Optional GetRangeFromMetadata(Value *V) { + if (Instruction *I = dyn_cast(V)) { + if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) { + ConstantRange TotalRange( + cast(I->getType())->getBitWidth(), false); + + unsigned NumRanges = MD->getNumOperands() / 2; + assert(NumRanges >= 1); + + for (unsigned i = 0; i < NumRanges; ++i) { + ConstantInt *Lower = + mdconst::extract(MD->getOperand(2 * i + 0)); + ConstantInt *Upper = + mdconst::extract(MD->getOperand(2 * i + 1)); + ConstantRange Range(Lower->getValue(), Upper->getValue()); + TotalRange = TotalRange.unionWith(Range); + } + + return TotalRange; + } + } + + return None; +} + /// getUnsignedRange - Determine the unsigned range for a particular SCEV. /// ConstantRange @@ -3532,9 +3894,14 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) { } if (const SCEVUnknown *U = dyn_cast(S)) { + // Check if the IR explicitly contains !range metadata. + Optional MDRange = GetRangeFromMetadata(U->getValue()); + if (MDRange.hasValue()) + ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); + // For a SCEVUnknown, ask ValueTracking. APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, DL); + computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AT, nullptr, DT); if (Ones == ~Zeros + 1) return setUnsignedRange(U, ConservativeResult); return setUnsignedRange(U, @@ -3683,10 +4050,15 @@ ScalarEvolution::getSignedRange(const SCEV *S) { } if (const SCEVUnknown *U = dyn_cast(S)) { + // Check if the IR explicitly contains !range metadata. + Optional MDRange = GetRangeFromMetadata(U->getValue()); + if (MDRange.hasValue()) + ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); + // For a SCEVUnknown, ask ValueTracking. if (!U->getValue()->getType()->isIntegerTy() && !DL) return setSignedRange(U, ConservativeResult); - unsigned NS = ComputeNumSignBits(U->getValue(), DL); + unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AT, nullptr, DT); if (NS <= 1) return setSignedRange(U, ConservativeResult); return setSignedRange(U, ConservativeResult.intersectWith( @@ -3793,7 +4165,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL, + 0, AT, nullptr, DT); APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); @@ -4070,6 +4443,14 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Iteration Count Computation Code // +unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { + if (BasicBlock *ExitingBB = L->getExitingBlock()) + return getSmallConstantTripCount(L, ExitingBB); + + // No trip count information for multiple exits. + return 0; +} + /// getSmallConstantTripCount - Returns the maximum trip count of this loop as a /// normal unsigned value. Returns 0 if the trip count is unknown or not /// constant. Will also return 0 if the maximum trip count is very large (>= @@ -4080,19 +4461,13 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { /// before taking the branch. For loops with multiple exits, it may not be the /// number times that the loop header executes because the loop may exit /// prematurely via another branch. -/// -/// FIXME: We conservatively call getBackedgeTakenCount(L) instead of -/// getExitCount(L, ExitingBlock) to compute a safe trip count considering all -/// loop exits. getExitCount() may return an exact count for this branch -/// assuming no-signed-wrap. The number of well-defined iterations may actually -/// be higher than this trip count if this exit test is skipped and the loop -/// exits via a different branch. Ideally, getExitCount() would know whether it -/// depends on a NSW assumption, and we would only fall back to a conservative -/// trip count in that case. -unsigned ScalarEvolution:: -getSmallConstantTripCount(Loop *L, BasicBlock * /*ExitingBlock*/) { +unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, + BasicBlock *ExitingBlock) { + assert(ExitingBlock && "Must pass a non-null exiting block!"); + assert(L->isLoopExiting(ExitingBlock) && + "Exiting block must actually branch out of the loop!"); const SCEVConstant *ExitCount = - dyn_cast(getBackedgeTakenCount(L)); + dyn_cast(getExitCount(L, ExitingBlock)); if (!ExitCount) return 0; @@ -4106,6 +4481,14 @@ getSmallConstantTripCount(Loop *L, BasicBlock * /*ExitingBlock*/) { return ((unsigned)ExitConst->getZExtValue()) + 1; } +unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { + if (BasicBlock *ExitingBB = L->getExitingBlock()) + return getSmallConstantTripMultiple(L, ExitingBB); + + // No trip multiple information for multiple exits. + return 0; +} + /// getSmallConstantTripMultiple - Returns the largest constant divisor of the /// trip count of this loop as a normal unsigned value, if possible. This /// means that the actual trip count is always a multiple of the returned @@ -4118,9 +4501,13 @@ getSmallConstantTripCount(Loop *L, BasicBlock * /*ExitingBlock*/) { /// /// As explained in the comments for getSmallConstantTripCount, this assumes /// that control exits the loop via ExitingBlock. -unsigned ScalarEvolution:: -getSmallConstantTripMultiple(Loop *L, BasicBlock * /*ExitingBlock*/) { - const SCEV *ExitCount = getBackedgeTakenCount(L); +unsigned +ScalarEvolution::getSmallConstantTripMultiple(Loop *L, + BasicBlock *ExitingBlock) { + assert(ExitingBlock && "Must pass a non-null exiting block!"); + assert(L->isLoopExiting(ExitingBlock) && + "Exiting block must actually branch out of the loop!"); + const SCEV *ExitCount = getExitCount(L, ExitingBlock); if (ExitCount == getCouldNotCompute()) return 1; @@ -4230,7 +4617,8 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { SmallPtrSet Visited; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); @@ -4282,7 +4670,8 @@ void ScalarEvolution::forgetLoop(const Loop *L) { SmallPtrSet Visited; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); @@ -4316,7 +4705,8 @@ void ScalarEvolution::forgetValue(Value *V) { SmallPtrSet Visited; while (!Worklist.empty()) { I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); @@ -4467,20 +4857,12 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { // non-exiting iterations. Partition the loop exits into two kinds: // LoopMustExits and LoopMayExits. // - // A LoopMustExit meets two requirements: - // - // (a) Its ExitLimit.MustExit flag must be set which indicates that the exit - // test condition cannot be skipped (the tested variable has unit stride or - // the test is less-than or greater-than, rather than a strict inequality). - // - // (b) It must dominate the loop latch, hence must be tested on every loop - // iteration. - // - // If any computable LoopMustExit is found, then MaxBECount is the minimum - // EL.Max of computable LoopMustExits. Otherwise, MaxBECount is - // conservatively the maximum EL.Max, where CouldNotCompute is considered - // greater than any computable EL.Max. - if (EL.MustExit && EL.Max != getCouldNotCompute() && Latch && + // If the exit dominates the loop latch, it is a LoopMustExit otherwise it + // is a LoopMayExit. If any computable LoopMustExit is found, then + // MaxBECount is the minimum EL.Max of computable LoopMustExits. Otherwise, + // MaxBECount is conservatively the maximum EL.Max, where CouldNotCompute is + // considered greater than any computable EL.Max. + if (EL.Max != getCouldNotCompute() && Latch && DT->dominates(ExitBB, Latch)) { if (!MustExitMaxBECount) MustExitMaxBECount = EL.Max; @@ -4567,18 +4949,19 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { return getCouldNotCompute(); } + bool IsOnlyExit = (L->getExitingBlock() != nullptr); TerminatorInst *Term = ExitingBlock->getTerminator(); if (BranchInst *BI = dyn_cast(Term)) { assert(BI->isConditional() && "If unconditional, it can't be in loop!"); // Proceed to the next level to examine the exit condition expression. return ComputeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1), - /*IsSubExpr=*/false); + /*ControlsExit=*/IsOnlyExit); } if (SwitchInst *SI = dyn_cast(Term)) return ComputeExitLimitFromSingleExitSwitch(L, SI, Exit, - /*IsSubExpr=*/false); + /*ControlsExit=*/IsOnlyExit); return getCouldNotCompute(); } @@ -4587,28 +4970,27 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { /// backedge of the specified loop will execute if its exit condition /// were a conditional branch of ExitCond, TBB, and FBB. /// -/// @param IsSubExpr is true if ExitCond does not directly control the exit -/// branch. In this case, we cannot assume that the loop only exits when the -/// condition is true and cannot infer that failing to meet the condition prior -/// to integer wraparound results in undefined behavior. +/// @param ControlsExit is true if ExitCond directly controls the exit +/// branch. In this case, we can assume that the loop exits only if the +/// condition is true and can infer that failing to meet the condition prior to +/// integer wraparound results in undefined behavior. ScalarEvolution::ExitLimit ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool IsSubExpr) { + bool ControlsExit) { // Check if the controlling expression for this loop is an And or Or. if (BinaryOperator *BO = dyn_cast(ExitCond)) { if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. bool EitherMayExit = L->contains(TBB); ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); - bool MustExit = false; if (EitherMayExit) { // Both conditions must be true for the loop to continue executing. // Choose the less conservative count. @@ -4623,7 +5005,6 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; else MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); - MustExit = EL0.MustExit || EL1.MustExit; } else { // Both conditions must be true at the same time for the loop to exit. // For now, be conservative. @@ -4632,21 +5013,19 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; if (EL0.Exact == EL1.Exact) BECount = EL0.Exact; - MustExit = EL0.MustExit && EL1.MustExit; } - return ExitLimit(BECount, MaxBECount, MustExit); + return ExitLimit(BECount, MaxBECount); } if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. bool EitherMayExit = L->contains(FBB); ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); - bool MustExit = false; if (EitherMayExit) { // Both conditions must be false for the loop to continue executing. // Choose the less conservative count. @@ -4661,7 +5040,6 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; else MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); - MustExit = EL0.MustExit || EL1.MustExit; } else { // Both conditions must be false at the same time for the loop to exit. // For now, be conservative. @@ -4670,17 +5048,16 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; if (EL0.Exact == EL1.Exact) BECount = EL0.Exact; - MustExit = EL0.MustExit && EL1.MustExit; } - return ExitLimit(BECount, MaxBECount, MustExit); + return ExitLimit(BECount, MaxBECount); } } // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. if (ICmpInst *ExitCondICmp = dyn_cast(ExitCond)) - return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, IsSubExpr); + return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); // Check for a constant condition. These are normally stripped out by // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to @@ -4707,7 +5084,7 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool IsSubExpr) { + bool ControlsExit) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Cond; @@ -4759,7 +5136,7 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, switch (Cond) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, IsSubExpr); + ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); if (EL.hasAnyInfo()) return EL; break; } @@ -4772,14 +5149,14 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = Cond == ICmpInst::ICMP_SLT; - ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, IsSubExpr); + ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, ControlsExit); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: { // while (X > Y) bool IsSigned = Cond == ICmpInst::ICMP_SGT; - ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, IsSubExpr); + ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit); if (EL.hasAnyInfo()) return EL; break; } @@ -4801,7 +5178,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock, - bool IsSubExpr) { + bool ControlsExit) { assert(!L->contains(ExitingBlock) && "Not an exiting block!"); // Give up if the exit is the default dest of a switch. @@ -4814,7 +5191,7 @@ ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L, const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); // while (X != Y) --> while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, IsSubExpr); + ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); if (EL.hasAnyInfo()) return EL; @@ -5687,7 +6064,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { /// effectively V != 0. We know and take advantage of the fact that this /// expression only being used in a comparison by zero context. ScalarEvolution::ExitLimit -ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr) { +ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // If the value is a constant if (const SCEVConstant *C = dyn_cast(V)) { // If the value is already zero, the branch will execute zero times. @@ -5781,37 +6158,30 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr) { else MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() : -CR.getUnsignedMin()); - return ExitLimit(Distance, MaxBECount, /*MustExit=*/true); + return ExitLimit(Distance, MaxBECount); } - // If the recurrence is known not to wraparound, unsigned divide computes the - // back edge count. (Ideally we would have an "isexact" bit for udiv). We know - // that the value will either become zero (and thus the loop terminates), that - // the loop will terminate through some other exit condition first, or that - // the loop has undefined behavior. This means we can't "miss" the exit - // value, even with nonunit stride, and exit later via the same branch. Note - // that we can skip this exit if loop later exits via a different - // branch. Hence MustExit=false. - // - // This is only valid for expressions that directly compute the loop exit. It - // is invalid for subexpressions in which the loop may exit through this - // branch even if this subexpression is false. In that case, the trip count - // computed by this udiv could be smaller than the number of well-defined - // iterations. - if (!IsSubExpr && AddRec->getNoWrapFlags(SCEV::FlagNW)) { + // If the step exactly divides the distance then unsigned divide computes the + // backedge count. + const SCEV *Q, *R; + ScalarEvolution &SE = *const_cast(this); + SCEVUDivision::divide(SE, Distance, Step, &Q, &R); + if (R->isZero()) { const SCEV *Exact = - getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - return ExitLimit(Exact, Exact, /*MustExit=*/false); + getUDivExactExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); + return ExitLimit(Exact, Exact); } - // If Step is a power of two that evenly divides Start we know that the loop - // will always terminate. Start may not be a constant so we just have the - // number of trailing zeros available. This is safe even in presence of - // overflow as the recurrence will overflow to exactly 0. - const APInt &StepV = StepC->getValue()->getValue(); - if (StepV.isPowerOf2() && - GetMinTrailingZeros(getNegativeSCEV(Start)) >= StepV.countTrailingZeros()) - return getUDivExactExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); + // If the condition controls loop exit (the loop exits only if the expression + // is true) and the addition is no-wrap we can use unsigned divide to + // compute the backedge count. In this case, the step may not divide the + // distance, but we don't care because if the condition is "missed" the loop + // will have undefined behavior due to wrapping. + if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) { + const SCEV *Exact = + getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); + return ExitLimit(Exact, Exact); + } // Then, try to solve the above equation provided that Start is constant. if (const SCEVConstant *StartC = dyn_cast(Start)) @@ -6309,19 +6679,30 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return true; + if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + BasicBlock *Latch = L->getLoopLatch(); if (!Latch) return false; BranchInst *LoopContinuePredicate = dyn_cast(Latch->getTerminator()); - if (!LoopContinuePredicate || - LoopContinuePredicate->isUnconditional()) - return false; + if (LoopContinuePredicate && LoopContinuePredicate->isConditional() && + isImpliedCond(Pred, LHS, RHS, + LoopContinuePredicate->getCondition(), + LoopContinuePredicate->getSuccessor(0) != L->getHeader())) + return true; + + // Check conditions due to any @llvm.assume intrinsics. + for (auto &CI : AT->assumptions(F)) { + if (!DT->dominates(CI, Latch->getTerminator())) + continue; - return isImpliedCond(Pred, LHS, RHS, - LoopContinuePredicate->getCondition(), - LoopContinuePredicate->getSuccessor(0) != L->getHeader()); + if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) + return true; + } + + return false; } /// isLoopEntryGuardedByCond - Test whether entry to the loop is protected @@ -6335,6 +6716,8 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return false; + if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + // Starting at the loop predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors // leading to the original header. @@ -6355,6 +6738,15 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, return true; } + // Check conditions due to any @llvm.assume intrinsics. + for (auto &CI : AT->assumptions(F)) { + if (!DT->dominates(CI, L->getHeader())) + continue; + + if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) + return true; + } + return false; } @@ -6469,6 +6861,66 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, RHS, LHS, FoundLHS, FoundRHS); } + // Check if we can make progress by sharpening ranges. + if (FoundPred == ICmpInst::ICMP_NE && + (isa(FoundLHS) || isa(FoundRHS))) { + + const SCEVConstant *C = nullptr; + const SCEV *V = nullptr; + + if (isa(FoundLHS)) { + C = cast(FoundLHS); + V = FoundRHS; + } else { + C = cast(FoundRHS); + V = FoundLHS; + } + + // The guarding predicate tells us that C != V. If the known range + // of V is [C, t), we can sharpen the range to [C + 1, t). The + // range we consider has to correspond to same signedness as the + // predicate we're interested in folding. + + APInt Min = ICmpInst::isSigned(Pred) ? + getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin(); + + if (Min == C->getValue()->getValue()) { + // Given (V >= Min && V != Min) we conclude V >= (Min + 1). + // This is true even if (Min + 1) wraps around -- in case of + // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)). + + APInt SharperMin = Min + 1; + + switch (Pred) { + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_UGE: + // We know V `Pred` SharperMin. If this implies LHS `Pred` + // RHS, we're done. + if (isImpliedCondOperands(Pred, LHS, RHS, V, + getConstant(SharperMin))) + return true; + + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_UGT: + // We know from the range information that (V `Pred` Min || + // V == Min). We know from the guarding condition that !(V + // == Min). This gives us + // + // V `Pred` Min || V == Min && !(V == Min) + // => V `Pred` Min + // + // If V `Pred` Min implies LHS `Pred` RHS, we're done. + + if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) + return true; + + default: + // No change + break; + } + } + } + // Check whether the actual condition is beyond sufficient. if (FoundPred == ICmpInst::ICMP_EQ) if (ICmpInst::isTrueWhenEqual(Pred)) @@ -6614,13 +7066,13 @@ const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, /// specified less-than comparison will execute. If not computable, return /// CouldNotCompute. /// -/// @param IsSubExpr is true when the LHS < RHS condition does not directly -/// control the branch. In this case, we can only compute an iteration count for -/// a subexpression that cannot overflow before evaluating true. +/// @param ControlsExit is true when the LHS < RHS condition directly controls +/// the branch (loops exits only if condition is true). In this case, we can use +/// NoWrapFlags to skip overflow checks. ScalarEvolution::ExitLimit ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool IsSubExpr) { + bool ControlsExit) { // We handle only IV < Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); @@ -6631,7 +7083,7 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, if (!IV || IV->getLoop() != L || !IV->isAffine()) return getCouldNotCompute(); - bool NoWrap = !IsSubExpr && + bool NoWrap = ControlsExit && IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); const SCEV *Stride = IV->getStepRecurrence(*this); @@ -6651,9 +7103,19 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, : ICmpInst::ICMP_ULT; const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) - End = IsSigned ? getSMaxExpr(RHS, Start) - : getUMaxExpr(RHS, Start); + if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) { + const SCEV *Diff = getMinusSCEV(RHS, Start); + // If we have NoWrap set, then we can assume that the increment won't + // overflow, in which case if RHS - Start is a constant, we don't need to + // do a max operation since we can just figure it out statically + if (NoWrap && isa(Diff)) { + APInt D = dyn_cast(Diff)->getValue()->getValue(); + if (D.isNegative()) + End = Start; + } else + End = IsSigned ? getSMaxExpr(RHS, Start) + : getUMaxExpr(RHS, Start); + } const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); @@ -6684,13 +7146,13 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, if (isa(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, /*MustExit=*/true); + return ExitLimit(BECount, MaxBECount); } ScalarEvolution::ExitLimit ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool IsSubExpr) { + bool ControlsExit) { // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); @@ -6701,7 +7163,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (!IV || IV->getLoop() != L || !IV->isAffine()) return getCouldNotCompute(); - bool NoWrap = !IsSubExpr && + bool NoWrap = ControlsExit && IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); @@ -6722,9 +7184,19 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) - End = IsSigned ? getSMinExpr(RHS, Start) - : getUMinExpr(RHS, Start); + if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { + const SCEV *Diff = getMinusSCEV(RHS, Start); + // If we have NoWrap set, then we can assume that the increment won't + // overflow, in which case if RHS - Start is a constant, we don't need to + // do a max operation since we can just figure it out statically + if (NoWrap && isa(Diff)) { + APInt D = dyn_cast(Diff)->getValue()->getValue(); + if (!D.isNegative()) + End = Start; + } else + End = IsSigned ? getSMinExpr(RHS, Start) + : getUMinExpr(RHS, Start); + } const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); @@ -6756,7 +7228,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (isa(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, /*MustExit=*/true); + return ExitLimit(BECount, MaxBECount); } /// getNumIterationsInRange - Return the number of iterations of this loop that @@ -6849,401 +7321,139 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // Pick the smallest positive root value. if (ConstantInt *CB = dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, - R1->getValue(), R2->getValue()))) { - if (CB->getZExtValue() == false) - std::swap(R1, R2); // R1 is the minimum root now. - - // Make sure the root is not off by one. The returned iteration should - // not be in the range, but the previous one should be. When solving - // for "X*X < 5", for example, we should not return a root of 2. - ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this, - R1->getValue(), - SE); - if (Range.contains(R1Val->getValue())) { - // The next iteration must be out of the range... - ConstantInt *NextVal = - ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1); - - R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); - if (!Range.contains(R1Val->getValue())) - return SE.getConstant(NextVal); - return SE.getCouldNotCompute(); // Something strange happened - } - - // If R1 was not in the range, then it is a good return value. Make - // sure that R1-1 WAS in the range though, just in case. - ConstantInt *NextVal = - ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1); - R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); - if (Range.contains(R1Val->getValue())) - return R1; - return SE.getCouldNotCompute(); // Something strange happened - } - } - } - - return SE.getCouldNotCompute(); -} - -namespace { -struct FindUndefs { - bool Found; - FindUndefs() : Found(false) {} - - bool follow(const SCEV *S) { - if (const SCEVUnknown *C = dyn_cast(S)) { - if (isa(C->getValue())) - Found = true; - } else if (const SCEVConstant *C = dyn_cast(S)) { - if (isa(C->getValue())) - Found = true; - } - - // Keep looking if we haven't found it yet. - return !Found; - } - bool isDone() const { - // Stop recursion if we have found an undef. - return Found; - } -}; -} - -// Return true when S contains at least an undef value. -static inline bool -containsUndefs(const SCEV *S) { - FindUndefs F; - SCEVTraversal ST(F); - ST.visitAll(S); - - return F.Found; -} - -namespace { -// Collect all steps of SCEV expressions. -struct SCEVCollectStrides { - ScalarEvolution &SE; - SmallVectorImpl &Strides; - - SCEVCollectStrides(ScalarEvolution &SE, SmallVectorImpl &S) - : SE(SE), Strides(S) {} - - bool follow(const SCEV *S) { - if (const SCEVAddRecExpr *AR = dyn_cast(S)) - Strides.push_back(AR->getStepRecurrence(SE)); - return true; - } - bool isDone() const { return false; } -}; - -// Collect all SCEVUnknown and SCEVMulExpr expressions. -struct SCEVCollectTerms { - SmallVectorImpl &Terms; - - SCEVCollectTerms(SmallVectorImpl &T) - : Terms(T) {} - - bool follow(const SCEV *S) { - if (isa(S) || isa(S)) { - if (!containsUndefs(S)) - Terms.push_back(S); - - // Stop recursion: once we collected a term, do not walk its operands. - return false; - } - - // Keep looking. - return true; - } - bool isDone() const { return false; } -}; -} - -/// Find parametric terms in this SCEVAddRecExpr. -void SCEVAddRecExpr::collectParametricTerms( - ScalarEvolution &SE, SmallVectorImpl &Terms) const { - SmallVector Strides; - SCEVCollectStrides StrideCollector(SE, Strides); - visitAll(this, StrideCollector); - - DEBUG({ - dbgs() << "Strides:\n"; - for (const SCEV *S : Strides) - dbgs() << *S << "\n"; - }); - - for (const SCEV *S : Strides) { - SCEVCollectTerms TermCollector(Terms); - visitAll(S, TermCollector); - } - - DEBUG({ - dbgs() << "Terms:\n"; - for (const SCEV *T : Terms) - dbgs() << *T << "\n"; - }); -} - -static const APInt srem(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.sext(ABW); - else if (ABW < BBW) - A = A.sext(BBW); - - return APIntOps::srem(A, B); -} - -static const APInt sdiv(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.sext(ABW); - else if (ABW < BBW) - A = A.sext(BBW); - - return APIntOps::sdiv(A, B); -} - -namespace { -struct FindSCEVSize { - int Size; - FindSCEVSize() : Size(0) {} - - bool follow(const SCEV *S) { - ++Size; - // Keep looking at all operands of S. - return true; - } - bool isDone() const { - return false; - } -}; -} - -// Returns the size of the SCEV S. -static inline int sizeOfSCEV(const SCEV *S) { - FindSCEVSize F; - SCEVTraversal ST(F); - ST.visitAll(S); - return F.Size; -} - -namespace { - -struct SCEVDivision : public SCEVVisitor { -public: - // Computes the Quotient and Remainder of the division of Numerator by - // Denominator. - static void divide(ScalarEvolution &SE, const SCEV *Numerator, - const SCEV *Denominator, const SCEV **Quotient, - const SCEV **Remainder) { - assert(Numerator && Denominator && "Uninitialized SCEV"); - - SCEVDivision D(SE, Numerator, Denominator); - - // Check for the trivial case here to avoid having to check for it in the - // rest of the code. - if (Numerator == Denominator) { - *Quotient = D.One; - *Remainder = D.Zero; - return; - } - - if (Numerator->isZero()) { - *Quotient = D.Zero; - *Remainder = D.Zero; - return; - } + R1->getValue(), R2->getValue()))) { + if (CB->getZExtValue() == false) + std::swap(R1, R2); // R1 is the minimum root now. - // Split the Denominator when it is a product. - if (const SCEVMulExpr *T = dyn_cast(Denominator)) { - const SCEV *Q, *R; - *Quotient = Numerator; - for (const SCEV *Op : T->operands()) { - divide(SE, *Quotient, Op, &Q, &R); - *Quotient = Q; + // Make sure the root is not off by one. The returned iteration should + // not be in the range, but the previous one should be. When solving + // for "X*X < 5", for example, we should not return a root of 2. + ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this, + R1->getValue(), + SE); + if (Range.contains(R1Val->getValue())) { + // The next iteration must be out of the range... + ConstantInt *NextVal = + ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1); - // Bail out when the Numerator is not divisible by one of the terms of - // the Denominator. - if (!R->isZero()) { - *Quotient = D.Zero; - *Remainder = Numerator; - return; + R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); + if (!Range.contains(R1Val->getValue())) + return SE.getConstant(NextVal); + return SE.getCouldNotCompute(); // Something strange happened } + + // If R1 was not in the range, then it is a good return value. Make + // sure that R1-1 WAS in the range though, just in case. + ConstantInt *NextVal = + ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1); + R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); + if (Range.contains(R1Val->getValue())) + return R1; + return SE.getCouldNotCompute(); // Something strange happened } - *Remainder = D.Zero; - return; } - - D.visit(Numerator); - *Quotient = D.Quotient; - *Remainder = D.Remainder; } - SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, const SCEV *Denominator) - : SE(S), Denominator(Denominator) { - Zero = SE.getConstant(Denominator->getType(), 0); - One = SE.getConstant(Denominator->getType(), 1); - - // By default, we don't know how to divide Expr by Denominator. - // Providing the default here simplifies the rest of the code. - Quotient = Zero; - Remainder = Numerator; - } + return SE.getCouldNotCompute(); +} - // Except in the trivial case described above, we do not know how to divide - // Expr by Denominator for the following functions with empty implementation. - void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} - void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} - void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} - void visitUDivExpr(const SCEVUDivExpr *Numerator) {} - void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} - void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} - void visitUnknown(const SCEVUnknown *Numerator) {} - void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} +namespace { +struct FindUndefs { + bool Found; + FindUndefs() : Found(false) {} - void visitConstant(const SCEVConstant *Numerator) { - if (const SCEVConstant *D = dyn_cast(Denominator)) { - Quotient = SE.getConstant(sdiv(Numerator, D)); - Remainder = SE.getConstant(srem(Numerator, D)); - return; + bool follow(const SCEV *S) { + if (const SCEVUnknown *C = dyn_cast(S)) { + if (isa(C->getValue())) + Found = true; + } else if (const SCEVConstant *C = dyn_cast(S)) { + if (isa(C->getValue())) + Found = true; } - } - void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { - const SCEV *StartQ, *StartR, *StepQ, *StepR; - assert(Numerator->isAffine() && "Numerator should be affine"); - divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); - divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); - Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), - Numerator->getNoWrapFlags()); - Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), - Numerator->getNoWrapFlags()); + // Keep looking if we haven't found it yet. + return !Found; } + bool isDone() const { + // Stop recursion if we have found an undef. + return Found; + } +}; +} - void visitAddExpr(const SCEVAddExpr *Numerator) { - SmallVector Qs, Rs; - Type *Ty = Denominator->getType(); - - for (const SCEV *Op : Numerator->operands()) { - const SCEV *Q, *R; - divide(SE, Op, Denominator, &Q, &R); +// Return true when S contains at least an undef value. +static inline bool +containsUndefs(const SCEV *S) { + FindUndefs F; + SCEVTraversal ST(F); + ST.visitAll(S); - // Bail out if types do not match. - if (Ty != Q->getType() || Ty != R->getType()) { - Quotient = Zero; - Remainder = Numerator; - return; - } + return F.Found; +} - Qs.push_back(Q); - Rs.push_back(R); - } +namespace { +// Collect all steps of SCEV expressions. +struct SCEVCollectStrides { + ScalarEvolution &SE; + SmallVectorImpl &Strides; - if (Qs.size() == 1) { - Quotient = Qs[0]; - Remainder = Rs[0]; - return; - } + SCEVCollectStrides(ScalarEvolution &SE, SmallVectorImpl &S) + : SE(SE), Strides(S) {} - Quotient = SE.getAddExpr(Qs); - Remainder = SE.getAddExpr(Rs); + bool follow(const SCEV *S) { + if (const SCEVAddRecExpr *AR = dyn_cast(S)) + Strides.push_back(AR->getStepRecurrence(SE)); + return true; } + bool isDone() const { return false; } +}; - void visitMulExpr(const SCEVMulExpr *Numerator) { - SmallVector Qs; - Type *Ty = Denominator->getType(); - - bool FoundDenominatorTerm = false; - for (const SCEV *Op : Numerator->operands()) { - // Bail out if types do not match. - if (Ty != Op->getType()) { - Quotient = Zero; - Remainder = Numerator; - return; - } - - if (FoundDenominatorTerm) { - Qs.push_back(Op); - continue; - } - - // Check whether Denominator divides one of the product operands. - const SCEV *Q, *R; - divide(SE, Op, Denominator, &Q, &R); - if (!R->isZero()) { - Qs.push_back(Op); - continue; - } +// Collect all SCEVUnknown and SCEVMulExpr expressions. +struct SCEVCollectTerms { + SmallVectorImpl &Terms; - // Bail out if types do not match. - if (Ty != Q->getType()) { - Quotient = Zero; - Remainder = Numerator; - return; - } + SCEVCollectTerms(SmallVectorImpl &T) + : Terms(T) {} - FoundDenominatorTerm = true; - Qs.push_back(Q); - } + bool follow(const SCEV *S) { + if (isa(S) || isa(S)) { + if (!containsUndefs(S)) + Terms.push_back(S); - if (FoundDenominatorTerm) { - Remainder = Zero; - if (Qs.size() == 1) - Quotient = Qs[0]; - else - Quotient = SE.getMulExpr(Qs); - return; + // Stop recursion: once we collected a term, do not walk its operands. + return false; } - if (!isa(Denominator)) { - Quotient = Zero; - Remainder = Numerator; - return; - } + // Keep looking. + return true; + } + bool isDone() const { return false; } +}; +} - // The Remainder is obtained by replacing Denominator by 0 in Numerator. - ValueToValueMap RewriteMap; - RewriteMap[cast(Denominator)->getValue()] = - cast(Zero)->getValue(); - Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); +/// Find parametric terms in this SCEVAddRecExpr. +void SCEVAddRecExpr::collectParametricTerms( + ScalarEvolution &SE, SmallVectorImpl &Terms) const { + SmallVector Strides; + SCEVCollectStrides StrideCollector(SE, Strides); + visitAll(this, StrideCollector); - if (Remainder->isZero()) { - // The Quotient is obtained by replacing Denominator by 1 in Numerator. - RewriteMap[cast(Denominator)->getValue()] = - cast(One)->getValue(); - Quotient = - SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); - return; - } + DEBUG({ + dbgs() << "Strides:\n"; + for (const SCEV *S : Strides) + dbgs() << *S << "\n"; + }); - // Quotient is (Numerator - Remainder) divided by Denominator. - const SCEV *Q, *R; - const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); - if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) { - // This SCEV does not seem to simplify: fail the division here. - Quotient = Zero; - Remainder = Numerator; - return; - } - divide(SE, Diff, Denominator, &Q, &R); - assert(R == Zero && - "(Numerator - Remainder) should evenly divide Denominator"); - Quotient = Q; + for (const SCEV *S : Strides) { + SCEVCollectTerms TermCollector(Terms); + visitAll(S, TermCollector); } -private: - ScalarEvolution &SE; - const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; -}; + DEBUG({ + dbgs() << "Terms:\n"; + for (const SCEV *T : Terms) + dbgs() << *T << "\n"; + }); } static bool findArrayDimensionsRec(ScalarEvolution &SE, @@ -7270,7 +7480,7 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE, for (const SCEV *&Term : Terms) { // Normalize the terms before the next call to findArrayDimensionsRec. const SCEV *Q, *R; - SCEVDivision::divide(SE, Term, Step, &Q, &R); + SCEVSDivision::divide(SE, Term, Step, &Q, &R); // Bail out when GCD does not evenly divide one of the terms. if (!R->isZero()) @@ -7407,7 +7617,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl &Terms, // Divide all terms by the element size. for (const SCEV *&Term : Terms) { const SCEV *Q, *R; - SCEVDivision::divide(SE, Term, ElementSize, &Q, &R); + SCEVSDivision::divide(SE, Term, ElementSize, &Q, &R); Term = Q; } @@ -7454,7 +7664,7 @@ void SCEVAddRecExpr::computeAccessFunctions( int Last = Sizes.size() - 1; for (int i = Last; i >= 0; i--) { const SCEV *Q, *R; - SCEVDivision::divide(SE, Res, Sizes[i], &Q, &R); + SCEVSDivision::divide(SE, Res, Sizes[i], &Q, &R); DEBUG({ dbgs() << "Res: " << *Res << "\n"; @@ -7609,7 +7819,7 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { // that until everything else is done. if (U == Old) continue; - if (!Visited.insert(U)) + if (!Visited.insert(U).second) continue; if (PHINode *PN = dyn_cast(U)) SE->ConstantEvolutionLoopExitValue.erase(PN); @@ -7638,6 +7848,7 @@ ScalarEvolution::ScalarEvolution() bool ScalarEvolution::runOnFunction(Function &F) { this->F = &F; + AT = &getAnalysis(); LI = &getAnalysis(); DataLayoutPass *DLP = getAnalysisIfAvailable(); DL = DLP ? &DLP->getDataLayout() : nullptr; @@ -7678,6 +7889,7 @@ void ScalarEvolution::releaseMemory() { void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); + AU.addRequired(); AU.addRequiredTransitive(); AU.addRequiredTransitive(); AU.addRequired();