X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=8fb46dd883b9089c28eed9dd4185a4861c7695e0;hb=dd643f26c43d162e905a07bf0826680aa10f7161;hp=364e1223ba4492476feecb79e7ed3bef645a4a86;hpb=581b0d453a63f7f657248f80317976995262be11;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 364e1223ba4..8fb46dd883b 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -2,8 +2,8 @@ // // The LLVM Compiler Infrastructure // -// This file was developed by the LLVM research group and is distributed under -// the University of Illinois Open Source License. See LICENSE.TXT for details. +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// // @@ -83,9 +83,6 @@ #include using namespace llvm; -STATISTIC(NumBruteForceEvaluations, - "Number of brute force evaluations needed to " - "calculate high-order polynomial exit values"); STATISTIC(NumArrayLenItCounts, "Number of trip counts computed with array length"); STATISTIC(NumTripCountsComputed, @@ -95,16 +92,15 @@ STATISTIC(NumTripCountsNotComputed, STATISTIC(NumBruteForceTripCountsComputed, "Number of loops with trip counts computed by force"); -cl::opt +static cl::opt MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant derived loop"), cl::init(100)); -namespace { - RegisterPass - R("scalar-evolution", "Scalar Evolution Analysis"); -} +static RegisterPass +R("scalar-evolution", "Scalar Evolution Analysis", false, true); +char ScalarEvolution::ID = 0; //===----------------------------------------------------------------------===// // SCEV class definitions @@ -118,21 +114,18 @@ void SCEV::dump() const { print(cerr); } -/// getValueRange - Return the tightest constant bounds that this value is -/// known to have. This method is only valid on integer SCEV objects. -ConstantRange SCEV::getValueRange() const { - const Type *Ty = getType(); - assert(Ty->isInteger() && "Can't get range for a non-integer SCEV!"); - // Default to a full range if no better information is available. - return ConstantRange(getType()); -} - uint32_t SCEV::getBitWidth() const { if (const IntegerType* ITy = dyn_cast(getType())) return ITy->getBitWidth(); return 0; } +bool SCEV::isZero() const { + if (const SCEVConstant *SC = dyn_cast(this)) + return SC->getValue()->isZero(); + return false; +} + SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(scCouldNotCompute) {} @@ -153,7 +146,8 @@ bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const { SCEVHandle SCEVCouldNotCompute:: replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, - const SCEVHandle &Conc) const { + const SCEVHandle &Conc, + ScalarEvolution &SE) const { return this; } @@ -176,14 +170,14 @@ SCEVConstant::~SCEVConstant() { SCEVConstants->erase(V); } -SCEVHandle SCEVConstant::get(ConstantInt *V) { +SCEVHandle ScalarEvolution::getConstant(ConstantInt *V) { SCEVConstant *&R = (*SCEVConstants)[V]; if (R == 0) R = new SCEVConstant(V); return R; } -ConstantRange SCEVConstant::getValueRange() const { - return ConstantRange(V->getValue()); +SCEVHandle ScalarEvolution::getConstant(const APInt& Val) { + return getConstant(ConstantInt::get(Val)); } const Type *SCEVConstant::getType() const { return V->getType(); } @@ -210,10 +204,6 @@ SCEVTruncateExpr::~SCEVTruncateExpr() { SCEVTruncates->erase(std::make_pair(Op, Ty)); } -ConstantRange SCEVTruncateExpr::getValueRange() const { - return getOperand()->getValueRange().truncate(getType()); -} - void SCEVTruncateExpr::print(std::ostream &OS) const { OS << "(truncate " << *Op << " to " << *Ty << ")"; } @@ -236,14 +226,32 @@ SCEVZeroExtendExpr::~SCEVZeroExtendExpr() { SCEVZeroExtends->erase(std::make_pair(Op, Ty)); } -ConstantRange SCEVZeroExtendExpr::getValueRange() const { - return getOperand()->getValueRange().zeroExtend(getType()); -} - void SCEVZeroExtendExpr::print(std::ostream &OS) const { OS << "(zeroextend " << *Op << " to " << *Ty << ")"; } +// SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any +// particular input. Don't use a SCEVHandle here, or else the object will never +// be deleted! +static ManagedStatic, + SCEVSignExtendExpr*> > SCEVSignExtends; + +SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty) + : SCEV(scSignExtend), Op(op), Ty(ty) { + assert(Op->getType()->isInteger() && Ty->isInteger() && + "Cannot sign extend non-integer value!"); + assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits() + && "This is not an extending conversion!"); +} + +SCEVSignExtendExpr::~SCEVSignExtendExpr() { + SCEVSignExtends->erase(std::make_pair(Op, Ty)); +} + +void SCEVSignExtendExpr::print(std::ostream &OS) const { + OS << "(signextend " << *Op << " to " << *Ty << ")"; +} + // SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any // particular input. Don't use a SCEVHandle here, or else the object will never // be deleted! @@ -267,9 +275,11 @@ void SCEVCommutativeExpr::print(std::ostream &OS) const { SCEVHandle SCEVCommutativeExpr:: replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, - const SCEVHandle &Conc) const { + const SCEVHandle &Conc, + ScalarEvolution &SE) const { for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - SCEVHandle H = getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc); + SCEVHandle H = + getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); if (H != getOperand(i)) { std::vector NewOps; NewOps.reserve(getNumOperands()); @@ -278,12 +288,16 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, NewOps.push_back(H); for (++i; i != e; ++i) NewOps.push_back(getOperand(i)-> - replaceSymbolicValuesWithConcrete(Sym, Conc)); + replaceSymbolicValuesWithConcrete(Sym, Conc, SE)); if (isa(this)) - return SCEVAddExpr::get(NewOps); + return SE.getAddExpr(NewOps); else if (isa(this)) - return SCEVMulExpr::get(NewOps); + return SE.getMulExpr(NewOps); + else if (isa(this)) + return SE.getSMaxExpr(NewOps); + else if (isa(this)) + return SE.getUMaxExpr(NewOps); else assert(0 && "Unknown commutative expr!"); } @@ -292,21 +306,21 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, } -// SCEVSDivs - Only allow the creation of one SCEVSDivExpr for any particular +// SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular // input. Don't use a SCEVHandle here, or else the object will never be // deleted! static ManagedStatic, - SCEVSDivExpr*> > SCEVSDivs; + SCEVUDivExpr*> > SCEVUDivs; -SCEVSDivExpr::~SCEVSDivExpr() { - SCEVSDivs->erase(std::make_pair(LHS, RHS)); +SCEVUDivExpr::~SCEVUDivExpr() { + SCEVUDivs->erase(std::make_pair(LHS, RHS)); } -void SCEVSDivExpr::print(std::ostream &OS) const { - OS << "(" << *LHS << " /s " << *RHS << ")"; +void SCEVUDivExpr::print(std::ostream &OS) const { + OS << "(" << *LHS << " /u " << *RHS << ")"; } -const Type *SCEVSDivExpr::getType() const { +const Type *SCEVUDivExpr::getType() const { return LHS->getType(); } @@ -324,9 +338,11 @@ SCEVAddRecExpr::~SCEVAddRecExpr() { SCEVHandle SCEVAddRecExpr:: replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, - const SCEVHandle &Conc) const { + const SCEVHandle &Conc, + ScalarEvolution &SE) const { for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - SCEVHandle H = getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc); + SCEVHandle H = + getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); if (H != getOperand(i)) { std::vector NewOps; NewOps.reserve(getNumOperands()); @@ -335,9 +351,9 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, NewOps.push_back(H); for (++i; i != e; ++i) NewOps.push_back(getOperand(i)-> - replaceSymbolicValuesWithConcrete(Sym, Conc)); + replaceSymbolicValuesWithConcrete(Sym, Conc, SE)); - return get(NewOps, L); + return SE.getAddRecExpr(NewOps, L); } } return this; @@ -391,7 +407,7 @@ namespace { /// than the complexity of the RHS. This comparator is used to canonicalize /// expressions. struct VISIBILITY_HIDDEN SCEVComplexityCompare { - bool operator()(SCEV *LHS, SCEV *RHS) { + bool operator()(const SCEV *LHS, const SCEV *RHS) const { return LHS->getSCEVType() < RHS->getSCEVType(); } }; @@ -412,7 +428,7 @@ static void GroupByComplexity(std::vector &Ops) { if (Ops.size() == 2) { // This is the common case, which also happens to be trivially simple. // Special case it. - if (Ops[0]->getSCEVType() > Ops[1]->getSCEVType()) + if (SCEVComplexityCompare()(Ops[1], Ops[0])) std::swap(Ops[0], Ops[1]); return; } @@ -449,104 +465,199 @@ static void GroupByComplexity(std::vector &Ops) { /// getIntegerSCEV - Given an integer or FP type, create a constant for the /// specified signed integer value and return a SCEV for the constant. -SCEVHandle SCEVUnknown::getIntegerSCEV(int Val, const Type *Ty) { +SCEVHandle ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) { Constant *C; if (Val == 0) C = Constant::getNullValue(Ty); else if (Ty->isFloatingPoint()) - C = ConstantFP::get(Ty, Val); + C = ConstantFP::get(APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : + APFloat::IEEEdouble, Val)); else C = ConstantInt::get(Ty, Val); - return SCEVUnknown::get(C); -} - -/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is zero -/// extended. -static SCEVHandle getTruncateOrZeroExtend(const SCEVHandle &V, const Type *Ty) { - const Type *SrcTy = V->getType(); - assert(SrcTy->isInteger() && Ty->isInteger() && - "Cannot truncate or zero extend with non-integer arguments!"); - if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits()) - return V; // No conversion - if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits()) - return SCEVTruncateExpr::get(V, Ty); - return SCEVZeroExtendExpr::get(V, Ty); + return getUnknown(C); } /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V /// -SCEVHandle SCEV::getNegativeSCEV(const SCEVHandle &V) { +SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) { if (SCEVConstant *VC = dyn_cast(V)) - return SCEVUnknown::get(ConstantExpr::getNeg(VC->getValue())); + return getUnknown(ConstantExpr::getNeg(VC->getValue())); - return SCEVMulExpr::get(V, SCEVUnknown::getIntegerSCEV(-1, V->getType())); + return getMulExpr(V, getConstant(ConstantInt::getAllOnesValue(V->getType()))); +} + +/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V +SCEVHandle ScalarEvolution::getNotSCEV(const SCEVHandle &V) { + if (SCEVConstant *VC = dyn_cast(V)) + return getUnknown(ConstantExpr::getNot(VC->getValue())); + + SCEVHandle AllOnes = getConstant(ConstantInt::getAllOnesValue(V->getType())); + return getMinusSCEV(AllOnes, V); } /// getMinusSCEV - Return a SCEV corresponding to LHS - RHS. /// -SCEVHandle SCEV::getMinusSCEV(const SCEVHandle &LHS, const SCEVHandle &RHS) { +SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS, + const SCEVHandle &RHS) { // X - Y --> X + -Y - return SCEVAddExpr::get(LHS, SCEV::getNegativeSCEV(RHS)); + return getAddExpr(LHS, getNegativeSCEV(RHS)); } -/// PartialFact - Compute V!/(V-NumSteps)! -static SCEVHandle PartialFact(SCEVHandle V, unsigned NumSteps) { - // Handle this case efficiently, it is common to have constant iteration - // counts while computing loop exit values. - if (SCEVConstant *SC = dyn_cast(V)) { - APInt Val = SC->getValue()->getValue(); - APInt Result(Val.getBitWidth(), 1); - for (; NumSteps; --NumSteps) - Result *= Val-(NumSteps-1); - return SCEVUnknown::get(ConstantInt::get(V->getType(), Result)); +/// BinomialCoefficient - Compute BC(It, K). The result has width W. +// Assume, K > 0. +static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, + ScalarEvolution &SE, + const IntegerType* ResultTy) { + // Handle the simplest case efficiently. + if (K == 1) + return SE.getTruncateOrZeroExtend(It, ResultTy); + + // We are using the following formula for BC(It, K): + // + // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K! + // + // Suppose, W is the bitwidth of the return value. We must be prepared for + // overflow. Hence, we must assure that the result of our computation is + // equal to the accurate one modulo 2^W. Unfortunately, division isn't + // safe in modular arithmetic. + // + // However, this code doesn't use exactly that formula; the formula it uses + // is something like the following, where T is the number of factors of 2 in + // K! (i.e. trailing zeros in the binary representation of K!), and ^ is + // exponentiation: + // + // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T) + // + // This formula is trivially equivalent to the previous formula. However, + // this formula can be implemented much more efficiently. The trick is that + // K! / 2^T is odd, and exact division by an odd number *is* safe in modular + // arithmetic. To do exact division in modular arithmetic, all we have + // to do is multiply by the inverse. Therefore, this step can be done at + // width W. + // + // The next issue is how to safely do the division by 2^T. The way this + // is done is by doing the multiplication step at a width of at least W + T + // bits. This way, the bottom W+T bits of the product are accurate. Then, + // when we perform the division by 2^T (which is equivalent to a right shift + // by T), the bottom W bits are accurate. Extra bits are okay; they'll get + // truncated out after the division by 2^T. + // + // In comparison to just directly using the first formula, this technique + // is much more efficient; using the first formula requires W * K bits, + // but this formula less than W + K bits. Also, the first formula requires + // a division step, whereas this formula only requires multiplies and shifts. + // + // It doesn't matter whether the subtraction step is done in the calculation + // width or the input iteration count's width; if the subtraction overflows, + // the result must be zero anyway. We prefer here to do it in the width of + // the induction variable because it helps a lot for certain cases; CodeGen + // isn't smart enough to ignore the overflow, which leads to much less + // efficient code if the width of the subtraction is wider than the native + // register width. + // + // (It's possible to not widen at all by pulling out factors of 2 before + // the multiplication; for example, K=2 can be calculated as + // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires + // extra arithmetic, so it's not an obvious win, and it gets + // much more complicated for K > 3.) + + // Protection from insane SCEVs; this bound is conservative, + // but it probably doesn't matter. + if (K > 1000) + return new SCEVCouldNotCompute(); + + unsigned W = ResultTy->getBitWidth(); + + // Calculate K! / 2^T and T; we divide out the factors of two before + // multiplying for calculating K! / 2^T to avoid overflow. + // Other overflow doesn't matter because we only care about the bottom + // W bits of the result. + APInt OddFactorial(W, 1); + unsigned T = 1; + for (unsigned i = 3; i <= K; ++i) { + APInt Mult(W, i); + unsigned TwoFactors = Mult.countTrailingZeros(); + T += TwoFactors; + Mult = Mult.lshr(TwoFactors); + OddFactorial *= Mult; } - const Type *Ty = V->getType(); - if (NumSteps == 0) - return SCEVUnknown::getIntegerSCEV(1, Ty); + // We need at least W + T bits for the multiplication step + // FIXME: A temporary hack; we round up the bitwidths + // to the nearest power of 2 to be nice to the code generator. + unsigned CalculationBits = 1U << Log2_32_Ceil(W + T); + // FIXME: Temporary hack to avoid generating integers that are too wide. + // Although, it's not completely clear how to determine how much + // widening is safe; for example, on X86, we can't really widen + // beyond 64 because we need to be able to do multiplication + // that's CalculationBits wide, but on X86-64, we can safely widen up to + // 128 bits. + if (CalculationBits > 64) + return new SCEVCouldNotCompute(); - SCEVHandle Result = V; - for (unsigned i = 1; i != NumSteps; ++i) - Result = SCEVMulExpr::get(Result, SCEV::getMinusSCEV(V, - SCEVUnknown::getIntegerSCEV(i, Ty))); - return Result; -} + // Calcuate 2^T, at width T+W. + APInt DivFactor = APInt(CalculationBits, 1).shl(T); + + // Calculate the multiplicative inverse of K! / 2^T; + // this multiplication factor will perform the exact division by + // K! / 2^T. + APInt Mod = APInt::getSignedMinValue(W+1); + APInt MultiplyFactor = OddFactorial.zext(W+1); + MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod); + MultiplyFactor = MultiplyFactor.trunc(W); + + // Calculate the product, at width T+W + const IntegerType *CalculationTy = IntegerType::get(CalculationBits); + SCEVHandle Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); + for (unsigned i = 1; i != K; ++i) { + SCEVHandle S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType())); + Dividend = SE.getMulExpr(Dividend, + SE.getTruncateOrZeroExtend(S, CalculationTy)); + } + + // Divide by 2^T + SCEVHandle DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); + // Truncate the result, and divide by K! / 2^T. + + return SE.getMulExpr(SE.getConstant(MultiplyFactor), + SE.getTruncateOrZeroExtend(DivResult, ResultTy)); +} /// evaluateAtIteration - Return the value of this chain of recurrences at /// the specified iteration number. We can evaluate this recurrence by /// multiplying each element in the chain by the binomial coefficient /// corresponding to it. In other words, we can evaluate {A,+,B,+,C,+,D} as: /// -/// A*choose(It, 0) + B*choose(It, 1) + C*choose(It, 2) + D*choose(It, 3) +/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) /// -/// FIXME/VERIFY: I don't trust that this is correct in the face of overflow. -/// Is the binomial equation safe using modular arithmetic?? +/// where BC(It, k) stands for binomial coefficient. /// -SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It) const { +SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It, + ScalarEvolution &SE) const { SCEVHandle Result = getStart(); - int Divisor = 1; - const Type *Ty = It->getType(); for (unsigned i = 1, e = getNumOperands(); i != e; ++i) { - SCEVHandle BC = PartialFact(It, i); - Divisor *= i; - SCEVHandle Val = SCEVSDivExpr::get(SCEVMulExpr::get(BC, getOperand(i)), - SCEVUnknown::getIntegerSCEV(Divisor,Ty)); - Result = SCEVAddExpr::get(Result, Val); + // The computation is correct in the face of overflow provided that the + // multiplication is performed _after_ the evaluation of the binomial + // coefficient. + SCEVHandle Coeff = BinomialCoefficient(It, i, SE, + cast(getType())); + if (isa(Coeff)) + return Coeff; + + Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff)); } return Result; } - //===----------------------------------------------------------------------===// // SCEV Expression folder implementations //===----------------------------------------------------------------------===// -SCEVHandle SCEVTruncateExpr::get(const SCEVHandle &Op, const Type *Ty) { +SCEVHandle ScalarEvolution::getTruncateExpr(const SCEVHandle &Op, const Type *Ty) { if (SCEVConstant *SC = dyn_cast(Op)) - return SCEVUnknown::get( + return getUnknown( ConstantExpr::getTrunc(SC->getValue(), Ty)); // If the input value is a chrec scev made out of constants, truncate @@ -556,11 +667,11 @@ SCEVHandle SCEVTruncateExpr::get(const SCEVHandle &Op, const Type *Ty) { for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) // FIXME: This should allow truncation of other expression types! if (isa(AddRec->getOperand(i))) - Operands.push_back(get(AddRec->getOperand(i), Ty)); + Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty)); else break; if (Operands.size() == AddRec->getNumOperands()) - return SCEVAddRecExpr::get(Operands, AddRec->getLoop()); + return getAddRecExpr(Operands, AddRec->getLoop()); } SCEVTruncateExpr *&Result = (*SCEVTruncates)[std::make_pair(Op, Ty)]; @@ -568,9 +679,9 @@ SCEVHandle SCEVTruncateExpr::get(const SCEVHandle &Op, const Type *Ty) { return Result; } -SCEVHandle SCEVZeroExtendExpr::get(const SCEVHandle &Op, const Type *Ty) { +SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op, const Type *Ty) { if (SCEVConstant *SC = dyn_cast(Op)) - return SCEVUnknown::get( + return getUnknown( ConstantExpr::getZExt(SC->getValue(), Ty)); // FIXME: If the input value is a chrec scev, and we can prove that the value @@ -583,8 +694,38 @@ SCEVHandle SCEVZeroExtendExpr::get(const SCEVHandle &Op, const Type *Ty) { return Result; } +SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, const Type *Ty) { + if (SCEVConstant *SC = dyn_cast(Op)) + return getUnknown( + ConstantExpr::getSExt(SC->getValue(), Ty)); + + // FIXME: If the input value is a chrec scev, and we can prove that the value + // did not overflow the old, smaller, value, we can sign extend all of the + // operands (often constants). This would allow analysis of something like + // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } + + SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)]; + if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty); + return Result; +} + +/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion +/// of the input value to the specified type. If the type must be +/// extended, it is zero extended. +SCEVHandle ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V, + const Type *Ty) { + const Type *SrcTy = V->getType(); + assert(SrcTy->isInteger() && Ty->isInteger() && + "Cannot truncate or zero extend with non-integer arguments!"); + if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits()) + return V; // No conversion + if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits()) + return getTruncateExpr(V, Ty); + return getZeroExtendExpr(V, Ty); +} + // get - Get a canonical add expression, or something simpler if possible. -SCEVHandle SCEVAddExpr::get(std::vector &Ops) { +SCEVHandle ScalarEvolution::getAddExpr(std::vector &Ops) { assert(!Ops.empty() && "Cannot get empty add!"); if (Ops.size() == 1) return Ops[0]; @@ -598,22 +739,16 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { assert(Idx < Ops.size()); while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - Constant *Fold = ConstantExpr::getAdd(LHSC->getValue(), RHSC->getValue()); - if (ConstantInt *CI = dyn_cast(Fold)) { - Ops[0] = SCEVConstant::get(CI); - Ops.erase(Ops.begin()+1); // Erase the folded element - if (Ops.size() == 1) return Ops[0]; - LHSC = cast(Ops[0]); - } else { - // If we couldn't fold the expression, move to the next constant. Note - // that this is impossible to happen in practice because we always - // constant fold constant ints to constant ints. - ++Idx; - } + ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() + + RHSC->getValue()->getValue()); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast(Ops[0]); } // If we are left with a constant zero being added, strip it off. - if (cast(Ops[0])->getValue()->isNullValue()) { + if (cast(Ops[0])->getValue()->isZero()) { Ops.erase(Ops.begin()); --Idx; } @@ -629,17 +764,20 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2 // Found a match, merge the two values into a multiply, and add any // remaining values to the result. - SCEVHandle Two = SCEVUnknown::getIntegerSCEV(2, Ty); - SCEVHandle Mul = SCEVMulExpr::get(Ops[i], Two); + SCEVHandle Two = getIntegerSCEV(2, Ty); + SCEVHandle Mul = getMulExpr(Ops[i], Two); if (Ops.size() == 2) return Mul; Ops.erase(Ops.begin()+i, Ops.begin()+i+2); Ops.push_back(Mul); - return SCEVAddExpr::get(Ops); + return getAddExpr(Ops); } - // Okay, now we know the first non-constant operand. If there are add - // operands they would be next. + // Now we know the first non-constant operand. Skip past any cast SCEVs. + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr) + ++Idx; + + // If there are add operands they would be next. if (Idx < Ops.size()) { bool DeletedAdd = false; while (SCEVAddExpr *Add = dyn_cast(Ops[Idx])) { @@ -654,7 +792,7 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { // and they are not necessarily sorted. Recurse to resort and resimplify // any operands we just aquired. if (DeletedAdd) - return get(Ops); + return getAddExpr(Ops); } // Skip over the add expression until we get to a multiply. @@ -677,11 +815,11 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { // Y*Z term. std::vector MulOps(Mul->op_begin(), Mul->op_end()); MulOps.erase(MulOps.begin()+MulOp); - InnerMul = SCEVMulExpr::get(MulOps); + InnerMul = getMulExpr(MulOps); } - SCEVHandle One = SCEVUnknown::getIntegerSCEV(1, Ty); - SCEVHandle AddOne = SCEVAddExpr::get(InnerMul, One); - SCEVHandle OuterMul = SCEVMulExpr::get(AddOne, Ops[AddOp]); + SCEVHandle One = getIntegerSCEV(1, Ty); + SCEVHandle AddOne = getAddExpr(InnerMul, One); + SCEVHandle OuterMul = getMulExpr(AddOne, Ops[AddOp]); if (Ops.size() == 2) return OuterMul; if (AddOp < Idx) { Ops.erase(Ops.begin()+AddOp); @@ -691,7 +829,7 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { Ops.erase(Ops.begin()+AddOp-1); } Ops.push_back(OuterMul); - return SCEVAddExpr::get(Ops); + return getAddExpr(Ops); } // Check this multiply against other multiplies being added together. @@ -709,22 +847,22 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { if (Mul->getNumOperands() != 2) { std::vector MulOps(Mul->op_begin(), Mul->op_end()); MulOps.erase(MulOps.begin()+MulOp); - InnerMul1 = SCEVMulExpr::get(MulOps); + InnerMul1 = getMulExpr(MulOps); } SCEVHandle InnerMul2 = OtherMul->getOperand(OMulOp == 0); if (OtherMul->getNumOperands() != 2) { std::vector MulOps(OtherMul->op_begin(), OtherMul->op_end()); MulOps.erase(MulOps.begin()+OMulOp); - InnerMul2 = SCEVMulExpr::get(MulOps); + InnerMul2 = getMulExpr(MulOps); } - SCEVHandle InnerMulSum = SCEVAddExpr::get(InnerMul1,InnerMul2); - SCEVHandle OuterMul = SCEVMulExpr::get(MulOpSCEV, InnerMulSum); + SCEVHandle InnerMulSum = getAddExpr(InnerMul1,InnerMul2); + SCEVHandle OuterMul = getMulExpr(MulOpSCEV, InnerMulSum); if (Ops.size() == 2) return OuterMul; Ops.erase(Ops.begin()+Idx); Ops.erase(Ops.begin()+OtherMulIdx-1); Ops.push_back(OuterMul); - return SCEVAddExpr::get(Ops); + return getAddExpr(Ops); } } } @@ -751,13 +889,13 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { - // NLI + LI + { Start,+,Step} --> NLI + { LI+Start,+,Step } + // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} LIOps.push_back(AddRec->getStart()); std::vector AddRecOps(AddRec->op_begin(), AddRec->op_end()); - AddRecOps[0] = SCEVAddExpr::get(LIOps); + AddRecOps[0] = getAddExpr(LIOps); - SCEVHandle NewRec = SCEVAddRecExpr::get(AddRecOps, AddRec->getLoop()); + SCEVHandle NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop()); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -767,7 +905,7 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { Ops[i] = NewRec; break; } - return SCEVAddExpr::get(Ops); + return getAddExpr(Ops); } // Okay, if there weren't any loop invariants to be folded, check to see if @@ -786,16 +924,16 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { OtherAddRec->op_end()); break; } - NewOps[i] = SCEVAddExpr::get(NewOps[i], OtherAddRec->getOperand(i)); + NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i)); } - SCEVHandle NewAddRec = SCEVAddRecExpr::get(NewOps, AddRec->getLoop()); + SCEVHandle NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop()); if (Ops.size() == 2) return NewAddRec; Ops.erase(Ops.begin()+Idx); Ops.erase(Ops.begin()+OtherIdx-1); Ops.push_back(NewAddRec); - return SCEVAddExpr::get(Ops); + return getAddExpr(Ops); } } @@ -813,7 +951,7 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { } -SCEVHandle SCEVMulExpr::get(std::vector &Ops) { +SCEVHandle ScalarEvolution::getMulExpr(std::vector &Ops) { assert(!Ops.empty() && "Cannot get empty mul!"); // Sort by complexity, this groups all similar expression types together. @@ -828,32 +966,26 @@ SCEVHandle SCEVMulExpr::get(std::vector &Ops) { if (SCEVAddExpr *Add = dyn_cast(Ops[1])) if (Add->getNumOperands() == 2 && isa(Add->getOperand(0))) - return SCEVAddExpr::get(SCEVMulExpr::get(LHSC, Add->getOperand(0)), - SCEVMulExpr::get(LHSC, Add->getOperand(1))); + return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), + getMulExpr(LHSC, Add->getOperand(1))); ++Idx; while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - Constant *Fold = ConstantExpr::getMul(LHSC->getValue(), RHSC->getValue()); - if (ConstantInt *CI = dyn_cast(Fold)) { - Ops[0] = SCEVConstant::get(CI); - Ops.erase(Ops.begin()+1); // Erase the folded element - if (Ops.size() == 1) return Ops[0]; - LHSC = cast(Ops[0]); - } else { - // If we couldn't fold the expression, move to the next constant. Note - // that this is impossible to happen in practice because we always - // constant fold constant ints to constant ints. - ++Idx; - } + ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() * + RHSC->getValue()->getValue()); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast(Ops[0]); } // If we are left with a constant one being multiplied, strip it off. if (cast(Ops[0])->getValue()->equalsInt(1)) { Ops.erase(Ops.begin()); --Idx; - } else if (cast(Ops[0])->getValue()->isNullValue()) { + } else if (cast(Ops[0])->getValue()->isZero()) { // If we have a multiply of zero, it will always be zero. return Ops[0]; } @@ -881,7 +1013,7 @@ SCEVHandle SCEVMulExpr::get(std::vector &Ops) { // and they are not necessarily sorted. Recurse to resort and resimplify // any operands we just aquired. if (DeletedMul) - return get(Ops); + return getMulExpr(Ops); } // If there are any add recurrences in the operands list, see if any other @@ -905,22 +1037,22 @@ SCEVHandle SCEVMulExpr::get(std::vector &Ops) { // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { - // NLI * LI * { Start,+,Step} --> NLI * { LI*Start,+,LI*Step } + // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} std::vector NewOps; NewOps.reserve(AddRec->getNumOperands()); if (LIOps.size() == 1) { SCEV *Scale = LIOps[0]; for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) - NewOps.push_back(SCEVMulExpr::get(Scale, AddRec->getOperand(i))); + NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i))); } else { for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { std::vector MulOps(LIOps); MulOps.push_back(AddRec->getOperand(i)); - NewOps.push_back(SCEVMulExpr::get(MulOps)); + NewOps.push_back(getMulExpr(MulOps)); } } - SCEVHandle NewRec = SCEVAddRecExpr::get(NewOps, AddRec->getLoop()); + SCEVHandle NewRec = getAddRecExpr(NewOps, AddRec->getLoop()); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -931,7 +1063,7 @@ SCEVHandle SCEVMulExpr::get(std::vector &Ops) { Ops[i] = NewRec; break; } - return SCEVMulExpr::get(Ops); + return getMulExpr(Ops); } // Okay, if there weren't any loop invariants to be folded, check to see if @@ -944,21 +1076,21 @@ SCEVHandle SCEVMulExpr::get(std::vector &Ops) { if (AddRec->getLoop() == OtherAddRec->getLoop()) { // F * G --> {A,+,B} * {C,+,D} --> {A*C,+,F*D + G*B + B*D} SCEVAddRecExpr *F = AddRec, *G = OtherAddRec; - SCEVHandle NewStart = SCEVMulExpr::get(F->getStart(), + SCEVHandle NewStart = getMulExpr(F->getStart(), G->getStart()); - SCEVHandle B = F->getStepRecurrence(); - SCEVHandle D = G->getStepRecurrence(); - SCEVHandle NewStep = SCEVAddExpr::get(SCEVMulExpr::get(F, D), - SCEVMulExpr::get(G, B), - SCEVMulExpr::get(B, D)); - SCEVHandle NewAddRec = SCEVAddRecExpr::get(NewStart, NewStep, - F->getLoop()); + SCEVHandle B = F->getStepRecurrence(*this); + SCEVHandle D = G->getStepRecurrence(*this); + SCEVHandle NewStep = getAddExpr(getMulExpr(F, D), + getMulExpr(G, B), + getMulExpr(B, D)); + SCEVHandle NewAddRec = getAddRecExpr(NewStart, NewStep, + F->getLoop()); if (Ops.size() == 2) return NewAddRec; Ops.erase(Ops.begin()+Idx); Ops.erase(Ops.begin()+OtherIdx-1); Ops.push_back(NewAddRec); - return SCEVMulExpr::get(Ops); + return getMulExpr(Ops); } } @@ -976,31 +1108,29 @@ SCEVHandle SCEVMulExpr::get(std::vector &Ops) { return Result; } -SCEVHandle SCEVSDivExpr::get(const SCEVHandle &LHS, const SCEVHandle &RHS) { +SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) { if (SCEVConstant *RHSC = dyn_cast(RHS)) { if (RHSC->getValue()->equalsInt(1)) - return LHS; // X sdiv 1 --> x - if (RHSC->getValue()->isAllOnesValue()) - return SCEV::getNegativeSCEV(LHS); // X sdiv -1 --> -x + return LHS; // X udiv 1 --> x if (SCEVConstant *LHSC = dyn_cast(LHS)) { Constant *LHSCV = LHSC->getValue(); Constant *RHSCV = RHSC->getValue(); - return SCEVUnknown::get(ConstantExpr::getSDiv(LHSCV, RHSCV)); + return getUnknown(ConstantExpr::getUDiv(LHSCV, RHSCV)); } } // FIXME: implement folding of (X*4)/4 when we know X*4 doesn't overflow. - SCEVSDivExpr *&Result = (*SCEVSDivs)[std::make_pair(LHS, RHS)]; - if (Result == 0) Result = new SCEVSDivExpr(LHS, RHS); + SCEVUDivExpr *&Result = (*SCEVUDivs)[std::make_pair(LHS, RHS)]; + if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS); return Result; } /// SCEVAddRecExpr::get - Get a add recurrence expression for the /// specified loop. Simplify the expression as much as possible. -SCEVHandle SCEVAddRecExpr::get(const SCEVHandle &Start, +SCEVHandle ScalarEvolution::getAddRecExpr(const SCEVHandle &Start, const SCEVHandle &Step, const Loop *L) { std::vector Operands; Operands.push_back(Start); @@ -1008,24 +1138,36 @@ SCEVHandle SCEVAddRecExpr::get(const SCEVHandle &Start, if (StepChrec->getLoop() == L) { Operands.insert(Operands.end(), StepChrec->op_begin(), StepChrec->op_end()); - return get(Operands, L); + return getAddRecExpr(Operands, L); } Operands.push_back(Step); - return get(Operands, L); + return getAddRecExpr(Operands, L); } /// SCEVAddRecExpr::get - Get a add recurrence expression for the /// specified loop. Simplify the expression as much as possible. -SCEVHandle SCEVAddRecExpr::get(std::vector &Operands, +SCEVHandle ScalarEvolution::getAddRecExpr(std::vector &Operands, const Loop *L) { if (Operands.size() == 1) return Operands[0]; - if (SCEVConstant *StepC = dyn_cast(Operands.back())) - if (StepC->getValue()->isNullValue()) { - Operands.pop_back(); - return get(Operands, L); // { X,+,0 } --> X + if (Operands.back()->isZero()) { + Operands.pop_back(); + return getAddRecExpr(Operands, L); // {X,+,0} --> X + } + + // Canonicalize nested AddRecs in by nesting them in order of loop depth. + if (SCEVAddRecExpr *NestedAR = dyn_cast(Operands[0])) { + const Loop* NestedLoop = NestedAR->getLoop(); + if (L->getLoopDepth() < NestedLoop->getLoopDepth()) { + std::vector NestedOperands(NestedAR->op_begin(), + NestedAR->op_end()); + SCEVHandle NestedARHandle(NestedAR); + Operands[0] = NestedAR->getStart(); + NestedOperands[0] = getAddRecExpr(Operands, L); + return getAddRecExpr(NestedOperands, NestedLoop); } + } SCEVAddRecExpr *&Result = (*SCEVAddRecExprs)[std::make_pair(L, std::vector(Operands.begin(), @@ -1034,9 +1176,169 @@ SCEVHandle SCEVAddRecExpr::get(std::vector &Operands, return Result; } -SCEVHandle SCEVUnknown::get(Value *V) { +SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS, + const SCEVHandle &RHS) { + std::vector Ops; + Ops.push_back(LHS); + Ops.push_back(RHS); + return getSMaxExpr(Ops); +} + +SCEVHandle ScalarEvolution::getSMaxExpr(std::vector Ops) { + assert(!Ops.empty() && "Cannot get empty smax!"); + if (Ops.size() == 1) return Ops[0]; + + // Sort by complexity, this groups all similar expression types together. + GroupByComplexity(Ops); + + // If there are any constants, fold them together. + unsigned Idx = 0; + if (SCEVConstant *LHSC = dyn_cast(Ops[0])) { + ++Idx; + assert(Idx < Ops.size()); + while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { + // We found two constants, fold them together! + ConstantInt *Fold = ConstantInt::get( + APIntOps::smax(LHSC->getValue()->getValue(), + RHSC->getValue()->getValue())); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast(Ops[0]); + } + + // If we are left with a constant -inf, strip it off. + if (cast(Ops[0])->getValue()->isMinValue(true)) { + Ops.erase(Ops.begin()); + --Idx; + } + } + + if (Ops.size() == 1) return Ops[0]; + + // Find the first SMax + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr) + ++Idx; + + // Check to see if one of the operands is an SMax. If so, expand its operands + // onto our operand list, and recurse to simplify. + if (Idx < Ops.size()) { + bool DeletedSMax = false; + while (SCEVSMaxExpr *SMax = dyn_cast(Ops[Idx])) { + Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end()); + Ops.erase(Ops.begin()+Idx); + DeletedSMax = true; + } + + if (DeletedSMax) + return getSMaxExpr(Ops); + } + + // Okay, check to see if the same value occurs in the operand list twice. If + // so, delete one. Since we sorted the list, these values are required to + // be adjacent. + for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) + if (Ops[i] == Ops[i+1]) { // X smax Y smax Y --> X smax Y + Ops.erase(Ops.begin()+i, Ops.begin()+i+1); + --i; --e; + } + + if (Ops.size() == 1) return Ops[0]; + + assert(!Ops.empty() && "Reduced smax down to nothing!"); + + // Okay, it looks like we really DO need an smax expr. Check to see if we + // already have one, otherwise create a new one. + std::vector SCEVOps(Ops.begin(), Ops.end()); + SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr, + SCEVOps)]; + if (Result == 0) Result = new SCEVSMaxExpr(Ops); + return Result; +} + +SCEVHandle ScalarEvolution::getUMaxExpr(const SCEVHandle &LHS, + const SCEVHandle &RHS) { + std::vector Ops; + Ops.push_back(LHS); + Ops.push_back(RHS); + return getUMaxExpr(Ops); +} + +SCEVHandle ScalarEvolution::getUMaxExpr(std::vector Ops) { + assert(!Ops.empty() && "Cannot get empty umax!"); + if (Ops.size() == 1) return Ops[0]; + + // Sort by complexity, this groups all similar expression types together. + GroupByComplexity(Ops); + + // If there are any constants, fold them together. + unsigned Idx = 0; + if (SCEVConstant *LHSC = dyn_cast(Ops[0])) { + ++Idx; + assert(Idx < Ops.size()); + while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { + // We found two constants, fold them together! + ConstantInt *Fold = ConstantInt::get( + APIntOps::umax(LHSC->getValue()->getValue(), + RHSC->getValue()->getValue())); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast(Ops[0]); + } + + // If we are left with a constant zero, strip it off. + if (cast(Ops[0])->getValue()->isMinValue(false)) { + Ops.erase(Ops.begin()); + --Idx; + } + } + + if (Ops.size() == 1) return Ops[0]; + + // Find the first UMax + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr) + ++Idx; + + // Check to see if one of the operands is a UMax. If so, expand its operands + // onto our operand list, and recurse to simplify. + if (Idx < Ops.size()) { + bool DeletedUMax = false; + while (SCEVUMaxExpr *UMax = dyn_cast(Ops[Idx])) { + Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end()); + Ops.erase(Ops.begin()+Idx); + DeletedUMax = true; + } + + if (DeletedUMax) + return getUMaxExpr(Ops); + } + + // Okay, check to see if the same value occurs in the operand list twice. If + // so, delete one. Since we sorted the list, these values are required to + // be adjacent. + for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) + if (Ops[i] == Ops[i+1]) { // X umax Y umax Y --> X umax Y + Ops.erase(Ops.begin()+i, Ops.begin()+i+1); + --i; --e; + } + + if (Ops.size() == 1) return Ops[0]; + + assert(!Ops.empty() && "Reduced umax down to nothing!"); + + // Okay, it looks like we really DO need a umax expr. Check to see if we + // already have one, otherwise create a new one. + std::vector SCEVOps(Ops.begin(), Ops.end()); + SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scUMaxExpr, + SCEVOps)]; + if (Result == 0) Result = new SCEVUMaxExpr(Ops); + return Result; +} + +SCEVHandle ScalarEvolution::getUnknown(Value *V) { if (ConstantInt *CI = dyn_cast(V)) - return SCEVConstant::get(CI); + return getConstant(CI); SCEVUnknown *&Result = (*SCEVUnknowns)[V]; if (Result == 0) Result = new SCEVUnknown(V); return Result; @@ -1052,6 +1354,9 @@ SCEVHandle SCEVUnknown::get(Value *V) { /// namespace { struct VISIBILITY_HIDDEN ScalarEvolutionsImpl { + /// SE - A reference to the public ScalarEvolution object. + ScalarEvolution &SE; + /// F - The function we are analyzing. /// Function &F; @@ -1080,8 +1385,8 @@ namespace { std::map ConstantEvolutionLoopExitValue; public: - ScalarEvolutionsImpl(Function &f, LoopInfo &li) - : F(f), LI(li), UnknownValue(new SCEVCouldNotCompute()) {} + ScalarEvolutionsImpl(ScalarEvolution &se, Function &f, LoopInfo &li) + : SE(se), F(f), LI(li), UnknownValue(new SCEVCouldNotCompute()) {} /// getSCEV - Return an existing SCEV if it exists, otherwise analyze the /// expression and create a new one. @@ -1098,6 +1403,7 @@ namespace { void setSCEV(Value *V, const SCEVHandle &H) { bool isNew = Scalars.insert(std::make_pair(V, H)).second; assert(isNew && "This entry already existed!"); + isNew = false; } @@ -1116,10 +1422,10 @@ namespace { /// loop without a loop-invariant iteration count. SCEVHandle getIterationCount(const Loop *L); - /// deleteInstructionFromRecords - This method should be called by the - /// client before it removes an instruction from the program, to make sure + /// deleteValueFromRecords - This method should be called by the + /// client before it removes a value from the program, to make sure /// that no dangling references are left around. - void deleteInstructionFromRecords(Instruction *I); + void deleteValueFromRecords(Value *V); private: /// createSCEV - We know that there is no SCEV for the specified value. @@ -1143,7 +1449,7 @@ namespace { SCEVHandle ComputeIterationCount(const Loop *L); /// ComputeLoadConstantCompareIterationCount - Given an exit condition of - /// 'setcc load X, cst', try to se if we can compute the trip count. + /// 'icmp op load X, cst', try to see if we can compute the trip count. SCEVHandle ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, const Loop *L, @@ -1169,14 +1475,31 @@ namespace { /// HowManyLessThans - Return the number of times a backedge containing the /// specified less-than comparison will execute. If not computable, return - /// UnknownValue. - SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L); + /// UnknownValue. isSigned specifies whether the less-than is signed. + SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, + bool isSigned, bool trueWhenEqual); + + /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB + /// (which may not be an immediate predecessor) which has exactly one + /// successor from which BB is reachable, or null if no such block is + /// found. + BasicBlock* getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB); + + /// executesAtLeastOnce - Test whether entry to the loop is protected by + /// a conditional between LHS and RHS. + bool executesAtLeastOnce(const Loop *L, bool isSigned, bool trueWhenEqual, + SCEV *LHS, SCEV *RHS); + + /// potentialInfiniteLoop - Test whether the loop might jump over the exit value + /// due to wrapping. + bool potentialInfiniteLoop(SCEV *Stride, SCEV *RHS, bool isSigned, + bool trueWhenEqual); /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a /// constant number of times, and the PHI node is just a recurrence /// involving constants, fold it. - Constant *getConstantEvolutionLoopExitValue(PHINode *PN, uint64_t Its, + Constant *getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its, const Loop *L); }; } @@ -1185,13 +1508,32 @@ namespace { // Basic SCEV Analysis and PHI Idiom Recognition Code // -/// deleteInstructionFromRecords - This method should be called by the +/// deleteValueFromRecords - This method should be called by the /// client before it removes an instruction from the program, to make sure /// that no dangling references are left around. -void ScalarEvolutionsImpl::deleteInstructionFromRecords(Instruction *I) { - Scalars.erase(I); - if (PHINode *PN = dyn_cast(I)) - ConstantEvolutionLoopExitValue.erase(PN); +void ScalarEvolutionsImpl::deleteValueFromRecords(Value *V) { + SmallVector Worklist; + + if (Scalars.erase(V)) { + if (PHINode *PN = dyn_cast(V)) + ConstantEvolutionLoopExitValue.erase(PN); + Worklist.push_back(V); + } + + while (!Worklist.empty()) { + Value *VV = Worklist.back(); + Worklist.pop_back(); + + for (Instruction::use_iterator UI = VV->use_begin(), UE = VV->use_end(); + UI != UE; ++UI) { + Instruction *Inst = cast(*UI); + if (Scalars.erase(Inst)) { + if (PHINode *PN = dyn_cast(VV)) + ConstantEvolutionLoopExitValue.erase(PN); + Worklist.push_back(Inst); + } + } + } } @@ -1217,7 +1559,7 @@ ReplaceSymbolicValueWithConcrete(Instruction *I, const SCEVHandle &SymName, if (SI == Scalars.end()) return; SCEVHandle NV = - SI->second->replaceSymbolicValuesWithConcrete(SymName, NewVal); + SI->second->replaceSymbolicValuesWithConcrete(SymName, NewVal, SE); if (NV == SI->second) return; // No change. SI->second = NV; // Update the scalars map! @@ -1242,7 +1584,7 @@ SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) { unsigned BackEdge = IncomingEdge^1; // While we are analyzing this PHI node, handle its value symbolically. - SCEVHandle SymbolicName = SCEVUnknown::get(PN); + SCEVHandle SymbolicName = SE.getUnknown(PN); assert(Scalars.find(PN) == Scalars.end() && "PHI node already processed?"); Scalars.insert(std::make_pair(PN, SymbolicName)); @@ -1273,7 +1615,7 @@ SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) { for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) if (i != FoundIndex) Ops.push_back(Add->getOperand(i)); - SCEVHandle Accum = SCEVAddExpr::get(Ops); + SCEVHandle Accum = SE.getAddExpr(Ops); // This is not a valid addrec if the step amount is varying each // loop iteration, but is not itself an addrec in this loop. @@ -1281,7 +1623,7 @@ SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) { (isa(Accum) && cast(Accum)->getLoop() == L)) { SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge)); - SCEVHandle PHISCEV = SCEVAddRecExpr::get(StartVal, Accum, L); + SCEVHandle PHISCEV = SE.getAddRecExpr(StartVal, Accum, L); // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and update all of the @@ -1303,10 +1645,10 @@ SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) { // If StartVal = j.start - j.stride, we can use StartVal as the // initial step of the addrec evolution. - if (StartVal == SCEV::getMinusSCEV(AddRec->getOperand(0), - AddRec->getOperand(1))) { + if (StartVal == SE.getMinusSCEV(AddRec->getOperand(0), + AddRec->getOperand(1))) { SCEVHandle PHISCEV = - SCEVAddRecExpr::get(StartVal, AddRec->getOperand(1), L); + SE.getAddRecExpr(StartVal, AddRec->getOperand(1), L); // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and update all of the @@ -1323,123 +1665,218 @@ SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) { } // If it's not a loop phi, we can't handle it yet. - return SCEVUnknown::get(PN); + return SE.getUnknown(PN); } -/// GetConstantFactor - Determine the largest constant factor that S has. For -/// example, turn {4,+,8} -> 4. (S umod result) should always equal zero. -static uint64_t GetConstantFactor(SCEVHandle S) { - if (SCEVConstant *C = dyn_cast(S)) { - if (uint64_t V = C->getValue()->getZExtValue()) - return V; - else // Zero is a multiple of everything. - return 1ULL << (S->getType()->getPrimitiveSizeInBits()-1); - } +/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is +/// guaranteed to end in (at every loop iteration). It is, at the same time, +/// the minimum number of times S is divisible by 2. For example, given {4,+,8} +/// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S. +static uint32_t GetMinTrailingZeros(SCEVHandle S) { + if (SCEVConstant *C = dyn_cast(S)) + return C->getValue()->getValue().countTrailingZeros(); if (SCEVTruncateExpr *T = dyn_cast(S)) - return GetConstantFactor(T->getOperand()) & - cast(T->getType())->getBitMask(); - if (SCEVZeroExtendExpr *E = dyn_cast(S)) - return GetConstantFactor(E->getOperand()); - + return std::min(GetMinTrailingZeros(T->getOperand()), T->getBitWidth()); + + if (SCEVZeroExtendExpr *E = dyn_cast(S)) { + uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes; + } + + if (SCEVSignExtendExpr *E = dyn_cast(S)) { + uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes; + } + if (SCEVAddExpr *A = dyn_cast(S)) { - // The result is the min of all operands. - uint64_t Res = GetConstantFactor(A->getOperand(0)); - for (unsigned i = 1, e = A->getNumOperands(); i != e && Res > 1; ++i) - Res = std::min(Res, GetConstantFactor(A->getOperand(i))); - return Res; + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); + for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); + return MinOpRes; } if (SCEVMulExpr *M = dyn_cast(S)) { - // The result is the product of all the operands. - uint64_t Res = GetConstantFactor(M->getOperand(0)); - for (unsigned i = 1, e = M->getNumOperands(); i != e; ++i) - Res *= GetConstantFactor(M->getOperand(i)); - return Res; + // The result is the sum of all operands results. + uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0)); + uint32_t BitWidth = M->getBitWidth(); + for (unsigned i = 1, e = M->getNumOperands(); + SumOpRes != BitWidth && i != e; ++i) + SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), + BitWidth); + return SumOpRes; } - + if (SCEVAddRecExpr *A = dyn_cast(S)) { - // For now, we just handle linear expressions. - if (A->getNumOperands() == 2) { - // We want the GCD between the start and the stride value. - uint64_t Start = GetConstantFactor(A->getOperand(0)); - if (Start == 1) return 1; - uint64_t Stride = GetConstantFactor(A->getOperand(1)); - return GreatestCommonDivisor64(Start, Stride); - } + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); + for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); + return MinOpRes; } - - // SCEVSDivExpr, SCEVUnknown. - return 1; + + if (SCEVSMaxExpr *M = dyn_cast(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); + for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); + return MinOpRes; + } + + if (SCEVUMaxExpr *M = dyn_cast(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); + for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); + return MinOpRes; + } + + // SCEVUDivExpr, SCEVUnknown + return 0; } /// createSCEV - We know that there is no SCEV for the specified value. /// Analyze the expression. /// SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { - if (Instruction *I = dyn_cast(V)) { - switch (I->getOpcode()) { - case Instruction::Add: - return SCEVAddExpr::get(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::Mul: - return SCEVMulExpr::get(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::SDiv: - return SCEVSDivExpr::get(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - break; + if (!isa(V->getType())) + return SE.getUnknown(V); + + unsigned Opcode = Instruction::UserOp1; + if (Instruction *I = dyn_cast(V)) + Opcode = I->getOpcode(); + else if (ConstantExpr *CE = dyn_cast(V)) + Opcode = CE->getOpcode(); + else + return SE.getUnknown(V); + + User *U = cast(V); + switch (Opcode) { + case Instruction::Add: + return SE.getAddExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Mul: + return SE.getMulExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::UDiv: + return SE.getUDivExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Sub: + return SE.getMinusSCEV(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Or: + // If the RHS of the Or is a constant, we may have something like: + // X*4+1 which got turned into X*4|1. Handle this as an Add so loop + // optimizations will transparently handle this case. + // + // In order for this transformation to be safe, the LHS must be of the + // form X*(2^n) and the Or constant must be less than 2^n. + if (ConstantInt *CI = dyn_cast(U->getOperand(1))) { + SCEVHandle LHS = getSCEV(U->getOperand(0)); + const APInt &CIVal = CI->getValue(); + if (GetMinTrailingZeros(LHS) >= + (CIVal.getBitWidth() - CIVal.countLeadingZeros())) + return SE.getAddExpr(LHS, getSCEV(U->getOperand(1))); + } + break; + case Instruction::Xor: + if (ConstantInt *CI = dyn_cast(U->getOperand(1))) { + // If the RHS of the xor is a signbit, then this is just an add. + // Instcombine turns add of signbit into xor as a strength reduction step. + if (CI->getValue().isSignBit()) + return SE.getAddExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + + // If the RHS of xor is -1, then this is a not operation. + else if (CI->isAllOnesValue()) + return SE.getNotSCEV(getSCEV(U->getOperand(0))); + } + break; - case Instruction::Sub: - return SCEV::getMinusSCEV(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::Or: - // If the RHS of the Or is a constant, we may have something like: - // X*4+1 which got turned into X*4|1. Handle this as an add so loop - // optimizations will transparently handle this case. - if (ConstantInt *CI = dyn_cast(I->getOperand(1))) { - SCEVHandle LHS = getSCEV(I->getOperand(0)); - uint64_t CommonFact = GetConstantFactor(LHS); - assert(CommonFact && "Common factor should at least be 1!"); - if (CommonFact > CI->getZExtValue()) { - // If the LHS is a multiple that is larger than the RHS, use +. - return SCEVAddExpr::get(LHS, - getSCEV(I->getOperand(1))); - } - } - break; - - case Instruction::Shl: - // Turn shift left of a constant amount into a multiply. - if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { - Constant *X = ConstantInt::get(V->getType(), 1); - X = ConstantExpr::getShl(X, SA); - return SCEVMulExpr::get(getSCEV(I->getOperand(0)), getSCEV(X)); - } - break; + case Instruction::Shl: + // Turn shift left of a constant amount into a multiply. + if (ConstantInt *SA = dyn_cast(U->getOperand(1))) { + uint32_t BitWidth = cast(V->getType())->getBitWidth(); + Constant *X = ConstantInt::get( + APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); + return SE.getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X)); + } + break; - case Instruction::Trunc: - return SCEVTruncateExpr::get(getSCEV(I->getOperand(0)), I->getType()); + case Instruction::LShr: + // Turn logical shift right of a constant into a unsigned divide. + if (ConstantInt *SA = dyn_cast(U->getOperand(1))) { + uint32_t BitWidth = cast(V->getType())->getBitWidth(); + Constant *X = ConstantInt::get( + APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); + return SE.getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X)); + } + break; - case Instruction::ZExt: - return SCEVZeroExtendExpr::get(getSCEV(I->getOperand(0)), I->getType()); + case Instruction::Trunc: + return SE.getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); - case Instruction::BitCast: - // BitCasts are no-op casts so we just eliminate the cast. - if (I->getType()->isInteger() && - I->getOperand(0)->getType()->isInteger()) - return getSCEV(I->getOperand(0)); - break; + case Instruction::ZExt: + return SE.getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType()); - case Instruction::PHI: - return createNodeForPHI(cast(I)); + case Instruction::SExt: + return SE.getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType()); - default: // We cannot analyze this expression. - break; + case Instruction::BitCast: + // BitCasts are no-op casts so we just eliminate the cast. + if (U->getType()->isInteger() && + U->getOperand(0)->getType()->isInteger()) + return getSCEV(U->getOperand(0)); + break; + + case Instruction::PHI: + return createNodeForPHI(cast(U)); + + case Instruction::Select: + // This could be a smax or umax that was lowered earlier. + // Try to recover it. + if (ICmpInst *ICI = dyn_cast(U->getOperand(0))) { + Value *LHS = ICI->getOperand(0); + Value *RHS = ICI->getOperand(1); + switch (ICI->getPredicate()) { + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) + return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS)); + else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) + // ~smax(~x, ~y) == smin(x, y). + return SE.getNotSCEV(SE.getSMaxExpr( + SE.getNotSCEV(getSCEV(LHS)), + SE.getNotSCEV(getSCEV(RHS)))); + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) + return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS)); + else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) + // ~umax(~x, ~y) == umin(x, y) + return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)), + SE.getNotSCEV(getSCEV(RHS)))); + break; + default: + break; + } } + + default: // We cannot analyze this expression. + break; } - return SCEVUnknown::get(V); + return SE.getUnknown(V); } @@ -1472,7 +1909,7 @@ SCEVHandle ScalarEvolutionsImpl::getIterationCount(const Loop *L) { /// will iterate. SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { // If the loop has a non-one exit block count, we can't analyze it. - std::vector ExitBlocks; + SmallVector ExitBlocks; L->getExitBlocks(ExitBlocks); if (ExitBlocks.size() != 1) return UnknownValue; @@ -1516,7 +1953,7 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { ICmpInst *ExitCond = dyn_cast(ExitBr->getCondition()); - // If its not an integer comparison then compute it the hard way. + // If it's not an integer comparison then compute it the hard way. // Note that ICmpInst deals with pointer comparisons too so we must check // the type of the operand. if (ExitCond == 0 || isa(ExitCond->getOperand(0)->getType())) @@ -1549,8 +1986,8 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { // At this point, we would like to compute how many iterations of the // loop the predicate will return true for these inputs. - if (isa(LHS) && !isa(RHS)) { - // If there is a constant, force it into the RHS. + if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) { + // If there is a loop-invariant, force it into the RHS. std::swap(LHS, RHS); Cond = ICmpInst::getSwappedPredicate(Cond); } @@ -1572,10 +2009,10 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { ConstantExpr::getBitCast(CompVal, RealTy)); if (CompVal) { // Form the constant range. - ConstantRange CompRange(Cond, CompVal->getValue()); + ConstantRange CompRange( + ICmpInst::makeConstantRange(Cond, CompVal->getValue())); - SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange, - false /*Always treat as unsigned range*/); + SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange, SE); if (!isa(Ret)) return Ret; } } @@ -1583,23 +2020,57 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { switch (Cond) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) - SCEVHandle TC = HowFarToZero(SCEV::getMinusSCEV(LHS, RHS), L); + SCEVHandle TC = HowFarToZero(SE.getMinusSCEV(LHS, RHS), L); if (!isa(TC)) return TC; break; } case ICmpInst::ICMP_EQ: { // Convert to: while (X-Y == 0) // while (X == Y) - SCEVHandle TC = HowFarToNonZero(SCEV::getMinusSCEV(LHS, RHS), L); + SCEVHandle TC = HowFarToNonZero(SE.getMinusSCEV(LHS, RHS), L); if (!isa(TC)) return TC; break; } case ICmpInst::ICMP_SLT: { - SCEVHandle TC = HowManyLessThans(LHS, RHS, L); + SCEVHandle TC = HowManyLessThans(LHS, RHS, L, true, false); if (!isa(TC)) return TC; break; } case ICmpInst::ICMP_SGT: { - SCEVHandle TC = HowManyLessThans(RHS, LHS, L); + SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS), + SE.getNotSCEV(RHS), L, true, false); + if (!isa(TC)) return TC; + break; + } + case ICmpInst::ICMP_ULT: { + SCEVHandle TC = HowManyLessThans(LHS, RHS, L, false, false); + if (!isa(TC)) return TC; + break; + } + case ICmpInst::ICMP_UGT: { + SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS), + SE.getNotSCEV(RHS), L, false, false); + if (!isa(TC)) return TC; + break; + } + case ICmpInst::ICMP_SLE: { + SCEVHandle TC = HowManyLessThans(LHS, RHS, L, true, true); + if (!isa(TC)) return TC; + break; + } + case ICmpInst::ICMP_SGE: { + SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS), + SE.getNotSCEV(RHS), L, true, true); + if (!isa(TC)) return TC; + break; + } + case ICmpInst::ICMP_ULE: { + SCEVHandle TC = HowManyLessThans(LHS, RHS, L, false, true); + if (!isa(TC)) return TC; + break; + } + case ICmpInst::ICMP_UGE: { + SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS), + SE.getNotSCEV(RHS), L, false, true); if (!isa(TC)) return TC; break; } @@ -1619,9 +2090,10 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { } static ConstantInt * -EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, Constant *C) { - SCEVHandle InVal = SCEVConstant::get(cast(C)); - SCEVHandle Val = AddRec->evaluateAtIteration(InVal); +EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, + ScalarEvolution &SE) { + SCEVHandle InVal = SE.getConstant(C); + SCEVHandle Val = AddRec->evaluateAtIteration(InVal, SE); assert(isa(Val) && "Evaluation of SCEV at constant didn't fold correctly?"); return cast(Val)->getValue(); @@ -1662,7 +2134,7 @@ GetAddressedElementFromGlobal(GlobalVariable *GV, } /// ComputeLoadConstantCompareIterationCount - Given an exit condition of -/// 'setcc load X, cst', try to se if we can compute the trip count. +/// 'icmp op load X, cst', try to see if we can compute the trip count. SCEVHandle ScalarEvolutionsImpl:: ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, const Loop *L, @@ -1713,7 +2185,7 @@ ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) { ConstantInt *ItCst = ConstantInt::get(IdxExpr->getType(), IterationNum); - ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst); + ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, SE); // Form the GEP offset. Indexes[VarIdxNum] = Val; @@ -1724,14 +2196,14 @@ ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, // Evaluate the condition for this iteration. Result = ConstantExpr::getICmp(predicate, Result, RHS); if (!isa(Result)) break; // Couldn't decide for sure - if (cast(Result)->getZExtValue() == false) { + if (cast(Result)->getValue().isMinValue()) { #if 0 cerr << "\n***\n*** Computed loop count " << *ItCst << "\n*** From global " << *GV << "*** BB: " << *L->getHeader() << "***\n"; #endif ++NumArrayLenItCounts; - return SCEVConstant::get(ItCst); // Found terminating iteration! + return SE.getConstant(ItCst); // Found terminating iteration! } } return UnknownValue; @@ -1747,7 +2219,7 @@ static bool CanConstantFold(const Instruction *I) { if (const CallInst *CI = dyn_cast(I)) if (const Function *F = CI->getCalledFunction()) - return canConstantFoldCallTo((Function*)F); // FIXME: elim cast + return canConstantFoldCallTo(F); return false; } @@ -1762,13 +2234,14 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { Instruction *I = dyn_cast(V); if (I == 0 || !L->contains(I->getParent())) return 0; - if (PHINode *PN = dyn_cast(I)) + if (PHINode *PN = dyn_cast(I)) { if (L->getHeader() == I->getParent()) return PN; else // We don't currently keep track of the control flow needed to evaluate // PHIs, so we cannot handle PHIs inside of loops. return 0; + } // If we won't be able to constant fold this expression even if the operands // are constants, return early. @@ -1798,8 +2271,6 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { /// reason, return null. static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { if (isa(V)) return PHIVal; - if (GlobalValue *GV = dyn_cast(V)) - return GV; if (Constant *C = dyn_cast(V)) return C; Instruction *I = cast(V); @@ -1811,7 +2282,12 @@ static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { if (Operands[i] == 0) return 0; } - return ConstantFoldInstOperands(I, &Operands[0], Operands.size()); + if (const CmpInst *CI = dyn_cast(I)) + return ConstantFoldCompareInstOperands(CI->getPredicate(), + &Operands[0], Operands.size()); + else + return ConstantFoldInstOperands(I->getOpcode(), I->getType(), + &Operands[0], Operands.size()); } /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is @@ -1819,13 +2295,13 @@ static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { /// constant number of times, and the PHI node is just a recurrence /// involving constants, fold it. Constant *ScalarEvolutionsImpl:: -getConstantEvolutionLoopExitValue(PHINode *PN, uint64_t Its, const Loop *L) { +getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its, const Loop *L){ std::map::iterator I = ConstantEvolutionLoopExitValue.find(PN); if (I != ConstantEvolutionLoopExitValue.end()) return I->second; - if (Its > MaxBruteForceIterations) + if (Its.ugt(APInt(Its.getBitWidth(),MaxBruteForceIterations))) return ConstantEvolutionLoopExitValue[PN] = 0; // Not going to evaluate it. Constant *&RetVal = ConstantEvolutionLoopExitValue[PN]; @@ -1845,11 +2321,11 @@ getConstantEvolutionLoopExitValue(PHINode *PN, uint64_t Its, const Loop *L) { return RetVal = 0; // Not derived from same PHI. // Execute the loop symbolically to determine the exit value. - unsigned IterationNum = 0; - unsigned NumIterations = Its; - if (NumIterations != Its) - return RetVal = 0; // More than 2^32 iterations?? + if (Its.getActiveBits() >= 32) + return RetVal = 0; // More than 2^32-1 iterations?? Not doing it! + unsigned NumIterations = Its.getZExtValue(); // must be in range + unsigned IterationNum = 0; for (Constant *PHIVal = StartCST; ; ++IterationNum) { if (IterationNum == NumIterations) return RetVal = PHIVal; // Got exit value! @@ -1899,10 +2375,10 @@ ComputeIterationCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) { // Couldn't symbolically evaluate. if (!CondVal) return UnknownValue; - if (CondVal->getZExtValue() == uint64_t(ExitWhen)) { + if (CondVal->getValue() == uint64_t(ExitWhen)) { ConstantEvolutionLoopExitValue[PN] = PHIVal; ++NumBruteForceTripCountsComputed; - return SCEVConstant::get(ConstantInt::get(Type::Int32Ty, IterationNum)); + return SE.getConstant(ConstantInt::get(Type::Int32Ty, IterationNum)); } // Compute the value of the PHI node for the next iteration. @@ -1924,7 +2400,7 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { if (isa(V)) return V; - // If this instruction is evolves from a constant-evolving PHI, compute the + // If this instruction is evolved from a constant-evolving PHI, compute the // exit value from the loop without using SCEVs. if (SCEVUnknown *SU = dyn_cast(V)) { if (Instruction *I = dyn_cast(SU->getValue())) { @@ -1941,9 +2417,9 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { // this is a constant evolving PHI node, get the final value at // the specified iteration number. Constant *RV = getConstantEvolutionLoopExitValue(PN, - ICC->getValue()->getZExtValue(), + ICC->getValue()->getValue(), LI); - if (RV) return SCEVUnknown::get(RV); + if (RV) return SE.getUnknown(RV); } } @@ -1959,6 +2435,11 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { if (Constant *C = dyn_cast(Op)) { Operands.push_back(C); } else { + // If any of the operands is non-constant and if they are + // non-integer, don't even try to analyze them with scev techniques. + if (!isa(Op->getType())) + return V; + SCEVHandle OpV = getSCEVAtScope(getSCEV(Op), L); if (SCEVConstant *SC = dyn_cast(OpV)) Operands.push_back(ConstantExpr::getIntegerCast(SC->getValue(), @@ -1976,8 +2457,15 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { } } } - Constant *C =ConstantFoldInstOperands(I, &Operands[0], Operands.size()); - return SCEVUnknown::get(C); + + Constant *C; + if (const CmpInst *CI = dyn_cast(I)) + C = ConstantFoldCompareInstOperands(CI->getPredicate(), + &Operands[0], Operands.size()); + else + C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), + &Operands[0], Operands.size()); + return SE.getUnknown(C); } } @@ -2003,23 +2491,28 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { NewOps.push_back(OpAtScope); } if (isa(Comm)) - return SCEVAddExpr::get(NewOps); - assert(isa(Comm) && "Only know about add and mul!"); - return SCEVMulExpr::get(NewOps); + return SE.getAddExpr(NewOps); + if (isa(Comm)) + return SE.getMulExpr(NewOps); + if (isa(Comm)) + return SE.getSMaxExpr(NewOps); + if (isa(Comm)) + return SE.getUMaxExpr(NewOps); + assert(0 && "Unknown commutative SCEV type!"); } } // If we got here, all operands are loop invariant. return Comm; } - if (SCEVSDivExpr *Div = dyn_cast(V)) { + if (SCEVUDivExpr *Div = dyn_cast(V)) { SCEVHandle LHS = getSCEVAtScope(Div->getLHS(), L); if (LHS == UnknownValue) return LHS; SCEVHandle RHS = getSCEVAtScope(Div->getRHS(), L); if (RHS == UnknownValue) return RHS; if (LHS == Div->getLHS() && RHS == Div->getRHS()) return Div; // must be loop invariant - return SCEVSDivExpr::get(LHS, RHS); + return SE.getUDivExpr(LHS, RHS); } // If this is a loop recurrence for a loop that does not contain L, then we @@ -2030,18 +2523,9 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { // loop iterates. Compute this now. SCEVHandle IterationCount = getIterationCount(AddRec->getLoop()); if (IterationCount == UnknownValue) return UnknownValue; - IterationCount = getTruncateOrZeroExtend(IterationCount, - AddRec->getType()); - - // If the value is affine, simplify the expression evaluation to just - // Start + Step*IterationCount. - if (AddRec->isAffine()) - return SCEVAddExpr::get(AddRec->getStart(), - SCEVMulExpr::get(IterationCount, - AddRec->getOperand(1))); - - // Otherwise, evaluate it the hard way. - return AddRec->evaluateAtIteration(IterationCount); + + // Then, evaluate the AddRec. + return AddRec->evaluateAtIteration(IterationCount, SE); } return UnknownValue; } @@ -2050,65 +2534,113 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { return UnknownValue; } +/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the +/// following equation: +/// +/// A * X = B (mod N) +/// +/// where N = 2^BW and BW is the common bit width of A and B. The signedness of +/// A and B isn't important. +/// +/// If the equation does not have a solution, SCEVCouldNotCompute is returned. +static SCEVHandle SolveLinEquationWithOverflow(const APInt &A, const APInt &B, + ScalarEvolution &SE) { + uint32_t BW = A.getBitWidth(); + assert(BW == B.getBitWidth() && "Bit widths must be the same."); + assert(A != 0 && "A must be non-zero."); + + // 1. D = gcd(A, N) + // + // The gcd of A and N may have only one prime factor: 2. The number of + // trailing zeros in A is its multiplicity + uint32_t Mult2 = A.countTrailingZeros(); + // D = 2^Mult2 + + // 2. Check if B is divisible by D. + // + // B is divisible by D if and only if the multiplicity of prime factor 2 for B + // is not less than multiplicity of this prime factor for D. + if (B.countTrailingZeros() < Mult2) + return new SCEVCouldNotCompute(); + + // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic + // modulo (N / D). + // + // (N / D) may need BW+1 bits in its representation. Hence, we'll use this + // bit width during computations. + APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D + APInt Mod(BW + 1, 0); + Mod.set(BW - Mult2); // Mod = N / D + APInt I = AD.multiplicativeInverse(Mod); + + // 4. Compute the minimum unsigned root of the equation: + // I * (B / D) mod (N / D) + APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod); + + // The result is guaranteed to be less than 2^BW so we may truncate it to BW + // bits. + return SE.getConstant(Result.trunc(BW)); +} /// SolveQuadraticEquation - Find the roots of the quadratic equation for the /// given quadratic chrec {L,+,M,+,N}. This returns either the two roots (which /// might be the same) or two SCEVCouldNotCompute objects. /// static std::pair -SolveQuadraticEquation(const SCEVAddRecExpr *AddRec) { +SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); - SCEVConstant *L = dyn_cast(AddRec->getOperand(0)); - SCEVConstant *M = dyn_cast(AddRec->getOperand(1)); - SCEVConstant *N = dyn_cast(AddRec->getOperand(2)); + SCEVConstant *LC = dyn_cast(AddRec->getOperand(0)); + SCEVConstant *MC = dyn_cast(AddRec->getOperand(1)); + SCEVConstant *NC = dyn_cast(AddRec->getOperand(2)); // We currently can only solve this if the coefficients are constants. - if (!L || !M || !N) { + if (!LC || !MC || !NC) { SCEV *CNC = new SCEVCouldNotCompute(); return std::make_pair(CNC, CNC); } - Constant *C = L->getValue(); - Constant *Two = ConstantInt::get(C->getType(), 2); - - // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C - // The B coefficient is M-N/2 - Constant *B = ConstantExpr::getSub(M->getValue(), - ConstantExpr::getSDiv(N->getValue(), - Two)); - // The A coefficient is N/2 - Constant *A = ConstantExpr::getSDiv(N->getValue(), Two); - - // Compute the B^2-4ac term. - Constant *SqrtTerm = - ConstantExpr::getMul(ConstantInt::get(C->getType(), 4), - ConstantExpr::getMul(A, C)); - SqrtTerm = ConstantExpr::getSub(ConstantExpr::getMul(B, B), SqrtTerm); - - // Compute floor(sqrt(B^2-4ac)) - uint64_t SqrtValV = cast(SqrtTerm)->getZExtValue(); - uint64_t SqrtValV2 = (uint64_t)sqrt((double)SqrtValV); - // The square root might not be precise for arbitrary 64-bit integer - // values. Do some sanity checks to ensure it's correct. - if (SqrtValV2*SqrtValV2 > SqrtValV || - (SqrtValV2+1)*(SqrtValV2+1) <= SqrtValV) { - SCEV *CNC = new SCEVCouldNotCompute(); - return std::make_pair(CNC, CNC); - } - - ConstantInt *SqrtVal = ConstantInt::get(Type::Int64Ty, SqrtValV2); - SqrtTerm = ConstantExpr::getTruncOrBitCast(SqrtVal, SqrtTerm->getType()); + uint32_t BitWidth = LC->getValue()->getValue().getBitWidth(); + const APInt &L = LC->getValue()->getValue(); + const APInt &M = MC->getValue()->getValue(); + const APInt &N = NC->getValue()->getValue(); + APInt Two(BitWidth, 2); + APInt Four(BitWidth, 4); + + { + using namespace APIntOps; + const APInt& C = L; + // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C + // The B coefficient is M-N/2 + APInt B(M); + B -= sdiv(N,Two); + + // The A coefficient is N/2 + APInt A(N.sdiv(Two)); + + // Compute the B^2-4ac term. + APInt SqrtTerm(B); + SqrtTerm *= B; + SqrtTerm -= Four * (A * C); + + // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest + // integer value or else APInt::sqrt() will assert. + APInt SqrtVal(SqrtTerm.sqrt()); + + // Compute the two solutions for the quadratic formula. + // The divisions must be performed as signed divisions. + APInt NegB(-B); + APInt TwoA( A << 1 ); + if (TwoA.isMinValue()) { + SCEV *CNC = new SCEVCouldNotCompute(); + return std::make_pair(CNC, CNC); + } - Constant *NegB = ConstantExpr::getNeg(B); - Constant *TwoA = ConstantExpr::getMul(A, Two); + ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA)); + ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA)); - // The divisions must be performed as signed divisions. - Constant *Solution1 = - ConstantExpr::getSDiv(ConstantExpr::getAdd(NegB, SqrtTerm), TwoA); - Constant *Solution2 = - ConstantExpr::getSDiv(ConstantExpr::getSub(NegB, SqrtTerm), TwoA); - return std::make_pair(SCEVUnknown::get(Solution1), - SCEVUnknown::get(Solution2)); + return std::make_pair(SE.getConstant(Solution1), + SE.getConstant(Solution2)); + } // end APIntOps namespace } /// HowFarToZero - Return the number of times a backedge comparing the specified @@ -2117,7 +2649,7 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { // If the value is a constant if (SCEVConstant *C = dyn_cast(V)) { // If the value is already zero, the branch will execute zero times. - if (C->getValue()->isNullValue()) return C; + if (C->getValue()->isZero()) return C; return UnknownValue; // Otherwise it will loop infinitely. } @@ -2126,41 +2658,41 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { return UnknownValue; if (AddRec->isAffine()) { - // If this is an affine expression the execution count of this branch is - // equal to: + // If this is an affine expression, the execution count of this branch is + // the minimum unsigned root of the following equation: + // + // Start + Step*N = 0 (mod 2^BW) + // + // equivalent to: // - // (0 - Start/Step) iff Start % Step == 0 + // Step*N = -Start (mod 2^BW) // + // where BW is the common bit width of Start and Step. + // Get the initial value for the loop. SCEVHandle Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); if (isa(Start)) return UnknownValue; - SCEVHandle Step = AddRec->getOperand(1); - Step = getSCEVAtScope(Step, L->getParentLoop()); + SCEVHandle Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); - // Figure out if Start % Step == 0. - // FIXME: We should add DivExpr and RemExpr operations to our AST. if (SCEVConstant *StepC = dyn_cast(Step)) { - if (StepC->getValue()->equalsInt(1)) // N % 1 == 0 - return SCEV::getNegativeSCEV(Start); // 0 - Start/1 == -Start - if (StepC->getValue()->isAllOnesValue()) // N % -1 == 0 - return Start; // 0 - Start/-1 == Start - - // Check to see if Start is divisible by SC with no remainder. - if (SCEVConstant *StartC = dyn_cast(Start)) { - ConstantInt *StartCC = StartC->getValue(); - Constant *StartNegC = ConstantExpr::getNeg(StartCC); - Constant *Rem = ConstantExpr::getSRem(StartNegC, StepC->getValue()); - if (Rem->isNullValue()) { - Constant *Result =ConstantExpr::getSDiv(StartNegC,StepC->getValue()); - return SCEVUnknown::get(Result); - } - } + // For now we handle only constant steps. + + // First, handle unitary steps. + if (StepC->getValue()->equalsInt(1)) // 1*N = -Start (mod 2^BW), so: + return SE.getNegativeSCEV(Start); // N = -Start (as unsigned) + if (StepC->getValue()->isAllOnesValue()) // -1*N = -Start (mod 2^BW), so: + return Start; // N = Start (as unsigned) + + // Then, try to solve the above equation provided that Start is constant. + if (SCEVConstant *StartC = dyn_cast(Start)) + return SolveLinEquationWithOverflow(StepC->getValue()->getValue(), + -StartC->getValue()->getValue(),SE); } } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) { // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of // the quadratic equation to solve it. - std::pair Roots = SolveQuadraticEquation(AddRec); + std::pair Roots = SolveQuadraticEquation(AddRec, SE); SCEVConstant *R1 = dyn_cast(Roots.first); SCEVConstant *R2 = dyn_cast(Roots.second); if (R1) { @@ -2178,10 +2710,9 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { // We can only use this value if the chrec ends up with an exact zero // value at this index. When solving for "X*X != 5", for example, we // should not accept a root of 2. - SCEVHandle Val = AddRec->evaluateAtIteration(R1); - if (SCEVConstant *EvalVal = dyn_cast(Val)) - if (EvalVal->getValue()->isNullValue()) - return R1; // We found a quadratic root! + SCEVHandle Val = AddRec->evaluateAtIteration(R1, SE); + if (Val->isZero()) + return R1; // We found a quadratic root! } } } @@ -2200,11 +2731,8 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) { // If the value is a constant, check to see if it is known to be non-zero // already. If so, the backedge will execute zero times. if (SCEVConstant *C = dyn_cast(V)) { - Constant *Zero = Constant::getNullValue(C->getValue()->getType()); - Constant *NonZero = - ConstantExpr::getICmp(ICmpInst::ICMP_NE, C->getValue(), Zero); - if (NonZero == ConstantInt::getTrue()) - return getSCEV(Zero); + if (!C->getValue()->isNullValue()) + return SE.getIntegerSCEV(0, C->getType()); return UnknownValue; // Otherwise it will loop infinitely. } @@ -2213,11 +2741,152 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) { return UnknownValue; } +/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB +/// (which may not be an immediate predecessor) which has exactly one +/// successor from which BB is reachable, or null if no such block is +/// found. +/// +BasicBlock * +ScalarEvolutionsImpl::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { + // If the block has a unique predecessor, the predecessor must have + // no other successors from which BB is reachable. + if (BasicBlock *Pred = BB->getSinglePredecessor()) + return Pred; + + // A loop's header is defined to be a block that dominates the loop. + // If the loop has a preheader, it must be a block that has exactly + // one successor that can reach BB. This is slightly more strict + // than necessary, but works if critical edges are split. + if (Loop *L = LI.getLoopFor(BB)) + return L->getLoopPreheader(); + + return 0; +} + +/// executesAtLeastOnce - Test whether entry to the loop is protected by +/// a conditional between LHS and RHS. +bool ScalarEvolutionsImpl::executesAtLeastOnce(const Loop *L, bool isSigned, + bool trueWhenEqual, + SCEV *LHS, SCEV *RHS) { + BasicBlock *Preheader = L->getLoopPreheader(); + BasicBlock *PreheaderDest = L->getHeader(); + + // Starting at the preheader, climb up the predecessor chain, as long as + // there are predecessors that can be found that have unique successors + // leading to the original header. + for (; Preheader; + PreheaderDest = Preheader, + Preheader = getPredecessorWithUniqueSuccessorForBB(Preheader)) { + + BranchInst *LoopEntryPredicate = + dyn_cast(Preheader->getTerminator()); + if (!LoopEntryPredicate || + LoopEntryPredicate->isUnconditional()) + continue; + + ICmpInst *ICI = dyn_cast(LoopEntryPredicate->getCondition()); + if (!ICI) continue; + + // Now that we found a conditional branch that dominates the loop, check to + // see if it is the comparison we are looking for. + Value *PreCondLHS = ICI->getOperand(0); + Value *PreCondRHS = ICI->getOperand(1); + ICmpInst::Predicate Cond; + if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest) + Cond = ICI->getPredicate(); + else + Cond = ICI->getInversePredicate(); + + switch (Cond) { + case ICmpInst::ICMP_UGT: + if (isSigned || trueWhenEqual) continue; + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_ULT; + break; + case ICmpInst::ICMP_SGT: + if (!isSigned || trueWhenEqual) continue; + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_SLT; + break; + case ICmpInst::ICMP_ULT: + if (isSigned || trueWhenEqual) continue; + break; + case ICmpInst::ICMP_SLT: + if (!isSigned || trueWhenEqual) continue; + break; + case ICmpInst::ICMP_UGE: + if (isSigned || !trueWhenEqual) continue; + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_ULE; + break; + case ICmpInst::ICMP_SGE: + if (!isSigned || !trueWhenEqual) continue; + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_SLE; + break; + case ICmpInst::ICMP_ULE: + if (isSigned || !trueWhenEqual) continue; + break; + case ICmpInst::ICMP_SLE: + if (!isSigned || !trueWhenEqual) continue; + break; + default: + continue; + } + + if (!PreCondLHS->getType()->isInteger()) continue; + + SCEVHandle PreCondLHSSCEV = getSCEV(PreCondLHS); + SCEVHandle PreCondRHSSCEV = getSCEV(PreCondRHS); + if ((LHS == PreCondLHSSCEV && RHS == PreCondRHSSCEV) || + (LHS == SE.getNotSCEV(PreCondRHSSCEV) && + RHS == SE.getNotSCEV(PreCondLHSSCEV))) + return true; + } + + return false; +} + +/// potentialInfiniteLoop - Test whether the loop might jump over the exit value +/// due to wrapping around 2^n. +bool ScalarEvolutionsImpl::potentialInfiniteLoop(SCEV *Stride, SCEV *RHS, + bool isSigned, bool trueWhenEqual) { + // Return true when the distance from RHS to maxint > Stride. + + if (!isa(Stride)) + return true; + SCEVConstant *SC = cast(Stride); + + if (SC->getValue()->isZero()) + return true; + if (!trueWhenEqual && SC->getValue()->isOne()) + return false; + + if (!isa(RHS)) + return true; + SCEVConstant *R = cast(RHS); + + if (isSigned) + return true; // XXX: because we don't have an sdiv scev. + + // If negative, it wraps around every iteration, but we don't care about that. + APInt S = SC->getValue()->getValue().abs(); + + APInt Dist = APInt::getMaxValue(R->getValue()->getBitWidth()) - + R->getValue()->getValue(); + + if (trueWhenEqual) + return !S.ult(Dist); + else + return !S.ule(Dist); +} + /// HowManyLessThans - Return the number of times a backedge containing the /// specified less-than comparison will execute. If not computable, return /// UnknownValue. SCEVHandle ScalarEvolutionsImpl:: -HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) { +HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, + bool isSigned, bool trueWhenEqual) { // Only handle: "ADDREC < LoopInvariant". if (!RHS->isLoopInvariant(L)) return UnknownValue; @@ -2226,83 +2895,50 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) { return UnknownValue; if (AddRec->isAffine()) { - // FORNOW: We only support unit strides. - SCEVHandle One = SCEVUnknown::getIntegerSCEV(1, RHS->getType()); - if (AddRec->getOperand(1) != One) + SCEVHandle Stride = AddRec->getOperand(1); + if (potentialInfiniteLoop(Stride, RHS, isSigned, trueWhenEqual)) return UnknownValue; - // The number of iterations for "[n,+,1] < m", is m-n. However, we don't - // know that m is >= n on input to the loop. If it is, the condition return - // true zero times. What we really should return, for full generality, is - // SMAX(0, m-n). Since we cannot check this, we will instead check for a - // canonical loop form: most do-loops will have a check that dominates the - // loop, that only enters the loop if [n-1]= n. + // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant + // m. So, we count the number of iterations in which {n,+,s} < m is true. + // Note that we cannot simply return max(m-n,0)/s because it's not safe to + // treat m-n as signed nor unsigned due to overflow possibility. - // Search for the check. - BasicBlock *Preheader = L->getLoopPreheader(); - BasicBlock *PreheaderDest = L->getHeader(); - if (Preheader == 0) return UnknownValue; + // First, we get the value of the LHS in the first iteration: n + SCEVHandle Start = AddRec->getOperand(0); - BranchInst *LoopEntryPredicate = - dyn_cast(Preheader->getTerminator()); - if (!LoopEntryPredicate) return UnknownValue; - - // This might be a critical edge broken out. If the loop preheader ends in - // an unconditional branch to the loop, check to see if the preheader has a - // single predecessor, and if so, look for its terminator. - while (LoopEntryPredicate->isUnconditional()) { - PreheaderDest = Preheader; - Preheader = Preheader->getSinglePredecessor(); - if (!Preheader) return UnknownValue; // Multiple preds. - - LoopEntryPredicate = - dyn_cast(Preheader->getTerminator()); - if (!LoopEntryPredicate) return UnknownValue; - } + SCEVHandle One = SE.getIntegerSCEV(1, RHS->getType()); - // Now that we found a conditional branch that dominates the loop, check to - // see if it is the comparison we are looking for. - if (ICmpInst *ICI = dyn_cast(LoopEntryPredicate->getCondition())){ - Value *PreCondLHS = ICI->getOperand(0); - Value *PreCondRHS = ICI->getOperand(1); - ICmpInst::Predicate Cond; - if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest) - Cond = ICI->getPredicate(); - else - Cond = ICI->getInversePredicate(); - - switch (Cond) { - case ICmpInst::ICMP_UGT: - std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_ULT; - break; - case ICmpInst::ICMP_SGT: - std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_SLT; - break; - default: break; - } + // Assuming that the loop will run at least once, we know that it will + // run (m-n)/s times. + SCEVHandle End = RHS; - if (Cond == ICmpInst::ICMP_SLT) { - if (PreCondLHS->getType()->isInteger()) { - if (RHS != getSCEV(PreCondRHS)) - return UnknownValue; // Not a comparison against 'm'. + if (!executesAtLeastOnce(L, isSigned, trueWhenEqual, + SE.getMinusSCEV(Start, One), RHS)) { + // If not, we get the value of the LHS in the first iteration in which + // the above condition doesn't hold. This equals to max(m,n). + End = isSigned ? SE.getSMaxExpr(RHS, Start) + : SE.getUMaxExpr(RHS, Start); + } - if (SCEV::getMinusSCEV(AddRec->getOperand(0), One) - != getSCEV(PreCondLHS)) - return UnknownValue; // Not a comparison against 'n-1'. - } - else return UnknownValue; - } else if (Cond == ICmpInst::ICMP_ULT) - return UnknownValue; + // If the expression is less-than-or-equal to, we need to extend the + // loop by one iteration. + // + // The loop won't actually run (m-n)/s times because the loop iterations + // won't divide evenly. For example, if you have {2,+,5} u< 10 the + // division would equal one, but the loop runs twice putting the + // induction variable at 12. + + if (!trueWhenEqual) + // (Stride - 1) is correct only because we know it's unsigned. + // What we really want is to decrease the magnitude of Stride by one. + Start = SE.getMinusSCEV(Start, SE.getMinusSCEV(Stride, One)); + else + Start = SE.getMinusSCEV(Start, Stride); - // cerr << "Computed Loop Trip Count as: " - // << // *SCEV::getMinusSCEV(RHS, AddRec->getOperand(0)) << "\n"; - return SCEV::getMinusSCEV(RHS, AddRec->getOperand(0)); - } - else - return UnknownValue; + // Finally, we subtract these two values to get the number of times the + // backedge is executed: max(m,n)-n. + return SE.getUDivExpr(SE.getMinusSCEV(End, Start), Stride); } return UnknownValue; @@ -2313,20 +2949,20 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) { /// this is that it returns the first iteration number where the value is not in /// the condition, thus computing the exit count. If the iteration count can't /// be computed, an instance of SCEVCouldNotCompute is returned. -SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, - bool isSigned) const { +SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, + ScalarEvolution &SE) const { if (Range.isFullSet()) // Infinite loop. return new SCEVCouldNotCompute(); // If the start is a non-zero constant, shift the range to simplify things. if (SCEVConstant *SC = dyn_cast(getStart())) - if (!SC->getValue()->isNullValue()) { + if (!SC->getValue()->isZero()) { std::vector Operands(op_begin(), op_end()); - Operands[0] = SCEVUnknown::getIntegerSCEV(0, SC->getType()); - SCEVHandle Shifted = SCEVAddRecExpr::get(Operands, getLoop()); + Operands[0] = SE.getIntegerSCEV(0, SC->getType()); + SCEVHandle Shifted = SE.getAddRecExpr(Operands, getLoop()); if (SCEVAddRecExpr *ShiftedAddRec = dyn_cast(Shifted)) return ShiftedAddRec->getNumIterationsInRange( - Range.subtract(SC->getValue()->getValue()),isSigned); + Range.subtract(SC->getValue()->getValue()), SE); // This is strange and shouldn't happen. return new SCEVCouldNotCompute(); } @@ -2343,52 +2979,50 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // First check to see if the range contains zero. If not, the first // iteration exits. - if (!Range.contains(APInt(getBitWidth(),0), isSigned)) - return SCEVConstant::get(ConstantInt::get(getType(),0)); + if (!Range.contains(APInt(getBitWidth(),0))) + return SE.getConstant(ConstantInt::get(getType(),0)); if (isAffine()) { // If this is an affine expression then we have this situation: // Solve {0,+,A} in Range === Ax in Range - // Since we know that zero is in the range, we know that the upper value of - // the range must be the first possible exit value. Also note that we - // already checked for a full range. - const APInt &Upper = Range.getUpper(); - APInt A = cast(getOperand(1))->getValue()->getValue(); + // We know that zero is in the range. If A is positive then we know that + // the upper value of the range must be the first possible exit value. + // If A is negative then the lower of the range is the last possible loop + // value. Also note that we already checked for a full range. APInt One(getBitWidth(),1); + APInt A = cast(getOperand(1))->getValue()->getValue(); + APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower(); - // The exit value should be (Upper+A-1)/A. - APInt ExitVal(Upper); - if (A != One) - ExitVal = (Upper + A - One).sdiv(A); - ConstantInt *ExitValue = ConstantInt::get(getType(), ExitVal); + // The exit value should be (End+A)/A. + APInt ExitVal = (End + A).udiv(A); + ConstantInt *ExitValue = ConstantInt::get(ExitVal); // Evaluate at the exit value. If we really did fall out of the valid // range, then we computed our trip count, otherwise wrap around or other // things must have happened. - ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue); - if (Range.contains(Val->getValue(), isSigned)) + ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE); + if (Range.contains(Val->getValue())) return new SCEVCouldNotCompute(); // Something strange happened // Ensure that the previous value is in the range. This is a sanity check. assert(Range.contains( EvaluateConstantChrecAtConstant(this, - ConstantInt::get(getType(), ExitVal - One))->getValue(), isSigned) && + ConstantInt::get(ExitVal - One), SE)->getValue()) && "Linear scev computation is off in a bad way!"); - return SCEVConstant::get(cast(ExitValue)); + return SE.getConstant(ExitValue); } else if (isQuadratic()) { // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the // quadratic equation to solve it. To do this, we must frame our problem in // terms of figuring out when zero is crossed, instead of when // Range.getUpper() is crossed. std::vector NewOps(op_begin(), op_end()); - NewOps[0] = SCEV::getNegativeSCEV(SCEVUnknown::get( - ConstantInt::get(getType(), Range.getUpper()))); - SCEVHandle NewAddRec = SCEVAddRecExpr::get(NewOps, getLoop()); + NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); + SCEVHandle NewAddRec = SE.getAddRecExpr(NewOps, getLoop()); // Next, solve the constructed addrec std::pair Roots = - SolveQuadraticEquation(cast(NewAddRec)); + SolveQuadraticEquation(cast(NewAddRec), SE); SCEVConstant *R1 = dyn_cast(Roots.first); SCEVConstant *R2 = dyn_cast(Roots.second); if (R1) { @@ -2403,55 +3037,29 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // not be in the range, but the previous one should be. When solving // for "X*X < 5", for example, we should not return a root of 2. ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this, - R1->getValue()); - if (Range.contains(R1Val->getValue(), isSigned)) { + R1->getValue(), + SE); + if (Range.contains(R1Val->getValue())) { // The next iteration must be out of the range... - Constant *NextVal = - ConstantExpr::getAdd(R1->getValue(), - ConstantInt::get(R1->getType(), 1)); + ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()+1); - R1Val = EvaluateConstantChrecAtConstant(this, NextVal); - if (!Range.contains(R1Val->getValue(), isSigned)) - return SCEVUnknown::get(NextVal); + R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); + if (!Range.contains(R1Val->getValue())) + return SE.getConstant(NextVal); return new SCEVCouldNotCompute(); // Something strange happened } // If R1 was not in the range, then it is a good return value. Make // sure that R1-1 WAS in the range though, just in case. - Constant *NextVal = - ConstantExpr::getSub(R1->getValue(), - ConstantInt::get(R1->getType(), 1)); - R1Val = EvaluateConstantChrecAtConstant(this, NextVal); - if (Range.contains(R1Val->getValue(), isSigned)) + ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()-1); + R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); + if (Range.contains(R1Val->getValue())) return R1; return new SCEVCouldNotCompute(); // Something strange happened } } } - // Fallback, if this is a general polynomial, figure out the progression - // through brute force: evaluate until we find an iteration that fails the - // test. This is likely to be slow, but getting an accurate trip count is - // incredibly important, we will be able to simplify the exit test a lot, and - // we are almost guaranteed to get a trip count in this case. - ConstantInt *TestVal = ConstantInt::get(getType(), 0); - ConstantInt *One = ConstantInt::get(getType(), 1); - ConstantInt *EndVal = TestVal; // Stop when we wrap around. - do { - ++NumBruteForceEvaluations; - SCEVHandle Val = evaluateAtIteration(SCEVConstant::get(TestVal)); - if (!isa(Val)) // This shouldn't happen. - return new SCEVCouldNotCompute(); - - // Check to see if we found the value! - if (!Range.contains(cast(Val)->getValue()->getValue(), - isSigned)) - return SCEVConstant::get(TestVal); - - // Increment to test the next index. - TestVal = cast(ConstantExpr::getAdd(TestVal, One)); - } while (TestVal != EndVal); - return new SCEVCouldNotCompute(); } @@ -2462,7 +3070,7 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, //===----------------------------------------------------------------------===// bool ScalarEvolution::runOnFunction(Function &F) { - Impl = new ScalarEvolutionsImpl(F, getAnalysis()); + Impl = new ScalarEvolutionsImpl(*this, F, getAnalysis()); return false; } @@ -2506,8 +3114,8 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) const { return ((ScalarEvolutionsImpl*)Impl)->getSCEVAtScope(getSCEV(V), L); } -void ScalarEvolution::deleteInstructionFromRecords(Instruction *I) const { - return ((ScalarEvolutionsImpl*)Impl)->deleteInstructionFromRecords(I); +void ScalarEvolution::deleteValueFromRecords(Value *V) const { + return ((ScalarEvolutionsImpl*)Impl)->deleteValueFromRecords(V); } static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE, @@ -2516,20 +3124,20 @@ static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE, for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) PrintLoopInfo(OS, SE, *I); - cerr << "Loop " << L->getHeader()->getName() << ": "; + OS << "Loop " << L->getHeader()->getName() << ": "; - std::vector ExitBlocks; + SmallVector ExitBlocks; L->getExitBlocks(ExitBlocks); if (ExitBlocks.size() != 1) - cerr << " "; + OS << " "; if (SE->hasLoopInvariantIterationCount(L)) { - cerr << *SE->getIterationCount(L) << " iterations! "; + OS << *SE->getIterationCount(L) << " iterations! "; } else { - cerr << "Unpredictable iteration count. "; + OS << "Unpredictable iteration count. "; } - cerr << "\n"; + OS << "\n"; } void ScalarEvolution::print(std::ostream &OS, const Module* ) const { @@ -2540,17 +3148,11 @@ void ScalarEvolution::print(std::ostream &OS, const Module* ) const { for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) if (I->getType()->isInteger()) { OS << *I; - OS << " --> "; + OS << " --> "; SCEVHandle SV = getSCEV(&*I); SV->print(OS); OS << "\t\t"; - if ((*I).getType()->isInteger()) { - ConstantRange Bounds = SV->getValueRange(); - if (!Bounds.isFullSet()) - OS << "Bounds: " << Bounds << " "; - } - if (const Loop *L = LI.getLoopFor((*I).getParent())) { OS << "Exits: "; SCEVHandle ExitValue = getSCEVAtScope(&*I, L->getParentLoop()); @@ -2569,4 +3171,3 @@ void ScalarEvolution::print(std::ostream &OS, const Module* ) const { for (LoopInfo::iterator I = LI.begin(), E = LI.end(); I != E; ++I) PrintLoopInfo(OS, this, *I); } -