[LAA] Begin moving the logic of generating checks out of addRuntimeCheck
authorAdam Nemet <anemet@apple.com>
Sun, 26 Jul 2015 05:32:14 +0000 (05:32 +0000)
committerAdam Nemet <anemet@apple.com>
Sun, 26 Jul 2015 05:32:14 +0000 (05:32 +0000)
Summary:
The goal is to start moving us closer to the model where
RuntimePointerChecking will compute and store the checks.  Then a client
can filter the check according to its requirements and then use the
filtered list of checks with addRuntimeCheck.

Before the patch, this is all done in addRuntimeCheck.  So the patch
starts to split up addRuntimeCheck while providing the old API under
what's more or less a wrapper now.

The new underlying addRuntimeCheck takes a collection of checks now,
expands the code for the bounds then generates the code for the checks.

I am not completely happy with making expandBounds static because now it
needs so many explicit arguments but I don't want to make the type
PointerBounds part of LAI.  This should get fixed when addRuntimeCheck
is moved to LoopVersioning where it really belongs, IMO.

Audited the assembly diff of the testsuite (including externals).  There
is a tiny bit of assembly churn that is due to the different order the
code for the bounds is expanded now
(MultiSource/Benchmarks/Prolangs-C/bison/conflicts.s and with LoopDist
on 456.hmmer/fast_algorithms.s).

Reviewers: hfinkel

Subscribers: klimek, llvm-commits

Differential Revision: http://reviews.llvm.org/D11205

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@243239 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/LoopAccessAnalysis.h
lib/Analysis/LoopAccessAnalysis.cpp

index 0361d78b88fbd3f8a285b3ec5516f47a8a8fad5a..5656dca79541dddffcaff2ac5247941dfff288b4 100644 (file)
@@ -368,6 +368,15 @@ public:
     SmallVector<unsigned, 2> Members;
   };
 
+  /// \brief A memcheck which made up of a pair of grouped pointers.
+  ///
+  /// These *have* to be const for now, since checks are generated from
+  /// CheckingPtrGroups in LAI::addRuntimeCheck which is a const member
+  /// function.  FIXME: once check-generation is moved inside this class (after
+  /// the PtrPartition hack is removed), we could drop const.
+  typedef std::pair<const CheckingPtrGroup *, const CheckingPtrGroup *>
+      PointerCheck;
+
   /// \brief Groups pointers such that a single memcheck is required
   /// between two different groups. This will clear the CheckingGroups vector
   /// and re-compute it. We will only group dependecies if \p UseDependencies
@@ -488,6 +497,16 @@ public:
   addRuntimeCheck(Instruction *Loc,
                   const SmallVectorImpl<int> *PtrPartition = nullptr) const;
 
+  /// \brief Generete the instructions for the checks in \p PointerChecks.
+  ///
+  /// Returns a pair of instructions where the first element is the first
+  /// instruction generated in possibly a sequence of instructions and the
+  /// second value is the final comparator value or NULL if no check is needed.
+  std::pair<Instruction *, Instruction *>
+  addRuntimeCheck(Instruction *Loc,
+                  const SmallVectorImpl<RuntimePointerChecking::PointerCheck>
+                      &PointerChecks) const;
+
   /// \brief The diagnostics report generated for the analysis.  E.g. why we
   /// couldn't analyze the loop.
   const Optional<LoopAccessReport> &getReport() const { return Report; }
index 07ca4f513f63e50435d800770879753da5d2d50c..99b5aebaeabcdad016918fe39e40fd29133cd82f 100644 (file)
@@ -1586,86 +1586,107 @@ static Instruction *getFirstInst(Instruction *FirstInst, Value *V,
   return nullptr;
 }
 
