X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=8fb46dd883b9089c28eed9dd4185a4861c7695e0;hb=dd643f26c43d162e905a07bf0826680aa10f7161;hp=0aeecb76bcc0c24a99e74f22062063fb47d1f7a1;hpb=3e6307698084e7adfc10b739442ae29742beefd0;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 0aeecb76bcc..8fb46dd883b 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -83,9 +83,6 @@ #include using namespace llvm; -STATISTIC(NumBruteForceEvaluations, - "Number of brute force evaluations needed to " - "calculate high-order polynomial exit values"); STATISTIC(NumArrayLenItCounts, "Number of trip counts computed with array length"); STATISTIC(NumTripCountsComputed, @@ -95,16 +92,14 @@ STATISTIC(NumTripCountsNotComputed, STATISTIC(NumBruteForceTripCountsComputed, "Number of loops with trip counts computed by force"); -cl::opt +static cl::opt MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant derived loop"), cl::init(100)); -namespace { - RegisterPass - R("scalar-evolution", "Scalar Evolution Analysis"); -} +static RegisterPass +R("scalar-evolution", "Scalar Evolution Analysis", false, true); char ScalarEvolution::ID = 0; //===----------------------------------------------------------------------===// @@ -119,21 +114,18 @@ void SCEV::dump() const { print(cerr); } -/// getValueRange - Return the tightest constant bounds that this value is -/// known to have. This method is only valid on integer SCEV objects. -ConstantRange SCEV::getValueRange() const { - const Type *Ty = getType(); - assert(Ty->isInteger() && "Can't get range for a non-integer SCEV!"); - // Default to a full range if no better information is available. - return ConstantRange(getBitWidth()); -} - uint32_t SCEV::getBitWidth() const { if (const IntegerType* ITy = dyn_cast(getType())) return ITy->getBitWidth(); return 0; } +bool SCEV::isZero() const { + if (const SCEVConstant *SC = dyn_cast(this)) + return SC->getValue()->isZero(); + return false; +} + SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(scCouldNotCompute) {} @@ -188,10 +180,6 @@ SCEVHandle ScalarEvolution::getConstant(const APInt& Val) { return getConstant(ConstantInt::get(Val)); } -ConstantRange SCEVConstant::getValueRange() const { - return ConstantRange(V->getValue()); -} - const Type *SCEVConstant::getType() const { return V->getType(); } void SCEVConstant::print(std::ostream &OS) const { @@ -216,10 +204,6 @@ SCEVTruncateExpr::~SCEVTruncateExpr() { SCEVTruncates->erase(std::make_pair(Op, Ty)); } -ConstantRange SCEVTruncateExpr::getValueRange() const { - return getOperand()->getValueRange().truncate(getBitWidth()); -} - void SCEVTruncateExpr::print(std::ostream &OS) const { OS << "(truncate " << *Op << " to " << *Ty << ")"; } @@ -242,10 +226,6 @@ SCEVZeroExtendExpr::~SCEVZeroExtendExpr() { SCEVZeroExtends->erase(std::make_pair(Op, Ty)); } -ConstantRange SCEVZeroExtendExpr::getValueRange() const { - return getOperand()->getValueRange().zeroExtend(getBitWidth()); -} - void SCEVZeroExtendExpr::print(std::ostream &OS) const { OS << "(zeroextend " << *Op << " to " << *Ty << ")"; } @@ -268,10 +248,6 @@ SCEVSignExtendExpr::~SCEVSignExtendExpr() { SCEVSignExtends->erase(std::make_pair(Op, Ty)); } -ConstantRange SCEVSignExtendExpr::getValueRange() const { - return getOperand()->getValueRange().signExtend(getBitWidth()); -} - void SCEVSignExtendExpr::print(std::ostream &OS) const { OS << "(signextend " << *Op << " to " << *Ty << ")"; } @@ -431,7 +407,7 @@ namespace { /// than the complexity of the RHS. This comparator is used to canonicalize /// expressions. struct VISIBILITY_HIDDEN SCEVComplexityCompare { - bool operator()(SCEV *LHS, SCEV *RHS) { + bool operator()(const SCEV *LHS, const SCEV *RHS) const { return LHS->getSCEVType() < RHS->getSCEVType(); } }; @@ -452,7 +428,7 @@ static void GroupByComplexity(std::vector &Ops) { if (Ops.size() == 2) { // This is the common case, which also happens to be trivially simple. // Special case it. - if (Ops[0]->getSCEVType() > Ops[1]->getSCEVType()) + if (SCEVComplexityCompare()(Ops[1], Ops[0])) std::swap(Ops[0], Ops[1]); return; } @@ -494,35 +470,20 @@ SCEVHandle ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) { if (Val == 0) C = Constant::getNullValue(Ty); else if (Ty->isFloatingPoint()) - C = ConstantFP::get(Ty, APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : - APFloat::IEEEdouble, Val)); + C = ConstantFP::get(APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : + APFloat::IEEEdouble, Val)); else C = ConstantInt::get(Ty, Val); return getUnknown(C); } -/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is zero -/// extended. -static SCEVHandle getTruncateOrZeroExtend(const SCEVHandle &V, const Type *Ty, - ScalarEvolution &SE) { - const Type *SrcTy = V->getType(); - assert(SrcTy->isInteger() && Ty->isInteger() && - "Cannot truncate or zero extend with non-integer arguments!"); - if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits()) - return V; // No conversion - if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits()) - return SE.getTruncateExpr(V, Ty); - return SE.getZeroExtendExpr(V, Ty); -} - /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V /// SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) { if (SCEVConstant *VC = dyn_cast(V)) return getUnknown(ConstantExpr::getNeg(VC->getValue())); - return getMulExpr(V, getUnknown(ConstantInt::getAllOnesValue(V->getType()))); + return getMulExpr(V, getConstant(ConstantInt::getAllOnesValue(V->getType()))); } /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V @@ -530,7 +491,7 @@ SCEVHandle ScalarEvolution::getNotSCEV(const SCEVHandle &V) { if (SCEVConstant *VC = dyn_cast(V)) return getUnknown(ConstantExpr::getNot(VC->getValue())); - SCEVHandle AllOnes = getUnknown(ConstantInt::getAllOnesValue(V->getType())); + SCEVHandle AllOnes = getConstant(ConstantInt::getAllOnesValue(V->getType())); return getMinusSCEV(AllOnes, V); } @@ -543,85 +504,125 @@ SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS, } -/// BinomialCoefficient - Compute BC(It, K). The result is of the same type as -/// It. Assume, K > 0. +/// BinomialCoefficient - Compute BC(It, K). The result has width W. +// Assume, K > 0. static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, - ScalarEvolution &SE) { + ScalarEvolution &SE, + const IntegerType* ResultTy) { + // Handle the simplest case efficiently. + if (K == 1) + return SE.getTruncateOrZeroExtend(It, ResultTy); + // We are using the following formula for BC(It, K): // // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K! // - // Suppose, W is the bitwidth of It (and of the return value as well). We - // must be prepared for overflow. Hence, we must assure that the result of - // our computation is equal to the accurate one modulo 2^W. Unfortunately, - // division isn't safe in modular arithmetic. This means we must perform the - // whole computation accurately and then truncate the result to W bits. + // Suppose, W is the bitwidth of the return value. We must be prepared for + // overflow. Hence, we must assure that the result of our computation is + // equal to the accurate one modulo 2^W. Unfortunately, division isn't + // safe in modular arithmetic. + // + // However, this code doesn't use exactly that formula; the formula it uses + // is something like the following, where T is the number of factors of 2 in + // K! (i.e. trailing zeros in the binary representation of K!), and ^ is + // exponentiation: + // + // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T) // - // The dividend of the formula is a multiplication of K integers of bitwidth - // W. K*W bits suffice to compute it accurately. + // This formula is trivially equivalent to the previous formula. However, + // this formula can be implemented much more efficiently. The trick is that + // K! / 2^T is odd, and exact division by an odd number *is* safe in modular + // arithmetic. To do exact division in modular arithmetic, all we have + // to do is multiply by the inverse. Therefore, this step can be done at + // width W. + // + // The next issue is how to safely do the division by 2^T. The way this + // is done is by doing the multiplication step at a width of at least W + T + // bits. This way, the bottom W+T bits of the product are accurate. Then, + // when we perform the division by 2^T (which is equivalent to a right shift + // by T), the bottom W bits are accurate. Extra bits are okay; they'll get + // truncated out after the division by 2^T. // - // FIXME: We assume the divisor can be accurately computed using 16-bit - // unsigned integer type. It is true up to K = 8 (AddRecs of length 9). In - // future we may use APInt to use the minimum number of bits necessary to - // compute it accurately. + // In comparison to just directly using the first formula, this technique + // is much more efficient; using the first formula requires W * K bits, + // but this formula less than W + K bits. Also, the first formula requires + // a division step, whereas this formula only requires multiplies and shifts. // - // It is safe to use unsigned division here: the dividend is nonnegative and - // the divisor is positive. + // It doesn't matter whether the subtraction step is done in the calculation + // width or the input iteration count's width; if the subtraction overflows, + // the result must be zero anyway. We prefer here to do it in the width of + // the induction variable because it helps a lot for certain cases; CodeGen + // isn't smart enough to ignore the overflow, which leads to much less + // efficient code if the width of the subtraction is wider than the native + // register width. + // + // (It's possible to not widen at all by pulling out factors of 2 before + // the multiplication; for example, K=2 can be calculated as + // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires + // extra arithmetic, so it's not an obvious win, and it gets + // much more complicated for K > 3.) + + // Protection from insane SCEVs; this bound is conservative, + // but it probably doesn't matter. + if (K > 1000) + return new SCEVCouldNotCompute(); - // Handle the simplest case efficiently. - if (K == 1) - return It; + unsigned W = ResultTy->getBitWidth(); + + // Calculate K! / 2^T and T; we divide out the factors of two before + // multiplying for calculating K! / 2^T to avoid overflow. + // Other overflow doesn't matter because we only care about the bottom + // W bits of the result. + APInt OddFactorial(W, 1); + unsigned T = 1; + for (unsigned i = 3; i <= K; ++i) { + APInt Mult(W, i); + unsigned TwoFactors = Mult.countTrailingZeros(); + T += TwoFactors; + Mult = Mult.lshr(TwoFactors); + OddFactorial *= Mult; + } - assert(K < 9 && "We cannot handle such long AddRecs yet."); - - // FIXME: A temporary hack to remove in future. Arbitrary precision integers - // aren't supported by the code generator yet. For the dividend, the bitwidth - // we use is the smallest power of 2 greater or equal to K*W and less or equal - // to 64. Note that setting the upper bound for bitwidth may still lead to - // miscompilation in some cases. - unsigned DividendBits = 1U << Log2_32_Ceil(K * It->getBitWidth()); - if (DividendBits > 64) - DividendBits = 64; -#if 0 // Waiting for the APInt support in the code generator... - unsigned DividendBits = K * It->getBitWidth(); -#endif + // We need at least W + T bits for the multiplication step + // FIXME: A temporary hack; we round up the bitwidths + // to the nearest power of 2 to be nice to the code generator. + unsigned CalculationBits = 1U << Log2_32_Ceil(W + T); + // FIXME: Temporary hack to avoid generating integers that are too wide. + // Although, it's not completely clear how to determine how much + // widening is safe; for example, on X86, we can't really widen + // beyond 64 because we need to be able to do multiplication + // that's CalculationBits wide, but on X86-64, we can safely widen up to + // 128 bits. + if (CalculationBits > 64) + return new SCEVCouldNotCompute(); - const IntegerType *DividendTy = IntegerType::get(DividendBits); - const SCEVHandle ExIt = SE.getZeroExtendExpr(It, DividendTy); - - // The final number of bits we need to perform the division is the maximum of - // dividend and divisor bitwidths. - const IntegerType *DivisionTy = - IntegerType::get(std::max(DividendBits, 16U)); - - // Compute K! We know K >= 2 here. - unsigned F = 2; - for (unsigned i = 3; i <= K; ++i) - F *= i; - APInt Divisor(DivisionTy->getBitWidth(), F); - - // Handle this case efficiently, it is common to have constant iteration - // counts while computing loop exit values. - if (SCEVConstant *SC = dyn_cast(ExIt)) { - const APInt& N = SC->getValue()->getValue(); - APInt Dividend(N.getBitWidth(), 1); - for (; K; --K) - Dividend *= N-(K-1); - if (DividendTy != DivisionTy) - Dividend = Dividend.zext(DivisionTy->getBitWidth()); - return SE.getConstant(Dividend.udiv(Divisor).trunc(It->getBitWidth())); + // Calcuate 2^T, at width T+W. + APInt DivFactor = APInt(CalculationBits, 1).shl(T); + + // Calculate the multiplicative inverse of K! / 2^T; + // this multiplication factor will perform the exact division by + // K! / 2^T. + APInt Mod = APInt::getSignedMinValue(W+1); + APInt MultiplyFactor = OddFactorial.zext(W+1); + MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod); + MultiplyFactor = MultiplyFactor.trunc(W); + + // Calculate the product, at width T+W + const IntegerType *CalculationTy = IntegerType::get(CalculationBits); + SCEVHandle Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); + for (unsigned i = 1; i != K; ++i) { + SCEVHandle S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType())); + Dividend = SE.getMulExpr(Dividend, + SE.getTruncateOrZeroExtend(S, CalculationTy)); } - - SCEVHandle Dividend = ExIt; - for (unsigned i = 1; i != K; ++i) - Dividend = - SE.getMulExpr(Dividend, - SE.getMinusSCEV(ExIt, SE.getIntegerSCEV(i, DividendTy))); - if (DividendTy != DivisionTy) - Dividend = SE.getZeroExtendExpr(Dividend, DivisionTy); - return - SE.getTruncateExpr(SE.getUDivExpr(Dividend, SE.getConstant(Divisor)), - It->getType()); + + // Divide by 2^T + SCEVHandle DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); + + // Truncate the result, and divide by K! / 2^T. + + return SE.getMulExpr(SE.getConstant(MultiplyFactor), + SE.getTruncateOrZeroExtend(DivResult, ResultTy)); } /// evaluateAtIteration - Return the value of this chain of recurrences at @@ -640,9 +641,12 @@ SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It, // The computation is correct in the face of overflow provided that the // multiplication is performed _after_ the evaluation of the binomial // coefficient. - SCEVHandle Val = SE.getMulExpr(getOperand(i), - BinomialCoefficient(It, i, SE)); - Result = SE.getAddExpr(Result, Val); + SCEVHandle Coeff = BinomialCoefficient(It, i, SE, + cast(getType())); + if (isa(Coeff)) + return Coeff; + + Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff)); } return Result; } @@ -705,6 +709,21 @@ SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, const Type * return Result; } +/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion +/// of the input value to the specified type. If the type must be +/// extended, it is zero extended. +SCEVHandle ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V, + const Type *Ty) { + const Type *SrcTy = V->getType(); + assert(SrcTy->isInteger() && Ty->isInteger() && + "Cannot truncate or zero extend with non-integer arguments!"); + if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits()) + return V; // No conversion + if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits()) + return getTruncateExpr(V, Ty); + return getZeroExtendExpr(V, Ty); +} + // get - Get a canonical add expression, or something simpler if possible. SCEVHandle ScalarEvolution::getAddExpr(std::vector &Ops) { assert(!Ops.empty() && "Cannot get empty add!"); @@ -870,7 +889,7 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector &Ops) { // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { - // NLI + LI + { Start,+,Step} --> NLI + { LI+Start,+,Step } + // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} LIOps.push_back(AddRec->getStart()); std::vector AddRecOps(AddRec->op_begin(), AddRec->op_end()); @@ -1018,7 +1037,7 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector &Ops) { // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { - // NLI * LI * { Start,+,Step} --> NLI * { LI*Start,+,LI*Step } + // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} std::vector NewOps; NewOps.reserve(AddRec->getNumOperands()); if (LIOps.size() == 1) { @@ -1132,11 +1151,23 @@ SCEVHandle ScalarEvolution::getAddRecExpr(std::vector &Operands, const Loop *L) { if (Operands.size() == 1) return Operands[0]; - if (SCEVConstant *StepC = dyn_cast(Operands.back())) - if (StepC->getValue()->isZero()) { - Operands.pop_back(); - return getAddRecExpr(Operands, L); // { X,+,0 } --> X + if (Operands.back()->isZero()) { + Operands.pop_back(); + return getAddRecExpr(Operands, L); // {X,+,0} --> X + } + + // Canonicalize nested AddRecs in by nesting them in order of loop depth. + if (SCEVAddRecExpr *NestedAR = dyn_cast(Operands[0])) { + const Loop* NestedLoop = NestedAR->getLoop(); + if (L->getLoopDepth() < NestedLoop->getLoopDepth()) { + std::vector NestedOperands(NestedAR->op_begin(), + NestedAR->op_end()); + SCEVHandle NestedARHandle(NestedAR); + Operands[0] = NestedAR->getStart(); + NestedOperands[0] = getAddRecExpr(Operands, L); + return getAddRecExpr(NestedOperands, NestedLoop); } + } SCEVAddRecExpr *&Result = (*SCEVAddRecExprs)[std::make_pair(L, std::vector(Operands.begin(), @@ -1372,6 +1403,7 @@ namespace { void setSCEV(Value *V, const SCEVHandle &H) { bool isNew = Scalars.insert(std::make_pair(V, H)).second; assert(isNew && "This entry already existed!"); + isNew = false; } @@ -1445,7 +1477,23 @@ namespace { /// specified less-than comparison will execute. If not computable, return /// UnknownValue. isSigned specifies whether the less-than is signed. SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, - bool isSigned); + bool isSigned, bool trueWhenEqual); + + /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB + /// (which may not be an immediate predecessor) which has exactly one + /// successor from which BB is reachable, or null if no such block is + /// found. + BasicBlock* getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB); + + /// executesAtLeastOnce - Test whether entry to the loop is protected by + /// a conditional between LHS and RHS. + bool executesAtLeastOnce(const Loop *L, bool isSigned, bool trueWhenEqual, + SCEV *LHS, SCEV *RHS); + + /// potentialInfiniteLoop - Test whether the loop might jump over the exit value + /// due to wrapping. + bool potentialInfiniteLoop(SCEV *Stride, SCEV *RHS, bool isSigned, + bool trueWhenEqual); /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a @@ -1695,118 +1743,137 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { if (!isa(V->getType())) return SE.getUnknown(V); - if (Instruction *I = dyn_cast(V)) { - switch (I->getOpcode()) { - case Instruction::Add: - return SE.getAddExpr(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::Mul: - return SE.getMulExpr(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::UDiv: - return SE.getUDivExpr(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::Sub: - return SE.getMinusSCEV(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::Or: - // If the RHS of the Or is a constant, we may have something like: - // X*4+1 which got turned into X*4|1. Handle this as an Add so loop - // optimizations will transparently handle this case. - // - // In order for this transformation to be safe, the LHS must be of the - // form X*(2^n) and the Or constant must be less than 2^n. - if (ConstantInt *CI = dyn_cast(I->getOperand(1))) { - SCEVHandle LHS = getSCEV(I->getOperand(0)); - const APInt &CIVal = CI->getValue(); - if (GetMinTrailingZeros(LHS) >= - (CIVal.getBitWidth() - CIVal.countLeadingZeros())) - return SE.getAddExpr(LHS, getSCEV(I->getOperand(1))); - } - break; - case Instruction::Xor: + unsigned Opcode = Instruction::UserOp1; + if (Instruction *I = dyn_cast(V)) + Opcode = I->getOpcode(); + else if (ConstantExpr *CE = dyn_cast(V)) + Opcode = CE->getOpcode(); + else + return SE.getUnknown(V); + + User *U = cast(V); + switch (Opcode) { + case Instruction::Add: + return SE.getAddExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Mul: + return SE.getMulExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::UDiv: + return SE.getUDivExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Sub: + return SE.getMinusSCEV(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Or: + // If the RHS of the Or is a constant, we may have something like: + // X*4+1 which got turned into X*4|1. Handle this as an Add so loop + // optimizations will transparently handle this case. + // + // In order for this transformation to be safe, the LHS must be of the + // form X*(2^n) and the Or constant must be less than 2^n. + if (ConstantInt *CI = dyn_cast(U->getOperand(1))) { + SCEVHandle LHS = getSCEV(U->getOperand(0)); + const APInt &CIVal = CI->getValue(); + if (GetMinTrailingZeros(LHS) >= + (CIVal.getBitWidth() - CIVal.countLeadingZeros())) + return SE.getAddExpr(LHS, getSCEV(U->getOperand(1))); + } + break; + case Instruction::Xor: + if (ConstantInt *CI = dyn_cast(U->getOperand(1))) { // If the RHS of the xor is a signbit, then this is just an add. // Instcombine turns add of signbit into xor as a strength reduction step. - if (ConstantInt *CI = dyn_cast(I->getOperand(1))) { - if (CI->getValue().isSignBit()) - return SE.getAddExpr(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - else if (CI->isAllOnesValue()) - return SE.getNotSCEV(getSCEV(I->getOperand(0))); - } - break; + if (CI->getValue().isSignBit()) + return SE.getAddExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); - case Instruction::Shl: - // Turn shift left of a constant amount into a multiply. - if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { - uint32_t BitWidth = cast(V->getType())->getBitWidth(); - Constant *X = ConstantInt::get( - APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); - return SE.getMulExpr(getSCEV(I->getOperand(0)), getSCEV(X)); - } - break; + // If the RHS of xor is -1, then this is a not operation. + else if (CI->isAllOnesValue()) + return SE.getNotSCEV(getSCEV(U->getOperand(0))); + } + break; - case Instruction::Trunc: - return SE.getTruncateExpr(getSCEV(I->getOperand(0)), I->getType()); + case Instruction::Shl: + // Turn shift left of a constant amount into a multiply. + if (ConstantInt *SA = dyn_cast(U->getOperand(1))) { + uint32_t BitWidth = cast(V->getType())->getBitWidth(); + Constant *X = ConstantInt::get( + APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); + return SE.getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X)); + } + break; - case Instruction::ZExt: - return SE.getZeroExtendExpr(getSCEV(I->getOperand(0)), I->getType()); + case Instruction::LShr: + // Turn logical shift right of a constant into a unsigned divide. + if (ConstantInt *SA = dyn_cast(U->getOperand(1))) { + uint32_t BitWidth = cast(V->getType())->getBitWidth(); + Constant *X = ConstantInt::get( + APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); + return SE.getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X)); + } + break; - case Instruction::SExt: - return SE.getSignExtendExpr(getSCEV(I->getOperand(0)), I->getType()); + case Instruction::Trunc: + return SE.getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); - case Instruction::BitCast: - // BitCasts are no-op casts so we just eliminate the cast. - if (I->getType()->isInteger() && - I->getOperand(0)->getType()->isInteger()) - return getSCEV(I->getOperand(0)); - break; + case Instruction::ZExt: + return SE.getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType()); - case Instruction::PHI: - return createNodeForPHI(cast(I)); - - case Instruction::Select: - // This could be a smax or umax that was lowered earlier. - // Try to recover it. - if (ICmpInst *ICI = dyn_cast(I->getOperand(0))) { - Value *LHS = ICI->getOperand(0); - Value *RHS = ICI->getOperand(1); - switch (ICI->getPredicate()) { - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - std::swap(LHS, RHS); - // fall through - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - if (LHS == I->getOperand(1) && RHS == I->getOperand(2)) - return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS)); - else if (LHS == I->getOperand(2) && RHS == I->getOperand(1)) - // -smax(-x, -y) == smin(x, y). - return SE.getNegativeSCEV(SE.getSMaxExpr( - SE.getNegativeSCEV(getSCEV(LHS)), - SE.getNegativeSCEV(getSCEV(RHS)))); - break; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: - std::swap(LHS, RHS); - // fall through - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: - if (LHS == I->getOperand(1) && RHS == I->getOperand(2)) - return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS)); - else if (LHS == I->getOperand(2) && RHS == I->getOperand(1)) - // ~umax(~x, ~y) == umin(x, y) - return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)), - SE.getNotSCEV(getSCEV(RHS)))); - break; - default: - break; - } - } + case Instruction::SExt: + return SE.getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType()); - default: // We cannot analyze this expression. - break; + case Instruction::BitCast: + // BitCasts are no-op casts so we just eliminate the cast. + if (U->getType()->isInteger() && + U->getOperand(0)->getType()->isInteger()) + return getSCEV(U->getOperand(0)); + break; + + case Instruction::PHI: + return createNodeForPHI(cast(U)); + + case Instruction::Select: + // This could be a smax or umax that was lowered earlier. + // Try to recover it. + if (ICmpInst *ICI = dyn_cast(U->getOperand(0))) { + Value *LHS = ICI->getOperand(0); + Value *RHS = ICI->getOperand(1); + switch (ICI->getPredicate()) { + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) + return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS)); + else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) + // ~smax(~x, ~y) == smin(x, y). + return SE.getNotSCEV(SE.getSMaxExpr( + SE.getNotSCEV(getSCEV(LHS)), + SE.getNotSCEV(getSCEV(RHS)))); + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) + return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS)); + else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) + // ~umax(~x, ~y) == umin(x, y) + return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)), + SE.getNotSCEV(getSCEV(RHS)))); + break; + default: + break; + } } + + default: // We cannot analyze this expression. + break; } return SE.getUnknown(V); @@ -1886,7 +1953,7 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { ICmpInst *ExitCond = dyn_cast(ExitBr->getCondition()); - // If its not an integer comparison then compute it the hard way. + // If it's not an integer comparison then compute it the hard way. // Note that ICmpInst deals with pointer comparisons too so we must check // the type of the operand. if (ExitCond == 0 || isa(ExitCond->getOperand(0)->getType())) @@ -1964,24 +2031,46 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { break; } case ICmpInst::ICMP_SLT: { - SCEVHandle TC = HowManyLessThans(LHS, RHS, L, true); + SCEVHandle TC = HowManyLessThans(LHS, RHS, L, true, false); if (!isa(TC)) return TC; break; } case ICmpInst::ICMP_SGT: { - SCEVHandle TC = HowManyLessThans(SE.getNegativeSCEV(LHS), - SE.getNegativeSCEV(RHS), L, true); + SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS), + SE.getNotSCEV(RHS), L, true, false); if (!isa(TC)) return TC; break; } case ICmpInst::ICMP_ULT: { - SCEVHandle TC = HowManyLessThans(LHS, RHS, L, false); + SCEVHandle TC = HowManyLessThans(LHS, RHS, L, false, false); if (!isa(TC)) return TC; break; } case ICmpInst::ICMP_UGT: { - SCEVHandle TC = HowManyLessThans(SE.getNegativeSCEV(LHS), - SE.getNegativeSCEV(RHS), L, false); + SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS), + SE.getNotSCEV(RHS), L, false, false); + if (!isa(TC)) return TC; + break; + } + case ICmpInst::ICMP_SLE: { + SCEVHandle TC = HowManyLessThans(LHS, RHS, L, true, true); + if (!isa(TC)) return TC; + break; + } + case ICmpInst::ICMP_SGE: { + SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS), + SE.getNotSCEV(RHS), L, true, true); + if (!isa(TC)) return TC; + break; + } + case ICmpInst::ICMP_ULE: { + SCEVHandle TC = HowManyLessThans(LHS, RHS, L, false, true); + if (!isa(TC)) return TC; + break; + } + case ICmpInst::ICMP_UGE: { + SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS), + SE.getNotSCEV(RHS), L, false, true); if (!isa(TC)) return TC; break; } @@ -2045,7 +2134,7 @@ GetAddressedElementFromGlobal(GlobalVariable *GV, } /// ComputeLoadConstantCompareIterationCount - Given an exit condition of -/// 'icmp op load X, cst', try to se if we can compute the trip count. +/// 'icmp op load X, cst', try to see if we can compute the trip count. SCEVHandle ScalarEvolutionsImpl:: ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, const Loop *L, @@ -2145,13 +2234,14 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { Instruction *I = dyn_cast(V); if (I == 0 || !L->contains(I->getParent())) return 0; - if (PHINode *PN = dyn_cast(I)) + if (PHINode *PN = dyn_cast(I)) { if (L->getHeader() == I->getParent()) return PN; else // We don't currently keep track of the control flow needed to evaluate // PHIs, so we cannot handle PHIs inside of loops. return 0; + } // If we won't be able to constant fold this expression even if the operands // are constants, return early. @@ -2181,8 +2271,6 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { /// reason, return null. static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { if (isa(V)) return PHIVal; - if (GlobalValue *GV = dyn_cast(V)) - return GV; if (Constant *C = dyn_cast(V)) return C; Instruction *I = cast(V); @@ -2435,17 +2523,8 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { // loop iterates. Compute this now. SCEVHandle IterationCount = getIterationCount(AddRec->getLoop()); if (IterationCount == UnknownValue) return UnknownValue; - IterationCount = getTruncateOrZeroExtend(IterationCount, - AddRec->getType(), SE); - - // If the value is affine, simplify the expression evaluation to just - // Start + Step*IterationCount. - if (AddRec->isAffine()) - return SE.getAddExpr(AddRec->getStart(), - SE.getMulExpr(IterationCount, - AddRec->getOperand(1))); - // Otherwise, evaluate it the hard way. + // Then, evaluate the AddRec. return AddRec->evaluateAtIteration(IterationCount, SE); } return UnknownValue; @@ -2455,6 +2534,53 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { return UnknownValue; } +/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the +/// following equation: +/// +/// A * X = B (mod N) +/// +/// where N = 2^BW and BW is the common bit width of A and B. The signedness of +/// A and B isn't important. +/// +/// If the equation does not have a solution, SCEVCouldNotCompute is returned. +static SCEVHandle SolveLinEquationWithOverflow(const APInt &A, const APInt &B, + ScalarEvolution &SE) { + uint32_t BW = A.getBitWidth(); + assert(BW == B.getBitWidth() && "Bit widths must be the same."); + assert(A != 0 && "A must be non-zero."); + + // 1. D = gcd(A, N) + // + // The gcd of A and N may have only one prime factor: 2. The number of + // trailing zeros in A is its multiplicity + uint32_t Mult2 = A.countTrailingZeros(); + // D = 2^Mult2 + + // 2. Check if B is divisible by D. + // + // B is divisible by D if and only if the multiplicity of prime factor 2 for B + // is not less than multiplicity of this prime factor for D. + if (B.countTrailingZeros() < Mult2) + return new SCEVCouldNotCompute(); + + // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic + // modulo (N / D). + // + // (N / D) may need BW+1 bits in its representation. Hence, we'll use this + // bit width during computations. + APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D + APInt Mod(BW + 1, 0); + Mod.set(BW - Mult2); // Mod = N / D + APInt I = AD.multiplicativeInverse(Mod); + + // 4. Compute the minimum unsigned root of the equation: + // I * (B / D) mod (N / D) + APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod); + + // The result is guaranteed to be less than 2^BW so we may truncate it to BW + // bits. + return SE.getConstant(Result.trunc(BW)); +} /// SolveQuadraticEquation - Find the roots of the quadratic equation for the /// given quadratic chrec {L,+,M,+,N}. This returns either the two roots (which @@ -2504,6 +2630,11 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { // The divisions must be performed as signed divisions. APInt NegB(-B); APInt TwoA( A << 1 ); + if (TwoA.isMinValue()) { + SCEV *CNC = new SCEVCouldNotCompute(); + return std::make_pair(CNC, CNC); + } + ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA)); ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA)); @@ -2527,36 +2658,36 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { return UnknownValue; if (AddRec->isAffine()) { - // If this is an affine expression the execution count of this branch is - // equal to: + // If this is an affine expression, the execution count of this branch is + // the minimum unsigned root of the following equation: + // + // Start + Step*N = 0 (mod 2^BW) + // + // equivalent to: // - // (0 - Start/Step) iff Start % Step == 0 + // Step*N = -Start (mod 2^BW) // + // where BW is the common bit width of Start and Step. + // Get the initial value for the loop. SCEVHandle Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); if (isa(Start)) return UnknownValue; - SCEVHandle Step = AddRec->getOperand(1); - Step = getSCEVAtScope(Step, L->getParentLoop()); + SCEVHandle Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); - // Figure out if Start % Step == 0. - // FIXME: We should add DivExpr and RemExpr operations to our AST. if (SCEVConstant *StepC = dyn_cast(Step)) { - if (StepC->getValue()->equalsInt(1)) // N % 1 == 0 - return SE.getNegativeSCEV(Start); // 0 - Start/1 == -Start - if (StepC->getValue()->isAllOnesValue()) // N % -1 == 0 - return Start; // 0 - Start/-1 == Start - - // Check to see if Start is divisible by SC with no remainder. - if (SCEVConstant *StartC = dyn_cast(Start)) { - ConstantInt *StartCC = StartC->getValue(); - Constant *StartNegC = ConstantExpr::getNeg(StartCC); - Constant *Rem = ConstantExpr::getSRem(StartNegC, StepC->getValue()); - if (Rem->isNullValue()) { - Constant *Result =ConstantExpr::getSDiv(StartNegC,StepC->getValue()); - return SE.getUnknown(Result); - } - } + // For now we handle only constant steps. + + // First, handle unitary steps. + if (StepC->getValue()->equalsInt(1)) // 1*N = -Start (mod 2^BW), so: + return SE.getNegativeSCEV(Start); // N = -Start (as unsigned) + if (StepC->getValue()->isAllOnesValue()) // -1*N = -Start (mod 2^BW), so: + return Start; // N = Start (as unsigned) + + // Then, try to solve the above equation provided that Start is constant. + if (SCEVConstant *StartC = dyn_cast(Start)) + return SolveLinEquationWithOverflow(StepC->getValue()->getValue(), + -StartC->getValue()->getValue(),SE); } } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) { // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of @@ -2580,9 +2711,8 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { // value at this index. When solving for "X*X != 5", for example, we // should not accept a root of 2. SCEVHandle Val = AddRec->evaluateAtIteration(R1, SE); - if (SCEVConstant *EvalVal = dyn_cast(Val)) - if (EvalVal->getValue()->isZero()) - return R1; // We found a quadratic root! + if (Val->isZero()) + return R1; // We found a quadratic root! } } } @@ -2601,11 +2731,8 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) { // If the value is a constant, check to see if it is known to be non-zero // already. If so, the backedge will execute zero times. if (SCEVConstant *C = dyn_cast(V)) { - Constant *Zero = Constant::getNullValue(C->getValue()->getType()); - Constant *NonZero = - ConstantExpr::getICmp(ICmpInst::ICMP_NE, C->getValue(), Zero); - if (NonZero == ConstantInt::getTrue()) - return getSCEV(Zero); + if (!C->getValue()->isNullValue()) + return SE.getIntegerSCEV(0, C->getType()); return UnknownValue; // Otherwise it will loop infinitely. } @@ -2614,11 +2741,152 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) { return UnknownValue; } +/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB +/// (which may not be an immediate predecessor) which has exactly one +/// successor from which BB is reachable, or null if no such block is +/// found. +/// +BasicBlock * +ScalarEvolutionsImpl::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { + // If the block has a unique predecessor, the predecessor must have + // no other successors from which BB is reachable. + if (BasicBlock *Pred = BB->getSinglePredecessor()) + return Pred; + + // A loop's header is defined to be a block that dominates the loop. + // If the loop has a preheader, it must be a block that has exactly + // one successor that can reach BB. This is slightly more strict + // than necessary, but works if critical edges are split. + if (Loop *L = LI.getLoopFor(BB)) + return L->getLoopPreheader(); + + return 0; +} + +/// executesAtLeastOnce - Test whether entry to the loop is protected by +/// a conditional between LHS and RHS. +bool ScalarEvolutionsImpl::executesAtLeastOnce(const Loop *L, bool isSigned, + bool trueWhenEqual, + SCEV *LHS, SCEV *RHS) { + BasicBlock *Preheader = L->getLoopPreheader(); + BasicBlock *PreheaderDest = L->getHeader(); + + // Starting at the preheader, climb up the predecessor chain, as long as + // there are predecessors that can be found that have unique successors + // leading to the original header. + for (; Preheader; + PreheaderDest = Preheader, + Preheader = getPredecessorWithUniqueSuccessorForBB(Preheader)) { + + BranchInst *LoopEntryPredicate = + dyn_cast(Preheader->getTerminator()); + if (!LoopEntryPredicate || + LoopEntryPredicate->isUnconditional()) + continue; + + ICmpInst *ICI = dyn_cast(LoopEntryPredicate->getCondition()); + if (!ICI) continue; + + // Now that we found a conditional branch that dominates the loop, check to + // see if it is the comparison we are looking for. + Value *PreCondLHS = ICI->getOperand(0); + Value *PreCondRHS = ICI->getOperand(1); + ICmpInst::Predicate Cond; + if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest) + Cond = ICI->getPredicate(); + else + Cond = ICI->getInversePredicate(); + + switch (Cond) { + case ICmpInst::ICMP_UGT: + if (isSigned || trueWhenEqual) continue; + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_ULT; + break; + case ICmpInst::ICMP_SGT: + if (!isSigned || trueWhenEqual) continue; + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_SLT; + break; + case ICmpInst::ICMP_ULT: + if (isSigned || trueWhenEqual) continue; + break; + case ICmpInst::ICMP_SLT: + if (!isSigned || trueWhenEqual) continue; + break; + case ICmpInst::ICMP_UGE: + if (isSigned || !trueWhenEqual) continue; + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_ULE; + break; + case ICmpInst::ICMP_SGE: + if (!isSigned || !trueWhenEqual) continue; + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_SLE; + break; + case ICmpInst::ICMP_ULE: + if (isSigned || !trueWhenEqual) continue; + break; + case ICmpInst::ICMP_SLE: + if (!isSigned || !trueWhenEqual) continue; + break; + default: + continue; + } + + if (!PreCondLHS->getType()->isInteger()) continue; + + SCEVHandle PreCondLHSSCEV = getSCEV(PreCondLHS); + SCEVHandle PreCondRHSSCEV = getSCEV(PreCondRHS); + if ((LHS == PreCondLHSSCEV && RHS == PreCondRHSSCEV) || + (LHS == SE.getNotSCEV(PreCondRHSSCEV) && + RHS == SE.getNotSCEV(PreCondLHSSCEV))) + return true; + } + + return false; +} + +/// potentialInfiniteLoop - Test whether the loop might jump over the exit value +/// due to wrapping around 2^n. +bool ScalarEvolutionsImpl::potentialInfiniteLoop(SCEV *Stride, SCEV *RHS, + bool isSigned, bool trueWhenEqual) { + // Return true when the distance from RHS to maxint > Stride. + + if (!isa(Stride)) + return true; + SCEVConstant *SC = cast(Stride); + + if (SC->getValue()->isZero()) + return true; + if (!trueWhenEqual && SC->getValue()->isOne()) + return false; + + if (!isa(RHS)) + return true; + SCEVConstant *R = cast(RHS); + + if (isSigned) + return true; // XXX: because we don't have an sdiv scev. + + // If negative, it wraps around every iteration, but we don't care about that. + APInt S = SC->getValue()->getValue().abs(); + + APInt Dist = APInt::getMaxValue(R->getValue()->getBitWidth()) - + R->getValue()->getValue(); + + if (trueWhenEqual) + return !S.ult(Dist); + else + return !S.ule(Dist); +} + /// HowManyLessThans - Return the number of times a backedge containing the /// specified less-than comparison will execute. If not computable, return /// UnknownValue. SCEVHandle ScalarEvolutionsImpl:: -HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) { +HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, + bool isSigned, bool trueWhenEqual) { // Only handle: "ADDREC < LoopInvariant". if (!RHS->isLoopInvariant(L)) return UnknownValue; @@ -2627,27 +2895,50 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) { return UnknownValue; if (AddRec->isAffine()) { - // FORNOW: We only support unit strides. - SCEVHandle One = SE.getIntegerSCEV(1, RHS->getType()); - if (AddRec->getOperand(1) != One) + SCEVHandle Stride = AddRec->getOperand(1); + if (potentialInfiniteLoop(Stride, RHS, isSigned, trueWhenEqual)) return UnknownValue; - // We know the LHS is of the form {n,+,1} and the RHS is some loop-invariant - // m. So, we count the number of iterations in which {n,+,1} < m is true. - // Note that we cannot simply return max(m-n,0) because it's not safe to + // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant + // m. So, we count the number of iterations in which {n,+,s} < m is true. + // Note that we cannot simply return max(m-n,0)/s because it's not safe to // treat m-n as signed nor unsigned due to overflow possibility. // First, we get the value of the LHS in the first iteration: n SCEVHandle Start = AddRec->getOperand(0); - // Then, we get the value of the LHS in the first iteration in which the - // above condition doesn't hold. This equals to max(m,n). - SCEVHandle End = isSigned ? SE.getSMaxExpr(RHS, Start) - : SE.getUMaxExpr(RHS, Start); + SCEVHandle One = SE.getIntegerSCEV(1, RHS->getType()); + + // Assuming that the loop will run at least once, we know that it will + // run (m-n)/s times. + SCEVHandle End = RHS; + + if (!executesAtLeastOnce(L, isSigned, trueWhenEqual, + SE.getMinusSCEV(Start, One), RHS)) { + // If not, we get the value of the LHS in the first iteration in which + // the above condition doesn't hold. This equals to max(m,n). + End = isSigned ? SE.getSMaxExpr(RHS, Start) + : SE.getUMaxExpr(RHS, Start); + } + + // If the expression is less-than-or-equal to, we need to extend the + // loop by one iteration. + // + // The loop won't actually run (m-n)/s times because the loop iterations + // won't divide evenly. For example, if you have {2,+,5} u< 10 the + // division would equal one, but the loop runs twice putting the + // induction variable at 12. + + if (!trueWhenEqual) + // (Stride - 1) is correct only because we know it's unsigned. + // What we really want is to decrease the magnitude of Stride by one. + Start = SE.getMinusSCEV(Start, SE.getMinusSCEV(Stride, One)); + else + Start = SE.getMinusSCEV(Start, Stride); // Finally, we subtract these two values to get the number of times the // backedge is executed: max(m,n)-n. - return SE.getMinusSCEV(End, Start); + return SE.getUDivExpr(SE.getMinusSCEV(End, Start), Stride); } return UnknownValue; @@ -2769,27 +3060,6 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, } } - // Fallback, if this is a general polynomial, figure out the progression - // through brute force: evaluate until we find an iteration that fails the - // test. This is likely to be slow, but getting an accurate trip count is - // incredibly important, we will be able to simplify the exit test a lot, and - // we are almost guaranteed to get a trip count in this case. - ConstantInt *TestVal = ConstantInt::get(getType(), 0); - ConstantInt *EndVal = TestVal; // Stop when we wrap around. - do { - ++NumBruteForceEvaluations; - SCEVHandle Val = evaluateAtIteration(SE.getConstant(TestVal), SE); - if (!isa(Val)) // This shouldn't happen. - return new SCEVCouldNotCompute(); - - // Check to see if we found the value! - if (!Range.contains(cast(Val)->getValue()->getValue())) - return SE.getConstant(TestVal); - - // Increment to test the next index. - TestVal = ConstantInt::get(TestVal->getValue()+1); - } while (TestVal != EndVal); - return new SCEVCouldNotCompute(); } @@ -2878,17 +3148,11 @@ void ScalarEvolution::print(std::ostream &OS, const Module* ) const { for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) if (I->getType()->isInteger()) { OS << *I; - OS << " --> "; + OS << " --> "; SCEVHandle SV = getSCEV(&*I); SV->print(OS); OS << "\t\t"; - if ((*I).getType()->isInteger()) { - ConstantRange Bounds = SV->getValueRange(); - if (!Bounds.isFullSet()) - OS << "Bounds: " << Bounds << " "; - } - if (const Loop *L = LI.getLoopFor((*I).getParent())) { OS << "Exits: "; SCEVHandle ExitValue = getSCEVAtScope(&*I, L->getParentLoop());