Re-commit r255115, with the PredicatedScalarEvolution class moved to
authorSilviu Baranga <silviu.baranga@arm.com>
Wed, 9 Dec 2015 16:06:28 +0000 (16:06 +0000)
committerSilviu Baranga <silviu.baranga@arm.com>
Wed, 9 Dec 2015 16:06:28 +0000 (16:06 +0000)
ScalarEvolution.h, in order to avoid cyclic dependencies between the Transform
and Analysis modules:

[LV][LAA] Add a layer over SCEV to apply run-time checked knowledge on SCEV expressions

Summary:
This change creates a layer over ScalarEvolution for LAA and LV, and centralizes the
usage of SCEV predicates. The SCEVPredicatedLayer takes the statically deduced knowledge
by ScalarEvolution and applies the knowledge from the SCEV predicates. The end goal is
that both LAA and LV should use this interface everywhere.

This also solves a problem involving the result of SCEV expression rewritting when
the predicate changes. Suppose we have the expression (sext {a,+,b}) and two predicates
  P1: {a,+,b} has nsw
  P2: b = 1.

Applying P1 and then P2 gives us {a,+,1}, while applying P2 and the P1 gives us
sext({a,+,1}) (the AddRec expression was changed by P2 so P1 no longer applies).
The SCEVPredicatedLayer maintains the order of transformations by feeding back
the results of previous transformations into new transformations, and therefore
avoiding this issue.

The SCEVPredicatedLayer maintains a cache to remember the results of previous
SCEV rewritting results. This also has the benefit of reducing the overall number
of expression rewrites.

Reviewers: mzolotukhin, anemet

Subscribers: jmolloy, sanjoy, llvm-commits

Differential Revision: http://reviews.llvm.org/D14296

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@255122 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/LoopAccessAnalysis.h
include/llvm/Analysis/ScalarEvolution.h
lib/Analysis/LoopAccessAnalysis.cpp
lib/Analysis/ScalarEvolution.cpp
lib/Transforms/Scalar/LoopDistribute.cpp
lib/Transforms/Scalar/LoopLoadElimination.cpp
lib/Transforms/Utils/LoopVersioning.cpp
lib/Transforms/Vectorize/LoopVectorize.cpp

index 77d412a4f9271767201932f0d74e65c6f50e70c4..871d35e99b748617ef8e39d8f56198653e1d56cc 100644 (file)
@@ -193,11 +193,10 @@ public:
                const SmallVectorImpl<Instruction *> &Instrs) const;
   };
 
                const SmallVectorImpl<Instruction *> &Instrs) const;
   };
 
-  MemoryDepChecker(ScalarEvolution *Se, const Loop *L,
-                   SCEVUnionPredicate &Preds)
-      : SE(Se), InnermostLoop(L), AccessIdx(0),
+  MemoryDepChecker(PredicatedScalarEvolution &PSE, const Loop *L)
+      : PSE(PSE), InnermostLoop(L), AccessIdx(0),
         ShouldRetryWithRuntimeCheck(false), SafeForVectorization(true),
         ShouldRetryWithRuntimeCheck(false), SafeForVectorization(true),
-        RecordDependences(true), Preds(Preds) {}
+        RecordDependences(true) {}
 
   /// \brief Register the location (instructions are given increasing numbers)
   /// of a write access.
 
   /// \brief Register the location (instructions are given increasing numbers)
   /// of a write access.
@@ -266,7 +265,13 @@ public:
                                                          bool isWrite) const;
 
 private:
                                                          bool isWrite) const;
 
 private:
-  ScalarEvolution *SE;
+  /// A wrapper around ScalarEvolution, used to add runtime SCEV checks, and
+  /// applies dynamic knowledge to simplify SCEV expressions and convert them
+  /// to a more usable form. We need this in case assumptions about SCEV
+  /// expressions need to be made in order to avoid unknown dependences. For
+  /// example we might assume a unit stride for a pointer in order to prove
+  /// that a memory access is strided and doesn't wrap.
+  PredicatedScalarEvolution &PSE;
   const Loop *InnermostLoop;
 
   /// \brief Maps access locations (ptr, read/write) to program order.
   const Loop *InnermostLoop;
 
   /// \brief Maps access locations (ptr, read/write) to program order.
@@ -317,15 +322,6 @@ private:
   /// \brief Check whether the data dependence could prevent store-load
   /// forwarding.
   bool couldPreventStoreLoadForward(unsigned Distance, unsigned TypeByteSize);
   /// \brief Check whether the data dependence could prevent store-load
   /// forwarding.
   bool couldPreventStoreLoadForward(unsigned Distance, unsigned TypeByteSize);
-
-  /// The SCEV predicate containing all the SCEV-related assumptions.
-  /// The dependence checker needs this in order to convert SCEVs of pointers
-  /// to more accurate expressions in the context of existing assumptions.
-  /// We also need this in case assumptions about SCEV expressions need to
-  /// be made in order to avoid unknown dependences. For example we might
-  /// assume a unit stride for a pointer in order to prove that a memory access
-  /// is strided and doesn't wrap.
-  SCEVUnionPredicate &Preds;
 };
 
 /// \brief Holds information about the memory runtime legality checks to verify
 };
 
 /// \brief Holds information about the memory runtime legality checks to verify
@@ -373,7 +369,7 @@ public:
   /// and change \p Preds.
   void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId,
               unsigned ASId, const ValueToValueMap &Strides,
   /// and change \p Preds.
   void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId,
               unsigned ASId, const ValueToValueMap &Strides,
-              SCEVUnionPredicate &Preds);
+              PredicatedScalarEvolution &PSE);
 
   /// \brief No run-time memory checking is necessary.
   bool empty() const { return Pointers.empty(); }
 
   /// \brief No run-time memory checking is necessary.
   bool empty() const { return Pointers.empty(); }
@@ -508,8 +504,8 @@ private:
 /// ScalarEvolution, we will generate run-time checks by emitting a
 /// SCEVUnionPredicate.
 ///
 /// ScalarEvolution, we will generate run-time checks by emitting a
 /// SCEVUnionPredicate.
 ///
