X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=069f6ec714cc54a01d21df51ad5885a0b0387927;hb=9a2f93121b31bf6345d1552bdc43037f89714d86;hp=7d73d7d39819916cfcb503ce4a154ff6deb10337;hpb=d977d8651a5cd26a3e1088267f31cade405f2adf;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 7d73d7d3981..069f6ec714c 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -59,6 +59,7 @@ // //===----------------------------------------------------------------------===// +#define DEBUG_TYPE "scalar-evolution" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Constants.h" #include "llvm/DerivedTypes.h" @@ -74,6 +75,7 @@ #include "llvm/Support/ConstantRange.h" #include "llvm/Support/InstIterator.h" #include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/Streams.h" #include "llvm/ADT/Statistic.h" #include @@ -81,33 +83,29 @@ #include using namespace llvm; +STATISTIC(NumBruteForceEvaluations, + "Number of brute force evaluations needed to " + "calculate high-order polynomial exit values"); +STATISTIC(NumArrayLenItCounts, + "Number of trip counts computed with array length"); +STATISTIC(NumTripCountsComputed, + "Number of loops with predictable loop counts"); +STATISTIC(NumTripCountsNotComputed, + "Number of loops without predictable loop counts"); +STATISTIC(NumBruteForceTripCountsComputed, + "Number of loops with trip counts computed by force"); + +cl::opt +MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, + cl::desc("Maximum number of iterations SCEV will " + "symbolically execute a constant derived loop"), + cl::init(100)); + namespace { RegisterPass R("scalar-evolution", "Scalar Evolution Analysis"); - - Statistic - NumBruteForceEvaluations("scalar-evolution", - "Number of brute force evaluations needed to " - "calculate high-order polynomial exit values"); - Statistic - NumArrayLenItCounts("scalar-evolution", - "Number of trip counts computed with array length"); - Statistic - NumTripCountsComputed("scalar-evolution", - "Number of loops with predictable loop counts"); - Statistic - NumTripCountsNotComputed("scalar-evolution", - "Number of loops without predictable loop counts"); - Statistic - NumBruteForceTripCountsComputed("scalar-evolution", - "Number of loops with trip counts computed by force"); - - cl::opt - MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, - cl::desc("Maximum number of iterations SCEV will " - "symbolically execute a constant derived loop"), - cl::init(100)); } +char ScalarEvolution::ID = 0; //===----------------------------------------------------------------------===// // SCEV class definitions @@ -126,9 +124,14 @@ void SCEV::dump() const { ConstantRange SCEV::getValueRange() const { const Type *Ty = getType(); assert(Ty->isInteger() && "Can't get range for a non-integer SCEV!"); - Ty = Ty->getUnsignedVersion(); // Default to a full range if no better information is available. - return ConstantRange(getType()); + return ConstantRange(getBitWidth()); +} + +uint32_t SCEV::getBitWidth() const { + if (const IntegerType* ITy = dyn_cast(getType())) + return ITy->getBitWidth(); + return 0; } @@ -175,20 +178,17 @@ SCEVConstant::~SCEVConstant() { } SCEVHandle SCEVConstant::get(ConstantInt *V) { - // Make sure that SCEVConstant instances are all unsigned. - if (V->getType()->isSigned()) { - const Type *NewTy = V->getType()->getUnsignedVersion(); - V = cast( - ConstantExpr::getBitCast(V, NewTy)); - } - SCEVConstant *&R = (*SCEVConstants)[V]; if (R == 0) R = new SCEVConstant(V); return R; } +SCEVHandle SCEVConstant::get(const APInt& Val) { + return get(ConstantInt::get(Val)); +} + ConstantRange SCEVConstant::getValueRange() const { - return ConstantRange(V); + return ConstantRange(V->getValue()); } const Type *SCEVConstant::getType() const { return V->getType(); } @@ -207,8 +207,8 @@ SCEVTruncateExpr::SCEVTruncateExpr(const SCEVHandle &op, const Type *ty) : SCEV(scTruncate), Op(op), Ty(ty) { assert(Op->getType()->isInteger() && Ty->isInteger() && "Cannot truncate non-integer value!"); - assert(Op->getType()->getPrimitiveSize() > Ty->getPrimitiveSize() && - "This is not a truncating conversion!"); + assert(Op->getType()->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits() + && "This is not a truncating conversion!"); } SCEVTruncateExpr::~SCEVTruncateExpr() { @@ -216,7 +216,7 @@ SCEVTruncateExpr::~SCEVTruncateExpr() { } ConstantRange SCEVTruncateExpr::getValueRange() const { - return getOperand()->getValueRange().truncate(getType()); + return getOperand()->getValueRange().truncate(getBitWidth()); } void SCEVTruncateExpr::print(std::ostream &OS) const { @@ -233,8 +233,8 @@ SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEVHandle &op, const Type *ty) : SCEV(scZeroExtend), Op(op), Ty(ty) { assert(Op->getType()->isInteger() && Ty->isInteger() && "Cannot zero extend non-integer value!"); - assert(Op->getType()->getPrimitiveSize() < Ty->getPrimitiveSize() && - "This is not an extending conversion!"); + assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits() + && "This is not an extending conversion!"); } SCEVZeroExtendExpr::~SCEVZeroExtendExpr() { @@ -242,13 +242,39 @@ SCEVZeroExtendExpr::~SCEVZeroExtendExpr() { } ConstantRange SCEVZeroExtendExpr::getValueRange() const { - return getOperand()->getValueRange().zeroExtend(getType()); + return getOperand()->getValueRange().zeroExtend(getBitWidth()); } void SCEVZeroExtendExpr::print(std::ostream &OS) const { OS << "(zeroextend " << *Op << " to " << *Ty << ")"; } +// SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any +// particular input. Don't use a SCEVHandle here, or else the object will never +// be deleted! +static ManagedStatic, + SCEVSignExtendExpr*> > SCEVSignExtends; + +SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty) + : SCEV(scSignExtend), Op(op), Ty(ty) { + assert(Op->getType()->isInteger() && Ty->isInteger() && + "Cannot sign extend non-integer value!"); + assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits() + && "This is not an extending conversion!"); +} + +SCEVSignExtendExpr::~SCEVSignExtendExpr() { + SCEVSignExtends->erase(std::make_pair(Op, Ty)); +} + +ConstantRange SCEVSignExtendExpr::getValueRange() const { + return getOperand()->getValueRange().signExtend(getBitWidth()); +} + +void SCEVSignExtendExpr::print(std::ostream &OS) const { + OS << "(signextend " << *Op << " to " << *Ty << ")"; +} + // SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any // particular input. Don't use a SCEVHandle here, or else the object will never // be deleted! @@ -312,9 +338,7 @@ void SCEVSDivExpr::print(std::ostream &OS) const { } const Type *SCEVSDivExpr::getType() const { - const Type *Ty = LHS->getType(); - if (Ty->isUnsigned()) Ty = Ty->getSignedVersion(); - return Ty; + return LHS->getType(); } // SCEVAddRecExprs - Only allow the creation of one SCEVAddRecExpr for any @@ -461,13 +485,10 @@ SCEVHandle SCEVUnknown::getIntegerSCEV(int Val, const Type *Ty) { if (Val == 0) C = Constant::getNullValue(Ty); else if (Ty->isFloatingPoint()) - C = ConstantFP::get(Ty, Val); - else if (Ty->isSigned()) + C = ConstantFP::get(Ty, APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : + APFloat::IEEEdouble, Val)); + else C = ConstantInt::get(Ty, Val); - else { - C = ConstantInt::get(Ty->getSignedVersion(), Val); - C = ConstantExpr::getBitCast(C, Ty); - } return SCEVUnknown::get(C); } @@ -478,9 +499,9 @@ static SCEVHandle getTruncateOrZeroExtend(const SCEVHandle &V, const Type *Ty) { const Type *SrcTy = V->getType(); assert(SrcTy->isInteger() && Ty->isInteger() && "Cannot truncate or zero extend with non-integer arguments!"); - if (SrcTy->getPrimitiveSize() == Ty->getPrimitiveSize()) + if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits()) return V; // No conversion - if (SrcTy->getPrimitiveSize() > Ty->getPrimitiveSize()) + if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits()) return SCEVTruncateExpr::get(V, Ty); return SCEVZeroExtendExpr::get(V, Ty); } @@ -507,13 +528,11 @@ static SCEVHandle PartialFact(SCEVHandle V, unsigned NumSteps) { // Handle this case efficiently, it is common to have constant iteration // counts while computing loop exit values. if (SCEVConstant *SC = dyn_cast(V)) { - uint64_t Val = SC->getValue()->getZExtValue(); - uint64_t Result = 1; + const APInt& Val = SC->getValue()->getValue(); + APInt Result(Val.getBitWidth(), 1); for (; NumSteps; --NumSteps) Result *= Val-(NumSteps-1); - Constant *Res = ConstantInt::get(Type::ULongTy, Result); - return SCEVUnknown::get( - ConstantExpr::getTruncOrBitCast(Res, V->getType())); + return SCEVConstant::get(Result); } const Type *Ty = V->getType(); @@ -596,6 +615,21 @@ SCEVHandle SCEVZeroExtendExpr::get(const SCEVHandle &Op, const Type *Ty) { return Result; } +SCEVHandle SCEVSignExtendExpr::get(const SCEVHandle &Op, const Type *Ty) { + if (SCEVConstant *SC = dyn_cast(Op)) + return SCEVUnknown::get( + ConstantExpr::getSExt(SC->getValue(), Ty)); + + // FIXME: If the input value is a chrec scev, and we can prove that the value + // did not overflow the old, smaller, value, we can sign extend all of the + // operands (often constants). This would allow analysis of something like + // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } + + SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)]; + if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty); + return Result; +} + // get - Get a canonical add expression, or something simpler if possible. SCEVHandle SCEVAddExpr::get(std::vector &Ops) { assert(!Ops.empty() && "Cannot get empty add!"); @@ -611,7 +645,8 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { assert(Idx < Ops.size()); while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - Constant *Fold = ConstantExpr::getAdd(LHSC->getValue(), RHSC->getValue()); + Constant *Fold = ConstantInt::get(LHSC->getValue()->getValue() + + RHSC->getValue()->getValue()); if (ConstantInt *CI = dyn_cast(Fold)) { Ops[0] = SCEVConstant::get(CI); Ops.erase(Ops.begin()+1); // Erase the folded element @@ -626,7 +661,7 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { } // If we are left with a constant zero being added, strip it off. - if (cast(Ops[0])->getValue()->isNullValue()) { + if (cast(Ops[0])->getValue()->isZero()) { Ops.erase(Ops.begin()); --Idx; } @@ -651,8 +686,11 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { return SCEVAddExpr::get(Ops); } - // Okay, now we know the first non-constant operand. If there are add - // operands they would be next. + // Now we know the first non-constant operand. Skip past any cast SCEVs. + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr) + ++Idx; + + // If there are add operands they would be next. if (Idx < Ops.size()) { bool DeletedAdd = false; while (SCEVAddExpr *Add = dyn_cast(Ops[Idx])) { @@ -848,7 +886,8 @@ SCEVHandle SCEVMulExpr::get(std::vector &Ops) { ++Idx; while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - Constant *Fold = ConstantExpr::getMul(LHSC->getValue(), RHSC->getValue()); + Constant *Fold = ConstantInt::get(LHSC->getValue()->getValue() * + RHSC->getValue()->getValue()); if (ConstantInt *CI = dyn_cast(Fold)) { Ops[0] = SCEVConstant::get(CI); Ops.erase(Ops.begin()+1); // Erase the folded element @@ -866,7 +905,7 @@ SCEVHandle SCEVMulExpr::get(std::vector &Ops) { if (cast(Ops[0])->getValue()->equalsInt(1)) { Ops.erase(Ops.begin()); --Idx; - } else if (cast(Ops[0])->getValue()->isNullValue()) { + } else if (cast(Ops[0])->getValue()->isZero()) { // If we have a multiply of zero, it will always be zero. return Ops[0]; } @@ -1035,7 +1074,7 @@ SCEVHandle SCEVAddRecExpr::get(std::vector &Operands, if (Operands.size() == 1) return Operands[0]; if (SCEVConstant *StepC = dyn_cast(Operands.back())) - if (StepC->getValue()->isNullValue()) { + if (StepC->getValue()->isZero()) { Operands.pop_back(); return get(Operands, L); // { X,+,0 } --> X } @@ -1129,10 +1168,10 @@ namespace { /// loop without a loop-invariant iteration count. SCEVHandle getIterationCount(const Loop *L); - /// deleteInstructionFromRecords - This method should be called by the - /// client before it removes an instruction from the program, to make sure + /// deleteValueFromRecords - This method should be called by the + /// client before it removes a value from the program, to make sure /// that no dangling references are left around. - void deleteInstructionFromRecords(Instruction *I); + void deleteValueFromRecords(Value *V); private: /// createSCEV - We know that there is no SCEV for the specified value. @@ -1156,11 +1195,11 @@ namespace { SCEVHandle ComputeIterationCount(const Loop *L); /// ComputeLoadConstantCompareIterationCount - Given an exit condition of - /// 'setcc load X, cst', try to se if we can compute the trip count. + /// 'setcc load X, cst', try to see if we can compute the trip count. SCEVHandle ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, const Loop *L, - unsigned SetCCOpcode); + ICmpInst::Predicate p); /// ComputeIterationCountExhaustively - If the trip is known to execute a /// constant number of times (the condition evolves only from constants), @@ -1182,14 +1221,15 @@ namespace { /// HowManyLessThans - Return the number of times a backedge containing the /// specified less-than comparison will execute. If not computable, return - /// UnknownValue. - SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L); + /// UnknownValue. isSigned specifies whether the less-than is signed. + SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, + bool isSigned); /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a /// constant number of times, and the PHI node is just a recurrence /// involving constants, fold it. - Constant *getConstantEvolutionLoopExitValue(PHINode *PN, uint64_t Its, + Constant *getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its, const Loop *L); }; } @@ -1198,13 +1238,32 @@ namespace { // Basic SCEV Analysis and PHI Idiom Recognition Code // -/// deleteInstructionFromRecords - This method should be called by the +/// deleteValueFromRecords - This method should be called by the /// client before it removes an instruction from the program, to make sure /// that no dangling references are left around. -void ScalarEvolutionsImpl::deleteInstructionFromRecords(Instruction *I) { - Scalars.erase(I); - if (PHINode *PN = dyn_cast(I)) - ConstantEvolutionLoopExitValue.erase(PN); +void ScalarEvolutionsImpl::deleteValueFromRecords(Value *V) { + SmallVector Worklist; + + if (Scalars.erase(V)) { + if (PHINode *PN = dyn_cast(V)) + ConstantEvolutionLoopExitValue.erase(PN); + Worklist.push_back(V); + } + + while (!Worklist.empty()) { + Value *VV = Worklist.back(); + Worklist.pop_back(); + + for (Instruction::use_iterator UI = VV->use_begin(), UE = VV->use_end(); + UI != UE; ++UI) { + Instruction *Inst = cast(*UI); + if (Scalars.erase(Inst)) { + if (PHINode *PN = dyn_cast(VV)) + ConstantEvolutionLoopExitValue.erase(PN); + Worklist.push_back(Inst); + } + } + } } @@ -1341,46 +1400,61 @@ SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) { /// GetConstantFactor - Determine the largest constant factor that S has. For /// example, turn {4,+,8} -> 4. (S umod result) should always equal zero. -static uint64_t GetConstantFactor(SCEVHandle S) { +static APInt GetConstantFactor(SCEVHandle S) { if (SCEVConstant *C = dyn_cast(S)) { - if (uint64_t V = C->getValue()->getZExtValue()) + const APInt& V = C->getValue()->getValue(); + if (!V.isMinValue()) return V; else // Zero is a multiple of everything. - return 1ULL << (S->getType()->getPrimitiveSizeInBits()-1); + return APInt(C->getBitWidth(), 1).shl(C->getBitWidth()-1); } - if (SCEVTruncateExpr *T = dyn_cast(S)) - return GetConstantFactor(T->getOperand()) & - T->getType()->getIntegralTypeMask(); + if (SCEVTruncateExpr *T = dyn_cast(S)) { + return GetConstantFactor(T->getOperand()).trunc( + cast(T->getType())->getBitWidth()); + } if (SCEVZeroExtendExpr *E = dyn_cast(S)) - return GetConstantFactor(E->getOperand()); + return GetConstantFactor(E->getOperand()).zext( + cast(E->getType())->getBitWidth()); + if (SCEVSignExtendExpr *E = dyn_cast(S)) + return GetConstantFactor(E->getOperand()).sext( + cast(E->getType())->getBitWidth()); if (SCEVAddExpr *A = dyn_cast(S)) { // The result is the min of all operands. - uint64_t Res = GetConstantFactor(A->getOperand(0)); - for (unsigned i = 1, e = A->getNumOperands(); i != e && Res > 1; ++i) - Res = std::min(Res, GetConstantFactor(A->getOperand(i))); + APInt Res(GetConstantFactor(A->getOperand(0))); + for (unsigned i = 1, e = A->getNumOperands(); + i != e && Res.ugt(APInt(Res.getBitWidth(),1)); ++i) { + APInt Tmp(GetConstantFactor(A->getOperand(i))); + Res = APIntOps::umin(Res, Tmp); + } return Res; } if (SCEVMulExpr *M = dyn_cast(S)) { // The result is the product of all the operands. - uint64_t Res = GetConstantFactor(M->getOperand(0)); - for (unsigned i = 1, e = M->getNumOperands(); i != e; ++i) - Res *= GetConstantFactor(M->getOperand(i)); + APInt Res(GetConstantFactor(M->getOperand(0))); + for (unsigned i = 1, e = M->getNumOperands(); i != e; ++i) { + APInt Tmp(GetConstantFactor(M->getOperand(i))); + Res *= Tmp; + } return Res; } if (SCEVAddRecExpr *A = dyn_cast(S)) { - // FIXME: Generalize. - if (A->getNumOperands() == 2) - return std::min(GetConstantFactor(A->getOperand(0)), - GetConstantFactor(A->getOperand(1))); - // ? + // For now, we just handle linear expressions. + if (A->getNumOperands() == 2) { + // We want the GCD between the start and the stride value. + APInt Start(GetConstantFactor(A->getOperand(0))); + if (Start == 1) + return Start; + APInt Stride(GetConstantFactor(A->getOperand(1))); + return APIntOps::GreatestCommonDivisor(Start, Stride); + } } // SCEVSDivExpr, SCEVUnknown. - return 1; + return APInt(S->getBitWidth(), 1); } /// createSCEV - We know that there is no SCEV for the specified value. @@ -1409,42 +1483,49 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { // optimizations will transparently handle this case. if (ConstantInt *CI = dyn_cast(I->getOperand(1))) { SCEVHandle LHS = getSCEV(I->getOperand(0)); - uint64_t CommonFact = GetConstantFactor(LHS); - assert(CommonFact && "Common factor should at least be 1!"); - if (CommonFact > CI->getZExtValue()) { + APInt CommonFact(GetConstantFactor(LHS)); + assert(!CommonFact.isMinValue() && + "Common factor should at least be 1!"); + if (CommonFact.ugt(CI->getValue())) { // If the LHS is a multiple that is larger than the RHS, use +. return SCEVAddExpr::get(LHS, getSCEV(I->getOperand(1))); } } break; - + case Instruction::Xor: + // If the RHS of the xor is a signbit, then this is just an add. + // Instcombine turns add of signbit into xor as a strength reduction step. + if (ConstantInt *CI = dyn_cast(I->getOperand(1))) { + if (CI->getValue().isSignBit()) + return SCEVAddExpr::get(getSCEV(I->getOperand(0)), + getSCEV(I->getOperand(1))); + } + break; + case Instruction::Shl: // Turn shift left of a constant amount into a multiply. if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { - Constant *X = ConstantInt::get(V->getType(), 1); - X = ConstantExpr::getShl(X, SA); + uint32_t BitWidth = cast(V->getType())->getBitWidth(); + Constant *X = ConstantInt::get( + APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); return SCEVMulExpr::get(getSCEV(I->getOperand(0)), getSCEV(X)); } break; case Instruction::Trunc: - // We don't handle trunc to bool yet. - if (I->getType()->isInteger()) - return SCEVTruncateExpr::get(getSCEV(I->getOperand(0)), - I->getType()->getUnsignedVersion()); - break; + return SCEVTruncateExpr::get(getSCEV(I->getOperand(0)), I->getType()); case Instruction::ZExt: - // We don't handle zext from bool yet. - if (I->getOperand(0)->getType()->isInteger()) - return SCEVZeroExtendExpr::get(getSCEV(I->getOperand(0)), - I->getType()->getUnsignedVersion()); - break; + return SCEVZeroExtendExpr::get(getSCEV(I->getOperand(0)), I->getType()); + + case Instruction::SExt: + return SCEVSignExtendExpr::get(getSCEV(I->getOperand(0)), I->getType()); case Instruction::BitCast: // BitCasts are no-op casts so we just eliminate the cast. - if (I->getType()->isInteger() && I->getOperand(0)->getType()->isInteger()) + if (I->getType()->isInteger() && + I->getOperand(0)->getType()->isInteger()) return getSCEV(I->getOperand(0)); break; @@ -1489,7 +1570,7 @@ SCEVHandle ScalarEvolutionsImpl::getIterationCount(const Loop *L) { /// will iterate. SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { // If the loop has a non-one exit block count, we can't analyze it. - std::vector ExitBlocks; + SmallVector ExitBlocks; L->getExitBlocks(ExitBlocks); if (ExitBlocks.size() != 1) return UnknownValue; @@ -1512,21 +1593,40 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { // exit. // // FIXME: we should be able to handle switch instructions (with a single exit) - // FIXME: We should handle cast of int to bool as well BranchInst *ExitBr = dyn_cast(ExitingBlock->getTerminator()); if (ExitBr == 0) return UnknownValue; assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!"); - SetCondInst *ExitCond = dyn_cast(ExitBr->getCondition()); - if (ExitCond == 0) // Not a setcc + + // At this point, we know we have a conditional branch that determines whether + // the loop is exited. However, we don't know if the branch is executed each + // time through the loop. If not, then the execution count of the branch will + // not be equal to the trip count of the loop. + // + // Currently we check for this by checking to see if the Exit branch goes to + // the loop header. If so, we know it will always execute the same number of + // times as the loop. We also handle the case where the exit block *is* the + // loop header. This is common for un-rotated loops. More extensive analysis + // could be done to handle more cases here. + if (ExitBr->getSuccessor(0) != L->getHeader() && + ExitBr->getSuccessor(1) != L->getHeader() && + ExitBr->getParent() != L->getHeader()) + return UnknownValue; + + ICmpInst *ExitCond = dyn_cast(ExitBr->getCondition()); + + // If its not an integer comparison then compute it the hard way. + // Note that ICmpInst deals with pointer comparisons too so we must check + // the type of the operand. + if (ExitCond == 0 || isa(ExitCond->getOperand(0)->getType())) return ComputeIterationCountExhaustively(L, ExitBr->getCondition(), ExitBr->getSuccessor(0) == ExitBlock); - // If the condition was exit on true, convert the condition to exit on false. - Instruction::BinaryOps Cond; + // If the condition was exit on true, convert the condition to exit on false + ICmpInst::Predicate Cond; if (ExitBr->getSuccessor(1) == ExitBlock) - Cond = ExitCond->getOpcode(); + Cond = ExitCond->getPredicate(); else - Cond = ExitCond->getInverseCondition(); + Cond = ExitCond->getInversePredicate(); // Handle common loops like: for (X = "string"; *X; ++X) if (LoadInst *LI = dyn_cast(ExitCond->getOperand(0))) @@ -1545,12 +1645,12 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { Tmp = getSCEVAtScope(RHS, L); if (!isa(Tmp)) RHS = Tmp; - // At this point, we would like to compute how many iterations of the loop the - // predicate will return true for these inputs. + // At this point, we would like to compute how many iterations of the + // loop the predicate will return true for these inputs. if (isa(LHS) && !isa(RHS)) { // If there is a constant, force it into the RHS. std::swap(LHS, RHS); - Cond = SetCondInst::getSwappedCondition(Cond); + Cond = ICmpInst::getSwappedPredicate(Cond); } // FIXME: think about handling pointer comparisons! i.e.: @@ -1570,18 +1670,8 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { ConstantExpr::getBitCast(CompVal, RealTy)); if (CompVal) { // Form the constant range. - ConstantRange CompRange(Cond, CompVal); - - // Now that we have it, if it's signed, convert it to an unsigned - // range. - if (CompRange.getLower()->getType()->isSigned()) { - const Type *NewTy = RHSC->getValue()->getType(); - Constant *NewL = ConstantExpr::getBitCast(CompRange.getLower(), - NewTy); - Constant *NewU = ConstantExpr::getBitCast(CompRange.getUpper(), - NewTy); - CompRange = ConstantRange(NewL, NewU); - } + ConstantRange CompRange( + ICmpInst::makeConstantRange(Cond, CompVal->getValue())); SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange); if (!isa(Ret)) return Ret; @@ -1589,52 +1679,58 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { } switch (Cond) { - case Instruction::SetNE: // while (X != Y) + case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) - if (LHS->getType()->isInteger()) { - SCEVHandle TC = HowFarToZero(SCEV::getMinusSCEV(LHS, RHS), L); - if (!isa(TC)) return TC; - } + SCEVHandle TC = HowFarToZero(SCEV::getMinusSCEV(LHS, RHS), L); + if (!isa(TC)) return TC; break; - case Instruction::SetEQ: + } + case ICmpInst::ICMP_EQ: { // Convert to: while (X-Y == 0) // while (X == Y) - if (LHS->getType()->isInteger()) { - SCEVHandle TC = HowFarToNonZero(SCEV::getMinusSCEV(LHS, RHS), L); - if (!isa(TC)) return TC; - } + SCEVHandle TC = HowFarToNonZero(SCEV::getMinusSCEV(LHS, RHS), L); + if (!isa(TC)) return TC; break; - case Instruction::SetLT: - if (LHS->getType()->isInteger() && - ExitCond->getOperand(0)->getType()->isSigned()) { - SCEVHandle TC = HowManyLessThans(LHS, RHS, L); - if (!isa(TC)) return TC; - } + } + case ICmpInst::ICMP_SLT: { + SCEVHandle TC = HowManyLessThans(LHS, RHS, L, true); + if (!isa(TC)) return TC; break; - case Instruction::SetGT: - if (LHS->getType()->isInteger() && - ExitCond->getOperand(0)->getType()->isSigned()) { - SCEVHandle TC = HowManyLessThans(RHS, LHS, L); - if (!isa(TC)) return TC; - } + } + case ICmpInst::ICMP_SGT: { + SCEVHandle TC = HowManyLessThans(SCEV::getNegativeSCEV(LHS), + SCEV::getNegativeSCEV(RHS), L, true); + if (!isa(TC)) return TC; break; + } + case ICmpInst::ICMP_ULT: { + SCEVHandle TC = HowManyLessThans(LHS, RHS, L, false); + if (!isa(TC)) return TC; + break; + } + case ICmpInst::ICMP_UGT: { + SCEVHandle TC = HowManyLessThans(SCEV::getNegativeSCEV(LHS), + SCEV::getNegativeSCEV(RHS), L, false); + if (!isa(TC)) return TC; + break; + } default: #if 0 cerr << "ComputeIterationCount "; if (ExitCond->getOperand(0)->getType()->isUnsigned()) cerr << "[unsigned] "; cerr << *LHS << " " - << Instruction::getOpcodeName(Cond) << " " << *RHS << "\n"; + << Instruction::getOpcodeName(Instruction::ICmp) + << " " << *RHS << "\n"; #endif break; } - return ComputeIterationCountExhaustively(L, ExitCond, - ExitBr->getSuccessor(0) == ExitBlock); + ExitBr->getSuccessor(0) == ExitBlock); } static ConstantInt * -EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, Constant *C) { - SCEVHandle InVal = SCEVConstant::get(cast(C)); +EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C) { + SCEVHandle InVal = SCEVConstant::get(C); SCEVHandle Val = AddRec->evaluateAtIteration(InVal); assert(isa(Val) && "Evaluation of SCEV at constant didn't fold correctly?"); @@ -1679,7 +1775,8 @@ GetAddressedElementFromGlobal(GlobalVariable *GV, /// 'setcc load X, cst', try to se if we can compute the trip count. SCEVHandle ScalarEvolutionsImpl:: ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, - const Loop *L, unsigned SetCCOpcode) { + const Loop *L, + ICmpInst::Predicate predicate) { if (LI->isVolatile()) return UnknownValue; // Check to see if the loaded pointer is a getelementptr of a global. @@ -1725,7 +1822,7 @@ ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, unsigned MaxSteps = MaxBruteForceIterations; for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) { ConstantInt *ItCst = - ConstantInt::get(IdxExpr->getType()->getUnsignedVersion(), IterationNum); + ConstantInt::get(IdxExpr->getType(), IterationNum); ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst); // Form the GEP offset. @@ -1735,9 +1832,9 @@ ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, if (Result == 0) break; // Cannot compute! // Evaluate the condition for this iteration. - Result = ConstantExpr::get(SetCCOpcode, Result, RHS); - if (!isa(Result)) break; // Couldn't decide for sure - if (cast(Result)->getValue() == false) { + Result = ConstantExpr::getICmp(predicate, Result, RHS); + if (!isa(Result)) break; // Couldn't decide for sure + if (cast(Result)->getValue().isMinValue()) { #if 0 cerr << "\n***\n*** Computed loop count " << *ItCst << "\n*** From global " << *GV << "*** BB: " << *L->getHeader() @@ -1754,7 +1851,7 @@ ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, /// CanConstantFold - Return true if we can constant fold an instruction of the /// specified type, assuming that all operands were constants. static bool CanConstantFold(const Instruction *I) { - if (isa(I) || isa(I) || + if (isa(I) || isa(I) || isa(I) || isa(I) || isa(I)) return true; @@ -1764,34 +1861,6 @@ static bool CanConstantFold(const Instruction *I) { return false; } -/// ConstantFold - Constant fold an instruction of the specified type with the -/// specified constant operands. This function may modify the operands vector. -static Constant *ConstantFold(const Instruction *I, - std::vector &Operands) { - if (isa(I) || isa(I)) - return ConstantExpr::get(I->getOpcode(), Operands[0], Operands[1]); - - if (isa(I)) - return ConstantExpr::getCast(I->getOpcode(), Operands[0], I->getType()); - - switch (I->getOpcode()) { - case Instruction::Select: - return ConstantExpr::getSelect(Operands[0], Operands[1], Operands[2]); - case Instruction::Call: - if (Function *GV = dyn_cast(Operands[0])) { - Operands.erase(Operands.begin()); - return ConstantFoldCall(cast(GV), Operands); - } - return 0; - case Instruction::GetElementPtr: - Constant *Base = Operands[0]; - Operands.erase(Operands.begin()); - return ConstantExpr::getGetElementPtr(Base, Operands); - } - return 0; -} - - /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node /// in the loop that V is derived from. We allow arbitrary operations along the /// way, but the operands of an operation must either be constants or a value @@ -1852,7 +1921,7 @@ static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { if (Operands[i] == 0) return 0; } - return ConstantFold(I, Operands); + return ConstantFoldInstOperands(I, &Operands[0], Operands.size()); } /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is @@ -1860,13 +1929,13 @@ static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { /// constant number of times, and the PHI node is just a recurrence /// involving constants, fold it. Constant *ScalarEvolutionsImpl:: -getConstantEvolutionLoopExitValue(PHINode *PN, uint64_t Its, const Loop *L) { +getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its, const Loop *L){ std::map::iterator I = ConstantEvolutionLoopExitValue.find(PN); if (I != ConstantEvolutionLoopExitValue.end()) return I->second; - if (Its > MaxBruteForceIterations) + if (Its.ugt(APInt(Its.getBitWidth(),MaxBruteForceIterations))) return ConstantEvolutionLoopExitValue[PN] = 0; // Not going to evaluate it. Constant *&RetVal = ConstantEvolutionLoopExitValue[PN]; @@ -1886,11 +1955,11 @@ getConstantEvolutionLoopExitValue(PHINode *PN, uint64_t Its, const Loop *L) { return RetVal = 0; // Not derived from same PHI. // Execute the loop symbolically to determine the exit value. - unsigned IterationNum = 0; - unsigned NumIterations = Its; - if (NumIterations != Its) - return RetVal = 0; // More than 2^32 iterations?? + if (Its.getActiveBits() >= 32) + return RetVal = 0; // More than 2^32-1 iterations?? Not doing it! + unsigned NumIterations = Its.getZExtValue(); // must be in range + unsigned IterationNum = 0; for (Constant *PHIVal = StartCST; ; ++IterationNum) { if (IterationNum == NumIterations) return RetVal = PHIVal; // Got exit value! @@ -1934,14 +2003,16 @@ ComputeIterationCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) { unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis. for (Constant *PHIVal = StartCST; IterationNum != MaxIterations; ++IterationNum) { - ConstantBool *CondVal = - dyn_cast_or_null(EvaluateExpression(Cond, PHIVal)); - if (!CondVal) return UnknownValue; // Couldn't symbolically evaluate. + ConstantInt *CondVal = + dyn_cast_or_null(EvaluateExpression(Cond, PHIVal)); - if (CondVal->getValue() == ExitWhen) { + // Couldn't symbolically evaluate. + if (!CondVal) return UnknownValue; + + if (CondVal->getValue() == uint64_t(ExitWhen)) { ConstantEvolutionLoopExitValue[PN] = PHIVal; ++NumBruteForceTripCountsComputed; - return SCEVConstant::get(ConstantInt::get(Type::UIntTy, IterationNum)); + return SCEVConstant::get(ConstantInt::get(Type::Int32Ty, IterationNum)); } // Compute the value of the PHI node for the next iteration. @@ -1980,7 +2051,7 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { // this is a constant evolving PHI node, get the final value at // the specified iteration number. Constant *RV = getConstantEvolutionLoopExitValue(PN, - ICC->getValue()->getZExtValue(), + ICC->getValue()->getValue(), LI); if (RV) return SCEVUnknown::get(RV); } @@ -2015,7 +2086,8 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { } } } - return SCEVUnknown::get(ConstantFold(I, Operands)); + Constant *C =ConstantFoldInstOperands(I, &Operands[0], Operands.size()); + return SCEVUnknown::get(C); } } @@ -2096,65 +2168,53 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { static std::pair SolveQuadraticEquation(const SCEVAddRecExpr *AddRec) { assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); - SCEVConstant *L = dyn_cast(AddRec->getOperand(0)); - SCEVConstant *M = dyn_cast(AddRec->getOperand(1)); - SCEVConstant *N = dyn_cast(AddRec->getOperand(2)); + SCEVConstant *LC = dyn_cast(AddRec->getOperand(0)); + SCEVConstant *MC = dyn_cast(AddRec->getOperand(1)); + SCEVConstant *NC = dyn_cast(AddRec->getOperand(2)); // We currently can only solve this if the coefficients are constants. - if (!L || !M || !N) { - SCEV *CNC = new SCEVCouldNotCompute(); - return std::make_pair(CNC, CNC); - } - - Constant *C = L->getValue(); - Constant *Two = ConstantInt::get(C->getType(), 2); - - // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C - // The B coefficient is M-N/2 - Constant *B = ConstantExpr::getSub(M->getValue(), - ConstantExpr::getSDiv(N->getValue(), - Two)); - // The A coefficient is N/2 - Constant *A = ConstantExpr::getSDiv(N->getValue(), Two); - - // Compute the B^2-4ac term. - Constant *SqrtTerm = - ConstantExpr::getMul(ConstantInt::get(C->getType(), 4), - ConstantExpr::getMul(A, C)); - SqrtTerm = ConstantExpr::getSub(ConstantExpr::getMul(B, B), SqrtTerm); - - // Compute floor(sqrt(B^2-4ac)) - ConstantInt *SqrtVal = - cast(ConstantExpr::getBitCast(SqrtTerm, - SqrtTerm->getType()->getUnsignedVersion())); - uint64_t SqrtValV = SqrtVal->getZExtValue(); - uint64_t SqrtValV2 = (uint64_t)sqrt((double)SqrtValV); - // The square root might not be precise for arbitrary 64-bit integer - // values. Do some sanity checks to ensure it's correct. - if (SqrtValV2*SqrtValV2 > SqrtValV || - (SqrtValV2+1)*(SqrtValV2+1) <= SqrtValV) { + if (!LC || !MC || !NC) { SCEV *CNC = new SCEVCouldNotCompute(); return std::make_pair(CNC, CNC); } - SqrtVal = ConstantInt::get(Type::ULongTy, SqrtValV2); - SqrtTerm = ConstantExpr::getTruncOrBitCast(SqrtVal, SqrtTerm->getType()); - - Constant *NegB = ConstantExpr::getNeg(B); - Constant *TwoA = ConstantExpr::getMul(A, Two); - - // The divisions must be performed as signed divisions. - const Type *SignedTy = NegB->getType()->getSignedVersion(); - NegB = ConstantExpr::getBitCast(NegB, SignedTy); - TwoA = ConstantExpr::getBitCast(TwoA, SignedTy); - SqrtTerm = ConstantExpr::getBitCast(SqrtTerm, SignedTy); - - Constant *Solution1 = - ConstantExpr::getSDiv(ConstantExpr::getAdd(NegB, SqrtTerm), TwoA); - Constant *Solution2 = - ConstantExpr::getSDiv(ConstantExpr::getSub(NegB, SqrtTerm), TwoA); - return std::make_pair(SCEVUnknown::get(Solution1), - SCEVUnknown::get(Solution2)); + uint32_t BitWidth = LC->getValue()->getValue().getBitWidth(); + const APInt &L = LC->getValue()->getValue(); + const APInt &M = MC->getValue()->getValue(); + const APInt &N = NC->getValue()->getValue(); + APInt Two(BitWidth, 2); + APInt Four(BitWidth, 4); + + { + using namespace APIntOps; + const APInt& C = L; + // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C + // The B coefficient is M-N/2 + APInt B(M); + B -= sdiv(N,Two); + + // The A coefficient is N/2 + APInt A(N.sdiv(Two)); + + // Compute the B^2-4ac term. + APInt SqrtTerm(B); + SqrtTerm *= B; + SqrtTerm -= Four * (A * C); + + // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest + // integer value or else APInt::sqrt() will assert. + APInt SqrtVal(SqrtTerm.sqrt()); + + // Compute the two solutions for the quadratic formula. + // The divisions must be performed as signed divisions. + APInt NegB(-B); + APInt TwoA( A << 1 ); + ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA)); + ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA)); + + return std::make_pair(SCEVConstant::get(Solution1), + SCEVConstant::get(Solution2)); + } // end APIntOps namespace } /// HowFarToZero - Return the number of times a backedge comparing the specified @@ -2163,7 +2223,7 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { // If the value is a constant if (SCEVConstant *C = dyn_cast(V)) { // If the value is already zero, the branch will execute zero times. - if (C->getValue()->isNullValue()) return C; + if (C->getValue()->isZero()) return C; return UnknownValue; // Otherwise it will loop infinitely. } @@ -2215,11 +2275,10 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { << " sol#2: " << *R2 << "\n"; #endif // Pick the smallest positive root value. - assert(R1->getType()->isUnsigned()&&"Didn't canonicalize to unsigned?"); - if (ConstantBool *CB = - dyn_cast(ConstantExpr::getSetLT(R1->getValue(), - R2->getValue()))) { - if (CB->getValue() == false) + if (ConstantInt *CB = + dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, + R1->getValue(), R2->getValue()))) { + if (CB->getZExtValue() == false) std::swap(R1, R2); // R1 is the minimum root now. // We can only use this value if the chrec ends up with an exact zero @@ -2227,7 +2286,7 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { // should not accept a root of 2. SCEVHandle Val = AddRec->evaluateAtIteration(R1); if (SCEVConstant *EvalVal = dyn_cast(Val)) - if (EvalVal->getValue()->isNullValue()) + if (EvalVal->getValue()->isZero()) return R1; // We found a quadratic root! } } @@ -2248,8 +2307,9 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) { // already. If so, the backedge will execute zero times. if (SCEVConstant *C = dyn_cast(V)) { Constant *Zero = Constant::getNullValue(C->getValue()->getType()); - Constant *NonZero = ConstantExpr::getSetNE(C->getValue(), Zero); - if (NonZero == ConstantBool::getTrue()) + Constant *NonZero = + ConstantExpr::getICmp(ICmpInst::ICMP_NE, C->getValue(), Zero); + if (NonZero == ConstantInt::getTrue()) return getSCEV(Zero); return UnknownValue; // Otherwise it will loop infinitely. } @@ -2263,7 +2323,7 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) { /// specified less-than comparison will execute. If not computable, return /// UnknownValue. SCEVHandle ScalarEvolutionsImpl:: -HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) { +HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) { // Only handle: "ADDREC < LoopInvariant". if (!RHS->isLoopInvariant(L)) return UnknownValue; @@ -2309,40 +2369,52 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) { // Now that we found a conditional branch that dominates the loop, check to // see if it is the comparison we are looking for. - SetCondInst *SCI =dyn_cast(LoopEntryPredicate->getCondition()); - if (!SCI) return UnknownValue; - Value *PreCondLHS = SCI->getOperand(0); - Value *PreCondRHS = SCI->getOperand(1); - Instruction::BinaryOps Cond; - if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest) - Cond = SCI->getOpcode(); - else - Cond = SCI->getInverseCondition(); + if (ICmpInst *ICI = dyn_cast(LoopEntryPredicate->getCondition())){ + Value *PreCondLHS = ICI->getOperand(0); + Value *PreCondRHS = ICI->getOperand(1); + ICmpInst::Predicate Cond; + if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest) + Cond = ICI->getPredicate(); + else + Cond = ICI->getInversePredicate(); - switch (Cond) { - case Instruction::SetGT: - std::swap(PreCondLHS, PreCondRHS); - Cond = Instruction::SetLT; - // Fall Through. - case Instruction::SetLT: - if (PreCondLHS->getType()->isInteger() && - PreCondLHS->getType()->isSigned()) { + switch (Cond) { + case ICmpInst::ICMP_UGT: + if (isSigned) return UnknownValue; + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_ULT; + break; + case ICmpInst::ICMP_SGT: + if (!isSigned) return UnknownValue; + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_SLT; + break; + case ICmpInst::ICMP_ULT: + if (isSigned) return UnknownValue; + break; + case ICmpInst::ICMP_SLT: + if (!isSigned) return UnknownValue; + break; + default: + return UnknownValue; + } + + if (PreCondLHS->getType()->isInteger()) { if (RHS != getSCEV(PreCondRHS)) return UnknownValue; // Not a comparison against 'm'. if (SCEV::getMinusSCEV(AddRec->getOperand(0), One) != getSCEV(PreCondLHS)) return UnknownValue; // Not a comparison against 'n-1'. - break; - } else { - return UnknownValue; } - default: break; - } + else return UnknownValue; - //cerr << "Computed Loop Trip Count as: " - // << *SCEV::getMinusSCEV(RHS, AddRec->getOperand(0)) << "\n"; - return SCEV::getMinusSCEV(RHS, AddRec->getOperand(0)); + // cerr << "Computed Loop Trip Count as: " + // << // *SCEV::getMinusSCEV(RHS, AddRec->getOperand(0)) << "\n"; + return SCEV::getMinusSCEV(RHS, AddRec->getOperand(0)); + } + else + return UnknownValue; } return UnknownValue; @@ -2359,13 +2431,13 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range) const { // If the start is a non-zero constant, shift the range to simplify things. if (SCEVConstant *SC = dyn_cast(getStart())) - if (!SC->getValue()->isNullValue()) { + if (!SC->getValue()->isZero()) { std::vector Operands(op_begin(), op_end()); Operands[0] = SCEVUnknown::getIntegerSCEV(0, SC->getType()); SCEVHandle Shifted = SCEVAddRecExpr::get(Operands, getLoop()); if (SCEVAddRecExpr *ShiftedAddRec = dyn_cast(Shifted)) return ShiftedAddRec->getNumIterationsInRange( - Range.subtract(SC->getValue())); + Range.subtract(SC->getValue()->getValue())); // This is strange and shouldn't happen. return new SCEVCouldNotCompute(); } @@ -2382,48 +2454,45 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range) const { // First check to see if the range contains zero. If not, the first // iteration exits. - ConstantInt *Zero = ConstantInt::get(getType(), 0); - if (!Range.contains(Zero)) return SCEVConstant::get(Zero); + if (!Range.contains(APInt(getBitWidth(),0))) + return SCEVConstant::get(ConstantInt::get(getType(),0)); if (isAffine()) { // If this is an affine expression then we have this situation: // Solve {0,+,A} in Range === Ax in Range - // Since we know that zero is in the range, we know that the upper value of - // the range must be the first possible exit value. Also note that we - // already checked for a full range. - ConstantInt *Upper = cast(Range.getUpper()); - ConstantInt *A = cast(getOperand(1))->getValue(); - ConstantInt *One = ConstantInt::get(getType(), 1); - - // The exit value should be (Upper+A-1)/A. - Constant *ExitValue = Upper; - if (A != One) { - ExitValue = ConstantExpr::getSub(ConstantExpr::getAdd(Upper, A), One); - ExitValue = ConstantExpr::getSDiv(ExitValue, A); - } - assert(isa(ExitValue) && - "Constant folding of integers not implemented?"); + // We know that zero is in the range. If A is positive then we know that + // the upper value of the range must be the first possible exit value. + // If A is negative then the lower of the range is the last possible loop + // value. Also note that we already checked for a full range. + APInt One(getBitWidth(),1); + APInt A = cast(getOperand(1))->getValue()->getValue(); + APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower(); + + // The exit value should be (End+A)/A. + APInt ExitVal = (End + A).udiv(A); + ConstantInt *ExitValue = ConstantInt::get(ExitVal); // Evaluate at the exit value. If we really did fall out of the valid // range, then we computed our trip count, otherwise wrap around or other // things must have happened. ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue); - if (Range.contains(Val)) + if (Range.contains(Val->getValue())) return new SCEVCouldNotCompute(); // Something strange happened // Ensure that the previous value is in the range. This is a sanity check. - assert(Range.contains(EvaluateConstantChrecAtConstant(this, - ConstantExpr::getSub(ExitValue, One))) && + assert(Range.contains( + EvaluateConstantChrecAtConstant(this, + ConstantInt::get(ExitVal - One))->getValue()) && "Linear scev computation is off in a bad way!"); - return SCEVConstant::get(cast(ExitValue)); + return SCEVConstant::get(ExitValue); } else if (isQuadratic()) { // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the // quadratic equation to solve it. To do this, we must frame our problem in // terms of figuring out when zero is crossed, instead of when // Range.getUpper() is crossed. std::vector NewOps(op_begin(), op_end()); - NewOps[0] = SCEV::getNegativeSCEV(SCEVUnknown::get(Range.getUpper())); + NewOps[0] = SCEV::getNegativeSCEV(SCEVConstant::get(Range.getUpper())); SCEVHandle NewAddRec = SCEVAddRecExpr::get(NewOps, getLoop()); // Next, solve the constructed addrec @@ -2433,11 +2502,10 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range) const { SCEVConstant *R2 = dyn_cast(Roots.second); if (R1) { // Pick the smallest positive root value. - assert(R1->getType()->isUnsigned() && "Didn't canonicalize to unsigned?"); - if (ConstantBool *CB = - dyn_cast(ConstantExpr::getSetLT(R1->getValue(), - R2->getValue()))) { - if (CB->getValue() == false) + if (ConstantInt *CB = + dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, + R1->getValue(), R2->getValue()))) { + if (CB->getZExtValue() == false) std::swap(R1, R2); // R1 is the minimum root now. // Make sure the root is not off by one. The returned iteration should @@ -2445,25 +2513,21 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range) const { // for "X*X < 5", for example, we should not return a root of 2. ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this, R1->getValue()); - if (Range.contains(R1Val)) { + if (Range.contains(R1Val->getValue())) { // The next iteration must be out of the range... - Constant *NextVal = - ConstantExpr::getAdd(R1->getValue(), - ConstantInt::get(R1->getType(), 1)); + ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()+1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal); - if (!Range.contains(R1Val)) - return SCEVUnknown::get(NextVal); + if (!Range.contains(R1Val->getValue())) + return SCEVConstant::get(NextVal); return new SCEVCouldNotCompute(); // Something strange happened } // If R1 was not in the range, then it is a good return value. Make // sure that R1-1 WAS in the range though, just in case. - Constant *NextVal = - ConstantExpr::getSub(R1->getValue(), - ConstantInt::get(R1->getType(), 1)); + ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()-1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal); - if (Range.contains(R1Val)) + if (Range.contains(R1Val->getValue())) return R1; return new SCEVCouldNotCompute(); // Something strange happened } @@ -2476,7 +2540,6 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range) const { // incredibly important, we will be able to simplify the exit test a lot, and // we are almost guaranteed to get a trip count in this case. ConstantInt *TestVal = ConstantInt::get(getType(), 0); - ConstantInt *One = ConstantInt::get(getType(), 1); ConstantInt *EndVal = TestVal; // Stop when we wrap around. do { ++NumBruteForceEvaluations; @@ -2485,11 +2548,11 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range) const { return new SCEVCouldNotCompute(); // Check to see if we found the value! - if (!Range.contains(cast(Val)->getValue())) + if (!Range.contains(cast(Val)->getValue()->getValue())) return SCEVConstant::get(TestVal); // Increment to test the next index. - TestVal = cast(ConstantExpr::getAdd(TestVal, One)); + TestVal = ConstantInt::get(TestVal->getValue()+1); } while (TestVal != EndVal); return new SCEVCouldNotCompute(); @@ -2546,8 +2609,8 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) const { return ((ScalarEvolutionsImpl*)Impl)->getSCEVAtScope(getSCEV(V), L); } -void ScalarEvolution::deleteInstructionFromRecords(Instruction *I) const { - return ((ScalarEvolutionsImpl*)Impl)->deleteInstructionFromRecords(I); +void ScalarEvolution::deleteValueFromRecords(Value *V) const { + return ((ScalarEvolutionsImpl*)Impl)->deleteValueFromRecords(V); } static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE, @@ -2558,7 +2621,7 @@ static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE, cerr << "Loop " << L->getHeader()->getName() << ": "; - std::vector ExitBlocks; + SmallVector ExitBlocks; L->getExitBlocks(ExitBlocks); if (ExitBlocks.size() != 1) cerr << " "; @@ -2585,7 +2648,7 @@ void ScalarEvolution::print(std::ostream &OS, const Module* ) const { SV->print(OS); OS << "\t\t"; - if ((*I).getType()->isIntegral()) { + if ((*I).getType()->isInteger()) { ConstantRange Bounds = SV->getValueRange(); if (!Bounds.isFullSet()) OS << "Bounds: " << Bounds << " ";