LoopVectorizer: Handle strided memory accesses by versioning
authorArnold Schwaighofer <aschwaighofer@apple.com>
Fri, 10 Jan 2014 18:20:32 +0000 (18:20 +0000)
committerArnold Schwaighofer <aschwaighofer@apple.com>
Fri, 10 Jan 2014 18:20:32 +0000 (18:20 +0000)
 for (i = 0; i < N; ++i)
   A[i * Stride1] += B[i * Stride2];

We take loops like this and check that the symbolic strides 'Strided1/2' are one
and drop to the scalar loop if they are not.

This is currently disabled by default and hidden behind the flag
'enable-mem-access-versioning'.

radar://13075509

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@198950 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Vectorize/LoopVectorize.cpp
test/Transforms/LoopVectorize/runtime-check-readonly.ll
test/Transforms/LoopVectorize/version-mem-access.ll [new file with mode: 0644]

index 70c18edf55a9d0ede07d2951a4efada9173b1d79..74285ec2457638f996dbed3478bd5dfea07566ad 100644 (file)
@@ -114,6 +114,21 @@ TinyTripCountVectorThreshold("vectorizer-min-trip-count", cl::init(16),
                                       "trip count that is smaller than this "
                                       "value."));
 
+/// This enables versioning on the strides of symbolically striding memory
+/// accesses in code like the following.
+///   for (i = 0; i < N; ++i)
+///     A[i * Stride1] += B[i * Stride2] ...
+///
+/// Will be roughly translated to
+///    if (Stride1 == 1 && Stride2 == 1) {
+///      for (i = 0; i < N; i+=4)
+///       A[i:i+3] += ...
+///    } else
+///      ...
+static cl::opt<bool> EnableMemAccessVersioning(
+    "enable-mem-access-versioning", cl::init(false), cl::Hidden,
+    cl::desc("Enable symblic stride memory access versioning"));
+
 /// We don't unroll loops with a known constant trip count below this number.
 static const unsigned TinyTripCountUnrollThreshold = 128;
 
@@ -158,15 +173,16 @@ public:
                       unsigned UnrollFactor)
       : OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), DL(DL), TLI(TLI),
         VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()), Induction(0),
-        OldInduction(0), WidenMap(UnrollFactor) {}
+        OldInduction(0), WidenMap(UnrollFactor), Legal(0) {}
 
   // Perform the actual loop widening (vectorization).
-  void vectorize(LoopVectorizationLegality *Legal) {
+  void vectorize(LoopVectorizationLegality *L) {
+    Legal = L;
     // Create a new empty loop. Unlink the old loop and connect the new one.
-    createEmptyLoop(Legal);
+    createEmptyLoop();
     // Widen each instruction in the old loop to a new one in the new loop.
     // Use the Legality module to find the induction and reduction variables.
-    vectorizeLoop(Legal);
+    vectorizeLoop();
     // Register the new loop and update the analysis passes.
     updateAnalysis();
   }
@@ -186,14 +202,23 @@ protected:
   typedef DenseMap<std::pair<BasicBlock*, BasicBlock*>,
                    VectorParts> EdgeMaskCache;
 
-  /// Add code that checks at runtime if the accessed arrays overlap.
-  /// Returns the comparator value or NULL if no check is needed.
-  Instruction *addRuntimeCheck(LoopVectorizationLegality *Legal,
-                               Instruction *Loc);
+  /// \brief Add code that checks at runtime if the accessed arrays overlap.
+  ///
+  /// Returns a pair of instructions where the first element is the first
+  /// instruction generated in possibly a sequence of instructions and the
+  /// second value is the final comparator value or NULL if no check is needed.
+  std::pair<Instruction *, Instruction *> addRuntimeCheck(Instruction *Loc);
+
+  /// \brief Add checks for strides that where assumed to be 1.
+  ///
+  /// Returns the last check instruction and the first check instruction in the
+  /// pair as (first, last).
+  std::pair<Instruction *, Instruction *> addStrideCheck(Instruction *Loc);
+
   /// Create an empty loop, based on the loop ranges of the old loop.
-  void createEmptyLoop(LoopVectorizationLegality *Legal);
+  void createEmptyLoop();
   /// Copy and widen the instructions from the old loop.
-  virtual void vectorizeLoop(LoopVectorizationLegality *Legal);
+  virtual void vectorizeLoop();
 
   /// \brief The Loop exit block may have single value PHI nodes where the
   /// incoming value is 'Undef'. While vectorizing we only handled real values
@@ -210,14 +235,12 @@ protected:
   VectorParts createEdgeMask(BasicBlock *Src, BasicBlock *Dst);
 
   /// A helper function to vectorize a single BB within the innermost loop.
-  void vectorizeBlockInLoop(LoopVectorizationLegality *Legal, BasicBlock *BB,
-                            PhiVector *PV);
+  void vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV);
 
   /// Vectorize a single PHINode in a block. This method handles the induction
   /// variable canonicalization. It supports both VF = 1 for unrolled loops and
   /// arbitrary length vectors.
   void widenPHIInstruction(Instruction *PN, VectorParts &Entry,
-                           LoopVectorizationLegality *Legal,
                            unsigned UF, unsigned VF, PhiVector *PV);
 
   /// Insert the new loop to the loop hierarchy and pass manager
@@ -229,8 +252,7 @@ protected:
   virtual void scalarizeInstruction(Instruction *Instr);
 
   /// Vectorize Load and Store instructions,
