Allow LLE/LD and the loop versioning infrastructure to use SCEV predicates
authorSilviu Baranga <silviu.baranga@arm.com>
Mon, 9 Nov 2015 13:26:09 +0000 (13:26 +0000)
committerSilviu Baranga <silviu.baranga@arm.com>
Mon, 9 Nov 2015 13:26:09 +0000 (13:26 +0000)
Summary:
LAA currently generates a set of SCEV predicates that must be checked by users.
In the case of Loop Distribute/Loop Load Elimination, no such predicates could have
been emitted, since we don't allow stride versioning. However, in the future there
could be SCEV predicates that will need to be checked.

This change adds support for SCEV predicate versioning in the Loop Distribute, Loop
Load Eliminate and the loop versioning infrastructure.

Reviewers: anemet

Subscribers: mssimpso, sanjoy, llvm-commits

Differential Revision: http://reviews.llvm.org/D14240

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@252467 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
include/llvm/Transforms/Utils/LoopVersioning.h
lib/Transforms/Scalar/LoopDistribute.cpp
lib/Transforms/Scalar/LoopLoadElimination.cpp
lib/Transforms/Utils/LoopVersioning.cpp
test/Transforms/LoopDistribute/basic-with-memchecks.ll
test/Transforms/LoopLoadElim/forward.ll
test/Transforms/LoopLoadElim/memcheck.ll

index c180ce37e39ea791c0ccc1b20d793a0d55becc6a..1bd7fd0db55b28635f289cf997d2d802fbfbe277 100644 (file)
@@ -193,7 +193,7 @@ namespace llvm {
 
     /// \brief Returns the estimated complexity of this predicate.
     /// This is roughly measured in the number of run-time checks required.
-    virtual unsigned getComplexity() { return 1; }
+    virtual unsigned getComplexity() const { return 1; }
 
     /// \brief Returns true if the predicate is always true. This means that no
     /// assumptions were made and nothing needs to be checked at run-time.
@@ -303,7 +303,7 @@ namespace llvm {
 
     /// \brief We estimate the complexity of a union predicate as the size
     /// number of predicates in the union.
-    unsigned getComplexity() override { return Preds.size(); }
+    unsigned getComplexity() const override { return Preds.size(); }
 
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
     static inline bool classof(const SCEVPredicate *P) {
index 41eb50c766206405a9827632db00228540a0155f..3b70594e0b632d350dbe1e2d7ea056b3940be0b0 100644 (file)
@@ -17,6 +17,7 @@
 #define LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H
 
 #include "llvm/Analysis/LoopAccessAnalysis.h"
+#include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Transforms/Utils/ValueMapper.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
 
@@ -25,6 +26,7 @@ namespace llvm {
 class Loop;
 class LoopAccessInfo;
 class LoopInfo;
+class ScalarEvolution;
 
 /// \brief This class emits a version of the loop where run-time checks ensure
 /// that may-alias pointers can't overlap.
@@ -33,16 +35,13 @@ class LoopInfo;
 /// already has a preheader.
 class LoopVersioning {
 public:
-  /// \brief Expects MemCheck, LoopAccessInfo, Loop, LoopInfo, DominatorTree
-  /// as input. It uses runtime check provided by user.
-  LoopVersioning(SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks,
-                 const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
-                 DominatorTree *DT);
-
   /// \brief Expects LoopAccessInfo, Loop, LoopInfo, DominatorTree as input.
-  /// It uses default runtime check provided by LoopAccessInfo.
-  LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L, LoopInfo *LI,
-                 DominatorTree *DT);
+  /// It uses runtime check provided by the user. If \p UseLAIChecks is true,
+  /// we will retain the default checks made by LAI. Otherwise, construct an
+  /// object having no checks and we expect the user to add them.
+  LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
+                 DominatorTree *DT, ScalarEvolution *SE,
+                 bool UseLAIChecks = true);
 
   /// \brief Performs the CFG manipulation part of versioning the loop including
   /// the DominatorTree and LoopInfo updates.
@@ -72,6 +71,13 @@ public:
   /// loop may alias (i.e. one of the memchecks failed).
   Loop *getNonVersionedLoop() { return NonVersionedLoop; }
 
+  /// \brief Sets the runtime alias checks for versioning the loop.
+  void setAliasChecks(
+      const SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks);
+
+  /// \brief Sets the runtime SCEV checks for versioning the loop.
+  void setSCEVChecks(SCEVUnionPredicate Check);
+
 private:
   /// \brief Adds the necessary PHI nodes for the versioned loops based on the
   /// loop-defined values used outside of the loop.
@@ -91,13 +97,17 @@ private:
   /// in NonVersionedLoop.
   ValueToValueMapTy VMap;
 
-  /// \brief The set of checks that we are versioning for.
-  SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks;
+  /// \brief The set of alias checks that we are versioning for.
+  SmallVector<RuntimePointerChecking::PointerCheck, 4> AliasChecks;
+
+  /// \brief The set of SCEV checks that we are versioning for.
+  SCEVUnionPredicate Preds;
 
   /// \brief Analyses used.
   const LoopAccessInfo &LAI;
   LoopInfo *LI;
   DominatorTree *DT;
+  ScalarEvolution *SE;
 };
 }
 
index 1584f0fa3ebac4b3d052f485e6383509dbde64d1..67ebd2532b1614f7702e2a0decc9e5b4b6819813 100644 (file)
@@ -55,6 +55,11 @@ static cl::opt<bool> DistributeNonIfConvertible(
              "if-convertible by the loop vectorizer"),
     cl::init(false));
 
