Pacify -Wnon-virtual-dtor
[oota-llvm.git] / include / llvm / Analysis / ScalarEvolutionExpressions.h
index 47b37102918618212f21c8c8233c1adc7936eb6b..eac91131ad535c53ec2fcaf42199c313abfc6ea6 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef LLVM_ANALYSIS_SCALAREVOLUTION_EXPRESSIONS_H
-#define LLVM_ANALYSIS_SCALAREVOLUTION_EXPRESSIONS_H
+#ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
+#define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
 
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Support/ErrorHandling.h"
 
@@ -45,7 +46,6 @@ namespace llvm {
     Type *getType() const { return V->getType(); }
 
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVConstant *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scConstant;
     }
@@ -67,7 +67,6 @@ namespace llvm {
     Type *getType() const { return Ty; }
 
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVCastExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scTruncate ||
              S->getSCEVType() == scZeroExtend ||
@@ -87,7 +86,6 @@ namespace llvm {
 
   public:
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVTruncateExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scTruncate;
     }
@@ -105,7 +103,6 @@ namespace llvm {
 
   public:
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVZeroExtendExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scZeroExtend;
     }
@@ -123,7 +120,6 @@ namespace llvm {
 
   public:
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVSignExtendExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scSignExtend;
     }
@@ -165,7 +161,6 @@ namespace llvm {
     }
 
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVNAryExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scAddExpr ||
              S->getSCEVType() == scMulExpr ||
@@ -187,7 +182,6 @@ namespace llvm {
 
   public:
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVCommutativeExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scAddExpr ||
              S->getSCEVType() == scMulExpr ||
@@ -222,7 +216,6 @@ namespace llvm {
     }
 
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVAddExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scAddExpr;
     }
@@ -241,7 +234,6 @@ namespace llvm {
 
   public:
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVMulExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scMulExpr;
     }
@@ -273,7 +265,6 @@ namespace llvm {
     }
 
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVUDivExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scUDivExpr;
     }
@@ -357,7 +348,6 @@ namespace llvm {
     }
 
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVAddRecExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scAddRecExpr;
     }
@@ -379,7 +369,6 @@ namespace llvm {
 
   public:
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVSMaxExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scSMaxExpr;
     }
@@ -401,7 +390,6 @@ namespace llvm {
 
   public:
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVUMaxExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scUMaxExpr;
     }
@@ -448,7 +436,6 @@ namespace llvm {
     Type *getType() const { return getValPtr()->getType(); }
 
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVUnknown *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scUnknown;
     }
@@ -493,6 +480,219 @@ namespace llvm {
       llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
     }
   };
