Removed #include <iostream> and replaced with llvm_* streams.
[oota-llvm.git] / lib / Transforms / Scalar / LoopUnswitch.cpp
index c583eea14d62ea9598069afb96342cb58b351cba..8b2f6cfc5eb8be1a580c3eff67ec18cb90714002 100644 (file)
@@ -40,7 +40,6 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/CommandLine.h"
 #include <algorithm>
-#include <iostream>
 #include <set>
 using namespace llvm;
 
@@ -73,6 +72,8 @@ namespace {
       AU.addPreservedID(LoopSimplifyID);
       AU.addRequired<LoopInfo>();
       AU.addPreserved<LoopInfo>();
+      AU.addRequiredID(LCSSAID);
+      AU.addPreservedID(LCSSAID);
     }
 
   private:
@@ -101,7 +102,7 @@ namespace {
                            std::vector<Instruction*> &Worklist);
     void RemoveLoopFromHierarchy(Loop *L);
   };
-  RegisterOpt<LoopUnswitch> X("loop-unswitch", "Unswitch loops");
+  RegisterPass<LoopUnswitch> X("loop-unswitch", "Unswitch loops");
 }
 
 FunctionPass *llvm::createLoopUnswitchPass() { return new LoopUnswitch(); }
@@ -154,6 +155,8 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) {
 }
 
 bool LoopUnswitch::visitLoop(Loop *L) {
+  assert(L->isLCSSAForm());
+  
   bool Changed = false;
   
   // Loop over all of the basic blocks in the loop.  If we find an interior
@@ -169,7 +172,8 @@ bool LoopUnswitch::visitLoop(Loop *L) {
         // See if this, or some part of it, is loop invariant.  If so, we can
         // unswitch on it if we desire.
         Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), L, Changed);
-        if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantBool::True, L)) {
+        if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantBool::getTrue(),
+                                             L)) {
           ++NumBranches;
           return true;
         }
@@ -192,39 +196,17 @@ bool LoopUnswitch::visitLoop(Loop *L) {
          BBI != E; ++BBI)
       if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) {
         Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), L, Changed);
-        if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantBool::True, L)) {
+        if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantBool::getTrue(),
+                                             L)) {
           ++NumSelects;
           return true;
         }
       }
   }
-    
-  return Changed;
-}
-
-
-/// LoopValuesUsedOutsideLoop - Return true if there are any values defined in
-/// the loop that are used by instructions outside of it.
-static bool LoopValuesUsedOutsideLoop(Loop *L) {
-  // We will be doing lots of "loop contains block" queries.  Loop::contains is
-  // linear time, use a set to speed this up.
-  std::set<BasicBlock*> LoopBlocks;
-
-  for (Loop::block_iterator BB = L->block_begin(), E = L->block_end();
-       BB != E; ++BB)
-    LoopBlocks.insert(*BB);
   
-  for (Loop::block_iterator BB = L->block_begin(), E = L->block_end();
-       BB != E; ++BB) {
-    for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end(); I != E; ++I)
-      for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); UI != E;
-           ++UI) {
-        BasicBlock *UserBB = cast<Instruction>(*UI)->getParent();
-        if (!LoopBlocks.count(UserBB))
-          return true;
-      }
-  }
-  return false;
+  assert(L->isLCSSAForm());
+  
+  return Changed;
 }
 
 /// isTrivialLoopExitBlock - Check to see if all paths from BB either:
@@ -305,9 +287,9 @@ static bool IsTrivialUnswitchCondition(Loop *L, Value *Cond, Constant **Val = 0,
     // side-effects.  If so, determine the value of Cond that causes it to do
     // this.
     if ((LoopExitBB = isTrivialLoopExitBlock(L, BI->getSuccessor(0)))) {
-      if (Val) *Val = ConstantBool::True;
+      if (Val) *Val = ConstantBool::getTrue();
     } else if ((LoopExitBB = isTrivialLoopExitBlock(L, BI->getSuccessor(1)))) {
-      if (Val) *Val = ConstantBool::False;
+      if (Val) *Val = ConstantBool::getFalse();
     }
   } else if (SwitchInst *SI = dyn_cast<SwitchInst>(HeaderTerm)) {
     // If this isn't a switch on Cond, we can't handle it.
@@ -352,6 +334,12 @@ unsigned LoopUnswitch::getLoopUnswitchCost(Loop *L, Value *LIC) {
   if (IsTrivialUnswitchCondition(L, LIC))
     return 0;
   
+  // FIXME: This is really overly conservative.  However, more liberal 
+  // estimations have thus far resulted in excessive unswitching, which is bad
+  // both in compile time and in code size.  This should be replaced once
+  // someone figures out how a good estimation.
+  return L->getBlocks().size();
+  
   unsigned Cost = 0;
   // FIXME: this is brain dead.  It should take into consideration code
   // shrinkage.
@@ -380,22 +368,12 @@ bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val,Loop *L){
     // FIXME: this should estimate growth by the amount of code shared by the
     // resultant unswitched loops.
     //
-    DEBUG(std::cerr << "NOT unswitching loop %"
-                    << L->getHeader()->getName() << ", cost too high: "
-                    << L->getBlocks().size() << "\n");
-    return false;
-  }
-    
-  // If this loop has live-out values, we can't unswitch it. We need something
-  // like loop-closed SSA form in order to know how to insert PHI nodes for
-  // these values.
-  if (LoopValuesUsedOutsideLoop(L)) {
-    DEBUG(std::cerr << "NOT unswitching loop %" << L->getHeader()->getName()
-                    << ", a loop value is used outside loop!  Cost: "
-                    << Cost << "\n");
+    DOUT << "NOT unswitching loop %"
+         << L->getHeader()->getName() << ", cost too high: "
+         << L->getBlocks().size() << "\n";
     return false;
   }
-      
+  
   // If this is a trivial condition to unswitch (which results in no code
   // duplication), do it now.
   Constant *CondVal;
@@ -511,7 +489,7 @@ static void EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val,
   Value *BranchVal = LIC;
   if (!isa<ConstantBool>(Val)) {
     BranchVal = BinaryOperator::createSetEQ(LIC, Val, "tmp", InsertPt);
-  } else if (Val != ConstantBool::True) {
+  } else if (Val != ConstantBool::getTrue()) {
     // We want to enter the new loop when the condition is true.
     std::swap(TrueDest, FalseDest);
   }
@@ -529,10 +507,10 @@ static void EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val,
 void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, 
                                             Constant *Val, 
                                             BasicBlock *ExitBlock) {
-  DEBUG(std::cerr << "loop-unswitch: Trivial-Unswitch loop %"
-        << L->getHeader()->getName() << " [" << L->getBlocks().size()
-        << " blocks] in Function " << L->getHeader()->getParent()->getName()
-        << " on cond: " << *Val << " == " << *Cond << "\n");
+  DOUT << "loop-unswitch: Trivial-Unswitch loop %"
+       << L->getHeader()->getName() << " [" << L->getBlocks().size()
+       << " blocks] in Function " << L->getHeader()->getParent()->getName()
+       << " on cond: " << *Val << " == " << *Cond << "\n";
   
   // First step, split the preheader, so that we know that there is a safe place
   // to insert the conditional branch.  We will change 'OrigPH' to have a
@@ -574,10 +552,10 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond,
 void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, 
                                                Loop *L) {
   Function *F = L->getHeader()->getParent();
-  DEBUG(std::cerr << "loop-unswitch: Unswitching loop %"
-                  << L->getHeader()->getName() << " [" << L->getBlocks().size()
-                  << " blocks] in Function " << F->getName()
-                  << " when '" << *Val << "' == " << *LIC << "\n");
+  DOUT << "loop-unswitch: Unswitching loop %"
+       << L->getHeader()->getName() << " [" << L->getBlocks().size()
+       << " blocks] in Function " << F->getName()
+       << " when '" << *Val << "' == " << *LIC << "\n";
 
   // LoopBlocks contains all of the basic blocks of the loop, including the
   // preheader of the loop, the body of the loop, and the exit blocks of the 
@@ -593,15 +571,10 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
   LoopBlocks.insert(LoopBlocks.end(), L->block_begin(), L->block_end());
 
   std::vector<BasicBlock*> ExitBlocks;
-  L->getExitBlocks(ExitBlocks);
-  std::sort(ExitBlocks.begin(), ExitBlocks.end());
-  ExitBlocks.erase(std::unique(ExitBlocks.begin(), ExitBlocks.end()),
-                   ExitBlocks.end());
-  
-  // Split all of the edges from inside the loop to their exit blocks.  This
-  // unswitching trivial: no phi nodes to update.
-  unsigned NumBlocks = L->getBlocks().size();
-  
+  L->getUniqueExitBlocks(ExitBlocks);
+
+  // Split all of the edges from inside the loop to their exit blocks.  Update
+  // the appropriate Phi nodes as we do so.
   for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) {
     BasicBlock *ExitBlock = ExitBlocks[i];
     std::vector<BasicBlock*> Preds(pred_begin(ExitBlock), pred_end(ExitBlock));
@@ -609,17 +582,48 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
     for (unsigned j = 0, e = Preds.size(); j != e; ++j) {
       assert(L->contains(Preds[j]) &&
              "All preds of loop exit blocks must be the same loop!");
-      SplitEdge(Preds[j], ExitBlock);
-    }
+      BasicBlock* MiddleBlock = SplitEdge(Preds[j], ExitBlock);
+      BasicBlock* StartBlock = Preds[j];
+      BasicBlock* EndBlock;
+      if (MiddleBlock->getSinglePredecessor() == ExitBlock) {
+        EndBlock = MiddleBlock;
+        MiddleBlock = EndBlock->getSinglePredecessor();;
+      } else {
+        EndBlock = ExitBlock;
+      }
+      
+      std::set<PHINode*> InsertedPHIs;
+      PHINode* OldLCSSA = 0;
+      for (BasicBlock::iterator I = EndBlock->begin();
+           (OldLCSSA = dyn_cast<PHINode>(I)); ++I) {
+        Value* OldValue = OldLCSSA->getIncomingValueForBlock(MiddleBlock);
+        PHINode* NewLCSSA = new PHINode(OldLCSSA->getType(),
+                                        OldLCSSA->getName() + ".us-lcssa",
+                                        MiddleBlock->getTerminator());
+        NewLCSSA->addIncoming(OldValue, StartBlock);
+        OldLCSSA->setIncomingValue(OldLCSSA->getBasicBlockIndex(MiddleBlock),
+                                   NewLCSSA);
+        InsertedPHIs.insert(NewLCSSA);
+      }
+
+      BasicBlock::iterator InsertPt = EndBlock->begin();
+      while (dyn_cast<PHINode>(InsertPt)) ++InsertPt;
+      for (BasicBlock::iterator I = MiddleBlock->begin();
+         (OldLCSSA = dyn_cast<PHINode>(I)) && InsertedPHIs.count(OldLCSSA) == 0;
+         ++I) {
+        PHINode *NewLCSSA = new PHINode(OldLCSSA->getType(),
+                                        OldLCSSA->getName() + ".us-lcssa",
+                                        InsertPt);
+        OldLCSSA->replaceAllUsesWith(NewLCSSA);
+        NewLCSSA->addIncoming(OldLCSSA, MiddleBlock);
+      }
+    }    
   }
   
   // The exit blocks may have been changed due to edge splitting, recompute.
   ExitBlocks.clear();