-/// Checks for both memory dependences and SCEV predicates must be emitted in
-/// order for the results of this analysis to be valid.
+/// Checks for both memory dependences and the SCEV predicates contained in the
+/// PSE must be emitted in order for the results of this analysis to be valid.
 class LoopAccessInfo {
 public:
   LoopAccessInfo(Loop *L, ScalarEvolution *SE, const DataLayout &DL,
 class LoopAccessInfo {
 public:
   LoopAccessInfo(Loop *L, ScalarEvolution *SE, const DataLayout &DL,
@@ -591,14 +587,12 @@ public:
     return StoreToLoopInvariantAddress;
   }
 
     return StoreToLoopInvariantAddress;
   }
 
-  /// The SCEV predicate contains all the SCEV-related assumptions.
-  /// The is used to keep track of the minimal set of assumptions on SCEV
-  /// expressions that the analysis needs to make in order to return a
-  /// meaningful result. All SCEV expressions during the analysis should be
-  /// re-written (and therefore simplified) according to Preds.
+  /// Used to add runtime SCEV checks. Simplifies SCEV expressions and converts
+  /// them to a more usable form.  All SCEV expressions during the analysis
+  /// should be re-written (and therefore simplified) according to PSE.
   /// A user of LoopAccessAnalysis will need to emit the runtime checks
   /// associated with this predicate.
   /// A user of LoopAccessAnalysis will need to emit the runtime checks
   /// associated with this predicate.
-  SCEVUnionPredicate Preds;
+  PredicatedScalarEvolution PSE;
 
 private:
   /// \brief Analyze the loop.  Substitute symbolic strides using Strides.
 
 private:
   /// \brief Analyze the loop.  Substitute symbolic strides using Strides.
@@ -619,7 +613,6 @@ private:
   MemoryDepChecker DepChecker;
 
   Loop *TheLoop;
   MemoryDepChecker DepChecker;
 
   Loop *TheLoop;
-  ScalarEvolution *SE;
   const DataLayout &DL;
   const TargetLibraryInfo *TLI;
   AliasAnalysis *AA;
   const DataLayout &DL;
   const TargetLibraryInfo *TLI;
   AliasAnalysis *AA;
@@ -654,18 +647,17 @@ Value *stripIntegerCast(Value *V);
 /// If \p OrigPtr is not null, use it to look up the stride value instead of \p
 /// Ptr.  \p PtrToStride provides the mapping between the pointer value and its
 /// stride as collected by LoopVectorizationLegality::collectStridedAccess.
 /// If \p OrigPtr is not null, use it to look up the stride value instead of \p
 /// Ptr.  \p PtrToStride provides the mapping between the pointer value and its
 /// stride as collected by LoopVectorizationLegality::collectStridedAccess.
-const SCEV *replaceSymbolicStrideSCEV(ScalarEvolution *SE,
+const SCEV *replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
                                       const ValueToValueMap &PtrToStride,
                                       const ValueToValueMap &PtrToStride,
-                                      SCEVUnionPredicate &Preds, Value *Ptr,
-                                      Value *OrigPtr = nullptr);
+                                      Value *Ptr, Value *OrigPtr = nullptr);
 
 /// \brief Check the stride of the pointer and ensure that it does not wrap in
 /// the address space, assuming \p Preds is true.
 ///
 /// If necessary this method will version the stride of the pointer according
 /// to \p PtrToStride and therefore add a new predicate to \p Preds.
 
 /// \brief Check the stride of the pointer and ensure that it does not wrap in
 /// the address space, assuming \p Preds is true.
 ///
 /// If necessary this method will version the stride of the pointer according
 /// to \p PtrToStride and therefore add a new predicate to \p Preds.
-int isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp,
-                 const ValueToValueMap &StridesMap, SCEVUnionPredicate &Preds);
+int isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, const Loop *Lp,
+                 const ValueToValueMap &StridesMap);
 
 /// \brief This analysis provides dependence information for the memory accesses
 /// of a loop.
 
 /// \brief This analysis provides dependence information for the memory accesses
 /// of a loop.
index f674cc7ee56f57ccaeef2a5cd2d7bc92ee897703..15565daf6d3b447310b6a3c5028270031cb2af9e 100644 (file)
@@ -1324,6 +1324,59 @@ namespace llvm {
     void print(raw_ostream &OS, const Module * = nullptr) const override;
     void verifyAnalysis() const override;
   };
     void print(raw_ostream &OS, const Module * = nullptr) const override;
     void verifyAnalysis() const override;
   };
+
+  /// An interface layer with SCEV used to manage how we see SCEV expressions
+  /// for values in the context of existing predicates. We can add new
+  /// predicates, but we cannot remove them.
+  ///
+  /// This layer has multiple purposes:
+  ///   - provides a simple interface for SCEV versioning.
+  ///   - guarantees that the order of transformations applied on a SCEV
+  ///     expression for a single Value is consistent across two different
+  ///     getSCEV calls. This means that, for example, once we've obtained
+  ///     an AddRec expression for a certain value through expression
+  ///     rewriting, we will continue to get an AddRec expression for that
+  ///     Value.
+  ///   - lowers the number of expression rewrites.
+  class PredicatedScalarEvolution {
+  public:
+    PredicatedScalarEvolution(ScalarEvolution &SE);
+    const SCEVUnionPredicate &getUnionPredicate() const;
+    /// \brief Returns the SCEV expression of V, in the context of the current
+    /// SCEV predicate.
+    /// The order of transformations applied on the expression of V returned
+    /// by ScalarEvolution is guaranteed to be preserved, even when adding new
+    /// predicates.
+    const SCEV *getSCEV(Value *V);
+    /// \brief Adds a new predicate.
+    void addPredicate(const SCEVPredicate &Pred);
+    /// \brief Returns the ScalarEvolution analysis used.
+    ScalarEvolution *getSE() const { return &SE; }
+
+  private:
+    /// \brief Increments the version number of the predicate.
+    /// This needs to be called every time the SCEV predicate changes.
+    void updateGeneration();
+    /// Holds a SCEV and the version number of the SCEV predicate used to
+    /// perform the rewrite of the expression.
+    typedef std::pair<unsigned, const SCEV *> RewriteEntry;
+    /// Maps a SCEV to the rewrite result of that SCEV at a certain version
+    /// number. If this number doesn't match the current Generation, we will
+    /// need to do a rewrite. To preserve the transformation order of previous
+    /// rewrites, we will rewrite the previous result instead of the original
+    /// SCEV.
+    DenseMap<const SCEV *, RewriteEntry> RewriteMap;
+    /// The ScalarEvolution analysis.
+    ScalarEvolution &SE;
+    /// The SCEVPredicate that forms our context. We will rewrite all
+    /// expressions assuming that this predicate true.
+    SCEVUnionPredicate Preds;
+    /// Marks the version of the SCEV predicate used. When rewriting a SCEV
+    /// expression we mark it with the version of the predicate. We use this to
+    /// figure out if the predicate has changed from the last rewrite of the
+    /// SCEV. If so, we need to perform a new rewrite.
+    unsigned Generation;
+  };
 }
 
 #endif
 }
 
 #endif
index b2670bf48dd84fb9d7dfa468a0e0893bdb5e9494..ce6a5ab5656d96d093dd6e5cc1450576a877c3a8 100644 (file)
@@ -87,11 +87,10 @@ Value *llvm::stripIntegerCast(Value *V) {
   return V;
 }
 
   return V;
 }
 
-const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE,
+const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
                                             const ValueToValueMap &PtrToStride,
                                             const ValueToValueMap &PtrToStride,
-                                            SCEVUnionPredicate &Preds,
                                             Value *Ptr, Value *OrigPtr) {
                                             Value *Ptr, Value *OrigPtr) {
-  const SCEV *OrigSCEV = SE->getSCEV(Ptr);
+  const SCEV *OrigSCEV = PSE.getSCEV(Ptr);
 
   // If there is an entry in the map return the SCEV of the pointer with the
   // symbolic stride replaced by one.
 
   // If there is an entry in the map return the SCEV of the pointer with the
   // symbolic stride replaced by one.
@@ -108,16 +107,17 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE,
     ValueToValueMap RewriteMap;
     RewriteMap[StrideVal] = One;
 
     ValueToValueMap RewriteMap;
     RewriteMap[StrideVal] = One;
 
+    ScalarEvolution *SE = PSE.getSE();
     const auto *U = cast<SCEVUnknown>(SE->getSCEV(StrideVal));
     const auto *CT =
         static_cast<const SCEVConstant *>(SE->getOne(StrideVal->getType()));
 
     const auto *U = cast<SCEVUnknown>(SE->getSCEV(StrideVal));
     const auto *CT =
         static_cast<const SCEVConstant *>(SE->getOne(StrideVal->getType()));
 
-    Preds.add(SE->getEqualPredicate(U, CT));
+    PSE.addPredicate(*SE->getEqualPredicate(U, CT));
+    auto *Expr = PSE.getSCEV(Ptr);
 
 
-    const SCEV *ByOne = SE->rewriteUsingPredicate(OrigSCEV, Preds);
-    DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV << " by: " << *ByOne
+    DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV << " by: " << *Expr
                  << "\n");
                  << "\n");
-    return ByOne;
+    return Expr;
   }
 
   // Otherwise, just return the SCEV of the original pointer.
   }
 
   // Otherwise, just return the SCEV of the original pointer.
@@ -127,11 +127,12 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE,
 void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr,
                                     unsigned DepSetId, unsigned ASId,
                                     const ValueToValueMap &Strides,
 void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr,
                                     unsigned DepSetId, unsigned ASId,
                                     const ValueToValueMap &Strides,
