X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=069f6ec714cc54a01d21df51ad5885a0b0387927;hb=9a2f93121b31bf6345d1552bdc43037f89714d86;hp=473eadcd87864bac583aefbffb628813789757f4;hpb=3e35c8d15e6981ab759820e84ffcb945bfcef71b;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 473eadcd878..069f6ec714c 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -105,6 +105,7 @@ namespace { RegisterPass R("scalar-evolution", "Scalar Evolution Analysis"); } +char ScalarEvolution::ID = 0; //===----------------------------------------------------------------------===// // SCEV class definitions @@ -182,6 +183,10 @@ SCEVHandle SCEVConstant::get(ConstantInt *V) { return R; } +SCEVHandle SCEVConstant::get(const APInt& Val) { + return get(ConstantInt::get(Val)); +} + ConstantRange SCEVConstant::getValueRange() const { return ConstantRange(V->getValue()); } @@ -244,6 +249,32 @@ void SCEVZeroExtendExpr::print(std::ostream &OS) const { OS << "(zeroextend " << *Op << " to " << *Ty << ")"; } +// SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any +// particular input. Don't use a SCEVHandle here, or else the object will never +// be deleted! +static ManagedStatic, + SCEVSignExtendExpr*> > SCEVSignExtends; + +SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty) + : SCEV(scSignExtend), Op(op), Ty(ty) { + assert(Op->getType()->isInteger() && Ty->isInteger() && + "Cannot sign extend non-integer value!"); + assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits() + && "This is not an extending conversion!"); +} + +SCEVSignExtendExpr::~SCEVSignExtendExpr() { + SCEVSignExtends->erase(std::make_pair(Op, Ty)); +} + +ConstantRange SCEVSignExtendExpr::getValueRange() const { + return getOperand()->getValueRange().signExtend(getBitWidth()); +} + +void SCEVSignExtendExpr::print(std::ostream &OS) const { + OS << "(signextend " << *Op << " to " << *Ty << ")"; +} + // SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any // particular input. Don't use a SCEVHandle here, or else the object will never // be deleted! @@ -454,16 +485,13 @@ SCEVHandle SCEVUnknown::getIntegerSCEV(int Val, const Type *Ty) { if (Val == 0) C = Constant::getNullValue(Ty); else if (Ty->isFloatingPoint()) - C = ConstantFP::get(Ty, Val); + C = ConstantFP::get(Ty, APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : + APFloat::IEEEdouble, Val)); else C = ConstantInt::get(Ty, Val); return SCEVUnknown::get(C); } -SCEVHandle SCEVUnknown::getIntegerSCEV(const APInt& Val) { - return SCEVUnknown::get(ConstantInt::get(Val)); -} - /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the /// input value to the specified type. If the type must be extended, it is zero /// extended. @@ -504,7 +532,7 @@ static SCEVHandle PartialFact(SCEVHandle V, unsigned NumSteps) { APInt Result(Val.getBitWidth(), 1); for (; NumSteps; --NumSteps) Result *= Val-(NumSteps-1); - return SCEVUnknown::get(ConstantInt::get(Result)); + return SCEVConstant::get(Result); } const Type *Ty = V->getType(); @@ -587,6 +615,21 @@ SCEVHandle SCEVZeroExtendExpr::get(const SCEVHandle &Op, const Type *Ty) { return Result; } +SCEVHandle SCEVSignExtendExpr::get(const SCEVHandle &Op, const Type *Ty) { + if (SCEVConstant *SC = dyn_cast(Op)) + return SCEVUnknown::get( + ConstantExpr::getSExt(SC->getValue(), Ty)); + + // FIXME: If the input value is a chrec scev, and we can prove that the value + // did not overflow the old, smaller, value, we can sign extend all of the + // operands (often constants). This would allow analysis of something like + // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } + + SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)]; + if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty); + return Result; +} + // get - Get a canonical add expression, or something simpler if possible. SCEVHandle SCEVAddExpr::get(std::vector &Ops) { assert(!Ops.empty() && "Cannot get empty add!"); @@ -643,8 +686,11 @@ SCEVHandle SCEVAddExpr::get(std::vector &Ops) { return SCEVAddExpr::get(Ops); } - // Okay, now we know the first non-constant operand. If there are add - // operands they would be next. + // Now we know the first non-constant operand. Skip past any cast SCEVs. + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr) + ++Idx; + + // If there are add operands they would be next. if (Idx < Ops.size()) { bool DeletedAdd = false; while (SCEVAddExpr *Add = dyn_cast(Ops[Idx])) { @@ -1122,10 +1168,10 @@ namespace { /// loop without a loop-invariant iteration count. SCEVHandle getIterationCount(const Loop *L); - /// deleteInstructionFromRecords - This method should be called by the - /// client before it removes an instruction from the program, to make sure + /// deleteValueFromRecords - This method should be called by the + /// client before it removes a value from the program, to make sure /// that no dangling references are left around. - void deleteInstructionFromRecords(Instruction *I); + void deleteValueFromRecords(Value *V); private: /// createSCEV - We know that there is no SCEV for the specified value. @@ -1175,8 +1221,9 @@ namespace { /// HowManyLessThans - Return the number of times a backedge containing the /// specified less-than comparison will execute. If not computable, return - /// UnknownValue. - SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L); + /// UnknownValue. isSigned specifies whether the less-than is signed. + SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, + bool isSigned); /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a @@ -1191,13 +1238,32 @@ namespace { // Basic SCEV Analysis and PHI Idiom Recognition Code // -/// deleteInstructionFromRecords - This method should be called by the +/// deleteValueFromRecords - This method should be called by the /// client before it removes an instruction from the program, to make sure /// that no dangling references are left around. -void ScalarEvolutionsImpl::deleteInstructionFromRecords(Instruction *I) { - Scalars.erase(I); - if (PHINode *PN = dyn_cast(I)) - ConstantEvolutionLoopExitValue.erase(PN); +void ScalarEvolutionsImpl::deleteValueFromRecords(Value *V) { + SmallVector Worklist; + + if (Scalars.erase(V)) { + if (PHINode *PN = dyn_cast(V)) + ConstantEvolutionLoopExitValue.erase(PN); + Worklist.push_back(V); + } + + while (!Worklist.empty()) { + Value *VV = Worklist.back(); + Worklist.pop_back(); + + for (Instruction::use_iterator UI = VV->use_begin(), UE = VV->use_end(); + UI != UE; ++UI) { + Instruction *Inst = cast(*UI); + if (Scalars.erase(Inst)) { + if (PHINode *PN = dyn_cast(VV)) + ConstantEvolutionLoopExitValue.erase(PN); + Worklist.push_back(Inst); + } + } + } } @@ -1350,6 +1416,9 @@ static APInt GetConstantFactor(SCEVHandle S) { if (SCEVZeroExtendExpr *E = dyn_cast(S)) return GetConstantFactor(E->getOperand()).zext( cast(E->getType())->getBitWidth()); + if (SCEVSignExtendExpr *E = dyn_cast(S)) + return GetConstantFactor(E->getOperand()).sext( + cast(E->getType())->getBitWidth()); if (SCEVAddExpr *A = dyn_cast(S)) { // The result is the min of all operands. @@ -1450,6 +1519,9 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { case Instruction::ZExt: return SCEVZeroExtendExpr::get(getSCEV(I->getOperand(0)), I->getType()); + case Instruction::SExt: + return SCEVSignExtendExpr::get(getSCEV(I->getOperand(0)), I->getType()); + case Instruction::BitCast: // BitCasts are no-op casts so we just eliminate the cast. if (I->getType()->isInteger() && @@ -1498,7 +1570,7 @@ SCEVHandle ScalarEvolutionsImpl::getIterationCount(const Loop *L) { /// will iterate. SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { // If the loop has a non-one exit block count, we can't analyze it. - std::vector ExitBlocks; + SmallVector ExitBlocks; L->getExitBlocks(ExitBlocks); if (ExitBlocks.size() != 1) return UnknownValue; @@ -1601,8 +1673,7 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { ConstantRange CompRange( ICmpInst::makeConstantRange(Cond, CompVal->getValue())); - SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange, - false /*Always treat as unsigned range*/); + SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange); if (!isa(Ret)) return Ret; } } @@ -1621,12 +1692,24 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { break; } case ICmpInst::ICMP_SLT: { - SCEVHandle TC = HowManyLessThans(LHS, RHS, L); + SCEVHandle TC = HowManyLessThans(LHS, RHS, L, true); if (!isa(TC)) return TC; break; } case ICmpInst::ICMP_SGT: { - SCEVHandle TC = HowManyLessThans(RHS, LHS, L); + SCEVHandle TC = HowManyLessThans(SCEV::getNegativeSCEV(LHS), + SCEV::getNegativeSCEV(RHS), L, true); + if (!isa(TC)) return TC; + break; + } + case ICmpInst::ICMP_ULT: { + SCEVHandle TC = HowManyLessThans(LHS, RHS, L, false); + if (!isa(TC)) return TC; + break; + } + case ICmpInst::ICMP_UGT: { + SCEVHandle TC = HowManyLessThans(SCEV::getNegativeSCEV(LHS), + SCEV::getNegativeSCEV(RHS), L, false); if (!isa(TC)) return TC; break; } @@ -1646,8 +1729,8 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { } static ConstantInt * -EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, Constant *C) { - SCEVHandle InVal = SCEVConstant::get(cast(C)); +EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C) { + SCEVHandle InVal = SCEVConstant::get(C); SCEVHandle Val = AddRec->evaluateAtIteration(InVal); assert(isa(Val) && "Evaluation of SCEV at constant didn't fold correctly?"); @@ -2129,8 +2212,8 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec) { ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA)); ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA)); - return std::make_pair(SCEVUnknown::get(Solution1), - SCEVUnknown::get(Solution2)); + return std::make_pair(SCEVConstant::get(Solution1), + SCEVConstant::get(Solution2)); } // end APIntOps namespace } @@ -2240,7 +2323,7 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) { /// specified less-than comparison will execute. If not computable, return /// UnknownValue. SCEVHandle ScalarEvolutionsImpl:: -HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) { +HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) { // Only handle: "ADDREC < LoopInvariant". if (!RHS->isLoopInvariant(L)) return UnknownValue; @@ -2297,28 +2380,34 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) { switch (Cond) { case ICmpInst::ICMP_UGT: + if (isSigned) return UnknownValue; std::swap(PreCondLHS, PreCondRHS); Cond = ICmpInst::ICMP_ULT; break; case ICmpInst::ICMP_SGT: + if (!isSigned) return UnknownValue; std::swap(PreCondLHS, PreCondRHS); Cond = ICmpInst::ICMP_SLT; break; - default: break; + case ICmpInst::ICMP_ULT: + if (isSigned) return UnknownValue; + break; + case ICmpInst::ICMP_SLT: + if (!isSigned) return UnknownValue; + break; + default: + return UnknownValue; } - if (Cond == ICmpInst::ICMP_SLT) { - if (PreCondLHS->getType()->isInteger()) { - if (RHS != getSCEV(PreCondRHS)) - return UnknownValue; // Not a comparison against 'm'. + if (PreCondLHS->getType()->isInteger()) { + if (RHS != getSCEV(PreCondRHS)) + return UnknownValue; // Not a comparison against 'm'. - if (SCEV::getMinusSCEV(AddRec->getOperand(0), One) - != getSCEV(PreCondLHS)) - return UnknownValue; // Not a comparison against 'n-1'. - } - else return UnknownValue; - } else if (Cond == ICmpInst::ICMP_ULT) - return UnknownValue; + if (SCEV::getMinusSCEV(AddRec->getOperand(0), One) + != getSCEV(PreCondLHS)) + return UnknownValue; // Not a comparison against 'n-1'. + } + else return UnknownValue; // cerr << "Computed Loop Trip Count as: " // << // *SCEV::getMinusSCEV(RHS, AddRec->getOperand(0)) << "\n"; @@ -2336,8 +2425,7 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) { /// this is that it returns the first iteration number where the value is not in /// the condition, thus computing the exit count. If the iteration count can't /// be computed, an instance of SCEVCouldNotCompute is returned. -SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, - bool isSigned) const { +SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range) const { if (Range.isFullSet()) // Infinite loop. return new SCEVCouldNotCompute(); @@ -2349,7 +2437,7 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, SCEVHandle Shifted = SCEVAddRecExpr::get(Operands, getLoop()); if (SCEVAddRecExpr *ShiftedAddRec = dyn_cast(Shifted)) return ShiftedAddRec->getNumIterationsInRange( - Range.subtract(SC->getValue()->getValue()),isSigned); + Range.subtract(SC->getValue()->getValue())); // This is strange and shouldn't happen. return new SCEVCouldNotCompute(); } @@ -2373,17 +2461,16 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // If this is an affine expression then we have this situation: // Solve {0,+,A} in Range === Ax in Range - // Since we know that zero is in the range, we know that the upper value of - // the range must be the first possible exit value. Also note that we - // already checked for a full range. - const APInt &Upper = Range.getUpper(); - APInt A = cast(getOperand(1))->getValue()->getValue(); + // We know that zero is in the range. If A is positive then we know that + // the upper value of the range must be the first possible exit value. + // If A is negative then the lower of the range is the last possible loop + // value. Also note that we already checked for a full range. APInt One(getBitWidth(),1); + APInt A = cast(getOperand(1))->getValue()->getValue(); + APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower(); - // The exit value should be (Upper+A-1)/A. - APInt ExitVal(Upper); - if (A != One) - ExitVal = (Upper + A - One).sdiv(A); + // The exit value should be (End+A)/A. + APInt ExitVal = (End + A).udiv(A); ConstantInt *ExitValue = ConstantInt::get(ExitVal); // Evaluate at the exit value. If we really did fall out of the valid @@ -2398,15 +2485,14 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, EvaluateConstantChrecAtConstant(this, ConstantInt::get(ExitVal - One))->getValue()) && "Linear scev computation is off in a bad way!"); - return SCEVConstant::get(cast(ExitValue)); + return SCEVConstant::get(ExitValue); } else if (isQuadratic()) { // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the // quadratic equation to solve it. To do this, we must frame our problem in // terms of figuring out when zero is crossed, instead of when // Range.getUpper() is crossed. std::vector NewOps(op_begin(), op_end()); - NewOps[0] = SCEV::getNegativeSCEV(SCEVUnknown::get( - ConstantInt::get(Range.getUpper()))); + NewOps[0] = SCEV::getNegativeSCEV(SCEVConstant::get(Range.getUpper())); SCEVHandle NewAddRec = SCEVAddRecExpr::get(NewOps, getLoop()); // Next, solve the constructed addrec @@ -2429,17 +2515,17 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, R1->getValue()); if (Range.contains(R1Val->getValue())) { // The next iteration must be out of the range... - Constant *NextVal = ConstantInt::get(R1->getValue()->getValue()+1); + ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()+1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal); if (!Range.contains(R1Val->getValue())) - return SCEVUnknown::get(NextVal); + return SCEVConstant::get(NextVal); return new SCEVCouldNotCompute(); // Something strange happened } // If R1 was not in the range, then it is a good return value. Make // sure that R1-1 WAS in the range though, just in case. - Constant *NextVal = ConstantInt::get(R1->getValue()->getValue()-1); + ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()-1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal); if (Range.contains(R1Val->getValue())) return R1; @@ -2523,8 +2609,8 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) const { return ((ScalarEvolutionsImpl*)Impl)->getSCEVAtScope(getSCEV(V), L); } -void ScalarEvolution::deleteInstructionFromRecords(Instruction *I) const { - return ((ScalarEvolutionsImpl*)Impl)->deleteInstructionFromRecords(I); +void ScalarEvolution::deleteValueFromRecords(Value *V) const { + return ((ScalarEvolutionsImpl*)Impl)->deleteValueFromRecords(V); } static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE, @@ -2535,7 +2621,7 @@ static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE, cerr << "Loop " << L->getHeader()->getName() << ": "; - std::vector ExitBlocks; + SmallVector ExitBlocks; L->getExitBlocks(ExitBlocks); if (ExitBlocks.size() != 1) cerr << " ";