+static cl::opt<unsigned> DistributeSCEVCheckThreshold(
+    "loop-distribute-scev-check-threshold", cl::init(8), cl::Hidden,
+    cl::desc("The maximum number of SCEV checks allowed for Loop "
+             "Distribution"));
+
 STATISTIC(NumLoopsDistributed, "Number of loops distributed");
 
 namespace {
@@ -577,6 +582,7 @@ public:
     LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
     LAA = &getAnalysis<LoopAccessAnalysis>();
     DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+    SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
 
     // Build up a worklist of inner-loops to vectorize. This is necessary as the
     // act of distributing a loop creates new loops and can invalidate iterators
@@ -599,6 +605,7 @@ public:
   }
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<ScalarEvolutionWrapperPass>();
     AU.addRequired<LoopInfoWrapperPass>();
     AU.addPreserved<LoopInfoWrapperPass>();
     AU.addRequired<LoopAccessAnalysis>();
@@ -753,6 +760,13 @@ private:
         return false;
     }
 
+    // Don't distribute the loop if we need too many SCEV run-time checks.
+    const SCEVUnionPredicate &Pred = LAI.Preds;
+    if (Pred.getComplexity() > DistributeSCEVCheckThreshold) {
+      DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");
+      return false;
+    }
+
     DEBUG(dbgs() << "\nDistributing loop: " << *L << "\n");
     // We're done forming the partitions set up the reverse mapping from
     // instructions to partitions.
@@ -764,17 +778,19 @@ private:
     if (!PH->getSinglePredecessor() || &*PH->begin() != PH->getTerminator())
       SplitBlock(PH, PH->getTerminator(), DT, LI);
 
-    // If we need run-time checks to disambiguate pointers are run-time, version
-    // the loop now.
+    // If we need run-time checks, version the loop now.
     auto PtrToPartition = Partitions.computePartitionSetForPointers(LAI);
     const auto *RtPtrChecking = LAI.getRuntimePointerChecking();
     const auto &AllChecks = RtPtrChecking->getChecks();
     auto Checks = includeOnlyCrossPartitionChecks(AllChecks, PtrToPartition,
                                                   RtPtrChecking);
-    if (!Checks.empty()) {
+
+    if (!Pred.isAlwaysTrue() || !Checks.empty()) {
       DEBUG(dbgs() << "\nPointers:\n");
       DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks));
-      LoopVersioning LVer(std::move(Checks), LAI, L, LI, DT);
+      LoopVersioning LVer(LAI, L, LI, DT, SE, false);
+      LVer.setAliasChecks(std::move(Checks));
+      LVer.setSCEVChecks(LAI.Preds);
       LVer.versionLoop(DefsUsedOutside);
     }
 
@@ -801,6 +817,7 @@ private:
   LoopInfo *LI;
   LoopAccessAnalysis *LAA;
   DominatorTree *DT;
+  ScalarEvolution *SE;
 };
 } // anonymous namespace
 