-                                    SCEVUnionPredicate &Preds) {
+                                    PredicatedScalarEvolution &PSE) {
   // Get the stride replaced scev.
   // Get the stride replaced scev.
-  const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
+  const SCEV *Sc = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc);
   assert(AR && "Invalid addrec expression");
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc);
   assert(AR && "Invalid addrec expression");
+  ScalarEvolution *SE = PSE.getSE();
   const SCEV *Ex = SE->getBackedgeTakenCount(Lp);
 
   const SCEV *ScStart = AR->getStart();
   const SCEV *Ex = SE->getBackedgeTakenCount(Lp);
 
   const SCEV *ScStart = AR->getStart();
@@ -423,9 +424,10 @@ public:
   typedef SmallPtrSet<MemAccessInfo, 8> MemAccessInfoSet;
 
   AccessAnalysis(const DataLayout &Dl, AliasAnalysis *AA, LoopInfo *LI,
   typedef SmallPtrSet<MemAccessInfo, 8> MemAccessInfoSet;
 
   AccessAnalysis(const DataLayout &Dl, AliasAnalysis *AA, LoopInfo *LI,
-                 MemoryDepChecker::DepCandidates &DA, SCEVUnionPredicate &Preds)
+                 MemoryDepChecker::DepCandidates &DA,
+                 PredicatedScalarEvolution &PSE)
       : DL(Dl), AST(*AA), LI(LI), DepCands(DA), IsRTCheckAnalysisNeeded(false),
       : DL(Dl), AST(*AA), LI(LI), DepCands(DA), IsRTCheckAnalysisNeeded(false),
-        Preds(Preds) {}
+        PSE(PSE) {}
 
   /// \brief Register a load  and whether it is only read from.
   void addLoad(MemoryLocation &Loc, bool IsReadOnly) {
 
   /// \brief Register a load  and whether it is only read from.
   void addLoad(MemoryLocation &Loc, bool IsReadOnly) {
@@ -512,16 +514,16 @@ private:
   bool IsRTCheckAnalysisNeeded;
 
   /// The SCEV predicate containing all the SCEV-related assumptions.
   bool IsRTCheckAnalysisNeeded;
 
   /// The SCEV predicate containing all the SCEV-related assumptions.
-  SCEVUnionPredicate &Preds;
+  PredicatedScalarEvolution &PSE;
 };
 
 } // end anonymous namespace
 
 /// \brief Check whether a pointer can participate in a runtime bounds check.
 };
 
 } // end anonymous namespace
 
 /// \brief Check whether a pointer can participate in a runtime bounds check.
-static bool hasComputableBounds(ScalarEvolution *SE,
+static bool hasComputableBounds(PredicatedScalarEvolution &PSE,
                                 const ValueToValueMap &Strides, Value *Ptr,
                                 const ValueToValueMap &Strides, Value *Ptr,
-                                Loop *L, SCEVUnionPredicate &Preds) {
-  const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
+                                Loop *L) {
+  const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
   if (!AR)
     return false;
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
   if (!AR)
     return false;
@@ -564,11 +566,11 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck,
       else
         ++NumReadPtrChecks;
 
       else
         ++NumReadPtrChecks;
 
-      if (hasComputableBounds(SE, StridesMap, Ptr, TheLoop, Preds) &&
+      if (hasComputableBounds(PSE, StridesMap, Ptr, TheLoop) &&
           // When we run after a failing dependency check we have to make sure
           // we don't have wrapping pointers.
           (!ShouldCheckStride ||
           // When we run after a failing dependency check we have to make sure
           // we don't have wrapping pointers.
           (!ShouldCheckStride ||
-           isStridedPtr(SE, Ptr, TheLoop, StridesMap, Preds) == 1)) {
+           isStridedPtr(PSE, Ptr, TheLoop, StridesMap) == 1)) {
         // The id of the dependence set.
         unsigned DepId;
 
         // The id of the dependence set.
         unsigned DepId;
 
@@ -582,7 +584,7 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck,
           // Each access has its own dependence set.
           DepId = RunningDepId++;
 
           // Each access has its own dependence set.
           DepId = RunningDepId++;
 
-        RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, Preds);
+        RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, PSE);
 
         DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n');
       } else {
 
         DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n');
       } else {
@@ -817,9 +819,8 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR,
 }
 
 /// \brief Check whether the access through \p Ptr has a constant stride.
 }
 
 /// \brief Check whether the access through \p Ptr has a constant stride.
-int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp,
-                       const ValueToValueMap &StridesMap,
-                       SCEVUnionPredicate &Preds) {
+int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr,
+                       const Loop *Lp, const ValueToValueMap &StridesMap) {
   Type *Ty = Ptr->getType();
   assert(Ty->isPointerTy() && "Unexpected non-ptr");
 
   Type *Ty = Ptr->getType();
   assert(Ty->isPointerTy() && "Unexpected non-ptr");
 
@@ -831,7 +832,7 @@ int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp,
     return 0;
   }
 
     return 0;
   }
 
-  const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Preds, Ptr);
+  const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr);
 
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
   if (!AR) {
 
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
   if (!AR) {
@@ -854,16 +855,16 @@ int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp,
   // to access the pointer value "0" which is undefined behavior in address
   // space 0, therefore we can also vectorize this case.
   bool IsInBoundsGEP = isInBoundsGep(Ptr);
   // to access the pointer value "0" which is undefined behavior in address
   // space 0, therefore we can also vectorize this case.
   bool IsInBoundsGEP = isInBoundsGep(Ptr);
-  bool IsNoWrapAddRec = isNoWrapAddRec(Ptr, AR, SE, Lp);
+  bool IsNoWrapAddRec = isNoWrapAddRec(Ptr, AR, PSE.getSE(), Lp);
   bool IsInAddressSpaceZero = PtrTy->getAddressSpace() == 0;
   if (!IsNoWrapAddRec && !IsInBoundsGEP && !IsInAddressSpaceZero) {
     DEBUG(dbgs() << "LAA: Bad stride - Pointer may wrap in the address space "
   bool IsInAddressSpaceZero = PtrTy->getAddressSpace() == 0;
   if (!IsNoWrapAddRec && !IsInBoundsGEP && !IsInAddressSpaceZero) {
     DEBUG(dbgs() << "LAA: Bad stride - Pointer may wrap in the address space "
-          << *Ptr << " SCEV: " << *PtrScev << "\n");
+                 << *Ptr << " SCEV: " << *PtrScev << "\n");
     return 0;
   }
 
   // Check the step is constant.
     return 0;
   }
 
   // Check the step is constant.
-  const SCEV *Step = AR->getStepRecurrence(*SE);
+  const SCEV *Step = AR->getStepRecurrence(*PSE.getSE());
 
   // Calculate the pointer stride and check if it is constant.
   const SCEVConstant *C = dyn_cast<SCEVConstant>(Step);
 
   // Calculate the pointer stride and check if it is constant.
   const SCEVConstant *C = dyn_cast<SCEVConstant>(Step);
@@ -1046,11 +1047,11 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
       BPtr->getType()->getPointerAddressSpace())
     return Dependence::Unknown;
 
       BPtr->getType()->getPointerAddressSpace())
     return Dependence::Unknown;
 
-  const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, APtr);
-  const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, BPtr);
+  const SCEV *AScev = replaceSymbolicStrideSCEV(PSE, Strides, APtr);
+  const SCEV *BScev = replaceSymbolicStrideSCEV(PSE, Strides, BPtr);
 
 
-  int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides, Preds);
-  int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides, Preds);
+  int StrideAPtr = isStridedPtr(PSE, APtr, InnermostLoop, Strides);
+  int StrideBPtr = isStridedPtr(PSE, BPtr, InnermostLoop, Strides);
 
   const SCEV *Src = AScev;
   const SCEV *Sink = BScev;
 
   const SCEV *Src = AScev;
   const SCEV *Sink = BScev;
