Revert r254592 (virtual dtor in SCEVPredicate).
[oota-llvm.git] / include / llvm / Analysis / ScalarEvolution.h
index 62d66c246d717bedfa2bee84fc1b48b18c664209..012fafb67a402be49a0e871ccaea56724be366db 100644 (file)
@@ -48,10 +48,15 @@ namespace llvm {
   class Loop;
   class LoopInfo;
   class Operator;
-  class SCEVUnknown;
-  class SCEVAddRecExpr;
   class SCEV;
-  template<> struct FoldingSetTrait<SCEV>;
+  class SCEVAddRecExpr;
+  class SCEVConstant;
+  class SCEVExpander;
+  class SCEVPredicate;
+  class SCEVUnknown;
+
+  template <> struct FoldingSetTrait<SCEV>;
+  template <> struct FoldingSetTrait<SCEVPredicate>;
 
   /// This class represents an analyzed expression in the program.  These are
   /// opaque objects that the client is not allowed to do much with directly.
@@ -128,11 +133,9 @@ namespace llvm {
     /// stream.  This should really only be used for debugging purposes.
     void print(raw_ostream &OS) const;
 
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
     /// This method is used for debugging.
     ///
     void dump() const;
-#endif
   };
 
   // Specialize FoldingSetTrait for SCEV to avoid needing to compute
@@ -166,6 +169,149 @@ namespace llvm {
     static bool classof(const SCEV *S);
   };
 
+  /// SCEVPredicate - This class represents an assumption made using SCEV
+  /// expressions which can be checked at run-time.
+  class SCEVPredicate : public FoldingSetNode {
+    friend struct FoldingSetTrait<SCEVPredicate>;
+
+    /// A reference to an Interned FoldingSetNodeID for this node.  The
+    /// ScalarEvolution's BumpPtrAllocator holds the data.
+    FoldingSetNodeIDRef FastID;
+
+  public:
+    enum SCEVPredicateKind { P_Union, P_Equal };
+
+  protected:
+    SCEVPredicateKind Kind;
+    ~SCEVPredicate() = default;
+    SCEVPredicate(const SCEVPredicate&) = default;
+    SCEVPredicate &operator=(const SCEVPredicate&) = default;
+
+  public:
+    SCEVPredicate(const FoldingSetNodeIDRef ID, SCEVPredicateKind Kind);
+
+    SCEVPredicateKind getKind() const { return Kind; }
+
+    /// \brief Returns the estimated complexity of this predicate.
+    /// This is roughly measured in the number of run-time checks required.
+    virtual unsigned getComplexity() const { return 1; }
+
+    /// \brief Returns true if the predicate is always true. This means that no
+    /// assumptions were made and nothing needs to be checked at run-time.
+    virtual bool isAlwaysTrue() const = 0;
+
+    /// \brief Returns true if this predicate implies \p N.
+    virtual bool implies(const SCEVPredicate *N) const = 0;
+
+    /// \brief Prints a textual representation of this predicate with an
+    /// indentation of \p Depth.
+    virtual void print(raw_ostream &OS, unsigned Depth = 0) const = 0;
+
+    /// \brief Returns the SCEV to which this predicate applies, or nullptr
+    /// if this is a SCEVUnionPredicate.
+    virtual const SCEV *getExpr() const = 0;
+  };
+
+  inline raw_ostream &operator<<(raw_ostream &OS, const SCEVPredicate &P) {
+    P.print(OS);
+    return OS;
+  }
+
+  // Specialize FoldingSetTrait for SCEVPredicate to avoid needing to compute
+  // temporary FoldingSetNodeID values.
+  template <>
+  struct FoldingSetTrait<SCEVPredicate>
+      : DefaultFoldingSetTrait<SCEVPredicate> {
+
+    static void Profile(const SCEVPredicate &X, FoldingSetNodeID &ID) {
+      ID = X.FastID;
+    }
+
+    static bool Equals(const SCEVPredicate &X, const FoldingSetNodeID &ID,
+                       unsigned IDHash, FoldingSetNodeID &TempID) {
+      return ID == X.FastID;
+    }
+    static unsigned ComputeHash(const SCEVPredicate &X,
+                                FoldingSetNodeID &TempID) {
+      return X.FastID.ComputeHash();
+    }
+  };
+
+  /// SCEVEqualPredicate - This class represents an assumption that two SCEV
+  /// expressions are equal, and this can be checked at run-time. We assume
+  /// that the left hand side is a SCEVUnknown and the right hand side a
+  /// constant.
+  class SCEVEqualPredicate final : public SCEVPredicate {
+    /// We assume that LHS == RHS, where LHS is a SCEVUnknown and RHS a
+    /// constant.
+    const SCEVUnknown *LHS;
+    const SCEVConstant *RHS;
+
+  public:
+    SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEVUnknown *LHS,
+                       const SCEVConstant *RHS);
+
+    /// Implementation of the SCEVPredicate interface
+    bool implies(const SCEVPredicate *N) const override;
+    void print(raw_ostream &OS, unsigned Depth = 0) const override;
+    bool isAlwaysTrue() const override;
+    const SCEV *getExpr() const override;
+
+    /// \brief Returns the left hand side of the equality.
+    const SCEVUnknown *getLHS() const { return LHS; }
+
+    /// \brief Returns the right hand side of the equality.
+    const SCEVConstant *getRHS() const { return RHS; }
+
+    /// Methods for support type inquiry through isa, cast, and dyn_cast:
+    static inline bool classof(const SCEVPredicate *P) {
+      return P->getKind() == P_Equal;
+    }
+  };
+
+  /// SCEVUnionPredicate - This class represents a composition of other
+  /// SCEV predicates, and is the class that most clients will interact with.
+  /// This is equivalent to a logical "AND" of all the predicates in the union.
+  class SCEVUnionPredicate final : public SCEVPredicate {
+  private:
+    typedef DenseMap<const SCEV *, SmallVector<const SCEVPredicate *, 4>>
+        PredicateMap;
+
+    /// Vector with references to all predicates in this union.
+    SmallVector<const SCEVPredicate *, 16> Preds;
+    /// Maps SCEVs to predicates for quick look-ups.
+    PredicateMap SCEVToPreds;
+
+  public:
+    SCEVUnionPredicate();
+
+    const SmallVectorImpl<const SCEVPredicate *> &getPredicates() const {
+      return Preds;
+    }
+
+    /// \brief Adds a predicate to this union.
+    void add(const SCEVPredicate *N);
+
+    /// \brief Returns a reference to a vector containing all predicates
+    /// which apply to \p Expr.
+    ArrayRef<const SCEVPredicate *> getPredicatesForExpr(const SCEV *Expr);
+
+    /// Implementation of the SCEVPredicate interface
+    bool isAlwaysTrue() const override;
+    bool implies(const SCEVPredicate *N) const override;
+    void print(raw_ostream &OS, unsigned Depth) const override;
+    const SCEV *getExpr() const override;
+
+    /// \brief We estimate the complexity of a union predicate as the size
+    /// number of predicates in the union.
+    unsigned getComplexity() const override { return Preds.size(); }
+
+    /// Methods for support type inquiry through isa, cast, and dyn_cast:
+    static inline bool classof(const SCEVPredicate *P) {
+      return P->getKind() == P_Union;
+    }
+  };
+
   /// The main scalar evolution driver. Because client code (intentionally)
   /// can't do much with the SCEV objects directly, they must ask this class
   /// for services.