@@ -811,6 +828,7 @@ INITIALIZE_PASS_BEGIN(LoopDistribute, LDIST_NAME, ldist_name, false, false)
 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(LoopAccessAnalysis)
 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
 INITIALIZE_PASS_END(LoopDistribute, LDIST_NAME, ldist_name, false, false)
 
 namespace llvm {
index e0456a2110dd067b5a24604d8086cd758d32e77e..7c7bf64ba79c81ee7a9e6fc06a80817c78bf2d38 100644 (file)
@@ -41,6 +41,12 @@ static cl::opt<unsigned> CheckPerElim(
     cl::desc("Max number of memchecks allowed per eliminated load on average"),
     cl::init(1));
 
+static cl::opt<unsigned> LoadElimSCEVCheckThreshold(
+    "loop-load-elimination-scev-check-threshold", cl::init(8), cl::Hidden,
+    cl::desc("The maximum number of SCEV checks allowed for Loop "
+             "Load Elimination"));
+
+
 STATISTIC(NumLoopLoadEliminted, "Number of loads eliminated by LLE");
 
 namespace {
@@ -453,10 +459,17 @@ public:
       return false;
     }
 
+    if (LAI.Preds.getComplexity() > LoadElimSCEVCheckThreshold) {
+      DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");
+      return false;
+    }
+
     // Point of no-return, start the transformation.  First, version the loop if
     // necessary.
-    if (!Checks.empty()) {
-      LoopVersioning LV(std::move(Checks), LAI, L, LI, DT);
+    if (!Checks.empty() || !LAI.Preds.isAlwaysTrue()) {
+      LoopVersioning LV(LAI, L, LI, DT, SE, false);
+      LV.setAliasChecks(std::move(Checks));
+      LV.setSCEVChecks(LAI.Preds);
       LV.versionLoop();
     }
 
index bf7ed73ff01b78db659cb25d71297523e73c28d8..a77c3642a56cae06e85409af68f8eec74937d990 100644 (file)
 
 #include "llvm/Analysis/LoopAccessAnalysis.h"
 #include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolutionExpander.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 
 using namespace llvm;
 
-LoopVersioning::LoopVersioning(
-    SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks,
-    const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, DominatorTree *DT)
-    : VersionedLoop(L), NonVersionedLoop(nullptr), Checks(std::move(Checks)),
-      LAI(LAI), LI(LI), DT(DT) {
+LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
+                               DominatorTree *DT, ScalarEvolution *SE,
+                               bool UseLAIChecks)
+    : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT),
+      SE(SE) {
   assert(L->getExitBlock() && "No single exit block");
   assert(L->getLoopPreheader() && "No preheader");
+  if (UseLAIChecks) {
+    setAliasChecks(LAI.getRuntimePointerChecking()->getChecks());
+    setSCEVChecks(LAI.Preds);
+  }
 }
 
-LoopVersioning::LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L,
-                               LoopInfo *LI, DominatorTree *DT)
-    : VersionedLoop(L), NonVersionedLoop(nullptr),
-      Checks(LAInfo.getRuntimePointerChecking()->getChecks()), LAI(LAInfo),
-      LI(LI), DT(DT) {
-  assert(L->getExitBlock() && "No single exit block");
-  assert(L->getLoopPreheader() && "No preheader");
+void LoopVersioning::setAliasChecks(
+    const SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks) {
+  AliasChecks = std::move(Checks);
+}
+
+void LoopVersioning::setSCEVChecks(SCEVUnionPredicate Check) {
+  Preds = std::move(Check);
 }
 
 void LoopVersioning::versionLoop(
     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
   Instruction *FirstCheckInst;
   Instruction *MemRuntimeCheck;
+  Value *SCEVRuntimeCheck;
+  Value *RuntimeCheck = nullptr;
+
   // Add the memcheck in the original preheader (this is empty initially).
-  BasicBlock *MemCheckBB = VersionedLoop->getLoopPreheader();
+  BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
   std::tie(FirstCheckInst, MemRuntimeCheck) =
-      LAI.addRuntimeChecks(MemCheckBB->getTerminator(), Checks);
+      LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks);
   assert(MemRuntimeCheck && "called even though needsAnyChecking = false");
 
+  const SCEVUnionPredicate &Pred = LAI.Preds;
+  SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
+                   "scev.check");
+  SCEVRuntimeCheck =
+      Exp.expandCodeForPredicate(&Pred, RuntimeCheckBB->getTerminator());
+  auto *CI = dyn_cast<ConstantInt>(SCEVRuntimeCheck);
+
+  // Discard the SCEV runtime check if it is always true.
+  if (CI && CI->isZero())
+    SCEVRuntimeCheck = nullptr;
+
+  if (MemRuntimeCheck && SCEVRuntimeCheck) {
+    RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck,
+                                          SCEVRuntimeCheck, "ldist.safe");
+    if (auto *I = dyn_cast<Instruction>(RuntimeCheck))
+      I->insertBefore(RuntimeCheckBB->getTerminator());
+  } else
+    RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
+
+  assert(RuntimeCheck && "called even though we don't need "
+                         "any runtime checks");
+
   // Rename the block to make the IR more readable.