@@ -1067,12 +1068,12 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
     std::swap(StrideAPtr, StrideBPtr);
   }
 
     std::swap(StrideAPtr, StrideBPtr);
   }
 
-  const SCEV *Dist = SE->getMinusSCEV(Sink, Src);
+  const SCEV *Dist = PSE.getSE()->getMinusSCEV(Sink, Src);
 
   DEBUG(dbgs() << "LAA: Src Scev: " << *Src << "Sink Scev: " << *Sink
 
   DEBUG(dbgs() << "LAA: Src Scev: " << *Src << "Sink Scev: " << *Sink
-        << "(Induction step: " << StrideAPtr <<  ")\n");
+               << "(Induction step: " << StrideAPtr << ")\n");
   DEBUG(dbgs() << "LAA: Distance for " << *InstMap[AIdx] << " to "
   DEBUG(dbgs() << "LAA: Distance for " << *InstMap[AIdx] << " to "
-        << *InstMap[BIdx] << ": " << *Dist << "\n");
+               << *InstMap[BIdx] << ": " << *Dist << "\n");
 
   // Need accesses with constant stride. We don't want to vectorize
   // "A[B[i]] += ..." and similar code or pointer arithmetic that could wrap in
 
   // Need accesses with constant stride. We don't want to vectorize
   // "A[B[i]] += ..." and similar code or pointer arithmetic that could wrap in
@@ -1343,10 +1344,10 @@ bool LoopAccessInfo::canAnalyzeLoop() {
   }
 
   // ScalarEvolution needs to be able to find the exit count.
   }
 
   // ScalarEvolution needs to be able to find the exit count.
-  const SCEV *ExitCount = SE->getBackedgeTakenCount(TheLoop);
-  if (ExitCount == SE->getCouldNotCompute()) {
-    emitAnalysis(LoopAccessReport() <<
-                 "could not determine number of loop iterations");
+  const SCEV *ExitCount = PSE.getSE()->getBackedgeTakenCount(TheLoop);
+  if (ExitCount == PSE.getSE()->getCouldNotCompute()) {
+    emitAnalysis(LoopAccessReport()
+                 << "could not determine number of loop iterations");
     DEBUG(dbgs() << "LAA: SCEV could not compute the loop exit count.\n");
     return false;
   }
     DEBUG(dbgs() << "LAA: SCEV could not compute the loop exit count.\n");
     return false;
   }
@@ -1447,7 +1448,7 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) {
 
   MemoryDepChecker::DepCandidates DependentAccesses;
   AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(),
 
   MemoryDepChecker::DepCandidates DependentAccesses;
   AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(),
-                          AA, LI, DependentAccesses, Preds);
+                          AA, LI, DependentAccesses, PSE);
 
   // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects
   // multiple times on the same object. If the ptr is accessed twice, once
 
   // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects
   // multiple times on the same object. If the ptr is accessed twice, once
@@ -1498,8 +1499,7 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) {
     // read a few words, modify, and write a few words, and some of the
     // words may be written to the same address.
     bool IsReadOnlyPtr = false;
     // read a few words, modify, and write a few words, and some of the
     // words may be written to the same address.
     bool IsReadOnlyPtr = false;
-    if (Seen.insert(Ptr).second ||
-        !isStridedPtr(SE, Ptr, TheLoop, Strides, Preds)) {
+    if (Seen.insert(Ptr).second || !isStridedPtr(PSE, Ptr, TheLoop, Strides)) {
       ++NumReads;
       IsReadOnlyPtr = true;
     }
       ++NumReads;
       IsReadOnlyPtr = true;
     }
@@ -1529,7 +1529,7 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) {
   // Find pointers with computable bounds. We are going to use this information
   // to place a runtime bound check.
   bool CanDoRTIfNeeded =
   // Find pointers with computable bounds. We are going to use this information
   // to place a runtime bound check.
   bool CanDoRTIfNeeded =
-      Accesses.canCheckPtrAtRT(PtrRtChecking, SE, TheLoop, Strides);
+      Accesses.canCheckPtrAtRT(PtrRtChecking, PSE.getSE(), TheLoop, Strides);
   if (!CanDoRTIfNeeded) {
     emitAnalysis(LoopAccessReport() << "cannot identify array bounds");
     DEBUG(dbgs() << "LAA: We can't vectorize because we can't find "
   if (!CanDoRTIfNeeded) {
     emitAnalysis(LoopAccessReport() << "cannot identify array bounds");
     DEBUG(dbgs() << "LAA: We can't vectorize because we can't find "
@@ -1556,6 +1556,7 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) {
       PtrRtChecking.reset();
       PtrRtChecking.Need = true;
 
       PtrRtChecking.reset();
       PtrRtChecking.Need = true;
 
+      auto *SE = PSE.getSE();
       CanDoRTIfNeeded =
           Accesses.canCheckPtrAtRT(PtrRtChecking, SE, TheLoop, Strides, true);
 
       CanDoRTIfNeeded =
           Accesses.canCheckPtrAtRT(PtrRtChecking, SE, TheLoop, Strides, true);
 
@@ -1598,7 +1599,7 @@ void LoopAccessInfo::emitAnalysis(LoopAccessReport &Message) {
 }
 
 bool LoopAccessInfo::isUniform(Value *V) const {
 }
 
 bool LoopAccessInfo::isUniform(Value *V) const {
-  return (SE->isLoopInvariant(SE->getSCEV(V), TheLoop));
+  return (PSE.getSE()->isLoopInvariant(PSE.getSE()->getSCEV(V), TheLoop));
 }
 
 // FIXME: this function is currently a duplicate of the one in
 }
 
 // FIXME: this function is currently a duplicate of the one in
@@ -1679,7 +1680,7 @@ std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeChecks(
     Instruction *Loc,
     const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks)
     const {
     Instruction *Loc,
     const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks)
     const {
-
+  auto *SE = PSE.getSE();
   SCEVExpander Exp(*SE, DL, "induction");
   auto ExpandedChecks =
       expandBounds(PointerChecks, TheLoop, Loc, SE, Exp, PtrRtChecking);
   SCEVExpander Exp(*SE, DL, "induction");
   auto ExpandedChecks =
       expandBounds(PointerChecks, TheLoop, Loc, SE, Exp, PtrRtChecking);
@@ -1749,7 +1750,7 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
                                const TargetLibraryInfo *TLI, AliasAnalysis *AA,
                                DominatorTree *DT, LoopInfo *LI,
                                const ValueToValueMap &Strides)
                                const TargetLibraryInfo *TLI, AliasAnalysis *AA,
                                DominatorTree *DT, LoopInfo *LI,
                                const ValueToValueMap &Strides)
-    : PtrRtChecking(SE), DepChecker(SE, L, Preds), TheLoop(L), SE(SE), DL(DL),
+    : PSE(*SE), PtrRtChecking(SE), DepChecker(PSE, L), TheLoop(L), DL(DL),
       TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0),
       MaxSafeDepDistBytes(-1U), CanVecMem(false),
       StoreToLoopInvariantAddress(false) {
       TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0),
       MaxSafeDepDistBytes(-1U), CanVecMem(false),
       StoreToLoopInvariantAddress(false) {
@@ -1786,7 +1787,7 @@ void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const {
                    << "found in loop.\n";
 
   OS.indent(Depth) << "SCEV assumptions:\n";
                    << "found in loop.\n";
 
   OS.indent(Depth) << "SCEV assumptions:\n";
-  Preds.print(OS, Depth);
+  PSE.getUnionPredicate().print(OS, Depth);
 }
 
 const LoopAccessInfo &
 }
 
 const LoopAccessInfo &
index f57997e146e09e55dd05c05f01f6394331eaa979..1c2fb3d1ed02e490dacc3952699c1b3ff6b758b9 100644 (file)
@@ -9707,3 +9707,46 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) {
   SCEVToPreds[Key].push_back(N);
   Preds.push_back(N);
 }
   SCEVToPreds[Key].push_back(N);
   Preds.push_back(N);
 }
