[SCEV][LV] Add SCEV Predicates and use them to re-implement stride versioning
[oota-llvm.git] / lib / Transforms / Vectorize / LoopVectorize.cpp
index ae5ec8cb88a84411c7c22f81b98464daf431860c..b0cfbcc6777ec354a16c840bc6d58ade26f4234d 100644 (file)
@@ -222,6 +222,15 @@ static cl::opt<unsigned> PragmaVectorizeMemoryCheckThreshold(
     cl::desc("The maximum allowed number of runtime memory checks with a "
              "vectorize(enable) pragma."));
 
+static cl::opt<unsigned> VectorizeSCEVCheckThreshold(
+    "vectorize-scev-check-threshold", cl::init(16), cl::Hidden,
+    cl::desc("The maximum number of SCEV checks allowed."));
+
+static cl::opt<unsigned> PragmaVectorizeSCEVCheckThreshold(
+    "pragma-vectorize-scev-check-threshold", cl::init(128), cl::Hidden,
+    cl::desc("The maximum number of SCEV checks allowed with a "
+             "vectorize(enable) pragma"));
+
 namespace {
 
 // Forward declarations.
@@ -273,12 +282,12 @@ public:
   InnerLoopVectorizer(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI,
                       DominatorTree *DT, const TargetLibraryInfo *TLI,
                       const TargetTransformInfo *TTI, unsigned VecWidth,
-                      unsigned UnrollFactor)
+                      unsigned UnrollFactor, SCEVUnionPredicate &Preds)
       : OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), TLI(TLI), TTI(TTI),
         VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()),
         Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor),
         TripCount(nullptr), VectorTripCount(nullptr), Legal(nullptr),
-        AddedSafetyChecks(false) {}
+        AddedSafetyChecks(false), Preds(Preds) {}
 
   // Perform the actual loop widening (vectorization).
   // MinimumBitWidths maps scalar integer values to the smallest bitwidth they
@@ -315,12 +324,6 @@ protected:
   typedef DenseMap<std::pair<BasicBlock*, BasicBlock*>,
                    VectorParts> EdgeMaskCache;
 
-  /// \brief Add checks for strides that were assumed to be 1.
-  ///
-  /// Returns the last check instruction and the first check instruction in the
-  /// pair as (first, last).
-  std::pair<Instruction *, Instruction *> addStrideCheck(Instruction *Loc);
-
   /// Create an empty loop, based on the loop ranges of the old loop.
   void createEmptyLoop();
   /// Create a new induction variable inside L.
@@ -404,11 +407,12 @@ protected:
   void emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass);
   /// Emit a bypass check to see if the vector trip count is nonzero.
   void emitVectorLoopEnteredCheck(Loop *L, BasicBlock *Bypass);
-  /// Emit bypass checks to check if strides we've assumed to be one really are.
-  void emitStrideChecks(Loop *L, BasicBlock *Bypass);
+  /// Emit a bypass check to see if all of the SCEV assumptions we've
+  /// had to make are correct.
+  void emitSCEVChecks(Loop *L, BasicBlock *Bypass);
   /// Emit bypass checks to check any memory assumptions we may have made.
   void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass);
-  
+
   /// This is a helper class that holds the vectorizer state. It maps scalar
   /// instructions to vector instructions. When the code is 'unrolled' then
   /// then a single scalar value is mapped to multiple vector parts. The parts
@@ -516,14 +520,23 @@ protected:
 
   // Record whether runtime check is added.
   bool AddedSafetyChecks;
