Merging r258184:
[oota-llvm.git] / include / llvm / Analysis / ScalarEvolution.h
index 3c28093d7052426372462d858fc6977eec337ae8..ef9305788849e7bd609aaa12edce7a89591283f1 100644 (file)
@@ -23,6 +23,7 @@
 
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/FoldingSet.h"
+#include "llvm/Analysis/LoopInfo.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
@@ -45,13 +46,16 @@ namespace llvm {
   class DataLayout;
   class TargetLibraryInfo;
   class LLVMContext;
-  class Loop;
-  class LoopInfo;
   class Operator;
-  class SCEVUnknown;
-  class SCEVAddRecExpr;
   class SCEV;
-  template<> struct FoldingSetTrait<SCEV>;
+  class SCEVAddRecExpr;
+  class SCEVConstant;
+  class SCEVExpander;
+  class SCEVPredicate;
+  class SCEVUnknown;
+
+  template <> struct FoldingSetTrait<SCEV>;
+  template <> struct FoldingSetTrait<SCEVPredicate>;
 
   /// This class represents an analyzed expression in the program.  These are
   /// opaque objects that the client is not allowed to do much with directly.
@@ -164,6 +168,149 @@ namespace llvm {
     static bool classof(const SCEV *S);
   };
 
+  /// SCEVPredicate - This class represents an assumption made using SCEV
+  /// expressions which can be checked at run-time.
+  class SCEVPredicate : public FoldingSetNode {
+    friend struct FoldingSetTrait<SCEVPredicate>;
+
+    /// A reference to an Interned FoldingSetNodeID for this node.  The
+    /// ScalarEvolution's BumpPtrAllocator holds the data.
+    FoldingSetNodeIDRef FastID;
+
+  public:
+    enum SCEVPredicateKind { P_Union, P_Equal };
+
+  protected:
+    SCEVPredicateKind Kind;
+    ~SCEVPredicate() = default;
+    SCEVPredicate(const SCEVPredicate&) = default;
+    SCEVPredicate &operator=(const SCEVPredicate&) = default;
+
+  public:
+    SCEVPredicate(const FoldingSetNodeIDRef ID, SCEVPredicateKind Kind);
+
+    SCEVPredicateKind getKind() const { return Kind; }
+
+    /// \brief Returns the estimated complexity of this predicate.
+    /// This is roughly measured in the number of run-time checks required.
+    virtual unsigned getComplexity() const { return 1; }
+
+    /// \brief Returns true if the predicate is always true. This means that no
+    /// assumptions were made and nothing needs to be checked at run-time.
+    virtual bool isAlwaysTrue() const = 0;
+
+    /// \brief Returns true if this predicate implies \p N.
+    virtual bool implies(const SCEVPredicate *N) const = 0;
+
+    /// \brief Prints a textual representation of this predicate with an
+    /// indentation of \p Depth.
+    virtual void print(raw_ostream &OS, unsigned Depth = 0) const = 0;
+
+    /// \brief Returns the SCEV to which this predicate applies, or nullptr
+    /// if this is a SCEVUnionPredicate.
+    virtual const SCEV *getExpr() const = 0;
+  };
+
+  inline raw_ostream &operator<<(raw_ostream &OS, const SCEVPredicate &P) {
+    P.print(OS);
+    return OS;
+  }
+
+  // Specialize FoldingSetTrait for SCEVPredicate to avoid needing to compute
+  // temporary FoldingSetNodeID values.
+  template <>
+  struct FoldingSetTrait<SCEVPredicate>
+      : DefaultFoldingSetTrait<SCEVPredicate> {
+
+    static void Profile(const SCEVPredicate &X, FoldingSetNodeID &ID) {
+      ID = X.FastID;
+    }
+
+    static bool Equals(const SCEVPredicate &X, const FoldingSetNodeID &ID,
+                       unsigned IDHash, FoldingSetNodeID &TempID) {
+      return ID == X.FastID;
+    }
+    static unsigned ComputeHash(const SCEVPredicate &X,
+                                FoldingSetNodeID &TempID) {
+      return X.FastID.ComputeHash();
+    }
+  };
+
+  /// SCEVEqualPredicate - This class represents an assumption that two SCEV
+  /// expressions are equal, and this can be checked at run-time. We assume
+  /// that the left hand side is a SCEVUnknown and the right hand side a
+  /// constant.
+  class SCEVEqualPredicate final : public SCEVPredicate {
+    /// We assume that LHS == RHS, where LHS is a SCEVUnknown and RHS a
+    /// constant.
+    const SCEVUnknown *LHS;
+    const SCEVConstant *RHS;
+
+  public:
+    SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEVUnknown *LHS,
+                       const SCEVConstant *RHS);
+
+    /// Implementation of the SCEVPredicate interface
+    bool implies(const SCEVPredicate *N) const override;
+    void print(raw_ostream &OS, unsigned Depth = 0) const override;
+    bool isAlwaysTrue() const override;
+    const SCEV *getExpr() const override;
+
+    /// \brief Returns the left hand side of the equality.
+    const SCEVUnknown *getLHS() const { return LHS; }
+
+    /// \brief Returns the right hand side of the equality.
+    const SCEVConstant *getRHS() const { return RHS; }
+
+    /// Methods for support type inquiry through isa, cast, and dyn_cast:
+    static inline bool classof(const SCEVPredicate *P) {
+      return P->getKind() == P_Equal;
+    }
+  };
+
+  /// SCEVUnionPredicate - This class represents a composition of other
+  /// SCEV predicates, and is the class that most clients will interact with.
+  /// This is equivalent to a logical "AND" of all the predicates in the union.
+  class SCEVUnionPredicate final : public SCEVPredicate {
+  private:
+    typedef DenseMap<const SCEV *, SmallVector<const SCEVPredicate *, 4>>
+        PredicateMap;
+
+    /// Vector with references to all predicates in this union.
+    SmallVector<const SCEVPredicate *, 16> Preds;
+    /// Maps SCEVs to predicates for quick look-ups.
+    PredicateMap SCEVToPreds;
+
+  public:
+    SCEVUnionPredicate();
+
+    const SmallVectorImpl<const SCEVPredicate *> &getPredicates() const {
+      return Preds;
+    }
+
+    /// \brief Adds a predicate to this union.
+    void add(const SCEVPredicate *N);
+
+    /// \brief Returns a reference to a vector containing all predicates
+    /// which apply to \p Expr.
+    ArrayRef<const SCEVPredicate *> getPredicatesForExpr(const SCEV *Expr);
+
+    /// Implementation of the SCEVPredicate interface
+    bool isAlwaysTrue() const override;
+    bool implies(const SCEVPredicate *N) const override;
+    void print(raw_ostream &OS, unsigned Depth) const override;
+    const SCEV *getExpr() const override;
+
+    /// \brief We estimate the complexity of a union predicate as the size
+    /// number of predicates in the union.
+    unsigned getComplexity() const override { return Preds.size(); }
+
+    /// Methods for support type inquiry through isa, cast, and dyn_cast:
+    static inline bool classof(const SCEVPredicate *P) {
+      return P->getKind() == P_Union;
+    }
+  };
+
   /// The main scalar evolution driver. Because client code (intentionally)
   /// can't do much with the SCEV objects directly, they must ask this class
   /// for services.