+
+PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE)
+    : SE(SE), Generation(0) {}
+
+const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
+  const SCEV *Expr = SE.getSCEV(V);
+  RewriteEntry &Entry = RewriteMap[Expr];
+
+  // If we already have an entry and the version matches, return it.
+  if (Entry.second && Generation == Entry.first)
+    return Entry.second;
+
+  // We found an entry but it's stale. Rewrite the stale entry
+  // acording to the current predicate.
+  if (Entry.second)
+    Expr = Entry.second;
+
+  const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds);
+  Entry = {Generation, NewSCEV};
+
+  return NewSCEV;
+}
+
+void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
+  if (Preds.implies(&Pred))
+    return;
+  Preds.add(&Pred);
+  updateGeneration();
+}
+
+const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const {
+  return Preds;
+}
+
+void PredicatedScalarEvolution::updateGeneration() {
+  // If the generation number wrapped recompute everything.
+  if (++Generation == 0) {
+    for (auto &II : RewriteMap) {
+      const SCEV *Rewritten = II.second.second;
+      II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)};
+    }
+  }
+}
index 67ebd2532b1614f7702e2a0decc9e5b4b6819813..fce063ab40a00e439d9565edbd208922512c109e 100644 (file)
@@ -761,7 +761,7 @@ private:
     }
 
     // Don't distribute the loop if we need too many SCEV run-time checks.
     }
 
     // Don't distribute the loop if we need too many SCEV run-time checks.
-    const SCEVUnionPredicate &Pred = LAI.Preds;
+    const SCEVUnionPredicate &Pred = LAI.PSE.getUnionPredicate();
     if (Pred.getComplexity() > DistributeSCEVCheckThreshold) {
       DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");
       return false;
     if (Pred.getComplexity() > DistributeSCEVCheckThreshold) {
       DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");
       return false;
@@ -790,7 +790,7 @@ private:
       DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks));
       LoopVersioning LVer(LAI, L, LI, DT, SE, false);
       LVer.setAliasChecks(std::move(Checks));
       DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks));
       LoopVersioning LVer(LAI, L, LI, DT, SE, false);
       LVer.setAliasChecks(std::move(Checks));
-      LVer.setSCEVChecks(LAI.Preds);
+      LVer.setSCEVChecks(LAI.PSE.getUnionPredicate());
       LVer.versionLoop(DefsUsedOutside);
     }
 
       LVer.versionLoop(DefsUsedOutside);
     }
 
index 7c7bf64ba79c81ee7a9e6fc06a80817c78bf2d38..09d022b3013b59d9bdefccb805e2606359d1250d 100644 (file)
@@ -459,17 +459,18 @@ public:
       return false;
     }
 
       return false;
     }
 
-    if (LAI.Preds.getComplexity() > LoadElimSCEVCheckThreshold) {
+    if (LAI.PSE.getUnionPredicate().getComplexity() >
+        LoadElimSCEVCheckThreshold) {
       DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");
       return false;
     }
 
     // Point of no-return, start the transformation.  First, version the loop if
     // necessary.
       DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");
       return false;
     }
 
     // Point of no-return, start the transformation.  First, version the loop if
     // necessary.
-    if (!Checks.empty() || !LAI.Preds.isAlwaysTrue()) {
+    if (!Checks.empty() || !LAI.PSE.getUnionPredicate().isAlwaysTrue()) {
       LoopVersioning LV(LAI, L, LI, DT, SE, false);
       LV.setAliasChecks(std::move(Checks));
       LoopVersioning LV(LAI, L, LI, DT, SE, false);
       LV.setAliasChecks(std::move(Checks));
-      LV.setSCEVChecks(LAI.Preds);
+      LV.setSCEVChecks(LAI.PSE.getUnionPredicate());
       LV.versionLoop();
     }
 
       LV.versionLoop();
     }
 
index cc3ff5d80d42a037ec38df9f4df93a59a464b88c..9a2a06cf689154835fd88d5c400aaa0d9a1fa20c 100644 (file)
@@ -32,7 +32,7 @@ LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
   assert(L->getLoopPreheader() && "No preheader");
   if (UseLAIChecks) {
     setAliasChecks(LAI.getRuntimePointerChecking()->getChecks());
   assert(L->getLoopPreheader() && "No preheader");
   if (UseLAIChecks) {
     setAliasChecks(LAI.getRuntimePointerChecking()->getChecks());
-    setSCEVChecks(LAI.Preds);
+    setSCEVChecks(LAI.PSE.getUnionPredicate());
   }
 }
 
   }
 }
 
@@ -58,7 +58,7 @@ void LoopVersioning::versionLoop(
       LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks);
   assert(MemRuntimeCheck && "called even though needsAnyChecking = false");
 
       LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks);
   assert(MemRuntimeCheck && "called even though needsAnyChecking = false");
 
-  const SCEVUnionPredicate &Pred = LAI.Preds;
+  const SCEVUnionPredicate &Pred = LAI.PSE.getUnionPredicate();
   SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
                    "scev.check");
   SCEVRuntimeCheck =
   SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
                    "scev.check");
   SCEVRuntimeCheck =
index 917f2d55f6cb18e4e553ae26f1390d262424c657..9adc80c8bd0fae5e29fecc4661477c320a48b331 100644 (file)
@@ -310,15 +310,16 @@ static GetElementPtrInst *getGEPInstruction(Value *Ptr) {
 /// and reduction variables that were found to a given vectorization factor.
 class InnerLoopVectorizer {
 public:
 /// and reduction variables that were found to a given vectorization factor.
 class InnerLoopVectorizer {
 public:
-  InnerLoopVectorizer(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI,
-                      DominatorTree *DT, const TargetLibraryInfo *TLI,
+  InnerLoopVectorizer(Loop *OrigLoop, PredicatedScalarEvolution &PSE,
+                      LoopInfo *LI, DominatorTree *DT,
+                      const TargetLibraryInfo *TLI,
                       const TargetTransformInfo *TTI, unsigned VecWidth,
                       const TargetTransformInfo *TTI, unsigned VecWidth,
-                      unsigned UnrollFactor, SCEVUnionPredicate &Preds)
-      : OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), TLI(TLI), TTI(TTI),
-        VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()),
+                      unsigned UnrollFactor)
+      : OrigLoop(OrigLoop), PSE(PSE), LI(LI), DT(DT), TLI(TLI), TTI(TTI),
+        VF(VecWidth), UF(UnrollFactor), Builder(PSE.getSE()->getContext()),
         Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor),
         TripCount(nullptr), VectorTripCount(nullptr), Legal(nullptr),
         Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor),
         TripCount(nullptr), VectorTripCount(nullptr), Legal(nullptr),
-        AddedSafetyChecks(false), Preds(Preds) {}
+        AddedSafetyChecks(false) {}
 
   // Perform the actual loop widening (vectorization).
   // MinimumBitWidths maps scalar integer values to the smallest bitwidth they
 
   // Perform the actual loop widening (vectorization).
   // MinimumBitWidths maps scalar integer values to the smallest bitwidth they
@@ -486,8 +487,10 @@ protected:
 
   /// The original loop.
   Loop *OrigLoop;
 
   /// The original loop.
   Loop *OrigLoop;
-  /// Scev analysis to use.
-  ScalarEvolution *SE;
+  /// A wrapper around ScalarEvolution used to add runtime SCEV checks. Applies
+  /// dynamic knowledge to simplify SCEV expressions and converts them to a
+  /// more usable form.
+  PredicatedScalarEvolution &PSE;
   /// Loop Info.
   LoopInfo *LI;
   /// Dominator Tree.
   /// Loop Info.
   LoopInfo *LI;
   /// Dominator Tree.
@@ -551,23 +554,15 @@ protected:
 
   // Record whether runtime check is added.
   bool AddedSafetyChecks;
 
   // Record whether runtime check is added.
   bool AddedSafetyChecks;
