Merging r258184:
[oota-llvm.git] / include / llvm / Analysis / ScalarEvolution.h
index 1bd7fd0db55b28635f289cf997d2d802fbfbe277..ef9305788849e7bd609aaa12edce7a89591283f1 100644 (file)
@@ -23,6 +23,7 @@
 
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/FoldingSet.h"
+#include "llvm/Analysis/LoopInfo.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
@@ -45,8 +46,6 @@ namespace llvm {
   class DataLayout;
   class TargetLibraryInfo;
   class LLVMContext;
-  class Loop;
-  class LoopInfo;
   class Operator;
   class SCEV;
   class SCEVAddRecExpr;
@@ -183,12 +182,13 @@ namespace llvm {
 
   protected:
     SCEVPredicateKind Kind;
+    ~SCEVPredicate() = default;
+    SCEVPredicate(const SCEVPredicate&) = default;
+    SCEVPredicate &operator=(const SCEVPredicate&) = default;
 
   public:
     SCEVPredicate(const FoldingSetNodeIDRef ID, SCEVPredicateKind Kind);
 
-    virtual ~SCEVPredicate() {}
-
     SCEVPredicateKind getKind() const { return Kind; }
 
     /// \brief Returns the estimated complexity of this predicate.
@@ -240,7 +240,7 @@ namespace llvm {
   /// expressions are equal, and this can be checked at run-time. We assume
   /// that the left hand side is a SCEVUnknown and the right hand side a
   /// constant.
-  class SCEVEqualPredicate : public SCEVPredicate {
+  class SCEVEqualPredicate final : public SCEVPredicate {
     /// We assume that LHS == RHS, where LHS is a SCEVUnknown and RHS a
     /// constant.
     const SCEVUnknown *LHS;
@@ -271,7 +271,7 @@ namespace llvm {
   /// SCEVUnionPredicate - This class represents a composition of other
   /// SCEV predicates, and is the class that most clients will interact with.
   /// This is equivalent to a logical "AND" of all the predicates in the union.
-  class SCEVUnionPredicate : public SCEVPredicate {
+  class SCEVUnionPredicate final : public SCEVPredicate {
   private:
     typedef DenseMap<const SCEV *, SmallVector<const SCEVPredicate *, 4>>
         PredicateMap;
@@ -412,7 +412,11 @@ namespace llvm {
 
       /*implicit*/ ExitLimit(const SCEV *E) : Exact(E), Max(E) {}
 
-      ExitLimit(const SCEV *E, const SCEV *M) : Exact(E), Max(M) {}
+      ExitLimit(const SCEV *E, const SCEV *M) : Exact(E), Max(M) {
+        assert((isa<SCEVCouldNotCompute>(Exact) ||
+                !isa<SCEVCouldNotCompute>(Max)) &&
+               "Exact is not allowed to be less precise than Max");
+      }
 
       /// Test whether this ExitLimit contains any computed information, or
       /// whether it's all SCEVCouldNotCompute values.
@@ -832,35 +836,24 @@ namespace llvm {
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap);
     const SCEV *getAddExpr(const SCEV *LHS, const SCEV *RHS,
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
-      SmallVector<const SCEV *, 2> Ops;
-      Ops.push_back(LHS);
-      Ops.push_back(RHS);
+      SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
       return getAddExpr(Ops, Flags);
     }
     const SCEV *getAddExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2,
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
-      SmallVector<const SCEV *, 3> Ops;
-      Ops.push_back(Op0);
-      Ops.push_back(Op1);
-      Ops.push_back(Op2);
+      SmallVector<const SCEV *, 3> Ops = {Op0, Op1, Op2};
       return getAddExpr(Ops, Flags);
     }
     const SCEV *getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap);
     const SCEV *getMulExpr(const SCEV *LHS, const SCEV *RHS,
-                           SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap)
-    {
-      SmallVector<const SCEV *, 2> Ops;
-      Ops.push_back(LHS);
-      Ops.push_back(RHS);
+                           SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
+      SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
       return getMulExpr(Ops, Flags);
     }
     const SCEV *getMulExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2,
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
-      SmallVector<const SCEV *, 3> Ops;
-      Ops.push_back(Op0);
-      Ops.push_back(Op1);
-      Ops.push_back(Op2);
+      SmallVector<const SCEV *, 3> Ops = {Op0, Op1, Op2};
       return getMulExpr(Ops, Flags);
     }
     const SCEV *getUDivExpr(const SCEV *LHS, const SCEV *RHS);
@@ -1334,6 +1327,59 @@ namespace llvm {
     void print(raw_ostream &OS, const Module * = nullptr) const override;
     void verifyAnalysis() const override;
   };
+
+  /// An interface layer with SCEV used to manage how we see SCEV expressions
+  /// for values in the context of existing predicates. We can add new
+  /// predicates, but we cannot remove them.
+  ///
+  /// This layer has multiple purposes:
+  ///   - provides a simple interface for SCEV versioning.
+  ///   - guarantees that the order of transformations applied on a SCEV
+  ///     expression for a single Value is consistent across two different
+  ///     getSCEV calls. This means that, for example, once we've obtained
+  ///     an AddRec expression for a certain value through expression
+  ///     rewriting, we will continue to get an AddRec expression for that
+  ///     Value.
+  ///   - lowers the number of expression rewrites.
+  class PredicatedScalarEvolution {
+  public:
+    PredicatedScalarEvolution(ScalarEvolution &SE);
+    const SCEVUnionPredicate &getUnionPredicate() const;
+    /// \brief Returns the SCEV expression of V, in the context of the current
+    /// SCEV predicate.
+    /// The order of transformations applied on the expression of V returned
+    /// by ScalarEvolution is guaranteed to be preserved, even when adding new
+    /// predicates.
+    const SCEV *getSCEV(Value *V);
+    /// \brief Adds a new predicate.
+    void addPredicate(const SCEVPredicate &Pred);
+    /// \brief Returns the ScalarEvolution analysis used.
+    ScalarEvolution *getSE() const { return &SE; }
+
+  private:
+    /// \brief Increments the version number of the predicate.
+    /// This needs to be called every time the SCEV predicate changes.
+    void updateGeneration();
+    /// Holds a SCEV and the version number of the SCEV predicate used to
+    /// perform the rewrite of the expression.
+    typedef std::pair<unsigned, const SCEV *> RewriteEntry;
+    /// Maps a SCEV to the rewrite result of that SCEV at a certain version
+    /// number. If this number doesn't match the current Generation, we will
+    /// need to do a rewrite. To preserve the transformation order of previous
+    /// rewrites, we will rewrite the previous result instead of the original
+    /// SCEV.
+    DenseMap<const SCEV *, RewriteEntry> RewriteMap;
+    /// The ScalarEvolution analysis.
+    ScalarEvolution &SE;
+    /// The SCEVPredicate that forms our context. We will rewrite all
+    /// expressions assuming that this predicate true.
+    SCEVUnionPredicate Preds;
+    /// Marks the version of the SCEV predicate used. When rewriting a SCEV
+    /// expression we mark it with the version of the predicate. We use this to
+    /// figure out if the predicate has changed from the last rewrite of the
+    /// SCEV. If so, we need to perform a new rewrite.
+    unsigned Generation;
+  };
 }
 
 #endif