-  L->getExitBlocks(ExitBlocks);
-  std::sort(ExitBlocks.begin(), ExitBlocks.end());
-  ExitBlocks.erase(std::unique(ExitBlocks.begin(), ExitBlocks.end()),
-                   ExitBlocks.end());
-  
+  L->getUniqueExitBlocks(ExitBlocks);
+
   // Add exit blocks to the loop blocks.
   LoopBlocks.insert(LoopBlocks.end(), ExitBlocks.begin(), ExitBlocks.end());
 
@@ -716,7 +720,7 @@ static void RemoveFromWorklist(Instruction *I,
 /// program, replacing all uses with V and update the worklist.
 static void ReplaceUsesOfWith(Instruction *I, Value *V, 
                               std::vector<Instruction*> &Worklist) {
-  DEBUG(std::cerr << "Replace with '" << *V << "': " << *I);
+  DOUT << "Replace with '" << *V << "': " << *I;
 
   // Add uses to the worklist, which may be dead now.
   for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i)
@@ -775,7 +779,7 @@ void LoopUnswitch::RemoveBlockIfDead(BasicBlock *BB,
     return;
   }
 
-  DEBUG(std::cerr << "Nuking dead block: " << *BB);
+  DOUT << "Nuking dead block: " << *BB;
   
   // Remove the instructions in the basic block from the worklist.
   for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) {
@@ -949,7 +953,29 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
               // Found a dead case value.  Don't remove PHI nodes in the 
               // successor if they become single-entry, those PHI nodes may
               // be in the Users list.
-              SI->getSuccessor(i)->removePredecessor(SI->getParent(), true);
+              
+              // FIXME: This is a hack.  We need to keep the successor around
+              // and hooked up so as to preserve the loop structure, because
+              // trying to update it is complicated.  So instead we preserve the
+              // loop structure and put the block on an dead code path.
+              
+              BasicBlock* Old = SI->getParent();
+              BasicBlock* Split = SplitBlock(Old, SI);
+              
+              Instruction* OldTerm = Old->getTerminator();
+              new BranchInst(Split, SI->getSuccessor(i),
+                             ConstantBool::getTrue(), OldTerm);
+              
+              Old->getTerminator()->eraseFromParent();
+              
+              
+              PHINode *PN;
+              for (BasicBlock::iterator II = SI->getSuccessor(i)->begin();
+                   (PN = dyn_cast<PHINode>(II)); ++II) {
+                Value *InVal = PN->removeIncomingValue(Split, false);
+                PN->addIncoming(InVal, Old);
+              }
+
               SI->removeCase(i);
               break;
             }