-
-  /// The SCEV predicate containing all the SCEV-related assumptions.
-  /// The predicate is used to simplify existing expressions in the
-  /// context of existing SCEV assumptions. Since legality checking is
-  /// not done here, we don't need to use this predicate to record
-  /// further assumptions.
-  SCEVUnionPredicate &Preds;
 };
 
 class InnerLoopUnroller : public InnerLoopVectorizer {
 public:
 };
 
 class InnerLoopUnroller : public InnerLoopVectorizer {
 public:
-  InnerLoopUnroller(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI,
-                    DominatorTree *DT, const TargetLibraryInfo *TLI,
-                    const TargetTransformInfo *TTI, unsigned UnrollFactor,
-                    SCEVUnionPredicate &Preds)
-      : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor,
-                            Preds) {}
+  InnerLoopUnroller(Loop *OrigLoop, PredicatedScalarEvolution &PSE,
+                    LoopInfo *LI, DominatorTree *DT,
+                    const TargetLibraryInfo *TLI,
+                    const TargetTransformInfo *TTI, unsigned UnrollFactor)
+      : InnerLoopVectorizer(OrigLoop, PSE, LI, DT, TLI, TTI, 1, UnrollFactor) {}
 
 private:
   void scalarizeInstruction(Instruction *Instr,
 
 private:
   void scalarizeInstruction(Instruction *Instr,
@@ -789,9 +784,9 @@ private:
 /// between the member and the group in a map.
 class InterleavedAccessInfo {
 public:
 /// between the member and the group in a map.
 class InterleavedAccessInfo {
 public:
-  InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT,
-                        SCEVUnionPredicate &Preds)
-      : SE(SE), TheLoop(L), DT(DT), Preds(Preds) {}
+  InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
+                        DominatorTree *DT)
+      : PSE(PSE), TheLoop(L), DT(DT) {}
 
   ~InterleavedAccessInfo() {
     SmallSet<InterleaveGroup *, 4> DelSet;
 
   ~InterleavedAccessInfo() {
     SmallSet<InterleaveGroup *, 4> DelSet;
@@ -821,17 +816,14 @@ public:
   }
 
 private:
   }
 
 private:
-  ScalarEvolution *SE;
+  /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
+  /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
+  /// The interleaved access analysis can also add new predicates (for example
+  /// by versioning strides of pointers).
+  PredicatedScalarEvolution &PSE;
   Loop *TheLoop;
   DominatorTree *DT;
 
   Loop *TheLoop;
   DominatorTree *DT;
 
-  /// The SCEV predicate containing all the SCEV-related assumptions.
-  /// The predicate is used to simplify SCEV expressions in the
-  /// context of existing SCEV assumptions. The interleaved access
-  /// analysis can also add new predicates (for example by versioning
-  /// strides of pointers).
-  SCEVUnionPredicate &Preds;
-
   /// Holds the relationships between the members and the interleave group.
   DenseMap<Instruction *, InterleaveGroup *> InterleaveGroupMap;
 
   /// Holds the relationships between the members and the interleave group.
   DenseMap<Instruction *, InterleaveGroup *> InterleaveGroupMap;
 
@@ -1189,18 +1181,17 @@ static void emitMissedWarning(Function *F, Loop *L,
 /// induction variable and the different reduction variables.
 class LoopVectorizationLegality {
 public:
 /// induction variable and the different reduction variables.
 class LoopVectorizationLegality {
 public:
-  LoopVectorizationLegality(Loop *L, ScalarEvolution *SE, DominatorTree *DT,
-                            TargetLibraryInfo *TLI, AliasAnalysis *AA,
-                            Function *F, const TargetTransformInfo *TTI,
+  LoopVectorizationLegality(Loop *L, PredicatedScalarEvolution &PSE,
+                            DominatorTree *DT, TargetLibraryInfo *TLI,
+                            AliasAnalysis *AA, Function *F,
+                            const TargetTransformInfo *TTI,
                             LoopAccessAnalysis *LAA,
                             LoopVectorizationRequirements *R,
                             LoopAccessAnalysis *LAA,
                             LoopVectorizationRequirements *R,
-                            const LoopVectorizeHints *H,
-                            SCEVUnionPredicate &Preds)
-      : NumPredStores(0), TheLoop(L), SE(SE), TLI(TLI), TheFunction(F),
-        TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr),
-        InterleaveInfo(SE, L, DT, Preds), Induction(nullptr),
-        WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), Hints(H),
-        Preds(Preds) {}
+                            const LoopVectorizeHints *H)
+      : NumPredStores(0), TheLoop(L), PSE(PSE), TLI(TLI), TheFunction(F),
+        TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), InterleaveInfo(PSE, L, DT),
+        Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false),
+        Requirements(R), Hints(H) {}
 
   /// ReductionList contains the reduction descriptors for all
   /// of the reductions that were found in the loop.
 
   /// ReductionList contains the reduction descriptors for all
   /// of the reductions that were found in the loop.
@@ -1347,8 +1338,12 @@ private:
 
   /// The loop that we evaluate.
   Loop *TheLoop;
 
   /// The loop that we evaluate.
   Loop *TheLoop;
-  /// Scev analysis.
-  ScalarEvolution *SE;
+  /// A wrapper around ScalarEvolution used to add runtime SCEV checks.
+  /// Applies dynamic knowledge to simplify SCEV expressions in the context
+  /// of existing SCEV assumptions. The analysis will also add a minimal set
+  /// of new predicates if this is required to enable vectorization and
+  /// unrolling.
+  PredicatedScalarEvolution &PSE;
   /// Target Library Info.
   TargetLibraryInfo *TLI;
   /// Parent function
   /// Target Library Info.
   TargetLibraryInfo *TLI;
   /// Parent function
@@ -1403,13 +1398,6 @@ private:
   /// While vectorizing these instructions we have to generate a
   /// call to the appropriate masked intrinsic
   SmallPtrSet<const Instruction *, 8> MaskedOp;
   /// While vectorizing these instructions we have to generate a
   /// call to the appropriate masked intrinsic
   SmallPtrSet<const Instruction *, 8> MaskedOp;
-
-  /// The SCEV predicate containing all the SCEV-related assumptions.
-  /// The predicate is used to simplify SCEV expressions in the
-  /// context of existing SCEV assumptions. The analysis will also
-  /// add a minimal set of new predicates if this is required to
-  /// enable vectorization/unrolling.
-  SCEVUnionPredicate &Preds;
 };
 
 /// LoopVectorizationCostModel - estimates the expected speedups due to
 };
 
 /// LoopVectorizationCostModel - estimates the expected speedups due to
@@ -1427,8 +1415,7 @@ public:
                              const TargetLibraryInfo *TLI, DemandedBits *DB,
                              AssumptionCache *AC, const Function *F,
                              const LoopVectorizeHints *Hints,
                              const TargetLibraryInfo *TLI, DemandedBits *DB,
                              AssumptionCache *AC, const Function *F,
                              const LoopVectorizeHints *Hints,
-                             SmallPtrSetImpl<const Value *> &ValuesToIgnore,
-                             SCEVUnionPredicate &Preds)
+                             SmallPtrSetImpl<const Value *> &ValuesToIgnore)
       : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB),
         TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {}
 
       : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB),
         TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {}
 
@@ -1758,12 +1745,12 @@ struct LoopVectorize : public FunctionPass {
       }
     }
 
       }
     }
 
-    SCEVUnionPredicate Preds;
+    PredicatedScalarEvolution PSE(*SE);
 
     // Check if it is legal to vectorize the loop.
     LoopVectorizationRequirements Requirements;
 
     // Check if it is legal to vectorize the loop.
     LoopVectorizationRequirements Requirements;
