implement unswitching of loops with switch stmts and selects in them
authorChris Lattner <sabre@nondot.org>
Sat, 11 Feb 2006 00:43:37 +0000 (00:43 +0000)
committerChris Lattner <sabre@nondot.org>
Sat, 11 Feb 2006 00:43:37 +0000 (00:43 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@26114 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/LoopUnswitch.cpp

index 15e1d992f85be2307b086e0197f20f20e1c2a3cd..932bfe307877e0db6335bfde5e79bd74eafbdf19 100644 (file)
@@ -66,10 +66,13 @@ namespace {
     }
 
   private:
+    bool UnswitchIfProfitable(Value *LoopCond, Constant *Val,Loop *L);
     unsigned getLoopUnswitchCost(Loop *L, Value *LIC);
-    void VersionLoop(Value *LIC, Loop *L, Loop *&Out1, Loop *&Out2);
+    void VersionLoop(Value *LIC, Constant *OnVal,
+                     Loop *L, Loop *&Out1, Loop *&Out2);
     BasicBlock *SplitEdge(BasicBlock *From, BasicBlock *To);
-    void RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, bool Val);
+    void RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,Constant *Val,
+                                              bool isEqual);
     void UnswitchTrivialCondition(Loop *L, Value *Cond, bool EntersLoopOnCond,
                                   BasicBlock *ExitBlock);
   };
@@ -256,85 +259,86 @@ bool LoopUnswitch::visitLoop(Loop *L) {
   // loop.
   for (Loop::block_iterator I = L->block_begin(), E = L->block_end();
        I != E; ++I) {
+    TerminatorInst *TI = (*I)->getTerminator();
+    if (BranchInst *BI = dyn_cast<BranchInst>(TI)) {
+      // If this isn't branching on an invariant condition, we can't unswitch
+      // it.
+      if (BI->isConditional()) {
+        // See if this, or some part of it, is loop invariant.  If so, we can
+        // unswitch on it if we desire.
+        Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), L, Changed);
+        if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantBool::True, L))
+          return true;
+      }      
+    } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
+      Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), L, Changed);
+      if (LoopCond && SI->getNumCases() > 1) {
+        // Find a value to unswitch on:
+        // FIXME: this should chose the most expensive case!
+        Constant *UnswitchVal = SI->getCaseValue(1);
+        if (UnswitchIfProfitable(LoopCond, UnswitchVal, L))
+          return true;
+      }
+    }
+    
+    // Scan the instructions to check for unswitchable values.
     for (BasicBlock::iterator BBI = (*I)->begin(), E = (*I)->end(); 
          BBI != E; ++BBI)
       if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) {
         Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), L, Changed);
-        if (LoopCond == 0) continue;
-        
-        //if (UnswitchIfProfitable(LoopCond, 
-        std::cerr << "LOOP INVARIANT SELECT: " << *SI;
+        if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantBool::True, L))
+          return true;
       }
-        
-    TerminatorInst *TI = (*I)->getTerminator();
-    // FIXME: Handle invariant select instructions.
-    
-    if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
-      if (!isa<Constant>(SI) && L->isLoopInvariant(SI->getCondition()))
-        DEBUG(std::cerr << "TODO: Implement unswitching 'switch' loop %"
-              << L->getHeader()->getName() << ", cost = "
-              << L->getBlocks().size() << "\n" << **I);
-      continue;
-    }
-    
-    BranchInst *BI = dyn_cast<BranchInst>(TI);
-    if (!BI) continue;
-    
-    // If this isn't branching on an invariant condition, we can't unswitch it.
-    if (!BI->isConditional())
-      continue;
-    
-    // See if this, or some part of it, is loop invariant.  If so, we can
-    // unswitch on it if we desire.
-    Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), L, Changed);
-    if (LoopCond == 0) continue;
+  }
     
