#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/FoldingSet.h"
+#include "llvm/Analysis/LoopInfo.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
class DataLayout;
class TargetLibraryInfo;
class LLVMContext;
- class Loop;
- class LoopInfo;
class Operator;
- class SCEVUnknown;
- class SCEVAddRecExpr;
class SCEV;
- template<> struct FoldingSetTrait<SCEV>;
+ class SCEVAddRecExpr;
+ class SCEVConstant;
+ class SCEVExpander;
+ class SCEVPredicate;
+ class SCEVUnknown;
+
+ template <> struct FoldingSetTrait<SCEV>;
+ template <> struct FoldingSetTrait<SCEVPredicate>;
/// This class represents an analyzed expression in the program. These are
/// opaque objects that the client is not allowed to do much with directly.
/// stream. This should really only be used for debugging purposes.
void print(raw_ostream &OS) const;
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// This method is used for debugging.
///
void dump() const;
-#endif
};
// Specialize FoldingSetTrait for SCEV to avoid needing to compute
static bool classof(const SCEV *S);
};
+ /// SCEVPredicate - This class represents an assumption made using SCEV
+ /// expressions which can be checked at run-time.
+ class SCEVPredicate : public FoldingSetNode {
+ friend struct FoldingSetTrait<SCEVPredicate>;
+
+ /// A reference to an Interned FoldingSetNodeID for this node. The
+ /// ScalarEvolution's BumpPtrAllocator holds the data.
+ FoldingSetNodeIDRef FastID;
+
+ public:
+ enum SCEVPredicateKind { P_Union, P_Equal };
+
+ protected:
+ SCEVPredicateKind Kind;
+ ~SCEVPredicate() = default;
+ SCEVPredicate(const SCEVPredicate&) = default;
+ SCEVPredicate &operator=(const SCEVPredicate&) = default;
+
+ public:
+ SCEVPredicate(const FoldingSetNodeIDRef ID, SCEVPredicateKind Kind);
+
+ SCEVPredicateKind getKind() const { return Kind; }
+
+ /// \brief Returns the estimated complexity of this predicate.
+ /// This is roughly measured in the number of run-time checks required.
+ virtual unsigned getComplexity() const { return 1; }
+
+ /// \brief Returns true if the predicate is always true. This means that no
+ /// assumptions were made and nothing needs to be checked at run-time.
+ virtual bool isAlwaysTrue() const = 0;
+
+ /// \brief Returns true if this predicate implies \p N.
+ virtual bool implies(const SCEVPredicate *N) const = 0;
+
+ /// \brief Prints a textual representation of this predicate with an
+ /// indentation of \p Depth.
+ virtual void print(raw_ostream &OS, unsigned Depth = 0) const = 0;
+
+ /// \brief Returns the SCEV to which this predicate applies, or nullptr
+ /// if this is a SCEVUnionPredicate.
+ virtual const SCEV *getExpr() const = 0;
+ };
+
+ inline raw_ostream &operator<<(raw_ostream &OS, const SCEVPredicate &P) {
+ P.print(OS);
+ return OS;
+ }
+
+ // Specialize FoldingSetTrait for SCEVPredicate to avoid needing to compute
+ // temporary FoldingSetNodeID values.
+ template <>
+ struct FoldingSetTrait<SCEVPredicate>
+ : DefaultFoldingSetTrait<SCEVPredicate> {
+
+ static void Profile(const SCEVPredicate &X, FoldingSetNodeID &ID) {
+ ID = X.FastID;
+ }
+
+ static bool Equals(const SCEVPredicate &X, const FoldingSetNodeID &ID,
+ unsigned IDHash, FoldingSetNodeID &TempID) {
+ return ID == X.FastID;
+ }
+ static unsigned ComputeHash(const SCEVPredicate &X,
+ FoldingSetNodeID &TempID) {
+ return X.FastID.ComputeHash();
+ }
+ };
+
+ /// SCEVEqualPredicate - This class represents an assumption that two SCEV
+ /// expressions are equal, and this can be checked at run-time. We assume
+ /// that the left hand side is a SCEVUnknown and the right hand side a
+ /// constant.
+ class SCEVEqualPredicate final : public SCEVPredicate {
+ /// We assume that LHS == RHS, where LHS is a SCEVUnknown and RHS a
+ /// constant.
+ const SCEVUnknown *LHS;
+ const SCEVConstant *RHS;
+
+ public:
+ SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEVUnknown *LHS,
+ const SCEVConstant *RHS);
+
+ /// Implementation of the SCEVPredicate interface
+ bool implies(const SCEVPredicate *N) const override;
+ void print(raw_ostream &OS, unsigned Depth = 0) const override;
+ bool isAlwaysTrue() const override;
+ const SCEV *getExpr() const override;
+
+ /// \brief Returns the left hand side of the equality.
+ const SCEVUnknown *getLHS() const { return LHS; }
+
+ /// \brief Returns the right hand side of the equality.
+ const SCEVConstant *getRHS() const { return RHS; }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast:
+ static inline bool classof(const SCEVPredicate *P) {
+ return P->getKind() == P_Equal;
+ }
+ };
+
+ /// SCEVUnionPredicate - This class represents a composition of other
+ /// SCEV predicates, and is the class that most clients will interact with.
+ /// This is equivalent to a logical "AND" of all the predicates in the union.
+ class SCEVUnionPredicate final : public SCEVPredicate {
+ private:
+ typedef DenseMap<const SCEV *, SmallVector<const SCEVPredicate *, 4>>
+ PredicateMap;
+
+ /// Vector with references to all predicates in this union.
+ SmallVector<const SCEVPredicate *, 16> Preds;
+ /// Maps SCEVs to predicates for quick look-ups.
+ PredicateMap SCEVToPreds;
+
+ public:
+ SCEVUnionPredicate();
+
+ const SmallVectorImpl<const SCEVPredicate *> &getPredicates() const {
+ return Preds;
+ }
+
+ /// \brief Adds a predicate to this union.
+ void add(const SCEVPredicate *N);
+
+ /// \brief Returns a reference to a vector containing all predicates
+ /// which apply to \p Expr.
+ ArrayRef<const SCEVPredicate *> getPredicatesForExpr(const SCEV *Expr);
+
+ /// Implementation of the SCEVPredicate interface
+ bool isAlwaysTrue() const override;
+ bool implies(const SCEVPredicate *N) const override;
+ void print(raw_ostream &OS, unsigned Depth) const override;
+ const SCEV *getExpr() const override;
+
+ /// \brief We estimate the complexity of a union predicate as the size
+ /// number of predicates in the union.
+ unsigned getComplexity() const override { return Preds.size(); }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast:
+ static inline bool classof(const SCEVPredicate *P) {
+ return P->getKind() == P_Union;
+ }
+ };
+
/// The main scalar evolution driver. Because client code (intentionally)
/// can't do much with the SCEV objects directly, they must ask this class
/// for services.
const BackedgeTakenInfo &getBackedgeTakenInfo(const Loop *L);
/// Compute the number of times the specified loop will iterate.
- BackedgeTakenInfo ComputeBackedgeTakenCount(const Loop *L);
+ BackedgeTakenInfo computeBackedgeTakenCount(const Loop *L);
/// Compute the number of times the backedge of the specified loop will
/// execute if it exits via the specified block.
- ExitLimit ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock);
+ ExitLimit computeExitLimit(const Loop *L, BasicBlock *ExitingBlock);
/// Compute the number of times the backedge of the specified loop will
/// execute if its exit condition were a conditional branch of ExitCond,
/// TBB, and FBB.
- ExitLimit ComputeExitLimitFromCond(const Loop *L,
+ ExitLimit computeExitLimitFromCond(const Loop *L,
Value *ExitCond,
BasicBlock *TBB,
BasicBlock *FBB,
/// Compute the number of times the backedge of the specified loop will
/// execute if its exit condition were a conditional branch of the ICmpInst
/// ExitCond, TBB, and FBB.
- ExitLimit ComputeExitLimitFromICmp(const Loop *L,
+ ExitLimit computeExitLimitFromICmp(const Loop *L,
ICmpInst *ExitCond,
BasicBlock *TBB,
BasicBlock *FBB,
/// execute if its exit condition were a switch with a single exiting case
/// to ExitingBB.
ExitLimit
- ComputeExitLimitFromSingleExitSwitch(const Loop *L, SwitchInst *Switch,
+ computeExitLimitFromSingleExitSwitch(const Loop *L, SwitchInst *Switch,
BasicBlock *ExitingBB, bool IsSubExpr);
/// Given an exit condition of 'icmp op load X, cst', try to see if we can
/// compute the backedge-taken count.
- ExitLimit ComputeLoadConstantCompareExitLimit(LoadInst *LI,
+ ExitLimit computeLoadConstantCompareExitLimit(LoadInst *LI,
Constant *RHS,
const Loop *L,
ICmpInst::Predicate p);
+ /// Compute the exit limit of a loop that is controlled by a
+ /// "(IV >> 1) != 0" type comparison. We cannot compute the exact trip
+ /// count in these cases (since SCEV has no way of expressing them), but we
+ /// can still sometimes compute an upper bound.
+ ///
+ /// Return an ExitLimit for a loop whose backedge is guarded by `LHS Pred
+ /// RHS`.
+ ExitLimit computeShiftCompareExitLimit(Value *LHS, Value *RHS,
+ const Loop *L,
+ ICmpInst::Predicate Pred);
+
/// If the loop is known to execute a constant number of times (the
/// condition evolves only from constants), try to evaluate a few iterations
/// of the loop until we get the exit condition gets a value of ExitWhen
/// (true or false). If we cannot evaluate the exit count of the loop,
/// return CouldNotCompute.
- const SCEV *ComputeExitCountExhaustively(const Loop *L,
+ const SCEV *computeExitCountExhaustively(const Loop *L,
Value *Cond,
bool ExitWhen);
bool isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS);
+ /// Try to prove the condition described by "LHS Pred RHS" by ruling out
+ /// integer overflow.
+ ///
+ /// For instance, this will return true for "A s< (A + C)<nsw>" if C is
+ /// positive.
+ bool isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS);
+
/// Try to split Pred LHS RHS into logical conjunctions (and's) and try to
/// prove them individually.
bool isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, const SCEV *LHS,
const SCEV *RHS);
+ /// Try to match the Expr as "(L + R)<Flags>".
+ bool splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R,
+ SCEV::NoWrapFlags &Flags);
+
+ /// Return true if More == (Less + C), where C is a constant. This is
+ /// intended to be used as a cheaper substitute for full SCEV subtraction.
+ bool computeConstantDifference(const SCEV *Less, const SCEV *More,
+ APInt &C);
+
/// Drop memoized information computed for S.
void forgetMemoizedResults(const SCEV *S);
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap);
const SCEV *getAddExpr(const SCEV *LHS, const SCEV *RHS,
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
- SmallVector<const SCEV *, 2> Ops;
- Ops.push_back(LHS);
- Ops.push_back(RHS);
+ SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
return getAddExpr(Ops, Flags);
}
const SCEV *getAddExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2,
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
- SmallVector<const SCEV *, 3> Ops;
- Ops.push_back(Op0);
- Ops.push_back(Op1);
- Ops.push_back(Op2);
+ SmallVector<const SCEV *, 3> Ops = {Op0, Op1, Op2};
return getAddExpr(Ops, Flags);
}
const SCEV *getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap);
const SCEV *getMulExpr(const SCEV *LHS, const SCEV *RHS,
- SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap)
- {
- SmallVector<const SCEV *, 2> Ops;
- Ops.push_back(LHS);
- Ops.push_back(RHS);
+ SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
+ SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
return getMulExpr(Ops, Flags);
}
const SCEV *getMulExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2,
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
- SmallVector<const SCEV *, 3> Ops;
- Ops.push_back(Op0);
- Ops.push_back(Op1);
- Ops.push_back(Op2);
+ SmallVector<const SCEV *, 3> Ops = {Op0, Op1, Op2};
return getMulExpr(Ops, Flags);
}
const SCEV *getUDivExpr(const SCEV *LHS, const SCEV *RHS);
SmallVectorImpl<const SCEV *> &Sizes,
const SCEV *ElementSize);
+ /// Return the DataLayout associated with the module this SCEV instance is
+ /// operating on.
+ const DataLayout &getDataLayout() const {
+ return F.getParent()->getDataLayout();
+ }
+
+ const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS,
+ const SCEVConstant *RHS);
+
+ /// Re-writes the SCEV according to the Predicates in \p Preds.
+ const SCEV *rewriteUsingPredicate(const SCEV *Scev, SCEVUnionPredicate &A);
+
private:
/// Compute the backedge taken count knowing the interval difference, the
/// stride and presence of the equality in the comparison.
private:
FoldingSet<SCEV> UniqueSCEVs;
+ FoldingSet<SCEVPredicate> UniquePreds;
BumpPtrAllocator SCEVAllocator;
/// The head of a linked list of all SCEVUnknown values that have been
void print(raw_ostream &OS, const Module * = nullptr) const override;
void verifyAnalysis() const override;
};
+
+ /// An interface layer with SCEV used to manage how we see SCEV expressions
+ /// for values in the context of existing predicates. We can add new
+ /// predicates, but we cannot remove them.
+ ///
+ /// This layer has multiple purposes:
+ /// - provides a simple interface for SCEV versioning.
+ /// - guarantees that the order of transformations applied on a SCEV
+ /// expression for a single Value is consistent across two different
+ /// getSCEV calls. This means that, for example, once we've obtained
+ /// an AddRec expression for a certain value through expression
+ /// rewriting, we will continue to get an AddRec expression for that
+ /// Value.
+ /// - lowers the number of expression rewrites.
+ class PredicatedScalarEvolution {
+ public:
+ PredicatedScalarEvolution(ScalarEvolution &SE);
+ const SCEVUnionPredicate &getUnionPredicate() const;
+ /// \brief Returns the SCEV expression of V, in the context of the current
+ /// SCEV predicate.
+ /// The order of transformations applied on the expression of V returned
+ /// by ScalarEvolution is guaranteed to be preserved, even when adding new
+ /// predicates.
+ const SCEV *getSCEV(Value *V);
+ /// \brief Adds a new predicate.
+ void addPredicate(const SCEVPredicate &Pred);
+ /// \brief Returns the ScalarEvolution analysis used.
+ ScalarEvolution *getSE() const { return &SE; }
+
+ private:
+ /// \brief Increments the version number of the predicate.
+ /// This needs to be called every time the SCEV predicate changes.
+ void updateGeneration();
+ /// Holds a SCEV and the version number of the SCEV predicate used to
+ /// perform the rewrite of the expression.
+ typedef std::pair<unsigned, const SCEV *> RewriteEntry;
+ /// Maps a SCEV to the rewrite result of that SCEV at a certain version
+ /// number. If this number doesn't match the current Generation, we will
+ /// need to do a rewrite. To preserve the transformation order of previous
+ /// rewrites, we will rewrite the previous result instead of the original
+ /// SCEV.
+ DenseMap<const SCEV *, RewriteEntry> RewriteMap;
+ /// The ScalarEvolution analysis.
+ ScalarEvolution &SE;
+ /// The SCEVPredicate that forms our context. We will rewrite all
+ /// expressions assuming that this predicate true.
+ SCEVUnionPredicate Preds;
+ /// Marks the version of the SCEV predicate used. When rewriting a SCEV
+ /// expression we mark it with the version of the predicate. We use this to
+ /// figure out if the predicate has changed from the last rewrite of the
+ /// SCEV. If so, we need to perform a new rewrite.
+ unsigned Generation;
+ };
}
#endif