-    LoopVectorizationLegality LVL(L, SE, DT, TLI, AA, F, TTI, LAA,
-                                  &Requirements, &Hints, Preds);
+    LoopVectorizationLegality LVL(L, PSE, DT, TLI, AA, F, TTI, LAA,
+                                  &Requirements, &Hints);
     if (!LVL.canVectorize()) {
       DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n");
       emitMissedWarning(F, L, Hints);
     if (!LVL.canVectorize()) {
       DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n");
       emitMissedWarning(F, L, Hints);
@@ -1781,8 +1768,8 @@ struct LoopVectorize : public FunctionPass {
     }
 
     // Use the cost model.
     }
 
     // Use the cost model.
-    LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, DB, AC, F, &Hints,
-                                  ValuesToIgnore, Preds);
+    LoopVectorizationCostModel CM(L, PSE.getSE(), LI, &LVL, *TTI, TLI, DB, AC,
+                                  F, &Hints, ValuesToIgnore);
 
     // Check the function attributes to find out if this function should be
     // optimized for size.
 
     // Check the function attributes to find out if this function should be
     // optimized for size.
@@ -1893,7 +1880,7 @@ struct LoopVectorize : public FunctionPass {
       assert(IC > 1 && "interleave count should not be 1 or 0");
       // If we decided that it is not legal to vectorize the loop then
       // interleave it.
       assert(IC > 1 && "interleave count should not be 1 or 0");
       // If we decided that it is not legal to vectorize the loop then
       // interleave it.
-      InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC, Preds);
+      InnerLoopUnroller Unroller(L, PSE, LI, DT, TLI, TTI, IC);
       Unroller.vectorize(&LVL, CM.MinBWs);
 
       emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(),
       Unroller.vectorize(&LVL, CM.MinBWs);
 
       emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(),
@@ -1901,7 +1888,7 @@ struct LoopVectorize : public FunctionPass {
                                  Twine(IC) + ")");
     } else {
       // If we decided that it is *legal* to vectorize the loop then do it.
                                  Twine(IC) + ")");
     } else {
       // If we decided that it is *legal* to vectorize the loop then do it.
-      InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC, Preds);
+      InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, VF.Width, IC);
       LB.vectorize(&LVL, CM.MinBWs);
       ++LoopsVectorized;
 
       LB.vectorize(&LVL, CM.MinBWs);
       ++LoopsVectorized;
 
@@ -2002,6 +1989,7 @@ Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx,
 
 int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) {
   assert(Ptr->getType()->isPointerTy() && "Unexpected non-ptr");
 
 int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) {
   assert(Ptr->getType()->isPointerTy() && "Unexpected non-ptr");
+  auto *SE = PSE.getSE();
   // Make sure that the pointer does not point to structs.
   if (Ptr->getType()->getPointerElementType()->isAggregateType())
     return 0;
   // Make sure that the pointer does not point to structs.
   if (Ptr->getType()->getPointerElementType()->isAggregateType())
     return 0;
@@ -2031,7 +2019,7 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) {
 
     // Make sure that all of the index operands are loop invariant.
     for (unsigned i = 1; i < NumOperands; ++i)
 
     // Make sure that all of the index operands are loop invariant.
     for (unsigned i = 1; i < NumOperands; ++i)
-      if (!SE->isLoopInvariant(SE->getSCEV(Gep->getOperand(i)), TheLoop))
+      if (!SE->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), TheLoop))
         return 0;
 
     InductionDescriptor II = Inductions[Phi];
         return 0;
 
     InductionDescriptor II = Inductions[Phi];
@@ -2044,14 +2032,14 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) {
   // operand.
   for (unsigned i = 0; i != NumOperands; ++i)
     if (i != InductionOperand &&
   // operand.
   for (unsigned i = 0; i != NumOperands; ++i)
     if (i != InductionOperand &&
-        !SE->isLoopInvariant(SE->getSCEV(Gep->getOperand(i)), TheLoop))
+        !SE->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), TheLoop))
       return 0;
 
   // We can emit wide load/stores only if the last non-zero index is the
   // induction variable.
   const SCEV *Last = nullptr;
   if (!Strides.count(Gep))
       return 0;
 
   // We can emit wide load/stores only if the last non-zero index is the
   // induction variable.
   const SCEV *Last = nullptr;
   if (!Strides.count(Gep))
-    Last = SE->getSCEV(Gep->getOperand(InductionOperand));
+    Last = PSE.getSCEV(Gep->getOperand(InductionOperand));
   else {
     // Because of the multiplication by a stride we can have a s/zext cast.
     // We are going to replace this stride by 1 so the cast is safe to ignore.
   else {
     // Because of the multiplication by a stride we can have a s/zext cast.
     // We are going to replace this stride by 1 so the cast is safe to ignore.
@@ -2062,7 +2050,7 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) {
     //  %idxprom = zext i32 %mul to i64  << Safe cast.
     //  %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom
     //
     //  %idxprom = zext i32 %mul to i64  << Safe cast.
     //  %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom
     //
-    Last = replaceSymbolicStrideSCEV(SE, Strides, Preds,
+    Last = replaceSymbolicStrideSCEV(PSE, Strides,
                                      Gep->getOperand(InductionOperand), Gep);
     if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(Last))
       Last =
                                      Gep->getOperand(InductionOperand), Gep);
     if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(Last))
       Last =
@@ -2420,8 +2408,9 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) {
     Ptr = Builder.Insert(Gep2);
   } else if (Gep) {
     setDebugLocFromInst(Builder, Gep);
     Ptr = Builder.Insert(Gep2);
   } else if (Gep) {
     setDebugLocFromInst(Builder, Gep);
-    assert(SE->isLoopInvariant(SE->getSCEV(Gep->getPointerOperand()),
-                               OrigLoop) && "Base ptr must be invariant");
+    assert(PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getPointerOperand()),
+                                        OrigLoop) &&
+           "Base ptr must be invariant");
 
     // The last index does not have to be the induction. It can be
     // consecutive and be a function of the index. For example A[I+1];
 
     // The last index does not have to be the induction. It can be
     // consecutive and be a function of the index. For example A[I+1];