-    // Check to see if it would be profitable to unswitch this loop.
-    if (getLoopUnswitchCost(L, LoopCond) > Threshold) {
-      // FIXME: this should estimate growth by the amount of code shared by the
-      // resultant unswitched loops.  This should have no code growth:
-      //    for () { if (iv) {...} }
-      // as one copy of the loop will be empty.
-      //
-      DEBUG(std::cerr << "NOT unswitching loop %"
-            << L->getHeader()->getName() << ", cost too high: "
-            << L->getBlocks().size() << "\n");
-      continue;
-    }
+  return Changed;
+}
+
+/// UnswitchIfProfitable - We have found that we can unswitch L when
+/// LoopCond == Val to simplify the loop.  If we decide that this is profitable,
+/// unswitch the loop, reprocess the pieces, then return true.
+bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val,Loop *L){
+  // Check to see if it would be profitable to unswitch this loop.
+  if (getLoopUnswitchCost(L, LoopCond) > Threshold) {
+    // FIXME: this should estimate growth by the amount of code shared by the
+    // resultant unswitched loops.
+    //
+    DEBUG(std::cerr << "NOT unswitching loop %"
+                    << L->getHeader()->getName() << ", cost too high: "
+                    << L->getBlocks().size() << "\n");
+    return false;
+  }
     
-    // If this loop has live-out values, we can't unswitch it. We need something
-    // like loop-closed SSA form in order to know how to insert PHI nodes for
-    // these values.
-    if (LoopValuesUsedOutsideLoop(L)) {
-      DEBUG(std::cerr << "NOT unswitching loop %"
-                      << L->getHeader()->getName()
-                      << ", a loop value is used outside loop!\n");
-      continue;
-    }
+  // If this loop has live-out values, we can't unswitch it. We need something
+  // like loop-closed SSA form in order to know how to insert PHI nodes for
+  // these values.
+  if (LoopValuesUsedOutsideLoop(L)) {
+    DEBUG(std::cerr << "NOT unswitching loop %" << L->getHeader()->getName()
+                    << ", a loop value is used outside loop!\n");
+    return false;
+  }
       
-    //std::cerr << "BEFORE:\n"; LI->dump();
-    Loop *NewLoop1 = 0, *NewLoop2 = 0;
+  //std::cerr << "BEFORE:\n"; LI->dump();
+  Loop *NewLoop1 = 0, *NewLoop2 = 0;
  
-    // If this is a trivial condition to unswitch (which results in no code
-    // duplication), do it now.
-    bool EntersLoopOnCond;
-    BasicBlock *ExitBlock;
-    if (IsTrivialUnswitchCondition(L, LoopCond, &EntersLoopOnCond, &ExitBlock)){
-      UnswitchTrivialCondition(L, LoopCond, EntersLoopOnCond, ExitBlock);
-      NewLoop1 = L;
-    } else {
-      VersionLoop(LoopCond, L, NewLoop1, NewLoop2);
-    }
-    
-    //std::cerr << "AFTER:\n"; LI->dump();
-    
-    // Try to unswitch each of our new loops now!
-    if (NewLoop1) visitLoop(NewLoop1);
-    if (NewLoop2) visitLoop(NewLoop2);
-    return true;
+  // If this is a trivial condition to unswitch (which results in no code
+  // duplication), do it now.
+  bool EntersLoopOnCond;
+  BasicBlock *ExitBlock;
+  if (IsTrivialUnswitchCondition(L, LoopCond, &EntersLoopOnCond, &ExitBlock)){
+    UnswitchTrivialCondition(L, LoopCond, EntersLoopOnCond, ExitBlock);
+    NewLoop1 = L;
+  } else {
+    VersionLoop(LoopCond, Val, L, NewLoop1, NewLoop2);
   }
-
-  return Changed;
+  ++NumUnswitched;
+  
+  //std::cerr << "AFTER:\n"; LI->dump();
+  
+  // Try to unswitch each of our new loops now!
+  if (NewLoop1) visitLoop(NewLoop1);
+  if (NewLoop2) visitLoop(NewLoop2);
+  return true;
 }
 
 BasicBlock *LoopUnswitch::SplitEdge(BasicBlock *BB, BasicBlock *Succ) {
@@ -456,23 +460,22 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond,
   // Now that we know that the loop is never entered when this condition is a
   // particular value, rewrite the loop with this info.  We know that this will
   // at least eliminate the old branch.
-  RewriteLoopBodyWithConditionConstant(L, Cond, EnterOnCond);
-  
-  ++NumUnswitched;
+  RewriteLoopBodyWithConditionConstant(L, Cond, ConstantBool::get(EnterOnCond),
+                                       true);
 }
 
 