-  MemCheckBB->setName(VersionedLoop->getHeader()->getName() + ".lver.memcheck");
+  RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
+                          ".lver.check");
 
   // Create empty preheader for the loop (and after cloning for the
   // non-versioned loop).
-  BasicBlock *PH = SplitBlock(MemCheckBB, MemCheckBB->getTerminator(), DT, LI);
+  BasicBlock *PH =
+      SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI);
   PH->setName(VersionedLoop->getHeader()->getName() + ".ph");
 
   // Clone the loop including the preheader.
@@ -65,20 +97,19 @@ void LoopVersioning::versionLoop(
   // block is a join between the two loops.
   SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
   NonVersionedLoop =
-      cloneLoopWithPreheader(PH, MemCheckBB, VersionedLoop, VMap, ".lver.orig",
-                             LI, DT, NonVersionedLoopBlocks);
+      cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
+                             ".lver.orig", LI, DT, NonVersionedLoopBlocks);
   remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
 
   // Insert the conditional branch based on the result of the memchecks.
-  Instruction *OrigTerm = MemCheckBB->getTerminator();
+  Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
   BranchInst::Create(NonVersionedLoop->getLoopPreheader(),
-                     VersionedLoop->getLoopPreheader(), MemRuntimeCheck,
-                     OrigTerm);
+                     VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm);
   OrigTerm->eraseFromParent();
 
   // The loops merge in the original exit block.  This is now dominated by the
   // memchecking block.
-  DT->changeImmediateDominator(VersionedLoop->getExitBlock(), MemCheckBB);
+  DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
 
   // Adds the necessary PHI nodes for the versioned loops based on the
   // loop-defined values used outside of the loop.
index 3aced48504111637a7d644fbdb73e3cb91371a75..dce5698595ac6a622853cae218ebb4889a6af987 100644 (file)
@@ -36,7 +36,7 @@ entry:
 ; Since the checks to A and A + 4 get merged, this will give us a
 ; total of 8 compares.
 ;
-; CHECK: for.body.lver.memcheck:
+; CHECK: for.body.lver.check:
 ; CHECK:     = icmp
 ; CHECK:     = icmp
 
index 1a77297a064198cf368987dc11f67e60830931d2..c2b1816530c10d762c588d4879b4663970d4e065 100644 (file)
@@ -11,7 +11,7 @@ target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
 
 define void @f(i32* %A, i32* %B, i32* %C, i64 %N) {
 
-; CHECK:   for.body.lver.memcheck:
+; CHECK:   for.body.lver.check:
 ; CHECK:     %found.conflict{{.*}} =
 ; CHECK-NOT: %found.conflict{{.*}} =
 
index ebb5282575407d338b4219c7c413bde0da6bce92..8eadd437a5ac3929668f4998fd52552cd703ab89 100644 (file)
@@ -16,7 +16,7 @@ define void @f(i32*  %A, i32*  %B, i32*  %C, i64 %N, i32* %D) {
 entry:
   br label %for.body
 
-; AGGRESSIVE: for.body.lver.memcheck:
+; AGGRESSIVE: for.body.lver.check:
 ; AGGRESSIVE: %found.conflict{{.*}} =
 ; AGGRESSIVE: %found.conflict{{.*}} =
 ; AGGRESSIVE-NOT: %found.conflict{{.*}} =