X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=82be9cd5c4e3f45e8105990c705abf8f4921f08a;hb=f451cb870efcf9e0302d25ed05f4cac6bb494e42;hp=51437bcd0d3b8a627fc937290f8c7d55616eb060;hpb=f5074ec9634d51472bc6e2114deea0afb6677dd8;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 51437bcd0d3..82be9cd5c4e 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -14,9 +14,8 @@ // There are several aspects to this library. First is the representation of // scalar expressions, which are represented as subclasses of the SCEV class. // These classes are used to represent certain types of subexpressions that we -// can handle. These classes are reference counted, managed by the const SCEV * -// class. We only create one SCEV of a particular shape, so pointer-comparisons -// for equality are legal. +// can handle. We only create one SCEV of a particular shape, so +// pointer-comparisons for equality are legal. // // One important aspect of the SCEV objects is that they are never cyclic, even // if there is a cycle in the dataflow for an expression (ie, a PHI node). If @@ -64,8 +63,10 @@ #include "llvm/Constants.h" #include "llvm/DerivedTypes.h" #include "llvm/GlobalVariable.h" +#include "llvm/GlobalAlias.h" #include "llvm/Instructions.h" #include "llvm/LLVMContext.h" +#include "llvm/Operator.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/Dominators.h" #include "llvm/Analysis/LoopInfo.h" @@ -73,8 +74,8 @@ #include "llvm/Assembly/Writer.h" #include "llvm/Target/TargetData.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Compiler.h" #include "llvm/Support/ConstantRange.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GetElementPtrTypeIterator.h" #include "llvm/Support/InstIterator.h" @@ -117,13 +118,8 @@ char ScalarEvolution::ID = 0; SCEV::~SCEV() {} void SCEV::dump() const { - print(errs()); - errs() << '\n'; -} - -void SCEV::print(std::ostream &o) const { - raw_os_ostream OS(o); - print(OS); + print(dbgs()); + dbgs() << '\n'; } bool SCEV::isZero() const { @@ -148,26 +144,23 @@ SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(FoldingSetNodeID(), scCouldNotCompute) {} bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const { - LLVM_UNREACHABLE("Attempt to use a SCEVCouldNotCompute object!"); + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); return false; } const Type *SCEVCouldNotCompute::getType() const { - LLVM_UNREACHABLE("Attempt to use a SCEVCouldNotCompute object!"); + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); return 0; } bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const { - LLVM_UNREACHABLE("Attempt to use a SCEVCouldNotCompute object!"); + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); return false; } -const SCEV * -SCEVCouldNotCompute::replaceSymbolicValuesWithConcrete( - const SCEV *Sym, - const SCEV *Conc, - ScalarEvolution &SE) const { - return this; +bool SCEVCouldNotCompute::hasOperand(const SCEV *) const { + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + return false; } void SCEVCouldNotCompute::print(raw_ostream &OS) const { @@ -191,12 +184,13 @@ const SCEV *ScalarEvolution::getConstant(ConstantInt *V) { } const SCEV *ScalarEvolution::getConstant(const APInt& Val) { - return getConstant(ConstantInt::get(Val)); + return getConstant(ConstantInt::get(getContext(), Val)); } const SCEV * ScalarEvolution::getConstant(const Type *Ty, uint64_t V, bool isSigned) { - return getConstant(ConstantInt::get(cast(Ty), V, isSigned)); + return getConstant( + ConstantInt::get(cast(Ty), V, isSigned)); } const Type *SCEVConstant::getType() const { return V->getType(); } @@ -213,6 +207,10 @@ bool SCEVCastExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { return Op->dominates(BB, DT); } +bool SCEVCastExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const { + return Op->properlyDominates(BB, DT); +} + SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeID &ID, const SCEV *op, const Type *ty) : SCEVCastExpr(ID, scTruncate, op, ty) { @@ -258,42 +256,17 @@ void SCEVCommutativeExpr::print(raw_ostream &OS) const { OS << ")"; } -const SCEV * -SCEVCommutativeExpr::replaceSymbolicValuesWithConcrete( - const SCEV *Sym, - const SCEV *Conc, - ScalarEvolution &SE) const { +bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - const SCEV *H = - getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); - if (H != getOperand(i)) { - SmallVector NewOps; - NewOps.reserve(getNumOperands()); - for (unsigned j = 0; j != i; ++j) - NewOps.push_back(getOperand(j)); - NewOps.push_back(H); - for (++i; i != e; ++i) - NewOps.push_back(getOperand(i)-> - replaceSymbolicValuesWithConcrete(Sym, Conc, SE)); - - if (isa(this)) - return SE.getAddExpr(NewOps); - else if (isa(this)) - return SE.getMulExpr(NewOps); - else if (isa(this)) - return SE.getSMaxExpr(NewOps); - else if (isa(this)) - return SE.getUMaxExpr(NewOps); - else - LLVM_UNREACHABLE("Unknown commutative expr!"); - } + if (!getOperand(i)->dominates(BB, DT)) + return false; } - return this; + return true; } -bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { +bool SCEVNAryExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const { for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - if (!getOperand(i)->dominates(BB, DT)) + if (!getOperand(i)->properlyDominates(BB, DT)) return false; } return true; @@ -303,6 +276,10 @@ bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { return LHS->dominates(BB, DT) && RHS->dominates(BB, DT); } +bool SCEVUDivExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const { + return LHS->properlyDominates(BB, DT) && RHS->properlyDominates(BB, DT); +} + void SCEVUDivExpr::print(raw_ostream &OS) const { OS << "(" << *LHS << " /u " << *RHS << ")"; } @@ -316,37 +293,13 @@ const Type *SCEVUDivExpr::getType() const { return RHS->getType(); } -const SCEV * -SCEVAddRecExpr::replaceSymbolicValuesWithConcrete(const SCEV *Sym, - const SCEV *Conc, - ScalarEvolution &SE) const { - for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - const SCEV *H = - getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); - if (H != getOperand(i)) { - SmallVector NewOps; - NewOps.reserve(getNumOperands()); - for (unsigned j = 0; j != i; ++j) - NewOps.push_back(getOperand(j)); - NewOps.push_back(H); - for (++i; i != e; ++i) - NewOps.push_back(getOperand(i)-> - replaceSymbolicValuesWithConcrete(Sym, Conc, SE)); - - return SE.getAddRecExpr(NewOps, L); - } - } - return this; -} - - bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const { // Add recurrences are never invariant in the function-body (null loop). if (!QueryLoop) return false; // This recurrence is variant w.r.t. QueryLoop if QueryLoop contains L. - if (QueryLoop->contains(L->getHeader())) + if (QueryLoop->contains(L)) return false; // This recurrence is variant w.r.t. QueryLoop if any of its operands @@ -363,7 +316,9 @@ void SCEVAddRecExpr::print(raw_ostream &OS) const { OS << "{" << *Operands[0]; for (unsigned i = 1, e = Operands.size(); i != e; ++i) OS << ",+," << *Operands[i]; - OS << "}<" << L->getHeader()->getName() + ">"; + OS << "}<"; + WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false); + OS << ">"; } bool SCEVUnknown::isLoopInvariant(const Loop *L) const { @@ -372,7 +327,7 @@ bool SCEVUnknown::isLoopInvariant(const Loop *L) const { // Instructions are never considered invariant in the function body // (null loop) because they are defined within the "loop". if (Instruction *I = dyn_cast(V)) - return L && !L->contains(I->getParent()); + return L && !L->contains(I); return true; } @@ -382,11 +337,101 @@ bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const { return true; } +bool SCEVUnknown::properlyDominates(BasicBlock *BB, DominatorTree *DT) const { + if (Instruction *I = dyn_cast(getValue())) + return DT->properlyDominates(I->getParent(), BB); + return true; +} + const Type *SCEVUnknown::getType() const { return V->getType(); } +bool SCEVUnknown::isSizeOf(const Type *&AllocTy) const { + if (ConstantExpr *VCE = dyn_cast(V)) + if (VCE->getOpcode() == Instruction::PtrToInt) + if (ConstantExpr *CE = dyn_cast(VCE->getOperand(0))) + if (CE->getOpcode() == Instruction::GetElementPtr && + CE->getOperand(0)->isNullValue() && + CE->getNumOperands() == 2) + if (ConstantInt *CI = dyn_cast(CE->getOperand(1))) + if (CI->isOne()) { + AllocTy = cast(CE->getOperand(0)->getType()) + ->getElementType(); + return true; + } + + return false; +} + +bool SCEVUnknown::isAlignOf(const Type *&AllocTy) const { + if (ConstantExpr *VCE = dyn_cast(V)) + if (VCE->getOpcode() == Instruction::PtrToInt) + if (ConstantExpr *CE = dyn_cast(VCE->getOperand(0))) + if (CE->getOpcode() == Instruction::GetElementPtr && + CE->getOperand(0)->isNullValue()) { + const Type *Ty = + cast(CE->getOperand(0)->getType())->getElementType(); + if (const StructType *STy = dyn_cast(Ty)) + if (!STy->isPacked() && + CE->getNumOperands() == 3 && + CE->getOperand(1)->isNullValue()) { + if (ConstantInt *CI = dyn_cast(CE->getOperand(2))) + if (CI->isOne() && + STy->getNumElements() == 2 && + STy->getElementType(0)->isInteger(1)) { + AllocTy = STy->getElementType(1); + return true; + } + } + } + + return false; +} + +bool SCEVUnknown::isOffsetOf(const Type *&CTy, Constant *&FieldNo) const { + if (ConstantExpr *VCE = dyn_cast(V)) + if (VCE->getOpcode() == Instruction::PtrToInt) + if (ConstantExpr *CE = dyn_cast(VCE->getOperand(0))) + if (CE->getOpcode() == Instruction::GetElementPtr && + CE->getNumOperands() == 3 && + CE->getOperand(0)->isNullValue() && + CE->getOperand(1)->isNullValue()) { + const Type *Ty = + cast(CE->getOperand(0)->getType())->getElementType(); + // Ignore vector types here so that ScalarEvolutionExpander doesn't + // emit getelementptrs that index into vectors. + if (isa(Ty) || isa(Ty)) { + CTy = Ty; + FieldNo = CE->getOperand(2); + return true; + } + } + + return false; +} + void SCEVUnknown::print(raw_ostream &OS) const { + const Type *AllocTy; + if (isSizeOf(AllocTy)) { + OS << "sizeof(" << *AllocTy << ")"; + return; + } + if (isAlignOf(AllocTy)) { + OS << "alignof(" << *AllocTy << ")"; + return; + } + + const Type *CTy; + Constant *FieldNo; + if (isOffsetOf(CTy, FieldNo)) { + OS << "offsetof(" << *CTy << ", "; + WriteAsOperand(OS, FieldNo, false); + OS << ")"; + return; + } + + // Otherwise just print it normally. WriteAsOperand(OS, V, false); } @@ -394,16 +439,55 @@ void SCEVUnknown::print(raw_ostream &OS) const { // SCEV Utilities //===----------------------------------------------------------------------===// +static bool CompareTypes(const Type *A, const Type *B) { + if (A->getTypeID() != B->getTypeID()) + return A->getTypeID() < B->getTypeID(); + if (const IntegerType *AI = dyn_cast(A)) { + const IntegerType *BI = cast(B); + return AI->getBitWidth() < BI->getBitWidth(); + } + if (const PointerType *AI = dyn_cast(A)) { + const PointerType *BI = cast(B); + return CompareTypes(AI->getElementType(), BI->getElementType()); + } + if (const ArrayType *AI = dyn_cast(A)) { + const ArrayType *BI = cast(B); + if (AI->getNumElements() != BI->getNumElements()) + return AI->getNumElements() < BI->getNumElements(); + return CompareTypes(AI->getElementType(), BI->getElementType()); + } + if (const VectorType *AI = dyn_cast(A)) { + const VectorType *BI = cast(B); + if (AI->getNumElements() != BI->getNumElements()) + return AI->getNumElements() < BI->getNumElements(); + return CompareTypes(AI->getElementType(), BI->getElementType()); + } + if (const StructType *AI = dyn_cast(A)) { + const StructType *BI = cast(B); + if (AI->getNumElements() != BI->getNumElements()) + return AI->getNumElements() < BI->getNumElements(); + for (unsigned i = 0, e = AI->getNumElements(); i != e; ++i) + if (CompareTypes(AI->getElementType(i), BI->getElementType(i)) || + CompareTypes(BI->getElementType(i), AI->getElementType(i))) + return CompareTypes(AI->getElementType(i), BI->getElementType(i)); + } + return false; +} + namespace { /// SCEVComplexityCompare - Return true if the complexity of the LHS is less /// than the complexity of the RHS. This comparator is used to canonicalize /// expressions. - class VISIBILITY_HIDDEN SCEVComplexityCompare { + class SCEVComplexityCompare { LoopInfo *LI; public: explicit SCEVComplexityCompare(LoopInfo *li) : LI(li) {} bool operator()(const SCEV *LHS, const SCEV *RHS) const { + // Fast-path: SCEVs are uniqued so we can do a quick equality check. + if (LHS == RHS) + return false; + // Primarily, sort the SCEVs by their getSCEVType(). if (LHS->getSCEVType() != RHS->getSCEVType()) return LHS->getSCEVType() < RHS->getSCEVType(); @@ -506,7 +590,7 @@ namespace { return operator()(LC->getOperand(), RC->getOperand()); } - LLVM_UNREACHABLE("Unknown SCEV kind!"); + llvm_unreachable("Unknown SCEV kind!"); return false; } }; @@ -566,8 +650,8 @@ static void GroupByComplexity(SmallVectorImpl &Ops, /// BinomialCoefficient - Compute BC(It, K). The result has width W. /// Assume, K > 0. static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, - ScalarEvolution &SE, - const Type* ResultTy) { + ScalarEvolution &SE, + const Type* ResultTy) { // Handle the simplest case efficiently. if (K == 1) return SE.getTruncateOrZeroExtend(It, ResultTy); @@ -657,7 +741,8 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, MultiplyFactor = MultiplyFactor.trunc(W); // Calculate the product, at width T+W - const IntegerType *CalculationTy = IntegerType::get(CalculationBits); + const IntegerType *CalculationTy = IntegerType::get(SE.getContext(), + CalculationBits); const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); for (unsigned i = 1; i != K; ++i) { const SCEV *S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType())); @@ -684,7 +769,7 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, /// where BC(It, k) stands for binomial coefficient. /// const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It, - ScalarEvolution &SE) const { + ScalarEvolution &SE) const { const SCEV *Result = getStart(); for (unsigned i = 1, e = getNumOperands(); i != e; ++i) { // The computation is correct in the face of overflow provided that the @@ -792,6 +877,13 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); + // If we have special knowledge that this addrec won't overflow, + // we don't need to do any further analysis. + if (AR->hasNoUnsignedWrap()) + return getAddRecExpr(getZeroExtendExpr(Start, Ty), + getZeroExtendExpr(Step, Ty), + L); + // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are // simply not analyzable, and it covers the case where this code is @@ -812,7 +904,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); if (MaxBECount == RecastedMaxBECount) { - const Type *WideTy = IntegerType::get(BitWidth * 2); + const Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no unsigned overflow. const SCEV *ZMul = getMulExpr(CastedMaxBECount, @@ -924,6 +1016,13 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); + // If we have special knowledge that this addrec won't overflow, + // we don't need to do any further analysis. + if (AR->hasNoSignedWrap()) + return getAddRecExpr(getSignExtendExpr(Start, Ty), + getSignExtendExpr(Step, Ty), + L); + // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are // simply not analyzable, and it covers the case where this code is @@ -944,7 +1043,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); if (MaxBECount == RecastedMaxBECount) { - const Type *WideTy = IntegerType::get(BitWidth * 2); + const Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no signed overflow. const SCEV *SMul = getMulExpr(CastedMaxBECount, @@ -959,6 +1058,22 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, return getAddRecExpr(getSignExtendExpr(Start, Ty), getSignExtendExpr(Step, Ty), L); + + // Similar to above, only this time treat the step value as unsigned. + // This covers loops that count up with an unsigned step. + const SCEV *UMul = + getMulExpr(CastedMaxBECount, + getTruncateOrZeroExtend(Step, Start->getType())); + Add = getAddExpr(Start, UMul); + OperandExtendedAdd = + getAddExpr(getSignExtendExpr(Start, WideTy), + getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy), + getZeroExtendExpr(Step, WideTy))); + if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) + // Return the expression with the addrec on the outside. + return getAddRecExpr(getSignExtendExpr(Start, Ty), + getZeroExtendExpr(Step, Ty), + L); } // If the backedge is guarded by a comparison with the pre-inc value @@ -1004,7 +1119,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, /// unspecified bits out to the given type. /// const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, - const Type *Ty) { + const Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -1034,6 +1149,15 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, if (!isa(SExt)) return SExt; + // Force the cast to be folded into the operands of an addrec. + if (const SCEVAddRecExpr *AR = dyn_cast(Op)) { + SmallVector Ops; + for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end(); + I != E; ++I) + Ops.push_back(getAnyExtendExpr(*I, Ty)); + return getAddRecExpr(Ops, AR->getLoop()); + } + // If the expression is obviously signed, use the sext cast value. if (isa(Op)) return SExt; @@ -1138,7 +1262,8 @@ namespace { /// getAddExpr - Get a canonical add expression, or something simpler if /// possible. -const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { +const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, + bool HasNUW, bool HasNSW) { assert(!Ops.empty() && "Cannot get empty add!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG @@ -1148,6 +1273,17 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { "SCEVAddExpr operand types don't match!"); #endif + // If HasNSW is true and all the operands are non-negative, infer HasNUW. + if (!HasNUW && HasNSW) { + bool All = true; + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + if (!isKnownNonNegative(Ops[i])) { + All = false; + break; + } + if (All) HasNUW = true; + } + // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, LI); @@ -1188,7 +1324,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { return Mul; Ops.erase(Ops.begin()+i, Ops.begin()+i+2); Ops.push_back(Mul); - return getAddExpr(Ops); + return getAddExpr(Ops, HasNUW, HasNSW); } // Check for truncates. If all the operands are truncated from the same @@ -1243,7 +1379,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { } if (Ok) { // Evaluate the expression in the larger type. - const SCEV *Fold = getAddExpr(LargeOps); + const SCEV *Fold = getAddExpr(LargeOps, HasNUW, HasNSW); // If it folds to something simple, use it. Otherwise, don't. if (isa(Fold) || isa(Fold)) return getTruncateExpr(Fold, DstType); @@ -1404,10 +1540,13 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { LIOps.push_back(AddRec->getStart()); SmallVector AddRecOps(AddRec->op_begin(), - AddRec->op_end()); + AddRec->op_end()); AddRecOps[0] = getAddExpr(LIOps); + // It's tempting to propagate NUW/NSW flags here, but nuw/nsw addition + // is not associative so this isn't necessarily safe. const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop()); + // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -1462,18 +1601,24 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = 0; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - SCEV *S = SCEVAllocator.Allocate(); - new (S) SCEVAddExpr(ID, Ops); - UniqueSCEVs.InsertNode(S, IP); + SCEVAddExpr *S = + static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + if (!S) { + S = SCEVAllocator.Allocate(); + new (S) SCEVAddExpr(ID, Ops); + UniqueSCEVs.InsertNode(S, IP); + } + if (HasNUW) S->setHasNoUnsignedWrap(true); + if (HasNSW) S->setHasNoSignedWrap(true); return S; } - /// getMulExpr - Get a canonical multiply expression, or something simpler if /// possible. -const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { +const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, + bool HasNUW, bool HasNSW) { assert(!Ops.empty() && "Cannot get empty mul!"); + if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG for (unsigned i = 1, e = Ops.size(); i != e; ++i) assert(getEffectiveSCEVType(Ops[i]->getType()) == @@ -1481,6 +1626,17 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { "SCEVMulExpr operand types don't match!"); #endif + // If HasNSW is true and all the operands are non-negative, infer HasNUW. + if (!HasNUW && HasNSW) { + bool All = true; + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + if (!isKnownNonNegative(Ops[i])) { + All = false; + break; + } + if (All) HasNUW = true; + } + // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, LI); @@ -1496,11 +1652,11 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), getMulExpr(LHSC, Add->getOperand(1))); - ++Idx; while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() * + ConstantInt *Fold = ConstantInt::get(getContext(), + LHSC->getValue()->getValue() * RHSC->getValue()->getValue()); Ops[0] = getConstant(Fold); Ops.erase(Ops.begin()+1); // Erase the folded element @@ -1515,6 +1671,22 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { } else if (cast(Ops[0])->getValue()->isZero()) { // If we have a multiply of zero, it will always be zero. return Ops[0]; + } else if (Ops[0]->isAllOnesValue()) { + // If we have a mul by -1 of an add, try distributing the -1 among the + // add operands. + if (Ops.size() == 2) + if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) { + SmallVector NewOps; + bool AnyFolded = false; + for (SCEVAddRecExpr::op_iterator I = Add->op_begin(), E = Add->op_end(); + I != E; ++I) { + const SCEV *Mul = getMulExpr(Ops[0], *I); + if (!isa(Mul)) AnyFolded = true; + NewOps.push_back(Mul); + } + if (AnyFolded) + return getAddExpr(NewOps); + } } } @@ -1579,7 +1751,11 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { } } - const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop()); + // It's tempting to propagate the NSW flag here, but nsw multiplication + // is not associative so this isn't necessarily safe. + const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), + HasNUW && AddRec->hasNoUnsignedWrap(), + /*HasNSW=*/false); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -1633,15 +1809,20 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = 0; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - SCEV *S = SCEVAllocator.Allocate(); - new (S) SCEVMulExpr(ID, Ops); - UniqueSCEVs.InsertNode(S, IP); + SCEVMulExpr *S = + static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + if (!S) { + S = SCEVAllocator.Allocate(); + new (S) SCEVMulExpr(ID, Ops); + UniqueSCEVs.InsertNode(S, IP); + } + if (HasNUW) S->setHasNoUnsignedWrap(true); + if (HasNSW) S->setHasNoSignedWrap(true); return S; } -/// getUDivExpr - Get a canonical multiply expression, or something simpler if -/// possible. +/// getUDivExpr - Get a canonical unsigned division expression, or something +/// simpler if possible. const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, const SCEV *RHS) { assert(getEffectiveSCEVType(LHS->getType()) == @@ -1650,7 +1831,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, if (const SCEVConstant *RHSC = dyn_cast(RHS)) { if (RHSC->getValue()->equalsInt(1)) - return LHS; // X udiv 1 --> x + return LHS; // X udiv 1 --> x if (RHSC->isZero()) return getIntegerSCEV(0, LHS->getType()); // value is undefined @@ -1665,7 +1846,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, if (!RHSC->getValue()->getValue().isPowerOf2()) ++MaxShiftAmt; const IntegerType *ExtTy = - IntegerType::get(getTypeSizeInBits(Ty) + MaxShiftAmt); + IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt); // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. if (const SCEVAddRecExpr *AR = dyn_cast(LHS)) if (const SCEVConstant *Step = @@ -1743,7 +1924,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, /// getAddRecExpr - Get an add recurrence expression for the specified loop. /// Simplify the expression as much as possible. const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, - const SCEV *Step, const Loop *L) { + const SCEV *Step, const Loop *L, + bool HasNUW, bool HasNSW) { SmallVector Operands; Operands.push_back(Start); if (const SCEVAddRecExpr *StepChrec = dyn_cast(Step)) @@ -1754,14 +1936,15 @@ const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, } Operands.push_back(Step); - return getAddRecExpr(Operands, L); + return getAddRecExpr(Operands, L, HasNUW, HasNSW); } /// getAddRecExpr - Get an add recurrence expression for the specified loop. /// Simplify the expression as much as possible. const SCEV * ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, - const Loop *L) { + const Loop *L, + bool HasNUW, bool HasNSW) { if (Operands.size() == 1) return Operands[0]; #ifndef NDEBUG for (unsigned i = 1, e = Operands.size(); i != e; ++i) @@ -1772,15 +1955,29 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, if (Operands.back()->isZero()) { Operands.pop_back(); - return getAddRecExpr(Operands, L); // {X,+,0} --> X + return getAddRecExpr(Operands, L, HasNUW, HasNSW); // {X,+,0} --> X + } + + // If HasNSW is true and all the operands are non-negative, infer HasNUW. + if (!HasNUW && HasNSW) { + bool All = true; + for (unsigned i = 0, e = Operands.size(); i != e; ++i) + if (!isKnownNonNegative(Operands[i])) { + All = false; + break; + } + if (All) HasNUW = true; } // Canonicalize nested AddRecs in by nesting them in order of loop depth. if (const SCEVAddRecExpr *NestedAR = dyn_cast(Operands[0])) { - const Loop* NestedLoop = NestedAR->getLoop(); - if (L->getLoopDepth() < NestedLoop->getLoopDepth()) { + const Loop *NestedLoop = NestedAR->getLoop(); + if (L->contains(NestedLoop->getHeader()) ? + (L->getLoopDepth() < NestedLoop->getLoopDepth()) : + (!NestedLoop->contains(L->getHeader()) && + DT->dominates(L->getHeader(), NestedLoop->getHeader()))) { SmallVector NestedOperands(NestedAR->op_begin(), - NestedAR->op_end()); + NestedAR->op_end()); Operands[0] = NestedAR->getStart(); // AddRecs require their operands be loop-invariant with respect to their // loops. Don't perform this transformation if it would break this @@ -1801,13 +1998,15 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, } if (AllInvariant) // Ok, both add recurrences are valid after the transformation. - return getAddRecExpr(NestedOperands, NestedLoop); + return getAddRecExpr(NestedOperands, NestedLoop, HasNUW, HasNSW); } // Reset Operands to its original state. Operands[0] = NestedAR; } } + // Okay, it looks like we really DO need an addrec expr. Check to see if we + // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(scAddRecExpr); ID.AddInteger(Operands.size()); @@ -1815,10 +2014,15 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, ID.AddPointer(Operands[i]); ID.AddPointer(L); void *IP = 0; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - SCEV *S = SCEVAllocator.Allocate(); - new (S) SCEVAddRecExpr(ID, Operands, L); - UniqueSCEVs.InsertNode(S, IP); + SCEVAddRecExpr *S = + static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + if (!S) { + S = SCEVAllocator.Allocate(); + new (S) SCEVAddRecExpr(ID, Operands, L); + UniqueSCEVs.InsertNode(S, IP); + } + if (HasNUW) S->setHasNoUnsignedWrap(true); + if (HasNSW) S->setHasNoSignedWrap(true); return S; } @@ -1851,7 +2055,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get( + ConstantInt *Fold = ConstantInt::get(getContext(), APIntOps::smax(LHSC->getValue()->getValue(), RHSC->getValue()->getValue())); Ops[0] = getConstant(Fold); @@ -1948,7 +2152,7 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get( + ConstantInt *Fold = ConstantInt::get(getContext(), APIntOps::umax(LHSC->getValue()->getValue(), RHSC->getValue()->getValue())); Ops[0] = getConstant(Fold); @@ -2028,6 +2232,40 @@ const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); } +const SCEV *ScalarEvolution::getSizeOfExpr(const Type *AllocTy) { + Constant *C = ConstantExpr::getSizeOf(AllocTy); + if (ConstantExpr *CE = dyn_cast(C)) + C = ConstantFoldConstantExpression(CE, TD); + const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy)); + return getTruncateOrZeroExtend(getSCEV(C), Ty); +} + +const SCEV *ScalarEvolution::getAlignOfExpr(const Type *AllocTy) { + Constant *C = ConstantExpr::getAlignOf(AllocTy); + if (ConstantExpr *CE = dyn_cast(C)) + C = ConstantFoldConstantExpression(CE, TD); + const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy)); + return getTruncateOrZeroExtend(getSCEV(C), Ty); +} + +const SCEV *ScalarEvolution::getOffsetOfExpr(const StructType *STy, + unsigned FieldNo) { + Constant *C = ConstantExpr::getOffsetOf(STy, FieldNo); + if (ConstantExpr *CE = dyn_cast(C)) + C = ConstantFoldConstantExpression(CE, TD); + const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy)); + return getTruncateOrZeroExtend(getSCEV(C), Ty); +} + +const SCEV *ScalarEvolution::getOffsetOfExpr(const Type *CTy, + Constant *FieldNo) { + Constant *C = ConstantExpr::getOffsetOf(CTy, FieldNo); + if (ConstantExpr *CE = dyn_cast(C)) + C = ConstantFoldConstantExpression(CE, TD); + const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(CTy)); + return getTruncateOrZeroExtend(getSCEV(C), Ty); +} + const SCEV *ScalarEvolution::getUnknown(Value *V) { // Don't attempt to do anything other than create a SCEVUnknown object // here. createSCEV only calls getUnknown after checking for all other @@ -2054,17 +2292,8 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) { /// can optionally include pointer types if the ScalarEvolution class /// has access to target-specific information. bool ScalarEvolution::isSCEVable(const Type *Ty) const { - // Integers are always SCEVable. - if (Ty->isInteger()) - return true; - - // Pointers are SCEVable if TargetData information is available - // to provide pointer size information. - if (isa(Ty)) - return TD != NULL; - - // Otherwise it's not SCEVable. - return false; + // Integers and pointers are always SCEVable. + return Ty->isInteger() || isa(Ty); } /// getTypeSizeInBits - Return the size in bits of the specified type, @@ -2076,9 +2305,14 @@ uint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const { if (TD) return TD->getTypeSizeInBits(Ty); - // Otherwise, we support only integer types. - assert(Ty->isInteger() && "isSCEVable permitted a non-SCEVable type!"); - return Ty->getPrimitiveSizeInBits(); + // Integer types have fixed sizes. + if (Ty->isInteger()) + return Ty->getPrimitiveSizeInBits(); + + // The only other support type is pointer. Without TargetData, conservatively + // assume pointers are 64-bit. + assert(isa(Ty) && "isSCEVable permitted a non-SCEVable type!"); + return 64; } /// getEffectiveSCEVType - Return a type with the same bitwidth as @@ -2091,8 +2325,12 @@ const Type *ScalarEvolution::getEffectiveSCEVType(const Type *Ty) const { if (Ty->isInteger()) return Ty; + // The only other support type is pointer. assert(isa(Ty) && "Unexpected non-pointer non-integer type!"); - return TD->getIntPtrType(); + if (TD) return TD->getIntPtrType(getContext()); + + // Without TargetData, conservatively assume pointers are 64-bit. + return Type::getInt64Ty(getContext()); } const SCEV *ScalarEvolution::getCouldNotCompute() { @@ -2113,7 +2351,7 @@ const SCEV *ScalarEvolution::getSCEV(Value *V) { /// getIntegerSCEV - Given a SCEVable type, create a constant for the /// specified signed integer value and return a SCEV for the constant. -const SCEV *ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) { +const SCEV *ScalarEvolution::getIntegerSCEV(int64_t Val, const Type *Ty) { const IntegerType *ITy = cast(getEffectiveSCEVType(Ty)); return getConstant(ConstantInt::get(ITy, Val)); } @@ -2123,24 +2361,24 @@ const SCEV *ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) { const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) { if (const SCEVConstant *VC = dyn_cast(V)) return getConstant( - cast(Context->getConstantExprNeg(VC->getValue()))); + cast(ConstantExpr::getNeg(VC->getValue()))); const Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); return getMulExpr(V, - getConstant(cast(Context->getAllOnesValue(Ty)))); + getConstant(cast(Constant::getAllOnesValue(Ty)))); } /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { if (const SCEVConstant *VC = dyn_cast(V)) return getConstant( - cast(Context->getConstantExprNot(VC->getValue()))); + cast(ConstantExpr::getNot(VC->getValue()))); const Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); const SCEV *AllOnes = - getConstant(cast(Context->getAllOnesValue(Ty))); + getConstant(cast(Constant::getAllOnesValue(Ty))); return getMinusSCEV(AllOnes, V); } @@ -2159,8 +2397,8 @@ const SCEV * ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, const Type *Ty) { const Type *SrcTy = V->getType(); - assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && - (Ty->isInteger() || (TD && isa(Ty))) && + assert((SrcTy->isInteger() || isa(SrcTy)) && + (Ty->isInteger() || isa(Ty)) && "Cannot truncate or zero extend with non-integer arguments!"); if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) return V; // No conversion @@ -2176,8 +2414,8 @@ const SCEV * ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, const Type *Ty) { const Type *SrcTy = V->getType(); - assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && - (Ty->isInteger() || (TD && isa(Ty))) && + assert((SrcTy->isInteger() || isa(SrcTy)) && + (Ty->isInteger() || isa(Ty)) && "Cannot truncate or zero extend with non-integer arguments!"); if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) return V; // No conversion @@ -2192,8 +2430,8 @@ ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, const SCEV * ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, const Type *Ty) { const Type *SrcTy = V->getType(); - assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && - (Ty->isInteger() || (TD && isa(Ty))) && + assert((SrcTy->isInteger() || isa(SrcTy)) && + (Ty->isInteger() || isa(Ty)) && "Cannot noop or zero extend with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && "getNoopOrZeroExtend cannot truncate!"); @@ -2208,8 +2446,8 @@ ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, const Type *Ty) { const SCEV * ScalarEvolution::getNoopOrSignExtend(const SCEV *V, const Type *Ty) { const Type *SrcTy = V->getType(); - assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && - (Ty->isInteger() || (TD && isa(Ty))) && + assert((SrcTy->isInteger() || isa(SrcTy)) && + (Ty->isInteger() || isa(Ty)) && "Cannot noop or sign extend with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && "getNoopOrSignExtend cannot truncate!"); @@ -2225,8 +2463,8 @@ ScalarEvolution::getNoopOrSignExtend(const SCEV *V, const Type *Ty) { const SCEV * ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, const Type *Ty) { const Type *SrcTy = V->getType(); - assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && - (Ty->isInteger() || (TD && isa(Ty))) && + assert((SrcTy->isInteger() || isa(SrcTy)) && + (Ty->isInteger() || isa(Ty)) && "Cannot noop or any extend with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && "getNoopOrAnyExtend cannot truncate!"); @@ -2240,8 +2478,8 @@ ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, const Type *Ty) { const SCEV * ScalarEvolution::getTruncateOrNoop(const SCEV *V, const Type *Ty) { const Type *SrcTy = V->getType(); - assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && - (Ty->isInteger() || (TD && isa(Ty))) && + assert((SrcTy->isInteger() || isa(SrcTy)) && + (Ty->isInteger() || isa(Ty)) && "Cannot truncate or noop with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) && "getTruncateOrNoop cannot extend!"); @@ -2282,28 +2520,54 @@ const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, return getUMinExpr(PromotedLHS, PromotedRHS); } -/// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value for -/// the specified instruction and replaces any references to the symbolic value -/// SymName with the specified value. This is used during PHI resolution. +/// PushDefUseChildren - Push users of the given Instruction +/// onto the given Worklist. +static void +PushDefUseChildren(Instruction *I, + SmallVectorImpl &Worklist) { + // Push the def-use children onto the Worklist stack. + for (Value::use_iterator UI = I->use_begin(), UE = I->use_end(); + UI != UE; ++UI) + Worklist.push_back(cast(UI)); +} + +/// ForgetSymbolicValue - This looks up computed SCEV values for all +/// instructions that depend on the given instruction and removes them from +/// the Scalars map if they reference SymName. This is used during PHI +/// resolution. void -ScalarEvolution::ReplaceSymbolicValueWithConcrete(Instruction *I, - const SCEV *SymName, - const SCEV *NewVal) { - std::map::iterator SI = - Scalars.find(SCEVCallbackVH(I, this)); - if (SI == Scalars.end()) return; +ScalarEvolution::ForgetSymbolicName(Instruction *I, const SCEV *SymName) { + SmallVector Worklist; + PushDefUseChildren(I, Worklist); - const SCEV *NV = - SI->second->replaceSymbolicValuesWithConcrete(SymName, NewVal, *this); - if (NV == SI->second) return; // No change. + SmallPtrSet Visited; + Visited.insert(I); + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + if (!Visited.insert(I)) continue; - SI->second = NV; // Update the scalars map! + std::map::iterator It = + Scalars.find(static_cast(I)); + if (It != Scalars.end()) { + // Short-circuit the def-use traversal if the symbolic name + // ceases to appear in expressions. + if (!It->second->hasOperand(SymName)) + continue; + + // SCEVUnknown for a PHI either means that it has an unrecognized + // structure, or it's a PHI that's in the progress of being computed + // by createNodeForPHI. In the former case, additional loop trip + // count information isn't going to change anything. In the later + // case, createNodeForPHI will perform the necessary updates on its + // own when it gets to that point. + if (!isa(I) || !isa(It->second)) { + ValuesAtScopes.erase(It->second); + Scalars.erase(It); + } + } - // Any instruction values that use this instruction might also need to be - // updated! - for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); - UI != E; ++UI) - ReplaceSymbolicValueWithConcrete(cast(*UI), SymName, NewVal); + PushDefUseChildren(I, Worklist); + } } /// createNodeForPHI - PHI nodes have two cases. Either the PHI node exists in @@ -2326,7 +2590,8 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // Using this symbolic name for the PHI, analyze the value coming around // the back-edge. - const SCEV *BEValue = getSCEV(PN->getIncomingValue(BackEdge)); + Value *BEValueV = PN->getIncomingValue(BackEdge); + const SCEV *BEValue = getSCEV(BEValueV); // NOTE: If BEValue is loop invariant, we know that the PHI node just // has a special value for the first iteration of the loop. @@ -2357,17 +2622,34 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { if (Accum->isLoopInvariant(L) || (isa(Accum) && cast(Accum)->getLoop() == L)) { + bool HasNUW = false; + bool HasNSW = false; + + // If the increment doesn't overflow, then neither the addrec nor + // the post-increment will overflow. + if (const AddOperator *OBO = dyn_cast(BEValueV)) { + if (OBO->hasNoUnsignedWrap()) + HasNUW = true; + if (OBO->hasNoSignedWrap()) + HasNSW = true; + } + const SCEV *StartVal = getSCEV(PN->getIncomingValue(IncomingEdge)); const SCEV *PHISCEV = - getAddRecExpr(StartVal, Accum, L); + getAddRecExpr(StartVal, Accum, L, HasNUW, HasNSW); + + // Since the no-wrap flags are on the increment, they apply to the + // post-incremented value as well. + if (Accum->isLoopInvariant(L)) + (void)getAddRecExpr(getAddExpr(StartVal, Accum), + Accum, L, HasNUW, HasNSW); // Okay, for the entire analysis of this edge we assumed the PHI - // to be symbolic. We now need to go back and update all of the - // entries for the scalars that use the PHI (except for the PHI - // itself) to use the new analyzed value instead of the "symbolic" - // value. - ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV); + // to be symbolic. We now need to go back and purge all of the + // entries for the scalars that use the symbolic expression. + ForgetSymbolicName(PN, SymbolicName); + Scalars[SCEVCallbackVH(PN, this)] = PHISCEV; return PHISCEV; } } @@ -2389,11 +2671,10 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { getAddRecExpr(StartVal, AddRec->getOperand(1), L); // Okay, for the entire analysis of this edge we assumed the PHI - // to be symbolic. We now need to go back and update all of the - // entries for the scalars that use the PHI (except for the PHI - // itself) to use the new analyzed value instead of the "symbolic" - // value. - ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV); + // to be symbolic. We now need to go back and purge all of the + // entries for the scalars that use the symbolic expression. + ForgetSymbolicName(PN, SymbolicName); + Scalars[SCEVCallbackVH(PN, this)] = PHISCEV; return PHISCEV; } } @@ -2402,6 +2683,10 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { return SymbolicName; } + // It's tempting to recognize PHIs with a unique incoming value, however + // this leads passes like indvars to break LCSSA form. Fortunately, such + // PHIs are rare, as instcombine zaps them. + // If it's not a loop phi, we can't handle it yet. return getUnknown(PN); } @@ -2409,9 +2694,10 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { /// createNodeForGEP - Expand GEP instructions into add and multiply /// operations. This allows them to be analyzed by regular SCEV code. /// -const SCEV *ScalarEvolution::createNodeForGEP(User *GEP) { +const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { - const Type *IntPtrTy = TD->getIntPtrType(); + bool InBounds = GEP->isInBounds(); + const Type *IntPtrTy = getEffectiveSCEVType(GEP->getType()); Value *Base = GEP->getOperand(0); // Don't attempt to analyze GEPs over unsized objects. if (!cast(Base->getType())->getElementType()->isSized()) @@ -2425,23 +2711,24 @@ const SCEV *ScalarEvolution::createNodeForGEP(User *GEP) { // Compute the (potentially symbolic) offset in bytes for this index. if (const StructType *STy = dyn_cast(*GTI++)) { // For a struct, add the member offset. - const StructLayout &SL = *TD->getStructLayout(STy); unsigned FieldNo = cast(Index)->getZExtValue(); - uint64_t Offset = SL.getElementOffset(FieldNo); - TotalOffset = getAddExpr(TotalOffset, getIntegerSCEV(Offset, IntPtrTy)); + TotalOffset = getAddExpr(TotalOffset, + getOffsetOfExpr(STy, FieldNo), + /*HasNUW=*/false, /*HasNSW=*/InBounds); } else { // For an array, add the element offset, explicitly scaled. const SCEV *LocalOffset = getSCEV(Index); - if (!isa(LocalOffset->getType())) - // Getelementptr indicies are signed. - LocalOffset = getTruncateOrSignExtend(LocalOffset, IntPtrTy); - LocalOffset = - getMulExpr(LocalOffset, - getIntegerSCEV(TD->getTypeAllocSize(*GTI), IntPtrTy)); - TotalOffset = getAddExpr(TotalOffset, LocalOffset); + // Getelementptr indicies are signed. + LocalOffset = getTruncateOrSignExtend(LocalOffset, IntPtrTy); + // Lower "inbounds" GEPs to NSW arithmetic. + LocalOffset = getMulExpr(LocalOffset, getSizeOfExpr(*GTI), + /*HasNUW=*/false, /*HasNSW=*/InBounds); + TotalOffset = getAddExpr(TotalOffset, LocalOffset, + /*HasNUW=*/false, /*HasNSW=*/InBounds); } } - return getAddExpr(getSCEV(Base), TotalOffset); + return getAddExpr(getSCEV(Base), TotalOffset, + /*HasNUW=*/false, /*HasNSW=*/InBounds); } /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is @@ -2533,75 +2820,89 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) { if (const SCEVConstant *C = dyn_cast(S)) return ConstantRange(C->getValue()->getValue()); + unsigned BitWidth = getTypeSizeInBits(S->getType()); + ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); + + // If the value has known zeros, the maximum unsigned value will have those + // known zeros as well. + uint32_t TZ = GetMinTrailingZeros(S); + if (TZ != 0) + ConservativeResult = + ConstantRange(APInt::getMinValue(BitWidth), + APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1); + if (const SCEVAddExpr *Add = dyn_cast(S)) { ConstantRange X = getUnsignedRange(Add->getOperand(0)); for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) X = X.add(getUnsignedRange(Add->getOperand(i))); - return X; + return ConservativeResult.intersectWith(X); } if (const SCEVMulExpr *Mul = dyn_cast(S)) { ConstantRange X = getUnsignedRange(Mul->getOperand(0)); for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) X = X.multiply(getUnsignedRange(Mul->getOperand(i))); - return X; + return ConservativeResult.intersectWith(X); } if (const SCEVSMaxExpr *SMax = dyn_cast(S)) { ConstantRange X = getUnsignedRange(SMax->getOperand(0)); for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) X = X.smax(getUnsignedRange(SMax->getOperand(i))); - return X; + return ConservativeResult.intersectWith(X); } if (const SCEVUMaxExpr *UMax = dyn_cast(S)) { ConstantRange X = getUnsignedRange(UMax->getOperand(0)); for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) X = X.umax(getUnsignedRange(UMax->getOperand(i))); - return X; + return ConservativeResult.intersectWith(X); } if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { ConstantRange X = getUnsignedRange(UDiv->getLHS()); ConstantRange Y = getUnsignedRange(UDiv->getRHS()); - return X.udiv(Y); + return ConservativeResult.intersectWith(X.udiv(Y)); } if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) { ConstantRange X = getUnsignedRange(ZExt->getOperand()); - return X.zeroExtend(cast(ZExt->getType())->getBitWidth()); + return ConservativeResult.intersectWith(X.zeroExtend(BitWidth)); } if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) { ConstantRange X = getUnsignedRange(SExt->getOperand()); - return X.signExtend(cast(SExt->getType())->getBitWidth()); + return ConservativeResult.intersectWith(X.signExtend(BitWidth)); } if (const SCEVTruncateExpr *Trunc = dyn_cast(S)) { ConstantRange X = getUnsignedRange(Trunc->getOperand()); - return X.truncate(cast(Trunc->getType())->getBitWidth()); + return ConservativeResult.intersectWith(X.truncate(BitWidth)); } - ConstantRange FullSet(getTypeSizeInBits(S->getType()), true); - if (const SCEVAddRecExpr *AddRec = dyn_cast(S)) { - const SCEV *T = getBackedgeTakenCount(AddRec->getLoop()); - const SCEVConstant *Trip = dyn_cast(T); - if (!Trip) return FullSet; + // If there's no unsigned wrap, the value will never be less than its + // initial value. + if (AddRec->hasNoUnsignedWrap()) + if (const SCEVConstant *C = dyn_cast(AddRec->getStart())) + ConservativeResult = + ConstantRange(C->getValue()->getValue(), + APInt(getTypeSizeInBits(C->getType()), 0)); // TODO: non-affine addrec if (AddRec->isAffine()) { const Type *Ty = AddRec->getType(); const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); - if (getTypeSizeInBits(MaxBECount->getType()) <= getTypeSizeInBits(Ty)) { + if (!isa(MaxBECount) && + getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); const SCEV *Start = AddRec->getStart(); const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this); // Check for overflow. - if (!isKnownPredicate(ICmpInst::ICMP_ULE, Start, End)) - return FullSet; + if (!AddRec->hasNoUnsignedWrap()) + return ConservativeResult; ConstantRange StartRange = getUnsignedRange(Start); ConstantRange EndRange = getUnsignedRange(End); @@ -2610,10 +2911,12 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) { APInt Max = APIntOps::umax(StartRange.getUnsignedMax(), EndRange.getUnsignedMax()); if (Min.isMinValue() && Max.isMaxValue()) - return ConstantRange(Min.getBitWidth(), /*isFullSet=*/true); - return ConstantRange(Min, Max+1); + return ConservativeResult; + return ConservativeResult.intersectWith(ConstantRange(Min, Max+1)); } } + + return ConservativeResult; } if (const SCEVUnknown *U = dyn_cast(S)) { @@ -2622,10 +2925,12 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) { APInt Mask = APInt::getAllOnesValue(BitWidth); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD); - return ConstantRange(Ones, ~Zeros); + if (Ones == ~Zeros + 1) + return ConservativeResult; + return ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)); } - return FullSet; + return ConservativeResult; } /// getSignedRange - Determine the signed range for a particular SCEV. @@ -2636,79 +2941,100 @@ ScalarEvolution::getSignedRange(const SCEV *S) { if (const SCEVConstant *C = dyn_cast(S)) return ConstantRange(C->getValue()->getValue()); + unsigned BitWidth = getTypeSizeInBits(S->getType()); + ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); + + // If the value has known zeros, the maximum signed value will have those + // known zeros as well. + uint32_t TZ = GetMinTrailingZeros(S); + if (TZ != 0) + ConservativeResult = + ConstantRange(APInt::getSignedMinValue(BitWidth), + APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); + if (const SCEVAddExpr *Add = dyn_cast(S)) { ConstantRange X = getSignedRange(Add->getOperand(0)); for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) X = X.add(getSignedRange(Add->getOperand(i))); - return X; + return ConservativeResult.intersectWith(X); } if (const SCEVMulExpr *Mul = dyn_cast(S)) { ConstantRange X = getSignedRange(Mul->getOperand(0)); for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) X = X.multiply(getSignedRange(Mul->getOperand(i))); - return X; + return ConservativeResult.intersectWith(X); } if (const SCEVSMaxExpr *SMax = dyn_cast(S)) { ConstantRange X = getSignedRange(SMax->getOperand(0)); for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) X = X.smax(getSignedRange(SMax->getOperand(i))); - return X; + return ConservativeResult.intersectWith(X); } if (const SCEVUMaxExpr *UMax = dyn_cast(S)) { ConstantRange X = getSignedRange(UMax->getOperand(0)); for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) X = X.umax(getSignedRange(UMax->getOperand(i))); - return X; + return ConservativeResult.intersectWith(X); } if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { ConstantRange X = getSignedRange(UDiv->getLHS()); ConstantRange Y = getSignedRange(UDiv->getRHS()); - return X.udiv(Y); + return ConservativeResult.intersectWith(X.udiv(Y)); } if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) { ConstantRange X = getSignedRange(ZExt->getOperand()); - return X.zeroExtend(cast(ZExt->getType())->getBitWidth()); + return ConservativeResult.intersectWith(X.zeroExtend(BitWidth)); } if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) { ConstantRange X = getSignedRange(SExt->getOperand()); - return X.signExtend(cast(SExt->getType())->getBitWidth()); + return ConservativeResult.intersectWith(X.signExtend(BitWidth)); } if (const SCEVTruncateExpr *Trunc = dyn_cast(S)) { ConstantRange X = getSignedRange(Trunc->getOperand()); - return X.truncate(cast(Trunc->getType())->getBitWidth()); + return ConservativeResult.intersectWith(X.truncate(BitWidth)); } - ConstantRange FullSet(getTypeSizeInBits(S->getType()), true); - if (const SCEVAddRecExpr *AddRec = dyn_cast(S)) { - const SCEV *T = getBackedgeTakenCount(AddRec->getLoop()); - const SCEVConstant *Trip = dyn_cast(T); - if (!Trip) return FullSet; + // If there's no signed wrap, and all the operands have the same sign or + // zero, the value won't ever change sign. + if (AddRec->hasNoSignedWrap()) { + bool AllNonNeg = true; + bool AllNonPos = true; + for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { + if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false; + if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false; + } + if (AllNonNeg) + ConservativeResult = ConservativeResult.intersectWith( + ConstantRange(APInt(BitWidth, 0), + APInt::getSignedMinValue(BitWidth))); + else if (AllNonPos) + ConservativeResult = ConservativeResult.intersectWith( + ConstantRange(APInt::getSignedMinValue(BitWidth), + APInt(BitWidth, 1))); + } // TODO: non-affine addrec if (AddRec->isAffine()) { const Type *Ty = AddRec->getType(); const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); - if (getTypeSizeInBits(MaxBECount->getType()) <= getTypeSizeInBits(Ty)) { + if (!isa(MaxBECount) && + getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); const SCEV *Start = AddRec->getStart(); - const SCEV *Step = AddRec->getStepRecurrence(*this); const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this); // Check for overflow. - if (!(isKnownPositive(Step) && - isKnownPredicate(ICmpInst::ICMP_SLT, Start, End)) && - !(isKnownNegative(Step) && - isKnownPredicate(ICmpInst::ICMP_SGT, Start, End))) - return FullSet; + if (!AddRec->hasNoSignedWrap()) + return ConservativeResult; ConstantRange StartRange = getSignedRange(Start); ConstantRange EndRange = getSignedRange(End); @@ -2717,24 +3043,27 @@ ScalarEvolution::getSignedRange(const SCEV *S) { APInt Max = APIntOps::smax(StartRange.getSignedMax(), EndRange.getSignedMax()); if (Min.isMinSignedValue() && Max.isMaxSignedValue()) - return ConstantRange(Min.getBitWidth(), /*isFullSet=*/true); - return ConstantRange(Min, Max+1); + return ConservativeResult; + return ConservativeResult.intersectWith(ConstantRange(Min, Max+1)); } } + + return ConservativeResult; } if (const SCEVUnknown *U = dyn_cast(S)) { // For a SCEVUnknown, ask ValueTracking. - unsigned BitWidth = getTypeSizeInBits(U->getType()); + if (!U->getValue()->getType()->isInteger() && !TD) + return ConservativeResult; unsigned NS = ComputeNumSignBits(U->getValue(), TD); if (NS == 1) - return FullSet; - return + return ConservativeResult; + return ConservativeResult.intersectWith( ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), - APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1); + APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1)); } - return FullSet; + return ConservativeResult; } /// createSCEV - We know that there is no SCEV for the specified value. @@ -2755,15 +3084,23 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { return getIntegerSCEV(0, V->getType()); else if (isa(V)) return getIntegerSCEV(0, V->getType()); + else if (GlobalAlias *GA = dyn_cast(V)) + return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee()); else return getUnknown(V); - User *U = cast(V); + Operator *U = cast(V); switch (Opcode) { case Instruction::Add: + // Don't transfer the NSW and NUW bits from the Add instruction to the + // Add expression, because the Instruction may be guarded by control + // flow and the no-overflow bits may not be valid for the expression in + // any context. return getAddExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1))); case Instruction::Mul: + // Don't transfer the NSW and NUW bits from the Mul instruction to the + // Mul expression, as with Add. return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1))); case Instruction::UDiv: @@ -2797,7 +3134,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask)) return getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)), - IntegerType::get(BitWidth - LZ)), + IntegerType::get(getContext(), BitWidth - LZ)), U->getType()); } break; @@ -2813,8 +3150,20 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { const SCEV *LHS = getSCEV(U->getOperand(0)); const APInt &CIVal = CI->getValue(); if (GetMinTrailingZeros(LHS) >= - (CIVal.getBitWidth() - CIVal.countLeadingZeros())) - return getAddExpr(LHS, getSCEV(U->getOperand(1))); + (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { + // Build a plain add SCEV. + const SCEV *S = getAddExpr(LHS, getSCEV(CI)); + // If the LHS of the add was an addrec and it has no-wrap flags, + // transfer the no-wrap flags, since an or won't introduce a wrap. + if (const SCEVAddRecExpr *NewAR = dyn_cast(S)) { + const SCEVAddRecExpr *OldAR = cast(LHS); + if (OldAR->hasNoUnsignedWrap()) + const_cast(NewAR)->setHasNoUnsignedWrap(true); + if (OldAR->hasNoSignedWrap()) + const_cast(NewAR)->setHasNoSignedWrap(true); + } + return S; + } } break; case Instruction::Xor: @@ -2865,8 +3214,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case Instruction::Shl: // Turn shift left of a constant amount into a multiply. if (ConstantInt *SA = dyn_cast(U->getOperand(1))) { - uint32_t BitWidth = cast(V->getType())->getBitWidth(); - Constant *X = ConstantInt::get( + uint32_t BitWidth = cast(U->getType())->getBitWidth(); + Constant *X = ConstantInt::get(getContext(), APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X)); } @@ -2875,8 +3224,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case Instruction::LShr: // Turn logical shift right of a constant into a unsigned divide. if (ConstantInt *SA = dyn_cast(U->getOperand(1))) { - uint32_t BitWidth = cast(V->getType())->getBitWidth(); - Constant *X = ConstantInt::get( + uint32_t BitWidth = cast(U->getType())->getBitWidth(); + Constant *X = ConstantInt::get(getContext(), APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X)); } @@ -2896,7 +3245,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { return getIntegerSCEV(0, U->getType()); // value is undefined return getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)), - IntegerType::get(Amt)), + IntegerType::get(getContext(), Amt)), U->getType()); } break; @@ -2916,19 +3265,13 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { return getSCEV(U->getOperand(0)); break; - case Instruction::IntToPtr: - if (!TD) break; // Without TD we can't analyze pointers. - return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)), - TD->getIntPtrType()); - - case Instruction::PtrToInt: - if (!TD) break; // Without TD we can't analyze pointers. - return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)), - U->getType()); + // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can + // lead to pointer expressions which cannot safely be expanded to GEPs, + // because ScalarEvolution doesn't respect the GEP aliasing rules when + // simplifying integer expressions. case Instruction::GetElementPtr: - if (!TD) break; // Without TD we can't analyze pointers. - return createNodeForGEP(U); + return createNodeForGEP(cast(U)); case Instruction::PHI: return createNodeForPHI(cast(U)); @@ -3032,17 +3375,6 @@ PushLoopPHIs(const Loop *L, SmallVectorImpl &Worklist) { Worklist.push_back(PN); } -/// PushDefUseChildren - Push users of the given Instruction -/// onto the given Worklist. -static void -PushDefUseChildren(Instruction *I, - SmallVectorImpl &Worklist) { - // Push the def-use children onto the Worklist stack. - for (Value::use_iterator UI = I->use_begin(), UE = I->use_end(); - UI != UE; ++UI) - Worklist.push_back(cast(UI)); -} - const ScalarEvolution::BackedgeTakenInfo & ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // Initially insert a CouldNotCompute for this loop. If the insertion @@ -3050,22 +3382,22 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // update the value. The temporary CouldNotCompute value tells SCEV // code elsewhere that it shouldn't attempt to request a new // backedge-taken count, which could result in infinite recursion. - std::pair::iterator, bool> Pair = + std::pair::iterator, bool> Pair = BackedgeTakenCounts.insert(std::make_pair(L, getCouldNotCompute())); if (Pair.second) { - BackedgeTakenInfo ItCount = ComputeBackedgeTakenCount(L); - if (ItCount.Exact != getCouldNotCompute()) { - assert(ItCount.Exact->isLoopInvariant(L) && - ItCount.Max->isLoopInvariant(L) && - "Computed trip count isn't loop invariant for loop!"); + BackedgeTakenInfo BECount = ComputeBackedgeTakenCount(L); + if (BECount.Exact != getCouldNotCompute()) { + assert(BECount.Exact->isLoopInvariant(L) && + BECount.Max->isLoopInvariant(L) && + "Computed backedge-taken count isn't loop invariant for loop!"); ++NumTripCountsComputed; // Update the value in the map. - Pair.first->second = ItCount; + Pair.first->second = BECount; } else { - if (ItCount.Max != getCouldNotCompute()) + if (BECount.Max != getCouldNotCompute()) // Update the value in the map. - Pair.first->second = ItCount; + Pair.first->second = BECount; if (isa(L->getHeader()->begin())) // Only count loops that have phi nodes as not being computable. ++NumTripCountsNotComputed; @@ -3074,10 +3406,9 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // Now that we know more about the trip count for this loop, forget any // existing SCEV values for PHI nodes in this loop since they are only // conservative estimates made without the benefit of trip count - // information. This is similar to the code in - // forgetLoopBackedgeTakenCount, except that it handles SCEVUnknown PHI - // nodes specially. - if (ItCount.hasAnyInfo()) { + // information. This is similar to the code in forgetLoop, except that + // it handles SCEVUnknown PHI nodes specially. + if (BECount.hasAnyInfo()) { SmallVector Worklist; PushLoopPHIs(L, Worklist); @@ -3086,7 +3417,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { Instruction *I = Worklist.pop_back_val(); if (!Visited.insert(I)) continue; - std::map::iterator It = + std::map::iterator It = Scalars.find(static_cast(I)); if (It != Scalars.end()) { // SCEVUnknown for a PHI either means that it has an unrecognized @@ -3095,9 +3426,10 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // count information isn't going to change anything. In the later // case, createNodeForPHI will perform the necessary updates on its // own when it gets to that point. - if (!isa(I) || !isa(It->second)) + if (!isa(I) || !isa(It->second)) { + ValuesAtScopes.erase(It->second); Scalars.erase(It); - ValuesAtScopes.erase(I); + } if (PHINode *PN = dyn_cast(I)) ConstantEvolutionLoopExitValue.erase(PN); } @@ -3109,13 +3441,14 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { return Pair.first->second; } -/// forgetLoopBackedgeTakenCount - This method should be called by the -/// client when it has changed a loop in a way that may effect -/// ScalarEvolution's ability to compute a trip count, or if the loop -/// is deleted. -void ScalarEvolution::forgetLoopBackedgeTakenCount(const Loop *L) { +/// forgetLoop - This method should be called by the client when it has +/// changed a loop in a way that may effect ScalarEvolution's ability to +/// compute a trip count, or if the loop is deleted. +void ScalarEvolution::forgetLoop(const Loop *L) { + // Drop any stored trip count value. BackedgeTakenCounts.erase(L); + // Drop information about expressions based on loop-header PHIs. SmallVector Worklist; PushLoopPHIs(L, Worklist); @@ -3124,11 +3457,11 @@ void ScalarEvolution::forgetLoopBackedgeTakenCount(const Loop *L) { Instruction *I = Worklist.pop_back_val(); if (!Visited.insert(I)) continue; - std::map::iterator It = + std::map::iterator It = Scalars.find(static_cast(I)); if (It != Scalars.end()) { + ValuesAtScopes.erase(It->second); Scalars.erase(It); - ValuesAtScopes.erase(I); if (PHINode *PN = dyn_cast(I)) ConstantEvolutionLoopExitValue.erase(PN); } @@ -3141,7 +3474,7 @@ void ScalarEvolution::forgetLoopBackedgeTakenCount(const Loop *L) { /// of the specified loop will execute. ScalarEvolution::BackedgeTakenInfo ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { - SmallVector ExitingBlocks; + SmallVector ExitingBlocks; L->getExitingBlocks(ExitingBlocks); // Examine all exits and pick the most conservative values. @@ -3394,8 +3727,8 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, if (!isa(TC)) return TC; break; } - case ICmpInst::ICMP_EQ: { - // Convert to: while (X-Y == 0) // while (X == Y) + case ICmpInst::ICMP_EQ: { // while (X == Y) + // Convert to: while (X-Y == 0) const SCEV *TC = HowFarToNonZero(getMinusSCEV(LHS, RHS), L); if (!isa(TC)) return TC; break; @@ -3424,10 +3757,10 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, } default: #if 0 - errs() << "ComputeBackedgeTakenCount "; + dbgs() << "ComputeBackedgeTakenCount "; if (ExitCond->getOperand(0)->getType()->isUnsigned()) - errs() << "[unsigned] "; - errs() << *LHS << " " + dbgs() << "[unsigned] "; + dbgs() << *LHS << " " << Instruction::getOpcodeName(Instruction::ICmp) << " " << *RHS << "\n"; #endif @@ -3452,7 +3785,7 @@ EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, /// the addressed element of the initializer or null if the index expression is /// invalid. static Constant * -GetAddressedElementFromGlobal(LLVMContext *Context, GlobalVariable *GV, +GetAddressedElementFromGlobal(GlobalVariable *GV, const std::vector &Indices) { Constant *Init = GV->getInitializer(); for (unsigned i = 0, e = Indices.size(); i != e; ++i) { @@ -3466,12 +3799,12 @@ GetAddressedElementFromGlobal(LLVMContext *Context, GlobalVariable *GV, } else if (isa(Init)) { if (const StructType *STy = dyn_cast(Init->getType())) { assert(Idx < STy->getNumElements() && "Bad struct index!"); - Init = Context->getNullValue(STy->getElementType(Idx)); + Init = Constant::getNullValue(STy->getElementType(Idx)); } else if (const ArrayType *ATy = dyn_cast(Init->getType())) { if (Idx >= ATy->getNumElements()) return 0; // Bogus program - Init = Context->getNullValue(ATy->getElementType()); + Init = Constant::getNullValue(ATy->getElementType()); } else { - LLVM_UNREACHABLE("Unknown constant aggregate type!"); + llvm_unreachable("Unknown constant aggregate type!"); } return 0; } else { @@ -3499,7 +3832,7 @@ ScalarEvolution::ComputeLoadConstantCompareBackedgeTakenCount( // Make sure that it is really a constant global we are gepping, with an // initializer, and make sure the first IDX is really 0. GlobalVariable *GV = dyn_cast(GEP->getOperand(0)); - if (!GV || !GV->isConstant() || !GV->hasInitializer() || + if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() || GEP->getNumOperands() < 3 || !isa(GEP->getOperand(1)) || !cast(GEP->getOperand(1))->isNullValue()) return getCouldNotCompute(); @@ -3533,14 +3866,14 @@ ScalarEvolution::ComputeLoadConstantCompareBackedgeTakenCount( unsigned MaxSteps = MaxBruteForceIterations; for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) { - ConstantInt *ItCst = - ConstantInt::get(cast(IdxExpr->getType()), IterationNum); + ConstantInt *ItCst = ConstantInt::get( + cast(IdxExpr->getType()), IterationNum); ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this); // Form the GEP offset. Indexes[VarIdxNum] = Val; - Constant *Result = GetAddressedElementFromGlobal(Context, GV, Indexes); + Constant *Result = GetAddressedElementFromGlobal(GV, Indexes); if (Result == 0) break; // Cannot compute! // Evaluate the condition for this iteration. @@ -3548,7 +3881,7 @@ ScalarEvolution::ComputeLoadConstantCompareBackedgeTakenCount( if (!isa(Result)) break; // Couldn't decide for sure if (cast(Result)->getValue().isMinValue()) { #if 0 - errs() << "\n***\n*** Computed loop count " << *ItCst + dbgs() << "\n***\n*** Computed loop count " << *ItCst << "\n*** From global " << *GV << "*** BB: " << *L->getHeader() << "***\n"; #endif @@ -3582,7 +3915,7 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { // If this is not an instruction, or if this is an instruction outside of the // loop, it can't be derived from a loop PHI. Instruction *I = dyn_cast(V); - if (I == 0 || !L->contains(I->getParent())) return 0; + if (I == 0 || !L->contains(I)) return 0; if (PHINode *PN = dyn_cast(I)) { if (L->getHeader() == I->getParent()) @@ -3619,29 +3952,26 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node /// in the loop has the value PHIVal. If we can't fold this expression for some /// reason, return null. -static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { +static Constant *EvaluateExpression(Value *V, Constant *PHIVal, + const TargetData *TD) { if (isa(V)) return PHIVal; if (Constant *C = dyn_cast(V)) return C; if (GlobalValue *GV = dyn_cast(V)) return GV; Instruction *I = cast(V); - LLVMContext *Context = I->getParent()->getContext(); std::vector Operands; Operands.resize(I->getNumOperands()); for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { - Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal); + Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal, TD); if (Operands[i] == 0) return 0; } if (const CmpInst *CI = dyn_cast(I)) - return ConstantFoldCompareInstOperands(CI->getPredicate(), - &Operands[0], Operands.size(), - Context); - else - return ConstantFoldInstOperands(I->getOpcode(), I->getType(), - &Operands[0], Operands.size(), - Context); + return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], + Operands[1], TD); + return ConstantFoldInstOperands(I->getOpcode(), I->getType(), + &Operands[0], Operands.size(), TD); } /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is @@ -3650,7 +3980,7 @@ static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { /// involving constants, fold it. Constant * ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, - const APInt& BEs, + const APInt &BEs, const Loop *L) { std::map::iterator I = ConstantEvolutionLoopExitValue.find(PN); @@ -3687,7 +4017,7 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, return RetVal = PHIVal; // Got exit value! // Compute the value of the PHI node for the next iteration. - Constant *NextPHI = EvaluateExpression(BEValue, PHIVal); + Constant *NextPHI = EvaluateExpression(BEValue, PHIVal, TD); if (NextPHI == PHIVal) return RetVal = NextPHI; // Stopped evolving! if (NextPHI == 0) @@ -3696,7 +4026,7 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, } } -/// ComputeBackedgeTakenCountExhaustively - If the trip is known to execute a +/// ComputeBackedgeTakenCountExhaustively - If the loop is known to execute a /// constant number of times (the condition evolves only from constants), /// try to evaluate a few iterations of the loop until we get the exit /// condition gets a value of ExitWhen (true or false). If we cannot @@ -3728,18 +4058,18 @@ ScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L, for (Constant *PHIVal = StartCST; IterationNum != MaxIterations; ++IterationNum) { ConstantInt *CondVal = - dyn_cast_or_null(EvaluateExpression(Cond, PHIVal)); + dyn_cast_or_null(EvaluateExpression(Cond, PHIVal, TD)); // Couldn't symbolically evaluate. if (!CondVal) return getCouldNotCompute(); if (CondVal->getValue() == uint64_t(ExitWhen)) { ++NumBruteForceTripCountsComputed; - return getConstant(Type::Int32Ty, IterationNum); + return getConstant(Type::getInt32Ty(getContext()), IterationNum); } // Compute the value of the PHI node for the next iteration. - Constant *NextPHI = EvaluateExpression(BEValue, PHIVal); + Constant *NextPHI = EvaluateExpression(BEValue, PHIVal, TD); if (NextPHI == 0 || NextPHI == PHIVal) return getCouldNotCompute();// Couldn't evaluate or not making progress... PHIVal = NextPHI; @@ -3749,7 +4079,7 @@ ScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L, return getCouldNotCompute(); } -/// getSCEVAtScope - Return a SCEV expression handle for the specified value +/// getSCEVAtScope - Return a SCEV expression for the specified value /// at the specified scope in the program. The L value specifies a loop /// nest to evaluate the expression at, where null is the top-level or a /// specified loop is immediately inside of the loop. @@ -3760,8 +4090,20 @@ ScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L, /// In the case that a relevant loop exit value cannot be computed, the /// original value V is returned. const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { - // FIXME: this should be turned into a virtual method on SCEV! + // Check to see if we've folded this expression at this loop before. + std::map &Values = ValuesAtScopes[V]; + std::pair::iterator, bool> Pair = + Values.insert(std::make_pair(L, static_cast(0))); + if (!Pair.second) + return Pair.first->second ? Pair.first->second : V; + // Otherwise compute it. + const SCEV *C = computeSCEVAtScope(V, L); + ValuesAtScopes[V][L] = C; + return C; +} + +const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { if (isa(V)) return V; // If this instruction is evolved from a constant-evolving PHI, compute the @@ -3794,13 +4136,6 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { // the arguments into constants, and if so, try to constant propagate the // result. This is particularly useful for computing loop exit values. if (CanConstantFold(I)) { - // Check to see if we've folded this instruction at this loop before. - std::map &Values = ValuesAtScopes[I]; - std::pair::iterator, bool> Pair = - Values.insert(std::make_pair(L, static_cast(0))); - if (!Pair.second) - return Pair.first->second ? &*getSCEV(Pair.first->second) : V; - std::vector Operands; Operands.reserve(I->getNumOperands()); for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { @@ -3814,7 +4149,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { if (!isSCEVable(Op->getType())) return V; - const SCEV* OpV = getSCEVAtScope(Op, L); + const SCEV *OpV = getSCEVAtScope(Op, L); if (const SCEVConstant *SC = dyn_cast(OpV)) { Constant *C = SC->getValue(); if (C->getType() != Op->getType()) @@ -3843,12 +4178,10 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { Constant *C; if (const CmpInst *CI = dyn_cast(I)) C = ConstantFoldCompareInstOperands(CI->getPredicate(), - &Operands[0], Operands.size(), - Context); + Operands[0], Operands[1], TD); else C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), - &Operands[0], Operands.size(), Context); - Pair.first->second = C; + &Operands[0], Operands.size(), TD); return getSCEV(C); } } @@ -3881,7 +4214,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { return getSMaxExpr(NewOps); if (isa(Comm)) return getUMaxExpr(NewOps); - LLVM_UNREACHABLE("Unknown commutative SCEV type!"); + llvm_unreachable("Unknown commutative SCEV type!"); } } // If we got here, all operands are loop invariant. @@ -3899,7 +4232,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { // If this is a loop recurrence for a loop that does not contain L, then we // are dealing with the final value computed by the loop. if (const SCEVAddRecExpr *AddRec = dyn_cast(V)) { - if (!L || !AddRec->getLoop()->contains(L->getHeader())) { + if (!L || !AddRec->getLoop()->contains(L)) { // To evaluate this recurrence, we need to know how many times the AddRec // loop iterates. Compute this now. const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); @@ -3932,7 +4265,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { return getTruncateExpr(Op, Cast->getType()); } - LLVM_UNREACHABLE("Unknown SCEV type!"); + llvm_unreachable("Unknown SCEV type!"); return 0; } @@ -4043,12 +4376,12 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { return std::make_pair(CNC, CNC); } - LLVMContext *Context = SE.getContext(); + LLVMContext &Context = SE.getContext(); ConstantInt *Solution1 = - Context->getConstantInt((NegB + SqrtVal).sdiv(TwoA)); + ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA)); ConstantInt *Solution2 = - Context->getConstantInt((NegB - SqrtVal).sdiv(TwoA)); + ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); return std::make_pair(SE.getConstant(Solution1), SE.getConstant(Solution2)); @@ -4092,7 +4425,7 @@ const SCEV *ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { // First, handle unitary steps. if (StepC->getValue()->equalsInt(1)) // 1*N = -Start (mod 2^BW), so: - return getNegativeSCEV(Start); // N = -Start (as unsigned) + return getNegativeSCEV(Start); // N = -Start (as unsigned) if (StepC->getValue()->isAllOnesValue()) // -1*N = -Start (mod 2^BW), so: return Start; // N = Start (as unsigned) @@ -4111,12 +4444,12 @@ const SCEV *ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { const SCEVConstant *R2 = dyn_cast(Roots.second); if (R1) { #if 0 - errs() << "HFTZ: " << *V << " - sol#1: " << *R1 + dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1 << " sol#2: " << *R2 << "\n"; #endif // Pick the smallest positive root value. if (ConstantInt *CB = - dyn_cast(Context->getConstantExprICmp(ICmpInst::ICMP_ULT, + dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (CB->getZExtValue() == false) std::swap(R1, R2); // R1 is the minimum root now. @@ -4208,7 +4541,7 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) { if (const SCEVUnknown *BU = dyn_cast(B)) if (const Instruction *AI = dyn_cast(AU->getValue())) if (const Instruction *BI = dyn_cast(BU->getValue())) - if (AI->isIdenticalTo(BI)) + if (AI->isIdenticalTo(BI) && !AI->mayReadFromMemory()) return true; // Otherwise assume they may have a different value. @@ -4243,7 +4576,7 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, switch (Pred) { default: - assert(0 && "Unexpected ICmpInst::Predicate value!"); + llvm_unreachable("Unexpected ICmpInst::Predicate value!"); break; case ICmpInst::ICMP_SGT: Pred = ICmpInst::ICMP_SLT; @@ -4255,20 +4588,6 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, return true; if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax())) return false; - - const SCEV *Diff = getMinusSCEV(LHS, RHS); - ConstantRange DiffRange = getUnsignedRange(Diff); - if (isKnownNegative(Diff)) { - if (DiffRange.getUnsignedMax().ult(LHSRange.getUnsignedMin())) - return true; - if (DiffRange.getUnsignedMin().uge(LHSRange.getUnsignedMax())) - return false; - } else if (isKnownPositive(Diff)) { - if (LHSRange.getUnsignedMax().ult(DiffRange.getUnsignedMin())) - return true; - if (LHSRange.getUnsignedMin().uge(DiffRange.getUnsignedMax())) - return false; - } break; } case ICmpInst::ICMP_SGE: @@ -4281,20 +4600,6 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, return true; if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax())) return false; - - const SCEV *Diff = getMinusSCEV(LHS, RHS); - ConstantRange DiffRange = getUnsignedRange(Diff); - if (isKnownNonPositive(Diff)) { - if (DiffRange.getUnsignedMax().ule(LHSRange.getUnsignedMin())) - return true; - if (DiffRange.getUnsignedMin().ugt(LHSRange.getUnsignedMax())) - return false; - } else if (isKnownNonNegative(Diff)) { - if (LHSRange.getUnsignedMax().ule(DiffRange.getUnsignedMin())) - return true; - if (LHSRange.getUnsignedMin().ugt(DiffRange.getUnsignedMax())) - return false; - } break; } case ICmpInst::ICMP_UGT: @@ -4307,13 +4612,6 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, return true; if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax())) return false; - - const SCEV *Diff = getMinusSCEV(LHS, RHS); - ConstantRange DiffRange = getUnsignedRange(Diff); - if (LHSRange.getUnsignedMax().ult(DiffRange.getUnsignedMin())) - return true; - if (LHSRange.getUnsignedMin().uge(DiffRange.getUnsignedMax())) - return false; break; } case ICmpInst::ICMP_UGE: @@ -4326,13 +4624,6 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, return true; if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax())) return false; - - const SCEV *Diff = getMinusSCEV(LHS, RHS); - ConstantRange DiffRange = getUnsignedRange(Diff); - if (LHSRange.getUnsignedMax().ule(DiffRange.getUnsignedMin())) - return true; - if (LHSRange.getUnsignedMin().ugt(DiffRange.getUnsignedMax())) - return false; break; } case ICmpInst::ICMP_NE: { @@ -4347,6 +4638,8 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, break; } case ICmpInst::ICMP_EQ: + // The check at the top of the function catches the case where + // the values are known to be equal. break; } return false; @@ -4373,9 +4666,8 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, LoopContinuePredicate->isUnconditional()) return false; - return - isNecessaryCond(LoopContinuePredicate->getCondition(), Pred, LHS, RHS, - LoopContinuePredicate->getSuccessor(0) != L->getHeader()); + return isImpliedCond(LoopContinuePredicate->getCondition(), Pred, LHS, RHS, + LoopContinuePredicate->getSuccessor(0) != L->getHeader()); } /// isLoopGuardedByCond - Test whether entry to the loop is protected @@ -4405,122 +4697,55 @@ ScalarEvolution::isLoopGuardedByCond(const Loop *L, LoopEntryPredicate->isUnconditional()) continue; - if (isNecessaryCond(LoopEntryPredicate->getCondition(), Pred, LHS, RHS, - LoopEntryPredicate->getSuccessor(0) != PredecessorDest)) + if (isImpliedCond(LoopEntryPredicate->getCondition(), Pred, LHS, RHS, + LoopEntryPredicate->getSuccessor(0) != PredecessorDest)) return true; } return false; } -/// isNecessaryCond - Test whether the condition described by Pred, LHS, -/// and RHS is a necessary condition for the given Cond value to evaluate -/// to true. -bool ScalarEvolution::isNecessaryCond(Value *CondValue, - ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - bool Inverse) { +/// isImpliedCond - Test whether the condition described by Pred, LHS, +/// and RHS is true whenever the given Cond value evaluates to true. +bool ScalarEvolution::isImpliedCond(Value *CondValue, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + bool Inverse) { // Recursivly handle And and Or conditions. if (BinaryOperator *BO = dyn_cast(CondValue)) { if (BO->getOpcode() == Instruction::And) { if (!Inverse) - return isNecessaryCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) || - isNecessaryCond(BO->getOperand(1), Pred, LHS, RHS, Inverse); + return isImpliedCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) || + isImpliedCond(BO->getOperand(1), Pred, LHS, RHS, Inverse); } else if (BO->getOpcode() == Instruction::Or) { if (Inverse) - return isNecessaryCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) || - isNecessaryCond(BO->getOperand(1), Pred, LHS, RHS, Inverse); + return isImpliedCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) || + isImpliedCond(BO->getOperand(1), Pred, LHS, RHS, Inverse); } } ICmpInst *ICI = dyn_cast(CondValue); if (!ICI) return false; - // Now that we found a conditional branch that dominates the loop, check to - // see if it is the comparison we are looking for. - Value *PreCondLHS = ICI->getOperand(0); - Value *PreCondRHS = ICI->getOperand(1); - ICmpInst::Predicate FoundPred; - if (Inverse) - FoundPred = ICI->getInversePredicate(); - else - FoundPred = ICI->getPredicate(); - - if (FoundPred == Pred) - ; // An exact match. - else if (!ICmpInst::isTrueWhenEqual(FoundPred) && Pred == ICmpInst::ICMP_NE) { - // The actual condition is beyond sufficient. - FoundPred = ICmpInst::ICMP_NE; - // NE is symmetric but the original comparison may not be. Swap - // the operands if necessary so that they match below. - if (isa(LHS)) - std::swap(PreCondLHS, PreCondRHS); - } else - // Check a few special cases. - switch (FoundPred) { - case ICmpInst::ICMP_UGT: - if (Pred == ICmpInst::ICMP_ULT) { - std::swap(PreCondLHS, PreCondRHS); - FoundPred = ICmpInst::ICMP_ULT; - break; - } - return false; - case ICmpInst::ICMP_SGT: - if (Pred == ICmpInst::ICMP_SLT) { - std::swap(PreCondLHS, PreCondRHS); - FoundPred = ICmpInst::ICMP_SLT; - break; - } - return false; - case ICmpInst::ICMP_NE: - // Expressions like (x >u 0) are often canonicalized to (x != 0), - // so check for this case by checking if the NE is comparing against - // a minimum or maximum constant. - if (!ICmpInst::isTrueWhenEqual(Pred)) - if (const SCEVConstant *C = dyn_cast(RHS)) { - const APInt &A = C->getValue()->getValue(); - switch (Pred) { - case ICmpInst::ICMP_SLT: - if (A.isMaxSignedValue()) break; - return false; - case ICmpInst::ICMP_SGT: - if (A.isMinSignedValue()) break; - return false; - case ICmpInst::ICMP_ULT: - if (A.isMaxValue()) break; - return false; - case ICmpInst::ICMP_UGT: - if (A.isMinValue()) break; - return false; - default: - return false; - } - FoundPred = Pred; - // NE is symmetric but the original comparison may not be. Swap - // the operands if necessary so that they match below. - if (isa(LHS)) - std::swap(PreCondLHS, PreCondRHS); - break; - } - return false; - default: - // We weren't able to reconcile the condition. - return false; - } - - assert(Pred == FoundPred && "Conditions were not reconciled!"); - // Bail if the ICmp's operands' types are wider than the needed type // before attempting to call getSCEV on them. This avoids infinite // recursion, since the analysis of widening casts can require loop // exit condition information for overflow checking, which would // lead back here. if (getTypeSizeInBits(LHS->getType()) < - getTypeSizeInBits(PreCondLHS->getType())) + getTypeSizeInBits(ICI->getOperand(0)->getType())) return false; - const SCEV *FoundLHS = getSCEV(PreCondLHS); - const SCEV *FoundRHS = getSCEV(PreCondRHS); + // Now that we found a conditional branch that dominates the loop, check to + // see if it is the comparison we are looking for. + ICmpInst::Predicate FoundPred; + if (Inverse) + FoundPred = ICI->getInversePredicate(); + else + FoundPred = ICI->getPredicate(); + + const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); + const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); // Balance the types. The case where FoundLHS' type is wider than // LHS' type is checked for above. @@ -4535,39 +4760,209 @@ bool ScalarEvolution::isNecessaryCond(Value *CondValue, } } - return isNecessaryCondOperands(Pred, LHS, RHS, - FoundLHS, FoundRHS) || + // Canonicalize the query to match the way instcombine will have + // canonicalized the comparison. + // First, put a constant operand on the right. + if (isa(LHS)) { + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + // Then, canonicalize comparisons with boundary cases. + if (const SCEVConstant *RC = dyn_cast(RHS)) { + const APInt &RA = RC->getValue()->getValue(); + switch (Pred) { + default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + break; + case ICmpInst::ICMP_UGE: + if ((RA - 1).isMinValue()) { + Pred = ICmpInst::ICMP_NE; + RHS = getConstant(RA - 1); + break; + } + if (RA.isMaxValue()) { + Pred = ICmpInst::ICMP_EQ; + break; + } + if (RA.isMinValue()) return true; + break; + case ICmpInst::ICMP_ULE: + if ((RA + 1).isMaxValue()) { + Pred = ICmpInst::ICMP_NE; + RHS = getConstant(RA + 1); + break; + } + if (RA.isMinValue()) { + Pred = ICmpInst::ICMP_EQ; + break; + } + if (RA.isMaxValue()) return true; + break; + case ICmpInst::ICMP_SGE: + if ((RA - 1).isMinSignedValue()) { + Pred = ICmpInst::ICMP_NE; + RHS = getConstant(RA - 1); + break; + } + if (RA.isMaxSignedValue()) { + Pred = ICmpInst::ICMP_EQ; + break; + } + if (RA.isMinSignedValue()) return true; + break; + case ICmpInst::ICMP_SLE: + if ((RA + 1).isMaxSignedValue()) { + Pred = ICmpInst::ICMP_NE; + RHS = getConstant(RA + 1); + break; + } + if (RA.isMinSignedValue()) { + Pred = ICmpInst::ICMP_EQ; + break; + } + if (RA.isMaxSignedValue()) return true; + break; + case ICmpInst::ICMP_UGT: + if (RA.isMinValue()) { + Pred = ICmpInst::ICMP_NE; + break; + } + if ((RA + 1).isMaxValue()) { + Pred = ICmpInst::ICMP_EQ; + RHS = getConstant(RA + 1); + break; + } + if (RA.isMaxValue()) return false; + break; + case ICmpInst::ICMP_ULT: + if (RA.isMaxValue()) { + Pred = ICmpInst::ICMP_NE; + break; + } + if ((RA - 1).isMinValue()) { + Pred = ICmpInst::ICMP_EQ; + RHS = getConstant(RA - 1); + break; + } + if (RA.isMinValue()) return false; + break; + case ICmpInst::ICMP_SGT: + if (RA.isMinSignedValue()) { + Pred = ICmpInst::ICMP_NE; + break; + } + if ((RA + 1).isMaxSignedValue()) { + Pred = ICmpInst::ICMP_EQ; + RHS = getConstant(RA + 1); + break; + } + if (RA.isMaxSignedValue()) return false; + break; + case ICmpInst::ICMP_SLT: + if (RA.isMaxSignedValue()) { + Pred = ICmpInst::ICMP_NE; + break; + } + if ((RA - 1).isMinSignedValue()) { + Pred = ICmpInst::ICMP_EQ; + RHS = getConstant(RA - 1); + break; + } + if (RA.isMinSignedValue()) return false; + break; + } + } + + // Check to see if we can make the LHS or RHS match. + if (LHS == FoundRHS || RHS == FoundLHS) { + if (isa(RHS)) { + std::swap(FoundLHS, FoundRHS); + FoundPred = ICmpInst::getSwappedPredicate(FoundPred); + } else { + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + } + + // Check whether the found predicate is the same as the desired predicate. + if (FoundPred == Pred) + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); + + // Check whether swapping the found predicate makes it the same as the + // desired predicate. + if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) { + if (isa(RHS)) + return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS); + else + return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), + RHS, LHS, FoundLHS, FoundRHS); + } + + // Check whether the actual condition is beyond sufficient. + if (FoundPred == ICmpInst::ICMP_EQ) + if (ICmpInst::isTrueWhenEqual(Pred)) + if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + if (Pred == ICmpInst::ICMP_NE) + if (!ICmpInst::isTrueWhenEqual(FoundPred)) + if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + + // Otherwise assume the worst. + return false; +} + +/// isImpliedCondOperands - Test whether the condition described by Pred, +/// LHS, and RHS is true whenever the condition desribed by Pred, FoundLHS, +/// and FoundRHS is true. +bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + return isImpliedCondOperandsHelper(Pred, LHS, RHS, + FoundLHS, FoundRHS) || // ~x < ~y --> x > y - isNecessaryCondOperands(Pred, LHS, RHS, - getNotSCEV(FoundRHS), getNotSCEV(FoundLHS)); + isImpliedCondOperandsHelper(Pred, LHS, RHS, + getNotSCEV(FoundRHS), + getNotSCEV(FoundLHS)); } -/// isNecessaryCondOperands - Test whether the condition described by Pred, -/// LHS, and RHS is a necessary condition for the condition described by -/// Pred, FoundLHS, and FoundRHS to evaluate to true. +/// isImpliedCondOperandsHelper - Test whether the condition described by +/// Pred, LHS, and RHS is true whenever the condition desribed by Pred, +/// FoundLHS, and FoundRHS is true. bool -ScalarEvolution::isNecessaryCondOperands(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS) { +ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { switch (Pred) { - default: break; + default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS)) + return true; + break; case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: if (isKnownPredicate(ICmpInst::ICMP_SLE, LHS, FoundLHS) && isKnownPredicate(ICmpInst::ICMP_SGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: if (isKnownPredicate(ICmpInst::ICMP_SGE, LHS, FoundLHS) && isKnownPredicate(ICmpInst::ICMP_SLE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: if (isKnownPredicate(ICmpInst::ICMP_ULE, LHS, FoundLHS) && isKnownPredicate(ICmpInst::ICMP_UGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: if (isKnownPredicate(ICmpInst::ICMP_UGE, LHS, FoundLHS) && isKnownPredicate(ICmpInst::ICMP_ULE, RHS, FoundRHS)) return true; @@ -4582,7 +4977,11 @@ ScalarEvolution::isNecessaryCondOperands(ICmpInst::Predicate Pred, /// CouldNotCompute if an intermediate computation overflows. const SCEV *ScalarEvolution::getBECount(const SCEV *Start, const SCEV *End, - const SCEV *Step) { + const SCEV *Step, + bool NoWrap) { + assert(!isKnownNegative(Step) && + "This code doesn't handle negative strides yet!"); + const Type *Ty = Start->getType(); const SCEV *NegOne = getIntegerSCEV(-1, Ty); const SCEV *Diff = getMinusSCEV(End, Start); @@ -4592,14 +4991,17 @@ const SCEV *ScalarEvolution::getBECount(const SCEV *Start, // the division will effectively round up. const SCEV *Add = getAddExpr(Diff, RoundUp); - // Check Add for unsigned overflow. - // TODO: More sophisticated things could be done here. - const Type *WideTy = Context->getIntegerType(getTypeSizeInBits(Ty) + 1); - const SCEV *EDiff = getZeroExtendExpr(Diff, WideTy); - const SCEV *ERoundUp = getZeroExtendExpr(RoundUp, WideTy); - const SCEV *OperandExtendedAdd = getAddExpr(EDiff, ERoundUp); - if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd) - return getCouldNotCompute(); + if (!NoWrap) { + // Check Add for unsigned overflow. + // TODO: More sophisticated things could be done here. + const Type *WideTy = IntegerType::get(getContext(), + getTypeSizeInBits(Ty) + 1); + const SCEV *EDiff = getZeroExtendExpr(Diff, WideTy); + const SCEV *ERoundUp = getZeroExtendExpr(RoundUp, WideTy); + const SCEV *OperandExtendedAdd = getAddExpr(EDiff, ERoundUp); + if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd) + return getCouldNotCompute(); + } return getUDivExpr(Add, Step); } @@ -4617,37 +5019,40 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, if (!AddRec || AddRec->getLoop() != L) return getCouldNotCompute(); + // Check to see if we have a flag which makes analysis easy. + bool NoWrap = isSigned ? AddRec->hasNoSignedWrap() : + AddRec->hasNoUnsignedWrap(); + if (AddRec->isAffine()) { - // FORNOW: We only support unit strides. unsigned BitWidth = getTypeSizeInBits(AddRec->getType()); const SCEV *Step = AddRec->getStepRecurrence(*this); - // TODO: handle non-constant strides. - const SCEVConstant *CStep = dyn_cast(Step); - if (!CStep || CStep->isZero()) + if (Step->isZero()) return getCouldNotCompute(); - if (CStep->isOne()) { + if (Step->isOne()) { // With unit stride, the iteration never steps past the limit value. - } else if (CStep->getValue()->getValue().isStrictlyPositive()) { - if (const SCEVConstant *CLimit = dyn_cast(RHS)) { - // Test whether a positive iteration iteration can step past the limit - // value and past the maximum value for its type in a single step. - if (isSigned) { - APInt Max = APInt::getSignedMaxValue(BitWidth); - if ((Max - CStep->getValue()->getValue()) - .slt(CLimit->getValue()->getValue())) - return getCouldNotCompute(); - } else { - APInt Max = APInt::getMaxValue(BitWidth); - if ((Max - CStep->getValue()->getValue()) - .ult(CLimit->getValue()->getValue())) - return getCouldNotCompute(); - } - } else - // TODO: handle non-constant limit values below. - return getCouldNotCompute(); + } else if (isKnownPositive(Step)) { + // Test whether a positive iteration can step past the limit + // value and past the maximum value for its type in a single step. + // Note that it's not sufficient to check NoWrap here, because even + // though the value after a wrap is undefined, it's not undefined + // behavior, so if wrap does occur, the loop could either terminate or + // loop infinitely, but in either case, the loop is guaranteed to + // iterate at least until the iteration where the wrapping occurs. + const SCEV *One = getIntegerSCEV(1, Step->getType()); + if (isSigned) { + APInt Max = APInt::getSignedMaxValue(BitWidth); + if ((Max - getSignedRange(getMinusSCEV(Step, One)).getSignedMax()) + .slt(getSignedRange(RHS).getSignedMax())) + return getCouldNotCompute(); + } else { + APInt Max = APInt::getMaxValue(BitWidth); + if ((Max - getUnsignedRange(getMinusSCEV(Step, One)).getUnsignedMax()) + .ult(getUnsignedRange(RHS).getUnsignedMax())) + return getCouldNotCompute(); + } } else - // TODO: handle negative strides below. + // TODO: Handle negative strides here and below. return getCouldNotCompute(); // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant @@ -4680,13 +5085,27 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, getSignedRange(End).getSignedMax() : getUnsignedRange(End).getUnsignedMax()); + // If MaxEnd is within a step of the maximum integer value in its type, + // adjust it down to the minimum value which would produce the same effect. + // This allows the subsequent ceiling divison of (N+(step-1))/step to + // compute the correct value. + const SCEV *StepMinusOne = getMinusSCEV(Step, + getIntegerSCEV(1, Step->getType())); + MaxEnd = isSigned ? + getSMinExpr(MaxEnd, + getMinusSCEV(getConstant(APInt::getSignedMaxValue(BitWidth)), + StepMinusOne)) : + getUMinExpr(MaxEnd, + getMinusSCEV(getConstant(APInt::getMaxValue(BitWidth)), + StepMinusOne)); + // Finally, we subtract these two values and divide, rounding up, to get // the number of times the backedge is executed. - const SCEV *BECount = getBECount(Start, End, Step); + const SCEV *BECount = getBECount(Start, End, Step, NoWrap); // The maximum backedge count is similar, except using the minimum start // value and the maximum end value. - const SCEV *MaxBECount = getBECount(MinStart, MaxEnd, Step); + const SCEV *MaxBECount = getBECount(MinStart, MaxEnd, Step, NoWrap); return BackedgeTakenInfo(BECount, MaxBECount); } @@ -4748,7 +5167,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // The exit value should be (End+A)/A. APInt ExitVal = (End + A).udiv(A); - ConstantInt *ExitValue = SE.getContext()->getConstantInt(ExitVal); + ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal); // Evaluate at the exit value. If we really did fall out of the valid // range, then we computed our trip count, otherwise wrap around or other @@ -4760,7 +5179,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // Ensure that the previous value is in the range. This is a sanity check. assert(Range.contains( EvaluateConstantChrecAtConstant(this, - SE.getContext()->getConstantInt(ExitVal - One), SE)->getValue()) && + ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) && "Linear scev computation is off in a bad way!"); return SE.getConstant(ExitValue); } else if (isQuadratic()) { @@ -4780,8 +5199,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, if (R1) { // Pick the smallest positive root value. if (ConstantInt *CB = - dyn_cast( - SE.getContext()->getConstantExprICmp(ICmpInst::ICMP_ULT, + dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (CB->getZExtValue() == false) std::swap(R1, R2); // R1 is the minimum root now. @@ -4795,7 +5213,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, if (Range.contains(R1Val->getValue())) { // The next iteration must be out of the range... ConstantInt *NextVal = - SE.getContext()->getConstantInt(R1->getValue()->getValue()+1); + ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (!Range.contains(R1Val->getValue())) @@ -4806,7 +5224,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // If R1 was not in the range, then it is a good return value. Make // sure that R1-1 WAS in the range though, just in case. ConstantInt *NextVal = - SE.getContext()->getConstantInt(R1->getValue()->getValue()-1); + ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (Range.contains(R1Val->getValue())) return R1; @@ -4825,22 +5243,21 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, //===----------------------------------------------------------------------===// void ScalarEvolution::SCEVCallbackVH::deleted() { - assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!"); + assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); if (PHINode *PN = dyn_cast(getValPtr())) SE->ConstantEvolutionLoopExitValue.erase(PN); - if (Instruction *I = dyn_cast(getValPtr())) - SE->ValuesAtScopes.erase(I); SE->Scalars.erase(getValPtr()); // this now dangles! } void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) { - assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!"); + assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); // Forget all the expressions associated with users of the old value, // so that future queries will recompute the expressions using the new // value. SmallVector Worklist; + SmallPtrSet Visited; Value *Old = getValPtr(); bool DeleteOld = false; for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end(); @@ -4854,20 +5271,19 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) { DeleteOld = true; continue; } + if (!Visited.insert(U)) + continue; if (PHINode *PN = dyn_cast(U)) SE->ConstantEvolutionLoopExitValue.erase(PN); - if (Instruction *I = dyn_cast(U)) - SE->ValuesAtScopes.erase(I); - if (SE->Scalars.erase(U)) - for (Value::use_iterator UI = U->use_begin(), UE = U->use_end(); - UI != UE; ++UI) - Worklist.push_back(*UI); + SE->Scalars.erase(U); + for (Value::use_iterator UI = U->use_begin(), UE = U->use_end(); + UI != UE; ++UI) + Worklist.push_back(*UI); } + // Delete the Old value if it (indirectly) references itself. if (DeleteOld) { if (PHINode *PN = dyn_cast(Old)) SE->ConstantEvolutionLoopExitValue.erase(PN); - if (Instruction *I = dyn_cast(Old)) - SE->ValuesAtScopes.erase(I); SE->Scalars.erase(Old); // this now dangles! } @@ -4888,6 +5304,7 @@ ScalarEvolution::ScalarEvolution() bool ScalarEvolution::runOnFunction(Function &F) { this->F = &F; LI = &getAnalysis(); + DT = &getAnalysis(); TD = getAnalysisIfAvailable(); return false; } @@ -4904,6 +5321,7 @@ void ScalarEvolution::releaseMemory() { void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequiredTransitive(); + AU.addRequiredTransitive(); } bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { @@ -4916,9 +5334,11 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) PrintLoopInfo(OS, SE, *I); - OS << "Loop " << L->getHeader()->getName() << ": "; + OS << "Loop "; + WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false); + OS << ": "; - SmallVector ExitBlocks; + SmallVector ExitBlocks; L->getExitBlocks(ExitBlocks); if (ExitBlocks.size() != 1) OS << " "; @@ -4929,8 +5349,10 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, OS << "Unpredictable backedge-taken count. "; } - OS << "\n"; - OS << "Loop " << L->getHeader()->getName() << ": "; + OS << "\n" + "Loop "; + WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false); + OS << ": "; if (!isa(SE->getMaxBackedgeTakenCount(L))) { OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L); @@ -4941,19 +5363,21 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, OS << "\n"; } -void ScalarEvolution::print(raw_ostream &OS, const Module* ) const { +void ScalarEvolution::print(raw_ostream &OS, const Module *) const { // ScalarEvolution's implementaiton of the print method is to print // out SCEV values of all instructions that are interesting. Doing // this potentially causes it to create new SCEV objects though, // which technically conflicts with the const qualifier. This isn't // observable from outside the class though, so casting away the // const isn't dangerous. - ScalarEvolution &SE = *const_cast(this); + ScalarEvolution &SE = *const_cast(this); - OS << "Classifying expressions for: " << F->getName() << "\n"; + OS << "Classifying expressions for: "; + WriteAsOperand(OS, F, /*PrintType=*/false); + OS << "\n"; for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) if (isSCEVable(I->getType())) { - OS << *I; + OS << *I << '\n'; OS << " --> "; const SCEV *SV = SE.getSCEV(&*I); SV->print(OS); @@ -4979,12 +5403,10 @@ void ScalarEvolution::print(raw_ostream &OS, const Module* ) const { OS << "\n"; } - OS << "Determining loop execution counts for: " << F->getName() << "\n"; + OS << "Determining loop execution counts for: "; + WriteAsOperand(OS, F, /*PrintType=*/false); + OS << "\n"; for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) PrintLoopInfo(OS, &SE, *I); } -void ScalarEvolution::print(std::ostream &o, const Module *M) const { - raw_os_ostream OS(o); - print(OS, M); -}