X-Git-Url: http://plrg.eecs.uci.edu/git/?p=oota-llvm.git;a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=a992e51e0fb13d80615b49e356feee4fd71c20a1;hp=349979843a5378dd05f85ef550a712bd13e444f6;hb=b83eb6447ba155342598f0fabe1f08f5baa9164a;hpb=05bd374b1f22b74baf1dc087c8c2d128c1e299aa diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 349979843a5..a992e51e0fb 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -64,21 +64,24 @@ #include "llvm/DerivedTypes.h" #include "llvm/GlobalVariable.h" #include "llvm/Instructions.h" +#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Assembly/Writer.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Support/CFG.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/ConstantRange.h" #include "llvm/Support/InstIterator.h" -#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/ADT/Statistic.h" #include +#include #include using namespace llvm; namespace { - RegisterAnalysis + RegisterPass R("scalar-evolution", "Scalar Evolution Analysis"); Statistic<> @@ -100,7 +103,8 @@ namespace { cl::opt MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, - cl::desc("Maximum number of iterations SCEV will symbolically execute a constant derived loop"), + cl::desc("Maximum number of iterations SCEV will " + "symbolically execute a constant derived loop"), cl::init(100)); } @@ -162,21 +166,21 @@ bool SCEVCouldNotCompute::classof(const SCEV *S) { // SCEVConstants - Only allow the creation of one SCEVConstant for any // particular value. Don't use a SCEVHandle here, or else the object will // never be deleted! -static std::map SCEVConstants; +static ManagedStatic > SCEVConstants; SCEVConstant::~SCEVConstant() { - SCEVConstants.erase(V); + SCEVConstants->erase(V); } SCEVHandle SCEVConstant::get(ConstantInt *V) { // Make sure that SCEVConstant instances are all unsigned. if (V->getType()->isSigned()) { const Type *NewTy = V->getType()->getUnsignedVersion(); - V = cast(ConstantExpr::getCast(V, NewTy)); + V = cast(ConstantExpr::getCast(V, NewTy)); } - SCEVConstant *&R = SCEVConstants[V]; + SCEVConstant *&R = (*SCEVConstants)[V]; if (R == 0) R = new SCEVConstant(V); return R; } @@ -194,7 +198,8 @@ void SCEVConstant::print(std::ostream &OS) const { // SCEVTruncates - Only allow the creation of one SCEVTruncateExpr for any // particular input. Don't use a SCEVHandle here, or else the object will // never be deleted! -static std::map, SCEVTruncateExpr*> SCEVTruncates; +static ManagedStatic, + SCEVTruncateExpr*> > SCEVTruncates; SCEVTruncateExpr::SCEVTruncateExpr(const SCEVHandle &op, const Type *ty) : SCEV(scTruncate), Op(op), Ty(ty) { @@ -206,7 +211,7 @@ SCEVTruncateExpr::SCEVTruncateExpr(const SCEVHandle &op, const Type *ty) } SCEVTruncateExpr::~SCEVTruncateExpr() { - SCEVTruncates.erase(std::make_pair(Op, Ty)); + SCEVTruncates->erase(std::make_pair(Op, Ty)); } ConstantRange SCEVTruncateExpr::getValueRange() const { @@ -220,8 +225,8 @@ void SCEVTruncateExpr::print(std::ostream &OS) const { // SCEVZeroExtends - Only allow the creation of one SCEVZeroExtendExpr for any // particular input. Don't use a SCEVHandle here, or else the object will never // be deleted! -static std::map, - SCEVZeroExtendExpr*> SCEVZeroExtends; +static ManagedStatic, + SCEVZeroExtendExpr*> > SCEVZeroExtends; SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEVHandle &op, const Type *ty) : SCEV(scTruncate), Op(op), Ty(ty) { @@ -233,7 +238,7 @@ SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEVHandle &op, const Type *ty) } SCEVZeroExtendExpr::~SCEVZeroExtendExpr() { - SCEVZeroExtends.erase(std::make_pair(Op, Ty)); + SCEVZeroExtends->erase(std::make_pair(Op, Ty)); } ConstantRange SCEVZeroExtendExpr::getValueRange() const { @@ -247,13 +252,13 @@ void SCEVZeroExtendExpr::print(std::ostream &OS) const { // SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any // particular input. Don't use a SCEVHandle here, or else the object will never // be deleted! -static std::map >, - SCEVCommutativeExpr*> SCEVCommExprs; +static ManagedStatic >, + SCEVCommutativeExpr*> > SCEVCommExprs; SCEVCommutativeExpr::~SCEVCommutativeExpr() { - SCEVCommExprs.erase(std::make_pair(getSCEVType(), - std::vector(Operands.begin(), - Operands.end()))); + SCEVCommExprs->erase(std::make_pair(getSCEVType(), + std::vector(Operands.begin(), + Operands.end()))); } void SCEVCommutativeExpr::print(std::ostream &OS) const { @@ -292,35 +297,36 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, } -// SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular +// SCEVSDivs - Only allow the creation of one SCEVSDivExpr for any particular // input. Don't use a SCEVHandle here, or else the object will never be // deleted! -static std::map, SCEVUDivExpr*> SCEVUDivs; +static ManagedStatic, + SCEVSDivExpr*> > SCEVSDivs; -SCEVUDivExpr::~SCEVUDivExpr() { - SCEVUDivs.erase(std::make_pair(LHS, RHS)); +SCEVSDivExpr::~SCEVSDivExpr() { + SCEVSDivs->erase(std::make_pair(LHS, RHS)); } -void SCEVUDivExpr::print(std::ostream &OS) const { - OS << "(" << *LHS << " /u " << *RHS << ")"; +void SCEVSDivExpr::print(std::ostream &OS) const { + OS << "(" << *LHS << " /s " << *RHS << ")"; } -const Type *SCEVUDivExpr::getType() const { +const Type *SCEVSDivExpr::getType() const { const Type *Ty = LHS->getType(); - if (Ty->isSigned()) Ty = Ty->getUnsignedVersion(); + if (Ty->isUnsigned()) Ty = Ty->getSignedVersion(); return Ty; } // SCEVAddRecExprs - Only allow the creation of one SCEVAddRecExpr for any // particular input. Don't use a SCEVHandle here, or else the object will never // be deleted! -static std::map >, - SCEVAddRecExpr*> SCEVAddRecExprs; +static ManagedStatic >, + SCEVAddRecExpr*> > SCEVAddRecExprs; SCEVAddRecExpr::~SCEVAddRecExpr() { - SCEVAddRecExprs.erase(std::make_pair(L, - std::vector(Operands.begin(), - Operands.end()))); + SCEVAddRecExprs->erase(std::make_pair(L, + std::vector(Operands.begin(), + Operands.end()))); } SCEVHandle SCEVAddRecExpr:: @@ -347,8 +353,9 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const { // This recurrence is invariant w.r.t to QueryLoop iff QueryLoop doesn't - // contain L. - return !QueryLoop->contains(L->getHeader()); + // contain L and if the start is invariant. + return !QueryLoop->contains(L->getHeader()) && + getOperand(0)->isLoopInvariant(QueryLoop); } @@ -362,9 +369,9 @@ void SCEVAddRecExpr::print(std::ostream &OS) const { // SCEVUnknowns - Only allow the creation of one SCEVUnknown for any particular // value. Don't use a SCEVHandle here, or else the object will never be // deleted! -static std::map SCEVUnknowns; +static ManagedStatic > SCEVUnknowns; -SCEVUnknown::~SCEVUnknown() { SCEVUnknowns.erase(V); } +SCEVUnknown::~SCEVUnknown() { SCEVUnknowns->erase(V); } bool SCEVUnknown::isLoopInvariant(const Loop *L) const { // All non-instruction values are loop invariant. All instructions are loop @@ -390,7 +397,7 @@ namespace { /// SCEVComplexityCompare - Return true if the complexity of the LHS is less /// than the complexity of the RHS. This comparator is used to canonicalize /// expressions. - struct SCEVComplexityCompare { + struct VISIBILITY_HIDDEN SCEVComplexityCompare { bool operator()(SCEV *LHS, SCEV *RHS) { return LHS->getSCEVType() < RHS->getSCEVType(); } @@ -456,9 +463,9 @@ SCEVHandle SCEVUnknown::getIntegerSCEV(int Val, const Type *Ty) { else if (Ty->isFloatingPoint()) C = ConstantFP::get(Ty, Val); else if (Ty->isSigned()) - C = ConstantSInt::get(Ty, Val); + C = ConstantInt::get(Ty, Val); else { - C = ConstantSInt::get(Ty->getSignedVersion(), Val); + C = ConstantInt::get(Ty->getSignedVersion(), Val); C = ConstantExpr::getCast(C, Ty); } return SCEVUnknown::get(C); @@ -500,11 +507,11 @@ static SCEVHandle PartialFact(SCEVHandle V, unsigned NumSteps) { // Handle this case efficiently, it is common to have constant iteration // counts while computing loop exit values. if (SCEVConstant *SC = dyn_cast(V)) { - uint64_t Val = SC->getValue()->getRawValue(); + uint64_t Val = SC->getValue()->getZExtValue(); uint64_t Result = 1; for (; NumSteps; --NumSteps) Result *= Val-(NumSteps-1); - Constant *Res = ConstantUInt::get(Type::ULongTy, Result); + Constant *Res = ConstantInt::get(Type::ULongTy, Result); return SCEVUnknown::get(ConstantExpr::getCast(Res, V->getType())); } @@ -537,7 +544,7 @@ SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It) const { for (unsigned i = 1, e = getNumOperands(); i != e; ++i) { SCEVHandle BC = PartialFact(It, i); Divisor *= i; - SCEVHandle Val = SCEVUDivExpr::get(SCEVMulExpr::get(BC, getOperand(i)), + SCEVHandle Val = SCEVSDivExpr::get(SCEVMulExpr::get(BC, getOperand(i)), SCEVUnknown::getIntegerSCEV(Divisor,Ty)); Result = SCEVAddExpr::get(Result, Val); } @@ -567,7 +574,7 @@ SCEVHandle SCEVTruncateExpr::get(const SCEVHandle &Op, const Type *Ty) { return SCEVAddRecExpr::get(Operands, AddRec->getLoop()); } - SCEVTruncateExpr *&Result = SCEVTruncates[std::make_pair(Op, Ty)]; + SCEVTruncateExpr *&Result = (*SCEVTruncates)[std::make_pair(Op, Ty)]; if (Result == 0) Result = new SCEVTruncateExpr(Op, Ty); return Result; } @@ -581,7 +588,7 @@ SCEVHandle SCEVZeroExtendExpr::get(const SCEVHandle &Op, const Type *Ty) { // operands (often constants). This would allow analysis of something like // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; } - SCEVZeroExtendExpr *&Result = SCEVZeroExtends[std::make_pair(Op, Ty)]; + SCEVZeroExtendExpr *&Result = (*SCEVZeroExtends)[std::make_pair(Op, Ty)]; if (Result == 0) Result = new SCEVZeroExtendExpr(Op, Ty); return Result; } @@ -809,8 +816,8 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { // Okay, it looks like we really DO need an add expr. Check to see if we // already have one, otherwise create a new one. std::vector SCEVOps(Ops.begin(), Ops.end()); - SCEVCommutativeExpr *&Result = SCEVCommExprs[std::make_pair(scAddExpr, - SCEVOps)]; + SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scAddExpr, + SCEVOps)]; if (Result == 0) Result = new SCEVAddExpr(Ops); return Result; } @@ -972,27 +979,27 @@ SCEVHandle SCEVMulExpr::get(std::vector &Ops) { // Okay, it looks like we really DO need an mul expr. Check to see if we // already have one, otherwise create a new one. std::vector SCEVOps(Ops.begin(), Ops.end()); - SCEVCommutativeExpr *&Result = SCEVCommExprs[std::make_pair(scMulExpr, - SCEVOps)]; + SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scMulExpr, + SCEVOps)]; if (Result == 0) Result = new SCEVMulExpr(Ops); return Result; } -SCEVHandle SCEVUDivExpr::get(const SCEVHandle &LHS, const SCEVHandle &RHS) { +SCEVHandle SCEVSDivExpr::get(const SCEVHandle &LHS, const SCEVHandle &RHS) { if (SCEVConstant *RHSC = dyn_cast(RHS)) { if (RHSC->getValue()->equalsInt(1)) - return LHS; // X /u 1 --> x + return LHS; // X /s 1 --> x if (RHSC->getValue()->isAllOnesValue()) - return SCEV::getNegativeSCEV(LHS); // X /u -1 --> -x + return SCEV::getNegativeSCEV(LHS); // X /s -1 --> -x if (SCEVConstant *LHSC = dyn_cast(LHS)) { Constant *LHSCV = LHSC->getValue(); Constant *RHSCV = RHSC->getValue(); - if (LHSCV->getType()->isSigned()) + if (LHSCV->getType()->isUnsigned()) LHSCV = ConstantExpr::getCast(LHSCV, - LHSCV->getType()->getUnsignedVersion()); - if (RHSCV->getType()->isSigned()) + LHSCV->getType()->getSignedVersion()); + if (RHSCV->getType()->isUnsigned()) RHSCV = ConstantExpr::getCast(RHSCV, LHSCV->getType()); return SCEVUnknown::get(ConstantExpr::getDiv(LHSCV, RHSCV)); } @@ -1000,8 +1007,8 @@ SCEVHandle SCEVUDivExpr::get(const SCEVHandle &LHS, const SCEVHandle &RHS) { // FIXME: implement folding of (X*4)/4 when we know X*4 doesn't overflow. - SCEVUDivExpr *&Result = SCEVUDivs[std::make_pair(LHS, RHS)]; - if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS); + SCEVSDivExpr *&Result = (*SCEVSDivs)[std::make_pair(LHS, RHS)]; + if (Result == 0) Result = new SCEVSDivExpr(LHS, RHS); return Result; } @@ -1036,8 +1043,8 @@ SCEVHandle SCEVAddRecExpr::get(std::vector &Operands, } SCEVAddRecExpr *&Result = - SCEVAddRecExprs[std::make_pair(L, std::vector(Operands.begin(), - Operands.end()))]; + (*SCEVAddRecExprs)[std::make_pair(L, std::vector(Operands.begin(), + Operands.end()))]; if (Result == 0) Result = new SCEVAddRecExpr(Operands, L); return Result; } @@ -1045,7 +1052,7 @@ SCEVHandle SCEVAddRecExpr::get(std::vector &Operands, SCEVHandle SCEVUnknown::get(Value *V) { if (ConstantInt *CI = dyn_cast(V)) return SCEVConstant::get(CI); - SCEVUnknown *&Result = SCEVUnknowns[V]; + SCEVUnknown *&Result = (*SCEVUnknowns)[V]; if (Result == 0) Result = new SCEVUnknown(V); return Result; } @@ -1059,7 +1066,7 @@ SCEVHandle SCEVUnknown::get(Value *V) { /// evolution code. /// namespace { - struct ScalarEvolutionsImpl { + struct VISIBILITY_HIDDEN ScalarEvolutionsImpl { /// F - The function we are analyzing. /// Function &F; @@ -1168,14 +1175,19 @@ namespace { /// HowFarToZero - Return the number of times a backedge comparing the /// specified value to zero will execute. If not computable, return - /// UnknownValue + /// UnknownValue. SCEVHandle HowFarToZero(SCEV *V, const Loop *L); /// HowFarToNonZero - Return the number of times a backedge checking the /// specified value for nonzero will execute. If not computable, return - /// UnknownValue + /// UnknownValue. SCEVHandle HowFarToNonZero(SCEV *V, const Loop *L); + /// HowManyLessThans - Return the number of times a backedge containing the + /// specified less-than comparison will execute. If not computable, return + /// UnknownValue. + SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L); + /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a /// constant number of times, and the PHI node is just a recurrence @@ -1287,6 +1299,31 @@ SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) { SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge)); SCEVHandle PHISCEV = SCEVAddRecExpr::get(StartVal, Accum, L); + // Okay, for the entire analysis of this edge we assumed the PHI + // to be symbolic. We now need to go back and update all of the + // entries for the scalars that use the PHI (except for the PHI + // itself) to use the new analyzed value instead of the "symbolic" + // value. + ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV); + return PHISCEV; + } + } + } else if (SCEVAddRecExpr *AddRec = dyn_cast(BEValue)) { + // Otherwise, this could be a loop like this: + // i = 0; for (j = 1; ..; ++j) { .... i = j; } + // In this case, j = {1,+,1} and BEValue is j. + // Because the other in-value of i (0) fits the evolution of BEValue + // i really is an addrec evolution. + if (AddRec->getLoop() == L && AddRec->isAffine()) { + SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge)); + + // If StartVal = j.start - j.stride, we can use StartVal as the + // initial step of the addrec evolution. + if (StartVal == SCEV::getMinusSCEV(AddRec->getOperand(0), + AddRec->getOperand(1))) { + SCEVHandle PHISCEV = + SCEVAddRecExpr::get(StartVal, AddRec->getOperand(1), L); + // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and update all of the // entries for the scalars that use the PHI (except for the PHI @@ -1348,8 +1385,8 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { return SCEVMulExpr::get(getSCEV(I->getOperand(0)), getSCEV(I->getOperand(1))); case Instruction::Div: - if (V->getType()->isInteger() && V->getType()->isUnsigned()) - return SCEVUDivExpr::get(getSCEV(I->getOperand(0)), + if (V->getType()->isInteger() && V->getType()->isSigned()) + return SCEVSDivExpr::get(getSCEV(I->getOperand(0)), getSCEV(I->getOperand(1))); break; @@ -1366,15 +1403,6 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { } break; - case Instruction::Shr: - if (ConstantUInt *SA = dyn_cast(I->getOperand(1))) - if (V->getType()->isUnsigned()) { - Constant *X = ConstantInt::get(V->getType(), 1); - X = ConstantExpr::getShl(X, SA); - return SCEVUDivExpr::get(getSCEV(I->getOperand(0)), getSCEV(X)); - } - break; - case Instruction::Cast: return createNodeForCast(cast(I)); @@ -1530,6 +1558,20 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { if (!isa(TC)) return TC; } break; + case Instruction::SetLT: + if (LHS->getType()->isInteger() && + ExitCond->getOperand(0)->getType()->isSigned()) { + SCEVHandle TC = HowManyLessThans(LHS, RHS, L); + if (!isa(TC)) return TC; + } + break; + case Instruction::SetGT: + if (LHS->getType()->isInteger() && + ExitCond->getOperand(0)->getType()->isSigned()) { + SCEVHandle TC = HowManyLessThans(RHS, LHS, L); + if (!isa(TC)) return TC; + } + break; default: #if 0 std::cerr << "ComputeIterationCount "; @@ -1563,7 +1605,7 @@ GetAddressedElementFromGlobal(GlobalVariable *GV, const std::vector &Indices) { Constant *Init = GV->getInitializer(); for (unsigned i = 0, e = Indices.size(); i != e; ++i) { - uint64_t Idx = Indices[i]->getRawValue(); + uint64_t Idx = Indices[i]->getZExtValue(); if (ConstantStruct *CS = dyn_cast(Init)) { assert(Idx < CS->getNumOperands() && "Bad struct index!"); Init = cast(CS->getOperand(Idx)); @@ -1637,8 +1679,8 @@ ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, unsigned MaxSteps = MaxBruteForceIterations; for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) { - ConstantUInt *ItCst = - ConstantUInt::get(IdxExpr->getType()->getUnsignedVersion(), IterationNum); + ConstantInt *ItCst = + ConstantInt::get(IdxExpr->getType()->getUnsignedVersion(), IterationNum); ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst); // Form the GEP offset. @@ -1650,7 +1692,7 @@ ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, // Evaluate the condition for this iteration. Result = ConstantExpr::get(SetCCOpcode, Result, RHS); if (!isa(Result)) break; // Couldn't decide for sure - if (Result == ConstantBool::False) { + if (cast(Result)->getValue() == false) { #if 0 std::cerr << "\n***\n*** Computed loop count " << *ItCst << "\n*** From global " << *GV << "*** BB: " << *L->getHeader() @@ -1854,7 +1896,7 @@ ComputeIterationCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) { if (CondVal->getValue() == ExitWhen) { ConstantEvolutionLoopExitValue[PN] = PHIVal; ++NumBruteForceTripCountsComputed; - return SCEVConstant::get(ConstantUInt::get(Type::UIntTy, IterationNum)); + return SCEVConstant::get(ConstantInt::get(Type::UIntTy, IterationNum)); } // Compute the value of the PHI node for the next iteration. @@ -1893,7 +1935,7 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { // this is a constant evolving PHI node, get the final value at // the specified iteration number. Constant *RV = getConstantEvolutionLoopExitValue(PN, - ICC->getValue()->getRawValue(), + ICC->getValue()->getZExtValue(), LI); if (RV) return SCEVUnknown::get(RV); } @@ -1960,14 +2002,14 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { return Comm; } - if (SCEVUDivExpr *UDiv = dyn_cast(V)) { - SCEVHandle LHS = getSCEVAtScope(UDiv->getLHS(), L); + if (SCEVSDivExpr *Div = dyn_cast(V)) { + SCEVHandle LHS = getSCEVAtScope(Div->getLHS(), L); if (LHS == UnknownValue) return LHS; - SCEVHandle RHS = getSCEVAtScope(UDiv->getRHS(), L); + SCEVHandle RHS = getSCEVAtScope(Div->getRHS(), L); if (RHS == UnknownValue) return RHS; - if (LHS == UDiv->getLHS() && RHS == UDiv->getRHS()) - return UDiv; // must be loop invariant - return SCEVUDivExpr::get(LHS, RHS); + if (LHS == Div->getLHS() && RHS == Div->getRHS()) + return Div; // must be loop invariant + return SCEVSDivExpr::get(LHS, RHS); } // If this is a loop recurrence for a loop that does not contain L, then we @@ -2034,10 +2076,10 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec) { SqrtTerm = ConstantExpr::getSub(ConstantExpr::getMul(B, B), SqrtTerm); // Compute floor(sqrt(B^2-4ac)) - ConstantUInt *SqrtVal = - cast(ConstantExpr::getCast(SqrtTerm, + ConstantInt *SqrtVal = + cast(ConstantExpr::getCast(SqrtTerm, SqrtTerm->getType()->getUnsignedVersion())); - uint64_t SqrtValV = SqrtVal->getValue(); + uint64_t SqrtValV = SqrtVal->getZExtValue(); uint64_t SqrtValV2 = (uint64_t)sqrt((double)SqrtValV); // The square root might not be precise for arbitrary 64-bit integer // values. Do some sanity checks to ensure it's correct. @@ -2047,7 +2089,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec) { return std::make_pair(CNC, CNC); } - SqrtVal = ConstantUInt::get(Type::ULongTy, SqrtValV2); + SqrtVal = ConstantInt::get(Type::ULongTy, SqrtValV2); SqrtTerm = ConstantExpr::getCast(SqrtVal, SqrtTerm->getType()); Constant *NegB = ConstantExpr::getNeg(B); @@ -2129,7 +2171,7 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { if (ConstantBool *CB = dyn_cast(ConstantExpr::getSetLT(R1->getValue(), R2->getValue()))) { - if (CB != ConstantBool::True) + if (CB->getValue() == false) std::swap(R1, R2); // R1 is the minimum root now. // We can only use this value if the chrec ends up with an exact zero @@ -2159,7 +2201,7 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) { if (SCEVConstant *C = dyn_cast(V)) { Constant *Zero = Constant::getNullValue(C->getValue()->getType()); Constant *NonZero = ConstantExpr::getSetNE(C->getValue(), Zero); - if (NonZero == ConstantBool::True) + if (NonZero == ConstantBool::getTrue()) return getSCEV(Zero); return UnknownValue; // Otherwise it will loop infinitely. } @@ -2169,6 +2211,95 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) { return UnknownValue; } +/// HowManyLessThans - Return the number of times a backedge containing the +/// specified less-than comparison will execute. If not computable, return +/// UnknownValue. +SCEVHandle ScalarEvolutionsImpl:: +HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) { + // Only handle: "ADDREC < LoopInvariant". + if (!RHS->isLoopInvariant(L)) return UnknownValue; + + SCEVAddRecExpr *AddRec = dyn_cast(LHS); + if (!AddRec || AddRec->getLoop() != L) + return UnknownValue; + + if (AddRec->isAffine()) { + // FORNOW: We only support unit strides. + SCEVHandle One = SCEVUnknown::getIntegerSCEV(1, RHS->getType()); + if (AddRec->getOperand(1) != One) + return UnknownValue; + + // The number of iterations for "[n,+,1] < m", is m-n. However, we don't + // know that m is >= n on input to the loop. If it is, the condition return + // true zero times. What we really should return, for full generality, is + // SMAX(0, m-n). Since we cannot check this, we will instead check for a + // canonical loop form: most do-loops will have a check that dominates the + // loop, that only enters the loop if [n-1]= n. + + // Search for the check. + BasicBlock *Preheader = L->getLoopPreheader(); + BasicBlock *PreheaderDest = L->getHeader(); + if (Preheader == 0) return UnknownValue; + + BranchInst *LoopEntryPredicate = + dyn_cast(Preheader->getTerminator()); + if (!LoopEntryPredicate) return UnknownValue; + + // This might be a critical edge broken out. If the loop preheader ends in + // an unconditional branch to the loop, check to see if the preheader has a + // single predecessor, and if so, look for its terminator. + while (LoopEntryPredicate->isUnconditional()) { + PreheaderDest = Preheader; + Preheader = Preheader->getSinglePredecessor(); + if (!Preheader) return UnknownValue; // Multiple preds. + + LoopEntryPredicate = + dyn_cast(Preheader->getTerminator()); + if (!LoopEntryPredicate) return UnknownValue; + } + + // Now that we found a conditional branch that dominates the loop, check to + // see if it is the comparison we are looking for. + SetCondInst *SCI =dyn_cast(LoopEntryPredicate->getCondition()); + if (!SCI) return UnknownValue; + Value *PreCondLHS = SCI->getOperand(0); + Value *PreCondRHS = SCI->getOperand(1); + Instruction::BinaryOps Cond; + if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest) + Cond = SCI->getOpcode(); + else + Cond = SCI->getInverseCondition(); + + switch (Cond) { + case Instruction::SetGT: + std::swap(PreCondLHS, PreCondRHS); + Cond = Instruction::SetLT; + // Fall Through. + case Instruction::SetLT: + if (PreCondLHS->getType()->isInteger() && + PreCondLHS->getType()->isSigned()) { + if (RHS != getSCEV(PreCondRHS)) + return UnknownValue; // Not a comparison against 'm'. + + if (SCEV::getMinusSCEV(AddRec->getOperand(0), One) + != getSCEV(PreCondLHS)) + return UnknownValue; // Not a comparison against 'n-1'. + break; + } else { + return UnknownValue; + } + default: break; + } + + //std::cerr << "Computed Loop Trip Count as: " << + // *SCEV::getMinusSCEV(RHS, AddRec->getOperand(0)) << "\n"; + return SCEV::getMinusSCEV(RHS, AddRec->getOperand(0)); + } + + return UnknownValue; +} + /// getNumIterationsInRange - Return the number of iterations of this loop that /// produce values in the specified constant range. Another way of looking at /// this is that it returns the first iteration number where the value is not in @@ -2258,7 +2389,7 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range) const { if (ConstantBool *CB = dyn_cast(ConstantExpr::getSetLT(R1->getValue(), R2->getValue()))) { - if (CB != ConstantBool::True) + if (CB->getValue() == false) std::swap(R1, R2); // R1 is the minimum root now. // Make sure the root is not off by one. The returned iteration should