-  virtual void vectorizeMemoryInstruction(Instruction *Instr,
-                                  LoopVectorizationLegality *Legal);
+  virtual void vectorizeMemoryInstruction(Instruction *Instr);
 
   /// Create a broadcast instruction. This method generates a broadcast
   /// instruction (shuffle) for loop invariant values and for the induction
@@ -345,6 +367,8 @@ protected:
   /// Maps scalars to widened vectors.
   ValueMap WidenMap;
   EdgeMaskCache MaskCache;
+
+  LoopVectorizationLegality *Legal;
 };
 
 class InnerLoopUnroller : public InnerLoopVectorizer {
@@ -356,8 +380,7 @@ public:
 
 private:
   virtual void scalarizeInstruction(Instruction *Instr);
-  virtual void vectorizeMemoryInstruction(Instruction *Instr,
-                                          LoopVectorizationLegality *Legal);
+  virtual void vectorizeMemoryInstruction(Instruction *Instr);
   virtual Value *getBroadcastInstrs(Value *V);
   virtual Value *getConsecutiveVector(Value* Val, int StartIdx, bool Negate);
   virtual Value *reverseVector(Value *Vec);
@@ -500,7 +523,7 @@ public:
 
     /// Insert a pointer and calculate the start and end SCEVs.
     void insert(ScalarEvolution *SE, Loop *Lp, Value *Ptr, bool WritePtr,
-                unsigned DepSetId);
+                unsigned DepSetId, ValueToValueMap &Strides);
 
     /// This flag indicates if we need to add the runtime check.
     bool Need;
@@ -584,6 +607,13 @@ public:
 
   unsigned getMaxSafeDepDistBytes() { return MaxSafeDepDistBytes; }
 
+  bool hasStride(Value *V) { return StrideSet.count(V); }
+  bool mustCheckStrides() { return !StrideSet.empty(); }
+  SmallPtrSet<Value *, 8>::iterator strides_begin() {
+    return StrideSet.begin();
+  }
+  SmallPtrSet<Value *, 8>::iterator strides_end() { return StrideSet.end(); }
+
 private:
   /// Check if a single basic block loop is vectorizable.
   /// At this point we know that this is a loop with a constant trip count
@@ -626,6 +656,12 @@ private:
   /// if the PHI is not an induction variable.
   InductionKind isInductionVariable(PHINode *Phi);
 
+  /// \brief Collect memory access with loop invariant strides.
+  ///
+  /// Looks for accesses like "a[i * StrideA]" where "StrideA" is loop
+  /// invariant.
+  void collectStridedAcccess(Value *LoadOrStoreInst);
+
   /// The loop that we evaluate.
   Loop *TheLoop;
   /// Scev analysis.
@@ -664,6 +700,9 @@ private:
   bool HasFunNoNaNAttr;
 
   unsigned MaxSafeDepDistBytes;
+
+  ValueToValueMap Strides;
+  SmallPtrSet<Value *, 8> StrideSet;
 };
 
 /// LoopVectorizationCostModel - estimates the expected speedups due to
