Reformat.
[oota-llvm.git] / lib / Analysis / LoopAccessAnalysis.cpp
index b8b14a3a67721a208ab3da4d4c61f4b34b12c3ac..521b4e87fe5c66363e29b667286accf6a94c8c5b 100644 (file)
@@ -73,12 +73,11 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE,
   return SE->getSCEV(Ptr);
 }
 
-void LoopAccessAnalysis::RuntimePointerCheck::insert(ScalarEvolution *SE,
-                                                     Loop *Lp, Value *Ptr,
-                                                     bool WritePtr,
-                                                     unsigned DepSetId,
-                                                     unsigned ASId,
-                                                     ValueToValueMap &Strides) {
+void LoopAccessInfo::RuntimePointerCheck::insert(ScalarEvolution *SE, Loop *Lp,
+                                                 Value *Ptr, bool WritePtr,
+                                                 unsigned DepSetId,
+                                                 unsigned ASId,
+                                                 ValueToValueMap &Strides) {
   // Get the stride replaced scev.
   const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc);
@@ -93,6 +92,23 @@ void LoopAccessAnalysis::RuntimePointerCheck::insert(ScalarEvolution *SE,
   AliasSetId.push_back(ASId);
 }
 
+bool LoopAccessInfo::RuntimePointerCheck::needsChecking(unsigned I,
+                                                        unsigned J) const {
+  // No need to check if two readonly pointers intersect.
+  if (!IsWritePtr[I] && !IsWritePtr[J])
+    return false;
+
+  // Only need to check pointers between two different dependency sets.
+  if (DependencySetId[I] == DependencySetId[J])
+    return false;
+
+  // Only need to check pointers in the same alias set.
+  if (AliasSetId[I] != AliasSetId[J])
+    return false;
+
+  return true;
+}
+
 namespace {
 /// \brief Analyses memory accesses in a loop.
 ///
@@ -128,7 +144,7 @@ public:
 
   /// \brief Check whether we can check the pointers at runtime for
   /// non-intersection.
-  bool canCheckPtrAtRT(LoopAccessAnalysis::RuntimePointerCheck &RtCheck,
+  bool canCheckPtrAtRT(LoopAccessInfo::RuntimePointerCheck &RtCheck,
                        unsigned &NumComparisons,
                        ScalarEvolution *SE, Loop *TheLoop,
                        ValueToValueMap &Strides,
@@ -196,7 +212,7 @@ static int isStridedPtr(ScalarEvolution *SE, const DataLayout *DL, Value *Ptr,
                         const Loop *Lp, ValueToValueMap &StridesMap);
 
 bool AccessAnalysis::canCheckPtrAtRT(
-    LoopAccessAnalysis::RuntimePointerCheck &RtCheck,
+    LoopAccessInfo::RuntimePointerCheck &RtCheck,
     unsigned &NumComparisons, ScalarEvolution *SE, Loop *TheLoop,
     ValueToValueMap &StridesMap, bool ShouldCheckStride) {
   // Find pointers with computable bounds. We are going to use this information
@@ -286,7 +302,7 @@ bool AccessAnalysis::canCheckPtrAtRT(
       unsigned ASj = PtrJ->getType()->getPointerAddressSpace();
       if (ASi != ASj) {
         DEBUG(dbgs() << "LV: Runtime check would require comparison between"
-                       " different address spaces\n");
+                        " different address spaces\n");
         return false;
       }
     }
@@ -439,7 +455,7 @@ public:
   typedef SmallPtrSet<MemAccessInfo, 8> MemAccessInfoSet;
 
   MemoryDepChecker(ScalarEvolution *Se, const DataLayout *Dl, const Loop *L,
-                   const LoopAccessAnalysis::VectorizerParams &VectParams)
+                   const LoopAccessInfo::VectorizerParams &VectParams)
       : SE(Se), DL(Dl), InnermostLoop(L), AccessIdx(0),
         ShouldRetryWithRuntimeCheck(false), VectParams(VectParams) {}
 
@@ -497,7 +513,7 @@ private:
   bool ShouldRetryWithRuntimeCheck;
 
   /// \brief Vectorizer parameters used by the analysis.
-  LoopAccessAnalysis::VectorizerParams VectParams;
+  LoopAccessInfo::VectorizerParams VectParams;
 
   /// \brief Check whether there is a plausible dependence between the two
   /// accesses.
@@ -537,8 +553,8 @@ static int isStridedPtr(ScalarEvolution *SE, const DataLayout *DL, Value *Ptr,
   // Make sure that the pointer does not point to aggregate types.
   const PointerType *PtrTy = cast<PointerType>(Ty);
   if (PtrTy->getElementType()->isAggregateType()) {
-    DEBUG(dbgs() << "LV: Bad stride - Not a pointer to a scalar type" << *Ptr <<
-          "\n");
+    DEBUG(dbgs() << "LV: Bad stride - Not a pointer to a scalar type" << *Ptr
+                 << "\n");
     return 0;
   }
 
@@ -546,15 +562,15 @@ static int isStridedPtr(ScalarEvolution *SE, const DataLayout *DL, Value *Ptr,
 
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
   if (!AR) {
-    DEBUG(dbgs() << "LV: Bad stride - Not an AddRecExpr pointer "
-          << *Ptr << " SCEV: " << *PtrScev << "\n");
+    DEBUG(dbgs() << "LV: Bad stride - Not an AddRecExpr pointer " << *Ptr
+                 << " SCEV: " << *PtrScev << "\n");
     return 0;
   }
 
   // The accesss function must stride over the innermost loop.
   if (Lp != AR->getLoop()) {
-    DEBUG(dbgs() << "LV: Bad stride - Not striding over innermost loop " <<
-          *Ptr << " SCEV: " << *PtrScev << "\n");
+    DEBUG(dbgs() << "LV: Bad stride - Not striding over innermost loop " << *Ptr
+                 << " SCEV: " << *PtrScev << "\n");
   }
 
   // The address calculation must not wrap. Otherwise, a dependence could be
@@ -569,7 +585,7 @@ static int isStridedPtr(ScalarEvolution *SE, const DataLayout *DL, Value *Ptr,
   bool IsInAddressSpaceZero = PtrTy->getAddressSpace() == 0;
   if (!IsNoWrapAddRec && !IsInBoundsGEP && !IsInAddressSpaceZero) {
     DEBUG(dbgs() << "LV: Bad stride - Pointer may wrap in the address space "
-          << *Ptr << " SCEV: " << *PtrScev << "\n");
+                 << *Ptr << " SCEV: " << *PtrScev << "\n");
     return 0;
   }
 
@@ -579,8 +595,8 @@ static int isStridedPtr(ScalarEvolution *SE, const DataLayout *DL, Value *Ptr,
   // Calculate the pointer stride and check if it is consecutive.
   const SCEVConstant *C = dyn_cast<SCEVConstant>(Step);
   if (!C) {
-    DEBUG(dbgs() << "LV: Bad stride - Not a constant strided " << *Ptr <<
-          " SCEV: " << *PtrScev << "\n");
+    DEBUG(dbgs() << "LV: Bad stride - Not a constant strided " << *Ptr
+                 << " SCEV: " << *PtrScev << "\n");
     return 0;
   }
 
@@ -622,8 +638,9 @@ bool MemoryDepChecker::couldPreventStoreLoadForward(unsigned Distance,
   // Store-load forwarding distance.
   const unsigned NumCyclesForStoreLoadThroughMemory = 8*TypeByteSize;
   // Maximum vector factor.
-  unsigned MaxVFWithoutSLForwardIssues = VectParams.MaxVectorWidth*TypeByteSize;
-  if(MaxSafeDepDistBytes < MaxVFWithoutSLForwardIssues)
+  unsigned MaxVFWithoutSLForwardIssues =
+      VectParams.MaxVectorWidth * TypeByteSize;
+  if (MaxSafeDepDistBytes < MaxVFWithoutSLForwardIssues)
     MaxVFWithoutSLForwardIssues = MaxSafeDepDistBytes;
 
   for (unsigned vf = 2*TypeByteSize; vf <= MaxVFWithoutSLForwardIssues;
@@ -634,14 +651,14 @@ bool MemoryDepChecker::couldPreventStoreLoadForward(unsigned Distance,
     }
   }
 
-  if (MaxVFWithoutSLForwardIssues< 2*TypeByteSize) {
-    DEBUG(dbgs() << "LV: Distance " << Distance <<
-          " that could cause a store-load forwarding conflict\n");
+  if (MaxVFWithoutSLForwardIssues < 2 * TypeByteSize) {
+    DEBUG(dbgs() << "LV: Distance " << Distance
+                 << " that could cause a store-load forwarding conflict\n");
     return true;
   }
 
   if (MaxVFWithoutSLForwardIssues < MaxSafeDepDistBytes &&
-      MaxVFWithoutSLForwardIssues != VectParams.MaxVectorWidth*TypeByteSize)
+      MaxVFWithoutSLForwardIssues != VectParams.MaxVectorWidth * TypeByteSize)
     MaxSafeDepDistBytes = MaxVFWithoutSLForwardIssues;
   return false;
 }
@@ -689,9 +706,9 @@ bool MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
   const SCEV *Dist = SE->getMinusSCEV(Sink, Src);
 
   DEBUG(dbgs() << "LV: Src Scev: " << *Src << "Sink Scev: " << *Sink
-        << "(Induction step: " << StrideAPtr <<  ")\n");
+               << "(Induction step: " << StrideAPtr << ")\n");
   DEBUG(dbgs() << "LV: Distance for " << *InstMap[AIdx] << " to "
-        << *InstMap[BIdx] << ": " << *Dist << "\n");
+               << *InstMap[BIdx] << ": " << *Dist << "\n");
 
   // Need consecutive accesses. We don't want to vectorize
   // "A[B[i]] += ..." and similar code or pointer arithmetic that could wrap in
@@ -738,18 +755,19 @@ bool MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
 
   // Positive distance bigger than max vectorization factor.
   if (ATy != BTy) {
-    DEBUG(dbgs() <<
-          "LV: ReadWrite-Write positive dependency with different types\n");
+    DEBUG(dbgs()
+          << "LV: ReadWrite-Write positive dependency with different types\n");
     return false;
   }
 
   unsigned Distance = (unsigned) Val.getZExtValue();
 
   // Bail out early if passed-in parameters make vectorization not feasible.
-  unsigned ForcedFactor = (VectParams.VectorizationFactor ?
-                           VectParams.VectorizationFactor : 1);
-  unsigned ForcedUnroll = (VectParams.VectorizationInterleave ?
-                           VectParams.VectorizationInterleave : 1);
+  unsigned ForcedFactor =
+      (VectParams.VectorizationFactor ? VectParams.VectorizationFactor : 1);
+  unsigned ForcedUnroll =
+      (VectParams.VectorizationInterleave ? VectParams.VectorizationInterleave
+                                          : 1);
 
   // The distance must be bigger than the size needed for a vectorized version
   // of the operation and the size of the vectorized operation must not be
@@ -758,7 +776,7 @@ bool MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
       2*TypeByteSize > MaxSafeDepDistBytes ||
       Distance < TypeByteSize * ForcedUnroll * ForcedFactor) {
     DEBUG(dbgs() << "LV: Failure because of Positive distance "
-        << Val.getSExtValue() << '\n');
+                 << Val.getSExtValue() << '\n');
     return true;
   }
 
@@ -770,8 +788,9 @@ bool MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
       couldPreventStoreLoadForward(Distance, TypeByteSize))
      return true;
 
-  DEBUG(dbgs() << "LV: Positive distance " << Val.getSExtValue() <<
-        " with max VF = " << MaxSafeDepDistBytes / TypeByteSize << '\n');
+  DEBUG(dbgs() << "LV: Positive distance " << Val.getSExtValue()
+               << " with max VF = " << MaxSafeDepDistBytes / TypeByteSize
+               << '\n');
 
   return false;
 }