+
+  /// The SCEV predicate containing all the SCEV-related assumptions.
+  /// The predicate is used to simplify existing expressions in the
+  /// context of existing SCEV assumptions. Since legality checking is
+  /// not done here, we don't need to use this predicate to record
+  /// further assumptions.
+  SCEVUnionPredicate &Preds;
 };
 
 class InnerLoopUnroller : public InnerLoopVectorizer {
 public:
   InnerLoopUnroller(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI,
                     DominatorTree *DT, const TargetLibraryInfo *TLI,
-                    const TargetTransformInfo *TTI, unsigned UnrollFactor)
-      : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor) {}
+                    const TargetTransformInfo *TTI, unsigned UnrollFactor,
+                    SCEVUnionPredicate &Preds)
+      : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor,
+                            Preds) {}
 
 private:
   void scalarizeInstruction(Instruction *Instr,
@@ -744,8 +757,9 @@ private:
 /// between the member and the group in a map.
 class InterleavedAccessInfo {
 public:
-  InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT)
-      : SE(SE), TheLoop(L), DT(DT) {}
+  InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT,
+                        SCEVUnionPredicate &Preds)
+      : SE(SE), TheLoop(L), DT(DT), Preds(Preds) {}
 
   ~InterleavedAccessInfo() {
     SmallSet<InterleaveGroup *, 4> DelSet;
@@ -779,6 +793,13 @@ private:
   Loop *TheLoop;
   DominatorTree *DT;
 
+  /// The SCEV predicate containing all the SCEV-related assumptions.
+  /// The predicate is used to simplify SCEV expressions in the
+  /// context of existing SCEV assumptions. The interleaved access
+  /// analysis can also add new predicates (for example by versioning
+  /// strides of pointers).
+  SCEVUnionPredicate &Preds;
+
   /// Holds the relationships between the members and the interleave group.
   DenseMap<Instruction *, InterleaveGroup *> InterleaveGroupMap;
 
@@ -1141,11 +1162,13 @@ public:
                             Function *F, const TargetTransformInfo *TTI,
                             LoopAccessAnalysis *LAA,
                             LoopVectorizationRequirements *R,
-                            const LoopVectorizeHints *H)
+                            const LoopVectorizeHints *H,
+                            SCEVUnionPredicate &Preds)
       : NumPredStores(0), TheLoop(L), SE(SE), TLI(TLI), TheFunction(F),
-        TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), InterleaveInfo(SE, L, DT),
-        Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false),
-        Requirements(R), Hints(H) {}
+        TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr),
+        InterleaveInfo(SE, L, DT, Preds), Induction(nullptr),
+        WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), Hints(H),
+        Preds(Preds) {}
 
   /// ReductionList contains the reduction descriptors for all
   /// of the reductions that were found in the loop.
@@ -1344,7 +1367,14 @@ private:
 
   /// While vectorizing these instructions we have to generate a
   /// call to the appropriate masked intrinsic
-  SmallPtrSet<const Instruction*, 8> MaskedOp;
+  SmallPtrSet<const Instruction *, 8> MaskedOp;
+
+  /// The SCEV predicate containing all the SCEV-related assumptions.
+  /// The predicate is used to simplify SCEV expressions in the
+  /// context of existing SCEV assumptions. The analysis will also
+  /// add a minimal set of new predicates if this is required to
+  /// enable vectorization/unrolling.
+  SCEVUnionPredicate &Preds;
 };
 
 /// LoopVectorizationCostModel - estimates the expected speedups due to
@@ -1360,9 +1390,10 @@ public:
                              LoopVectorizationLegality *Legal,
                              const TargetTransformInfo &TTI,
                              const TargetLibraryInfo *TLI, DemandedBits *DB,
-                             AssumptionCache *AC,
-                             const Function *F, const LoopVectorizeHints *Hints,
-                             SmallPtrSetImpl<const Value *> &ValuesToIgnore)
+                             AssumptionCache *AC, const Function *F,
+                             const LoopVectorizeHints *Hints,
+                             SmallPtrSetImpl<const Value *> &ValuesToIgnore,
+                             SCEVUnionPredicate &Preds)
       : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB),
         TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {}
 
@@ -1690,10 +1721,12 @@ struct LoopVectorize : public FunctionPass {
       }
     }
 
+    SCEVUnionPredicate Preds;
+
     // Check if it is legal to vectorize the loop.
     LoopVectorizationRequirements Requirements;
     LoopVectorizationLegality LVL(L, SE, DT, TLI, AA, F, TTI, LAA,
-                                  &Requirements, &Hints);
+                                  &Requirements, &Hints, Preds);
     if (!LVL.canVectorize()) {
       DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n");
       emitMissedWarning(F, L, Hints);
@@ -1712,7 +1745,7 @@ struct LoopVectorize : public FunctionPass {
 
     // Use the cost model.
     LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, DB, AC, F, &Hints,
-                                  ValuesToIgnore);
+                                  ValuesToIgnore, Preds);
 
     // Check the function attributes to find out if this function should be
     // optimized for size.