+/// \brief IR Values for the lower and upper bounds of a pointer evolution.
+struct PointerBounds {
+  Value *Start;
+  Value *End;
+};
+
+/// \brief Expand code for the lower and upper bound of the pointer group \p CG
+/// in \p TheLoop.  \return the values for the bounds.
+static PointerBounds
+expandBounds(const RuntimePointerChecking::CheckingPtrGroup *CG, Loop *TheLoop,
+             Instruction *Loc, SCEVExpander &Exp, ScalarEvolution *SE,
+             const RuntimePointerChecking &PtrRtChecking) {
+  Value *Ptr = PtrRtChecking.Pointers[CG->Members[0]].PointerValue;
+  const SCEV *Sc = SE->getSCEV(Ptr);
+
+  if (SE->isLoopInvariant(Sc, TheLoop)) {
+    DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" << *Ptr
+                 << "\n");
+    return {Ptr, Ptr};
+  } else {
+    unsigned AS = Ptr->getType()->getPointerAddressSpace();
+    LLVMContext &Ctx = Loc->getContext();
+
+    // Use this type for pointer arithmetic.
+    Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS);
+    Value *Start = nullptr, *End = nullptr;
+
+    DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
+    Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc);
+    End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc);
+    DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High << "\n");
+    return {Start, End};
+  }
+}
+
+/// \brief Turns a collection of checks into a collection of expanded upper and
+/// lower bounds for both pointers in the check.
+static SmallVector<std::pair<PointerBounds, PointerBounds>, 4> expandBounds(
+    const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks,
+    Loop *L, Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp,
+    const RuntimePointerChecking &PtrRtChecking) {
+  SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
+
+  // Here we're relying on the SCEV Expander's cache to only emit code for the
+  // same bounds once.
+  std::transform(
+      PointerChecks.begin(), PointerChecks.end(),
+      std::back_inserter(ChecksWithBounds),
+      [&](const RuntimePointerChecking::PointerCheck &Check) {
+        return std::make_pair(
+            expandBounds(Check.first, L, Loc, Exp, SE, PtrRtChecking),
+            expandBounds(Check.second, L, Loc, Exp, SE, PtrRtChecking));
+      });
+
+  return ChecksWithBounds;
+}
+
 std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeCheck(
-    Instruction *Loc, const SmallVectorImpl<int> *PtrPartition) const {
-  if (!PtrRtChecking.Need)
-    return std::make_pair(nullptr, nullptr);
+    Instruction *Loc,
+    const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks)
+    const {
 
-  SmallVector<TrackingVH<Value>, 2> Starts;
-  SmallVector<TrackingVH<Value>, 2> Ends;
+  SCEVExpander Exp(*SE, DL, "induction");
+  auto ExpandedChecks =
+      expandBounds(PointerChecks, TheLoop, Loc, SE, Exp, PtrRtChecking);
 
   LLVMContext &Ctx = Loc->getContext();
-  SCEVExpander Exp(*SE, DL, "induction");
   Instruction *FirstInst = nullptr;
-
-  for (unsigned i = 0; i < PtrRtChecking.CheckingGroups.size(); ++i) {
-    const RuntimePointerChecking::CheckingPtrGroup &CG =
-        PtrRtChecking.CheckingGroups[i];
-    Value *Ptr = PtrRtChecking.Pointers[CG.Members[0]].PointerValue;
-    const SCEV *Sc = SE->getSCEV(Ptr);
-
-    if (SE->isLoopInvariant(Sc, TheLoop)) {
-      DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" << *Ptr
-                   << "\n");
-      Starts.push_back(Ptr);
-      Ends.push_back(Ptr);
-    } else {
-      unsigned AS = Ptr->getType()->getPointerAddressSpace();
-
-      // Use this type for pointer arithmetic.
-      Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS);
-      Value *Start = nullptr, *End = nullptr;
-
-      DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
-      Start = Exp.expandCodeFor(CG.Low, PtrArithTy, Loc);
-      End = Exp.expandCodeFor(CG.High, PtrArithTy, Loc);
-      DEBUG(dbgs() << "Start: " << *CG.Low << " End: " << *CG.High << "\n");
-      Starts.push_back(Start);
-      Ends.push_back(End);
-    }
-  }
-
   IRBuilder<> ChkBuilder(Loc);
   // Our instructions might fold to a constant.
   Value *MemoryRuntimeCheck = nullptr;