@@ -1033,12 +1072,52 @@ struct LoopVectorize : public LoopPass {
 // LoopVectorizationCostModel.
 //===----------------------------------------------------------------------===//
 
-void
-LoopVectorizationLegality::RuntimePointerCheck::insert(ScalarEvolution *SE,
-                                                       Loop *Lp, Value *Ptr,
-                                                       bool WritePtr,
-                                                       unsigned DepSetId) {
-  const SCEV *Sc = SE->getSCEV(Ptr);
+static Value *stripCast(Value *V) {
+  if (CastInst *CI = dyn_cast<CastInst>(V))
+    return CI->getOperand(0);
+  return V;
+}
+
+///\brief Replaces the symbolic stride in a pointer SCEV expression by one.
+///
+/// If \p OrigPtr is not null, use it to look up the stride value instead of
+/// \p Ptr.
+static const SCEV *replaceSymbolicStrideSCEV(ScalarEvolution *SE,
+                                             ValueToValueMap &PtrToStride,
+                                             Value *Ptr, Value *OrigPtr = 0) {
+
+  const SCEV *OrigSCEV = SE->getSCEV(Ptr);
+
+  // If there is an entry in the map return the SCEV of the pointer with the
+  // symbolic stride replaced by one.
+  ValueToValueMap::iterator SI = PtrToStride.find(OrigPtr ? OrigPtr : Ptr);
+  if (SI != PtrToStride.end()) {
+    Value *StrideVal = SI->second;
+
+    // Strip casts.
+    StrideVal = stripCast(StrideVal);
+
+    // Replace symbolic stride by one.
+    Value *One = ConstantInt::get(StrideVal->getType(), 1);
+    ValueToValueMap RewriteMap;
+    RewriteMap[StrideVal] = One;
+
+    const SCEV *ByOne =
+        SCEVParameterRewriter::rewrite(OrigSCEV, *SE, RewriteMap, true);
+    DEBUG(dbgs() << "LV: Replacing SCEV: " << *OrigSCEV << " by: " << *ByOne
+                 << "\n");
+    return ByOne;
+  }
+
+  // Otherwise, just return the SCEV of the original pointer.
+  return SE->getSCEV(Ptr);
+}
+
+void LoopVectorizationLegality::RuntimePointerCheck::insert(
+    ScalarEvolution *SE, Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId,
+    ValueToValueMap &Strides) {
+  // Get the stride replaced scev.
+  const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc);
   assert(AR && "Invalid addrec expression");
   const SCEV *Ex = SE->getBackedgeTakenCount(Lp);
@@ -1170,7 +1249,27 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) {
 
   // We can emit wide load/stores only if the last non-zero index is the
   // induction variable.
-  const SCEV *Last = SE->getSCEV(Gep->getOperand(InductionOperand));
+  const SCEV *Last = 0;
+  if (!Strides.count(Gep))
+    Last = SE->getSCEV(Gep->getOperand(InductionOperand));
+  else {
+    // Because of the multiplication by a stride we can have a s/zext cast.
+    // We are going to replace this stride by 1 so the cast is safe to ignore.
+    //
+    //  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+    //  %0 = trunc i64 %indvars.iv to i32
+    //  %mul = mul i32 %0, %Stride1
+    //  %idxprom = zext i32 %mul to i64  << Safe cast.
+    //  %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom
+    //
+    Last = replaceSymbolicStrideSCEV(SE, Strides,
+                                     Gep->getOperand(InductionOperand), Gep);
+    if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(Last))
+      Last =
+          (C->getSCEVType() == scSignExtend || C->getSCEVType() == scZeroExtend)
+              ? C->getOperand()
+              : Last;
+  }
   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Last)) {
     const SCEV *Step = AR->getStepRecurrence(*SE);
 
@@ -1194,6 +1293,10 @@ InnerLoopVectorizer::getVectorValue(Value *V) {
   assert(V != Induction && "The new induction variable should not be used.");
   assert(!V->getType()->isVectorTy() && "Can't widen a vector");
 
+  // If we have a stride that is replaced by one, do it here.
+  if (Legal->hasStride(V))
+    V = ConstantInt::get(V->getType(), 1);
+
   // If we have this scalar in the map, return it.
   if (WidenMap.has(V))
     return WidenMap.get(V);
@@ -1215,9 +1318,7 @@ Value *InnerLoopVectorizer::reverseVector(Value *Vec) {
                                      "reverse");
 }
 
-
-void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr,
-                                             LoopVectorizationLegality *Legal) {
+void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) {
   // Attempt to issue a wide load.
   LoadInst *LI = dyn_cast<LoadInst>(Instr);
   StoreInst *SI = dyn_cast<StoreInst>(Instr);
@@ -1427,14 +1528,58 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr) {
   }
 }
 
