Allow LLE/LD and the loop versioning infrastructure to use SCEV predicates
[oota-llvm.git] / lib / Transforms / Utils / LoopVersioning.cpp
index bf7ed73ff01b78db659cb25d71297523e73c28d8..a77c3642a56cae06e85409af68f8eec74937d990 100644 (file)
 
 #include "llvm/Analysis/LoopAccessAnalysis.h"
 #include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolutionExpander.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 
 using namespace llvm;
 
-LoopVersioning::LoopVersioning(
-    SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks,
-    const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, DominatorTree *DT)
-    : VersionedLoop(L), NonVersionedLoop(nullptr), Checks(std::move(Checks)),
-      LAI(LAI), LI(LI), DT(DT) {
+LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
+                               DominatorTree *DT, ScalarEvolution *SE,
+                               bool UseLAIChecks)
+    : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT),
+      SE(SE) {
   assert(L->getExitBlock() && "No single exit block");
   assert(L->getLoopPreheader() && "No preheader");
+  if (UseLAIChecks) {
+    setAliasChecks(LAI.getRuntimePointerChecking()->getChecks());
+    setSCEVChecks(LAI.Preds);
+  }
 }
 
-LoopVersioning::LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L,
-                               LoopInfo *LI, DominatorTree *DT)
-    : VersionedLoop(L), NonVersionedLoop(nullptr),
-      Checks(LAInfo.getRuntimePointerChecking()->getChecks()), LAI(LAInfo),
-      LI(LI), DT(DT) {
-  assert(L->getExitBlock() && "No single exit block");
-  assert(L->getLoopPreheader() && "No preheader");
+void LoopVersioning::setAliasChecks(
+    const SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks) {
+  AliasChecks = std::move(Checks);
+}
+
+void LoopVersioning::setSCEVChecks(SCEVUnionPredicate Check) {
+  Preds = std::move(Check);
 }
 
 void LoopVersioning::versionLoop(
     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
   Instruction *FirstCheckInst;
   Instruction *MemRuntimeCheck;
+  Value *SCEVRuntimeCheck;
+  Value *RuntimeCheck = nullptr;
+
   // Add the memcheck in the original preheader (this is empty initially).
-  BasicBlock *MemCheckBB = VersionedLoop->getLoopPreheader();
+  BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
   std::tie(FirstCheckInst, MemRuntimeCheck) =
-      LAI.addRuntimeChecks(MemCheckBB->getTerminator(), Checks);
+      LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks);
   assert(MemRuntimeCheck && "called even though needsAnyChecking = false");
 
+  const SCEVUnionPredicate &Pred = LAI.Preds;
+  SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
+                   "scev.check");
+  SCEVRuntimeCheck =
+      Exp.expandCodeForPredicate(&Pred, RuntimeCheckBB->getTerminator());
+  auto *CI = dyn_cast<ConstantInt>(SCEVRuntimeCheck);
+
+  // Discard the SCEV runtime check if it is always true.
+  if (CI && CI->isZero())
+    SCEVRuntimeCheck = nullptr;
+
+  if (MemRuntimeCheck && SCEVRuntimeCheck) {
+    RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck,
+                                          SCEVRuntimeCheck, "ldist.safe");
+    if (auto *I = dyn_cast<Instruction>(RuntimeCheck))
+      I->insertBefore(RuntimeCheckBB->getTerminator());
+  } else
+    RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
+
+  assert(RuntimeCheck && "called even though we don't need "
+                         "any runtime checks");
+
   // Rename the block to make the IR more readable.
-  MemCheckBB->setName(VersionedLoop->getHeader()->getName() + ".lver.memcheck");
+  RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
+                          ".lver.check");
 
   // Create empty preheader for the loop (and after cloning for the
   // non-versioned loop).
-  BasicBlock *PH = SplitBlock(MemCheckBB, MemCheckBB->getTerminator(), DT, LI);
+  BasicBlock *PH =
+      SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI);
   PH->setName(VersionedLoop->getHeader()->getName() + ".ph");
 
   // Clone the loop including the preheader.
@@ -65,20 +97,19 @@ void LoopVersioning::versionLoop(
   // block is a join between the two loops.
   SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
   NonVersionedLoop =
-      cloneLoopWithPreheader(PH, MemCheckBB, VersionedLoop, VMap, ".lver.orig",
-                             LI, DT, NonVersionedLoopBlocks);
+      cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
+                             ".lver.orig", LI, DT, NonVersionedLoopBlocks);
   remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
 
   // Insert the conditional branch based on the result of the memchecks.
-  Instruction *OrigTerm = MemCheckBB->getTerminator();
+  Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
   BranchInst::Create(NonVersionedLoop->getLoopPreheader(),
-                     VersionedLoop->getLoopPreheader(), MemRuntimeCheck,
-                     OrigTerm);
+                     VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm);
   OrigTerm->eraseFromParent();
 
   // The loops merge in the original exit block.  This is now dominated by the
   // memchecking block.
-  DT->changeImmediateDominator(VersionedLoop->getExitBlock(), MemCheckBB);
+  DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
 
   // Adds the necessary PHI nodes for the versioned loops based on the
   // loop-defined values used outside of the loop.