@@ -253,6 +399,10 @@ namespace llvm {
     /// conditions dominating the backedge of a loop.
     bool WalkingBEDominatingConds;
 
+    /// Set to true by isKnownPredicateViaSplitting when we're trying to prove a
+    /// predicate by splitting it into a set of independent predicates.
+    bool ProvingSplitPredicate;
+
     /// Information about the number of loop iterations for which a loop exit's
     /// branch condition evaluates to the not-taken path.  This is a temporary
     /// pair of exact and max expressions that are eventually summarized in
@@ -411,6 +561,19 @@ namespace llvm {
     /// Provide the special handling we need to analyze PHI SCEVs.
     const SCEV *createNodeForPHI(PHINode *PN);
 
+    /// Helper function called from createNodeForPHI.
+    const SCEV *createAddRecFromPHI(PHINode *PN);
+
+    /// Helper function called from createNodeForPHI.
+    const SCEV *createNodeFromSelectLikePHI(PHINode *PN);
+
+    /// Provide special handling for a select-like instruction (currently this
+    /// is either a select instruction or a phi node).  \p I is the instruction
+    /// being processed, and it is assumed equivalent to "Cond ? TrueVal :
+    /// FalseVal".
+    const SCEV *createNodeForSelectOrPHI(Instruction *I, Value *Cond,
+                                         Value *TrueVal, Value *FalseVal);
+
     /// Provide the special handling we need to analyze GEP SCEVs.
     const SCEV *createNodeForGEP(GEPOperator *GEP);
 
@@ -429,16 +592,16 @@ namespace llvm {
     const BackedgeTakenInfo &getBackedgeTakenInfo(const Loop *L);
 
     /// Compute the number of times the specified loop will iterate.
-    BackedgeTakenInfo ComputeBackedgeTakenCount(const Loop *L);
+    BackedgeTakenInfo computeBackedgeTakenCount(const Loop *L);
 
     /// Compute the number of times the backedge of the specified loop will
     /// execute if it exits via the specified block.
-    ExitLimit ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock);
+    ExitLimit computeExitLimit(const Loop *L, BasicBlock *ExitingBlock);
 
     /// Compute the number of times the backedge of the specified loop will
     /// execute if its exit condition were a conditional branch of ExitCond,
     /// TBB, and FBB.
-    ExitLimit ComputeExitLimitFromCond(const Loop *L,
+    ExitLimit computeExitLimitFromCond(const Loop *L,
                                        Value *ExitCond,
                                        BasicBlock *TBB,
                                        BasicBlock *FBB,
@@ -447,7 +610,7 @@ namespace llvm {
     /// Compute the number of times the backedge of the specified loop will
     /// execute if its exit condition were a conditional branch of the ICmpInst
     /// ExitCond, TBB, and FBB.
-    ExitLimit ComputeExitLimitFromICmp(const Loop *L,
+    ExitLimit computeExitLimitFromICmp(const Loop *L,
                                        ICmpInst *ExitCond,
                                        BasicBlock *TBB,
                                        BasicBlock *FBB,
@@ -457,22 +620,33 @@ namespace llvm {
     /// execute if its exit condition were a switch with a single exiting case
     /// to ExitingBB.
     ExitLimit
-    ComputeExitLimitFromSingleExitSwitch(const Loop *L, SwitchInst *Switch,
+    computeExitLimitFromSingleExitSwitch(const Loop *L, SwitchInst *Switch,
                                BasicBlock *ExitingBB, bool IsSubExpr);
 
     /// Given an exit condition of 'icmp op load X, cst', try to see if we can
     /// compute the backedge-taken count.
-    ExitLimit ComputeLoadConstantCompareExitLimit(LoadInst *LI,
+    ExitLimit computeLoadConstantCompareExitLimit(LoadInst *LI,
                                                   Constant *RHS,
                                                   const Loop *L,
                                                   ICmpInst::Predicate p);
 
+    /// Compute the exit limit of a loop that is controlled by a
+    /// "(IV >> 1) != 0" type comparison.  We cannot compute the exact trip
+    /// count in these cases (since SCEV has no way of expressing them), but we
+    /// can still sometimes compute an upper bound.
+    ///
+    /// Return an ExitLimit for a loop whose backedge is guarded by `LHS Pred
+    /// RHS`.
+    ExitLimit computeShiftCompareExitLimit(Value *LHS, Value *RHS,
+                                           const Loop *L,
+                                           ICmpInst::Predicate Pred);
+
     /// If the loop is known to execute a constant number of times (the
     /// condition evolves only from constants), try to evaluate a few iterations
     /// of the loop until we get the exit condition gets a value of ExitWhen
     /// (true or false).  If we cannot evaluate the exit count of the loop,
     /// return CouldNotCompute.
-    const SCEV *ComputeExitCountExhaustively(const Loop *L,
+    const SCEV *computeExitCountExhaustively(const Loop *L,
                                              Value *Cond,
                                              bool ExitWhen);
 
@@ -559,6 +733,28 @@ namespace llvm {
     bool isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
                                     const SCEV *LHS, const SCEV *RHS);
 
+    /// Try to prove the condition described by "LHS Pred RHS" by ruling out
+    /// integer overflow.
+    ///
+    /// For instance, this will return true for "A s< (A + C)<nsw>" if C is
+    /// positive.
+    bool isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
+                                       const SCEV *LHS, const SCEV *RHS);
+
+    /// Try to split Pred LHS RHS into logical conjunctions (and's) and try to
+    /// prove them individually.
+    bool isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, const SCEV *LHS,
+                                      const SCEV *RHS);
+
+    /// Try to match the Expr as "(L + R)<Flags>".
+    bool splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R,
+                        SCEV::NoWrapFlags &Flags);
+
+    /// Return true if More == (Less + C), where C is a constant.  This is
+    /// intended to be used as a cheaper substitute for full SCEV subtraction.
+    bool computeConstantDifference(const SCEV *Less, const SCEV *More,
+                                   APInt &C);
+
     /// Drop memoized information computed for S.
     void forgetMemoizedResults(const SCEV *S);
 
@@ -637,35 +833,24 @@ namespace llvm {
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap);
     const SCEV *getAddExpr(const SCEV *LHS, const SCEV *RHS,
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
-      SmallVector<const SCEV *, 2> Ops;
-      Ops.push_back(LHS);
-      Ops.push_back(RHS);
+      SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
       return getAddExpr(Ops, Flags);
     }
     const SCEV *getAddExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2,
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
-      SmallVector<const SCEV *, 3> Ops;
-      Ops.push_back(Op0);
-      Ops.push_back(Op1);
-      Ops.push_back(Op2);
+      SmallVector<const SCEV *, 3> Ops = {Op0, Op1, Op2};
       return getAddExpr(Ops, Flags);
     }
     const SCEV *getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap);
     const SCEV *getMulExpr(const SCEV *LHS, const SCEV *RHS,
-                           SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap)
-    {
-      SmallVector<const SCEV *, 2> Ops;
-      Ops.push_back(LHS);
-      Ops.push_back(RHS);
+                           SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
+      SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
       return getMulExpr(Ops, Flags);
     }
     const SCEV *getMulExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2,
                            SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
-      SmallVector<const SCEV *, 3> Ops;
-      Ops.push_back(Op0);
-      Ops.push_back(Op1);
-      Ops.push_back(Op2);
+      SmallVector<const SCEV *, 3> Ops = {Op0, Op1, Op2};
       return getMulExpr(Ops, Flags);
     }
     const SCEV *getUDivExpr(const SCEV *LHS, const SCEV *RHS);
@@ -1054,6 +1239,18 @@ namespace llvm {
                      SmallVectorImpl<const SCEV *> &Sizes,
                      const SCEV *ElementSize);
 
+    /// Return the DataLayout associated with the module this SCEV instance is
+    /// operating on.
+    const DataLayout &getDataLayout() const {
+      return F.getParent()->getDataLayout();
+    }
+
+    const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS,
+                                           const SCEVConstant *RHS);
+
+    /// Re-writes the SCEV according to the Predicates in \p Preds.
+    const SCEV *rewriteUsingPredicate(const SCEV *Scev, SCEVUnionPredicate &A);
+
   private:
     /// Compute the backedge taken count knowing the interval difference, the
     /// stride and presence of the equality in the comparison.
@@ -1074,6 +1271,7 @@ namespace llvm {
 
   private:
     FoldingSet<SCEV> UniqueSCEVs;
+    FoldingSet<SCEVPredicate> UniquePreds;
     BumpPtrAllocator SCEVAllocator;
 
     /// The head of a linked list of all SCEVUnknown values that have been
@@ -1126,6 +1324,59 @@ namespace llvm {
     void print(raw_ostream &OS, const Module * = nullptr) const override;
     void verifyAnalysis() const override;
   };
+
+  /// An interface layer with SCEV used to manage how we see SCEV expressions
+  /// for values in the context of existing predicates. We can add new
+  /// predicates, but we cannot remove them.
+  ///
+  /// This layer has multiple purposes:
+  ///   - provides a simple interface for SCEV versioning.
+  ///   - guarantees that the order of transformations applied on a SCEV
+  ///     expression for a single Value is consistent across two different
+  ///     getSCEV calls. This means that, for example, once we've obtained
+  ///     an AddRec expression for a certain value through expression
+  ///     rewriting, we will continue to get an AddRec expression for that
+  ///     Value.
+  ///   - lowers the number of expression rewrites.
+  class PredicatedScalarEvolution {
+  public:
+    PredicatedScalarEvolution(ScalarEvolution &SE);
+    const SCEVUnionPredicate &getUnionPredicate() const;
+    /// \brief Returns the SCEV expression of V, in the context of the current
+    /// SCEV predicate.
+    /// The order of transformations applied on the expression of V returned
+    /// by ScalarEvolution is guaranteed to be preserved, even when adding new
+    /// predicates.
+    const SCEV *getSCEV(Value *V);
+    /// \brief Adds a new predicate.
+    void addPredicate(const SCEVPredicate &Pred);
+    /// \brief Returns the ScalarEvolution analysis used.
+    ScalarEvolution *getSE() const { return &SE; }
+
+  private:
+    /// \brief Increments the version number of the predicate.
+    /// This needs to be called every time the SCEV predicate changes.
+    void updateGeneration();
+    /// Holds a SCEV and the version number of the SCEV predicate used to
+    /// perform the rewrite of the expression.
+    typedef std::pair<unsigned, const SCEV *> RewriteEntry;
+    /// Maps a SCEV to the rewrite result of that SCEV at a certain version
+    /// number. If this number doesn't match the current Generation, we will
+    /// need to do a rewrite. To preserve the transformation order of previous
+    /// rewrites, we will rewrite the previous result instead of the original
+    /// SCEV.
+    DenseMap<const SCEV *, RewriteEntry> RewriteMap;
+    /// The ScalarEvolution analysis.
+    ScalarEvolution &SE;
+    /// The SCEVPredicate that forms our context. We will rewrite all
+    /// expressions assuming that this predicate true.
+    SCEVUnionPredicate Preds;
+    /// Marks the version of the SCEV predicate used. When rewriting a SCEV
+    /// expression we mark it with the version of the predicate. We use this to
+    /// figure out if the predicate has changed from the last rewrite of the
+    /// SCEV. If so, we need to perform a new rewrite.
+    unsigned Generation;
+  };
 }
 
 #endif