X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=include%2Fllvm%2FAnalysis%2FScalarEvolutionExpressions.h;h=eac91131ad535c53ec2fcaf42199c313abfc6ea6;hb=b075ed3b90fa2a520aeb15802fddf3460d865f91;hp=cf15f73a7511771de5c9cbdcfcaa0ee853c452bb;hpb=8b7036b0f4ae3f76ad24a6b9bc2d874620406306;p=oota-llvm.git diff --git a/include/llvm/Analysis/ScalarEvolutionExpressions.h b/include/llvm/Analysis/ScalarEvolutionExpressions.h index cf15f73a751..eac91131ad5 100644 --- a/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -11,9 +11,10 @@ // //===----------------------------------------------------------------------===// -#ifndef LLVM_ANALYSIS_SCALAREVOLUTION_EXPRESSIONS_H -#define LLVM_ANALYSIS_SCALAREVOLUTION_EXPRESSIONS_H +#ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H +#define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Support/ErrorHandling.h" @@ -45,7 +46,6 @@ namespace llvm { Type *getType() const { return V->getType(); } /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVConstant *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; } @@ -67,7 +67,6 @@ namespace llvm { Type *getType() const { return Ty; } /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVCastExpr *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend || @@ -87,7 +86,6 @@ namespace llvm { public: /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVTruncateExpr *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; } @@ -105,7 +103,6 @@ namespace llvm { public: /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVZeroExtendExpr *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scZeroExtend; } @@ -123,7 +120,6 @@ namespace llvm { public: /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVSignExtendExpr *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scSignExtend; } @@ -165,7 +161,6 @@ namespace llvm { } /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVNAryExpr *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || @@ -187,7 +182,6 @@ namespace llvm { public: /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVCommutativeExpr *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || @@ -222,7 +216,6 @@ namespace llvm { } /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVAddExpr *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; } @@ -241,7 +234,6 @@ namespace llvm { public: /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVMulExpr *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; } @@ -273,7 +265,6 @@ namespace llvm { } /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVUDivExpr *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; } @@ -357,7 +348,6 @@ namespace llvm { } /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVAddRecExpr *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scAddRecExpr; } @@ -379,7 +369,6 @@ namespace llvm { public: /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVSMaxExpr *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; } @@ -401,7 +390,6 @@ namespace llvm { public: /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVUMaxExpr *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; } @@ -448,7 +436,6 @@ namespace llvm { Type *getType() const { return getValPtr()->getType(); } /// Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const SCEVUnknown *S) { return true; } static inline bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; } @@ -505,9 +492,10 @@ namespace llvm { class SCEVTraversal { SV &Visitor; SmallVector Worklist; + SmallPtrSet Visited; void push(const SCEV *S) { - if (Visitor.follow(S)) + if (Visited.insert(S) && Visitor.follow(S)) Worklist.push_back(S); } public: @@ -560,6 +548,151 @@ namespace llvm { SCEVTraversal T(Visitor); T.visitAll(Root); } + + /// The SCEVRewriter takes a scalar evolution expression and copies all its + /// components. The result after a rewrite is an identical SCEV. + struct SCEVRewriter + : public SCEVVisitor { + public: + SCEVRewriter(ScalarEvolution &S) : SE(S) {} + + virtual ~SCEVRewriter() {} + + virtual const SCEV *visitConstant(const SCEVConstant *Constant) { + return Constant; + } + + virtual const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + return SE.getTruncateExpr(Operand, Expr->getType()); + } + + virtual const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + return SE.getZeroExtendExpr(Operand, Expr->getType()); + } + + virtual const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + return SE.getSignExtendExpr(Operand, Expr->getType()); + } + + virtual const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getAddExpr(Operands); + } + + virtual const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getMulExpr(Operands); + } + + virtual const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { + return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS())); + } + + virtual const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getAddRecExpr(Operands, Expr->getLoop(), + Expr->getNoWrapFlags()); + } + + virtual const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getSMaxExpr(Operands); + } + + virtual const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + return SE.getUMaxExpr(Operands); + } + + virtual const SCEV *visitUnknown(const SCEVUnknown *Expr) { + return Expr; + } + + virtual const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + return Expr; + } + + protected: + ScalarEvolution &SE; + }; + + typedef DenseMap ValueToValueMap; + + /// The SCEVParameterRewriter takes a scalar evolution expression and updates + /// the SCEVUnknown components following the Map (Value -> Value). + struct SCEVParameterRewriter: public SCEVRewriter { + public: + static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, + ValueToValueMap &Map) { + SCEVParameterRewriter Rewriter(SE, Map); + return Rewriter.visit(Scev); + } + SCEVParameterRewriter(ScalarEvolution &S, ValueToValueMap &M) + : SCEVRewriter(S), Map(M) {} + + virtual const SCEV *visitUnknown(const SCEVUnknown *Expr) { + Value *V = Expr->getValue(); + if (Map.count(V)) + return SE.getUnknown(Map[V]); + return Expr; + } + + private: + ValueToValueMap ⤅ + }; + + typedef DenseMap LoopToScevMapT; + + /// The SCEVApplyRewriter takes a scalar evolution expression and applies + /// the Map (Loop -> SCEV) to all AddRecExprs. + struct SCEVApplyRewriter: public SCEVRewriter { + public: + static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, + ScalarEvolution &SE) { + SCEVApplyRewriter Rewriter(SE, Map); + return Rewriter.visit(Scev); + } + SCEVApplyRewriter(ScalarEvolution &S, LoopToScevMapT &M) + : SCEVRewriter(S), Map(M) {} + + virtual const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SmallVector Operands; + for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) + Operands.push_back(visit(Expr->getOperand(i))); + + const Loop *L = Expr->getLoop(); + const SCEV *Res = SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags()); + + if (0 == Map.count(L)) + return Res; + + const SCEVAddRecExpr *Rec = (const SCEVAddRecExpr *) Res; + return Rec->evaluateAtIteration(Map[L], SE); + } + + private: + LoopToScevMapT ⤅ + }; + +/// Applies the Map (Loop -> SCEV) to the given Scev. +static inline const SCEV *apply(const SCEV *Scev, LoopToScevMapT &Map, + ScalarEvolution &SE) { + return SCEVApplyRewriter::rewrite(Scev, Map, SE); +} + } #endif