@@ -815,7 +834,7 @@ bool MemoryDepChecker::areDepsSafe(AccessAnalysis::DepCandidates &AccessSets,
   return true;
 }
 
-bool LoopAccessAnalysis::canVectorizeMemory(ValueToValueMap &Strides) {
+bool LoopAccessInfo::canVectorizeMemory(ValueToValueMap &Strides) {
 
   typedef SmallVector<Value*, 16> ValueVector;
   typedef SmallPtrSet<Value*, 16> ValueSet;
@@ -870,8 +889,8 @@ bool LoopAccessAnalysis::canVectorizeMemory(ValueToValueMap &Strides) {
       if (it->mayWriteToMemory()) {
         StoreInst *St = dyn_cast<StoreInst>(it);
         if (!St) {
-          emitAnalysis(VectorizationReport(it) <<
-                       "instruction cannot be vectorized");
+          emitAnalysis(VectorizationReport(it)
+                       << "instruction cannot be vectorized");
           return false;
         }
         if (!St->isSimple() && !IsAnnotatedParallel) {
@@ -929,7 +948,7 @@ bool LoopAccessAnalysis::canVectorizeMemory(ValueToValueMap &Strides) {
       // The TBAA metadata could have a control dependency on the predication
       // condition, so we cannot rely on it when determining whether or not we
       // need runtime pointer checks.
-      if (blockNeedsPredication(ST->getParent()))
+      if (blockNeedsPredication(ST->getParent(), TheLoop, DT))
         Loc.AATags.TBAA = nullptr;
 
       Accesses.addStore(Loc);
@@ -937,9 +956,8 @@ bool LoopAccessAnalysis::canVectorizeMemory(ValueToValueMap &Strides) {
   }
 
   if (IsAnnotatedParallel) {
-    DEBUG(dbgs()
-          << "LV: A loop annotated parallel, ignore memory dependency "
-          << "checks.\n");
+    DEBUG(dbgs() << "LV: A loop annotated parallel, ignore memory dependency "
+                 << "checks.\n");
     return true;
   }
 
@@ -965,7 +983,7 @@ bool LoopAccessAnalysis::canVectorizeMemory(ValueToValueMap &Strides) {
     // The TBAA metadata could have a control dependency on the predication
     // condition, so we cannot rely on it when determining whether or not we
     // need runtime pointer checks.
-    if (blockNeedsPredication(LD->getParent()))
+    if (blockNeedsPredication(LD->getParent(), TheLoop, DT))
       Loc.AATags.TBAA = nullptr;
 
     Accesses.addLoad(Loc, IsReadOnlyPtr);
@@ -991,8 +1009,8 @@ bool LoopAccessAnalysis::canVectorizeMemory(ValueToValueMap &Strides) {
     CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, NumComparisons, SE, TheLoop,
                                        Strides);
 
-  DEBUG(dbgs() << "LV: We need to do " << NumComparisons <<
-        " pointer comparisons.\n");
+  DEBUG(dbgs() << "LV: We need to do " << NumComparisons
+               << " pointer comparisons.\n");
 
   // If we only have one set of dependences to check pointers among we don't
   // need a runtime check.
@@ -1012,8 +1030,8 @@ bool LoopAccessAnalysis::canVectorizeMemory(ValueToValueMap &Strides) {
 
   if (NeedRTCheck && !CanDoRT) {
     emitAnalysis(VectorizationReport() << "cannot identify array bounds");
-    DEBUG(dbgs() << "LV: We can't vectorize because we can't find " <<
-          "the array bounds.\n");
+    DEBUG(dbgs() << "LV: We can't vectorize because we can't find "
+                 << "the array bounds.\n");
     PtrRtCheck.reset();
     return false;
   }
@@ -1060,16 +1078,17 @@ bool LoopAccessAnalysis::canVectorizeMemory(ValueToValueMap &Strides) {
   }
 
   if (!CanVecMem)
-    emitAnalysis(VectorizationReport() <<
-                 "unsafe dependent memory operations in loop");
+    emitAnalysis(VectorizationReport()
+                 << "unsafe dependent memory operations in loop");
 
-  DEBUG(dbgs() << "LV: We" << (NeedRTCheck ? "" : " don't") <<
-        " need a runtime memory check.\n");
+  DEBUG(dbgs() << "LV: We" << (NeedRTCheck ? "" : " don't")
+               << " need a runtime memory check.\n");
 
   return CanVecMem;
 }
 
-bool LoopAccessAnalysis::blockNeedsPredication(BasicBlock *BB)  {
+bool LoopAccessInfo::blockNeedsPredication(BasicBlock *BB, Loop *TheLoop,
+                                           DominatorTree *DT)  {
   assert(TheLoop->contains(BB) && "Unknown block used");
 
   // Blocks that do not dominate the latch need predication.
@@ -1077,11 +1096,11 @@ bool LoopAccessAnalysis::blockNeedsPredication(BasicBlock *BB)  {
   return !DT->dominates(BB, Latch);
 }
 
-void LoopAccessAnalysis::emitAnalysis(VectorizationReport &Message) {
+void LoopAccessInfo::emitAnalysis(VectorizationReport &Message) {
   VectorizationReport::emitAnalysis(Message, TheFunction, TheLoop);
 }
 
-bool LoopAccessAnalysis::isUniform(Value *V) {
+bool LoopAccessInfo::isUniform(Value *V) {
   return (SE->isLoopInvariant(SE->getSCEV(V), TheLoop));
 }
 
@@ -1097,7 +1116,7 @@ static Instruction *getFirstInst(Instruction *FirstInst, Value *V,
 }
 
 std::pair<Instruction *, Instruction *>
-LoopAccessAnalysis::addRuntimeCheck(Instruction *Loc) {
+LoopAccessInfo::addRuntimeCheck(Instruction *Loc) {
   Instruction *tnullptr = nullptr;
   if (!PtrRtCheck.Need)
     return std::pair<Instruction *, Instruction *>(tnullptr, tnullptr);
@@ -1115,8 +1134,8 @@ LoopAccessAnalysis::addRuntimeCheck(Instruction *Loc) {
     const SCEV *Sc = SE->getSCEV(Ptr);
 
     if (SE->isLoopInvariant(Sc, TheLoop)) {
-      DEBUG(dbgs() << "LV: Adding RT check for a loop invariant ptr:" <<
-            *Ptr <<"\n");
+      DEBUG(dbgs() << "LV: Adding RT check for a loop invariant ptr:" << *Ptr
+                   << "\n");
       Starts.push_back(Ptr);
       Ends.push_back(Ptr);
     } else {
@@ -1138,15 +1157,7 @@ LoopAccessAnalysis::addRuntimeCheck(Instruction *Loc) {
   Value *MemoryRuntimeCheck = nullptr;
   for (unsigned i = 0; i < NumPointers; ++i) {
     for (unsigned j = i+1; j < NumPointers; ++j) {
-      // No need to check if two readonly pointers intersect.
-      if (!PtrRtCheck.IsWritePtr[i] && !PtrRtCheck.IsWritePtr[j])
-        continue;
-
-      // Only need to check pointers between two different dependency sets.
-      if (PtrRtCheck.DependencySetId[i] == PtrRtCheck.DependencySetId[j])
-       continue;
-      // Only need to check pointers in the same alias set.
-      if (PtrRtCheck.AliasSetId[i] != PtrRtCheck.AliasSetId[j])
+      if (!PtrRtCheck.needsChecking(i, j))
         continue;
 
       unsigned AS0 = Starts[i]->getType()->getPointerAddressSpace();