From 85b05a2e60e0e696739167b52cc7cc3e7cf390c0 Mon Sep 17 00:00:00 2001 From: Dan Gohman Date: Mon, 13 Jul 2009 21:35:55 +0000 Subject: [PATCH] Reapply 75252, with a fix to avoid the infinite recursion case. The check for avoiding re-analyzing a widening cast needed to happen earlier, as getSCEV itself may result in a isLoopGuardedByCond query. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@75511 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/Analysis/ScalarEvolution.h | 62 ++- lib/Analysis/ScalarEvolution.cpp | 641 ++++++++++++++++++---- test/Transforms/IndVarSimplify/iv-sext.ll | 1 - 3 files changed, 580 insertions(+), 124 deletions(-) diff --git a/include/llvm/Analysis/ScalarEvolution.h b/include/llvm/Analysis/ScalarEvolution.h index 0224c00ba1c..e31d63c5a82 100644 --- a/include/llvm/Analysis/ScalarEvolution.h +++ b/include/llvm/Analysis/ScalarEvolution.h @@ -26,6 +26,7 @@ #include "llvm/Support/DataTypes.h" #include "llvm/Support/ValueHandle.h" #include "llvm/Support/Allocator.h" +#include "llvm/Support/ConstantRange.h" #include "llvm/ADT/FoldingSet.h" #include "llvm/ADT/DenseMap.h" #include @@ -330,12 +331,20 @@ namespace llvm { /// found. BasicBlock* getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB); - /// isNecessaryCond - Test whether the given CondValue value is a condition - /// which is at least as strict as the one described by Pred, LHS, and RHS. + /// isNecessaryCond - Test whether the condition described by Pred, LHS, + /// and RHS is a necessary condition for the given Cond value to evaluate + /// to true. bool isNecessaryCond(Value *Cond, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, bool Inverse); + /// isNecessaryCondOperands - Test whether the condition described by Pred, + /// LHS, and RHS is a necessary condition for the condition described by + /// Pred, FoundLHS, and FoundRHS to evaluate to true. + bool isNecessaryCondOperands(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, const SCEV *FoundRHS); + /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a /// constant number of times, and the PHI node is just a recurrence @@ -495,10 +504,16 @@ namespace llvm { /// isLoopGuardedByCond - Test whether entry to the loop is protected by /// a conditional between LHS and RHS. This is used to help avoid max - /// expressions in loop trip counts. + /// expressions in loop trip counts, and to eliminate casts. bool isLoopGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS); + /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is + /// protected by a conditional between LHS and RHS. This is used to + /// to eliminate casts. + bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS); + /// getBackedgeTakenCount - If the specified loop has a predictable /// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute /// object. The backedge-taken count is the number of times the loop header @@ -534,13 +549,42 @@ namespace llvm { /// bitwidth of S. uint32_t GetMinTrailingZeros(const SCEV *S); - /// GetMinLeadingZeros - Determine the minimum number of zero bits that S is - /// guaranteed to begin with (at every loop iteration). - uint32_t GetMinLeadingZeros(const SCEV *S); + /// getUnsignedRange - Determine the unsigned range for a particular SCEV. + /// + ConstantRange getUnsignedRange(const SCEV *S); + + /// getSignedRange - Determine the signed range for a particular SCEV. + /// + ConstantRange getSignedRange(const SCEV *S); + + /// isKnownNegative - Test if the given expression is known to be negative. + /// + bool isKnownNegative(const SCEV *S); + + /// isKnownPositive - Test if the given expression is known to be positive. + /// + bool isKnownPositive(const SCEV *S); + + /// isKnownNonNegative - Test if the given expression is known to be + /// non-negative. + /// + bool isKnownNonNegative(const SCEV *S); + + /// isKnownNonPositive - Test if the given expression is known to be + /// non-positive. + /// + bool isKnownNonPositive(const SCEV *S); + + /// isKnownNonZero - Test if the given expression is known to be + /// non-zero. + /// + bool isKnownNonZero(const SCEV *S); - /// GetMinSignBits - Determine the minimum number of sign bits that S is - /// guaranteed to begin with. - uint32_t GetMinSignBits(const SCEV *S); + /// isKnownNonZero - Test if the given expression is known to satisfy + /// the condition described by Pred, LHS, and RHS. + /// + bool isKnownPredicate(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS); virtual bool runOnFunction(Function &F); virtual void releaseMemory(); diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 2db39c4e0c4..2dbc3485d0f 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -787,6 +787,11 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; } if (const SCEVAddRecExpr *AR = dyn_cast(Op)) if (AR->isAffine()) { + const SCEV *Start = AR->getStart(); + const SCEV *Step = AR->getStepRecurrence(*this); + unsigned BitWidth = getTypeSizeInBits(AR->getType()); + const Loop *L = AR->getLoop(); + // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are // simply not analyzable, and it covers the case where this code is @@ -795,12 +800,10 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - const SCEV *MaxBECount = getMaxBackedgeTakenCount(AR->getLoop()); + const SCEV *MaxBECount = getMaxBackedgeTakenCount(L); if (!isa(MaxBECount)) { // Manually compute the final value for AR, checking for // overflow. - const SCEV *Start = AR->getStart(); - const SCEV *Step = AR->getStepRecurrence(*this); // Check whether the backedge-taken count can be losslessly casted to // the addrec's type. The count is always unsigned. @@ -809,8 +812,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); if (MaxBECount == RecastedMaxBECount) { - const Type *WideTy = - IntegerType::get(getTypeSizeInBits(Start->getType()) * 2); + const Type *WideTy = IntegerType::get(BitWidth * 2); // Check whether Start+Step*MaxBECount has no unsigned overflow. const SCEV *ZMul = getMulExpr(CastedMaxBECount, @@ -824,7 +826,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Return the expression with the addrec on the outside. return getAddRecExpr(getZeroExtendExpr(Start, Ty), getZeroExtendExpr(Step, Ty), - AR->getLoop()); + L); // Similar to above, only this time treat the step value as signed. // This covers loops that count down. @@ -840,7 +842,35 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Return the expression with the addrec on the outside. return getAddRecExpr(getZeroExtendExpr(Start, Ty), getSignExtendExpr(Step, Ty), - AR->getLoop()); + L); + } + + // If the backedge is guarded by a comparison with the pre-inc value + // the addrec is safe. Also, if the entry is guarded by a comparison + // with the start value and the backedge is guarded by a comparison + // with the post-inc value, the addrec is safe. + if (isKnownPositive(Step)) { + const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - + getUnsignedRange(Step).getUnsignedMax()); + if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || + (isLoopGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) && + isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, + AR->getPostIncExpr(*this), N))) + // Return the expression with the addrec on the outside. + return getAddRecExpr(getZeroExtendExpr(Start, Ty), + getZeroExtendExpr(Step, Ty), + L); + } else if (isKnownNegative(Step)) { + const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - + getSignedRange(Step).getSignedMin()); + if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) && + (isLoopGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) || + isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, + AR->getPostIncExpr(*this), N))) + // Return the expression with the addrec on the outside. + return getAddRecExpr(getZeroExtendExpr(Start, Ty), + getSignExtendExpr(Step, Ty), + L); } } } @@ -889,6 +919,11 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } if (const SCEVAddRecExpr *AR = dyn_cast(Op)) if (AR->isAffine()) { + const SCEV *Start = AR->getStart(); + const SCEV *Step = AR->getStepRecurrence(*this); + unsigned BitWidth = getTypeSizeInBits(AR->getType()); + const Loop *L = AR->getLoop(); + // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are // simply not analyzable, and it covers the case where this code is @@ -897,12 +932,10 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - const SCEV *MaxBECount = getMaxBackedgeTakenCount(AR->getLoop()); + const SCEV *MaxBECount = getMaxBackedgeTakenCount(L); if (!isa(MaxBECount)) { // Manually compute the final value for AR, checking for // overflow. - const SCEV *Start = AR->getStart(); - const SCEV *Step = AR->getStepRecurrence(*this); // Check whether the backedge-taken count can be losslessly casted to // the addrec's type. The count is always unsigned. @@ -911,8 +944,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); if (MaxBECount == RecastedMaxBECount) { - const Type *WideTy = - IntegerType::get(getTypeSizeInBits(Start->getType()) * 2); + const Type *WideTy = IntegerType::get(BitWidth * 2); // Check whether Start+Step*MaxBECount has no signed overflow. const SCEV *SMul = getMulExpr(CastedMaxBECount, @@ -926,7 +958,35 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // Return the expression with the addrec on the outside. return getAddRecExpr(getSignExtendExpr(Start, Ty), getSignExtendExpr(Step, Ty), - AR->getLoop()); + L); + } + + // If the backedge is guarded by a comparison with the pre-inc value + // the addrec is safe. Also, if the entry is guarded by a comparison + // with the start value and the backedge is guarded by a comparison + // with the post-inc value, the addrec is safe. + if (isKnownPositive(Step)) { + const SCEV *N = getConstant(APInt::getSignedMinValue(BitWidth) - + getSignedRange(Step).getSignedMax()); + if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SLT, AR, N) || + (isLoopGuardedByCond(L, ICmpInst::ICMP_SLT, Start, N) && + isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SLT, + AR->getPostIncExpr(*this), N))) + // Return the expression with the addrec on the outside. + return getAddRecExpr(getSignExtendExpr(Start, Ty), + getSignExtendExpr(Step, Ty), + L); + } else if (isKnownNegative(Step)) { + const SCEV *N = getConstant(APInt::getSignedMaxValue(BitWidth) - + getSignedRange(Step).getSignedMin()); + if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SGT, AR, N) || + (isLoopGuardedByCond(L, ICmpInst::ICMP_SGT, Start, N) && + isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SGT, + AR->getPostIncExpr(*this), N))) + // Return the expression with the addrec on the outside. + return getAddRecExpr(getSignExtendExpr(Start, Ty), + getSignExtendExpr(Step, Ty), + L); } } } @@ -2368,19 +2428,16 @@ const SCEV *ScalarEvolution::createNodeForGEP(User *GEP) { const StructLayout &SL = *TD->getStructLayout(STy); unsigned FieldNo = cast(Index)->getZExtValue(); uint64_t Offset = SL.getElementOffset(FieldNo); - TotalOffset = getAddExpr(TotalOffset, - getIntegerSCEV(Offset, IntPtrTy)); + TotalOffset = getAddExpr(TotalOffset, getIntegerSCEV(Offset, IntPtrTy)); } else { // For an array, add the element offset, explicitly scaled. const SCEV *LocalOffset = getSCEV(Index); if (!isa(LocalOffset->getType())) // Getelementptr indicies are signed. - LocalOffset = getTruncateOrSignExtend(LocalOffset, - IntPtrTy); + LocalOffset = getTruncateOrSignExtend(LocalOffset, IntPtrTy); LocalOffset = getMulExpr(LocalOffset, - getIntegerSCEV(TD->getTypeAllocSize(*GTI), - IntPtrTy)); + getIntegerSCEV(TD->getTypeAllocSize(*GTI), IntPtrTy)); TotalOffset = getAddExpr(TotalOffset, LocalOffset); } } @@ -2468,18 +2525,95 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { return 0; } -uint32_t -ScalarEvolution::GetMinLeadingZeros(const SCEV *S) { - // TODO: Handle other SCEV expression types here. +/// getUnsignedRange - Determine the unsigned range for a particular SCEV. +/// +ConstantRange +ScalarEvolution::getUnsignedRange(const SCEV *S) { if (const SCEVConstant *C = dyn_cast(S)) - return C->getValue()->getValue().countLeadingZeros(); + return ConstantRange(C->getValue()->getValue()); + + if (const SCEVAddExpr *Add = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(Add->getOperand(0)); + for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) + X = X.add(getUnsignedRange(Add->getOperand(i))); + return X; + } + + if (const SCEVMulExpr *Mul = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(Mul->getOperand(0)); + for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) + X = X.multiply(getUnsignedRange(Mul->getOperand(i))); + return X; + } - if (const SCEVZeroExtendExpr *C = dyn_cast(S)) { - // A zero-extension cast adds zero bits. - return GetMinLeadingZeros(C->getOperand()) + - (getTypeSizeInBits(C->getType()) - - getTypeSizeInBits(C->getOperand()->getType())); + if (const SCEVSMaxExpr *SMax = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(SMax->getOperand(0)); + for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) + X = X.smax(getUnsignedRange(SMax->getOperand(i))); + return X; + } + + if (const SCEVUMaxExpr *UMax = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(UMax->getOperand(0)); + for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) + X = X.umax(getUnsignedRange(UMax->getOperand(i))); + return X; + } + + if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(UDiv->getLHS()); + ConstantRange Y = getUnsignedRange(UDiv->getRHS()); + return X.udiv(Y); + } + + if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(ZExt->getOperand()); + return X.zeroExtend(cast(ZExt->getType())->getBitWidth()); + } + + if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(SExt->getOperand()); + return X.signExtend(cast(SExt->getType())->getBitWidth()); + } + + if (const SCEVTruncateExpr *Trunc = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(Trunc->getOperand()); + return X.truncate(cast(Trunc->getType())->getBitWidth()); + } + + ConstantRange FullSet(getTypeSizeInBits(S->getType()), true); + + if (const SCEVAddRecExpr *AddRec = dyn_cast(S)) { + const SCEV *T = getBackedgeTakenCount(AddRec->getLoop()); + const SCEVConstant *Trip = dyn_cast(T); + if (!Trip) return FullSet; + + // TODO: non-affine addrec + if (AddRec->isAffine()) { + const Type *Ty = AddRec->getType(); + const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); + if (getTypeSizeInBits(MaxBECount->getType()) <= getTypeSizeInBits(Ty)) { + MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); + + const SCEV *Start = AddRec->getStart(); + const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this); + + // Check for overflow. + if (!isKnownPredicate(ICmpInst::ICMP_ULE, Start, End)) + return FullSet; + + ConstantRange StartRange = getUnsignedRange(Start); + ConstantRange EndRange = getUnsignedRange(End); + APInt Min = APIntOps::umin(StartRange.getUnsignedMin(), + EndRange.getUnsignedMin()); + APInt Max = APIntOps::umax(StartRange.getUnsignedMax(), + EndRange.getUnsignedMax()); + if (Min.isMinValue() && Max.isMaxValue()) + return ConstantRange(Min.getBitWidth(), /*isFullSet=*/true); + return ConstantRange(Min, Max+1); + } + } } if (const SCEVUnknown *U = dyn_cast(S)) { @@ -2488,67 +2622,119 @@ ScalarEvolution::GetMinLeadingZeros(const SCEV *S) { APInt Mask = APInt::getAllOnesValue(BitWidth); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD); - return Zeros.countLeadingOnes(); + return ConstantRange(Ones, ~Zeros); } - return 1; + return FullSet; } -uint32_t -ScalarEvolution::GetMinSignBits(const SCEV *S) { - // TODO: Handle other SCEV expression types here. +/// getSignedRange - Determine the signed range for a particular SCEV. +/// +ConstantRange +ScalarEvolution::getSignedRange(const SCEV *S) { - if (const SCEVConstant *C = dyn_cast(S)) { - const APInt &A = C->getValue()->getValue(); - return A.isNegative() ? A.countLeadingOnes() : - A.countLeadingZeros(); + if (const SCEVConstant *C = dyn_cast(S)) + return ConstantRange(C->getValue()->getValue()); + + if (const SCEVAddExpr *Add = dyn_cast(S)) { + ConstantRange X = getSignedRange(Add->getOperand(0)); + for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) + X = X.add(getSignedRange(Add->getOperand(i))); + return X; } - if (const SCEVSignExtendExpr *C = dyn_cast(S)) { - // A sign-extension cast adds sign bits. - return GetMinSignBits(C->getOperand()) + - (getTypeSizeInBits(C->getType()) - - getTypeSizeInBits(C->getOperand()->getType())); + if (const SCEVMulExpr *Mul = dyn_cast(S)) { + ConstantRange X = getSignedRange(Mul->getOperand(0)); + for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) + X = X.multiply(getSignedRange(Mul->getOperand(i))); + return X; } - if (const SCEVAddExpr *A = dyn_cast(S)) { - unsigned BitWidth = getTypeSizeInBits(A->getType()); - - // Special case decrementing a value (ADD X, -1): - if (const SCEVConstant *CRHS = dyn_cast(A->getOperand(0))) - if (CRHS->isAllOnesValue()) { - SmallVector OtherOps(A->op_begin() + 1, A->op_end()); - const SCEV *OtherOpsAdd = getAddExpr(OtherOps); - unsigned LZ = GetMinLeadingZeros(OtherOpsAdd); - - // If the input is known to be 0 or 1, the output is 0/-1, which is all - // sign bits set. - if (LZ == BitWidth - 1) - return BitWidth; - - // If we are subtracting one from a positive number, there is no carry - // out of the result. - if (LZ > 0) - return GetMinSignBits(OtherOpsAdd); - } + if (const SCEVSMaxExpr *SMax = dyn_cast(S)) { + ConstantRange X = getSignedRange(SMax->getOperand(0)); + for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) + X = X.smax(getSignedRange(SMax->getOperand(i))); + return X; + } - // Add can have at most one carry bit. Thus we know that the output - // is, at worst, one more bit than the inputs. - unsigned Min = BitWidth; - for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { - unsigned N = GetMinSignBits(A->getOperand(i)); - Min = std::min(Min, N) - 1; - if (Min == 0) return 1; + if (const SCEVUMaxExpr *UMax = dyn_cast(S)) { + ConstantRange X = getSignedRange(UMax->getOperand(0)); + for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) + X = X.umax(getSignedRange(UMax->getOperand(i))); + return X; + } + + if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { + ConstantRange X = getSignedRange(UDiv->getLHS()); + ConstantRange Y = getSignedRange(UDiv->getRHS()); + return X.udiv(Y); + } + + if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) { + ConstantRange X = getSignedRange(ZExt->getOperand()); + return X.zeroExtend(cast(ZExt->getType())->getBitWidth()); + } + + if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) { + ConstantRange X = getSignedRange(SExt->getOperand()); + return X.signExtend(cast(SExt->getType())->getBitWidth()); + } + + if (const SCEVTruncateExpr *Trunc = dyn_cast(S)) { + ConstantRange X = getSignedRange(Trunc->getOperand()); + return X.truncate(cast(Trunc->getType())->getBitWidth()); + } + + ConstantRange FullSet(getTypeSizeInBits(S->getType()), true); + + if (const SCEVAddRecExpr *AddRec = dyn_cast(S)) { + const SCEV *T = getBackedgeTakenCount(AddRec->getLoop()); + const SCEVConstant *Trip = dyn_cast(T); + if (!Trip) return FullSet; + + // TODO: non-affine addrec + if (AddRec->isAffine()) { + const Type *Ty = AddRec->getType(); + const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); + if (getTypeSizeInBits(MaxBECount->getType()) <= getTypeSizeInBits(Ty)) { + MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); + + const SCEV *Start = AddRec->getStart(); + const SCEV *Step = AddRec->getStepRecurrence(*this); + const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this); + + // Check for overflow. + if (!(isKnownPositive(Step) && + isKnownPredicate(ICmpInst::ICMP_SLT, Start, End)) && + !(isKnownNegative(Step) && + isKnownPredicate(ICmpInst::ICMP_SGT, Start, End))) + return FullSet; + + ConstantRange StartRange = getSignedRange(Start); + ConstantRange EndRange = getSignedRange(End); + APInt Min = APIntOps::smin(StartRange.getSignedMin(), + EndRange.getSignedMin()); + APInt Max = APIntOps::smax(StartRange.getSignedMax(), + EndRange.getSignedMax()); + if (Min.isMinSignedValue() && Max.isMaxSignedValue()) + return ConstantRange(Min.getBitWidth(), /*isFullSet=*/true); + return ConstantRange(Min, Max+1); + } } - return 1; } if (const SCEVUnknown *U = dyn_cast(S)) { // For a SCEVUnknown, ask ValueTracking. - return ComputeNumSignBits(U->getValue(), TD); + unsigned BitWidth = getTypeSizeInBits(U->getType()); + unsigned NS = ComputeNumSignBits(U->getValue(), TD); + if (NS == 1) + return FullSet; + return + ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), + APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1); } - return 1; + return FullSet; } /// createSCEV - We know that there is no SCEV for the specified value. @@ -3628,7 +3814,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { if (!isSCEVable(Op->getType())) return V; - const SCEV *OpV = getSCEVAtScope(getSCEV(Op), L); + const SCEV* OpV = getSCEVAtScope(Op, L); if (const SCEVConstant *SC = dyn_cast(OpV)) { Constant *C = SC->getValue(); if (C->getType() != Op->getType()) @@ -4029,12 +4215,176 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) { return false; } -/// isLoopGuardedByCond - Test whether entry to the loop is protected by -/// a conditional between LHS and RHS. This is used to help avoid max -/// expressions in loop trip counts. -bool ScalarEvolution::isLoopGuardedByCond(const Loop *L, - ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +bool ScalarEvolution::isKnownNegative(const SCEV *S) { + return getSignedRange(S).getSignedMax().isNegative(); +} + +bool ScalarEvolution::isKnownPositive(const SCEV *S) { + return getSignedRange(S).getSignedMin().isStrictlyPositive(); +} + +bool ScalarEvolution::isKnownNonNegative(const SCEV *S) { + return !getSignedRange(S).getSignedMin().isNegative(); +} + +bool ScalarEvolution::isKnownNonPositive(const SCEV *S) { + return !getSignedRange(S).getSignedMax().isStrictlyPositive(); +} + +bool ScalarEvolution::isKnownNonZero(const SCEV *S) { + return isKnownNegative(S) || isKnownPositive(S); +} + +bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + + if (HasSameValue(LHS, RHS)) + return ICmpInst::isTrueWhenEqual(Pred); + + switch (Pred) { + default: + assert(0 && "Unexpected ICmpInst::Predicate value!"); + break; + case ICmpInst::ICMP_SGT: + Pred = ICmpInst::ICMP_SLT; + std::swap(LHS, RHS); + case ICmpInst::ICMP_SLT: { + ConstantRange LHSRange = getSignedRange(LHS); + ConstantRange RHSRange = getSignedRange(RHS); + if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin())) + return true; + if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax())) + return false; + + const SCEV *Diff = getMinusSCEV(LHS, RHS); + ConstantRange DiffRange = getUnsignedRange(Diff); + if (isKnownNegative(Diff)) { + if (DiffRange.getUnsignedMax().ult(LHSRange.getUnsignedMin())) + return true; + if (DiffRange.getUnsignedMin().uge(LHSRange.getUnsignedMax())) + return false; + } else if (isKnownPositive(Diff)) { + if (LHSRange.getUnsignedMax().ult(DiffRange.getUnsignedMin())) + return true; + if (LHSRange.getUnsignedMin().uge(DiffRange.getUnsignedMax())) + return false; + } + break; + } + case ICmpInst::ICMP_SGE: + Pred = ICmpInst::ICMP_SLE; + std::swap(LHS, RHS); + case ICmpInst::ICMP_SLE: { + ConstantRange LHSRange = getSignedRange(LHS); + ConstantRange RHSRange = getSignedRange(RHS); + if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin())) + return true; + if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax())) + return false; + + const SCEV *Diff = getMinusSCEV(LHS, RHS); + ConstantRange DiffRange = getUnsignedRange(Diff); + if (isKnownNonPositive(Diff)) { + if (DiffRange.getUnsignedMax().ule(LHSRange.getUnsignedMin())) + return true; + if (DiffRange.getUnsignedMin().ugt(LHSRange.getUnsignedMax())) + return false; + } else if (isKnownNonNegative(Diff)) { + if (LHSRange.getUnsignedMax().ule(DiffRange.getUnsignedMin())) + return true; + if (LHSRange.getUnsignedMin().ugt(DiffRange.getUnsignedMax())) + return false; + } + break; + } + case ICmpInst::ICMP_UGT: + Pred = ICmpInst::ICMP_ULT; + std::swap(LHS, RHS); + case ICmpInst::ICMP_ULT: { + ConstantRange LHSRange = getUnsignedRange(LHS); + ConstantRange RHSRange = getUnsignedRange(RHS); + if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin())) + return true; + if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax())) + return false; + + const SCEV *Diff = getMinusSCEV(LHS, RHS); + ConstantRange DiffRange = getUnsignedRange(Diff); + if (LHSRange.getUnsignedMax().ult(DiffRange.getUnsignedMin())) + return true; + if (LHSRange.getUnsignedMin().uge(DiffRange.getUnsignedMax())) + return false; + break; + } + case ICmpInst::ICMP_UGE: + Pred = ICmpInst::ICMP_ULE; + std::swap(LHS, RHS); + case ICmpInst::ICMP_ULE: { + ConstantRange LHSRange = getUnsignedRange(LHS); + ConstantRange RHSRange = getUnsignedRange(RHS); + if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin())) + return true; + if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax())) + return false; + + const SCEV *Diff = getMinusSCEV(LHS, RHS); + ConstantRange DiffRange = getUnsignedRange(Diff); + if (LHSRange.getUnsignedMax().ule(DiffRange.getUnsignedMin())) + return true; + if (LHSRange.getUnsignedMin().ugt(DiffRange.getUnsignedMax())) + return false; + break; + } + case ICmpInst::ICMP_NE: { + if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet()) + return true; + if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet()) + return true; + + const SCEV *Diff = getMinusSCEV(LHS, RHS); + if (isKnownNonZero(Diff)) + return true; + break; + } + case ICmpInst::ICMP_EQ: + break; + } + return false; +} + +/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is +/// protected by a conditional between LHS and RHS. This is used to +/// to eliminate casts. +bool +ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // Interpret a null as meaning no loop, where there is obviously no guard + // (interprocedural conditions notwithstanding). + if (!L) return true; + + BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return false; + + BranchInst *LoopContinuePredicate = + dyn_cast(Latch->getTerminator()); + if (!LoopContinuePredicate || + LoopContinuePredicate->isUnconditional()) + return false; + + return + isNecessaryCond(LoopContinuePredicate->getCondition(), Pred, LHS, RHS, + LoopContinuePredicate->getSuccessor(0) != L->getHeader()); +} + +/// isLoopGuardedByCond - Test whether entry to the loop is protected +/// by a conditional between LHS and RHS. This is used to help avoid max +/// expressions in loop trip counts, and to eliminate casts. +bool +ScalarEvolution::isLoopGuardedByCond(const Loop *L, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { // Interpret a null as meaning no loop, where there is obviously no guard // (interprocedural conditions notwithstanding). if (!L) return false; @@ -4063,8 +4413,9 @@ bool ScalarEvolution::isLoopGuardedByCond(const Loop *L, return false; } -/// isNecessaryCond - Test whether the given CondValue value is a condition -/// which is at least as strict as the one described by Pred, LHS, and RHS. +/// isNecessaryCond - Test whether the condition described by Pred, LHS, +/// and RHS is a necessary condition for the given Cond value to evaluate +/// to true. bool ScalarEvolution::isNecessaryCond(Value *CondValue, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, @@ -4089,30 +4440,35 @@ bool ScalarEvolution::isNecessaryCond(Value *CondValue, // see if it is the comparison we are looking for. Value *PreCondLHS = ICI->getOperand(0); Value *PreCondRHS = ICI->getOperand(1); - ICmpInst::Predicate Cond; + ICmpInst::Predicate FoundPred; if (Inverse) - Cond = ICI->getInversePredicate(); + FoundPred = ICI->getInversePredicate(); else - Cond = ICI->getPredicate(); + FoundPred = ICI->getPredicate(); - if (Cond == Pred) + if (FoundPred == Pred) ; // An exact match. - else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE) - ; // The actual condition is beyond sufficient. - else + else if (!ICmpInst::isTrueWhenEqual(FoundPred) && Pred == ICmpInst::ICMP_NE) { + // The actual condition is beyond sufficient. + FoundPred = ICmpInst::ICMP_NE; + // NE is symmetric but the original comparison may not be. Swap + // the operands if necessary so that they match below. + if (isa(LHS)) + std::swap(PreCondLHS, PreCondRHS); + } else // Check a few special cases. - switch (Cond) { + switch (FoundPred) { case ICmpInst::ICMP_UGT: if (Pred == ICmpInst::ICMP_ULT) { std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_ULT; + FoundPred = ICmpInst::ICMP_ULT; break; } return false; case ICmpInst::ICMP_SGT: if (Pred == ICmpInst::ICMP_SLT) { std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_SLT; + FoundPred = ICmpInst::ICMP_SLT; break; } return false; @@ -4121,8 +4477,8 @@ bool ScalarEvolution::isNecessaryCond(Value *CondValue, // so check for this case by checking if the NE is comparing against // a minimum or maximum constant. if (!ICmpInst::isTrueWhenEqual(Pred)) - if (ConstantInt *CI = dyn_cast(PreCondRHS)) { - const APInt &A = CI->getValue(); + if (const SCEVConstant *C = dyn_cast(RHS)) { + const APInt &A = C->getValue()->getValue(); switch (Pred) { case ICmpInst::ICMP_SLT: if (A.isMaxSignedValue()) break; @@ -4139,7 +4495,7 @@ bool ScalarEvolution::isNecessaryCond(Value *CondValue, default: return false; } - Cond = ICmpInst::ICMP_NE; + FoundPred = Pred; // NE is symmetric but the original comparison may not be. Swap // the operands if necessary so that they match below. if (isa(LHS)) @@ -4152,14 +4508,73 @@ bool ScalarEvolution::isNecessaryCond(Value *CondValue, return false; } - if (!PreCondLHS->getType()->isInteger()) return false; + assert(Pred == FoundPred && "Conditions were not reconciled!"); + + // Bail if the ICmp's operands' types are wider than the needed type + // before attempting to call getSCEV on them. This avoids infinite + // recursion, since the analysis of widening casts can require loop + // exit condition information for overflow checking, which would + // lead back here. + if (getTypeSizeInBits(LHS->getType()) < + getTypeSizeInBits(PreCondLHS->getType())) + return false; - const SCEV *PreCondLHSSCEV = getSCEV(PreCondLHS); - const SCEV *PreCondRHSSCEV = getSCEV(PreCondRHS); - return (HasSameValue(LHS, PreCondLHSSCEV) && - HasSameValue(RHS, PreCondRHSSCEV)) || - (HasSameValue(LHS, getNotSCEV(PreCondRHSSCEV)) && - HasSameValue(RHS, getNotSCEV(PreCondLHSSCEV))); + const SCEV *FoundLHS = getSCEV(PreCondLHS); + const SCEV *FoundRHS = getSCEV(PreCondRHS); + + // Balance the types. The case where FoundLHS' type is wider than + // LHS' type is checked for above. + if (getTypeSizeInBits(LHS->getType()) > + getTypeSizeInBits(FoundLHS->getType())) { + if (CmpInst::isSigned(Pred)) { + FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType()); + FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType()); + } else { + FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType()); + FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType()); + } + } + + return isNecessaryCondOperands(Pred, LHS, RHS, + FoundLHS, FoundRHS) || + // ~x < ~y --> x > y + isNecessaryCondOperands(Pred, LHS, RHS, + getNotSCEV(FoundRHS), getNotSCEV(FoundLHS)); +} + +/// isNecessaryCondOperands - Test whether the condition described by Pred, +/// LHS, and RHS is a necessary condition for the condition described by +/// Pred, FoundLHS, and FoundRHS to evaluate to true. +bool +ScalarEvolution::isNecessaryCondOperands(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + switch (Pred) { + default: break; + case ICmpInst::ICMP_SLT: + if (isKnownPredicate(ICmpInst::ICMP_SLE, LHS, FoundLHS) && + isKnownPredicate(ICmpInst::ICMP_SGE, RHS, FoundRHS)) + return true; + break; + case ICmpInst::ICMP_SGT: + if (isKnownPredicate(ICmpInst::ICMP_SGE, LHS, FoundLHS) && + isKnownPredicate(ICmpInst::ICMP_SLE, RHS, FoundRHS)) + return true; + break; + case ICmpInst::ICMP_ULT: + if (isKnownPredicate(ICmpInst::ICMP_ULE, LHS, FoundLHS) && + isKnownPredicate(ICmpInst::ICMP_UGE, RHS, FoundRHS)) + return true; + break; + case ICmpInst::ICMP_UGT: + if (isKnownPredicate(ICmpInst::ICMP_UGE, LHS, FoundLHS) && + isKnownPredicate(ICmpInst::ICMP_ULE, RHS, FoundRHS)) + return true; + break; + } + + return false; } /// getBECount - Subtract the end and start values and divide by the step, @@ -4180,9 +4595,9 @@ const SCEV *ScalarEvolution::getBECount(const SCEV *Start, // Check Add for unsigned overflow. // TODO: More sophisticated things could be done here. const Type *WideTy = Context->getIntegerType(getTypeSizeInBits(Ty) + 1); - const SCEV *OperandExtendedAdd = - getAddExpr(getZeroExtendExpr(Diff, WideTy), - getZeroExtendExpr(RoundUp, WideTy)); + const SCEV *EDiff = getZeroExtendExpr(Diff, WideTy); + const SCEV *ERoundUp = getZeroExtendExpr(RoundUp, WideTy); + const SCEV *OperandExtendedAdd = getAddExpr(EDiff, ERoundUp); if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd) return getCouldNotCompute(); @@ -4244,9 +4659,9 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, const SCEV *Start = AddRec->getOperand(0); // Determine the minimum constant start value. - const SCEV *MinStart = isa(Start) ? Start : - getConstant(isSigned ? APInt::getSignedMinValue(BitWidth) : - APInt::getMinValue(BitWidth)); + const SCEV *MinStart = getConstant(isSigned ? + getSignedRange(Start).getSignedMin() : + getUnsignedRange(Start).getUnsignedMin()); // If we know that the condition is true in order to enter the loop, // then we know that it will run exactly (m-n)/s times. Otherwise, we @@ -4254,18 +4669,16 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // the division must round up. const SCEV *End = RHS; if (!isLoopGuardedByCond(L, - isSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + isSigned ? ICmpInst::ICMP_SLT : + ICmpInst::ICMP_ULT, getMinusSCEV(Start, Step), RHS)) End = isSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); // Determine the maximum constant end value. - const SCEV *MaxEnd = - isa(End) ? End : - getConstant(isSigned ? APInt::getSignedMaxValue(BitWidth) - .ashr(GetMinSignBits(End) - 1) : - APInt::getMaxValue(BitWidth) - .lshr(GetMinLeadingZeros(End))); + const SCEV *MaxEnd = getConstant(isSigned ? + getSignedRange(End).getSignedMax() : + getUnsignedRange(End).getUnsignedMax()); // Finally, we subtract these two values and divide, rounding up, to get // the number of times the backedge is executed. diff --git a/test/Transforms/IndVarSimplify/iv-sext.ll b/test/Transforms/IndVarSimplify/iv-sext.ll index ae97208b15f..120acb23c83 100644 --- a/test/Transforms/IndVarSimplify/iv-sext.ll +++ b/test/Transforms/IndVarSimplify/iv-sext.ll @@ -1,7 +1,6 @@ ; RUN: llvm-as < %s | opt -indvars | llvm-dis > %t ; RUN: grep {= sext} %t | count 4 ; RUN: grep {phi i64} %t | count 2 -; XFAIL: * ; Indvars should be able to promote the hiPart induction variable in the ; inner loop to i64. -- 2.34.1