@@ -2438,7 +2427,8 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) {
       if (i == InductionOperand ||
           (GepOperandInst && OrigLoop->contains(GepOperandInst))) {
         assert((i == InductionOperand ||
       if (i == InductionOperand ||
           (GepOperandInst && OrigLoop->contains(GepOperandInst))) {
         assert((i == InductionOperand ||
-               SE->isLoopInvariant(SE->getSCEV(GepOperandInst), OrigLoop)) &&
+                PSE.getSE()->isLoopInvariant(PSE.getSCEV(GepOperandInst),
+                                             OrigLoop)) &&
                "Must be last index or loop invariant");
 
         VectorParts &GEPParts = getVectorValue(GepOperand);
                "Must be last index or loop invariant");
 
         VectorParts &GEPParts = getVectorValue(GepOperand);
@@ -2658,6 +2648,7 @@ Value *InnerLoopVectorizer::getOrCreateTripCount(Loop *L) {
 
   IRBuilder<> Builder(L->getLoopPreheader()->getTerminator());
   // Find the loop boundaries.
 
   IRBuilder<> Builder(L->getLoopPreheader()->getTerminator());
   // Find the loop boundaries.
+  ScalarEvolution *SE = PSE.getSE();
   const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(OrigLoop);
   assert(BackedgeTakenCount != SE->getCouldNotCompute() &&
          "Invalid loop count");
   const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(OrigLoop);
   assert(BackedgeTakenCount != SE->getCouldNotCompute() &&
          "Invalid loop count");
@@ -2765,8 +2756,10 @@ void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) {
   // Generate the code to check that the SCEV assumptions that we made.
   // We want the new basic block to start at the first instruction in a
   // sequence of instructions that form a check.
   // Generate the code to check that the SCEV assumptions that we made.
   // We want the new basic block to start at the first instruction in a
   // sequence of instructions that form a check.
-  SCEVExpander Exp(*SE, Bypass->getModule()->getDataLayout(), "scev.check");
-  Value *SCEVCheck = Exp.expandCodeForPredicate(&Preds, BB->getTerminator());
+  SCEVExpander Exp(*PSE.getSE(), Bypass->getModule()->getDataLayout(),
+                   "scev.check");
+  Value *SCEVCheck =
+      Exp.expandCodeForPredicate(&PSE.getUnionPredicate(), BB->getTerminator());
 
   if (auto *C = dyn_cast<ConstantInt>(SCEVCheck))
     if (C->isZero())
 
   if (auto *C = dyn_cast<ConstantInt>(SCEVCheck))
     if (C->isZero())
@@ -3785,8 +3778,9 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) {
       // Widen selects.
       // If the selector is loop invariant we can create a select
       // instruction with a scalar condition. Otherwise, use vector-select.
       // Widen selects.
       // If the selector is loop invariant we can create a select
       // instruction with a scalar condition. Otherwise, use vector-select.
-      bool InvariantCond = SE->isLoopInvariant(SE->getSCEV(it->getOperand(0)),
-                                               OrigLoop);
+      auto *SE = PSE.getSE();
+      bool InvariantCond =
+          SE->isLoopInvariant(PSE.getSCEV(it->getOperand(0)), OrigLoop);
       setDebugLocFromInst(Builder, &*it);
 
       // The condition can be loop invariant  but still defined inside the
       setDebugLocFromInst(Builder, &*it);
 
       // The condition can be loop invariant  but still defined inside the
@@ -3967,7 +3961,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) {
 
 void InnerLoopVectorizer::updateAnalysis() {
   // Forget the original basic block.
 
 void InnerLoopVectorizer::updateAnalysis() {
   // Forget the original basic block.
-  SE->forgetLoop(OrigLoop);
+  PSE.getSE()->forgetLoop(OrigLoop);
 
   // Update the dominator tree information.
   assert(DT->properlyDominates(LoopBypassBlocks.front(), LoopExitBlock) &&
 
   // Update the dominator tree information.
   assert(DT->properlyDominates(LoopBypassBlocks.front(), LoopExitBlock) &&
@@ -4119,10 +4113,10 @@ bool LoopVectorizationLegality::canVectorize() {
   }
 
   // ScalarEvolution needs to be able to find the exit count.
   }
 
   // ScalarEvolution needs to be able to find the exit count.
-  const SCEV *ExitCount = SE->getBackedgeTakenCount(TheLoop);
-  if (ExitCount == SE->getCouldNotCompute()) {
-    emitAnalysis(VectorizationReport() <<
-                 "could not determine number of loop iterations");
+  const SCEV *ExitCount = PSE.getSE()->getBackedgeTakenCount(TheLoop);
+  if (ExitCount == PSE.getSE()->getCouldNotCompute()) {
+    emitAnalysis(VectorizationReport()
+                 << "could not determine number of loop iterations");
     DEBUG(dbgs() << "LV: SCEV could not compute the loop exit count.\n");
     return false;
   }
     DEBUG(dbgs() << "LV: SCEV could not compute the loop exit count.\n");
     return false;
   }
@@ -4162,7 +4156,7 @@ bool LoopVectorizationLegality::canVectorize() {
   if (Hints->getForce() == LoopVectorizeHints::FK_Enabled)
     SCEVThreshold = PragmaVectorizeSCEVCheckThreshold;
 
   if (Hints->getForce() == LoopVectorizeHints::FK_Enabled)
     SCEVThreshold = PragmaVectorizeSCEVCheckThreshold;
 
-  if (Preds.getComplexity() > SCEVThreshold) {
+  if (PSE.getUnionPredicate().getComplexity() > SCEVThreshold) {
     emitAnalysis(VectorizationReport()
                  << "Too many SCEV assumptions need to be made and checked "
                  << "at runtime");
     emitAnalysis(VectorizationReport()
                  << "Too many SCEV assumptions need to be made and checked "
                  << "at runtime");
@@ -4268,7 +4262,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
         }
 
         InductionDescriptor ID;
         }
 
         InductionDescriptor ID;
-        if (InductionDescriptor::isInductionPHI(Phi, SE, ID)) {
+        if (InductionDescriptor::isInductionPHI(Phi, PSE.getSE(), ID)) {
           Inductions[Phi] = ID;
           // Get the widest type.
           if (!WidestIndTy)
           Inductions[Phi] = ID;
           // Get the widest type.
           if (!WidestIndTy)
@@ -4337,7 +4331,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
       // second argument is the same (i.e. loop invariant)
       if (CI &&
           hasVectorInstrinsicScalarOpd(getIntrinsicIDForCall(CI, TLI), 1)) {
       // second argument is the same (i.e. loop invariant)
       if (CI &&
           hasVectorInstrinsicScalarOpd(getIntrinsicIDForCall(CI, TLI), 1)) {
-        if (!SE->isLoopInvariant(SE->getSCEV(CI->getOperand(1)), TheLoop)) {
+        auto *SE = PSE.getSE();
+        if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(1)), TheLoop)) {
           emitAnalysis(VectorizationReport(&*it)
                        << "intrinsic instruction cannot be vectorized");
           DEBUG(dbgs() << "LV: Found unvectorizable intrinsic " << *CI << "\n");
           emitAnalysis(VectorizationReport(&*it)
                        << "intrinsic instruction cannot be vectorized");
           DEBUG(dbgs() << "LV: Found unvectorizable intrinsic " << *CI << "\n");
@@ -4410,7 +4405,7 @@ void LoopVectorizationLegality::collectStridedAccess(Value *MemAccess) {
   else
     return;
 
   else
     return;
 
-  Value *Stride = getStrideFromPointer(Ptr, SE, TheLoop);
+  Value *Stride = getStrideFromPointer(Ptr, PSE.getSE(), TheLoop);
   if (!Stride)
     return;
 
   if (!Stride)
     return;
 
@@ -4474,7 +4469,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() {
   }
 
   Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks());
   }
 
   Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks());
-  Preds.add(&LAI->Preds);
+  PSE.addPredicate(LAI->PSE.getUnionPredicate());
 
   return true;
 }
 
   return true;
 }
@@ -4589,7 +4584,7 @@ void InterleavedAccessInfo::collectConstStridedAccesses(
     StoreInst *SI = dyn_cast<StoreInst>(I);
 
     Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand();
     StoreInst *SI = dyn_cast<StoreInst>(I);
 
     Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand();
-    int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides, Preds);
+    int Stride = isStridedPtr(PSE, Ptr, TheLoop, Strides);
 
     // The factor of the corresponding interleave group.
     unsigned Factor = std::abs(Stride);
 
     // The factor of the corresponding interleave group.
     unsigned Factor = std::abs(Stride);
@@ -4598,7 +4593,7 @@ void InterleavedAccessInfo::collectConstStridedAccesses(
     if (Factor < 2 || Factor > MaxInterleaveGroupFactor)
       continue;
 
     if (Factor < 2 || Factor > MaxInterleaveGroupFactor)
       continue;
 
-    const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
+    const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);
     PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType());
     unsigned Size = DL.getTypeAllocSize(PtrTy->getElementType());
 
     PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType());
     unsigned Size = DL.getTypeAllocSize(PtrTy->getElementType());
 
@@ -4685,8 +4680,8 @@ void InterleavedAccessInfo::analyzeInterleaving(
         continue;
 
       // Calculate the distance and prepare for the rule 3.
         continue;
 
       // Calculate the distance and prepare for the rule 3.
-      const SCEVConstant *DistToA =
-          dyn_cast<SCEVConstant>(SE->getMinusSCEV(DesB.Scev, DesA.Scev));
+      const SCEVConstant *DistToA = dyn_cast<SCEVConstant>(
+          PSE.getSE()->getMinusSCEV(DesB.Scev, DesA.Scev));
       if (!DistToA)
         continue;
 
       if (!DistToA)
         continue;