@@ -1823,7 +1856,7 @@ struct LoopVectorize : public FunctionPass {
       assert(IC > 1 && "interleave count should not be 1 or 0");
       // If we decided that it is not legal to vectorize the loop then
       // interleave it.
-      InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC);
+      InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC, Preds);
       Unroller.vectorize(&LVL, CM.MinBWs);
 
       emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(),
@@ -1831,7 +1864,7 @@ struct LoopVectorize : public FunctionPass {
                                  Twine(IC) + ")");
     } else {
       // If we decided that it is *legal* to vectorize the loop then do it.
-      InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC);
+      InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC, Preds);
       LB.vectorize(&LVL, CM.MinBWs);
       ++LoopsVectorized;
 
@@ -1992,7 +2025,7 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) {
     //  %idxprom = zext i32 %mul to i64  << Safe cast.
     //  %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom
     //
-    Last = replaceSymbolicStrideSCEV(SE, Strides,
+    Last = replaceSymbolicStrideSCEV(SE, Strides, Preds,
                                      Gep->getOperand(InductionOperand), Gep);
     if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(Last))
       Last =
@@ -2551,56 +2584,8 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, bool IfPredic
   }
 }
 
-static Instruction *getFirstInst(Instruction *FirstInst, Value *V,
-                                 Instruction *Loc) {
-  if (FirstInst)
-    return FirstInst;
-  if (Instruction *I = dyn_cast<Instruction>(V))
-    return I->getParent() == Loc->getParent() ? I : nullptr;
-  return nullptr;
-}
-
-std::pair<Instruction *, Instruction *>
-InnerLoopVectorizer::addStrideCheck(Instruction *Loc) {
-  Instruction *tnullptr = nullptr;
-  if (!Legal->mustCheckStrides())
-    return std::pair<Instruction *, Instruction *>(tnullptr, tnullptr);
-
-  IRBuilder<> ChkBuilder(Loc);
-
-  // Emit checks.
-  Value *Check = nullptr;
-  Instruction *FirstInst = nullptr;
-  for (SmallPtrSet<Value *, 8>::iterator SI = Legal->strides_begin(),
-                                         SE = Legal->strides_end();
-       SI != SE; ++SI) {
-    Value *Ptr = stripIntegerCast(*SI);
-    Value *C = ChkBuilder.CreateICmpNE(Ptr, ConstantInt::get(Ptr->getType(), 1),
-                                       "stride.chk");
-    // Store the first instruction we create.
-    FirstInst = getFirstInst(FirstInst, C, Loc);
-    if (Check)
-      Check = ChkBuilder.CreateOr(Check, C);
-    else
-      Check = C;
-  }
-
-  // We have to do this trickery because the IRBuilder might fold the check to a
-  // constant expression in which case there is no Instruction anchored in a
-  // the block.
-  LLVMContext &Ctx = Loc->getContext();
-  Instruction *TheCheck =
-      BinaryOperator::CreateAnd(Check, ConstantInt::getTrue(Ctx));
-  ChkBuilder.Insert(TheCheck, "stride.not.one");
-  FirstInst = getFirstInst(FirstInst, TheCheck, Loc);
-
-  return std::make_pair(FirstInst, TheCheck);
-}
-
-PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L,
-                                                      Value *Start,
-                                                      Value *End,
-                                                      Value *Step,
+PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start,
+                                                      Value *End, Value *Step,
                                                       Instruction *DL) {
   BasicBlock *Header = L->getHeader();
   BasicBlock *Latch = L->getLoopLatch();
@@ -2735,26 +2720,26 @@ void InnerLoopVectorizer::emitVectorLoopEnteredCheck(Loop *L,
   LoopBypassBlocks.push_back(BB);
 }
 
-void InnerLoopVectorizer::emitStrideChecks(Loop *L,
-                                           BasicBlock *Bypass) {
+void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) {
   BasicBlock *BB = L->getLoopPreheader();
-  
-  // Generate the code to check that the strides we assumed to be one are really
-  // one. We want the new basic block to start at the first instruction in a
+
+  // Generate the code to check that the SCEV assumptions that we made.
+  // We want the new basic block to start at the first instruction in a
   // sequence of instructions that form a check.
-  Instruction *StrideCheck;
-  Instruction *FirstCheckInst;
-  std::tie(FirstCheckInst, StrideCheck) = addStrideCheck(BB->getTerminator());
-  if (!StrideCheck)
-    return;
+  SCEVExpander Exp(*SE, Bypass->getModule()->getDataLayout(), "scev.check");
+  Value *SCEVCheck = Exp.expandCodeForPredicate(&Preds, BB->getTerminator());
+
+  if (auto *C = dyn_cast<ConstantInt>(SCEVCheck))
+    if (C->isZero())
+      return;
 
   // Create a new block containing the stride check.
-  BB->setName("vector.stridecheck");
+  BB->setName("vector.scevcheck");
   auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph");
   if (L->getParentLoop())
     L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI);
   ReplaceInstWithInst(BB->getTerminator(),
-                      BranchInst::Create(Bypass, NewBB, StrideCheck));
+                      BranchInst::Create(Bypass, NewBB, SCEVCheck));
   LoopBypassBlocks.push_back(BB);
   AddedSafetyChecks = true;
 }
