implement the "no aliasing accesses in loop" safety check. This pass
[oota-llvm.git] / lib / Transforms / Scalar / LoopIdiomRecognize.cpp
index f8748efeceaa2114188771659411355aee5106f8..56a35112aa1c2eaeeb60c4c26922325f0574ef2b 100644 (file)
@@ -15,6 +15,7 @@
 
 #define DEBUG_TYPE "loop-idiom"
 #include "llvm/Transforms/Scalar.h"
+#include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/LoopPass.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/ScalarEvolutionExpander.h"
@@ -59,6 +60,8 @@ namespace {
       AU.addPreservedID(LoopSimplifyID);
       AU.addRequiredID(LCSSAID);
       AU.addPreservedID(LCSSAID);
+      AU.addRequired<AliasAnalysis>();
+      AU.addPreserved<AliasAnalysis>();
       AU.addRequired<ScalarEvolution>();
       AU.addPreserved<ScalarEvolution>();
       AU.addPreserved<DominatorTree>();
@@ -73,6 +76,7 @@ INITIALIZE_PASS_DEPENDENCY(LoopInfo)
 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
 INITIALIZE_PASS_DEPENDENCY(LCSSA)
 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
+INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
 INITIALIZE_PASS_END(LoopIdiomRecognize, "loop-idiom", "Recognize loop idioms",
                     false, false)
 
@@ -141,13 +145,15 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L, LPPassManager &LPM) {
     StoreInst *SI = dyn_cast<StoreInst>(I++);
     if (SI == 0 || SI->isVolatile()) continue;
     
-    WeakVH InstPtr;
-    if (processLoopStore(SI, BECount)) {
-      // If processing the store invalidated our iterator, start over from the
-      // head of the loop.
-      if (InstPtr == 0)
-        I = BB->begin();
-    }
+    WeakVH InstPtr(SI);
+    if (!processLoopStore(SI, BECount)) continue;
+    
+    MadeChange = true;
+    
+    // If processing the store invalidated our iterator, start over from the
+    // head of the loop.
+    if (InstPtr == 0)
+      I = BB->begin();
   }
   
   return MadeChange;
@@ -158,12 +164,9 @@ bool LoopIdiomRecognize::processLoopStore(StoreInst *SI, const SCEV *BECount) {
   Value *StoredVal = SI->getValueOperand();
   Value *StorePtr = SI->getPointerOperand();
   
-  // Check to see if the store updates all bits in memory.  We don't want to
-  // process things like a store of i3.  We also require that the store be a
-  // multiple of a byte.
+  // Reject stores that are so large that they overflow an unsigned.
   uint64_t SizeInBits = TD->getTypeSizeInBits(StoredVal->getType());
-  if ((SizeInBits & 7) || (SizeInBits >> 32) != 0 ||
-      SizeInBits != TD->getTypeStoreSizeInBits(StoredVal->getType()))
+  if ((SizeInBits & 7) || (SizeInBits >> 32) != 0)
     return false;
   
   // See if the pointer expression is an AddRec like {base,+,1} on the current
@@ -177,6 +180,9 @@ bool LoopIdiomRecognize::processLoopStore(StoreInst *SI, const SCEV *BECount) {
   // know that every byte is touched in the loop.
   unsigned StoreSize = (unsigned)SizeInBits >> 3; 
   const SCEVConstant *Stride = dyn_cast<SCEVConstant>(Ev->getOperand(1));
+  
+  // TODO: Could also handle negative stride here someday, that will require the
+  // validity check in mayLoopModRefLocation to be updated though.
   if (Stride == 0 || StoreSize != Stride->getValue()->getValue())
     return false;
   
@@ -193,18 +199,46 @@ bool LoopIdiomRecognize::processLoopStore(StoreInst *SI, const SCEV *BECount) {
   return false;
 }
 
+/// mayLoopModRefLocation - Return true if the specified loop might do a load or
+/// store to the same location that the specified store could store to, which is
+/// a loop-strided access. 
+static bool mayLoopModRefLocation(StoreInst *SI, Loop *L, AliasAnalysis &AA) {
+  // Get the location that may be stored across the loop.  Since the access is
+  // strided positively through memory, we say that the modified location starts
+  // at the pointer and has infinite size.
+  // TODO: Could improve this for constant trip-count loops.
+  AliasAnalysis::Location StoreLoc =
+    AliasAnalysis::Location(SI->getPointerOperand());
+
+  for (Loop::block_iterator BI = L->block_begin(), E = L->block_end(); BI != E;
+       ++BI)
+    for (BasicBlock::iterator I = (*BI)->begin(), E = (*BI)->end(); I != E; ++I)
+      if (AA.getModRefInfo(I, StoreLoc) != AliasAnalysis::NoModRef)
+        return true;
+
+  return false;
+}
+
 /// processLoopStoreOfSplatValue - We see a strided store of a memsetable value.
 /// If we can transform this into a memset in the loop preheader, do so.
 bool LoopIdiomRecognize::
 processLoopStoreOfSplatValue(StoreInst *SI, unsigned StoreSize,
                              Value *SplatValue,
                              const SCEVAddRecExpr *Ev, const SCEV *BECount) {
+  // Temporarily remove the store from the loop, to avoid the mod/ref query from
+  // seeing it.
+  Instruction *InstAfterStore = ++BasicBlock::iterator(SI);
+  SI->removeFromParent();
+  
   // Okay, we have a strided store "p[i]" of a splattable value.  We can turn
   // this into a memset in the loop preheader now if we want.  However, this
   // would be unsafe to do if there is anything else in the loop that may read
   // or write to the aliased location.  Check for an alias.
+  bool Unsafe=mayLoopModRefLocation(SI, CurLoop, getAnalysis<AliasAnalysis>());
+
+  SI->insertBefore(InstAfterStore);
   
-  // FIXME: TODO safety check.
+  if (Unsafe) return false;
   
   // Okay, everything looks good, insert the memset.
   BasicBlock *Preheader = CurLoop->getLoopPreheader();
@@ -244,6 +278,7 @@ processLoopStoreOfSplatValue(StoreInst *SI, unsigned StoreSize,
   
   DEBUG(dbgs() << "  Formed memset: " << *NewCall << "\n"
                << "    from store to: " << *Ev << " at: " << *SI << "\n");
+  (void)NewCall;
   
   // Okay, the memset has been formed.  Zap the original store and anything that
   // feeds into it.