-Instruction *
-InnerLoopVectorizer::addRuntimeCheck(LoopVectorizationLegality *Legal,
-                                     Instruction *Loc) {
+static Instruction *getFirstInst(Instruction *FirstInst, Value *V,
+                                 Instruction *Loc) {
+  if (FirstInst)
+    return FirstInst;
+  if (Instruction *I = dyn_cast<Instruction>(V))
+    return I->getParent() == Loc->getParent() ? I : 0;
+  return 0;
+}
+
+std::pair<Instruction *, Instruction *>
+InnerLoopVectorizer::addStrideCheck(Instruction *Loc) {
+  if (!Legal->mustCheckStrides())
+    return std::pair<Instruction *, Instruction *>(0, 0);
+
+  IRBuilder<> ChkBuilder(Loc);
+
+  // Emit checks.
+  Value *Check = 0;
+  Instruction *FirstInst = 0;
+  for (SmallPtrSet<Value *, 8>::iterator SI = Legal->strides_begin(),
+                                         SE = Legal->strides_end();
+       SI != SE; ++SI) {
+    Value *Ptr = stripCast(*SI);
+    Value *C = ChkBuilder.CreateICmpNE(Ptr, ConstantInt::get(Ptr->getType(), 1),
+                                       "stride.chk");
+    // Store the first instruction we create.
+    FirstInst = getFirstInst(FirstInst, C, Loc);
+    if (Check)
+      Check = ChkBuilder.CreateOr(Check, C);
+    else
+      Check = C;
+  }
+
+  // We have to do this trickery because the IRBuilder might fold the check to a
+  // constant expression in which case there is no Instruction anchored in a
+  // the block.
+  LLVMContext &Ctx = Loc->getContext();
+  Instruction *TheCheck =
+      BinaryOperator::CreateAnd(Check, ConstantInt::getTrue(Ctx));
+  ChkBuilder.Insert(TheCheck, "stride.not.one");
+  FirstInst = getFirstInst(FirstInst, TheCheck, Loc);
+
+  return std::make_pair(FirstInst, TheCheck);
+}
+
+std::pair<Instruction *, Instruction *>
+InnerLoopVectorizer::addRuntimeCheck(Instruction *Loc) {
   LoopVectorizationLegality::RuntimePointerCheck *PtrRtCheck =
   Legal->getRuntimePointerCheck();
 
   if (!PtrRtCheck->Need)
-    return NULL;
+    return std::pair<Instruction *, Instruction *>(0, 0);
 
   unsigned NumPointers = PtrRtCheck->Pointers.size();
   SmallVector<TrackingVH<Value> , 2> Starts;
@@ -1442,6 +1587,7 @@ InnerLoopVectorizer::addRuntimeCheck(LoopVectorizationLegality *Legal,
 
   LLVMContext &Ctx = Loc->getContext();
   SCEVExpander Exp(*SE, "induction");
+  Instruction *FirstInst = 0;
 
   for (unsigned i = 0; i < NumPointers; ++i) {
     Value *Ptr = PtrRtCheck->Pointers[i];
@@ -1495,11 +1641,16 @@ InnerLoopVectorizer::addRuntimeCheck(LoopVectorizationLegality *Legal,
       Value *End1 =   ChkBuilder.CreateBitCast(Ends[j],   PtrArithTy0, "bc");
 
       Value *Cmp0 = ChkBuilder.CreateICmpULE(Start0, End1, "bound0");
+      FirstInst = getFirstInst(FirstInst, Cmp0, Loc);
       Value *Cmp1 = ChkBuilder.CreateICmpULE(Start1, End0, "bound1");
+      FirstInst = getFirstInst(FirstInst, Cmp1, Loc);
       Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict");
-      if (MemoryRuntimeCheck)
+      FirstInst = getFirstInst(FirstInst, IsConflict, Loc);
+      if (MemoryRuntimeCheck) {
         IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict,
                                          "conflict.rdx");
+        FirstInst = getFirstInst(FirstInst, IsConflict, Loc);
+      }
       MemoryRuntimeCheck = IsConflict;
     }
   }
@@ -1510,11 +1661,11 @@ InnerLoopVectorizer::addRuntimeCheck(LoopVectorizationLegality *Legal,
   Instruction *Check = BinaryOperator::CreateAnd(MemoryRuntimeCheck,
                                                  ConstantInt::getTrue(Ctx));
   ChkBuilder.Insert(Check, "memcheck.conflict");
-  return Check;
+  FirstInst = getFirstInst(FirstInst, Check, Loc);
+  return std::make_pair(FirstInst, Check);
 }
 
-void
-InnerLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) {
+void InnerLoopVectorizer::createEmptyLoop() {
   /*
    In this function we generate a new loop. The new loop will contain
    the vectorized instructions while the old loop will continue to run the
@@ -1665,22 +1816,48 @@ InnerLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) {
 
   BasicBlock *LastBypassBlock = BypassBlock;
 
+  // Generate the code to check that the strides we assumed to be one are really
+  // one. We want the new basic block to start at the first instruction in a
+  // sequence of instructions that form a check.
+  Instruction *StrideCheck;
+  Instruction *FirstCheckInst;
+  tie(FirstCheckInst, StrideCheck) =
+      addStrideCheck(BypassBlock->getTerminator());
+  if (StrideCheck) {
+    // Create a new block containing the stride check.
+    BasicBlock *CheckBlock =
+        BypassBlock->splitBasicBlock(FirstCheckInst, "vector.stridecheck");
+    if (ParentLoop)
+      ParentLoop->addBasicBlockToLoop(CheckBlock, LI->getBase());
+    LoopBypassBlocks.push_back(CheckBlock);
+
+    // Replace the branch into the memory check block with a conditional branch
+    // for the "few elements case".
+    Instruction *OldTerm = BypassBlock->getTerminator();
+    BranchInst::Create(MiddleBlock, CheckBlock, Cmp, OldTerm);
+    OldTerm->eraseFromParent();
+
+    Cmp = StrideCheck;
+    LastBypassBlock = CheckBlock;
+  }
+
   // Generate the code that checks in runtime if arrays overlap. We put the
   // checks into a separate block to make the more common case of few elements
   // faster.
-  Instruction *MemRuntimeCheck = addRuntimeCheck(Legal,
-                                                 BypassBlock->getTerminator());
+  Instruction *MemRuntimeCheck;
+  tie(FirstCheckInst, MemRuntimeCheck) =
+      addRuntimeCheck(LastBypassBlock->getTerminator());
   if (MemRuntimeCheck) {
     // Create a new block containing the memory check.
-    BasicBlock *CheckBlock = BypassBlock->splitBasicBlock(MemRuntimeCheck,
-                                                          "vector.memcheck");
+    BasicBlock *CheckBlock =
+        LastBypassBlock->splitBasicBlock(MemRuntimeCheck, "vector.memcheck");
     if (ParentLoop)
       ParentLoop->addBasicBlockToLoop(CheckBlock, LI->getBase());
     LoopBypassBlocks.push_back(CheckBlock);
 
     // Replace the branch into the memory check block with a conditional branch
     // for the "few elements case".
-    Instruction *OldTerm = BypassBlock->getTerminator();
+    Instruction *OldTerm = LastBypassBlock->getTerminator();
     BranchInst::Create(MiddleBlock, CheckBlock, Cmp, OldTerm);
     OldTerm->eraseFromParent();
 
@@ -2138,8 +2315,7 @@ static void cse(BasicBlock *BB) {
   }
 }
 
-void
-InnerLoopVectorizer::vectorizeLoop(LoopVectorizationLegality *Legal) {
+void InnerLoopVectorizer::vectorizeLoop() {
   //===------------------------------------------------===//
   //
   // Notice: any optimization or new instruction that go
@@ -2167,7 +2343,7 @@ InnerLoopVectorizer::vectorizeLoop(LoopVectorizationLegality *Legal) {
   // Vectorize all of the blocks in the original loop.
   for (LoopBlocksDFS::RPOIterator bb = DFS.beginRPO(),
        be = DFS.endRPO(); bb != be; ++bb)
-    vectorizeBlockInLoop(Legal, *bb, &RdxPHIsToFix);
+    vectorizeBlockInLoop(*bb, &RdxPHIsToFix);
 
   // At this point every instruction in the original loop is widened to
   // a vector form. We are almost done. Now, we need to fix the PHI nodes
@@ -2434,7 +2610,6 @@ InnerLoopVectorizer::createBlockInMask(BasicBlock *BB) {
 
 void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN,
                                               InnerLoopVectorizer::VectorParts &Entry,
-                                              LoopVectorizationLegality *Legal,
                                               unsigned UF, unsigned VF, PhiVector *PV) {
   PHINode* P = cast<PHINode>(PN);
   // Handle reduction variables:
@@ -2596,9 +2771,7 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN,
   }
 }
 
-void
-InnerLoopVectorizer::vectorizeBlockInLoop(LoopVectorizationLegality *Legal,
-                                          BasicBlock *BB, PhiVector *PV) {
+void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) {
   // For each instruction in the old loop.
   for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
     VectorParts &Entry = WidenMap.get(it);
@@ -2609,7 +2782,7 @@ InnerLoopVectorizer::vectorizeBlockInLoop(LoopVectorizationLegality *Legal,
       continue;
     case Instruction::PHI:{
       // Vectorize PHINodes.
-      widenPHIInstruction(it, Entry, Legal, UF, VF, PV);
+      widenPHIInstruction(it, Entry, UF, VF, PV);
       continue;
     }// End of PHI.
 
@@ -2703,7 +2876,7 @@ InnerLoopVectorizer::vectorizeBlockInLoop(LoopVectorizationLegality *Legal,
 
     case Instruction::Store:
     case Instruction::Load:
-        vectorizeMemoryInstruction(it, Legal);
+      vectorizeMemoryInstruction(it);
         break;
     case Instruction::ZExt:
     case Instruction::SExt:
@@ -3120,8 +3293,14 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
         Type *T = ST->getValueOperand()->getType();
         if (!VectorType::isValidElementType(T))
           return false;
+        if (EnableMemAccessVersioning)
+          collectStridedAcccess(ST);
       }
 
+      if (EnableMemAccessVersioning)
+        if (LoadInst *LI = dyn_cast<LoadInst>(it))
+          collectStridedAcccess(LI);
+
       // Reduction instructions are allowed to have exit users.
       // All other instructions must not have external users.
       if (hasOutsideLoopUser(TheLoop, it, AllowedExit))
@@ -3140,6 +3319,139 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
   return true;
 }
 
+///\brief Remove GEPs whose indices but the last one are loop invariant and
+/// return the induction operand of the gep pointer.
+static Value *stripGetElementPtr(Value *Ptr, ScalarEvolution *SE,
+                                 DataLayout *DL, Loop *Lp) {
+  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+  if (!GEP)
+    return Ptr;
+
+  unsigned InductionOperand = getGEPInductionOperand(DL, GEP);
+
+  // Check that all of the gep indices are uniform except for our induction
+  // operand.
+  for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i)
+    if (i != InductionOperand &&
+        !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp))
+      return Ptr;
+  return GEP->getOperand(InductionOperand);
+}
+
+///\brief Look for a cast use of the passed value.
+static Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) {
+  Value *UniqueCast = 0;
+  for (Value::use_iterator UI = Ptr->use_begin(), UE = Ptr->use_end(); UI != UE;
+       ++UI) {
+    CastInst *CI = dyn_cast<CastInst>(*UI);
+    if (CI && CI->getType() == Ty) {
+      if (!UniqueCast)
+        UniqueCast = CI;
+      else
+        return 0;
+    }
+  }
+  return UniqueCast;
+}
+
+///\brief Get the stride of a pointer access in a loop.
+/// Looks for symbolic strides "a[i*stride]". Returns the symbolic stride as a
+/// pointer to the Value, or null otherwise.
+static Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE,
+                                   DataLayout *DL, Loop *Lp) {
+  const PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType());
+  if (!PtrTy || PtrTy->isAggregateType())
+    return 0;
+
+  // Try to remove a gep instruction to make the pointer (actually index at this
+  // point) easier analyzable. If OrigPtr is equal to Ptr we are analzying the
+  // pointer, otherwise, we are analyzing the index.
+  Value *OrigPtr = Ptr;
+
+  // The size of the pointer access.
+  int64_t PtrAccessSize = 1;
+
+  Ptr = stripGetElementPtr(Ptr, SE, DL, Lp);
+  const SCEV *V = SE->getSCEV(Ptr);
+
+  if (Ptr != OrigPtr)
+    // Strip off casts.
+    while (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V))
+      V = C->getOperand();
+
+  const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V);
+  if (!S)
+    return 0;
+
+  V = S->getStepRecurrence(*SE);
+  if (!V)
+    return 0;
+
+  // Strip off the size of access multiplication if we are still analyzing the
+  // pointer.
+  if (OrigPtr == Ptr) {
+    DL->getTypeAllocSize(PtrTy->getElementType());
+    if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) {
+      if (M->getOperand(0)->getSCEVType() != scConstant)
+        return 0;
+
+      const APInt &APStepVal =
+          cast<SCEVConstant>(M->getOperand(0))->getValue()->getValue();
+
+      // Huge step value - give up.
+      if (APStepVal.getBitWidth() > 64)
+        return 0;
+
+      int64_t StepVal = APStepVal.getSExtValue();
+      if (PtrAccessSize != StepVal)
+        return 0;
+      V = M->getOperand(1);
+    }
+  }
+
+  // Strip off casts.
+  Type *StripedOffRecurrenceCast = 0;
+  if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) {
+    StripedOffRecurrenceCast = C->getType();
+    V = C->getOperand();
+  }
+
+  // Look for the loop invariant symbolic value.
+  const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V);
+  if (!U)
+    return 0;
+
+  Value *Stride = U->getValue();
+  if (!Lp->isLoopInvariant(Stride))
+    return 0;
+
+  // If we have stripped off the recurrence cast we have to make sure that we
+  // return the value that is used in this loop so that we can replace it later.
+  if (StripedOffRecurrenceCast)
+    Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast);
+
+  return Stride;
+}
+
+void LoopVectorizationLegality::collectStridedAcccess(Value *MemAccess) {
+  Value *Ptr = 0;
+  if (LoadInst *LI = dyn_cast<LoadInst>(MemAccess))
+    Ptr = LI->getPointerOperand();
+  else if (StoreInst *SI = dyn_cast<StoreInst>(MemAccess))
+    Ptr = SI->getPointerOperand();
+  else
+    return;
+
+  Value *Stride = getStrideFromPointer(Ptr, SE, DL, TheLoop);
+  if (!Stride)
+    return;
+
+  DEBUG(dbgs() << "LV: Found a strided access that we can version");
+  DEBUG(dbgs() << "  Ptr: " << *Ptr << " Stride: " << *Stride << "\n");
+  Strides[Ptr] = Stride;
+  StrideSet.insert(Stride);
+}
+
 void LoopVectorizationLegality::collectLoopUniforms() {
   // We now know that the loop is vectorizable!
   // Collect variables that will remain uniform after vectorization.
@@ -3201,7 +3513,8 @@ public:
   /// non-intersection.
   bool canCheckPtrAtRT(LoopVectorizationLegality::RuntimePointerCheck &RtCheck,
                        unsigned &NumComparisons, ScalarEvolution *SE,
-                       Loop *TheLoop, bool ShouldCheckStride = false);
+                       Loop *TheLoop, ValueToValueMap &Strides,
+                       bool ShouldCheckStride = false);
 
   /// \brief Goes over all memory accesses, checks whether a RT check is needed
   /// and builds sets of dependent accesses.
@@ -3261,8 +3574,9 @@ private:
 } // end anonymous namespace
 
 /// \brief Check whether a pointer can participate in a runtime bounds check.
-static bool hasComputableBounds(ScalarEvolution *SE, Value *Ptr) {
-  const SCEV *PtrScev = SE->getSCEV(Ptr);
+static bool hasComputableBounds(ScalarEvolution *SE, ValueToValueMap &Strides,
+                                Value *Ptr) {
+  const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
   if (!AR)
     return false;
@@ -3273,12 +3587,12 @@ static bool hasComputableBounds(ScalarEvolution *SE, Value *Ptr) {
 /// \brief Check the stride of the pointer and ensure that it does not wrap in
 /// the address space.
 static int isStridedPtr(ScalarEvolution *SE, DataLayout *DL, Value *Ptr,
-                        const Loop *Lp);
+                        const Loop *Lp, ValueToValueMap &StridesMap);
 
 bool AccessAnalysis::canCheckPtrAtRT(
-                       LoopVectorizationLegality::RuntimePointerCheck &RtCheck,
-                        unsigned &NumComparisons, ScalarEvolution *SE,
-                        Loop *TheLoop, bool ShouldCheckStride) {
+    LoopVectorizationLegality::RuntimePointerCheck &RtCheck,
+    unsigned &NumComparisons, ScalarEvolution *SE, Loop *TheLoop,
+    ValueToValueMap &StridesMap, bool ShouldCheckStride) {
   // Find pointers with computable bounds. We are going to use this information
   // to place a runtime bound check.
   unsigned NumReadPtrChecks = 0;
@@ -3306,10 +3620,11 @@ bool AccessAnalysis::canCheckPtrAtRT(
     else
       ++NumReadPtrChecks;
 
-    if (hasComputableBounds(SE, Ptr) &&
+    if (hasComputableBounds(SE, StridesMap, Ptr) &&
         // When we run after a failing dependency check we have to make sure we
         // don't have wrapping pointers.
-        (!ShouldCheckStride || isStridedPtr(SE, DL, Ptr, TheLoop) == 1)) {
+        (!ShouldCheckStride ||
+         isStridedPtr(SE, DL, Ptr, TheLoop, StridesMap) == 1)) {
       // The id of the dependence set.
       unsigned DepId;
 
@@ -3323,7 +3638,7 @@ bool AccessAnalysis::canCheckPtrAtRT(
         // Each access has its own dependence set.
         DepId = RunningDepId++;
 
-      RtCheck.insert(SE, TheLoop, Ptr, IsWrite, DepId);
+      RtCheck.insert(SE, TheLoop, Ptr, IsWrite, DepId, StridesMap);
 
       DEBUG(dbgs() << "LV: Found a runtime check ptr:" << *Ptr << '\n');
     } else {
@@ -3517,7 +3832,7 @@ public:
   ///
   /// Only checks sets with elements in \p CheckDeps.
   bool areDepsSafe(AccessAnalysis::DepCandidates &AccessSets,
-                   MemAccessInfoSet &CheckDeps);
+                   MemAccessInfoSet &CheckDeps, ValueToValueMap &Strides);
 
   /// \brief The maximum number of bytes of a vector register we can vectorize
   /// the accesses safely with.
@@ -3561,7 +3876,8 @@ private:
   /// distance is smaller than any other distance encountered so far).
   /// Otherwise, this function returns true signaling a possible dependence.
   bool isDependent(const MemAccessInfo &A, unsigned AIdx,
-                   const MemAccessInfo &B, unsigned BIdx);
+                   const MemAccessInfo &B, unsigned BIdx,
+                   ValueToValueMap &Strides);
 
   /// \brief Check whether the data dependence could prevent store-load
   /// forwarding.
@@ -3578,7 +3894,7 @@ static bool isInBoundsGep(Value *Ptr) {
 
 /// \brief Check whether the access through \p Ptr has a constant stride.
 static int isStridedPtr(ScalarEvolution *SE, DataLayout *DL, Value *Ptr,
-                        const Loop *Lp) {
+                        const Loop *Lp, ValueToValueMap &StridesMap) {
   const Type *Ty = Ptr->getType();
   assert(Ty->isPointerTy() && "Unexpected non-ptr");
 
@@ -3590,7 +3906,8 @@ static int isStridedPtr(ScalarEvolution *SE, DataLayout *DL, Value *Ptr,
     return 0;
   }
 
-  const SCEV *PtrScev = SE->getSCEV(Ptr);
+  const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Ptr);
+
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
   if (!AR) {
     DEBUG(dbgs() << "LV: Bad stride - Not an AddRecExpr pointer "
@@ -3694,7 +4011,8 @@ bool MemoryDepChecker::couldPreventStoreLoadForward(unsigned Distance,
 }
 
 bool MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
-                                   const MemAccessInfo &B, unsigned BIdx) {
+                                   const MemAccessInfo &B, unsigned BIdx,
+                                   ValueToValueMap &Strides) {
   assert (AIdx < BIdx && "Must pass arguments in program order");
 
   Value *APtr = A.getPointer();
@@ -3706,11 +4024,11 @@ bool MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
   if (!AIsWrite && !BIsWrite)
     return false;
 
-  const SCEV *AScev = SE->getSCEV(APtr);
-  const SCEV *BScev = SE->getSCEV(BPtr);
+  const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, APtr);
+  const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, BPtr);
 
-  int StrideAPtr = isStridedPtr(SE, DL, APtr, InnermostLoop);
-  int StrideBPtr = isStridedPtr(SE, DL, BPtr, InnermostLoop);
+  int StrideAPtr = isStridedPtr(SE, DL, APtr, InnermostLoop, Strides);
+  int StrideBPtr = isStridedPtr(SE, DL, BPtr, InnermostLoop, Strides);
 
   const SCEV *Src = AScev;
   const SCEV *Sink = BScev;
@@ -3815,9 +4133,9 @@ bool MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
   return false;
 }
 
-bool
-MemoryDepChecker::areDepsSafe(AccessAnalysis::DepCandidates &AccessSets,
-                              MemAccessInfoSet &CheckDeps) {
+bool MemoryDepChecker::areDepsSafe(AccessAnalysis::DepCandidates &AccessSets,
+                                   MemAccessInfoSet &CheckDeps,
+                                   ValueToValueMap &Strides) {
 
   MaxSafeDepDistBytes = -1U;
   while (!CheckDeps.empty()) {
@@ -3841,9 +4159,9 @@ MemoryDepChecker::areDepsSafe(AccessAnalysis::DepCandidates &AccessSets,
              I1E = Accesses[*AI].end(); I1 != I1E; ++I1)
           for (std::vector<unsigned>::iterator I2 = Accesses[*OI].begin(),
                I2E = Accesses[*OI].end(); I2 != I2E; ++I2) {
-            if (*I1 < *I2 && isDependent(*AI, *I1, *OI, *I2))
+            if (*I1 < *I2 && isDependent(*AI, *I1, *OI, *I2, Strides))
               return false;
-            if (*I2 < *I1 && isDependent(*OI, *I2, *AI, *I1))
+            if (*I2 < *I1 && isDependent(*OI, *I2, *AI, *I1, Strides))
               return false;
           }
         ++OI;
@@ -3974,7 +4292,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() {
     // read a few words, modify, and write a few words, and some of the
     // words may be written to the same address.
     bool IsReadOnlyPtr = false;
-    if (Seen.insert(Ptr) || !isStridedPtr(SE, DL, Ptr, TheLoop)) {
+    if (Seen.insert(Ptr) || !isStridedPtr(SE, DL, Ptr, TheLoop, Strides)) {
       ++NumReads;
       IsReadOnlyPtr = true;
     }
@@ -3998,8 +4316,8 @@ bool LoopVectorizationLegality::canVectorizeMemory() {
   unsigned NumComparisons = 0;
   bool CanDoRT = false;
   if (NeedRTCheck)
-    CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, NumComparisons, SE, TheLoop);
-
+    CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, NumComparisons, SE, TheLoop,
+                                       Strides);
 
   DEBUG(dbgs() << "LV: We need to do " << NumComparisons <<
         " pointer comparisons.\n");
@@ -4032,8 +4350,8 @@ bool LoopVectorizationLegality::canVectorizeMemory() {
   bool CanVecMem = true;
   if (Accesses.isDependencyCheckNeeded()) {
     DEBUG(dbgs() << "LV: Checking memory dependencies\n");
-    CanVecMem = DepChecker.areDepsSafe(DependentAccesses,
-                                       Accesses.getDependenciesToCheck());
+    CanVecMem = DepChecker.areDepsSafe(
+        DependentAccesses, Accesses.getDependenciesToCheck(), Strides);
     MaxSafeDepDistBytes = DepChecker.getMaxSafeDepDistBytes();
 
     if (!CanVecMem && DepChecker.shouldRetryWithRuntimeCheck()) {
@@ -4047,7 +4365,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() {
       PtrRtCheck.Need = true;
 
       CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, NumComparisons, SE,
-                                         TheLoop, true);
+                                         TheLoop, Strides, true);
       // Check that we did not collect too many pointers or found an unsizeable
       // pointer.
       if (!CanDoRT || NumComparisons > RuntimeMemoryCheckThreshold) {
@@ -4867,6 +5185,12 @@ static bool isLikelyComplexAddressComputation(Value *Ptr,
   return StepVal > MaxMergeDistance;
 }
 
+static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) {
+  if (Legal->hasStride(I->getOperand(0)) || Legal->hasStride(I->getOperand(1)))
+    return true;
+  return false;
+}
+
 unsigned
 LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) {
   // If we know that this instruction will remain uniform, check the cost of
@@ -4909,6 +5233,9 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) {
   case Instruction::And:
   case Instruction::Or:
   case Instruction::Xor: {
+    // Since we will replace the stride by 1 the multiplication should go away.
+    if (I->getOpcode() == Instruction::Mul && isStrideMul(I, Legal))
+      return 0;
     // Certain instructions can be cheaper to vectorize if they have a constant
     // second vector operand. One example of this are shifts on x86.
     TargetTransformInfo::OperandValueKind Op1VK =
@@ -5155,9 +5482,7 @@ void InnerLoopUnroller::scalarizeInstruction(Instruction *Instr) {
   }
 }
 
-void
-InnerLoopUnroller::vectorizeMemoryInstruction(Instruction *Instr,
-                                              LoopVectorizationLegality*) {
+void InnerLoopUnroller::vectorizeMemoryInstruction(Instruction *Instr) {
   return scalarizeInstruction(Instr);
 }
 
index a2b9ad94c83766e000fd9e4f7fddbadeb2e8b2d7..e7b1e2a6b72cc0ff97d154d0ca6d1d6cd64589b2 100644 (file)
@@ -7,11 +7,13 @@ target triple = "x86_64-apple-macosx10.8.0"
 ;CHECK: br
 ;CHECK: getelementptr
 ;CHECK-NEXT: getelementptr
-;CHECK-NEXT: icmp uge
-;CHECK-NEXT: icmp uge
-;CHECK-NEXT: icmp uge
-;CHECK-NEXT: icmp uge
-;CHECK-NEXT: and
+;CHECK-DAG: icmp uge
+;CHECK-DAG: icmp uge
+;CHECK-DAG: icmp uge
+;CHECK-DAG: icmp uge
+;CHECK-DAG: and
+;CHECK-DAG: and
+;CHECK: br
 ;CHECK: ret
 define void @add_ints(i32* nocapture %A, i32* nocapture %B, i32* nocapture %C) {
 entry:
diff --git a/test/Transforms/LoopVectorize/version-mem-access.ll b/test/Transforms/LoopVectorize/version-mem-access.ll
new file mode 100644 (file)
index 0000000..e712728
--- /dev/null
@@ -0,0 +1,50 @@
+; RUN: opt -basicaa -loop-vectorize -enable-mem-access-versioning -force-vector-width=2 -force-vector-unroll=1 < %s -S | FileCheck %s
+
+target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
+
+; CHECK-LABEL: test
+define void @test(i32* noalias %A, i64 %AStride,
+                  i32* noalias %B, i32 %BStride,
+                  i32* noalias %C, i64 %CStride, i32 %N) {
+entry:
+  %cmp13 = icmp eq i32 %N, 0
+  br i1 %cmp13, label %for.end, label %for.body.preheader
+
+; CHECK-DAG: icmp ne i64 %AStride, 1
+; CHECK-DAG: icmp ne i32 %BStride, 1
+; CHECK-DAG: icmp ne i64 %CStride, 1
+; CHECK: or
+; CHECK: or
+; CHECK: br
+
+; CHECK: vector.body
+; CHECK: load <2 x i32>
+
+for.body.preheader:
+  br label %for.body
+
+for.body:
+  %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ]
+  %iv.trunc = trunc i64 %indvars.iv to i32
+  %mul = mul i32 %iv.trunc, %BStride
+  %mul64 = zext i32 %mul to i64
+  %arrayidx = getelementptr inbounds i32* %B, i64 %mul64
+  %0 = load i32* %arrayidx, align 4
+  %mul2 = mul nsw i64 %indvars.iv, %CStride
+  %arrayidx3 = getelementptr inbounds i32* %C, i64 %mul2
+  %1 = load i32* %arrayidx3, align 4
+  %mul4 = mul nsw i32 %1, %0
+  %mul3 = mul nsw i64 %indvars.iv, %AStride
+  %arrayidx7 = getelementptr inbounds i32* %A, i64 %mul3
+  store i32 %mul4, i32* %arrayidx7, align 4
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %lftr.wideiv = trunc i64 %indvars.iv.next to i32
+  %exitcond = icmp eq i32 %lftr.wideiv, %N
+  br i1 %exitcond, label %for.end.loopexit, label %for.body
+
+for.end.loopexit:
+  br label %for.end
+
+for.end:
+  ret void
+}