@@ -2874,10 +2859,10 @@ void InnerLoopVectorizer::createEmptyLoop() {
   // Now, compare the new count to zero. If it is zero skip the vector loop and
   // jump to the scalar loop.
   emitVectorLoopEnteredCheck(Lp, ScalarPH);
-  // Generate the code to check that the strides we assumed to be one are really
-  // one. We want the new basic block to start at the first instruction in a
-  // sequence of instructions that form a check.
-  emitStrideChecks(Lp, ScalarPH);
+  // Generate the code to check any assumptions that we've made for SCEV
+  // expressions.
+  emitSCEVChecks(Lp, ScalarPH);
+
   // Generate the code that checks in runtime if arrays overlap. We put the
   // checks into a separate block to make the more common case of few elements
   // faster.
@@ -4130,7 +4115,19 @@ bool LoopVectorizationLegality::canVectorize() {
 
   // Analyze interleaved memory accesses.
   if (UseInterleaved)
-     InterleaveInfo.analyzeInterleaving(Strides);
+    InterleaveInfo.analyzeInterleaving(Strides);
+
+  unsigned SCEVThreshold = VectorizeSCEVCheckThreshold;
+  if (Hints->getForce() == LoopVectorizeHints::FK_Enabled)
+    SCEVThreshold = PragmaVectorizeSCEVCheckThreshold;
+
+  if (Preds.getComplexity() > SCEVThreshold) {
+    emitAnalysis(VectorizationReport()
+                 << "Too many SCEV assumptions need to be made and checked "
+                 << "at runtime");
+    DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n");
+    return false;
+  }
 
   // Okay! We can vectorize. At this point we don't have any other mem analysis
   // which may limit our maximum vectorization factor, so just return true with
@@ -4436,6 +4433,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() {
   }
 
   Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks());
+  Preds.add(&LAI->Preds);
 
   return true;
 }
@@ -4550,7 +4548,7 @@ void InterleavedAccessInfo::collectConstStridedAccesses(
     StoreInst *SI = dyn_cast<StoreInst>(I);
 
     Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand();
-    int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides);
+    int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides, Preds);
 
     // The factor of the corresponding interleave group.
     unsigned Factor = std::abs(Stride);
@@ -4559,7 +4557,7 @@ void InterleavedAccessInfo::collectConstStridedAccesses(
     if (Factor < 2 || Factor > MaxInterleaveGroupFactor)
       continue;
 
-    const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
+    const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
     PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType());
     unsigned Size = DL.getTypeAllocSize(PtrTy->getElementType());