@@ -265,7 +412,11 @@ namespace llvm {
 
       /*implicit*/ ExitLimit(const SCEV *E) : Exact(E), Max(E) {}
 
-      ExitLimit(const SCEV *E, const SCEV *M) : Exact(E), Max(M) {}
+      ExitLimit(const SCEV *E, const SCEV *M) : Exact(E), Max(M) {
+        assert((isa<SCEVCouldNotCompute>(Exact) ||
+                !isa<SCEVCouldNotCompute>(Max)) &&
+               "Exact is not allowed to be less precise than Max");
+      }
 
       /// Test whether this ExitLimit contains any computed information, or
       /// whether it's all SCEVCouldNotCompute values.
@@ -482,6 +633,17 @@ namespace llvm {
                                                   const Loop *L,
                                                   ICmpInst::Predicate p);
 
+    /// Compute the exit limit of a loop that is controlled by a
+    /// "(IV >> 1) != 0" type comparison.  We cannot compute the exact trip
+    /// count in these cases (since SCEV has no way of expressing them), but we
+    /// can still sometimes compute an upper bound.
+    ///
+    /// Return an ExitLimit for a loop whose backedge is guarded by `LHS Pred
+    /// RHS`.
+    ExitLimit computeShiftCompareExitLimit(Value *LHS, Value *RHS,
+                                           const Loop *L,
+                                           ICmpInst::Predicate Pred);
+
     /// If the loop is known to execute a constant number of times (the
     /// condition evolves only from constants), try to evaluate a few iterations
     /// of the loop until we get the exit condition gets a value of ExitWhen
@@ -674,35 +836,24 @@ namespace llvm {
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap);
     const SCEV *getAddExpr(const SCEV *LHS, const SCEV *RHS,
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
-      SmallVector<const SCEV *, 2> Ops;
-      Ops.push_back(LHS);
-      Ops.push_back(RHS);
+      SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
       return getAddExpr(Ops, Flags);
     }
     const SCEV *getAddExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2,
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
-      SmallVector<const SCEV *, 3> Ops;
-      Ops.push_back(Op0);
-      Ops.push_back(Op1);
-      Ops.push_back(Op2);
+      SmallVector<const SCEV *, 3> Ops = {Op0, Op1, Op2};
       return getAddExpr(Ops, Flags);
     }
     const SCEV *getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap);
     const SCEV *getMulExpr(const SCEV *LHS, const SCEV *RHS,
-                           SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap)
-    {
-      SmallVector<const SCEV *, 2> Ops;
-      Ops.push_back(LHS);
-      Ops.push_back(RHS);
+                           SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
+      SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
       return getMulExpr(Ops, Flags);
     }
     const SCEV *getMulExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2,
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
-      SmallVector<const SCEV *, 3> Ops;
-      Ops.push_back(Op0);
-      Ops.push_back(Op1);
-      Ops.push_back(Op2);
+      SmallVector<const SCEV *, 3> Ops = {Op0, Op1, Op2};
       return getMulExpr(Ops, Flags);
     }
     const SCEV *getUDivExpr(const SCEV *LHS, const SCEV *RHS);
