InstCombine: Shrink ((zext X) & C1) == C2 to fold away the cast if the "zext" and...
[oota-llvm.git] / lib / Transforms / Scalar / LoopUnswitch.cpp
index ae7bf40e0e1ad20402f724cf26c3e820190e50c7..e05f29c3e13f37d5449fdf50e8cbf69a6e10d04e 100644 (file)
 #include "llvm/DerivedTypes.h"
 #include "llvm/Function.h"
 #include "llvm/Instructions.h"
-#include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InlineCost.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/LoopPass.h"
 #include "llvm/Analysis/Dominators.h"
+#include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -77,7 +77,6 @@ namespace {
     bool redoLoop;
 
     Loop *currentLoop;
-    DominanceFrontier *DF;
     DominatorTree *DT;
     BasicBlock *loopHeader;
     BasicBlock *loopPreheader;
@@ -92,15 +91,17 @@ namespace {
   public:
     static char ID; // Pass ID, replacement for typeid
     explicit LoopUnswitch(bool Os = false) : 
-      LoopPass(&ID), OptimizeForSize(Os), redoLoop(false), 
-      currentLoop(NULL), DF(NULL), DT(NULL), loopHeader(NULL),
-      loopPreheader(NULL) {}
+      LoopPass(ID), OptimizeForSize(Os), redoLoop(false), 
+      currentLoop(NULL), DT(NULL), loopHeader(NULL),
+      loopPreheader(NULL) {
+        initializeLoopUnswitchPass(*PassRegistry::getPassRegistry());
+      }
 
     bool runOnLoop(Loop *L, LPPassManager &LPM);
     bool processCurrentLoop();
 
     /// This transformation requires natural loop information & requires that
-    /// loop preheaders be inserted into the CFG...
+    /// loop preheaders be inserted into the CFG.
     ///
     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
       AU.addRequiredID(LoopSimplifyID);
@@ -110,7 +111,7 @@ namespace {
       AU.addRequiredID(LCSSAID);
       AU.addPreservedID(LCSSAID);
       AU.addPreserved<DominatorTree>();
-      AU.addPreserved<DominanceFrontier>();
+      AU.addPreserved<ScalarEvolution>();
     }
 
   private:
@@ -160,7 +161,13 @@ namespace {
   };
 }
 char LoopUnswitch::ID = 0;