+
+  /// Visit all nodes in the expression tree using worklist traversal.
+  ///
+  /// Visitor implements:
+  ///   // return true to follow this node.
+  ///   bool follow(const SCEV *S);
+  ///   // return true to terminate the search.
+  ///   bool isDone();
+  template<typename SV>
+  class SCEVTraversal {
+    SV &Visitor;
+    SmallVector<const SCEV *, 8> Worklist;
+    SmallPtrSet<const SCEV *, 8> Visited;
+
+    void push(const SCEV *S) {
+      if (Visited.insert(S) && Visitor.follow(S))
+        Worklist.push_back(S);
+    }
+  public:
+    SCEVTraversal(SV& V): Visitor(V) {}
+
+    void visitAll(const SCEV *Root) {
+      push(Root);
+      while (!Worklist.empty() && !Visitor.isDone()) {
+        const SCEV *S = Worklist.pop_back_val();
+
+        switch (S->getSCEVType()) {
+        case scConstant:
+        case scUnknown:
+          break;
+        case scTruncate:
+        case scZeroExtend:
+        case scSignExtend:
+          push(cast<SCEVCastExpr>(S)->getOperand());
+          break;
+        case scAddExpr:
+        case scMulExpr:
+        case scSMaxExpr:
+        case scUMaxExpr:
+        case scAddRecExpr: {
+          const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
+          for (SCEVNAryExpr::op_iterator I = NAry->op_begin(),
+                 E = NAry->op_end(); I != E; ++I) {
+            push(*I);
+          }
+          break;
+        }
+        case scUDivExpr: {
+          const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
+          push(UDiv->getLHS());
+          push(UDiv->getRHS());
+          break;
+        }
+        case scCouldNotCompute:
+          llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
+        default:
+          llvm_unreachable("Unknown SCEV kind!");
+        }
+      }
+    }
+  };
+
+  /// Use SCEVTraversal to visit all nodes in the givien expression tree.
+  template<typename SV>
+  void visitAll(const SCEV *Root, SV& Visitor) {
+    SCEVTraversal<SV> T(Visitor);
+    T.visitAll(Root);
+  }
+
+  /// The SCEVRewriter takes a scalar evolution expression and copies all its
+  /// components. The result after a rewrite is an identical SCEV.
+  struct SCEVRewriter
+    : public SCEVVisitor<SCEVRewriter, const SCEV*> {
+  public:
+    SCEVRewriter(ScalarEvolution &S) : SE(S) {}
+
+    virtual ~SCEVRewriter() {}
+
+    virtual const SCEV *visitConstant(const SCEVConstant *Constant) {
+      return Constant;
+    }
+
+    virtual const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
+      const SCEV *Operand = visit(Expr->getOperand());
+      return SE.getTruncateExpr(Operand, Expr->getType());
+    }
+
+    virtual const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
+      const SCEV *Operand = visit(Expr->getOperand());
+      return SE.getZeroExtendExpr(Operand, Expr->getType());
+    }
+
+    virtual const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
+      const SCEV *Operand = visit(Expr->getOperand());
+      return SE.getSignExtendExpr(Operand, Expr->getType());
+    }
+
+    virtual const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
+      SmallVector<const SCEV *, 2> Operands;
+      for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
+        Operands.push_back(visit(Expr->getOperand(i)));
+      return SE.getAddExpr(Operands);
+    }
+
+    virtual const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
+      SmallVector<const SCEV *, 2> Operands;
+      for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
+        Operands.push_back(visit(Expr->getOperand(i)));
+      return SE.getMulExpr(Operands);
+    }
+
+    virtual const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
+      return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS()));
+    }
+
+    virtual const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+      SmallVector<const SCEV *, 2> Operands;
+      for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
+        Operands.push_back(visit(Expr->getOperand(i)));
+      return SE.getAddRecExpr(Operands, Expr->getLoop(),
+                              Expr->getNoWrapFlags());
+    }
+
+    virtual const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
+      SmallVector<const SCEV *, 2> Operands;
+      for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
+        Operands.push_back(visit(Expr->getOperand(i)));
+      return SE.getSMaxExpr(Operands);
+    }
+
+    virtual const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
+      SmallVector<const SCEV *, 2> Operands;
+      for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
+        Operands.push_back(visit(Expr->getOperand(i)));
+      return SE.getUMaxExpr(Operands);
+    }
+
+    virtual const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+      return Expr;
+    }
+
+    virtual const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
+      return Expr;
+    }
+
+  protected:
+    ScalarEvolution &SE;
+  };
+
+  typedef DenseMap<const Value*, Value*> ValueToValueMap;
+
+  /// The SCEVParameterRewriter takes a scalar evolution expression and updates
+  /// the SCEVUnknown components following the Map (Value -> Value).
+  struct SCEVParameterRewriter: public SCEVRewriter {
+  public:
+    static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
+                               ValueToValueMap &Map) {
+      SCEVParameterRewriter Rewriter(SE, Map);
+      return Rewriter.visit(Scev);
+    }
+    SCEVParameterRewriter(ScalarEvolution &S, ValueToValueMap &M)
+      : SCEVRewriter(S), Map(M) {}
+
+    virtual const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+      Value *V = Expr->getValue();
+      if (Map.count(V))
+        return SE.getUnknown(Map[V]);
+      return Expr;
+    }
+
+  private:
+    ValueToValueMap &Map;
+  };
+
+  typedef DenseMap<const Loop*, const SCEV*> LoopToScevMapT;
+
+  /// The SCEVApplyRewriter takes a scalar evolution expression and applies
+  /// the Map (Loop -> SCEV) to all AddRecExprs.
+  struct SCEVApplyRewriter: public SCEVRewriter {
+  public:
+    static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
+                               ScalarEvolution &SE) {
+      SCEVApplyRewriter Rewriter(SE, Map);
+      return Rewriter.visit(Scev);
+    }
+    SCEVApplyRewriter(ScalarEvolution &S, LoopToScevMapT &M)
+      : SCEVRewriter(S), Map(M) {}
+
+    virtual const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+      SmallVector<const SCEV *, 2> Operands;
+      for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
+        Operands.push_back(visit(Expr->getOperand(i)));
+
+      const Loop *L = Expr->getLoop();
+      const SCEV *Res = SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
+
+      if (0 == Map.count(L))
+        return Res;
+
+      const SCEVAddRecExpr *Rec = (const SCEVAddRecExpr *) Res;
+      return Rec->evaluateAtIteration(Map[L], SE);
+    }
+
+  private:
+    LoopToScevMapT &Map;
+  };
+
+/// Applies the Map (Loop -> SCEV) to the given Scev.
+static inline const SCEV *apply(const SCEV *Scev, LoopToScevMapT &Map,
+                                ScalarEvolution &SE) {
+  return SCEVApplyRewriter::rewrite(Scev, Map, SE);
+}
+
 }
 
 #endif