@@ -1091,6 +1242,18 @@ namespace llvm {
                      SmallVectorImpl<const SCEV *> &Sizes,
                      const SCEV *ElementSize);
 
+    /// Return the DataLayout associated with the module this SCEV instance is
+    /// operating on.
+    const DataLayout &getDataLayout() const {
+      return F.getParent()->getDataLayout();
+    }
+
+    const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS,
+                                           const SCEVConstant *RHS);
+
+    /// Re-writes the SCEV according to the Predicates in \p Preds.
+    const SCEV *rewriteUsingPredicate(const SCEV *Scev, SCEVUnionPredicate &A);
+
   private:
     /// Compute the backedge taken count knowing the interval difference, the
     /// stride and presence of the equality in the comparison.
@@ -1111,6 +1274,7 @@ namespace llvm {
 
   private:
     FoldingSet<SCEV> UniqueSCEVs;
+    FoldingSet<SCEVPredicate> UniquePreds;
     BumpPtrAllocator SCEVAllocator;
 
     /// The head of a linked list of all SCEVUnknown values that have been
@@ -1163,6 +1327,59 @@ namespace llvm {
     void print(raw_ostream &OS, const Module * = nullptr) const override;
     void verifyAnalysis() const override;
   };
+
+  /// An interface layer with SCEV used to manage how we see SCEV expressions
+  /// for values in the context of existing predicates. We can add new
+  /// predicates, but we cannot remove them.
+  ///
+  /// This layer has multiple purposes:
+  ///   - provides a simple interface for SCEV versioning.
+  ///   - guarantees that the order of transformations applied on a SCEV
+  ///     expression for a single Value is consistent across two different
+  ///     getSCEV calls. This means that, for example, once we've obtained
+  ///     an AddRec expression for a certain value through expression
+  ///     rewriting, we will continue to get an AddRec expression for that
+  ///     Value.
+  ///   - lowers the number of expression rewrites.
+  class PredicatedScalarEvolution {
+  public:
+    PredicatedScalarEvolution(ScalarEvolution &SE);
+    const SCEVUnionPredicate &getUnionPredicate() const;
+    /// \brief Returns the SCEV expression of V, in the context of the current
+    /// SCEV predicate.
+    /// The order of transformations applied on the expression of V returned
+    /// by ScalarEvolution is guaranteed to be preserved, even when adding new
+    /// predicates.
+    const SCEV *getSCEV(Value *V);
+    /// \brief Adds a new predicate.
+    void addPredicate(const SCEVPredicate &Pred);
+    /// \brief Returns the ScalarEvolution analysis used.
+    ScalarEvolution *getSE() const { return &SE; }
+
+  private:
+    /// \brief Increments the version number of the predicate.
+    /// This needs to be called every time the SCEV predicate changes.
+    void updateGeneration();
+    /// Holds a SCEV and the version number of the SCEV predicate used to
+    /// perform the rewrite of the expression.
+    typedef std::pair<unsigned, const SCEV *> RewriteEntry;
+    /// Maps a SCEV to the rewrite result of that SCEV at a certain version
+    /// number. If this number doesn't match the current Generation, we will
+    /// need to do a rewrite. To preserve the transformation order of previous
+    /// rewrites, we will rewrite the previous result instead of the original
+    /// SCEV.
+    DenseMap<const SCEV *, RewriteEntry> RewriteMap;
+    /// The ScalarEvolution analysis.
+    ScalarEvolution &SE;
+    /// The SCEVPredicate that forms our context. We will rewrite all
+    /// expressions assuming that this predicate true.
+    SCEVUnionPredicate Preds;
+    /// Marks the version of the SCEV predicate used. When rewriting a SCEV
+    /// expression we mark it with the version of the predicate. We use this to
+    /// figure out if the predicate has changed from the last rewrite of the
+    /// SCEV. If so, we need to perform a new rewrite.
+    unsigned Generation;
+  };
 }
 
 #endif