[LoopReroll] Alter the data structures used during reroll validation.
[oota-llvm.git] / lib / Transforms / Scalar / LoopRerollPass.cpp
index d105354f5b392aa3e78310da2803b2ba7bcfd75c..ad308caff0f8673012b9130d41fcb93665cb9cfb 100644 (file)
@@ -12,7 +12,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Scalar.h"
+#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AliasAnalysis.h"
@@ -119,6 +121,16 @@ MaxInc("max-reroll-increment", cl::init(2048), cl::Hidden,
 // br %cmp, header, exit
 
 namespace {
+  enum IterationLimits {
+    /// The maximum number of iterations that we'll try and reroll. This
+    /// has to be less than 25 in order to fit into a SmallBitVector.
+    IL_MaxRerollIterations = 16,
+    /// The bitvector index used by loop induction variables and other
+    /// instructions that belong to no one particular iteration.
+    IL_LoopIncIdx,
+    IL_End
+  };
+
   class LoopReroll : public LoopPass {
   public:
     static char ID; // Pass ID, replacement for typeid
@@ -138,7 +150,7 @@ namespace {
       AU.addRequired<TargetLibraryInfoWrapperPass>();
     }
 
-protected:
+  protected:
     AliasAnalysis *AA;
     LoopInfo *LI;
     ScalarEvolution *SE;
@@ -331,9 +343,12 @@ protected:
       void replace(const SCEV *IterCount);
 
     protected:
+      typedef MapVector<Instruction*, SmallBitVector> UsesTy;
+
       bool findScaleFromMul();
       bool collectAllRoots();
 
+      bool collectUsedInstructions(SmallInstructionSet &PossibleRedSet);
       void collectInLoopUserSet(const SmallInstructionVector &Roots,
                                 const SmallInstructionSet &Exclude,
                                 const SmallInstructionSet &Final,
@@ -343,6 +358,8 @@ protected:
                                 const SmallInstructionSet &Final,
                                 DenseSet<Instruction *> &Users);
 
+      UsesTy::iterator nextInstr(int Val, UsesTy &In, UsesTy::iterator I);
+
       LoopReroll *Parent;
 
       // Members of Parent, replicated here for brevity.
@@ -366,12 +383,10 @@ protected:
       SmallInstructionVector Roots;
       // All increment instructions for IV.
       SmallInstructionVector LoopIncs;
-      // All instructions transitively used by any root.
-      DenseSet<Instruction *> AllRootUses;
-      // All instructions transitively used by the base.
-      DenseSet<Instruction *> BaseUseSet;
-      // All instructions transitively used by the increments.
-      DenseSet<Instruction *> LoopIncUseSet;
+      // Map of all instructions in the loop (in order) to the iterations
+      // they are used in (or specially, IL_LoopIncIdx for instructions
+      // used in the loop increment mechanism).
+      UsesTy Uses;
     };
 
     void collectPossibleIVs(Loop *L, SmallInstructionVector &PossibleIVs);
@@ -399,10 +414,10 @@ Pass *llvm::createLoopRerollPass() {
 // This operates like Instruction::isUsedOutsideOfBlock, but considers PHIs in
 // non-loop blocks to be outside the loop.
 static bool hasUsesOutsideLoop(Instruction *I, Loop *L) {
-  for (User *U : I->users())
+  for (User *U : I->users()) {
     if (!L->contains(cast<Instruction>(U)))
       return true;
-
+  }
   return false;
 }
 
@@ -470,11 +485,12 @@ void LoopReroll::SimpleLoopReduction::add(Loop *L) {
     return;
 
   // C is now the (potential) last instruction in the reduction chain.
-  for (User *U : C->users())
+  for (User *U : C->users()) {
     // The only in-loop user can be the initial PHI.
     if (L->contains(cast<Instruction>(U)))
       if (cast<Instruction>(U) != Instructions.front())
         return;
+  }
 
   Instructions.push_back(C);
   Valid = true;
@@ -592,6 +608,13 @@ bool LoopReroll::DAGRootTracker::findRoots() {
   if (!collectAllRoots())
     return false;
 
+  if (Roots.size() > IL_MaxRerollIterations) {
+    DEBUG(dbgs() << "LRR: Aborting - too many iterations found. "
+          << "#Found=" << Roots.size() << ", #Max=" << IL_MaxRerollIterations
+          << "\n");
+    return false;
+  }
+
   return true;
 }
 
@@ -715,9 +738,65 @@ bool LoopReroll::DAGRootTracker::collectAllRoots() {
   return true;
 }
 
-bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) {
-  BasicBlock *Header = L->getHeader();
+bool LoopReroll::DAGRootTracker::collectUsedInstructions(SmallInstructionSet &PossibleRedSet) {
+  // Populate the MapVector with all instructions in the block, in order first,
+  // so we can iterate over the contents later in perfect order.
+  for (auto &I : *L->getHeader()) {
+    Uses[&I].resize(IL_End);
+  }
+
+  SmallInstructionSet Exclude;
+  Exclude.insert(Roots.begin(), Roots.end());
+  Exclude.insert(LoopIncs.begin(), LoopIncs.end());
 
+  DenseSet<Instruction*> VBase;
+  collectInLoopUserSet(IV, Exclude, PossibleRedSet, VBase);
+  for (auto *I : VBase) {
+    Uses[I].set(0);
+  }
+
+  unsigned Idx = 1;
+  for (auto *Root : Roots) {
+    DenseSet<Instruction*> V;
+    collectInLoopUserSet(Root, Exclude, PossibleRedSet, V);
+
+    // While we're here, check the use sets are the same size.
+    if (V.size() != VBase.size()) {
+      DEBUG(dbgs() << "LRR: Aborting - use sets are different sizes\n");
+      return false;
+    }
+
+    for (auto *I : V) {
+      Uses[I].set(Idx);
+    }
+    ++Idx;
+  }
+
+  // Make sure the loop increments are also accounted for.
+  Exclude.clear();
+  Exclude.insert(Roots.begin(), Roots.end());
+
+  DenseSet<Instruction*> V;
+  collectInLoopUserSet(LoopIncs, Exclude, PossibleRedSet, V);
+  for (auto *I : V) {
+    Uses[I].set(IL_LoopIncIdx);
+  }
+  if (IV != RealIV)
+    Uses[RealIV].set(IL_LoopIncIdx);
+
+  return true;
+
+}
+
+LoopReroll::DAGRootTracker::UsesTy::iterator
+LoopReroll::DAGRootTracker::nextInstr(int Val, UsesTy &In,
+                                      UsesTy::iterator I) {
+  while (I != In.end() && I->second.test(Val) == 0)
+    ++I;
+  return I;
+}
+
+bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) {
   // We now need to check for equivalence of the use graph of each root with
   // that of the primary induction variable (excluding the roots). Our goal
   // here is not to solve the full graph isomorphism problem, but rather to
@@ -726,9 +805,6 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) {
   // is the same (although we will not make an assumption about how the
   // different iterations are intermixed). Note that while the order must be
   // the same, the instructions may not be in the same basic block.
-  SmallInstructionSet Exclude;
-  Exclude.insert(Roots.begin(), Roots.end());
-  Exclude.insert(LoopIncs.begin(), LoopIncs.end());
 
   // An array of just the possible reductions for this scale factor. When we
   // collect the set of all users of some root instructions, these reduction
@@ -740,116 +816,123 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) {
   SmallInstructionSet PossibleRedPHISet;
   Reductions.restrictToScale(Scale, PossibleRedSet,
                              PossibleRedPHISet, PossibleRedLastSet);
-                                          
-
-  collectInLoopUserSet(IV, Exclude, PossibleRedSet, BaseUseSet);
 
-  std::vector<DenseSet<Instruction *> > RootUseSets(Scale-1);
+  // Populate "Uses" with where each instruction is used.
+  if (!collectUsedInstructions(PossibleRedSet))
+    return false;
 
-  bool MatchFailed = false;
-  for (unsigned i = 0; i < Scale-1 && !MatchFailed; ++i) {
-    DenseSet<Instruction *> &RootUseSet = RootUseSets[i];
-    collectInLoopUserSet(Roots[i], SmallInstructionSet(),
-                         PossibleRedSet, RootUseSet);
+  // Make sure we mark the reduction PHIs as used in all iterations.
+  for (auto *I : PossibleRedPHISet) {
+    Uses[I].set(IL_LoopIncIdx);
+  }
 
-    DEBUG(dbgs() << "LRR: base use set size: " << BaseUseSet.size() <<
-                    " vs. iteration increment " << (i+1) <<
-                    " use set size: " << RootUseSet.size() << "\n");
+  // Make sure all instructions in the loop are in one and only one
+  // set.
+  for (auto &KV : Uses) {
+    if (KV.second.count() != 1) {
+      DEBUG(dbgs() << "LRR: Aborting - instruction is not used in 1 iteration: "
+            << *KV.first << " (#uses=" << KV.second.count() << ")\n");
+      return false;
+    }
+  }
 
-    if (BaseUseSet.size() != RootUseSet.size()) {
-      MatchFailed = true;
-      break;
+  DEBUG(
+    for (auto &KV : Uses) {
+      dbgs() << "LRR: " << KV.second.find_first() << "\t" << *KV.first << "\n";
     }
+    );
 
+  for (unsigned Iter = 1; Iter < Scale; ++Iter) {
     // In addition to regular aliasing information, we need to look for
     // instructions from later (future) iterations that have side effects
     // preventing us from reordering them past other instructions with side
     // effects.
     bool FutureSideEffects = false;
     AliasSetTracker AST(*AA);
-
     // The map between instructions in f(%iv.(i+1)) and f(%iv).
     DenseMap<Value *, Value *> BaseMap;
 
-    assert(L->getNumBlocks() == 1 && "Cannot handle multi-block loops");
-    for (BasicBlock::iterator J1 = Header->begin(), J2 = Header->begin(),
-         JE = Header->end(); J1 != JE && !MatchFailed; ++J1) {
-      if (cast<Instruction>(J1) == RealIV)
-        continue;
-      if (cast<Instruction>(J1) == IV)
-        continue;
-      if (!BaseUseSet.count(J1))
-        continue;
-      if (PossibleRedPHISet.count(J1)) // Skip reduction PHIs.
-        continue;
+    // Compare iteration Iter to the base.
+    auto BaseIt = nextInstr(0, Uses, Uses.begin());
+    auto RootIt = nextInstr(Iter, Uses, Uses.begin());
+    auto LastRootIt = Uses.begin();
 
-      while (J2 != JE && (!RootUseSet.count(J2) || Roots[i] ==  J2)) {
-        // As we iterate through the instructions, instructions that don't
-        // belong to previous iterations (or the base case), must belong to
-        // future iterations. We want to track the alias set of writes from
-        // previous iterations.
-        if (!isa<PHINode>(J2) && !BaseUseSet.count(J2) &&
-            !AllRootUses.count(J2)) {
-          if (J2->mayWriteToMemory())
-            AST.add(J2);
-
-          // Note: This is specifically guarded by a check on isa<PHINode>,
-          // which while a valid (somewhat arbitrary) micro-optimization, is
-          // needed because otherwise isSafeToSpeculativelyExecute returns
-          // false on PHI nodes.
-          if (!isSimpleLoadStore(J2) && !isSafeToSpeculativelyExecute(J2, DL))
-            FutureSideEffects = true;
-        }
+    while (BaseIt != Uses.end() && RootIt != Uses.end()) {
+      Instruction *BaseInst = BaseIt->first;
+      Instruction *RootInst = RootIt->first;
 
-        ++J2;
+      // Skip over the IV or root instructions; only match their users.
+      bool Continue = false;
+      if (BaseInst == RealIV || BaseInst == IV) {
+        BaseIt = nextInstr(0, Uses, ++BaseIt);
+        Continue = true;
+      }
+      if (std::find(Roots.begin(), Roots.end(), RootInst) != Roots.end()) {
+        LastRootIt = RootIt;
+        RootIt = nextInstr(Iter, Uses, ++RootIt);
+        Continue = true;
+      }
+      if (Continue) continue;
+
+      // All instructions between the last root and this root
+      // belong to some other iteration. If they belong to a 
+      // future iteration, then they're dangerous to alias with.
+      for (; LastRootIt != RootIt; ++LastRootIt) {
+        Instruction *I = LastRootIt->first;
+        if (LastRootIt->second.find_first() < (int)Iter)
+          continue;
+        if (I->mayWriteToMemory())
+          AST.add(I);
+        // Note: This is specifically guarded by a check on isa<PHINode>,
+        // which while a valid (somewhat arbitrary) micro-optimization, is
+        // needed because otherwise isSafeToSpeculativelyExecute returns
+        // false on PHI nodes.
+        if (!isa<PHINode>(I) && !isSimpleLoadStore(I) &&
+            !isSafeToSpeculativelyExecute(I, DL))
+          // Intervening instructions cause side effects.
+          FutureSideEffects = true;
       }
 
-      if (!J1->isSameOperationAs(J2)) {
-        DEBUG(dbgs() << "LRR: iteration root match failed at " << *J1 <<
-                        " vs. " << *J2 << "\n");
-        MatchFailed = true;
-        break;
+      if (!BaseInst->isSameOperationAs(RootInst)) {
+        DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst <<
+              " vs. " << *RootInst << "\n");
+        return false;
       }
 
       // Make sure that this instruction, which is in the use set of this
       // root instruction, does not also belong to the base set or the set of
-      // some previous root instruction.
-      if (BaseUseSet.count(J2) || AllRootUses.count(J2)) {
-        DEBUG(dbgs() << "LRR: iteration root match failed at " << *J1 <<
-                        " vs. " << *J2 << " (prev. case overlap)\n");
-        MatchFailed = true;
-        break;
+      // some other root instruction.
+      if (RootIt->second.count() > 1) {
+        DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst <<
+                        " vs. " << *RootInst << " (prev. case overlap)\n");
+        return false;
       }
 
       // Make sure that we don't alias with any instruction in the alias set
       // tracker. If we do, then we depend on a future iteration, and we
       // can't reroll.
-      if (J2->mayReadFromMemory()) {
-        for (AliasSetTracker::iterator K = AST.begin(), KE = AST.end();
-             K != KE && !MatchFailed; ++K) {
-          if (K->aliasesUnknownInst(J2, *AA)) {
-            DEBUG(dbgs() << "LRR: iteration root match failed at " << *J1 <<
-                            " vs. " << *J2 << " (depends on future store)\n");
-            MatchFailed = true;
-            break;
+      if (RootInst->mayReadFromMemory())
+        for (auto &K : AST) {
+          if (K.aliasesUnknownInst(RootInst, *AA)) {
+            DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst <<
+                            " vs. " << *RootInst << " (depends on future store)\n");
+            return false;
           }
         }
-      }
 
       // If we've past an instruction from a future iteration that may have
       // side effects, and this instruction might also, then we can't reorder
       // them, and this matching fails. As an exception, we allow the alias
       // set tracker to handle regular (simple) load/store dependencies.
       if (FutureSideEffects &&
-            ((!isSimpleLoadStore(J1) &&
-              !isSafeToSpeculativelyExecute(J1, DL)) ||
-             (!isSimpleLoadStore(J2) &&
-              !isSafeToSpeculativelyExecute(J2, DL)))) {
-        DEBUG(dbgs() << "LRR: iteration root match failed at " << *J1 <<
-                        " vs. " << *J2 <<
+            ((!isSimpleLoadStore(BaseInst) &&
+              !isSafeToSpeculativelyExecute(BaseInst, DL)) ||
+             (!isSimpleLoadStore(RootInst) &&
+              !isSafeToSpeculativelyExecute(RootInst, DL)))) {
+        DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst <<
+                        " vs. " << *RootInst <<
                         " (side effects prevent reordering)\n");
-        MatchFailed = true;
-        break;
+        return false;
       }
 
       // For instructions that are part of a reduction, if the operation is
@@ -862,41 +945,40 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) {
       //   x += a[i]; x += b[i];
       //   x += a[i+1]; x += b[i+1];
       //   x += b[i+2]; x += a[i+2];
-      bool InReduction = Reductions.isPairInSame(J1, J2);
+      bool InReduction = Reductions.isPairInSame(BaseInst, RootInst);
 
-      if (!(InReduction && J1->isAssociative())) {
+      if (!(InReduction && BaseInst->isAssociative())) {
         bool Swapped = false, SomeOpMatched = false;
-        for (unsigned j = 0; j < J1->getNumOperands() && !MatchFailed; ++j) {
-          Value *Op2 = J2->getOperand(j);
+        for (unsigned j = 0; j < BaseInst->getNumOperands(); ++j) {
+          Value *Op2 = RootInst->getOperand(j);
 
           // If this is part of a reduction (and the operation is not
           // associatve), then we match all operands, but not those that are
           // part of the reduction.
           if (InReduction)
             if (Instruction *Op2I = dyn_cast<Instruction>(Op2))
-              if (Reductions.isPairInSame(J2, Op2I))
+              if (Reductions.isPairInSame(RootInst, Op2I))
                 continue;
 
           DenseMap<Value *, Value *>::iterator BMI = BaseMap.find(Op2);
           if (BMI != BaseMap.end())
             Op2 = BMI->second;
-          else if (Roots[i] == (Instruction*) Op2)
+          else if (Roots[Iter-1] == (Instruction*) Op2)
             Op2 = IV;
 
-          if (J1->getOperand(Swapped ? unsigned(!j) : j) != Op2) {
+          if (BaseInst->getOperand(Swapped ? unsigned(!j) : j) != Op2) {
             // If we've not already decided to swap the matched operands, and
             // we've not already matched our first operand (note that we could
             // have skipped matching the first operand because it is part of a
             // reduction above), and the instruction is commutative, then try
             // the swapped match.
-            if (!Swapped && J1->isCommutative() && !SomeOpMatched &&
-                J1->getOperand(!j) == Op2) {
+            if (!Swapped && BaseInst->isCommutative() && !SomeOpMatched &&
+                BaseInst->getOperand(!j) == Op2) {
               Swapped = true;
             } else {
-              DEBUG(dbgs() << "LRR: iteration root match failed at " << *J1 <<
-                              " vs. " << *J2 << " (operand " << j << ")\n");
-              MatchFailed = true;
-              break;
+              DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst
+                    << " vs. " << *RootInst << " (operand " << j << ")\n");
+              return false;
             }
           }
 
@@ -904,67 +986,29 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) {
         }
       }
 
-      if ((!PossibleRedLastSet.count(J1) && hasUsesOutsideLoop(J1, L)) ||
-          (!PossibleRedLastSet.count(J2) && hasUsesOutsideLoop(J2, L))) {
-        DEBUG(dbgs() << "LRR: iteration root match failed at " << *J1 <<
-                        " vs. " << *J2 << " (uses outside loop)\n");
-        MatchFailed = true;
-        break;
+      if ((!PossibleRedLastSet.count(BaseInst) &&
+           hasUsesOutsideLoop(BaseInst, L)) ||
+          (!PossibleRedLastSet.count(RootInst) &&
+           hasUsesOutsideLoop(RootInst, L))) {
+        DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst <<
+                        " vs. " << *RootInst << " (uses outside loop)\n");
+        return false;
       }
 
-      if (!MatchFailed)
-        BaseMap.insert(std::pair<Value *, Value *>(J2, J1));
-
-      AllRootUses.insert(J2);
-      Reductions.recordPair(J1, J2, i+1);
+      Reductions.recordPair(BaseInst, RootInst, Iter);
+      BaseMap.insert(std::make_pair(RootInst, BaseInst));
 
-      ++J2;
+      LastRootIt = RootIt;
+      BaseIt = nextInstr(0, Uses, ++BaseIt);
+      RootIt = nextInstr(Iter, Uses, ++RootIt);
     }
+    assert (BaseIt == Uses.end() && RootIt == Uses.end() &&
+            "Mismatched set sizes!");
   }
 
-  if (MatchFailed)
-    return false;
-
   DEBUG(dbgs() << "LRR: Matched all iteration increments for " <<
                   *RealIV << "\n");
 
-  collectInLoopUserSet(LoopIncs, SmallInstructionSet(),
-                       SmallInstructionSet(), LoopIncUseSet);
-  DEBUG(dbgs() << "LRR: Loop increment set size: " <<
-                  LoopIncUseSet.size() << "\n");
-
-  // Make sure that all instructions in the loop have been included in some
-  // use set.
-  for (BasicBlock::iterator J = Header->begin(), JE = Header->end();
-       J != JE; ++J) {
-    if (isa<DbgInfoIntrinsic>(J))
-      continue;
-    if (cast<Instruction>(J) == RealIV)
-      continue;
-    if (cast<Instruction>(J) == IV)
-      continue;
-    if (BaseUseSet.count(J) || AllRootUses.count(J) ||
-        (LoopIncUseSet.count(J) && (J->isTerminator() ||
-                                    isSafeToSpeculativelyExecute(J, DL))))
-      continue;
-
-    if (std::find(Roots.begin(), Roots.end(), J) != Roots.end())
-      continue;
-
-    if (Reductions.isSelectedPHI(J))
-      continue;
-
-    DEBUG(dbgs() << "LRR: aborting reroll based on " << *RealIV <<
-                    " unprocessed instruction found: " << *J << "\n");
-    MatchFailed = true;
-    break;
-  }
-
-  if (MatchFailed)
-    return false;
-
-  DEBUG(dbgs() << "LRR: all instructions processed from " <<
-                  *RealIV << "\n");
   return true;
 }
 
@@ -973,7 +1017,8 @@ void LoopReroll::DAGRootTracker::replace(const SCEV *IterCount) {
   // Remove instructions associated with non-base iterations.
   for (BasicBlock::reverse_iterator J = Header->rbegin();
        J != Header->rend();) {
-    if (AllRootUses.count(&*J)) {
+    unsigned I = Uses[&*J].find_first();
+    if (I > 0 && I < IL_LoopIncIdx) {
       Instruction *D = &*J;
       DEBUG(dbgs() << "LRR: removing: " << *D << "\n");
       D->eraseFromParent();
@@ -997,12 +1042,14 @@ void LoopReroll::DAGRootTracker::replace(const SCEV *IterCount) {
     SCEVExpander Expander(*SE, "reroll");
     Value *NewIV = Expander.expandCodeFor(H, IV->getType(), Header->begin());
 
-    for (DenseSet<Instruction *>::iterator J = BaseUseSet.begin(),
-         JE = BaseUseSet.end(); J != JE; ++J)
-      (*J)->replaceUsesOfWith(IV, NewIV);
+    for (auto &KV : Uses) {
+      if (KV.second.find_first() == 0)
+        KV.first->replaceUsesOfWith(IV, NewIV);
+    }
 
     if (BranchInst *BI = dyn_cast<BranchInst>(Header->getTerminator())) {
-      if (LoopIncUseSet.count(BI)) {
+      // FIXME: Why do we need this check?
+      if (Uses[BI].find_first() == IL_LoopIncIdx) {
         const SCEV *ICSCEV = RealIVSCEV->evaluateAtIteration(IterCount, *SE);
         if (Inc == 1)
           ICSCEV =
@@ -1100,8 +1147,9 @@ void LoopReroll::ReductionTracker::replaceSelected() {
 
     // Replace users with the new end-of-chain value.
     SmallInstructionVector Users;
-    for (User *U : PossibleReds[i].getReducedValue()->users())
+    for (User *U : PossibleReds[i].getReducedValue()->users()) {
       Users.push_back(cast<Instruction>(U));
+    }
 
     for (SmallInstructionVector::iterator J = Users.begin(),
          JE = Users.end(); J != JE; ++J)