@@ -986,7 +1012,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist) {
     
     // Simple DCE.
     if (isInstructionTriviallyDead(I)) {
-      DEBUG(std::cerr << "Remove dead instruction '" << *I);
+      DOUT << "Remove dead instruction '" << *I;
       
       // Add uses to the worklist, which may be dead now.
       for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i)
@@ -1039,8 +1065,8 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist) {
         if (!SinglePred) continue;  // Nothing to do.
         assert(SinglePred == Pred && "CFG broken");
 
-        DEBUG(std::cerr << "Merging blocks: " << Pred->getName() << " <- " 
-                        << Succ->getName() << "\n");
+        DOUT << "Merging blocks: " << Pred->getName() << " <- " 
+             << Succ->getName() << "\n";
         
         // Resolve any single entry PHI nodes in Succ.
         while (PHINode *PN = dyn_cast<PHINode>(Succ->begin()))
@@ -1065,7 +1091,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist) {
         // remove dead blocks.
         break;  // FIXME: Enable.
 
-        DEBUG(std::cerr << "Folded branch: " << *BI);
+        DOUT << "Folded branch: " << *BI;
         BasicBlock *DeadSucc = BI->getSuccessor(CB->getValue());
         BasicBlock *LiveSucc = BI->getSuccessor(!CB->getValue());
         DeadSucc->removePredecessor(BI->getParent(), true);