From c54c561c9f7270c055dd7ba75a3a003b771a42d9 Mon Sep 17 00:00:00 2001 From: Nick Lewycky Date: Sun, 25 Nov 2007 22:41:31 +0000 Subject: [PATCH] Add new SCEV, SCEVSMax. This allows LLVM to analyze do-while loops. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@44319 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/Analysis/ScalarEvolution.h | 2 + .../llvm/Analysis/ScalarEvolutionExpander.h | 2 + .../Analysis/ScalarEvolutionExpressions.h | 29 ++- lib/Analysis/ScalarEvolution.cpp | 210 +++++++++++------- lib/Analysis/ScalarEvolutionExpander.cpp | 10 + test/Analysis/ScalarEvolution/do-loop.ll | 18 ++ test/Analysis/ScalarEvolution/smax.ll | 12 + .../IndVarsSimplify/loop_evaluate_2.ll | 4 +- 8 files changed, 203 insertions(+), 84 deletions(-) create mode 100644 test/Analysis/ScalarEvolution/do-loop.ll create mode 100644 test/Analysis/ScalarEvolution/smax.ll diff --git a/include/llvm/Analysis/ScalarEvolution.h b/include/llvm/Analysis/ScalarEvolution.h index a52f273b2c7..d0886e8c64c 100644 --- a/include/llvm/Analysis/ScalarEvolution.h +++ b/include/llvm/Analysis/ScalarEvolution.h @@ -235,6 +235,8 @@ namespace llvm { std::vector NewOp(Operands); return getAddRecExpr(NewOp, L); } + SCEVHandle getSMaxExpr(const SCEVHandle &LHS, const SCEVHandle &RHS); + SCEVHandle getSMaxExpr(std::vector Operands); SCEVHandle getUnknown(Value *V); /// getNegativeSCEV - Return the SCEV object corresponding to -V. diff --git a/include/llvm/Analysis/ScalarEvolutionExpander.h b/include/llvm/Analysis/ScalarEvolutionExpander.h index 8582067e108..6529902ec81 100644 --- a/include/llvm/Analysis/ScalarEvolutionExpander.h +++ b/include/llvm/Analysis/ScalarEvolutionExpander.h @@ -134,6 +134,8 @@ namespace llvm { Value *visitAddRecExpr(SCEVAddRecExpr *S); + Value *visitSMaxExpr(SCEVSMaxExpr *S); + Value *visitUnknown(SCEVUnknown *S) { return S->getValue(); } diff --git a/include/llvm/Analysis/ScalarEvolutionExpressions.h b/include/llvm/Analysis/ScalarEvolutionExpressions.h index fb69a90fe74..f6392433a7a 100644 --- a/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -25,7 +25,7 @@ namespace llvm { // These should be ordered in terms of increasing complexity to make the // folders simpler. scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr, - scSDivExpr, scAddRecExpr, scUnknown, scCouldNotCompute + scSDivExpr, scAddRecExpr, scSMaxExpr, scUnknown, scCouldNotCompute }; //===--------------------------------------------------------------------===// @@ -274,7 +274,8 @@ namespace llvm { static inline bool classof(const SCEVCommutativeExpr *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr || - S->getSCEVType() == scMulExpr; + S->getSCEVType() == scMulExpr || + S->getSCEVType() == scSMaxExpr; } }; @@ -459,6 +460,28 @@ namespace llvm { } }; + + //===--------------------------------------------------------------------===// + /// SCEVSMaxExpr - This class represents a signed maximum selection. + /// + class SCEVSMaxExpr : public SCEVCommutativeExpr { + friend class ScalarEvolution; + + explicit SCEVSMaxExpr(const std::vector &ops) + : SCEVCommutativeExpr(scSMaxExpr, ops) { + } + + public: + virtual const char *getOperationStr() const { return " smax "; } + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const SCEVSMaxExpr *S) { return true; } + static inline bool classof(const SCEV *S) { + return S->getSCEVType() == scSMaxExpr; + } + }; + + //===--------------------------------------------------------------------===// /// SCEVUnknown - This means that we are dealing with an entirely unknown SCEV /// value, and only represent it as it's LLVM Value. This is the "bottom" @@ -521,6 +544,8 @@ namespace llvm { return ((SC*)this)->visitSDivExpr((SCEVSDivExpr*)S); case scAddRecExpr: return ((SC*)this)->visitAddRecExpr((SCEVAddRecExpr*)S); + case scSMaxExpr: + return ((SC*)this)->visitSMaxExpr((SCEVSMaxExpr*)S); case scUnknown: return ((SC*)this)->visitUnknown((SCEVUnknown*)S); case scCouldNotCompute: diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 27158e5ddde..558da230d19 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -318,6 +318,8 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, return SE.getAddExpr(NewOps); else if (isa(this)) return SE.getMulExpr(NewOps); + else if (isa(this)) + return SE.getSMaxExpr(NewOps); else assert(0 && "Unknown commutative expr!"); } @@ -1095,6 +1097,93 @@ SCEVHandle ScalarEvolution::getAddRecExpr(std::vector &Operands, return Result; } +SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS, + const SCEVHandle &RHS) { + std::vector Ops; + Ops.push_back(LHS); + Ops.push_back(RHS); + return getSMaxExpr(Ops); +} + +SCEVHandle ScalarEvolution::getSMaxExpr(std::vector Ops) { + assert(!Ops.empty() && "Cannot get empty smax!"); + if (Ops.size() == 1) return Ops[0]; + + // Sort by complexity, this groups all similar expression types together. + GroupByComplexity(Ops); + + // If there are any constants, fold them together. + unsigned Idx = 0; + if (SCEVConstant *LHSC = dyn_cast(Ops[0])) { + ++Idx; + assert(Idx < Ops.size()); + while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { + // We found two constants, fold them together! + Constant *Fold = ConstantInt::get( + APIntOps::smax(LHSC->getValue()->getValue(), + RHSC->getValue()->getValue())); + if (ConstantInt *CI = dyn_cast(Fold)) { + Ops[0] = getConstant(CI); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast(Ops[0]); + } else { + // If we couldn't fold the expression, move to the next constant. Note + // that this is impossible to happen in practice because we always + // constant fold constant ints to constant ints. + ++Idx; + } + } + + // If we are left with a constant -inf, strip it off. + if (cast(Ops[0])->getValue()->isMinValue(true)) { + Ops.erase(Ops.begin()); + --Idx; + } + } + + if (Ops.size() == 1) return Ops[0]; + + // Find the first SMax + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr) + ++Idx; + + // Check to see if one of the operands is an SMax. If so, expand its operands + // onto our operand list, and recurse to simplify. + if (Idx < Ops.size()) { + bool DeletedSMax = false; + while (SCEVSMaxExpr *SMax = dyn_cast(Ops[Idx])) { + Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end()); + Ops.erase(Ops.begin()+Idx); + DeletedSMax = true; + } + + if (DeletedSMax) + return getSMaxExpr(Ops); + } + + // Okay, check to see if the same value occurs in the operand list twice. If + // so, delete one. Since we sorted the list, these values are required to + // be adjacent. + for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) + if (Ops[i] == Ops[i+1]) { // X smax Y smax Y --> X smax Y + Ops.erase(Ops.begin()+i, Ops.begin()+i+1); + --i; --e; + } + + if (Ops.size() == 1) return Ops[0]; + + assert(!Ops.empty() && "Reduced smax down to nothing!"); + + // Okay, it looks like we really DO need an add expr. Check to see if we + // already have one, otherwise create a new one. + std::vector SCEVOps(Ops.begin(), Ops.end()); + SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr, + SCEVOps)]; + if (Result == 0) Result = new SCEVSMaxExpr(Ops); + return Result; +} + SCEVHandle ScalarEvolution::getUnknown(Value *V) { if (ConstantInt *CI = dyn_cast(V)) return getConstant(CI); @@ -1458,6 +1547,14 @@ static uint32_t GetMinTrailingZeros(SCEVHandle S) { return MinOpRes; } + if (SCEVSMaxExpr *M = dyn_cast(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); + for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); + return MinOpRes; + } + // SCEVSDivExpr, SCEVUnknown return 0; } @@ -1537,6 +1634,25 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { case Instruction::PHI: return createNodeForPHI(cast(I)); + case Instruction::Select: + // This could be an SCEVSMax that was lowered earlier. Try to recover it. + if (ICmpInst *ICI = dyn_cast(I->getOperand(0))) { + Value *LHS = ICI->getOperand(0); + Value *RHS = ICI->getOperand(1); + switch (ICI->getPredicate()) { + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + if (LHS == I->getOperand(1) && RHS == I->getOperand(2)) + return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS)); + default: + break; + } + } + default: // We cannot analyze this expression. break; } @@ -2125,8 +2241,11 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { } if (isa(Comm)) return SE.getAddExpr(NewOps); - assert(isa(Comm) && "Only know about add and mul!"); - return SE.getMulExpr(NewOps); + if (isa(Comm)) + return SE.getMulExpr(NewOps); + if (isa(Comm)) + return SE.getSMaxExpr(NewOps); + assert(0 && "Unknown commutative SCEV type!"); } } // If we got here, all operands are loop invariant. @@ -2343,90 +2462,21 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) { return UnknownValue; if (AddRec->isAffine()) { + // The number of iterations for "{n,+,1} < m", is m-n. However, we don't + // know that m is >= n on input to the loop. If it is, the condition + // returns true zero times. To handle both cases, we return SMAX(0, m-n). + // FORNOW: We only support unit strides. - SCEVHandle Zero = SE.getIntegerSCEV(0, RHS->getType()); SCEVHandle One = SE.getIntegerSCEV(1, RHS->getType()); if (AddRec->getOperand(1) != One) return UnknownValue; - // The number of iterations for "{n,+,1} < m", is m-n. However, we don't - // know that m is >= n on input to the loop. If it is, the condition return - // true zero times. What we really should return, for full generality, is - // SMAX(0, m-n). Since we cannot check this, we will instead check for a - // canonical loop form: most do-loops will have a check that dominates the - // loop, that only enters the loop if (n-1)= n. - - // Search for the check. - BasicBlock *Preheader = L->getLoopPreheader(); - BasicBlock *PreheaderDest = L->getHeader(); - if (Preheader == 0) return UnknownValue; - - BranchInst *LoopEntryPredicate = - dyn_cast(Preheader->getTerminator()); - if (!LoopEntryPredicate) return UnknownValue; - - // This might be a critical edge broken out. If the loop preheader ends in - // an unconditional branch to the loop, check to see if the preheader has a - // single predecessor, and if so, look for its terminator. - while (LoopEntryPredicate->isUnconditional()) { - PreheaderDest = Preheader; - Preheader = Preheader->getSinglePredecessor(); - if (!Preheader) return UnknownValue; // Multiple preds. - - LoopEntryPredicate = - dyn_cast(Preheader->getTerminator()); - if (!LoopEntryPredicate) return UnknownValue; - } - - // Now that we found a conditional branch that dominates the loop, check to - // see if it is the comparison we are looking for. - if (ICmpInst *ICI = dyn_cast(LoopEntryPredicate->getCondition())){ - Value *PreCondLHS = ICI->getOperand(0); - Value *PreCondRHS = ICI->getOperand(1); - ICmpInst::Predicate Cond; - if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest) - Cond = ICI->getPredicate(); - else - Cond = ICI->getInversePredicate(); - - switch (Cond) { - case ICmpInst::ICMP_UGT: - if (isSigned) return UnknownValue; - std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_ULT; - break; - case ICmpInst::ICMP_SGT: - if (!isSigned) return UnknownValue; - std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_SLT; - break; - case ICmpInst::ICMP_ULT: - if (isSigned) return UnknownValue; - break; - case ICmpInst::ICMP_SLT: - if (!isSigned) return UnknownValue; - break; - default: - return UnknownValue; - } - - if (PreCondLHS->getType()->isInteger()) { - if (RHS != getSCEV(PreCondRHS)) - return UnknownValue; // Not a comparison against 'm'. + SCEVHandle Iters = SE.getMinusSCEV(RHS, AddRec->getOperand(0)); - if (SE.getMinusSCEV(AddRec->getOperand(0), One) - != getSCEV(PreCondLHS)) - return UnknownValue; // Not a comparison against 'n-1'. - } - else return UnknownValue; - - // cerr << "Computed Loop Trip Count as: " - // << // *SE.getMinusSCEV(RHS, AddRec->getOperand(0)) << "\n"; - return SE.getMinusSCEV(RHS, AddRec->getOperand(0)); - } - else - return UnknownValue; + if (isSigned) + return SE.getSMaxExpr(SE.getIntegerSCEV(0, RHS->getType()), Iters); + else + return Iters; } return UnknownValue; diff --git a/lib/Analysis/ScalarEvolutionExpander.cpp b/lib/Analysis/ScalarEvolutionExpander.cpp index 3bac3024c45..88fd0aaf8fb 100644 --- a/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/lib/Analysis/ScalarEvolutionExpander.cpp @@ -208,6 +208,16 @@ Value *SCEVExpander::visitAddRecExpr(SCEVAddRecExpr *S) { return expand(V); } +Value *SCEVExpander::visitSMaxExpr(SCEVSMaxExpr *S) { + Value *LHS = expand(S->getOperand(0)); + for (unsigned i = 1; i < S->getNumOperands(); ++i) { + Value *RHS = expand(S->getOperand(i)); + Value *ICmp = new ICmpInst(ICmpInst::ICMP_SGT, LHS, RHS, "tmp", InsertPt); + LHS = new SelectInst(ICmp, LHS, RHS, "smax", InsertPt); + } + return LHS; +} + Value *SCEVExpander::expand(SCEV *S) { // Check to see if we already expanded this. std::map::iterator I = InsertedExpressions.find(S); diff --git a/test/Analysis/ScalarEvolution/do-loop.ll b/test/Analysis/ScalarEvolution/do-loop.ll new file mode 100644 index 00000000000..c6b3298638b --- /dev/null +++ b/test/Analysis/ScalarEvolution/do-loop.ll @@ -0,0 +1,18 @@ +; RUN: llvm-as < %s | opt -analyze -scalar-evolution | grep smax +; PR1614 + +define i32 @f(i32 %x, i32 %y) { +entry: + br label %bb + +bb: ; preds = %bb, %entry + %indvar = phi i32 [ 0, %entry ], [ %indvar.next, %bb ] ; [#uses=2] + %x_addr.0 = add i32 %indvar, %x ; [#uses=1] + %tmp2 = add i32 %x_addr.0, 1 ; [#uses=2] + %tmp5 = icmp slt i32 %tmp2, %y ; [#uses=1] + %indvar.next = add i32 %indvar, 1 ; [#uses=1] + br i1 %tmp5, label %bb, label %bb7 + +bb7: ; preds = %bb + ret i32 %tmp2 +} diff --git a/test/Analysis/ScalarEvolution/smax.ll b/test/Analysis/ScalarEvolution/smax.ll new file mode 100644 index 00000000000..157d54f3e63 --- /dev/null +++ b/test/Analysis/ScalarEvolution/smax.ll @@ -0,0 +1,12 @@ +; RUN: llvm-as < %s | opt -analyze -scalar-evolution | grep smax | count 2 +; RUN: llvm-as < %s | opt -analyze -scalar-evolution | grep \ +; RUN: "%. smax %. smax %." +; PR1614 + +define i32 @x(i32 %a, i32 %b, i32 %c) { + %A = icmp sgt i32 %a, %b + %B = select i1 %A, i32 %a, i32 %b + %C = icmp sle i32 %c, %B + %D = select i1 %C, i32 %B, i32 %c + ret i32 %D +} diff --git a/test/Transforms/IndVarsSimplify/loop_evaluate_2.ll b/test/Transforms/IndVarsSimplify/loop_evaluate_2.ll index c7426918462..635950aed66 100644 --- a/test/Transforms/IndVarsSimplify/loop_evaluate_2.ll +++ b/test/Transforms/IndVarsSimplify/loop_evaluate_2.ll @@ -1,5 +1,5 @@ -; RUN: llvm-as < %s | opt -indvars -adce -simplifycfg | llvm-dis | grep select -; RUN: llvm-as < %s | opt -indvars -adce -simplifycfg | llvm-dis | not grep br +; RUN: llvm-as < %s | opt -indvars -adce -simplifycfg | opt \ +; RUN: -analyze -loops | not grep "^Loop Containing" ; PR1179 define i32 @ltst(i32 %x) { -- 2.34.1