-  for (unsigned i = 0; i < PtrRtChecking.CheckingGroups.size(); ++i) {
-    for (unsigned j = i + 1; j < PtrRtChecking.CheckingGroups.size(); ++j) {
-      const RuntimePointerChecking::CheckingPtrGroup &CGI =
-          PtrRtChecking.CheckingGroups[i];
-      const RuntimePointerChecking::CheckingPtrGroup &CGJ =
-          PtrRtChecking.CheckingGroups[j];
 
-      if (!PtrRtChecking.needsChecking(CGI, CGJ, PtrPartition))
-        continue;
-
-      unsigned AS0 = Starts[i]->getType()->getPointerAddressSpace();
-      unsigned AS1 = Starts[j]->getType()->getPointerAddressSpace();
-
-      assert((AS0 == Ends[j]->getType()->getPointerAddressSpace()) &&
-             (AS1 == Ends[i]->getType()->getPointerAddressSpace()) &&
-             "Trying to bounds check pointers with different address spaces");
-
-      Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0);
-      Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1);
-
-      Value *Start0 = ChkBuilder.CreateBitCast(Starts[i], PtrArithTy0, "bc");
-      Value *Start1 = ChkBuilder.CreateBitCast(Starts[j], PtrArithTy1, "bc");
-      Value *End0 =   ChkBuilder.CreateBitCast(Ends[i],   PtrArithTy1, "bc");
-      Value *End1 =   ChkBuilder.CreateBitCast(Ends[j],   PtrArithTy0, "bc");
-
-      Value *Cmp0 = ChkBuilder.CreateICmpULE(Start0, End1, "bound0");
-      FirstInst = getFirstInst(FirstInst, Cmp0, Loc);
-      Value *Cmp1 = ChkBuilder.CreateICmpULE(Start1, End0, "bound1");
-      FirstInst = getFirstInst(FirstInst, Cmp1, Loc);
-      Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict");
+  for (const auto &Check : ExpandedChecks) {
+    const PointerBounds &A = Check.first, &B = Check.second;
+    unsigned AS0 = A.Start->getType()->getPointerAddressSpace();
+    unsigned AS1 = B.Start->getType()->getPointerAddressSpace();
+
+    assert((AS0 == B.End->getType()->getPointerAddressSpace()) &&
+           (AS1 == A.End->getType()->getPointerAddressSpace()) &&
+           "Trying to bounds check pointers with different address spaces");
+
+    Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0);
+    Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1);
+
+    Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc");
+    Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc");
+    Value *End0 =   ChkBuilder.CreateBitCast(A.End,   PtrArithTy1, "bc");
+    Value *End1 =   ChkBuilder.CreateBitCast(B.End,   PtrArithTy0, "bc");
+
+    Value *Cmp0 = ChkBuilder.CreateICmpULE(Start0, End1, "bound0");
+    FirstInst = getFirstInst(FirstInst, Cmp0, Loc);
+    Value *Cmp1 = ChkBuilder.CreateICmpULE(Start1, End0, "bound1");
+    FirstInst = getFirstInst(FirstInst, Cmp1, Loc);
+    Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict");
+    FirstInst = getFirstInst(FirstInst, IsConflict, Loc);
+    if (MemoryRuntimeCheck) {
+      IsConflict =
+          ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");
       FirstInst = getFirstInst(FirstInst, IsConflict, Loc);
-      if (MemoryRuntimeCheck) {
-        IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict,
-                                         "conflict.rdx");
-        FirstInst = getFirstInst(FirstInst, IsConflict, Loc);
-      }
-      MemoryRuntimeCheck = IsConflict;
     }
+    MemoryRuntimeCheck = IsConflict;
   }
 
   if (!MemoryRuntimeCheck)
@@ -1681,6 +1702,27 @@ std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeCheck(
   return std::make_pair(FirstInst, Check);
 }
 
+std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeCheck(
+    Instruction *Loc, const SmallVectorImpl<int> *PtrPartition) const {
+  if (!PtrRtChecking.Need)
+    return std::make_pair(nullptr, nullptr);
+
+  SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks;
+  for (unsigned i = 0; i < PtrRtChecking.CheckingGroups.size(); ++i) {
+    for (unsigned j = i + 1; j < PtrRtChecking.CheckingGroups.size(); ++j) {
+      const RuntimePointerChecking::CheckingPtrGroup &CGI =
+          PtrRtChecking.CheckingGroups[i];
+      const RuntimePointerChecking::CheckingPtrGroup &CGJ =
+          PtrRtChecking.CheckingGroups[j];
+
+      if (PtrRtChecking.needsChecking(CGI, CGJ, PtrPartition))
+        Checks.push_back(std::make_pair(&CGI, &CGJ));
+    }
+  }
+
+  return addRuntimeCheck(Loc, Checks);
+}
+
 LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
                                const DataLayout &DL,
                                const TargetLibraryInfo *TLI, AliasAnalysis *AA,