-static RegisterPass<LoopUnswitch> X("loop-unswitch", "Unswitch loops");
+INITIALIZE_PASS_BEGIN(LoopUnswitch, "loop-unswitch", "Unswitch loops",
+                      false, false)
+INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
+INITIALIZE_PASS_DEPENDENCY(LoopInfo)
+INITIALIZE_PASS_DEPENDENCY(LCSSA)
+INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops",
+                      false, false)
 
 Pass *llvm::createLoopUnswitchPass(bool Os) { 
   return new LoopUnswitch(Os); 
@@ -201,7 +208,6 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) {
 bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {
   LI = &getAnalysis<LoopInfo>();
   LPM = &LPM_Ref;
-  DF = getAnalysisIfAvailable<DominanceFrontier>();
   DT = getAnalysisIfAvailable<DominatorTree>();
   currentLoop = L;
   Function *F = currentLoop->getHeader()->getParent();
@@ -216,8 +222,6 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {
     // FIXME: Reconstruct dom info, because it is not preserved properly.
     if (DT)
       DT->runOnFunction(*F);
-    if (DF)
-      DF->runOnFunction(*F);
   }
   return Changed;
 }
@@ -254,6 +258,7 @@ bool LoopUnswitch::processCurrentLoop() {
       if (LoopCond && SI->getNumCases() > 1) {
         // Find a value to unswitch on:
         // FIXME: this should chose the most expensive case!
+        // FIXME: scan for a case with a non-critical edge?
         Constant *UnswitchVal = SI->getCaseValue(1);
         // Do not process same value again and again.
         if (!UnswitchedVals.insert(UnswitchVal))
@@ -282,19 +287,18 @@ bool LoopUnswitch::processCurrentLoop() {
   return Changed;
 }
 
-/// isTrivialLoopExitBlock - Check to see if all paths from BB either:
-///   1. Exit the loop with no side effects.
-///   2. Branch to the latch block with no side-effects.
+/// isTrivialLoopExitBlock - Check to see if all paths from BB exit the
+/// loop with no side effects (including infinite loops).
 ///
-/// If these conditions are true, we return true and set ExitBB to the block we
+/// If true, we return true and set ExitBB to the block we
 /// exit through.
 ///
 static bool isTrivialLoopExitBlockHelper(Loop *L, BasicBlock *BB,
                                          BasicBlock *&ExitBB,
                                          std::set<BasicBlock*> &Visited) {
   if (!Visited.insert(BB).second) {
-    // Already visited and Ok, end of recursion.
-    return true;
+    // Already visited. Without more analysis, this could indicate an infinte loop.
+    return false;
   } else if (!L->contains(BB)) {
     // Otherwise, this is a loop exit, this is fine so long as this is the
     // first exit.
@@ -324,7 +328,7 @@ static bool isTrivialLoopExitBlockHelper(Loop *L, BasicBlock *BB,
 /// process.  If so, return the block that is exited to, otherwise return null.
 static BasicBlock *isTrivialLoopExitBlock(Loop *L, BasicBlock *BB) {
   std::set<BasicBlock*> Visited;
-  Visited.insert(L->getHeader());  // Branches to header are ok.
+  Visited.insert(L->getHeader());  // Branches to header make infinite loops.
   BasicBlock *ExitBB = 0;
   if (isTrivialLoopExitBlockHelper(L, BB, ExitBB, Visited))
     return ExitBB;
@@ -356,8 +360,8 @@ bool LoopUnswitch::IsTrivialUnswitchCondition(Value *Cond, Constant **Val,
     if (!BI->isConditional() || BI->getCondition() != Cond)
       return false;
   
-    // Check to see if a successor of the branch is guaranteed to go to the
-    // latch block or exit through a one exit block without having any 
+    // Check to see if a successor of the branch is guaranteed to 
+    // exit through a unique exit block without having any 
     // side-effects.  If so, determine the value of Cond that causes it to do
     // this.
     if ((LoopExitBB = isTrivialLoopExitBlock(currentLoop, 
@@ -445,7 +449,7 @@ bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val) {
   // This is a very ad-hoc heuristic.
   if (Metrics.NumInsts > Threshold ||
       Metrics.NumBlocks * 5 > Threshold ||
-      Metrics.NeverInline) {
+      Metrics.containsIndirectBr || Metrics.isRecursive) {
     DEBUG(dbgs() << "NOT unswitching loop %"
           << currentLoop->getHeader()->getName() << ", cost too high: "
           << currentLoop->getBlocks().size() << "\n");
@@ -456,22 +460,9 @@ bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val) {
   return true;
 }
 
-// RemapInstruction - Convert the instruction operands from referencing the
-// current values into those specified by ValueMap.
-//
-static inline void RemapInstruction(Instruction *I,
-                                    DenseMap<const Value *, Value*> &ValueMap) {
-  for (unsigned op = 0, E = I->getNumOperands(); op != E; ++op) {
-    Value *Op = I->getOperand(op);
-    DenseMap<const Value *, Value*>::iterator It = ValueMap.find(Op);
-    if (It != ValueMap.end()) Op = It->second;
-    I->setOperand(op, Op);
-  }
-}
-
 /// CloneLoop - Recursively clone the specified loop and all of its children,
 /// mapping the blocks with the specified map.
-static Loop *CloneLoop(Loop *L, Loop *PL, DenseMap<const Value*, Value*> &VM,
+static Loop *CloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM,
                        LoopInfo *LI, LPPassManager *LPM) {
   Loop *New = new Loop();
   LPM->insertLoop(New, PL);
@@ -570,6 +561,8 @@ void LoopUnswitch::SplitExitEdges(Loop *L,
     BasicBlock *ExitBlock = ExitBlocks[i];
     SmallVector<BasicBlock *, 4> Preds(pred_begin(ExitBlock),
                                        pred_end(ExitBlock));
+    // Although SplitBlockPredecessors doesn't preserve loop-simplify in
+    // general, if we call it on all predecessors of all exits then it does.
     SplitBlockPredecessors(ExitBlock, Preds.data(), Preds.size(),
                            ".us-lcssa", this);
   }
@@ -586,6 +579,9 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
         << " blocks] in Function " << F->getName()
         << " when '" << *Val << "' == " << *LIC << "\n");
 
+  if (ScalarEvolution *SE = getAnalysisIfAvailable<ScalarEvolution>())
+    SE->forgetLoop(L);
+
   LoopBlocks.clear();
   NewBlocks.clear();
 
@@ -615,11 +611,11 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
   // the loop preheader and exit blocks), keeping track of the mapping between
   // the instructions and blocks.
   NewBlocks.reserve(LoopBlocks.size());
-  DenseMap<const Value*, Value*> ValueMap;
+  ValueToValueMapTy VMap;
   for (unsigned i = 0, e = LoopBlocks.size(); i != e; ++i) {
-    BasicBlock *NewBB = CloneBasicBlock(LoopBlocks[i], ValueMap, ".us", F);
+    BasicBlock *NewBB = CloneBasicBlock(LoopBlocks[i], VMap, ".us", F);
     NewBlocks.push_back(NewBB);
-    ValueMap[LoopBlocks[i]] = NewBB;  // Keep the BB mapping.
+    VMap[LoopBlocks[i]] = NewBB;  // Keep the BB mapping.
     LPM->cloneBasicBlockSimpleAnalysis(LoopBlocks[i], NewBB, L);
   }
 
@@ -629,7 +625,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
                                 NewBlocks[0], F->end());
 
   // Now we create the new Loop object for the versioned loop.
-  Loop *NewLoop = CloneLoop(L, L->getParentLoop(), ValueMap, LI, LPM);
+  Loop *NewLoop = CloneLoop(L, L->getParentLoop(), VMap, LI, LPM);
   Loop *ParentLoop = L->getParentLoop();
   if (ParentLoop) {
     // Make sure to add the cloned preheader and exit blocks to the parent loop
@@ -638,7 +634,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
   }
   
   for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) {
-    BasicBlock *NewExit = cast<BasicBlock>(ValueMap[ExitBlocks[i]]);
+    BasicBlock *NewExit = cast<BasicBlock>(VMap[ExitBlocks[i]]);
     // The new exit block should be in the same loop as the old one.
     if (Loop *ExitBBLoop = LI->getLoopFor(ExitBlocks[i]))
       ExitBBLoop->addBasicBlockToLoop(NewExit, LI->getBase());
@@ -653,8 +649,8 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
     for (BasicBlock::iterator I = ExitSucc->begin(); isa<PHINode>(I); ++I) {
       PN = cast<PHINode>(I);
       Value *V = PN->getIncomingValueForBlock(ExitBlocks[i]);
-      DenseMap<const Value *, Value*>::iterator It = ValueMap.find(V);
-      if (It != ValueMap.end()) V = It->second;
+      ValueToValueMapTy::iterator It = VMap.find(V);
+      if (It != VMap.end()) V = It->second;
       PN->addIncoming(V, NewExit);
     }
   }
@@ -663,7 +659,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,
   for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i)
     for (BasicBlock::iterator I = NewBlocks[i]->begin(),
            E = NewBlocks[i]->end(); I != E; ++I)
-      RemapInstruction(I, ValueMap);
+      RemapInstruction(I, VMap,RF_NoModuleLevelChanges|RF_IgnoreMissingEntries);
   
   // Rewrite the original preheader to select between versions of the loop.
   BranchInst *OldBR = cast<BranchInst>(loopPreheader->getTerminator());
@@ -793,8 +789,13 @@ void LoopUnswitch::RemoveBlockIfDead(BasicBlock *BB,
   // If this is the edge to the header block for a loop, remove the loop and
   // promote all subloops.
   if (Loop *BBLoop = LI->getLoopFor(BB)) {
-    if (BBLoop->getLoopLatch() == BB)
+    if (BBLoop->getLoopLatch() == BB) {
       RemoveLoopFromHierarchy(BBLoop);
+      if (currentLoop == BBLoop) {
+        currentLoop = 0;
+        redoLoop = false;
+      }
+    }
   }
 
   // Remove the block from the loop info, which removes it from any loops it
@@ -866,7 +867,6 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
   
   // FOLD boolean conditions (X|LIC), (X&LIC).  Fold conditional branches,
   // selects, switches.
-  std::vector<User*> Users(LIC->use_begin(), LIC->use_end());
   std::vector<Instruction*> Worklist;
   LLVMContext &Context = Val->getContext();
 
@@ -882,13 +882,14 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
       Replacement = ConstantInt::get(Type::getInt1Ty(Val->getContext()), 
                                      !cast<ConstantInt>(Val)->getZExtValue());
     
-    for (unsigned i = 0, e = Users.size(); i != e; ++i)
-      if (Instruction *U = cast<Instruction>(Users[i])) {
-        if (!L->contains(U))
-          continue;
-        U->replaceUsesOfWith(LIC, Replacement);
-        Worklist.push_back(U);
-      }
+    for (Value::use_iterator UI = LIC->use_begin(), E = LIC->use_end();
+         UI != E; ++UI) {
+      Instruction *U = dyn_cast<Instruction>(*UI);
+      if (!U || !L->contains(U))
+        continue;
+      U->replaceUsesOfWith(LIC, Replacement);
+      Worklist.push_back(U);
+    }
     SimplifyCode(Worklist, L);
     return;
   }
@@ -896,9 +897,10 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
   // Otherwise, we don't know the precise value of LIC, but we do know that it
   // is certainly NOT "Val".  As such, simplify any uses in the loop that we
   // can.  This case occurs when we unswitch switch statements.
-  for (unsigned i = 0, e = Users.size(); i != e; ++i) {
-    Instruction *U = cast<Instruction>(Users[i]);
-    if (!L->contains(U))
+  for (Value::use_iterator UI = LIC->use_begin(), E = LIC->use_end();
+       UI != E; ++UI) {
+    Instruction *U = dyn_cast<Instruction>(*UI);
+    if (!U || !L->contains(U))
       continue;
 
     Worklist.push_back(U);
@@ -916,13 +918,22 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
     // Found a dead case value.  Don't remove PHI nodes in the 
     // successor if they become single-entry, those PHI nodes may
     // be in the Users list.
-        
+
+    BasicBlock *Switch = SI->getParent();
+    BasicBlock *SISucc = SI->getSuccessor(DeadCase);
+    BasicBlock *Latch = L->getLoopLatch();
+    if (!SI->findCaseDest(SISucc)) continue;  // Edge is critical.
+    // If the DeadCase successor dominates the loop latch, then the
+    // transformation isn't safe since it will delete the sole predecessor edge
+    // to the latch.
+    if (Latch && DT->dominates(SISucc, Latch))
+      continue;
+
     // FIXME: This is a hack.  We need to keep the successor around
     // and hooked up so as to preserve the loop structure, because
     // trying to update it is complicated.  So instead we preserve the
     // loop structure and put the block on a dead code path.
-    BasicBlock *Switch = SI->getParent();
-    SplitEdge(Switch, SI->getSuccessor(DeadCase), this);
+    SplitEdge(Switch, SISucc, this);
     // Compute the successors instead of relying on the return value
     // of SplitEdge, since it may have split the switch successor
     // after PHI nodes.
@@ -967,13 +978,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {
   while (!Worklist.empty()) {
     Instruction *I = Worklist.back();
     Worklist.pop_back();
-    
-    // Simple constant folding.
-    if (Constant *C = ConstantFoldInstruction(I)) {
-      ReplaceUsesOfWith(I, C, Worklist, L, LPM);
-      continue;
-    }
-    
+
     // Simple DCE.
     if (isInstructionTriviallyDead(I)) {
       DEBUG(dbgs() << "Remove dead instruction '" << *I);
@@ -988,15 +993,16 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {
       ++NumSimplify;
       continue;
     }
-    
+
     // See if instruction simplification can hack this up.  This is common for
     // things like "select false, X, Y" after unswitching made the condition be
     // 'false'.
-    if (Value *V = SimplifyInstruction(I)) {
-      ReplaceUsesOfWith(I, V, Worklist, L, LPM);
-      continue;
-    }
-    
+    if (Value *V = SimplifyInstruction(I, 0, DT))
+      if (LI->replacementPreservesLCSSAForm(I, V)) {
+        ReplaceUsesOfWith(I, V, Worklist, L, LPM);
+        continue;
+      }
+
     // Special case hacks that appear commonly in unswitched code.
     if (BranchInst *BI = dyn_cast<BranchInst>(I)) {
       if (BI->isUnconditional()) {