X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=fb5acb91aee0a21d172a8618cdc62c26ca890697;hb=fcb4356dee96563def584fe38eeafb3eb63c5cd8;hp=0ef5d84be5abe572f653a60665e577675c83e6cc;hpb=9f93d30a26ae9928f886ef7271efeafcea2a00a6;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 0ef5d84be5a..e5c26404245 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -69,6 +69,7 @@ #include "llvm/Operator.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Assembly/Writer.h" @@ -103,8 +104,12 @@ MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, "derived loop"), cl::init(100)); -static RegisterPass -R("scalar-evolution", "Scalar Evolution Analysis", false, true); +INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution", + "Scalar Evolution Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(LoopInfo) +INITIALIZE_PASS_DEPENDENCY(DominatorTree) +INITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution", + "Scalar Evolution Analysis", false, true) char ScalarEvolution::ID = 0; //===----------------------------------------------------------------------===// @@ -115,13 +120,142 @@ char ScalarEvolution::ID = 0; // Implementation of the SCEV class. // -SCEV::~SCEV() {} - void SCEV::dump() const { print(dbgs()); dbgs() << '\n'; } +void SCEV::print(raw_ostream &OS) const { + switch (getSCEVType()) { + case scConstant: + WriteAsOperand(OS, cast(this)->getValue(), false); + return; + case scTruncate: { + const SCEVTruncateExpr *Trunc = cast(this); + const SCEV *Op = Trunc->getOperand(); + OS << "(trunc " << *Op->getType() << " " << *Op << " to " + << *Trunc->getType() << ")"; + return; + } + case scZeroExtend: { + const SCEVZeroExtendExpr *ZExt = cast(this); + const SCEV *Op = ZExt->getOperand(); + OS << "(zext " << *Op->getType() << " " << *Op << " to " + << *ZExt->getType() << ")"; + return; + } + case scSignExtend: { + const SCEVSignExtendExpr *SExt = cast(this); + const SCEV *Op = SExt->getOperand(); + OS << "(sext " << *Op->getType() << " " << *Op << " to " + << *SExt->getType() << ")"; + return; + } + case scAddRecExpr: { + const SCEVAddRecExpr *AR = cast(this); + OS << "{" << *AR->getOperand(0); + for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i) + OS << ",+," << *AR->getOperand(i); + OS << "}<"; + if (AR->getNoWrapFlags(FlagNUW)) + OS << "nuw><"; + if (AR->getNoWrapFlags(FlagNSW)) + OS << "nsw><"; + if (AR->getNoWrapFlags(FlagNW) && + !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW))) + OS << "nw><"; + WriteAsOperand(OS, AR->getLoop()->getHeader(), /*PrintType=*/false); + OS << ">"; + return; + } + case scAddExpr: + case scMulExpr: + case scUMaxExpr: + case scSMaxExpr: { + const SCEVNAryExpr *NAry = cast(this); + const char *OpStr = 0; + switch (NAry->getSCEVType()) { + case scAddExpr: OpStr = " + "; break; + case scMulExpr: OpStr = " * "; break; + case scUMaxExpr: OpStr = " umax "; break; + case scSMaxExpr: OpStr = " smax "; break; + } + OS << "("; + for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); + I != E; ++I) { + OS << **I; + if (llvm::next(I) != E) + OS << OpStr; + } + OS << ")"; + return; + } + case scUDivExpr: { + const SCEVUDivExpr *UDiv = cast(this); + OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")"; + return; + } + case scUnknown: { + const SCEVUnknown *U = cast(this); + Type *AllocTy; + if (U->isSizeOf(AllocTy)) { + OS << "sizeof(" << *AllocTy << ")"; + return; + } + if (U->isAlignOf(AllocTy)) { + OS << "alignof(" << *AllocTy << ")"; + return; + } + + Type *CTy; + Constant *FieldNo; + if (U->isOffsetOf(CTy, FieldNo)) { + OS << "offsetof(" << *CTy << ", "; + WriteAsOperand(OS, FieldNo, false); + OS << ")"; + return; + } + + // Otherwise just print it normally. + WriteAsOperand(OS, U->getValue(), false); + return; + } + case scCouldNotCompute: + OS << "***COULDNOTCOMPUTE***"; + return; + default: break; + } + llvm_unreachable("Unknown SCEV kind!"); +} + +Type *SCEV::getType() const { + switch (getSCEVType()) { + case scConstant: + return cast(this)->getType(); + case scTruncate: + case scZeroExtend: + case scSignExtend: + return cast(this)->getType(); + case scAddRecExpr: + case scMulExpr: + case scUMaxExpr: + case scSMaxExpr: + return cast(this)->getType(); + case scAddExpr: + return cast(this)->getType(); + case scUDivExpr: + return cast(this)->getType(); + case scUnknown: + return cast(this)->getType(); + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + return 0; + default: break; + } + llvm_unreachable("Unknown SCEV kind!"); + return 0; +} + bool SCEV::isZero() const { if (const SCEVConstant *SC = dyn_cast(this)) return SC->getValue()->isZero(); @@ -143,30 +277,6 @@ bool SCEV::isAllOnesValue() const { SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(FoldingSetNodeIDRef(), scCouldNotCompute) {} -bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const { - llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); - return false; -} - -const Type *SCEVCouldNotCompute::getType() const { - llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); - return 0; -} - -bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const { - llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); - return false; -} - -bool SCEVCouldNotCompute::hasOperand(const SCEV *) const { - llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); - return false; -} - -void SCEVCouldNotCompute::print(raw_ostream &OS) const { - OS << "***COULDNOTCOMPUTE***"; -} - bool SCEVCouldNotCompute::classof(const SCEV *S) { return S->getSCEVType() == scCouldNotCompute; } @@ -187,184 +297,65 @@ const SCEV *ScalarEvolution::getConstant(const APInt& Val) { } const SCEV * -ScalarEvolution::getConstant(const Type *Ty, uint64_t V, bool isSigned) { - const IntegerType *ITy = cast(getEffectiveSCEVType(Ty)); +ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { + IntegerType *ITy = cast(getEffectiveSCEVType(Ty)); return getConstant(ConstantInt::get(ITy, V, isSigned)); } -const Type *SCEVConstant::getType() const { return V->getType(); } - -void SCEVConstant::print(raw_ostream &OS) const { - WriteAsOperand(OS, V, false); -} - SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, - unsigned SCEVTy, const SCEV *op, const Type *ty) + unsigned SCEVTy, const SCEV *op, Type *ty) : SCEV(ID, SCEVTy), Op(op), Ty(ty) {} -bool SCEVCastExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { - return Op->dominates(BB, DT); -} - -bool SCEVCastExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const { - return Op->properlyDominates(BB, DT); -} - SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, - const SCEV *op, const Type *ty) + const SCEV *op, Type *ty) : SCEVCastExpr(ID, scTruncate, op, ty) { assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot truncate non-integer value!"); } -void SCEVTruncateExpr::print(raw_ostream &OS) const { - OS << "(trunc " << *Op->getType() << " " << *Op << " to " << *Ty << ")"; -} - SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, - const SCEV *op, const Type *ty) + const SCEV *op, Type *ty) : SCEVCastExpr(ID, scZeroExtend, op, ty) { assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot zero extend non-integer value!"); } -void SCEVZeroExtendExpr::print(raw_ostream &OS) const { - OS << "(zext " << *Op->getType() << " " << *Op << " to " << *Ty << ")"; -} - SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, - const SCEV *op, const Type *ty) + const SCEV *op, Type *ty) : SCEVCastExpr(ID, scSignExtend, op, ty) { assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot sign extend non-integer value!"); } -void SCEVSignExtendExpr::print(raw_ostream &OS) const { - OS << "(sext " << *Op->getType() << " " << *Op << " to " << *Ty << ")"; -} +void SCEVUnknown::deleted() { + // Clear this SCEVUnknown from various maps. + SE->forgetMemoizedResults(this); -void SCEVCommutativeExpr::print(raw_ostream &OS) const { - const char *OpStr = getOperationStr(); - OS << "("; - for (op_iterator I = op_begin(), E = op_end(); I != E; ++I) { - OS << **I; - if (next(I) != E) - OS << OpStr; - } - OS << ")"; -} + // Remove this SCEVUnknown from the uniquing map. + SE->UniqueSCEVs.RemoveNode(this); -bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { - for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - if (!getOperand(i)->dominates(BB, DT)) - return false; - } - return true; + // Release the value. + setValPtr(0); } -bool SCEVNAryExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const { - for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - if (!getOperand(i)->properlyDominates(BB, DT)) - return false; - } - return true; -} - -bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { - return LHS->dominates(BB, DT) && RHS->dominates(BB, DT); -} - -bool SCEVUDivExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const { - return LHS->properlyDominates(BB, DT) && RHS->properlyDominates(BB, DT); -} - -void SCEVUDivExpr::print(raw_ostream &OS) const { - OS << "(" << *LHS << " /u " << *RHS << ")"; -} - -const Type *SCEVUDivExpr::getType() const { - // In most cases the types of LHS and RHS will be the same, but in some - // crazy cases one or the other may be a pointer. ScalarEvolution doesn't - // depend on the type for correctness, but handling types carefully can - // avoid extra casts in the SCEVExpander. The LHS is more likely to be - // a pointer type than the RHS, so use the RHS' type here. - return RHS->getType(); -} - -bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const { - // Add recurrences are never invariant in the function-body (null loop). - if (!QueryLoop) - return false; - - // This recurrence is variant w.r.t. QueryLoop if QueryLoop contains L. - if (QueryLoop->contains(L)) - return false; - - // This recurrence is variant w.r.t. QueryLoop if any of its operands - // are variant. - for (unsigned i = 0, e = getNumOperands(); i != e; ++i) - if (!getOperand(i)->isLoopInvariant(QueryLoop)) - return false; - - // Otherwise it's loop-invariant. - return true; -} - -bool -SCEVAddRecExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { - return DT->dominates(L->getHeader(), BB) && - SCEVNAryExpr::dominates(BB, DT); -} - -bool -SCEVAddRecExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const { - // This uses a "dominates" query instead of "properly dominates" query because - // the instruction which produces the addrec's value is a PHI, and a PHI - // effectively properly dominates its entire containing block. - return DT->dominates(L->getHeader(), BB) && - SCEVNAryExpr::properlyDominates(BB, DT); -} - -void SCEVAddRecExpr::print(raw_ostream &OS) const { - OS << "{" << *Operands[0]; - for (unsigned i = 1, e = NumOperands; i != e; ++i) - OS << ",+," << *Operands[i]; - OS << "}<"; - WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false); - OS << ">"; -} +void SCEVUnknown::allUsesReplacedWith(Value *New) { + // Clear this SCEVUnknown from various maps. + SE->forgetMemoizedResults(this); -bool SCEVUnknown::isLoopInvariant(const Loop *L) const { - // All non-instruction values are loop invariant. All instructions are loop - // invariant if they are not contained in the specified loop. - // Instructions are never considered invariant in the function body - // (null loop) because they are defined within the "loop". - if (Instruction *I = dyn_cast(V)) - return L && !L->contains(I); - return true; -} - -bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const { - if (Instruction *I = dyn_cast(getValue())) - return DT->dominates(I->getParent(), BB); - return true; -} - -bool SCEVUnknown::properlyDominates(BasicBlock *BB, DominatorTree *DT) const { - if (Instruction *I = dyn_cast(getValue())) - return DT->properlyDominates(I->getParent(), BB); - return true; -} + // Remove this SCEVUnknown from the uniquing map. + SE->UniqueSCEVs.RemoveNode(this); -const Type *SCEVUnknown::getType() const { - return V->getType(); + // Update this SCEVUnknown to point to the new value. This is needed + // because there may still be outstanding SCEVs which still point to + // this SCEVUnknown. + setValPtr(New); } -bool SCEVUnknown::isSizeOf(const Type *&AllocTy) const { - if (ConstantExpr *VCE = dyn_cast(V)) +bool SCEVUnknown::isSizeOf(Type *&AllocTy) const { + if (ConstantExpr *VCE = dyn_cast(getValue())) if (VCE->getOpcode() == Instruction::PtrToInt) if (ConstantExpr *CE = dyn_cast(VCE->getOperand(0))) if (CE->getOpcode() == Instruction::GetElementPtr && @@ -380,15 +371,15 @@ bool SCEVUnknown::isSizeOf(const Type *&AllocTy) const { return false; } -bool SCEVUnknown::isAlignOf(const Type *&AllocTy) const { - if (ConstantExpr *VCE = dyn_cast(V)) +bool SCEVUnknown::isAlignOf(Type *&AllocTy) const { + if (ConstantExpr *VCE = dyn_cast(getValue())) if (VCE->getOpcode() == Instruction::PtrToInt) if (ConstantExpr *CE = dyn_cast(VCE->getOperand(0))) if (CE->getOpcode() == Instruction::GetElementPtr && CE->getOperand(0)->isNullValue()) { - const Type *Ty = + Type *Ty = cast(CE->getOperand(0)->getType())->getElementType(); - if (const StructType *STy = dyn_cast(Ty)) + if (StructType *STy = dyn_cast(Ty)) if (!STy->isPacked() && CE->getNumOperands() == 3 && CE->getOperand(1)->isNullValue()) { @@ -405,15 +396,15 @@ bool SCEVUnknown::isAlignOf(const Type *&AllocTy) const { return false; } -bool SCEVUnknown::isOffsetOf(const Type *&CTy, Constant *&FieldNo) const { - if (ConstantExpr *VCE = dyn_cast(V)) +bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const { + if (ConstantExpr *VCE = dyn_cast(getValue())) if (VCE->getOpcode() == Instruction::PtrToInt) if (ConstantExpr *CE = dyn_cast(VCE->getOperand(0))) if (CE->getOpcode() == Instruction::GetElementPtr && CE->getNumOperands() == 3 && CE->getOperand(0)->isNullValue() && CE->getOperand(1)->isNullValue()) { - const Type *Ty = + Type *Ty = cast(CE->getOperand(0)->getType())->getElementType(); // Ignore vector types here so that ScalarEvolutionExpander doesn't // emit getelementptrs that index into vectors. @@ -427,187 +418,180 @@ bool SCEVUnknown::isOffsetOf(const Type *&CTy, Constant *&FieldNo) const { return false; } -void SCEVUnknown::print(raw_ostream &OS) const { - const Type *AllocTy; - if (isSizeOf(AllocTy)) { - OS << "sizeof(" << *AllocTy << ")"; - return; - } - if (isAlignOf(AllocTy)) { - OS << "alignof(" << *AllocTy << ")"; - return; - } - - const Type *CTy; - Constant *FieldNo; - if (isOffsetOf(CTy, FieldNo)) { - OS << "offsetof(" << *CTy << ", "; - WriteAsOperand(OS, FieldNo, false); - OS << ")"; - return; - } - - // Otherwise just print it normally. - WriteAsOperand(OS, V, false); -} - //===----------------------------------------------------------------------===// // SCEV Utilities //===----------------------------------------------------------------------===// -static bool CompareTypes(const Type *A, const Type *B) { - if (A->getTypeID() != B->getTypeID()) - return A->getTypeID() < B->getTypeID(); - if (const IntegerType *AI = dyn_cast(A)) { - const IntegerType *BI = cast(B); - return AI->getBitWidth() < BI->getBitWidth(); - } - if (const PointerType *AI = dyn_cast(A)) { - const PointerType *BI = cast(B); - return CompareTypes(AI->getElementType(), BI->getElementType()); - } - if (const ArrayType *AI = dyn_cast(A)) { - const ArrayType *BI = cast(B); - if (AI->getNumElements() != BI->getNumElements()) - return AI->getNumElements() < BI->getNumElements(); - return CompareTypes(AI->getElementType(), BI->getElementType()); - } - if (const VectorType *AI = dyn_cast(A)) { - const VectorType *BI = cast(B); - if (AI->getNumElements() != BI->getNumElements()) - return AI->getNumElements() < BI->getNumElements(); - return CompareTypes(AI->getElementType(), BI->getElementType()); - } - if (const StructType *AI = dyn_cast(A)) { - const StructType *BI = cast(B); - if (AI->getNumElements() != BI->getNumElements()) - return AI->getNumElements() < BI->getNumElements(); - for (unsigned i = 0, e = AI->getNumElements(); i != e; ++i) - if (CompareTypes(AI->getElementType(i), BI->getElementType(i)) || - CompareTypes(BI->getElementType(i), AI->getElementType(i))) - return CompareTypes(AI->getElementType(i), BI->getElementType(i)); - } - return false; -} - namespace { /// SCEVComplexityCompare - Return true if the complexity of the LHS is less /// than the complexity of the RHS. This comparator is used to canonicalize /// expressions. class SCEVComplexityCompare { - LoopInfo *LI; + const LoopInfo *const LI; public: - explicit SCEVComplexityCompare(LoopInfo *li) : LI(li) {} + explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {} + // Return true or false if LHS is less than, or at least RHS, respectively. bool operator()(const SCEV *LHS, const SCEV *RHS) const { + return compare(LHS, RHS) < 0; + } + + // Return negative, zero, or positive, if LHS is less than, equal to, or + // greater than RHS, respectively. A three-way result allows recursive + // comparisons to be more efficient. + int compare(const SCEV *LHS, const SCEV *RHS) const { // Fast-path: SCEVs are uniqued so we can do a quick equality check. if (LHS == RHS) - return false; + return 0; // Primarily, sort the SCEVs by their getSCEVType(). - if (LHS->getSCEVType() != RHS->getSCEVType()) - return LHS->getSCEVType() < RHS->getSCEVType(); + unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); + if (LType != RType) + return (int)LType - (int)RType; // Aside from the getSCEVType() ordering, the particular ordering // isn't very important except that it's beneficial to be consistent, // so that (a + b) and (b + a) don't end up as different expressions. - - // Sort SCEVUnknown values with some loose heuristics. TODO: This is - // not as complete as it could be. - if (const SCEVUnknown *LU = dyn_cast(LHS)) { + switch (LType) { + case scUnknown: { + const SCEVUnknown *LU = cast(LHS); const SCEVUnknown *RU = cast(RHS); + // Sort SCEVUnknown values with some loose heuristics. TODO: This is + // not as complete as it could be. + const Value *LV = LU->getValue(), *RV = RU->getValue(); + // Order pointer values after integer values. This helps SCEVExpander // form GEPs. - if (LU->getType()->isPointerTy() && !RU->getType()->isPointerTy()) - return false; - if (RU->getType()->isPointerTy() && !LU->getType()->isPointerTy()) - return true; + bool LIsPointer = LV->getType()->isPointerTy(), + RIsPointer = RV->getType()->isPointerTy(); + if (LIsPointer != RIsPointer) + return (int)LIsPointer - (int)RIsPointer; // Compare getValueID values. - if (LU->getValue()->getValueID() != RU->getValue()->getValueID()) - return LU->getValue()->getValueID() < RU->getValue()->getValueID(); + unsigned LID = LV->getValueID(), + RID = RV->getValueID(); + if (LID != RID) + return (int)LID - (int)RID; // Sort arguments by their position. - if (const Argument *LA = dyn_cast(LU->getValue())) { - const Argument *RA = cast(RU->getValue()); - return LA->getArgNo() < RA->getArgNo(); + if (const Argument *LA = dyn_cast(LV)) { + const Argument *RA = cast(RV); + unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); + return (int)LArgNo - (int)RArgNo; } - // For instructions, compare their loop depth, and their opcode. - // This is pretty loose. - if (Instruction *LV = dyn_cast(LU->getValue())) { - Instruction *RV = cast(RU->getValue()); + // For instructions, compare their loop depth, and their operand + // count. This is pretty loose. + if (const Instruction *LInst = dyn_cast(LV)) { + const Instruction *RInst = cast(RV); // Compare loop depths. - if (LI->getLoopDepth(LV->getParent()) != - LI->getLoopDepth(RV->getParent())) - return LI->getLoopDepth(LV->getParent()) < - LI->getLoopDepth(RV->getParent()); - - // Compare opcodes. - if (LV->getOpcode() != RV->getOpcode()) - return LV->getOpcode() < RV->getOpcode(); + const BasicBlock *LParent = LInst->getParent(), + *RParent = RInst->getParent(); + if (LParent != RParent) { + unsigned LDepth = LI->getLoopDepth(LParent), + RDepth = LI->getLoopDepth(RParent); + if (LDepth != RDepth) + return (int)LDepth - (int)RDepth; + } // Compare the number of operands. - if (LV->getNumOperands() != RV->getNumOperands()) - return LV->getNumOperands() < RV->getNumOperands(); + unsigned LNumOps = LInst->getNumOperands(), + RNumOps = RInst->getNumOperands(); + return (int)LNumOps - (int)RNumOps; } - return false; + return 0; } - // Compare constant values. - if (const SCEVConstant *LC = dyn_cast(LHS)) { + case scConstant: { + const SCEVConstant *LC = cast(LHS); const SCEVConstant *RC = cast(RHS); - if (LC->getValue()->getBitWidth() != RC->getValue()->getBitWidth()) - return LC->getValue()->getBitWidth() < RC->getValue()->getBitWidth(); - return LC->getValue()->getValue().ult(RC->getValue()->getValue()); + + // Compare constant values. + const APInt &LA = LC->getValue()->getValue(); + const APInt &RA = RC->getValue()->getValue(); + unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); + if (LBitWidth != RBitWidth) + return (int)LBitWidth - (int)RBitWidth; + return LA.ult(RA) ? -1 : 1; } - // Compare addrec loop depths. - if (const SCEVAddRecExpr *LA = dyn_cast(LHS)) { + case scAddRecExpr: { + const SCEVAddRecExpr *LA = cast(LHS); const SCEVAddRecExpr *RA = cast(RHS); - if (LA->getLoop()->getLoopDepth() != RA->getLoop()->getLoopDepth()) - return LA->getLoop()->getLoopDepth() < RA->getLoop()->getLoopDepth(); + + // Compare addrec loop depths. + const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); + if (LLoop != RLoop) { + unsigned LDepth = LLoop->getLoopDepth(), + RDepth = RLoop->getLoopDepth(); + if (LDepth != RDepth) + return (int)LDepth - (int)RDepth; + } + + // Addrec complexity grows with operand count. + unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); + if (LNumOps != RNumOps) + return (int)LNumOps - (int)RNumOps; + + // Lexicographically compare. + for (unsigned i = 0; i != LNumOps; ++i) { + long X = compare(LA->getOperand(i), RA->getOperand(i)); + if (X != 0) + return X; + } + + return 0; } - // Lexicographically compare n-ary expressions. - if (const SCEVNAryExpr *LC = dyn_cast(LHS)) { + case scAddExpr: + case scMulExpr: + case scSMaxExpr: + case scUMaxExpr: { + const SCEVNAryExpr *LC = cast(LHS); const SCEVNAryExpr *RC = cast(RHS); - for (unsigned i = 0, e = LC->getNumOperands(); i != e; ++i) { - if (i >= RC->getNumOperands()) - return false; - if (operator()(LC->getOperand(i), RC->getOperand(i))) - return true; - if (operator()(RC->getOperand(i), LC->getOperand(i))) - return false; + + // Lexicographically compare n-ary expressions. + unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); + for (unsigned i = 0; i != LNumOps; ++i) { + if (i >= RNumOps) + return 1; + long X = compare(LC->getOperand(i), RC->getOperand(i)); + if (X != 0) + return X; } - return LC->getNumOperands() < RC->getNumOperands(); + return (int)LNumOps - (int)RNumOps; } - // Lexicographically compare udiv expressions. - if (const SCEVUDivExpr *LC = dyn_cast(LHS)) { + case scUDivExpr: { + const SCEVUDivExpr *LC = cast(LHS); const SCEVUDivExpr *RC = cast(RHS); - if (operator()(LC->getLHS(), RC->getLHS())) - return true; - if (operator()(RC->getLHS(), LC->getLHS())) - return false; - if (operator()(LC->getRHS(), RC->getRHS())) - return true; - if (operator()(RC->getRHS(), LC->getRHS())) - return false; - return false; + + // Lexicographically compare udiv expressions. + long X = compare(LC->getLHS(), RC->getLHS()); + if (X != 0) + return X; + return compare(LC->getRHS(), RC->getRHS()); } - // Compare cast expressions by operand. - if (const SCEVCastExpr *LC = dyn_cast(LHS)) { + case scTruncate: + case scZeroExtend: + case scSignExtend: { + const SCEVCastExpr *LC = cast(LHS); const SCEVCastExpr *RC = cast(RHS); - return operator()(LC->getOperand(), RC->getOperand()); + + // Compare cast expressions by operand. + return compare(LC->getOperand(), RC->getOperand()); + } + + default: + break; } llvm_unreachable("Unknown SCEV kind!"); - return false; + return 0; } }; } @@ -628,8 +612,9 @@ static void GroupByComplexity(SmallVectorImpl &Ops, if (Ops.size() == 2) { // This is the common case, which also happens to be trivially simple. // Special case it. - if (SCEVComplexityCompare(LI)(Ops[1], Ops[0])) - std::swap(Ops[0], Ops[1]); + const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; + if (SCEVComplexityCompare(LI)(RHS, LHS)) + std::swap(LHS, RHS); return; } @@ -667,7 +652,7 @@ static void GroupByComplexity(SmallVectorImpl &Ops, /// Assume, K > 0. static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, - const Type* ResultTy) { + Type* ResultTy) { // Handle the simplest case efficiently. if (K == 1) return SE.getTruncateOrZeroExtend(It, ResultTy); @@ -757,11 +742,11 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, MultiplyFactor = MultiplyFactor.trunc(W); // Calculate the product, at width T+W - const IntegerType *CalculationTy = IntegerType::get(SE.getContext(), + IntegerType *CalculationTy = IntegerType::get(SE.getContext(), CalculationBits); const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); for (unsigned i = 1; i != K; ++i) { - const SCEV *S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType())); + const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i)); Dividend = SE.getMulExpr(Dividend, SE.getTruncateOrZeroExtend(S, CalculationTy)); } @@ -805,7 +790,7 @@ const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It, //===----------------------------------------------------------------------===// const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, - const Type *Ty) { + Type *Ty) { assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && "This is not a truncating conversion!"); assert(isSCEVable(Ty) && @@ -822,7 +807,8 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast(Op)) return getConstant( - cast(ConstantExpr::getTrunc(SC->getValue(), Ty))); + cast(ConstantExpr::getTrunc(SC->getValue(), + getEffectiveSCEVType(Ty)))); // trunc(trunc(x)) --> trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast(Op)) @@ -836,17 +822,54 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) return getTruncateOrZeroExtend(SZ->getOperand(), Ty); + // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can + // eliminate all the truncates. + if (const SCEVAddExpr *SA = dyn_cast(Op)) { + SmallVector Operands; + bool hasTrunc = false; + for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) { + const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty); + hasTrunc = isa(S); + Operands.push_back(S); + } + if (!hasTrunc) + return getAddExpr(Operands); + UniqueSCEVs.FindNodeOrInsertPos(ID, IP); // Mutates IP, returns NULL. + } + + // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can + // eliminate all the truncates. + if (const SCEVMulExpr *SM = dyn_cast(Op)) { + SmallVector Operands; + bool hasTrunc = false; + for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) { + const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty); + hasTrunc = isa(S); + Operands.push_back(S); + } + if (!hasTrunc) + return getMulExpr(Operands); + UniqueSCEVs.FindNodeOrInsertPos(ID, IP); // Mutates IP, returns NULL. + } + // If the input value is a chrec scev, truncate the chrec's operands. if (const SCEVAddRecExpr *AddRec = dyn_cast(Op)) { SmallVector Operands; for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty)); - return getAddRecExpr(Operands, AddRec->getLoop()); + return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap); } - // The cast wasn't folded; create an explicit cast node. - // Recompute the insert position, as it may have been invalidated. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + // As a special case, fold trunc(undef) to undef. We don't want to + // know too much about SCEVUnknowns, but this special case is handy + // and harmless. + if (const SCEVUnknown *U = dyn_cast(Op)) + if (isa(U->getValue())) + return getSCEV(UndefValue::get(Ty)); + + // The cast wasn't folded; create an explicit cast node. We can reuse + // the existing insert position since if we get here, we won't have + // made any changes which would invalidate it. SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); @@ -854,7 +877,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, } const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, - const Type *Ty) { + Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -862,12 +885,10 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Ty = getEffectiveSCEVType(Ty); // Fold if the operand is constant. - if (const SCEVConstant *SC = dyn_cast(Op)) { - const Type *IntTy = getEffectiveSCEVType(Ty); - Constant *C = ConstantExpr::getZExt(SC->getValue(), IntTy); - if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty); - return getConstant(cast(C)); - } + if (const SCEVConstant *SC = dyn_cast(Op)) + return getConstant( + cast(ConstantExpr::getZExt(SC->getValue(), + getEffectiveSCEVType(Ty)))); // zext(zext(x)) --> zext(x) if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) @@ -882,6 +903,19 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, void *IP = 0; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + // zext(trunc(x)) --> zext(x) or x or trunc(x) + if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { + // It's possible the bits taken off by the truncate were all zero bits. If + // so, we should be able to simplify this further. + const SCEV *X = ST->getOperand(); + ConstantRange CR = getUnsignedRange(X); + unsigned TruncBits = getTypeSizeInBits(ST->getType()); + unsigned NewBits = getTypeSizeInBits(Ty); + if (CR.truncate(TruncBits).zeroExtend(NewBits).contains( + CR.zextOrTrunc(NewBits))) + return getTruncateOrZeroExtend(X, Ty); + } + // If the input value is a chrec scev, and we can prove that the value // did not overflow the old, smaller, value, we can zero extend all of the // operands (often constants). This allows analysis of something like @@ -895,10 +929,10 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. - if (AR->hasNoUnsignedWrap()) + if (AR->getNoWrapFlags(SCEV::FlagNUW)) return getAddRecExpr(getZeroExtendExpr(Start, Ty), getZeroExtendExpr(Step, Ty), - L); + L, AR->getNoWrapFlags()); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -920,7 +954,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); if (MaxBECount == RecastedMaxBECount) { - const Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); + Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no unsigned overflow. const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step); const SCEV *Add = getAddExpr(Start, ZMul); @@ -928,12 +962,14 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, getAddExpr(getZeroExtendExpr(Start, WideTy), getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy), getZeroExtendExpr(Step, WideTy))); - if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) + if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) { + // Cache knowledge of AR NUW, which is propagated to this AddRec. + const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr(getZeroExtendExpr(Start, Ty), getZeroExtendExpr(Step, Ty), - L); - + L, AR->getNoWrapFlags()); + } // Similar to above, only this time treat the step value as signed. // This covers loops that count down. const SCEV *SMul = getMulExpr(CastedMaxBECount, Step); @@ -942,11 +978,15 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, getAddExpr(getZeroExtendExpr(Start, WideTy), getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy), getSignExtendExpr(Step, WideTy))); - if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) + if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) { + // Cache knowledge of AR NW, which is propagated to this AddRec. + // Negative step causes unsigned wrap, but it still can't self-wrap. + const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr(getZeroExtendExpr(Start, Ty), getSignExtendExpr(Step, Ty), - L); + L, AR->getNoWrapFlags()); + } } // If the backedge is guarded by a comparison with the pre-inc value @@ -959,22 +999,29 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) && isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, - AR->getPostIncExpr(*this), N))) + AR->getPostIncExpr(*this), N))) { + // Cache knowledge of AR NUW, which is propagated to this AddRec. + const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr(getZeroExtendExpr(Start, Ty), getZeroExtendExpr(Step, Ty), - L); + L, AR->getNoWrapFlags()); + } } else if (isKnownNegative(Step)) { const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - getSignedRange(Step).getSignedMin()); - if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) && - (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) || + if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) || + (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) && isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, - AR->getPostIncExpr(*this), N))) + AR->getPostIncExpr(*this), N))) { + // Cache knowledge of AR NW, which is propagated to this AddRec. + // Negative step causes unsigned wrap, but it still can't self-wrap. + const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr(getZeroExtendExpr(Start, Ty), getSignExtendExpr(Step, Ty), - L); + L, AR->getNoWrapFlags()); + } } } } @@ -988,8 +1035,95 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, return S; } +// Get the limit of a recurrence such that incrementing by Step cannot cause +// signed overflow as long as the value of the recurrence within the loop does +// not exceed this limit before incrementing. +static const SCEV *getOverflowLimitForStep(const SCEV *Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { + unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); + if (SE->isKnownPositive(Step)) { + *Pred = ICmpInst::ICMP_SLT; + return SE->getConstant(APInt::getSignedMinValue(BitWidth) - + SE->getSignedRange(Step).getSignedMax()); + } + if (SE->isKnownNegative(Step)) { + *Pred = ICmpInst::ICMP_SGT; + return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - + SE->getSignedRange(Step).getSignedMin()); + } + return 0; +} + +// The recurrence AR has been shown to have no signed wrap. Typically, if we can +// prove NSW for AR, then we can just as easily prove NSW for its preincrement +// or postincrement sibling. This allows normalizing a sign extended AddRec as +// such: {sext(Step + Start),+,Step} => {(Step + sext(Start),+,Step} As a +// result, the expression "Step + sext(PreIncAR)" is congruent with +// "sext(PostIncAR)" +static const SCEV *getPreStartForSignExtend(const SCEVAddRecExpr *AR, + Type *Ty, + ScalarEvolution *SE) { + const Loop *L = AR->getLoop(); + const SCEV *Start = AR->getStart(); + const SCEV *Step = AR->getStepRecurrence(*SE); + + // Check for a simple looking step prior to loop entry. + const SCEVAddExpr *SA = dyn_cast(Start); + if (!SA || SA->getNumOperands() != 2 || SA->getOperand(0) != Step) + return 0; + + // This is a postinc AR. Check for overflow on the preinc recurrence using the + // same three conditions that getSignExtendedExpr checks. + + // 1. NSW flags on the step increment. + const SCEV *PreStart = SA->getOperand(1); + const SCEVAddRecExpr *PreAR = dyn_cast( + SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); + + if (PreAR && PreAR->getNoWrapFlags(SCEV::FlagNSW)) + return PreStart; + + // 2. Direct overflow check on the step operation's expression. + unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); + Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); + const SCEV *OperandExtendedStart = + SE->getAddExpr(SE->getSignExtendExpr(PreStart, WideTy), + SE->getSignExtendExpr(Step, WideTy)); + if (SE->getSignExtendExpr(Start, WideTy) == OperandExtendedStart) { + // Cache knowledge of PreAR NSW. + if (PreAR) + const_cast(PreAR)->setNoWrapFlags(SCEV::FlagNSW); + // FIXME: this optimization needs a unit test + DEBUG(dbgs() << "SCEV: untested prestart overflow check\n"); + return PreStart; + } + + // 3. Loop precondition. + ICmpInst::Predicate Pred; + const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, SE); + + if (OverflowLimit && + SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) { + return PreStart; + } + return 0; +} + +// Get the normalized sign-extended expression for this AddRec's Start. +static const SCEV *getSignExtendAddRecStart(const SCEVAddRecExpr *AR, + Type *Ty, + ScalarEvolution *SE) { + const SCEV *PreStart = getPreStartForSignExtend(AR, Ty, SE); + if (!PreStart) + return SE->getSignExtendExpr(AR->getStart(), Ty); + + return SE->getAddExpr(SE->getSignExtendExpr(AR->getStepRecurrence(*SE), Ty), + SE->getSignExtendExpr(PreStart, Ty)); +} + const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, - const Type *Ty) { + Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -997,17 +1131,19 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, Ty = getEffectiveSCEVType(Ty); // Fold if the operand is constant. - if (const SCEVConstant *SC = dyn_cast(Op)) { - const Type *IntTy = getEffectiveSCEVType(Ty); - Constant *C = ConstantExpr::getSExt(SC->getValue(), IntTy); - if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty); - return getConstant(cast(C)); - } + if (const SCEVConstant *SC = dyn_cast(Op)) + return getConstant( + cast(ConstantExpr::getSExt(SC->getValue(), + getEffectiveSCEVType(Ty)))); // sext(sext(x)) --> sext(x) if (const SCEVSignExtendExpr *SS = dyn_cast(Op)) return getSignExtendExpr(SS->getOperand(), Ty); + // sext(zext(x)) --> zext(x) + if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) + return getZeroExtendExpr(SZ->getOperand(), Ty); + // Before doing any expensive analysis, check to see if we've already // computed a SCEV for this Op and Ty. FoldingSetNodeID ID; @@ -1017,6 +1153,23 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, void *IP = 0; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + // If the input value is provably positive, build a zext instead. + if (isKnownNonNegative(Op)) + return getZeroExtendExpr(Op, Ty); + + // sext(trunc(x)) --> sext(x) or x or trunc(x) + if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { + // It's possible the bits taken off by the truncate were all sign bits. If + // so, we should be able to simplify this further. + const SCEV *X = ST->getOperand(); + ConstantRange CR = getSignedRange(X); + unsigned TruncBits = getTypeSizeInBits(ST->getType()); + unsigned NewBits = getTypeSizeInBits(Ty); + if (CR.truncate(TruncBits).signExtend(NewBits).contains( + CR.sextOrTrunc(NewBits))) + return getTruncateOrSignExtend(X, Ty); + } + // If the input value is a chrec scev, and we can prove that the value // did not overflow the old, smaller, value, we can sign extend all of the // operands (often constants). This allows analysis of something like @@ -1030,10 +1183,10 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. - if (AR->hasNoSignedWrap()) - return getAddRecExpr(getSignExtendExpr(Start, Ty), + if (AR->getNoWrapFlags(SCEV::FlagNSW)) + return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), getSignExtendExpr(Step, Ty), - L); + L, SCEV::FlagNSW); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1055,7 +1208,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); if (MaxBECount == RecastedMaxBECount) { - const Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); + Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no signed overflow. const SCEV *SMul = getMulExpr(CastedMaxBECount, Step); const SCEV *Add = getAddExpr(Start, SMul); @@ -1063,12 +1216,14 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, getAddExpr(getSignExtendExpr(Start, WideTy), getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy), getSignExtendExpr(Step, WideTy))); - if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) + if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) { + // Cache knowledge of AR NSW, which is propagated to this AddRec. + const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getSignExtendExpr(Start, Ty), + return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), getSignExtendExpr(Step, Ty), - L); - + L, AR->getNoWrapFlags()); + } // Similar to above, only this time treat the step value as unsigned. // This covers loops that count up with an unsigned step. const SCEV *UMul = getMulExpr(CastedMaxBECount, Step); @@ -1077,39 +1232,32 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, getAddExpr(getSignExtendExpr(Start, WideTy), getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy), getZeroExtendExpr(Step, WideTy))); - if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) + if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) { + // Cache knowledge of AR NSW, which is propagated to this AddRec. + const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); // Return the expression with the addrec on the outside. - return getAddRecExpr(getSignExtendExpr(Start, Ty), + return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), getZeroExtendExpr(Step, Ty), - L); + L, AR->getNoWrapFlags()); + } } // If the backedge is guarded by a comparison with the pre-inc value // the addrec is safe. Also, if the entry is guarded by a comparison // with the start value and the backedge is guarded by a comparison // with the post-inc value, the addrec is safe. - if (isKnownPositive(Step)) { - const SCEV *N = getConstant(APInt::getSignedMinValue(BitWidth) - - getSignedRange(Step).getSignedMax()); - if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SLT, AR, N) || - (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SLT, Start, N) && - isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SLT, - AR->getPostIncExpr(*this), N))) - // Return the expression with the addrec on the outside. - return getAddRecExpr(getSignExtendExpr(Start, Ty), - getSignExtendExpr(Step, Ty), - L); - } else if (isKnownNegative(Step)) { - const SCEV *N = getConstant(APInt::getSignedMaxValue(BitWidth) - - getSignedRange(Step).getSignedMin()); - if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SGT, AR, N) || - (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SGT, Start, N) && - isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SGT, - AR->getPostIncExpr(*this), N))) - // Return the expression with the addrec on the outside. - return getAddRecExpr(getSignExtendExpr(Start, Ty), - getSignExtendExpr(Step, Ty), - L); + ICmpInst::Predicate Pred; + const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, this); + if (OverflowLimit && + (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || + (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) && + isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this), + OverflowLimit)))) { + // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. + const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); + return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this), + getSignExtendExpr(Step, Ty), + L, AR->getNoWrapFlags()); } } } @@ -1127,7 +1275,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, /// unspecified bits out to the given type. /// const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, - const Type *Ty) { + Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -1163,9 +1311,16 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end(); I != E; ++I) Ops.push_back(getAnyExtendExpr(*I, Ty)); - return getAddRecExpr(Ops, AR->getLoop()); + return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW); } + // As a special case, fold anyext(undef) to undef. We don't want to + // know too much about SCEVUnknowns, but this special case is handy + // and harmless. + if (const SCEVUnknown *U = dyn_cast(Op)) + if (isa(U->getValue())) + return getSCEV(UndefValue::get(Ty)); + // If the expression is obviously signed, use the sext cast value. if (isa(Op)) return SExt; @@ -1208,8 +1363,19 @@ CollectAddOperandsWithScales(DenseMap &M, ScalarEvolution &SE) { bool Interesting = false; - // Iterate over the add operands. - for (unsigned i = 0, e = NumOperands; i != e; ++i) { + // Iterate over the add operands. They are sorted, with constants first. + unsigned i = 0; + while (const SCEVConstant *C = dyn_cast(Ops[i])) { + ++i; + // Pull a buried constant out to the outside. + if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero()) + Interesting = true; + AccumulatedConstant += Scale * C->getValue()->getValue(); + } + + // Next comes everything else. We're especially interested in multiplies + // here, but they're in the middle, so just visit the rest with one loop. + for (; i != NumOperands; ++i) { const SCEVMulExpr *Mul = dyn_cast(Ops[i]); if (Mul && isa(Mul->getOperand(0))) { APInt NewScale = @@ -1237,11 +1403,6 @@ CollectAddOperandsWithScales(DenseMap &M, Interesting = true; } } - } else if (const SCEVConstant *C = dyn_cast(Ops[i])) { - // Pull a buried constant out to the outside. - if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero()) - Interesting = true; - AccumulatedConstant += Scale * C->getValue()->getValue(); } else { // An ordinary operand. Update the map. std::pair::iterator, bool> Pair = @@ -1271,25 +1432,31 @@ namespace { /// getAddExpr - Get a canonical add expression, or something simpler if /// possible. const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, - bool HasNUW, bool HasNSW) { + SCEV::NoWrapFlags Flags) { + assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && + "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty add!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG + Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); for (unsigned i = 1, e = Ops.size(); i != e; ++i) - assert(getEffectiveSCEVType(Ops[i]->getType()) == - getEffectiveSCEVType(Ops[0]->getType()) && + assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && "SCEVAddExpr operand types don't match!"); #endif - // If HasNSW is true and all the operands are non-negative, infer HasNUW. - if (!HasNUW && HasNSW) { + // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. + // And vice-versa. + int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; + SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask); + if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) { bool All = true; - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - if (!isKnownNonNegative(Ops[i])) { + for (SmallVectorImpl::const_iterator I = Ops.begin(), + E = Ops.end(); I != E; ++I) + if (!isKnownNonNegative(*I)) { All = false; break; } - if (All) HasNUW = true; + if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); } // Sort by complexity, this groups all similar expression types together. @@ -1318,22 +1485,29 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, if (Ops.size() == 1) return Ops[0]; } - // Okay, check to see if the same value occurs in the operand list twice. If - // so, merge them together into an multiply expression. Since we sorted the - // list, these values are required to be adjacent. - const Type *Ty = Ops[0]->getType(); - for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) + // Okay, check to see if the same value occurs in the operand list more than + // once. If so, merge them together into an multiply expression. Since we + // sorted the list, these values are required to be adjacent. + Type *Ty = Ops[0]->getType(); + bool FoundMatch = false; + for (unsigned i = 0, e = Ops.size(); i != e-1; ++i) if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2 - // Found a match, merge the two values into a multiply, and add any - // remaining values to the result. - const SCEV *Two = getIntegerSCEV(2, Ty); - const SCEV *Mul = getMulExpr(Ops[i], Two); - if (Ops.size() == 2) + // Scan ahead to count how many equal operands there are. + unsigned Count = 2; + while (i+Count != e && Ops[i+Count] == Ops[i]) + ++Count; + // Merge the values into a multiply. + const SCEV *Scale = getConstant(Ty, Count); + const SCEV *Mul = getMulExpr(Scale, Ops[i]); + if (Ops.size() == Count) return Mul; - Ops.erase(Ops.begin()+i, Ops.begin()+i+2); - Ops.push_back(Mul); - return getAddExpr(Ops, HasNUW, HasNSW); + Ops[i] = Mul; + Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count); + --i; e -= Count - 1; + FoundMatch = true; } + if (FoundMatch) + return getAddExpr(Ops, Flags); // Check for truncates. If all the operands are truncated from the same // type, see if factoring out the truncate would permit the result to be @@ -1341,8 +1515,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // if the contents of the resulting outer trunc fold to something simple. for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { const SCEVTruncateExpr *Trunc = cast(Ops[Idx]); - const Type *DstType = Trunc->getType(); - const Type *SrcType = Trunc->getOperand()->getType(); + Type *DstType = Trunc->getType(); + Type *SrcType = Trunc->getOperand()->getType(); SmallVector LargeOps; bool Ok = true; // Check all the operands to see if they can be represented in the @@ -1383,7 +1557,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, } if (Ok) { // Evaluate the expression in the larger type. - const SCEV *Fold = getAddExpr(LargeOps, HasNUW, HasNSW); + const SCEV *Fold = getAddExpr(LargeOps, Flags); // If it folds to something simple, use it. Otherwise, don't. if (isa(Fold) || isa(Fold)) return getTruncateExpr(Fold, DstType); @@ -1400,8 +1574,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, while (const SCEVAddExpr *Add = dyn_cast(Ops[Idx])) { // If we have an add, expand the add operands onto the end of the operands // list. - Ops.insert(Ops.end(), Add->op_begin(), Add->op_end()); Ops.erase(Ops.begin()+Idx); + Ops.append(Add->op_begin(), Add->op_end()); DeletedAdd = true; } @@ -1430,7 +1604,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // re-generate the operands list. Group the operands by constant scale, // to avoid multiplying by the same constant scale multiple times. std::map, APIntCompare> MulOpLists; - for (SmallVector::iterator I = NewOps.begin(), + for (SmallVector::const_iterator I = NewOps.begin(), E = NewOps.end(); I != E; ++I) MulOpLists[M.find(*I)->second].push_back(*I); // Re-generate the operands list. @@ -1443,7 +1617,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, Ops.push_back(getMulExpr(getConstant(I->first), getAddExpr(I->second))); if (Ops.empty()) - return getIntegerSCEV(0, Ty); + return getConstant(Ty, 0); if (Ops.size() == 1) return Ops[0]; return getAddExpr(Ops); @@ -1457,20 +1631,23 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, const SCEVMulExpr *Mul = cast(Ops[Idx]); for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) { const SCEV *MulOpSCEV = Mul->getOperand(MulOp); + if (isa(MulOpSCEV)) + continue; for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) - if (MulOpSCEV == Ops[AddOp] && !isa(Ops[AddOp])) { + if (MulOpSCEV == Ops[AddOp]) { // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1)) const SCEV *InnerMul = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { // If the multiply has more than two operands, we must get the // Y*Z term. - SmallVector MulOps(Mul->op_begin(), Mul->op_end()); - MulOps.erase(MulOps.begin()+MulOp); + SmallVector MulOps(Mul->op_begin(), + Mul->op_begin()+MulOp); + MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); InnerMul = getMulExpr(MulOps); } - const SCEV *One = getIntegerSCEV(1, Ty); - const SCEV *AddOne = getAddExpr(InnerMul, One); - const SCEV *OuterMul = getMulExpr(AddOne, Ops[AddOp]); + const SCEV *One = getConstant(Ty, 1); + const SCEV *AddOne = getAddExpr(One, InnerMul); + const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV); if (Ops.size() == 2) return OuterMul; if (AddOp < Idx) { Ops.erase(Ops.begin()+AddOp); @@ -1497,15 +1674,15 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { SmallVector MulOps(Mul->op_begin(), - Mul->op_end()); - MulOps.erase(MulOps.begin()+MulOp); + Mul->op_begin()+MulOp); + MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); InnerMul1 = getMulExpr(MulOps); } const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0); if (OtherMul->getNumOperands() != 2) { SmallVector MulOps(OtherMul->op_begin(), - OtherMul->op_end()); - MulOps.erase(MulOps.begin()+OMulOp); + OtherMul->op_begin()+OMulOp); + MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end()); InnerMul2 = getMulExpr(MulOps); } const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2); @@ -1534,7 +1711,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, const SCEVAddRecExpr *AddRec = cast(Ops[Idx]); const Loop *AddRecLoop = AddRec->getLoop(); for (unsigned i = 0, e = Ops.size(); i != e; ++i) - if (Ops[i]->isLoopInvariant(AddRecLoop)) { + if (isLoopInvariant(Ops[i], AddRecLoop)) { LIOps.push_back(Ops[i]); Ops.erase(Ops.begin()+i); --i; --e; @@ -1549,9 +1726,11 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, AddRec->op_end()); AddRecOps[0] = getAddExpr(LIOps); - // It's tempting to propagate NUW/NSW flags here, but nuw/nsw addition - // is not associative so this isn't necessarily safe. - const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop); + // Build the new addrec. Propagate the NUW and NSW flags if both the + // outer add and the inner addrec are guaranteed to have no overflow. + // Always propagate NW. + Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW)); + const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -1569,30 +1748,32 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // there are multiple AddRec's with the same loop induction variable being // added together. If so, we can fold them. for (unsigned OtherIdx = Idx+1; - OtherIdx < Ops.size() && isa(Ops[OtherIdx]);++OtherIdx) - if (OtherIdx != Idx) { - const SCEVAddRecExpr *OtherAddRec = cast(Ops[OtherIdx]); - if (AddRecLoop == OtherAddRec->getLoop()) { - // Other + {A,+,B} + {C,+,D} --> Other + {A+C,+,B+D} - SmallVector NewOps(AddRec->op_begin(), - AddRec->op_end()); - for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) { - if (i >= NewOps.size()) { - NewOps.insert(NewOps.end(), OtherAddRec->op_begin()+i, - OtherAddRec->op_end()); - break; + OtherIdx < Ops.size() && isa(Ops[OtherIdx]); + ++OtherIdx) + if (AddRecLoop == cast(Ops[OtherIdx])->getLoop()) { + // Other + {A,+,B} + {C,+,D} --> Other + {A+C,+,B+D} + SmallVector AddRecOps(AddRec->op_begin(), + AddRec->op_end()); + for (; OtherIdx != Ops.size() && isa(Ops[OtherIdx]); + ++OtherIdx) + if (const SCEVAddRecExpr *OtherAddRec = + dyn_cast(Ops[OtherIdx])) + if (OtherAddRec->getLoop() == AddRecLoop) { + for (unsigned i = 0, e = OtherAddRec->getNumOperands(); + i != e; ++i) { + if (i >= AddRecOps.size()) { + AddRecOps.append(OtherAddRec->op_begin()+i, + OtherAddRec->op_end()); + break; + } + AddRecOps[i] = getAddExpr(AddRecOps[i], + OtherAddRec->getOperand(i)); + } + Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; } - NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i)); - } - const SCEV *NewAddRec = getAddRecExpr(NewOps, AddRecLoop); - - if (Ops.size() == 2) return NewAddRec; - - Ops.erase(Ops.begin()+Idx); - Ops.erase(Ops.begin()+OtherIdx-1); - Ops.push_back(NewAddRec); - return getAddExpr(Ops); - } + // Step size has changed, so we cannot guarantee no self-wraparound. + Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap); + return getAddExpr(Ops); } // Otherwise couldn't fold anything into this recurrence. Move onto the @@ -1603,7 +1784,6 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(scAddExpr); - ID.AddInteger(Ops.size()); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = 0; @@ -1616,33 +1796,38 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); } - if (HasNUW) S->setHasNoUnsignedWrap(true); - if (HasNSW) S->setHasNoSignedWrap(true); + S->setNoWrapFlags(Flags); return S; } /// getMulExpr - Get a canonical multiply expression, or something simpler if /// possible. const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, - bool HasNUW, bool HasNSW) { + SCEV::NoWrapFlags Flags) { + assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) && + "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty mul!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG + Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); for (unsigned i = 1, e = Ops.size(); i != e; ++i) - assert(getEffectiveSCEVType(Ops[i]->getType()) == - getEffectiveSCEVType(Ops[0]->getType()) && + assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && "SCEVMulExpr operand types don't match!"); #endif - // If HasNSW is true and all the operands are non-negative, infer HasNUW. - if (!HasNUW && HasNSW) { + // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. + // And vice-versa. + int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; + SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask); + if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) { bool All = true; - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - if (!isKnownNonNegative(Ops[i])) { + for (SmallVectorImpl::const_iterator I = Ops.begin(), + E = Ops.end(); I != E; ++I) + if (!isKnownNonNegative(*I)) { All = false; break; } - if (All) HasNUW = true; + if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); } // Sort by complexity, this groups all similar expression types together. @@ -1682,12 +1867,12 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, } else if (Ops[0]->isAllOnesValue()) { // If we have a mul by -1 of an add, try distributing the -1 among the // add operands. - if (Ops.size() == 2) + if (Ops.size() == 2) { if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) { SmallVector NewOps; bool AnyFolded = false; - for (SCEVAddRecExpr::op_iterator I = Add->op_begin(), E = Add->op_end(); - I != E; ++I) { + for (SCEVAddRecExpr::op_iterator I = Add->op_begin(), + E = Add->op_end(); I != E; ++I) { const SCEV *Mul = getMulExpr(Ops[0], *I); if (!isa(Mul)) AnyFolded = true; NewOps.push_back(Mul); @@ -1695,6 +1880,18 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, if (AnyFolded) return getAddExpr(NewOps); } + else if (const SCEVAddRecExpr * + AddRec = dyn_cast(Ops[1])) { + // Negation preserves a recurrence's no self-wrap property. + SmallVector Operands; + for (SCEVAddRecExpr::op_iterator I = AddRec->op_begin(), + E = AddRec->op_end(); I != E; ++I) { + Operands.push_back(getMulExpr(Ops[0], *I)); + } + return getAddRecExpr(Operands, AddRec->getLoop(), + AddRec->getNoWrapFlags(SCEV::FlagNW)); + } + } } if (Ops.size() == 1) @@ -1711,8 +1908,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, while (const SCEVMulExpr *Mul = dyn_cast(Ops[Idx])) { // If we have an mul, expand the mul operands onto the end of the operands // list. - Ops.insert(Ops.end(), Mul->op_begin(), Mul->op_end()); Ops.erase(Ops.begin()+Idx); + Ops.append(Mul->op_begin(), Mul->op_end()); DeletedMul = true; } @@ -1735,8 +1932,9 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // they are loop invariant w.r.t. the recurrence. SmallVector LIOps; const SCEVAddRecExpr *AddRec = cast(Ops[Idx]); + const Loop *AddRecLoop = AddRec->getLoop(); for (unsigned i = 0, e = Ops.size(); i != e; ++i) - if (Ops[i]->isLoopInvariant(AddRec->getLoop())) { + if (isLoopInvariant(Ops[i], AddRecLoop)) { LIOps.push_back(Ops[i]); Ops.erase(Ops.begin()+i); --i; --e; @@ -1747,23 +1945,17 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} SmallVector NewOps; NewOps.reserve(AddRec->getNumOperands()); - if (LIOps.size() == 1) { - const SCEV *Scale = LIOps[0]; - for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) - NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i))); - } else { - for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { - SmallVector MulOps(LIOps.begin(), LIOps.end()); - MulOps.push_back(AddRec->getOperand(i)); - NewOps.push_back(getMulExpr(MulOps)); - } - } - - // It's tempting to propagate the NSW flag here, but nsw multiplication - // is not associative so this isn't necessarily safe. - const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), - HasNUW && AddRec->hasNoUnsignedWrap(), - /*HasNSW=*/false); + const SCEV *Scale = getMulExpr(LIOps); + for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) + NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i))); + + // Build the new addrec. Propagate the NUW and NSW flags if both the + // outer mul and the inner addrec are guaranteed to have no overflow. + // + // No self-wrap cannot be guaranteed after changing the step size, but + // will be inferred if either NUW or NSW is true. + Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW)); + const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -1781,28 +1973,31 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // there are multiple AddRec's with the same loop induction variable being // multiplied together. If so, we can fold them. for (unsigned OtherIdx = Idx+1; - OtherIdx < Ops.size() && isa(Ops[OtherIdx]);++OtherIdx) - if (OtherIdx != Idx) { - const SCEVAddRecExpr *OtherAddRec = cast(Ops[OtherIdx]); - if (AddRec->getLoop() == OtherAddRec->getLoop()) { - // F * G --> {A,+,B} * {C,+,D} --> {A*C,+,F*D + G*B + B*D} - const SCEVAddRecExpr *F = AddRec, *G = OtherAddRec; - const SCEV *NewStart = getMulExpr(F->getStart(), - G->getStart()); - const SCEV *B = F->getStepRecurrence(*this); - const SCEV *D = G->getStepRecurrence(*this); - const SCEV *NewStep = getAddExpr(getMulExpr(F, D), - getMulExpr(G, B), - getMulExpr(B, D)); - const SCEV *NewAddRec = getAddRecExpr(NewStart, NewStep, - F->getLoop()); - if (Ops.size() == 2) return NewAddRec; - - Ops.erase(Ops.begin()+Idx); - Ops.erase(Ops.begin()+OtherIdx-1); - Ops.push_back(NewAddRec); - return getMulExpr(Ops); - } + OtherIdx < Ops.size() && isa(Ops[OtherIdx]); + ++OtherIdx) + if (AddRecLoop == cast(Ops[OtherIdx])->getLoop()) { + // F * G, where F = {A,+,B} and G = {C,+,D} --> + // {A*C,+,F*D + G*B + B*D} + for (; OtherIdx != Ops.size() && isa(Ops[OtherIdx]); + ++OtherIdx) + if (const SCEVAddRecExpr *OtherAddRec = + dyn_cast(Ops[OtherIdx])) + if (OtherAddRec->getLoop() == AddRecLoop) { + const SCEVAddRecExpr *F = AddRec, *G = OtherAddRec; + const SCEV *NewStart = getMulExpr(F->getStart(), G->getStart()); + const SCEV *B = F->getStepRecurrence(*this); + const SCEV *D = G->getStepRecurrence(*this); + const SCEV *NewStep = getAddExpr(getMulExpr(F, D), + getMulExpr(G, B), + getMulExpr(B, D)); + const SCEV *NewAddRec = getAddRecExpr(NewStart, NewStep, + F->getLoop(), + SCEV::FlagAnyWrap); + if (Ops.size() == 2) return NewAddRec; + Ops[Idx] = AddRec = cast(NewAddRec); + Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; + } + return getMulExpr(Ops); } // Otherwise couldn't fold anything into this recurrence. Move onto the @@ -1813,7 +2008,6 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(scMulExpr); - ID.AddInteger(Ops.size()); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = 0; @@ -1826,8 +2020,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); } - if (HasNUW) S->setHasNoUnsignedWrap(true); - if (HasNSW) S->setHasNoSignedWrap(true); + S->setNoWrapFlags(Flags); return S; } @@ -1849,14 +2042,14 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // Determine if the division can be folded into the operands of // its operands. // TODO: Generalize this to non-constants by using known-bits information. - const Type *Ty = LHS->getType(); + Type *Ty = LHS->getType(); unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros(); - unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ; + unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1; // For non-power-of-two values, effectively round the value up to the // nearest power of two. if (!RHSC->getValue()->getValue().isPowerOf2()) ++MaxShiftAmt; - const IntegerType *ExtTy = + IntegerType *ExtTy = IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt); // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. if (const SCEVAddRecExpr *AR = dyn_cast(LHS)) @@ -1867,11 +2060,12 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, getZeroExtendExpr(AR, ExtTy) == getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), getZeroExtendExpr(Step, ExtTy), - AR->getLoop())) { + AR->getLoop(), SCEV::FlagAnyWrap)) { SmallVector Operands; for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i) Operands.push_back(getUDivExpr(AR->getOperand(i), RHS)); - return getAddRecExpr(Operands, AR->getLoop()); + return getAddRecExpr(Operands, AR->getLoop(), + SCEV::FlagNW); } // (A*B)/C --> A*(B/C) if safe and B/C can be folded. if (const SCEVMulExpr *M = dyn_cast(LHS)) { @@ -1892,7 +2086,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, } } // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. - if (const SCEVAddRecExpr *A = dyn_cast(LHS)) { + if (const SCEVAddExpr *A = dyn_cast(LHS)) { SmallVector Operands; for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy)); @@ -1935,39 +2129,40 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, /// getAddRecExpr - Get an add recurrence expression for the specified loop. /// Simplify the expression as much as possible. -const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, - const SCEV *Step, const Loop *L, - bool HasNUW, bool HasNSW) { +const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, + const Loop *L, + SCEV::NoWrapFlags Flags) { SmallVector Operands; Operands.push_back(Start); if (const SCEVAddRecExpr *StepChrec = dyn_cast(Step)) if (StepChrec->getLoop() == L) { - Operands.insert(Operands.end(), StepChrec->op_begin(), - StepChrec->op_end()); - return getAddRecExpr(Operands, L); + Operands.append(StepChrec->op_begin(), StepChrec->op_end()); + return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW)); } Operands.push_back(Step); - return getAddRecExpr(Operands, L, HasNUW, HasNSW); + return getAddRecExpr(Operands, L, Flags); } /// getAddRecExpr - Get an add recurrence expression for the specified loop. /// Simplify the expression as much as possible. const SCEV * ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, - const Loop *L, - bool HasNUW, bool HasNSW) { + const Loop *L, SCEV::NoWrapFlags Flags) { if (Operands.size() == 1) return Operands[0]; #ifndef NDEBUG + Type *ETy = getEffectiveSCEVType(Operands[0]->getType()); for (unsigned i = 1, e = Operands.size(); i != e; ++i) - assert(getEffectiveSCEVType(Operands[i]->getType()) == - getEffectiveSCEVType(Operands[0]->getType()) && + assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy && "SCEVAddRecExpr operand types don't match!"); + for (unsigned i = 0, e = Operands.size(); i != e; ++i) + assert(isLoopInvariant(Operands[i], L) && + "SCEVAddRecExpr operand is not loop-invariant!"); #endif if (Operands.back()->isZero()) { Operands.pop_back(); - return getAddRecExpr(Operands, L, HasNUW, HasNSW); // {X,+,0} --> X + return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X } // It's tempting to want to call getMaxBackedgeTakenCount count here and @@ -1976,23 +2171,27 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, // meaningful BE count at this point (and if we don't, we'd be stuck // with a SCEVCouldNotCompute as the cached BE count). - // If HasNSW is true and all the operands are non-negative, infer HasNUW. - if (!HasNUW && HasNSW) { + // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. + // And vice-versa. + int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; + SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask); + if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) { bool All = true; - for (unsigned i = 0, e = Operands.size(); i != e; ++i) - if (!isKnownNonNegative(Operands[i])) { + for (SmallVectorImpl::const_iterator I = Operands.begin(), + E = Operands.end(); I != E; ++I) + if (!isKnownNonNegative(*I)) { All = false; break; } - if (All) HasNUW = true; + if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); } // Canonicalize nested AddRecs in by nesting them in order of loop depth. if (const SCEVAddRecExpr *NestedAR = dyn_cast(Operands[0])) { const Loop *NestedLoop = NestedAR->getLoop(); - if (L->contains(NestedLoop->getHeader()) ? + if (L->contains(NestedLoop) ? (L->getLoopDepth() < NestedLoop->getLoopDepth()) : - (!NestedLoop->contains(L->getHeader()) && + (!NestedLoop->contains(L) && DT->dominates(L->getHeader(), NestedLoop->getHeader()))) { SmallVector NestedOperands(NestedAR->op_begin(), NestedAR->op_end()); @@ -2002,21 +2201,34 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, // requirement. bool AllInvariant = true; for (unsigned i = 0, e = Operands.size(); i != e; ++i) - if (!Operands[i]->isLoopInvariant(L)) { + if (!isLoopInvariant(Operands[i], L)) { AllInvariant = false; break; } if (AllInvariant) { - NestedOperands[0] = getAddRecExpr(Operands, L); + // Create a recurrence for the outer loop with the same step size. + // + // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the + // inner recurrence has the same property. + SCEV::NoWrapFlags OuterFlags = + maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags()); + + NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags); AllInvariant = true; for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i) - if (!NestedOperands[i]->isLoopInvariant(NestedLoop)) { + if (!isLoopInvariant(NestedOperands[i], NestedLoop)) { AllInvariant = false; break; } - if (AllInvariant) + if (AllInvariant) { // Ok, both add recurrences are valid after the transformation. - return getAddRecExpr(NestedOperands, NestedLoop, HasNUW, HasNSW); + // + // The inner recurrence keeps its NW flag but only keeps NUW/NSW if + // the outer recurrence has the same property. + SCEV::NoWrapFlags InnerFlags = + maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags); + return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags); + } } // Reset Operands to its original state. Operands[0] = NestedAR; @@ -2027,7 +2239,6 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(scAddRecExpr); - ID.AddInteger(Operands.size()); for (unsigned i = 0, e = Operands.size(); i != e; ++i) ID.AddPointer(Operands[i]); ID.AddPointer(L); @@ -2041,8 +2252,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, O, Operands.size(), L); UniqueSCEVs.InsertNode(S, IP); } - if (HasNUW) S->setHasNoUnsignedWrap(true); - if (HasNSW) S->setHasNoSignedWrap(true); + S->setNoWrapFlags(Flags); return S; } @@ -2059,9 +2269,9 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { assert(!Ops.empty() && "Cannot get empty smax!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG + Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); for (unsigned i = 1, e = Ops.size(); i != e; ++i) - assert(getEffectiveSCEVType(Ops[i]->getType()) == - getEffectiveSCEVType(Ops[0]->getType()) && + assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && "SCEVSMaxExpr operand types don't match!"); #endif @@ -2106,8 +2316,8 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { if (Idx < Ops.size()) { bool DeletedSMax = false; while (const SCEVSMaxExpr *SMax = dyn_cast(Ops[Idx])) { - Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end()); Ops.erase(Ops.begin()+Idx); + Ops.append(SMax->op_begin(), SMax->op_end()); DeletedSMax = true; } @@ -2138,7 +2348,6 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(scSMaxExpr); - ID.AddInteger(Ops.size()); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = 0; @@ -2164,9 +2373,9 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { assert(!Ops.empty() && "Cannot get empty umax!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG + Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); for (unsigned i = 1, e = Ops.size(); i != e; ++i) - assert(getEffectiveSCEVType(Ops[i]->getType()) == - getEffectiveSCEVType(Ops[0]->getType()) && + assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && "SCEVUMaxExpr operand types don't match!"); #endif @@ -2211,8 +2420,8 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { if (Idx < Ops.size()) { bool DeletedUMax = false; while (const SCEVUMaxExpr *UMax = dyn_cast(Ops[Idx])) { - Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end()); Ops.erase(Ops.begin()+Idx); + Ops.append(UMax->op_begin(), UMax->op_end()); DeletedUMax = true; } @@ -2243,7 +2452,6 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(scUMaxExpr); - ID.AddInteger(Ops.size()); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = 0; @@ -2268,7 +2476,7 @@ const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); } -const SCEV *ScalarEvolution::getSizeOfExpr(const Type *AllocTy) { +const SCEV *ScalarEvolution::getSizeOfExpr(Type *AllocTy) { // If we have TargetData, we can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. @@ -2278,20 +2486,22 @@ const SCEV *ScalarEvolution::getSizeOfExpr(const Type *AllocTy) { Constant *C = ConstantExpr::getSizeOf(AllocTy); if (ConstantExpr *CE = dyn_cast(C)) - C = ConstantFoldConstantExpression(CE, TD); - const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy)); + if (Constant *Folded = ConstantFoldConstantExpression(CE, TD)) + C = Folded; + Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy)); return getTruncateOrZeroExtend(getSCEV(C), Ty); } -const SCEV *ScalarEvolution::getAlignOfExpr(const Type *AllocTy) { +const SCEV *ScalarEvolution::getAlignOfExpr(Type *AllocTy) { Constant *C = ConstantExpr::getAlignOf(AllocTy); if (ConstantExpr *CE = dyn_cast(C)) - C = ConstantFoldConstantExpression(CE, TD); - const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy)); + if (Constant *Folded = ConstantFoldConstantExpression(CE, TD)) + C = Folded; + Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy)); return getTruncateOrZeroExtend(getSCEV(C), Ty); } -const SCEV *ScalarEvolution::getOffsetOfExpr(const StructType *STy, +const SCEV *ScalarEvolution::getOffsetOfExpr(StructType *STy, unsigned FieldNo) { // If we have TargetData, we can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. @@ -2302,17 +2512,19 @@ const SCEV *ScalarEvolution::getOffsetOfExpr(const StructType *STy, Constant *C = ConstantExpr::getOffsetOf(STy, FieldNo); if (ConstantExpr *CE = dyn_cast(C)) - C = ConstantFoldConstantExpression(CE, TD); - const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy)); + if (Constant *Folded = ConstantFoldConstantExpression(CE, TD)) + C = Folded; + Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy)); return getTruncateOrZeroExtend(getSCEV(C), Ty); } -const SCEV *ScalarEvolution::getOffsetOfExpr(const Type *CTy, +const SCEV *ScalarEvolution::getOffsetOfExpr(Type *CTy, Constant *FieldNo) { Constant *C = ConstantExpr::getOffsetOf(CTy, FieldNo); if (ConstantExpr *CE = dyn_cast(C)) - C = ConstantFoldConstantExpression(CE, TD); - const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(CTy)); + if (Constant *Folded = ConstantFoldConstantExpression(CE, TD)) + C = Folded; + Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(CTy)); return getTruncateOrZeroExtend(getSCEV(C), Ty); } @@ -2326,8 +2538,14 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) { ID.AddInteger(scUnknown); ID.AddPointer(V); void *IP = 0; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V); + if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) { + assert(cast(S)->getValue() == V && + "Stale SCEVUnknown in uniquing map!"); + return S; + } + SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this, + FirstUnknown); + FirstUnknown = cast(S); UniqueSCEVs.InsertNode(S, IP); return S; } @@ -2340,14 +2558,14 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) { /// the SCEV framework. This primarily includes integer types, and it /// can optionally include pointer types if the ScalarEvolution class /// has access to target-specific information. -bool ScalarEvolution::isSCEVable(const Type *Ty) const { +bool ScalarEvolution::isSCEVable(Type *Ty) const { // Integers and pointers are always SCEVable. return Ty->isIntegerTy() || Ty->isPointerTy(); } /// getTypeSizeInBits - Return the size in bits of the specified type, /// for which isSCEVable must return true. -uint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const { +uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); // If we have a TargetData, use it! @@ -2368,7 +2586,7 @@ uint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const { /// the given type and which represents how SCEV will treat the given /// type, for which isSCEVable must return true. For pointer types, /// this is the pointer-sized integer type. -const Type *ScalarEvolution::getEffectiveSCEVType(const Type *Ty) const { +Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); if (Ty->isIntegerTy()) @@ -2391,18 +2609,16 @@ const SCEV *ScalarEvolution::getCouldNotCompute() { const SCEV *ScalarEvolution::getSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); - std::map::iterator I = Scalars.find(V); - if (I != Scalars.end()) return I->second; + ValueExprMapType::const_iterator I = ValueExprMap.find(V); + if (I != ValueExprMap.end()) return I->second; const SCEV *S = createSCEV(V); - Scalars.insert(std::make_pair(SCEVCallbackVH(V, this), S)); - return S; -} -/// getIntegerSCEV - Given a SCEVable type, create a constant for the -/// specified signed integer value and return a SCEV for the constant. -const SCEV *ScalarEvolution::getIntegerSCEV(int64_t Val, const Type *Ty) { - const IntegerType *ITy = cast(getEffectiveSCEVType(Ty)); - return getConstant(ConstantInt::get(ITy, Val)); + // The process of creating a SCEV for V may have caused other SCEVs + // to have been created, so it's necessary to insert the new entry + // from scratch, rather than trying to remember the insert position + // above. + ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S)); + return S; } /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V @@ -2412,7 +2628,7 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) { return getConstant( cast(ConstantExpr::getNeg(VC->getValue()))); - const Type *Ty = V->getType(); + Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); return getMulExpr(V, getConstant(cast(Constant::getAllOnesValue(Ty)))); @@ -2424,28 +2640,32 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { return getConstant( cast(ConstantExpr::getNot(VC->getValue()))); - const Type *Ty = V->getType(); + Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); const SCEV *AllOnes = getConstant(cast(Constant::getAllOnesValue(Ty))); return getMinusSCEV(AllOnes, V); } -/// getMinusSCEV - Return a SCEV corresponding to LHS - RHS. -/// -const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, - const SCEV *RHS) { +/// getMinusSCEV - Return LHS-RHS. Minus is represented in SCEV as A+B*-1. +const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, + SCEV::NoWrapFlags Flags) { + assert(!maskFlags(Flags, SCEV::FlagNUW) && "subtraction does not have NUW"); + + // Fast path: X - X --> 0. + if (LHS == RHS) + return getConstant(LHS->getType(), 0); + // X - Y --> X + -Y - return getAddExpr(LHS, getNegativeSCEV(RHS)); + return getAddExpr(LHS, getNegativeSCEV(RHS), Flags); } /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the /// input value to the specified type. If the type must be extended, it is zero /// extended. const SCEV * -ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, - const Type *Ty) { - const Type *SrcTy = V->getType(); +ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) { + Type *SrcTy = V->getType(); assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot truncate or zero extend with non-integer arguments!"); @@ -2461,8 +2681,8 @@ ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, /// extended. const SCEV * ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, - const Type *Ty) { - const Type *SrcTy = V->getType(); + Type *Ty) { + Type *SrcTy = V->getType(); assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot truncate or zero extend with non-integer arguments!"); @@ -2477,8 +2697,8 @@ ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, /// input value to the specified type. If the type must be extended, it is zero /// extended. The conversion must not be narrowing. const SCEV * -ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, const Type *Ty) { - const Type *SrcTy = V->getType(); +ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { + Type *SrcTy = V->getType(); assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot noop or zero extend with non-integer arguments!"); @@ -2493,8 +2713,8 @@ ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, const Type *Ty) { /// input value to the specified type. If the type must be extended, it is sign /// extended. The conversion must not be narrowing. const SCEV * -ScalarEvolution::getNoopOrSignExtend(const SCEV *V, const Type *Ty) { - const Type *SrcTy = V->getType(); +ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { + Type *SrcTy = V->getType(); assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot noop or sign extend with non-integer arguments!"); @@ -2510,8 +2730,8 @@ ScalarEvolution::getNoopOrSignExtend(const SCEV *V, const Type *Ty) { /// it is extended with unspecified bits. The conversion must not be /// narrowing. const SCEV * -ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, const Type *Ty) { - const Type *SrcTy = V->getType(); +ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { + Type *SrcTy = V->getType(); assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot noop or any extend with non-integer arguments!"); @@ -2525,8 +2745,8 @@ ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, const Type *Ty) { /// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the /// input value to the specified type. The conversion must not be widening. const SCEV * -ScalarEvolution::getTruncateOrNoop(const SCEV *V, const Type *Ty) { - const Type *SrcTy = V->getType(); +ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { + Type *SrcTy = V->getType(); assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot truncate or noop with non-integer arguments!"); @@ -2569,6 +2789,36 @@ const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, return getUMinExpr(PromotedLHS, PromotedRHS); } +/// getPointerBase - Transitively follow the chain of pointer-type operands +/// until reaching a SCEV that does not have a single pointer operand. This +/// returns a SCEVUnknown pointer for well-formed pointer-type expressions, +/// but corner cases do exist. +const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { + // A pointer operand may evaluate to a nonpointer expression, such as null. + if (!V->getType()->isPointerTy()) + return V; + + if (const SCEVCastExpr *Cast = dyn_cast(V)) { + return getPointerBase(Cast->getOperand()); + } + else if (const SCEVNAryExpr *NAry = dyn_cast(V)) { + const SCEV *PtrOp = 0; + for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); + I != E; ++I) { + if ((*I)->getType()->isPointerTy()) { + // Cannot find the base of an expression with multiple pointer operands. + if (PtrOp) + return V; + PtrOp = *I; + } + } + if (!PtrOp) + return V; + return getPointerBase(PtrOp); + } + return V; +} + /// PushDefUseChildren - Push users of the given Instruction /// onto the given Worklist. static void @@ -2577,12 +2827,12 @@ PushDefUseChildren(Instruction *I, // Push the def-use children onto the Worklist stack. for (Value::use_iterator UI = I->use_begin(), UE = I->use_end(); UI != UE; ++UI) - Worklist.push_back(cast(UI)); + Worklist.push_back(cast(*UI)); } /// ForgetSymbolicValue - This looks up computed SCEV values for all /// instructions that depend on the given instruction and removes them from -/// the Scalars map if they reference SymName. This is used during PHI +/// the ValueExprMapType map if they reference SymName. This is used during PHI /// resolution. void ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { @@ -2595,12 +2845,14 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { Instruction *I = Worklist.pop_back_val(); if (!Visited.insert(I)) continue; - std::map::iterator It = - Scalars.find(static_cast(I)); - if (It != Scalars.end()) { + ValueExprMapType::iterator It = + ValueExprMap.find(static_cast(I)); + if (It != ValueExprMap.end()) { + const SCEV *Old = It->second; + // Short-circuit the def-use traversal if the symbolic name // ceases to appear in expressions. - if (It->second != SymName && !It->second->hasOperand(SymName)) + if (Old != SymName && !hasOperand(Old, SymName)) continue; // SCEVUnknown for a PHI either means that it has an unrecognized @@ -2611,10 +2863,10 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { // updates on its own when it gets to that point. In the third, we do // want to forget the SCEVUnknown. if (!isa(I) || - !isa(It->second) || - (I != PN && It->second == SymName)) { - ValuesAtScopes.erase(It->second); - Scalars.erase(It); + !isa(Old) || + (I != PN && Old == SymName)) { + forgetMemoizedResults(Old); + ValueExprMap.erase(It); } } @@ -2651,9 +2903,9 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { if (BEValueV && StartValueV) { // While we are analyzing this PHI node, handle its value symbolically. const SCEV *SymbolicName = getUnknown(PN); - assert(Scalars.find(PN) == Scalars.end() && + assert(ValueExprMap.find(PN) == ValueExprMap.end() && "PHI node already processed?"); - Scalars.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName)); + ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName)); // Using this symbolic name for the PHI, analyze the value coming around // the back-edge. @@ -2685,36 +2937,43 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // This is not a valid addrec if the step amount is varying each // loop iteration, but is not itself an addrec in this loop. - if (Accum->isLoopInvariant(L) || + if (isLoopInvariant(Accum, L) || (isa(Accum) && cast(Accum)->getLoop() == L)) { - bool HasNUW = false; - bool HasNSW = false; + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; // If the increment doesn't overflow, then neither the addrec nor // the post-increment will overflow. if (const AddOperator *OBO = dyn_cast(BEValueV)) { if (OBO->hasNoUnsignedWrap()) - HasNUW = true; + Flags = setFlags(Flags, SCEV::FlagNUW); if (OBO->hasNoSignedWrap()) - HasNSW = true; + Flags = setFlags(Flags, SCEV::FlagNSW); + } else if (const GEPOperator *GEP = + dyn_cast(BEValueV)) { + // If the increment is an inbounds GEP, then we know the address + // space cannot be wrapped around. We cannot make any guarantee + // about signed or unsigned overflow because pointers are + // unsigned but we may have a negative index from the base + // pointer. + if (GEP->isInBounds()) + Flags = setFlags(Flags, SCEV::FlagNW); } const SCEV *StartVal = getSCEV(StartValueV); - const SCEV *PHISCEV = - getAddRecExpr(StartVal, Accum, L, HasNUW, HasNSW); + const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); // Since the no-wrap flags are on the increment, they apply to the // post-incremented value as well. - if (Accum->isLoopInvariant(L)) + if (isLoopInvariant(Accum, L)) (void)getAddRecExpr(getAddExpr(StartVal, Accum), - Accum, L, HasNUW, HasNSW); + Accum, L, Flags); // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the // entries for the scalars that use the symbolic expression. ForgetSymbolicName(PN, SymbolicName); - Scalars[SCEVCallbackVH(PN, this)] = PHISCEV; + ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; return PHISCEV; } } @@ -2732,14 +2991,17 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // initial step of the addrec evolution. if (StartVal == getMinusSCEV(AddRec->getOperand(0), AddRec->getOperand(1))) { + // FIXME: For constant StartVal, we should be able to infer + // no-wrap flags. const SCEV *PHISCEV = - getAddRecExpr(StartVal, AddRec->getOperand(1), L); + getAddRecExpr(StartVal, AddRec->getOperand(1), L, + SCEV::FlagAnyWrap); // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the // entries for the scalars that use the symbolic expression. ForgetSymbolicName(PN, SymbolicName); - Scalars[SCEVCallbackVH(PN, this)] = PHISCEV; + ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; return PHISCEV; } } @@ -2751,17 +3013,9 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // PHI's incoming blocks are in a different loop, in which case doing so // risks breaking LCSSA form. Instcombine would normally zap these, but // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = PN->hasConstantValue(DT)) { - bool AllSameLoop = true; - Loop *PNLoop = LI->getLoopFor(PN->getParent()); - for (size_t i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (LI->getLoopFor(PN->getIncomingBlock(i)) != PNLoop) { - AllSameLoop = false; - break; - } - if (AllSameLoop) + if (Value *V = SimplifyInstruction(PN, TD, DT)) + if (LI->replacementPreservesLCSSAForm(PN, V)) return getSCEV(V); - } // If it's not a loop phi, we can't handle it yet. return getUnknown(PN); @@ -2772,39 +3026,54 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { /// const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { - bool InBounds = GEP->isInBounds(); - const Type *IntPtrTy = getEffectiveSCEVType(GEP->getType()); + // Don't blindly transfer the inbounds flag from the GEP instruction to the + // Add expression, because the Instruction may be guarded by control flow + // and the no-overflow bits may not be valid for the expression in any + // context. + bool isInBounds = GEP->isInBounds(); + + Type *IntPtrTy = getEffectiveSCEVType(GEP->getType()); Value *Base = GEP->getOperand(0); // Don't attempt to analyze GEPs over unsized objects. if (!cast(Base->getType())->getElementType()->isSized()) return getUnknown(GEP); - const SCEV *TotalOffset = getIntegerSCEV(0, IntPtrTy); + const SCEV *TotalOffset = getConstant(IntPtrTy, 0); gep_type_iterator GTI = gep_type_begin(GEP); - for (GetElementPtrInst::op_iterator I = next(GEP->op_begin()), + for (GetElementPtrInst::op_iterator I = llvm::next(GEP->op_begin()), E = GEP->op_end(); I != E; ++I) { Value *Index = *I; // Compute the (potentially symbolic) offset in bytes for this index. - if (const StructType *STy = dyn_cast(*GTI++)) { + if (StructType *STy = dyn_cast(*GTI++)) { // For a struct, add the member offset. unsigned FieldNo = cast(Index)->getZExtValue(); - TotalOffset = getAddExpr(TotalOffset, - getOffsetOfExpr(STy, FieldNo), - /*HasNUW=*/false, /*HasNSW=*/InBounds); + const SCEV *FieldOffset = getOffsetOfExpr(STy, FieldNo); + + // Add the field offset to the running total offset. + TotalOffset = getAddExpr(TotalOffset, FieldOffset); } else { // For an array, add the element offset, explicitly scaled. - const SCEV *LocalOffset = getSCEV(Index); + const SCEV *ElementSize = getSizeOfExpr(*GTI); + const SCEV *IndexS = getSCEV(Index); // Getelementptr indices are signed. - LocalOffset = getTruncateOrSignExtend(LocalOffset, IntPtrTy); - // Lower "inbounds" GEPs to NSW arithmetic. - LocalOffset = getMulExpr(LocalOffset, getSizeOfExpr(*GTI), - /*HasNUW=*/false, /*HasNSW=*/InBounds); - TotalOffset = getAddExpr(TotalOffset, LocalOffset, - /*HasNUW=*/false, /*HasNSW=*/InBounds); + IndexS = getTruncateOrSignExtend(IndexS, IntPtrTy); + + // Multiply the index by the element size to compute the element offset. + const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize, + isInBounds ? SCEV::FlagNSW : + SCEV::FlagAnyWrap); + + // Add the element offset to the running total offset. + TotalOffset = getAddExpr(TotalOffset, LocalOffset); } } - return getAddExpr(getSCEV(Base), TotalOffset, - /*HasNUW=*/false, /*HasNSW=*/InBounds); + + // Get the SCEV for the GEP base. + const SCEV *BaseS = getSCEV(Base); + + // Add the total offset from all the GEP indices to the base. + return getAddExpr(BaseS, TotalOffset, + isInBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap); } /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is @@ -2892,9 +3161,13 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { /// ConstantRange ScalarEvolution::getUnsignedRange(const SCEV *S) { + // See if we've computed this range already. + DenseMap::iterator I = UnsignedRanges.find(S); + if (I != UnsignedRanges.end()) + return I->second; if (const SCEVConstant *C = dyn_cast(S)) - return ConstantRange(C->getValue()->getValue()); + return setUnsignedRange(C, ConstantRange(C->getValue()->getValue())); unsigned BitWidth = getTypeSizeInBits(S->getType()); ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); @@ -2911,63 +3184,67 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) { ConstantRange X = getUnsignedRange(Add->getOperand(0)); for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) X = X.add(getUnsignedRange(Add->getOperand(i))); - return ConservativeResult.intersectWith(X); + return setUnsignedRange(Add, ConservativeResult.intersectWith(X)); } if (const SCEVMulExpr *Mul = dyn_cast(S)) { ConstantRange X = getUnsignedRange(Mul->getOperand(0)); for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) X = X.multiply(getUnsignedRange(Mul->getOperand(i))); - return ConservativeResult.intersectWith(X); + return setUnsignedRange(Mul, ConservativeResult.intersectWith(X)); } if (const SCEVSMaxExpr *SMax = dyn_cast(S)) { ConstantRange X = getUnsignedRange(SMax->getOperand(0)); for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) X = X.smax(getUnsignedRange(SMax->getOperand(i))); - return ConservativeResult.intersectWith(X); + return setUnsignedRange(SMax, ConservativeResult.intersectWith(X)); } if (const SCEVUMaxExpr *UMax = dyn_cast(S)) { ConstantRange X = getUnsignedRange(UMax->getOperand(0)); for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) X = X.umax(getUnsignedRange(UMax->getOperand(i))); - return ConservativeResult.intersectWith(X); + return setUnsignedRange(UMax, ConservativeResult.intersectWith(X)); } if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { ConstantRange X = getUnsignedRange(UDiv->getLHS()); ConstantRange Y = getUnsignedRange(UDiv->getRHS()); - return ConservativeResult.intersectWith(X.udiv(Y)); + return setUnsignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y))); } if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) { ConstantRange X = getUnsignedRange(ZExt->getOperand()); - return ConservativeResult.intersectWith(X.zeroExtend(BitWidth)); + return setUnsignedRange(ZExt, + ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); } if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) { ConstantRange X = getUnsignedRange(SExt->getOperand()); - return ConservativeResult.intersectWith(X.signExtend(BitWidth)); + return setUnsignedRange(SExt, + ConservativeResult.intersectWith(X.signExtend(BitWidth))); } if (const SCEVTruncateExpr *Trunc = dyn_cast(S)) { ConstantRange X = getUnsignedRange(Trunc->getOperand()); - return ConservativeResult.intersectWith(X.truncate(BitWidth)); + return setUnsignedRange(Trunc, + ConservativeResult.intersectWith(X.truncate(BitWidth))); } if (const SCEVAddRecExpr *AddRec = dyn_cast(S)) { // If there's no unsigned wrap, the value will never be less than its // initial value. - if (AddRec->hasNoUnsignedWrap()) + if (AddRec->getNoWrapFlags(SCEV::FlagNUW)) if (const SCEVConstant *C = dyn_cast(AddRec->getStart())) if (!C->getValue()->isZero()) ConservativeResult = - ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0)); + ConservativeResult.intersectWith( + ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0))); // TODO: non-affine addrec if (AddRec->isAffine()) { - const Type *Ty = AddRec->getType(); + Type *Ty = AddRec->getType(); const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa(MaxBECount) && getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { @@ -2992,19 +3269,20 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) { ConstantRange ExtEndRange = EndRange.zextOrTrunc(BitWidth*2+1); if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) != ExtEndRange) - return ConservativeResult; + return setUnsignedRange(AddRec, ConservativeResult); APInt Min = APIntOps::umin(StartRange.getUnsignedMin(), EndRange.getUnsignedMin()); APInt Max = APIntOps::umax(StartRange.getUnsignedMax(), EndRange.getUnsignedMax()); if (Min.isMinValue() && Max.isMaxValue()) - return ConservativeResult; - return ConservativeResult.intersectWith(ConstantRange(Min, Max+1)); + return setUnsignedRange(AddRec, ConservativeResult); + return setUnsignedRange(AddRec, + ConservativeResult.intersectWith(ConstantRange(Min, Max+1))); } } - return ConservativeResult; + return setUnsignedRange(AddRec, ConservativeResult); } if (const SCEVUnknown *U = dyn_cast(S)) { @@ -3013,20 +3291,25 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) { APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD); if (Ones == ~Zeros + 1) - return ConservativeResult; - return ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)); + return setUnsignedRange(U, ConservativeResult); + return setUnsignedRange(U, + ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1))); } - return ConservativeResult; + return setUnsignedRange(S, ConservativeResult); } /// getSignedRange - Determine the signed range for a particular SCEV. /// ConstantRange ScalarEvolution::getSignedRange(const SCEV *S) { + // See if we've computed this range already. + DenseMap::iterator I = SignedRanges.find(S); + if (I != SignedRanges.end()) + return I->second; if (const SCEVConstant *C = dyn_cast(S)) - return ConstantRange(C->getValue()->getValue()); + return setSignedRange(C, ConstantRange(C->getValue()->getValue())); unsigned BitWidth = getTypeSizeInBits(S->getType()); ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); @@ -3043,55 +3326,58 @@ ScalarEvolution::getSignedRange(const SCEV *S) { ConstantRange X = getSignedRange(Add->getOperand(0)); for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) X = X.add(getSignedRange(Add->getOperand(i))); - return ConservativeResult.intersectWith(X); + return setSignedRange(Add, ConservativeResult.intersectWith(X)); } if (const SCEVMulExpr *Mul = dyn_cast(S)) { ConstantRange X = getSignedRange(Mul->getOperand(0)); for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) X = X.multiply(getSignedRange(Mul->getOperand(i))); - return ConservativeResult.intersectWith(X); + return setSignedRange(Mul, ConservativeResult.intersectWith(X)); } if (const SCEVSMaxExpr *SMax = dyn_cast(S)) { ConstantRange X = getSignedRange(SMax->getOperand(0)); for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) X = X.smax(getSignedRange(SMax->getOperand(i))); - return ConservativeResult.intersectWith(X); + return setSignedRange(SMax, ConservativeResult.intersectWith(X)); } if (const SCEVUMaxExpr *UMax = dyn_cast(S)) { ConstantRange X = getSignedRange(UMax->getOperand(0)); for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) X = X.umax(getSignedRange(UMax->getOperand(i))); - return ConservativeResult.intersectWith(X); + return setSignedRange(UMax, ConservativeResult.intersectWith(X)); } if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { ConstantRange X = getSignedRange(UDiv->getLHS()); ConstantRange Y = getSignedRange(UDiv->getRHS()); - return ConservativeResult.intersectWith(X.udiv(Y)); + return setSignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y))); } if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) { ConstantRange X = getSignedRange(ZExt->getOperand()); - return ConservativeResult.intersectWith(X.zeroExtend(BitWidth)); + return setSignedRange(ZExt, + ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); } if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) { ConstantRange X = getSignedRange(SExt->getOperand()); - return ConservativeResult.intersectWith(X.signExtend(BitWidth)); + return setSignedRange(SExt, + ConservativeResult.intersectWith(X.signExtend(BitWidth))); } if (const SCEVTruncateExpr *Trunc = dyn_cast(S)) { ConstantRange X = getSignedRange(Trunc->getOperand()); - return ConservativeResult.intersectWith(X.truncate(BitWidth)); + return setSignedRange(Trunc, + ConservativeResult.intersectWith(X.truncate(BitWidth))); } if (const SCEVAddRecExpr *AddRec = dyn_cast(S)) { // If there's no signed wrap, and all the operands have the same sign or // zero, the value won't ever change sign. - if (AddRec->hasNoSignedWrap()) { + if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) { bool AllNonNeg = true; bool AllNonPos = true; for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { @@ -3110,7 +3396,7 @@ ScalarEvolution::getSignedRange(const SCEV *S) { // TODO: non-affine addrec if (AddRec->isAffine()) { - const Type *Ty = AddRec->getType(); + Type *Ty = AddRec->getType(); const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa(MaxBECount) && getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { @@ -3135,34 +3421,35 @@ ScalarEvolution::getSignedRange(const SCEV *S) { ConstantRange ExtEndRange = EndRange.sextOrTrunc(BitWidth*2+1); if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) != ExtEndRange) - return ConservativeResult; + return setSignedRange(AddRec, ConservativeResult); APInt Min = APIntOps::smin(StartRange.getSignedMin(), EndRange.getSignedMin()); APInt Max = APIntOps::smax(StartRange.getSignedMax(), EndRange.getSignedMax()); if (Min.isMinSignedValue() && Max.isMaxSignedValue()) - return ConservativeResult; - return ConservativeResult.intersectWith(ConstantRange(Min, Max+1)); + return setSignedRange(AddRec, ConservativeResult); + return setSignedRange(AddRec, + ConservativeResult.intersectWith(ConstantRange(Min, Max+1))); } } - return ConservativeResult; + return setSignedRange(AddRec, ConservativeResult); } if (const SCEVUnknown *U = dyn_cast(S)) { // For a SCEVUnknown, ask ValueTracking. if (!U->getValue()->getType()->isIntegerTy() && !TD) - return ConservativeResult; + return setSignedRange(U, ConservativeResult); unsigned NS = ComputeNumSignBits(U->getValue(), TD); if (NS == 1) - return ConservativeResult; - return ConservativeResult.intersectWith( + return setSignedRange(U, ConservativeResult); + return setSignedRange(U, ConservativeResult.intersectWith( ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), - APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1)); + APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1))); } - return ConservativeResult; + return setSignedRange(S, ConservativeResult); } /// createSCEV - We know that there is no SCEV for the specified value. @@ -3187,7 +3474,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { else if (ConstantInt *CI = dyn_cast(V)) return getConstant(CI); else if (isa(V)) - return getIntegerSCEV(0, V->getType()); + return getConstant(V->getType(), 0); else if (GlobalAlias *GA = dyn_cast(V)) return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee()); else @@ -3195,18 +3482,42 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { Operator *U = cast(V); switch (Opcode) { - case Instruction::Add: - // Don't transfer the NSW and NUW bits from the Add instruction to the - // Add expression, because the Instruction may be guarded by control - // flow and the no-overflow bits may not be valid for the expression in - // any context. - return getAddExpr(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1))); - case Instruction::Mul: - // Don't transfer the NSW and NUW bits from the Mul instruction to the - // Mul expression, as with Add. - return getMulExpr(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1))); + case Instruction::Add: { + // The simple thing to do would be to just call getSCEV on both operands + // and call getAddExpr with the result. However if we're looking at a + // bunch of things all added together, this can be quite inefficient, + // because it leads to N-1 getAddExpr calls for N ultimate operands. + // Instead, gather up all the operands and make a single getAddExpr call. + // LLVM IR canonical form means we need only traverse the left operands. + SmallVector AddOps; + AddOps.push_back(getSCEV(U->getOperand(1))); + for (Value *Op = U->getOperand(0); ; Op = U->getOperand(0)) { + unsigned Opcode = Op->getValueID() - Value::InstructionVal; + if (Opcode != Instruction::Add && Opcode != Instruction::Sub) + break; + U = cast(Op); + const SCEV *Op1 = getSCEV(U->getOperand(1)); + if (Opcode == Instruction::Sub) + AddOps.push_back(getNegativeSCEV(Op1)); + else + AddOps.push_back(Op1); + } + AddOps.push_back(getSCEV(U->getOperand(0))); + return getAddExpr(AddOps); + } + case Instruction::Mul: { + // See the Add code above. + SmallVector MulOps; + MulOps.push_back(getSCEV(U->getOperand(1))); + for (Value *Op = U->getOperand(0); + Op->getValueID() == Instruction::Mul + Value::InstructionVal; + Op = U->getOperand(0)) { + U = cast(Op); + MulOps.push_back(getSCEV(U->getOperand(1))); + } + MulOps.push_back(getSCEV(U->getOperand(0))); + return getMulExpr(MulOps); + } case Instruction::UDiv: return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1))); @@ -3261,10 +3572,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // transfer the no-wrap flags, since an or won't introduce a wrap. if (const SCEVAddRecExpr *NewAR = dyn_cast(S)) { const SCEVAddRecExpr *OldAR = cast(LHS); - if (OldAR->hasNoUnsignedWrap()) - const_cast(NewAR)->setHasNoUnsignedWrap(true); - if (OldAR->hasNoSignedWrap()) - const_cast(NewAR)->setHasNoSignedWrap(true); + const_cast(NewAR)->setNoWrapFlags( + OldAR->getNoWrapFlags()); } return S; } @@ -3292,9 +3601,9 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { LCI->getValue() == CI->getValue()) if (const SCEVZeroExtendExpr *Z = dyn_cast(getSCEV(U->getOperand(0)))) { - const Type *UTy = U->getType(); + Type *UTy = U->getType(); const SCEV *Z0 = Z->getOperand(); - const Type *Z0Ty = Z0->getType(); + Type *Z0Ty = Z0->getType(); unsigned Z0TySize = getTypeSizeInBits(Z0Ty); // If C is a low-bits mask, the zero extend is serving to @@ -3306,8 +3615,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // If C is a single bit, it may be in the sign-bit position // before the zero-extend. In this case, represent the xor // using an add, which is equivalent, and re-apply the zext. - APInt Trunc = APInt(CI->getValue()).trunc(Z0TySize); - if (APInt(Trunc).zext(getTypeSizeInBits(UTy)) == CI->getValue() && + APInt Trunc = CI->getValue().trunc(Z0TySize); + if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && Trunc.isSignBit()) return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), UTy); @@ -3468,7 +3777,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { const SCEV *LDiff = getMinusSCEV(LA, LS); const SCEV *RDiff = getMinusSCEV(RA, One); if (LDiff == RDiff) - return getAddExpr(getUMaxExpr(LS, One), LDiff); + return getAddExpr(getUMaxExpr(One, LS), LDiff); } break; case ICmpInst::ICMP_EQ: @@ -3483,7 +3792,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { const SCEV *LDiff = getMinusSCEV(LA, One); const SCEV *RDiff = getMinusSCEV(RA, LS); if (LDiff == RDiff) - return getAddExpr(getUMaxExpr(LS, One), LDiff); + return getAddExpr(getUMaxExpr(One, LS), LDiff); } break; default: @@ -3504,6 +3813,13 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Iteration Count Computation Code // +// getExitCount - Get the expression for the number of loop iterations for which +// this loop is guaranteed not to exit via ExitintBlock. Otherwise return +// SCEVCouldNotCompute. +const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) { + return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); +} + /// getBackedgeTakenCount - If the specified loop has a predictable /// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute /// object. The backedge-taken count is the number of times the loop header @@ -3516,14 +3832,14 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { /// hasLoopInvariantBackedgeTakenCount). /// const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) { - return getBackedgeTakenInfo(L).Exact; + return getBackedgeTakenInfo(L).getExact(this); } /// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except /// return the least SCEV value that is known never to be less than the /// actual backedge taken count. const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { - return getBackedgeTakenInfo(L).Max; + return getBackedgeTakenInfo(L).getMax(this); } /// PushLoopPHIs - Push PHI nodes in the header of the given loop @@ -3540,68 +3856,76 @@ PushLoopPHIs(const Loop *L, SmallVectorImpl &Worklist) { const ScalarEvolution::BackedgeTakenInfo & ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { - // Initially insert a CouldNotCompute for this loop. If the insertion + // Initially insert an invalid entry for this loop. If the insertion // succeeds, proceed to actually compute a backedge-taken count and // update the value. The temporary CouldNotCompute value tells SCEV // code elsewhere that it shouldn't attempt to request a new // backedge-taken count, which could result in infinite recursion. - std::pair::iterator, bool> Pair = - BackedgeTakenCounts.insert(std::make_pair(L, getCouldNotCompute())); - if (Pair.second) { - BackedgeTakenInfo BECount = ComputeBackedgeTakenCount(L); - if (BECount.Exact != getCouldNotCompute()) { - assert(BECount.Exact->isLoopInvariant(L) && - BECount.Max->isLoopInvariant(L) && - "Computed backedge-taken count isn't loop invariant for loop!"); - ++NumTripCountsComputed; - - // Update the value in the map. - Pair.first->second = BECount; - } else { - if (BECount.Max != getCouldNotCompute()) - // Update the value in the map. - Pair.first->second = BECount; - if (isa(L->getHeader()->begin())) - // Only count loops that have phi nodes as not being computable. - ++NumTripCountsNotComputed; - } - - // Now that we know more about the trip count for this loop, forget any - // existing SCEV values for PHI nodes in this loop since they are only - // conservative estimates made without the benefit of trip count - // information. This is similar to the code in forgetLoop, except that - // it handles SCEVUnknown PHI nodes specially. - if (BECount.hasAnyInfo()) { - SmallVector Worklist; - PushLoopPHIs(L, Worklist); - - SmallPtrSet Visited; - while (!Worklist.empty()) { - Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; - - std::map::iterator It = - Scalars.find(static_cast(I)); - if (It != Scalars.end()) { - // SCEVUnknown for a PHI either means that it has an unrecognized - // structure, or it's a PHI that's in the progress of being computed - // by createNodeForPHI. In the former case, additional loop trip - // count information isn't going to change anything. In the later - // case, createNodeForPHI will perform the necessary updates on its - // own when it gets to that point. - if (!isa(I) || !isa(It->second)) { - ValuesAtScopes.erase(It->second); - Scalars.erase(It); - } - if (PHINode *PN = dyn_cast(I)) - ConstantEvolutionLoopExitValue.erase(PN); + std::pair::iterator, bool> Pair = + BackedgeTakenCounts.insert(std::make_pair(L, BackedgeTakenInfo())); + if (!Pair.second) + return Pair.first->second; + + // ComputeBackedgeTakenCount may allocate memory for its result. Inserting it + // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result + // must be cleared in this scope. + BackedgeTakenInfo Result = ComputeBackedgeTakenCount(L); + + if (Result.getExact(this) != getCouldNotCompute()) { + assert(isLoopInvariant(Result.getExact(this), L) && + isLoopInvariant(Result.getMax(this), L) && + "Computed backedge-taken count isn't loop invariant for loop!"); + ++NumTripCountsComputed; + } + else if (Result.getMax(this) == getCouldNotCompute() && + isa(L->getHeader()->begin())) { + // Only count loops that have phi nodes as not being computable. + ++NumTripCountsNotComputed; + } + + // Now that we know more about the trip count for this loop, forget any + // existing SCEV values for PHI nodes in this loop since they are only + // conservative estimates made without the benefit of trip count + // information. This is similar to the code in forgetLoop, except that + // it handles SCEVUnknown PHI nodes specially. + if (Result.hasAnyInfo()) { + SmallVector Worklist; + PushLoopPHIs(L, Worklist); + + SmallPtrSet Visited; + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + if (!Visited.insert(I)) continue; + + ValueExprMapType::iterator It = + ValueExprMap.find(static_cast(I)); + if (It != ValueExprMap.end()) { + const SCEV *Old = It->second; + + // SCEVUnknown for a PHI either means that it has an unrecognized + // structure, or it's a PHI that's in the progress of being computed + // by createNodeForPHI. In the former case, additional loop trip + // count information isn't going to change anything. In the later + // case, createNodeForPHI will perform the necessary updates on its + // own when it gets to that point. + if (!isa(I) || !isa(Old)) { + forgetMemoizedResults(Old); + ValueExprMap.erase(It); } - - PushDefUseChildren(I, Worklist); + if (PHINode *PN = dyn_cast(I)) + ConstantEvolutionLoopExitValue.erase(PN); } + + PushDefUseChildren(I, Worklist); } } - return Pair.first->second; + + // Re-lookup the insert position, since the call to + // ComputeBackedgeTakenCount above could result in a + // recusive call to getBackedgeTakenInfo (on a different + // loop), which would invalidate the iterator computed + // earlier. + return BackedgeTakenCounts.find(L)->second = Result; } /// forgetLoop - This method should be called by the client when it has @@ -3609,7 +3933,12 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { /// compute a trip count, or if the loop is deleted. void ScalarEvolution::forgetLoop(const Loop *L) { // Drop any stored trip count value. - BackedgeTakenCounts.erase(L); + DenseMap::iterator BTCPos = + BackedgeTakenCounts.find(L); + if (BTCPos != BackedgeTakenCounts.end()) { + BTCPos->second.clear(); + BackedgeTakenCounts.erase(BTCPos); + } // Drop information about expressions based on loop-header PHIs. SmallVector Worklist; @@ -3620,17 +3949,21 @@ void ScalarEvolution::forgetLoop(const Loop *L) { Instruction *I = Worklist.pop_back_val(); if (!Visited.insert(I)) continue; - std::map::iterator It = - Scalars.find(static_cast(I)); - if (It != Scalars.end()) { - ValuesAtScopes.erase(It->second); - Scalars.erase(It); + ValueExprMapType::iterator It = ValueExprMap.find(static_cast(I)); + if (It != ValueExprMap.end()) { + forgetMemoizedResults(It->second); + ValueExprMap.erase(It); if (PHINode *PN = dyn_cast(I)) ConstantEvolutionLoopExitValue.erase(PN); } PushDefUseChildren(I, Worklist); } + + // Forget all contained loops too, to avoid dangling entries in the + // ValuesAtScopes map. + for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) + forgetLoop(*I); } /// forgetValue - This method should be called by the client when it has @@ -3649,11 +3982,10 @@ void ScalarEvolution::forgetValue(Value *V) { I = Worklist.pop_back_val(); if (!Visited.insert(I)) continue; - std::map::iterator It = - Scalars.find(static_cast(I)); - if (It != Scalars.end()) { - ValuesAtScopes.erase(It->second); - Scalars.erase(It); + ValueExprMapType::iterator It = ValueExprMap.find(static_cast(I)); + if (It != ValueExprMap.end()) { + forgetMemoizedResults(It->second); + ValueExprMap.erase(It); if (PHINode *PN = dyn_cast(I)) ConstantEvolutionLoopExitValue.erase(PN); } @@ -3662,6 +3994,84 @@ void ScalarEvolution::forgetValue(Value *V) { } } +/// getExact - Get the exact loop backedge taken count considering all loop +/// exits. If all exits are computable, this is the minimum computed count. +const SCEV * +ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const { + // If any exits were not computable, the loop is not computable. + if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute(); + + // We need at least one computable exit. + if (!ExitNotTaken.ExitingBlock) return SE->getCouldNotCompute(); + assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info"); + + const SCEV *BECount = 0; + for (const ExitNotTakenInfo *ENT = &ExitNotTaken; + ENT != 0; ENT = ENT->getNextExit()) { + + assert(ENT->ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); + + if (!BECount) + BECount = ENT->ExactNotTaken; + else + BECount = SE->getUMinFromMismatchedTypes(BECount, ENT->ExactNotTaken); + } + return BECount; +} + +/// getExact - Get the exact not taken count for this loop exit. +const SCEV * +ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock, + ScalarEvolution *SE) const { + for (const ExitNotTakenInfo *ENT = &ExitNotTaken; + ENT != 0; ENT = ENT->getNextExit()) { + + if (ENT->ExitingBlock == ExitingBlock) + return ENT->ExactNotTaken; + } + return SE->getCouldNotCompute(); +} + +/// getMax - Get the max backedge taken count for the loop. +const SCEV * +ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { + return Max ? Max : SE->getCouldNotCompute(); +} + +/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each +/// computable exit into a persistent ExitNotTakenInfo array. +ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( + SmallVectorImpl< std::pair > &ExitCounts, + bool Complete, const SCEV *MaxCount) : Max(MaxCount) { + + if (!Complete) + ExitNotTaken.setIncomplete(); + + unsigned NumExits = ExitCounts.size(); + if (NumExits == 0) return; + + ExitNotTaken.ExitingBlock = ExitCounts[0].first; + ExitNotTaken.ExactNotTaken = ExitCounts[0].second; + if (NumExits == 1) return; + + // Handle the rare case of multiple computable exits. + ExitNotTakenInfo *ENT = new ExitNotTakenInfo[NumExits-1]; + + ExitNotTakenInfo *PrevENT = &ExitNotTaken; + for (unsigned i = 1; i < NumExits; ++i, PrevENT = ENT, ++ENT) { + PrevENT->setNextExit(ENT); + ENT->ExitingBlock = ExitCounts[i].first; + ENT->ExactNotTaken = ExitCounts[i].second; + } +} + +/// clear - Invalidate this result and free the ExitNotTakenInfo array. +void ScalarEvolution::BackedgeTakenInfo::clear() { + ExitNotTaken.ExitingBlock = 0; + ExitNotTaken.ExactNotTaken = 0; + delete[] ExitNotTaken.getNextExit(); +} + /// ComputeBackedgeTakenCount - Compute the number of times the backedge /// of the specified loop will execute. ScalarEvolution::BackedgeTakenInfo @@ -3670,38 +4080,31 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { L->getExitingBlocks(ExitingBlocks); // Examine all exits and pick the most conservative values. - const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); - bool CouldNotComputeBECount = false; + bool CouldComputeBECount = true; + SmallVector, 4> ExitCounts; for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { - BackedgeTakenInfo NewBTI = - ComputeBackedgeTakenCountFromExit(L, ExitingBlocks[i]); - - if (NewBTI.Exact == getCouldNotCompute()) { + ExitLimit EL = ComputeExitLimit(L, ExitingBlocks[i]); + if (EL.Exact == getCouldNotCompute()) // We couldn't compute an exact value for this exit, so // we won't be able to compute an exact value for the loop. - CouldNotComputeBECount = true; - BECount = getCouldNotCompute(); - } else if (!CouldNotComputeBECount) { - if (BECount == getCouldNotCompute()) - BECount = NewBTI.Exact; - else - BECount = getUMinFromMismatchedTypes(BECount, NewBTI.Exact); - } + CouldComputeBECount = false; + else + ExitCounts.push_back(std::make_pair(ExitingBlocks[i], EL.Exact)); + if (MaxBECount == getCouldNotCompute()) - MaxBECount = NewBTI.Max; - else if (NewBTI.Max != getCouldNotCompute()) - MaxBECount = getUMinFromMismatchedTypes(MaxBECount, NewBTI.Max); + MaxBECount = EL.Max; + else if (EL.Max != getCouldNotCompute()) + MaxBECount = getUMinFromMismatchedTypes(MaxBECount, EL.Max); } - return BackedgeTakenInfo(BECount, MaxBECount); + return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount); } -/// ComputeBackedgeTakenCountFromExit - Compute the number of times the backedge -/// of the specified loop will execute if it exits via the specified block. -ScalarEvolution::BackedgeTakenInfo -ScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L, - BasicBlock *ExitingBlock) { +/// ComputeExitLimit - Compute the number of times the backedge of the specified +/// loop will execute if it exits via the specified block. +ScalarEvolution::ExitLimit +ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { // Okay, we've chosen an exiting block. See what condition causes us to // exit at this block. @@ -3759,97 +4162,91 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L, } // Proceed to the next level to examine the exit condition expression. - return ComputeBackedgeTakenCountFromExitCond(L, ExitBr->getCondition(), - ExitBr->getSuccessor(0), - ExitBr->getSuccessor(1)); + return ComputeExitLimitFromCond(L, ExitBr->getCondition(), + ExitBr->getSuccessor(0), + ExitBr->getSuccessor(1)); } -/// ComputeBackedgeTakenCountFromExitCond - Compute the number of times the +/// ComputeExitLimitFromCond - Compute the number of times the /// backedge of the specified loop will execute if its exit condition /// were a conditional branch of ExitCond, TBB, and FBB. -ScalarEvolution::BackedgeTakenInfo -ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L, - Value *ExitCond, - BasicBlock *TBB, - BasicBlock *FBB) { +ScalarEvolution::ExitLimit +ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, + Value *ExitCond, + BasicBlock *TBB, + BasicBlock *FBB) { // Check if the controlling expression for this loop is an And or Or. if (BinaryOperator *BO = dyn_cast(ExitCond)) { if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. - BackedgeTakenInfo BTI0 = - ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB); - BackedgeTakenInfo BTI1 = - ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB); + ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB); + ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (L->contains(TBB)) { // Both conditions must be true for the loop to continue executing. // Choose the less conservative count. - if (BTI0.Exact == getCouldNotCompute() || - BTI1.Exact == getCouldNotCompute()) + if (EL0.Exact == getCouldNotCompute() || + EL1.Exact == getCouldNotCompute()) BECount = getCouldNotCompute(); else - BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact); - if (BTI0.Max == getCouldNotCompute()) - MaxBECount = BTI1.Max; - else if (BTI1.Max == getCouldNotCompute()) - MaxBECount = BTI0.Max; + BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact); + if (EL0.Max == getCouldNotCompute()) + MaxBECount = EL1.Max; + else if (EL1.Max == getCouldNotCompute()) + MaxBECount = EL0.Max; else - MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max); + MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); } else { - // Both conditions must be true for the loop to exit. + // Both conditions must be true at the same time for the loop to exit. + // For now, be conservative. assert(L->contains(FBB) && "Loop block has no successor in loop!"); - if (BTI0.Exact != getCouldNotCompute() && - BTI1.Exact != getCouldNotCompute()) - BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact); - if (BTI0.Max != getCouldNotCompute() && - BTI1.Max != getCouldNotCompute()) - MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max); + if (EL0.Max == EL1.Max) + MaxBECount = EL0.Max; + if (EL0.Exact == EL1.Exact) + BECount = EL0.Exact; } - return BackedgeTakenInfo(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount); } if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. - BackedgeTakenInfo BTI0 = - ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB); - BackedgeTakenInfo BTI1 = - ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB); + ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB); + ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (L->contains(FBB)) { // Both conditions must be false for the loop to continue executing. // Choose the less conservative count. - if (BTI0.Exact == getCouldNotCompute() || - BTI1.Exact == getCouldNotCompute()) + if (EL0.Exact == getCouldNotCompute() || + EL1.Exact == getCouldNotCompute()) BECount = getCouldNotCompute(); else - BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact); - if (BTI0.Max == getCouldNotCompute()) - MaxBECount = BTI1.Max; - else if (BTI1.Max == getCouldNotCompute()) - MaxBECount = BTI0.Max; + BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact); + if (EL0.Max == getCouldNotCompute()) + MaxBECount = EL1.Max; + else if (EL1.Max == getCouldNotCompute()) + MaxBECount = EL0.Max; else - MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max); + MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); } else { - // Both conditions must be false for the loop to exit. + // Both conditions must be false at the same time for the loop to exit. + // For now, be conservative. assert(L->contains(TBB) && "Loop block has no successor in loop!"); - if (BTI0.Exact != getCouldNotCompute() && - BTI1.Exact != getCouldNotCompute()) - BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact); - if (BTI0.Max != getCouldNotCompute() && - BTI1.Max != getCouldNotCompute()) - MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max); + if (EL0.Max == EL1.Max) + MaxBECount = EL0.Max; + if (EL0.Exact == EL1.Exact) + BECount = EL0.Exact; } - return BackedgeTakenInfo(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount); } } // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. if (ICmpInst *ExitCondICmp = dyn_cast(ExitCond)) - return ComputeBackedgeTakenCountFromExitCondICmp(L, ExitCondICmp, TBB, FBB); + return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB); // Check for a constant condition. These are normally stripped out by // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to @@ -3861,21 +4258,21 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L, return getCouldNotCompute(); else // The backedge is never taken. - return getIntegerSCEV(0, CI->getType()); + return getConstant(CI->getType(), 0); } // If it's not an integer or pointer comparison then compute it the hard way. - return ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB)); + return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); } -/// ComputeBackedgeTakenCountFromExitCondICmp - Compute the number of times the +/// ComputeExitLimitFromICmp - Compute the number of times the /// backedge of the specified loop will execute if its exit condition /// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB. -ScalarEvolution::BackedgeTakenInfo -ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, - ICmpInst *ExitCond, - BasicBlock *TBB, - BasicBlock *FBB) { +ScalarEvolution::ExitLimit +ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, + ICmpInst *ExitCond, + BasicBlock *TBB, + BasicBlock *FBB) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Cond; @@ -3887,8 +4284,8 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, // Handle common loops like: for (X = "string"; *X; ++X) if (LoadInst *LI = dyn_cast(ExitCond->getOperand(0))) if (Constant *RHS = dyn_cast(ExitCond->getOperand(1))) { - BackedgeTakenInfo ItCnt = - ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond); + ExitLimit ItCnt = + ComputeLoadConstantCompareExitLimit(LI, RHS, L, Cond); if (ItCnt.hasAnyInfo()) return ItCnt; } @@ -3902,12 +4299,15 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, // At this point, we would like to compute how many iterations of the // loop the predicate will return true for these inputs. - if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) { + if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) { // If there is a loop-invariant, force it into the RHS. std::swap(LHS, RHS); Cond = ICmpInst::getSwappedPredicate(Cond); } + // Simplify the operands before analyzing them. + (void)SimplifyICmpOperands(Cond, LHS, RHS); + // If we have a comparison of a chrec against a constant, try to use value // ranges to answer this query. if (const SCEVConstant *RHSC = dyn_cast(RHS)) @@ -3921,90 +4321,39 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, if (!isa(Ret)) return Ret; } - // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by - // adding or subtracting 1 from one of the operands. - switch (Cond) { - case ICmpInst::ICMP_SLE: - if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) { - RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, - /*HasNUW=*/false, /*HasNSW=*/true); - Cond = ICmpInst::ICMP_SLT; - } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) { - LHS = getAddExpr(getConstant(RHS->getType(), -1, true), LHS, - /*HasNUW=*/false, /*HasNSW=*/true); - Cond = ICmpInst::ICMP_SLT; - } - break; - case ICmpInst::ICMP_SGE: - if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) { - RHS = getAddExpr(getConstant(RHS->getType(), -1, true), RHS, - /*HasNUW=*/false, /*HasNSW=*/true); - Cond = ICmpInst::ICMP_SGT; - } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) { - LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, - /*HasNUW=*/false, /*HasNSW=*/true); - Cond = ICmpInst::ICMP_SGT; - } - break; - case ICmpInst::ICMP_ULE: - if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) { - RHS = getAddExpr(getConstant(RHS->getType(), 1, false), RHS, - /*HasNUW=*/true, /*HasNSW=*/false); - Cond = ICmpInst::ICMP_ULT; - } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) { - LHS = getAddExpr(getConstant(RHS->getType(), -1, false), LHS, - /*HasNUW=*/true, /*HasNSW=*/false); - Cond = ICmpInst::ICMP_ULT; - } - break; - case ICmpInst::ICMP_UGE: - if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) { - RHS = getAddExpr(getConstant(RHS->getType(), -1, false), RHS, - /*HasNUW=*/true, /*HasNSW=*/false); - Cond = ICmpInst::ICMP_UGT; - } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) { - LHS = getAddExpr(getConstant(RHS->getType(), 1, false), LHS, - /*HasNUW=*/true, /*HasNSW=*/false); - Cond = ICmpInst::ICMP_UGT; - } - break; - default: - break; - } - switch (Cond) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) - BackedgeTakenInfo BTI = HowFarToZero(getMinusSCEV(LHS, RHS), L); - if (BTI.hasAnyInfo()) return BTI; + ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L); + if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_EQ: { // while (X == Y) // Convert to: while (X-Y == 0) - BackedgeTakenInfo BTI = HowFarToNonZero(getMinusSCEV(LHS, RHS), L); - if (BTI.hasAnyInfo()) return BTI; + ExitLimit EL = HowFarToNonZero(getMinusSCEV(LHS, RHS), L); + if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SLT: { - BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, true); - if (BTI.hasAnyInfo()) return BTI; + ExitLimit EL = HowManyLessThans(LHS, RHS, L, true); + if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SGT: { - BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS), + ExitLimit EL = HowManyLessThans(getNotSCEV(LHS), getNotSCEV(RHS), L, true); - if (BTI.hasAnyInfo()) return BTI; + if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_ULT: { - BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, false); - if (BTI.hasAnyInfo()) return BTI; + ExitLimit EL = HowManyLessThans(LHS, RHS, L, false); + if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_UGT: { - BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS), + ExitLimit EL = HowManyLessThans(getNotSCEV(LHS), getNotSCEV(RHS), L, false); - if (BTI.hasAnyInfo()) return BTI; + if (EL.hasAnyInfo()) return EL; break; } default: @@ -4018,8 +4367,7 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, #endif break; } - return - ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB)); + return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); } static ConstantInt * @@ -4049,10 +4397,10 @@ GetAddressedElementFromGlobal(GlobalVariable *GV, if (Idx >= CA->getNumOperands()) return 0; // Bogus program Init = cast(CA->getOperand(Idx)); } else if (isa(Init)) { - if (const StructType *STy = dyn_cast(Init->getType())) { + if (StructType *STy = dyn_cast(Init->getType())) { assert(Idx < STy->getNumElements() && "Bad struct index!"); Init = Constant::getNullValue(STy->getElementType(Idx)); - } else if (const ArrayType *ATy = dyn_cast(Init->getType())) { + } else if (ArrayType *ATy = dyn_cast(Init->getType())) { if (Idx >= ATy->getNumElements()) return 0; // Bogus program Init = Constant::getNullValue(ATy->getElementType()); } else { @@ -4066,15 +4414,16 @@ GetAddressedElementFromGlobal(GlobalVariable *GV, return Init; } -/// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition of +/// ComputeLoadConstantCompareExitLimit - Given an exit condition of /// 'icmp op load X, cst', try to see if we can compute the backedge /// execution count. -ScalarEvolution::BackedgeTakenInfo -ScalarEvolution::ComputeLoadConstantCompareBackedgeTakenCount( - LoadInst *LI, - Constant *RHS, - const Loop *L, - ICmpInst::Predicate predicate) { +ScalarEvolution::ExitLimit +ScalarEvolution::ComputeLoadConstantCompareExitLimit( + LoadInst *LI, + Constant *RHS, + const Loop *L, + ICmpInst::Predicate predicate) { + if (LI->isVolatile()) return getCouldNotCompute(); // Check to see if the loaded pointer is a getelementptr of a global. @@ -4112,7 +4461,7 @@ ScalarEvolution::ComputeLoadConstantCompareBackedgeTakenCount( // We can only recognize very limited forms of loop index expressions, in // particular, only affine AddRec's like {C1,+,C2}. const SCEVAddRecExpr *IdxExpr = dyn_cast(Idx); - if (!IdxExpr || !IdxExpr->isAffine() || IdxExpr->isLoopInvariant(L) || + if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) || !isa(IdxExpr->getOperand(0)) || !isa(IdxExpr->getOperand(1))) return getCouldNotCompute(); @@ -4187,8 +4536,7 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { // constant or derived from a PHI node themselves. PHINode *PHI = 0; for (unsigned Op = 0, e = I->getNumOperands(); Op != e; ++Op) - if (!(isa(I->getOperand(Op)) || - isa(I->getOperand(Op)))) { + if (!isa(I->getOperand(Op))) { PHINode *P = getConstantEvolvingPHI(I->getOperand(Op), L); if (P == 0) return 0; // Not evolving from PHI if (PHI == 0) @@ -4209,11 +4557,9 @@ static Constant *EvaluateExpression(Value *V, Constant *PHIVal, const TargetData *TD) { if (isa(V)) return PHIVal; if (Constant *C = dyn_cast(V)) return C; - if (GlobalValue *GV = dyn_cast(V)) return GV; Instruction *I = cast(V); - std::vector Operands; - Operands.resize(I->getNumOperands()); + std::vector Operands(I->getNumOperands()); for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal, TD); @@ -4223,8 +4569,7 @@ static Constant *EvaluateExpression(Value *V, Constant *PHIVal, if (const CmpInst *CI = dyn_cast(I)) return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], Operands[1], TD); - return ConstantFoldInstOperands(I->getOpcode(), I->getType(), - &Operands[0], Operands.size(), TD); + return ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, TD); } /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is @@ -4235,7 +4580,7 @@ Constant * ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, const APInt &BEs, const Loop *L) { - std::map::iterator I = + DenseMap::const_iterator I = ConstantEvolutionLoopExitValue.find(PN); if (I != ConstantEvolutionLoopExitValue.end()) return I->second; @@ -4255,8 +4600,8 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, return RetVal = 0; // Must be a constant. Value *BEValue = PN->getIncomingValue(SecondIsBackedge); - PHINode *PN2 = getConstantEvolvingPHI(BEValue, L); - if (PN2 != PN) + if (getConstantEvolvingPHI(BEValue, L) != PN && + !isa(BEValue)) return RetVal = 0; // Not derived from same PHI. // Execute the loop symbolically to determine the exit value. @@ -4279,20 +4624,22 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, } } -/// ComputeBackedgeTakenCountExhaustively - If the loop is known to execute a +/// ComputeExitCountExhaustively - If the loop is known to execute a /// constant number of times (the condition evolves only from constants), /// try to evaluate a few iterations of the loop until we get the exit /// condition gets a value of ExitWhen (true or false). If we cannot /// evaluate the trip count of the loop, return getCouldNotCompute(). -const SCEV * -ScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L, - Value *Cond, - bool ExitWhen) { +const SCEV * ScalarEvolution::ComputeExitCountExhaustively(const Loop *L, + Value *Cond, + bool ExitWhen) { PHINode *PN = getConstantEvolvingPHI(Cond, L); if (PN == 0) return getCouldNotCompute(); - // Since the loop is canonicalized, the PHI node must have two entries. One - // entry must be a constant (coming in from outside of the loop), and the + // If the loop is canonicalized, the PHI will have exactly two entries. + // That's the only form we support here. + if (PN->getNumIncomingValues() != 2) return getCouldNotCompute(); + + // One entry must be a constant (coming in from outside of the loop), and the // second must be derived from the same PHI. bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1)); Constant *StartCST = @@ -4300,8 +4647,9 @@ ScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L, if (StartCST == 0) return getCouldNotCompute(); // Must be a constant. Value *BEValue = PN->getIncomingValue(SecondIsBackedge); - PHINode *PN2 = getConstantEvolvingPHI(BEValue, L); - if (PN2 != PN) return getCouldNotCompute(); // Not derived from same PHI. + if (getConstantEvolvingPHI(BEValue, L) != PN && + !isa(BEValue)) + return getCouldNotCompute(); // Not derived from same PHI. // Okay, we find a PHI node that defines the trip count of this loop. Execute // the loop symbolically to determine when the condition gets a value of @@ -4389,54 +4737,51 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // the arguments into constants, and if so, try to constant propagate the // result. This is particularly useful for computing loop exit values. if (CanConstantFold(I)) { - std::vector Operands; - Operands.reserve(I->getNumOperands()); + SmallVector Operands; + bool MadeImprovement = false; for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { Value *Op = I->getOperand(i); if (Constant *C = dyn_cast(Op)) { Operands.push_back(C); - } else { - // If any of the operands is non-constant and if they are - // non-integer and non-pointer, don't even try to analyze them - // with scev techniques. - if (!isSCEVable(Op->getType())) - return V; - - const SCEV *OpV = getSCEVAtScope(Op, L); - if (const SCEVConstant *SC = dyn_cast(OpV)) { - Constant *C = SC->getValue(); - if (C->getType() != Op->getType()) - C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false, - Op->getType(), - false), - C, Op->getType()); - Operands.push_back(C); - } else if (const SCEVUnknown *SU = dyn_cast(OpV)) { - if (Constant *C = dyn_cast(SU->getValue())) { - if (C->getType() != Op->getType()) - C = - ConstantExpr::getCast(CastInst::getCastOpcode(C, false, - Op->getType(), - false), - C, Op->getType()); - Operands.push_back(C); - } else - return V; - } else { - return V; - } + continue; } + + // If any of the operands is non-constant and if they are + // non-integer and non-pointer, don't even try to analyze them + // with scev techniques. + if (!isSCEVable(Op->getType())) + return V; + + const SCEV *OrigV = getSCEV(Op); + const SCEV *OpV = getSCEVAtScope(OrigV, L); + MadeImprovement |= OrigV != OpV; + + Constant *C = 0; + if (const SCEVConstant *SC = dyn_cast(OpV)) + C = SC->getValue(); + if (const SCEVUnknown *SU = dyn_cast(OpV)) + C = dyn_cast(SU->getValue()); + if (!C) return V; + if (C->getType() != Op->getType()) + C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false, + Op->getType(), + false), + C, Op->getType()); + Operands.push_back(C); } - Constant *C = 0; - if (const CmpInst *CI = dyn_cast(I)) - C = ConstantFoldCompareInstOperands(CI->getPredicate(), - Operands[0], Operands[1], TD); - else - C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), - &Operands[0], Operands.size(), TD); - if (C) + // Check to see if getSCEVAtScope actually made an improvement. + if (MadeImprovement) { + Constant *C = 0; + if (const CmpInst *CI = dyn_cast(I)) + C = ConstantFoldCompareInstOperands(CI->getPredicate(), + Operands[0], Operands[1], TD); + else + C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), + Operands, TD); + if (!C) return V; return getSCEV(C); + } } } @@ -4486,7 +4831,37 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // If this is a loop recurrence for a loop that does not contain L, then we // are dealing with the final value computed by the loop. if (const SCEVAddRecExpr *AddRec = dyn_cast(V)) { - if (!L || !AddRec->getLoop()->contains(L)) { + // First, attempt to evaluate each operand. + // Avoid performing the look-up in the common case where the specified + // expression has no loop-variant portions. + for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { + const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L); + if (OpAtScope == AddRec->getOperand(i)) + continue; + + // Okay, at least one of these operands is loop variant but might be + // foldable. Build a new instance of the folded commutative expression. + SmallVector NewOps(AddRec->op_begin(), + AddRec->op_begin()+i); + NewOps.push_back(OpAtScope); + for (++i; i != e; ++i) + NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L)); + + const SCEV *FoldedRec = + getAddRecExpr(NewOps, AddRec->getLoop(), + AddRec->getNoWrapFlags(SCEV::FlagNW)); + AddRec = dyn_cast(FoldedRec); + // The addrec may be folded to a nonrecurrence, for example, if the + // induction variable is multiplied by zero after constant folding. Go + // ahead and return the folded value. + if (!AddRec) + return FoldedRec; + break; + } + + // If the scope is outside the addrec's loop, evaluate it by using the + // loop exit value of the addrec. + if (!AddRec->getLoop()->contains(L)) { // To evaluate this recurrence, we need to know how many times the AddRec // loop iterates. Compute this now. const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); @@ -4495,6 +4870,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // Then, evaluate the AddRec. return AddRec->evaluateAtIteration(BackedgeTakenCount, *this); } + return AddRec; } @@ -4565,7 +4941,7 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, // bit width during computations. APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D APInt Mod(BW + 1, 0); - Mod.set(BW - Mult2); // Mod = N / D + Mod.setBit(BW - Mult2); // Mod = N / D APInt I = AD.multiplicativeInverse(Mod); // 4. Compute the minimum unsigned root of the equation: @@ -4644,7 +5020,12 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { /// HowFarToZero - Return the number of times a backedge comparing the specified /// value to zero will execute. If not computable, return CouldNotCompute. -ScalarEvolution::BackedgeTakenInfo +/// +/// This is only used for loops with a "x != y" exit test. The exit condition is +/// now expressed as a single expression, V = x-y. So the exit test is +/// effectively V != 0. We know and take advantage of the fact that this +/// expression only being used in a comparison by zero context. +ScalarEvolution::ExitLimit ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { // If the value is a constant if (const SCEVConstant *C = dyn_cast(V)) { @@ -4657,55 +5038,23 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { if (!AddRec || AddRec->getLoop() != L) return getCouldNotCompute(); - if (AddRec->isAffine()) { - // If this is an affine expression, the execution count of this branch is - // the minimum unsigned root of the following equation: - // - // Start + Step*N = 0 (mod 2^BW) - // - // equivalent to: - // - // Step*N = -Start (mod 2^BW) - // - // where BW is the common bit width of Start and Step. - - // Get the initial value for the loop. - const SCEV *Start = getSCEVAtScope(AddRec->getStart(), - L->getParentLoop()); - const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), - L->getParentLoop()); - - if (const SCEVConstant *StepC = dyn_cast(Step)) { - // For now we handle only constant steps. - - // First, handle unitary steps. - if (StepC->getValue()->equalsInt(1)) // 1*N = -Start (mod 2^BW), so: - return getNegativeSCEV(Start); // N = -Start (as unsigned) - if (StepC->getValue()->isAllOnesValue()) // -1*N = -Start (mod 2^BW), so: - return Start; // N = Start (as unsigned) - - // Then, try to solve the above equation provided that Start is constant. - if (const SCEVConstant *StartC = dyn_cast(Start)) - return SolveLinEquationWithOverflow(StepC->getValue()->getValue(), - -StartC->getValue()->getValue(), - *this); - } - } else if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { - // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of - // the quadratic equation to solve it. - std::pair Roots = SolveQuadraticEquation(AddRec, - *this); + // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of + // the quadratic equation to solve it. + if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { + std::pair Roots = + SolveQuadraticEquation(AddRec, *this); const SCEVConstant *R1 = dyn_cast(Roots.first); const SCEVConstant *R2 = dyn_cast(Roots.second); - if (R1) { + if (R1 && R2) { #if 0 dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1 << " sol#2: " << *R2 << "\n"; #endif // Pick the smallest positive root value. if (ConstantInt *CB = - dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, - R1->getValue(), R2->getValue()))) { + dyn_cast(ConstantExpr::getICmp(CmpInst::ICMP_ULT, + R1->getValue(), + R2->getValue()))) { if (CB->getZExtValue() == false) std::swap(R1, R2); // R1 is the minimum root now. @@ -4717,15 +5066,78 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { return R1; // We found a quadratic root! } } + return getCouldNotCompute(); } + // Otherwise we can only handle this if it is affine. + if (!AddRec->isAffine()) + return getCouldNotCompute(); + + // If this is an affine expression, the execution count of this branch is + // the minimum unsigned root of the following equation: + // + // Start + Step*N = 0 (mod 2^BW) + // + // equivalent to: + // + // Step*N = -Start (mod 2^BW) + // + // where BW is the common bit width of Start and Step. + + // Get the initial value for the loop. + const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); + const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); + + // For now we handle only constant steps. + // + // TODO: Handle a nonconstant Step given AddRec. If the + // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap + // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step. + // We have not yet seen any such cases. + const SCEVConstant *StepC = dyn_cast(Step); + if (StepC == 0) + return getCouldNotCompute(); + + // For positive steps (counting up until unsigned overflow): + // N = -Start/Step (as unsigned) + // For negative steps (counting down to zero): + // N = Start/-Step + // First compute the unsigned distance from zero in the direction of Step. + bool CountDown = StepC->getValue()->getValue().isNegative(); + const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start); + + // Handle unitary steps, which cannot wraparound. + // 1*N = -Start; -1*N = Start (mod 2^BW), so: + // N = Distance (as unsigned) + if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) + return Distance; + + // If the recurrence is known not to wraparound, unsigned divide computes the + // back edge count. We know that the value will either become zero (and thus + // the loop terminates), that the loop will terminate through some other exit + // condition first, or that the loop has undefined behavior. This means + // we can't "miss" the exit value, even with nonunit stride. + // + // FIXME: Prove that loops always exhibits *acceptable* undefined + // behavior. Loops must exhibit defined behavior until a wrapped value is + // actually used. So the trip count computed by udiv could be smaller than the + // number of well-defined iterations. + if (AddRec->getNoWrapFlags(SCEV::FlagNW)) + // FIXME: We really want an "isexact" bit for udiv. + return getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); + + // Then, try to solve the above equation provided that Start is constant. + if (const SCEVConstant *StartC = dyn_cast(Start)) + return SolveLinEquationWithOverflow(StepC->getValue()->getValue(), + -StartC->getValue()->getValue(), + *this); return getCouldNotCompute(); } /// HowFarToNonZero - Return the number of times a backedge checking the /// specified value for nonzero will execute. If not computable, return /// CouldNotCompute -ScalarEvolution::BackedgeTakenInfo +ScalarEvolution::ExitLimit ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { // Loops that look like: while (X == 0) are very strange indeed. We don't // handle them yet except for the trivial case. This could be expanded in the @@ -4735,7 +5147,7 @@ ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { // already. If so, the backedge will execute zero times. if (const SCEVConstant *C = dyn_cast(V)) { if (!C->getValue()->isNullValue()) - return getIntegerSCEV(0, C->getType()); + return getConstant(C->getType(), 0); return getCouldNotCompute(); // Otherwise it will loop infinitely. } @@ -4744,23 +5156,6 @@ ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { return getCouldNotCompute(); } -/// getLoopPredecessor - If the given loop's header has exactly one unique -/// predecessor outside the loop, return it. Otherwise return null. -/// This is less strict that the loop "preheader" concept, which requires -/// the predecessor to have only one single successor. -/// -BasicBlock *ScalarEvolution::getLoopPredecessor(const Loop *L) { - BasicBlock *Header = L->getHeader(); - BasicBlock *Pred = 0; - for (pred_iterator PI = pred_begin(Header), E = pred_end(Header); - PI != E; ++PI) - if (!L->contains(*PI)) { - if (Pred && Pred != *PI) return 0; // Multiple predecessors. - Pred = *PI; - } - return Pred; -} - /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB /// (which may not be an immediate predecessor) which has exactly one /// successor from which BB is reachable, or null if no such block is @@ -4778,7 +5173,7 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { // If the header has a unique predecessor outside the loop, it must be // a block that has exactly one successor that can reach the loop. if (Loop *L = LI->getLoopFor(BB)) - return std::make_pair(getLoopPredecessor(L), L->getHeader()); + return std::make_pair(L->getLoopPredecessor(), L->getHeader()); return std::pair(); } @@ -4831,13 +5226,16 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, } // If we're comparing an addrec with a value which is loop-invariant in the - // addrec's loop, put the addrec on the left. - if (const SCEVAddRecExpr *AR = dyn_cast(RHS)) - if (LHS->isLoopInvariant(AR->getLoop())) { + // addrec's loop, put the addrec on the left. Also make a dominance check, + // as both operands could be addrecs loop-invariant in each other's loop. + if (const SCEVAddRecExpr *AR = dyn_cast(RHS)) { + const Loop *L = AR->getLoop(); + if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) { std::swap(LHS, RHS); Pred = ICmpInst::getSwappedPredicate(Pred); Changed = true; } + } // If there's a constant operand, canonicalize comparisons with boundary // cases, and canonicalize *-or-equal comparisons to regular comparisons. @@ -4987,19 +5385,78 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, goto trivially_false; } + // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by + // adding or subtracting 1 from one of the operands. + switch (Pred) { + case ICmpInst::ICMP_SLE: + if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) { + RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, + SCEV::FlagNSW); + Pred = ICmpInst::ICMP_SLT; + Changed = true; + } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) { + LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS, + SCEV::FlagNSW); + Pred = ICmpInst::ICMP_SLT; + Changed = true; + } + break; + case ICmpInst::ICMP_SGE: + if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) { + RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, + SCEV::FlagNSW); + Pred = ICmpInst::ICMP_SGT; + Changed = true; + } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) { + LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, + SCEV::FlagNSW); + Pred = ICmpInst::ICMP_SGT; + Changed = true; + } + break; + case ICmpInst::ICMP_ULE: + if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) { + RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, + SCEV::FlagNUW); + Pred = ICmpInst::ICMP_ULT; + Changed = true; + } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) { + LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS, + SCEV::FlagNUW); + Pred = ICmpInst::ICMP_ULT; + Changed = true; + } + break; + case ICmpInst::ICMP_UGE: + if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) { + RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, + SCEV::FlagNUW); + Pred = ICmpInst::ICMP_UGT; + Changed = true; + } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) { + LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, + SCEV::FlagNUW); + Pred = ICmpInst::ICMP_UGT; + Changed = true; + } + break; + default: + break; + } + // TODO: More simplifications are possible here. return Changed; trivially_true: // Return 0 == 0. - LHS = RHS = getConstant(Type::getInt1Ty(getContext()), 0); + LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); Pred = ICmpInst::ICMP_EQ; return true; trivially_false: // Return 0 != 0. - LHS = RHS = getConstant(Type::getInt1Ty(getContext()), 0); + LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); Pred = ICmpInst::ICMP_NE; return true; } @@ -5035,15 +5492,13 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, if (isLoopEntryGuardedByCond( AR->getLoop(), Pred, AR->getStart(), RHS) && isLoopBackedgeGuardedByCond( - AR->getLoop(), Pred, - getAddExpr(AR, AR->getStepRecurrence(*this)), RHS)) + AR->getLoop(), Pred, AR->getPostIncExpr(*this), RHS)) return true; if (const SCEVAddRecExpr *AR = dyn_cast(RHS)) if (isLoopEntryGuardedByCond( AR->getLoop(), Pred, LHS, AR->getStart()) && isLoopBackedgeGuardedByCond( - AR->getLoop(), Pred, - LHS, getAddExpr(AR, AR->getStepRecurrence(*this)))) + AR->getLoop(), Pred, LHS, AR->getPostIncExpr(*this))) return true; // Otherwise see what can be done with known constant ranges. @@ -5150,7 +5605,8 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, LoopContinuePredicate->isUnconditional()) return false; - return isImpliedCond(LoopContinuePredicate->getCondition(), Pred, LHS, RHS, + return isImpliedCond(Pred, LHS, RHS, + LoopContinuePredicate->getCondition(), LoopContinuePredicate->getSuccessor(0) != L->getHeader()); } @@ -5169,7 +5625,7 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, // as there are predecessors that can be found that have unique successors // leading to the original header. for (std::pair - Pair(getLoopPredecessor(L), L->getHeader()); + Pair(L->getLoopPredecessor(), L->getHeader()); Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { @@ -5179,7 +5635,8 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, LoopEntryPredicate->isUnconditional()) continue; - if (isImpliedCond(LoopEntryPredicate->getCondition(), Pred, LHS, RHS, + if (isImpliedCond(Pred, LHS, RHS, + LoopEntryPredicate->getCondition(), LoopEntryPredicate->getSuccessor(0) != Pair.second)) return true; } @@ -5189,24 +5646,24 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, /// isImpliedCond - Test whether the condition described by Pred, LHS, /// and RHS is true whenever the given Cond value evaluates to true. -bool ScalarEvolution::isImpliedCond(Value *CondValue, - ICmpInst::Predicate Pred, +bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + Value *FoundCondValue, bool Inverse) { // Recursively handle And and Or conditions. - if (BinaryOperator *BO = dyn_cast(CondValue)) { + if (BinaryOperator *BO = dyn_cast(FoundCondValue)) { if (BO->getOpcode() == Instruction::And) { if (!Inverse) - return isImpliedCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) || - isImpliedCond(BO->getOperand(1), Pred, LHS, RHS, Inverse); + return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || + isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); } else if (BO->getOpcode() == Instruction::Or) { if (Inverse) - return isImpliedCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) || - isImpliedCond(BO->getOperand(1), Pred, LHS, RHS, Inverse); + return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || + isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); } } - ICmpInst *ICI = dyn_cast(CondValue); + ICmpInst *ICI = dyn_cast(FoundCondValue); if (!ICI) return false; // Bail if the ICmp's operands' types are wider than the needed type @@ -5246,10 +5703,10 @@ bool ScalarEvolution::isImpliedCond(Value *CondValue, // canonicalized the comparison. if (SimplifyICmpOperands(Pred, LHS, RHS)) if (LHS == RHS) - return Pred == ICmpInst::ICMP_EQ; + return CmpInst::isTrueWhenEqual(Pred); if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS)) if (FoundLHS == FoundRHS) - return Pred == ICmpInst::ICMP_NE; + return CmpInst::isFalseWhenEqual(Pred); // Check to see if we can make the LHS or RHS match. if (LHS == FoundRHS || RHS == FoundLHS) { @@ -5359,8 +5816,15 @@ const SCEV *ScalarEvolution::getBECount(const SCEV *Start, assert(!isKnownNegative(Step) && "This code doesn't handle negative strides yet!"); - const Type *Ty = Start->getType(); - const SCEV *NegOne = getIntegerSCEV(-1, Ty); + Type *Ty = Start->getType(); + + // When Start == End, we have an exact BECount == 0. Short-circuit this case + // here because SCEV may not be able to determine that the unsigned division + // after rounding is zero. + if (Start == End) + return getConstant(Ty, 0); + + const SCEV *NegOne = getConstant(Ty, (uint64_t)-1); const SCEV *Diff = getMinusSCEV(End, Start); const SCEV *RoundUp = getAddExpr(Step, NegOne); @@ -5371,7 +5835,7 @@ const SCEV *ScalarEvolution::getBECount(const SCEV *Start, if (!NoWrap) { // Check Add for unsigned overflow. // TODO: More sophisticated things could be done here. - const Type *WideTy = IntegerType::get(getContext(), + Type *WideTy = IntegerType::get(getContext(), getTypeSizeInBits(Ty) + 1); const SCEV *EDiff = getZeroExtendExpr(Diff, WideTy); const SCEV *ERoundUp = getZeroExtendExpr(RoundUp, WideTy); @@ -5386,19 +5850,19 @@ const SCEV *ScalarEvolution::getBECount(const SCEV *Start, /// HowManyLessThans - Return the number of times a backedge containing the /// specified less-than comparison will execute. If not computable, return /// CouldNotCompute. -ScalarEvolution::BackedgeTakenInfo +ScalarEvolution::ExitLimit ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool isSigned) { // Only handle: "ADDREC < LoopInvariant". - if (!RHS->isLoopInvariant(L)) return getCouldNotCompute(); + if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); const SCEVAddRecExpr *AddRec = dyn_cast(LHS); if (!AddRec || AddRec->getLoop() != L) return getCouldNotCompute(); // Check to see if we have a flag which makes analysis easy. - bool NoWrap = isSigned ? AddRec->hasNoSignedWrap() : - AddRec->hasNoUnsignedWrap(); + bool NoWrap = isSigned ? AddRec->getNoWrapFlags(SCEV::FlagNSW) : + AddRec->getNoWrapFlags(SCEV::FlagNUW); if (AddRec->isAffine()) { unsigned BitWidth = getTypeSizeInBits(AddRec->getType()); @@ -5416,7 +5880,7 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // behavior, so if wrap does occur, the loop could either terminate or // loop infinitely, but in either case, the loop is guaranteed to // iterate at least until the iteration where the wrapping occurs. - const SCEV *One = getIntegerSCEV(1, Step->getType()); + const SCEV *One = getConstant(Step->getType(), 1); if (isSigned) { APInt Max = APInt::getSignedMaxValue(BitWidth); if ((Max - getSignedRange(getMinusSCEV(Step, One)).getSignedMax()) @@ -5467,7 +5931,7 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // This allows the subsequent ceiling division of (N+(step-1))/step to // compute the correct value. const SCEV *StepMinusOne = getMinusSCEV(Step, - getIntegerSCEV(1, Step->getType())); + getConstant(Step->getType(), 1)); MaxEnd = isSigned ? getSMinExpr(MaxEnd, getMinusSCEV(getConstant(APInt::getSignedMaxValue(BitWidth)), @@ -5482,9 +5946,18 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // The maximum backedge count is similar, except using the minimum start // value and the maximum end value. - const SCEV *MaxBECount = getBECount(MinStart, MaxEnd, Step, NoWrap); + // If we already have an exact constant BECount, use it instead. + const SCEV *MaxBECount = isa(BECount) ? BECount + : getBECount(MinStart, MaxEnd, Step, NoWrap); + + // If the stride is nonconstant, and NoWrap == true, then + // getBECount(MinStart, MaxEnd) may not compute. This would result in an + // exact BECount and invalid MaxBECount, which should be avoided to catch + // more optimization opportunities. + if (isa(MaxBECount)) + MaxBECount = BECount; - return BackedgeTakenInfo(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount); } return getCouldNotCompute(); @@ -5504,8 +5977,9 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, if (const SCEVConstant *SC = dyn_cast(getStart())) if (!SC->getValue()->isZero()) { SmallVector Operands(op_begin(), op_end()); - Operands[0] = SE.getIntegerSCEV(0, SC->getType()); - const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop()); + Operands[0] = SE.getConstant(SC->getType(), 0); + const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(), + getNoWrapFlags(FlagNW)); if (const SCEVAddRecExpr *ShiftedAddRec = dyn_cast(Shifted)) return ShiftedAddRec->getNumIterationsInRange( @@ -5528,7 +6002,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // iteration exits. unsigned BitWidth = SE.getTypeSizeInBits(getType()); if (!Range.contains(APInt(BitWidth, 0))) - return SE.getIntegerSCEV(0, getType()); + return SE.getConstant(getType(), 0); if (isAffine()) { // If this is an affine expression then we have this situation: @@ -5566,7 +6040,9 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // Range.getUpper() is crossed. SmallVector NewOps(op_begin(), op_end()); NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); - const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop()); + const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), + // getNoWrapFlags(FlagNW) + FlagAnyWrap); // Next, solve the constructed addrec std::pair Roots = @@ -5623,20 +6099,19 @@ void ScalarEvolution::SCEVCallbackVH::deleted() { assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); if (PHINode *PN = dyn_cast(getValPtr())) SE->ConstantEvolutionLoopExitValue.erase(PN); - SE->Scalars.erase(getValPtr()); + SE->ValueExprMap.erase(getValPtr()); // this now dangles! } -void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) { +void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); // Forget all the expressions associated with users of the old value, // so that future queries will recompute the expressions using the new // value. + Value *Old = getValPtr(); SmallVector Worklist; SmallPtrSet Visited; - Value *Old = getValPtr(); - bool DeleteOld = false; for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end(); UI != UE; ++UI) Worklist.push_back(*UI); @@ -5644,27 +6119,22 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) { User *U = Worklist.pop_back_val(); // Deleting the Old value will cause this to dangle. Postpone // that until everything else is done. - if (U == Old) { - DeleteOld = true; + if (U == Old) continue; - } if (!Visited.insert(U)) continue; if (PHINode *PN = dyn_cast(U)) SE->ConstantEvolutionLoopExitValue.erase(PN); - SE->Scalars.erase(U); + SE->ValueExprMap.erase(U); for (Value::use_iterator UI = U->use_begin(), UE = U->use_end(); UI != UE; ++UI) Worklist.push_back(*UI); } - // Delete the Old value if it (indirectly) references itself. - if (DeleteOld) { - if (PHINode *PN = dyn_cast(Old)) - SE->ConstantEvolutionLoopExitValue.erase(PN); - SE->Scalars.erase(Old); - // this now dangles! - } - // this may dangle! + // Delete the Old value. + if (PHINode *PN = dyn_cast(Old)) + SE->ConstantEvolutionLoopExitValue.erase(PN); + SE->ValueExprMap.erase(Old); + // this now dangles! } ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) @@ -5675,7 +6145,8 @@ ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) //===----------------------------------------------------------------------===// ScalarEvolution::ScalarEvolution() - : FunctionPass(&ID) { + : FunctionPass(ID), FirstUnknown(0) { + initializeScalarEvolutionPass(*PassRegistry::getPassRegistry()); } bool ScalarEvolution::runOnFunction(Function &F) { @@ -5687,10 +6158,29 @@ bool ScalarEvolution::runOnFunction(Function &F) { } void ScalarEvolution::releaseMemory() { - Scalars.clear(); + // Iterate through all the SCEVUnknown instances and call their + // destructors, so that they release their references to their values. + for (SCEVUnknown *U = FirstUnknown; U; U = U->Next) + U->~SCEVUnknown(); + FirstUnknown = 0; + + ValueExprMap.clear(); + + // Free any extra memory created for ExitNotTakenInfo in the unlikely event + // that a loop had multiple computable exits. + for (DenseMap::iterator I = + BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end(); + I != E; ++I) { + I->second.clear(); + } + BackedgeTakenCounts.clear(); ConstantEvolutionLoopExitValue.clear(); ValuesAtScopes.clear(); + LoopDispositions.clear(); + BlockDispositions.clear(); + UnsignedRanges.clear(); + SignedRanges.clear(); UniqueSCEVs.clear(); SCEVAllocator.Reset(); } @@ -5753,7 +6243,7 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const { WriteAsOperand(OS, F, /*PrintType=*/false); OS << "\n"; for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) - if (isSCEVable(I->getType())) { + if (isSCEVable(I->getType()) && !isa(*I)) { OS << *I << '\n'; OS << " --> "; const SCEV *SV = SE.getSCEV(&*I); @@ -5770,7 +6260,7 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const { if (L) { OS << "\t\t" "Exits: "; const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); - if (!ExitValue->isLoopInvariant(L)) { + if (!SE.isLoopInvariant(ExitValue, L)) { OS << "<>"; } else { OS << *ExitValue; @@ -5787,3 +6277,240 @@ void ScalarEvolution::print(raw_ostream &OS, const Module *) const { PrintLoopInfo(OS, &SE, *I); } +ScalarEvolution::LoopDisposition +ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { + std::map &Values = LoopDispositions[S]; + std::pair::iterator, bool> Pair = + Values.insert(std::make_pair(L, LoopVariant)); + if (!Pair.second) + return Pair.first->second; + + LoopDisposition D = computeLoopDisposition(S, L); + return LoopDispositions[S][L] = D; +} + +ScalarEvolution::LoopDisposition +ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { + switch (S->getSCEVType()) { + case scConstant: + return LoopInvariant; + case scTruncate: + case scZeroExtend: + case scSignExtend: + return getLoopDisposition(cast(S)->getOperand(), L); + case scAddRecExpr: { + const SCEVAddRecExpr *AR = cast(S); + + // If L is the addrec's loop, it's computable. + if (AR->getLoop() == L) + return LoopComputable; + + // Add recurrences are never invariant in the function-body (null loop). + if (!L) + return LoopVariant; + + // This recurrence is variant w.r.t. L if L contains AR's loop. + if (L->contains(AR->getLoop())) + return LoopVariant; + + // This recurrence is invariant w.r.t. L if AR's loop contains L. + if (AR->getLoop()->contains(L)) + return LoopInvariant; + + // This recurrence is variant w.r.t. L if any of its operands + // are variant. + for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end(); + I != E; ++I) + if (!isLoopInvariant(*I, L)) + return LoopVariant; + + // Otherwise it's loop-invariant. + return LoopInvariant; + } + case scAddExpr: + case scMulExpr: + case scUMaxExpr: + case scSMaxExpr: { + const SCEVNAryExpr *NAry = cast(S); + bool HasVarying = false; + for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); + I != E; ++I) { + LoopDisposition D = getLoopDisposition(*I, L); + if (D == LoopVariant) + return LoopVariant; + if (D == LoopComputable) + HasVarying = true; + } + return HasVarying ? LoopComputable : LoopInvariant; + } + case scUDivExpr: { + const SCEVUDivExpr *UDiv = cast(S); + LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L); + if (LD == LoopVariant) + return LoopVariant; + LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L); + if (RD == LoopVariant) + return LoopVariant; + return (LD == LoopInvariant && RD == LoopInvariant) ? + LoopInvariant : LoopComputable; + } + case scUnknown: + // All non-instruction values are loop invariant. All instructions are loop + // invariant if they are not contained in the specified loop. + // Instructions are never considered invariant in the function body + // (null loop) because they are defined within the "loop". + if (Instruction *I = dyn_cast(cast(S)->getValue())) + return (L && !L->contains(I)) ? LoopInvariant : LoopVariant; + return LoopInvariant; + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + return LoopVariant; + default: break; + } + llvm_unreachable("Unknown SCEV kind!"); + return LoopVariant; +} + +bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) { + return getLoopDisposition(S, L) == LoopInvariant; +} + +bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) { + return getLoopDisposition(S, L) == LoopComputable; +} + +ScalarEvolution::BlockDisposition +ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { + std::map &Values = BlockDispositions[S]; + std::pair::iterator, bool> + Pair = Values.insert(std::make_pair(BB, DoesNotDominateBlock)); + if (!Pair.second) + return Pair.first->second; + + BlockDisposition D = computeBlockDisposition(S, BB); + return BlockDispositions[S][BB] = D; +} + +ScalarEvolution::BlockDisposition +ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { + switch (S->getSCEVType()) { + case scConstant: + return ProperlyDominatesBlock; + case scTruncate: + case scZeroExtend: + case scSignExtend: + return getBlockDisposition(cast(S)->getOperand(), BB); + case scAddRecExpr: { + // This uses a "dominates" query instead of "properly dominates" query + // to test for proper dominance too, because the instruction which + // produces the addrec's value is a PHI, and a PHI effectively properly + // dominates its entire containing block. + const SCEVAddRecExpr *AR = cast(S); + if (!DT->dominates(AR->getLoop()->getHeader(), BB)) + return DoesNotDominateBlock; + } + // FALL THROUGH into SCEVNAryExpr handling. + case scAddExpr: + case scMulExpr: + case scUMaxExpr: + case scSMaxExpr: { + const SCEVNAryExpr *NAry = cast(S); + bool Proper = true; + for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); + I != E; ++I) { + BlockDisposition D = getBlockDisposition(*I, BB); + if (D == DoesNotDominateBlock) + return DoesNotDominateBlock; + if (D == DominatesBlock) + Proper = false; + } + return Proper ? ProperlyDominatesBlock : DominatesBlock; + } + case scUDivExpr: { + const SCEVUDivExpr *UDiv = cast(S); + const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS(); + BlockDisposition LD = getBlockDisposition(LHS, BB); + if (LD == DoesNotDominateBlock) + return DoesNotDominateBlock; + BlockDisposition RD = getBlockDisposition(RHS, BB); + if (RD == DoesNotDominateBlock) + return DoesNotDominateBlock; + return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ? + ProperlyDominatesBlock : DominatesBlock; + } + case scUnknown: + if (Instruction *I = + dyn_cast(cast(S)->getValue())) { + if (I->getParent() == BB) + return DominatesBlock; + if (DT->properlyDominates(I->getParent(), BB)) + return ProperlyDominatesBlock; + return DoesNotDominateBlock; + } + return ProperlyDominatesBlock; + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + return DoesNotDominateBlock; + default: break; + } + llvm_unreachable("Unknown SCEV kind!"); + return DoesNotDominateBlock; +} + +bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) { + return getBlockDisposition(S, BB) >= DominatesBlock; +} + +bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) { + return getBlockDisposition(S, BB) == ProperlyDominatesBlock; +} + +bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { + switch (S->getSCEVType()) { + case scConstant: + return false; + case scTruncate: + case scZeroExtend: + case scSignExtend: { + const SCEVCastExpr *Cast = cast(S); + const SCEV *CastOp = Cast->getOperand(); + return Op == CastOp || hasOperand(CastOp, Op); + } + case scAddRecExpr: + case scAddExpr: + case scMulExpr: + case scUMaxExpr: + case scSMaxExpr: { + const SCEVNAryExpr *NAry = cast(S); + for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); + I != E; ++I) { + const SCEV *NAryOp = *I; + if (NAryOp == Op || hasOperand(NAryOp, Op)) + return true; + } + return false; + } + case scUDivExpr: { + const SCEVUDivExpr *UDiv = cast(S); + const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS(); + return LHS == Op || hasOperand(LHS, Op) || + RHS == Op || hasOperand(RHS, Op); + } + case scUnknown: + return false; + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + return false; + default: break; + } + llvm_unreachable("Unknown SCEV kind!"); + return false; +} + +void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { + ValuesAtScopes.erase(S); + LoopDispositions.erase(S); + BlockDispositions.erase(S); + UnsignedRanges.erase(S); + SignedRanges.erase(S); +}