-/// VersionLoop - We determined that the loop is profitable to unswitch and
-/// contains a branch on a loop invariant condition.  Split it into loop
-/// versions and test the condition outside of either loop.  Return the loops
-/// created as Out1/Out2.
-void LoopUnswitch::VersionLoop(Value *LIC, Loop *L, Loop *&Out1, Loop *&Out2) {
+/// VersionLoop - We determined that the loop is profitable to unswitch when LIC
+/// equal Val.  Split it into loop versions and test the condition outside of
+/// either loop.  Return the loops created as Out1/Out2.
+void LoopUnswitch::VersionLoop(Value *LIC, Constant *Val, Loop *L,
+                               Loop *&Out1, Loop *&Out2) {
   Function *F = L->getHeader()->getParent();
   
   DEBUG(std::cerr << "loop-unswitch: Unswitching loop %"
-        << L->getHeader()->getName() << " [" << L->getBlocks().size()
-        << " blocks] in Function " << F->getName()
-        << " on cond:" << *LIC << "\n");
+                  << L->getHeader()->getName() << " [" << L->getBlocks().size()
+                  << " blocks] in Function " << F->getName()
+                  << " when '" << *Val << "' == " << *LIC << "\n");
 
   // LoopBlocks contains all of the basic blocks of the loop, including the
   // preheader of the loop, the body of the loop, and the exit blocks of the 
@@ -572,41 +575,79 @@ void LoopUnswitch::VersionLoop(Value *LIC, Loop *L, Loop *&Out1, Loop *&Out2) {
          cast<BranchInst>(OrigPreheader->getTerminator())->isUnconditional() &&
          OrigPreheader->getTerminator()->getSuccessor(0) == LoopBlocks[0] &&
          "Preheader splitting did not work correctly!");
-  // Remove the unconditional branch to LoopBlocks[0].
-  OrigPreheader->getInstList().pop_back();
 
   // Insert a conditional branch on LIC to the two preheaders.  The original
   // code is the true version and the new code is the false version.
-  new BranchInst(LoopBlocks[0], NewBlocks[0], LIC, OrigPreheader);
+  Value *BranchVal = LIC;
+  if (!isa<ConstantBool>(BranchVal)) {
+    BranchVal = BinaryOperator::createSetEQ(LIC, Val, "tmp",
+                                            OrigPreheader->getTerminator());
+  } else if (Val != ConstantBool::True) {
+    // We want to enter the new loop when the condition is true.
+    BranchVal = BinaryOperator::createNot(BranchVal, "tmp",
+                                          OrigPreheader->getTerminator());
+  }
+  
+  // Remove the unconditional branch to LoopBlocks[0] and insert the new branch.
+  OrigPreheader->getInstList().pop_back();
+  new BranchInst(NewBlocks[0], LoopBlocks[0], BranchVal, OrigPreheader);
 
   // Now we rewrite the original code to know that the condition is true and the
   // new code to know that the condition is false.
-  RewriteLoopBodyWithConditionConstant(L, LIC, true);
-  RewriteLoopBodyWithConditionConstant(NewLoop, LIC, false);
-  ++NumUnswitched;
+  RewriteLoopBodyWithConditionConstant(L, LIC, Val, false);
+  RewriteLoopBodyWithConditionConstant(NewLoop, LIC, Val, true);
   Out1 = L;
   Out2 = NewLoop;
 }
 
-// RewriteLoopBodyWithConditionConstant - We know that the boolean value LIC has
-// the value specified by Val in the specified loop.  Rewrite any uses of LIC or
-// of properties correlated to it.
+// RewriteLoopBodyWithConditionConstant - We know either that the value LIC has
+// the value specified by Val in the specified loop, or we know it does NOT have
+// that value.  Rewrite any uses of LIC or of properties correlated to it.
 void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
-                                                        bool Val) {
+                                                        Constant *Val,
+                                                        bool IsEqual) {
   assert(!isa<Constant>(LIC) && "Why are we unswitching on a constant?");
+  
   // FIXME: Support correlated properties, like:
   //  for (...)
   //    if (li1 < li2)
   //      ...
   //    if (li1 > li2)
   //      ...
-  ConstantBool *BoolVal = ConstantBool::get(Val);
 
+  // NotVal - If Val is a bool, this contains its inverse.
+  Constant *NotVal = 0;
+  if (ConstantBool *CB = dyn_cast<ConstantBool>(Val))
+    NotVal = ConstantBool::get(!CB->getValue());
+  
   // FOLD boolean conditions (X|LIC), (X&LIC).  Fold conditional branches,
   // selects, switches.
   std::vector<User*> Users(LIC->use_begin(), LIC->use_end());
+  
+  // Haha, this loop could be unswitched.  Get it? The unswitch pass could
+  // unswitch itself. Amazing.
   for (unsigned i = 0, e = Users.size(); i != e; ++i)
     if (Instruction *U = cast<Instruction>(Users[i]))
       if (L->contains(U->getParent()))
-        U->replaceUsesOfWith(LIC, BoolVal);
+        if (IsEqual) {
+          U->replaceUsesOfWith(LIC, Val);
+        } else if (NotVal) {
+          U->replaceUsesOfWith(LIC, NotVal);
+        } else {
+          // If we know that LIC is not Val, use this info to simplify code.
+          if (SwitchInst *SI = dyn_cast<SwitchInst>(U)) {
+            for (unsigned i = 1, e = SI->getNumCases(); i != e; ++i) {
+              if (SI->getCaseValue(i) == Val) {
+                // Found a dead case value.  Don't remove PHI nodes in the 
+                // successor if they become single-entry, those PHI nodes may
+                // be in the Users list.
+                SI->getSuccessor(i)->removePredecessor(SI->getParent(), true);
+                SI->removeCase(i);
+                break;
+              }
+            }
+          }
+
+          // TODO: We could simplify stuff like X == C.
+        }
 }