add some notes, move some code around. Implement unswitching of loops
authorChris Lattner <sabre@nondot.org>
Fri, 10 Feb 2006 02:30:37 +0000 (02:30 +0000)
committerChris Lattner <sabre@nondot.org>
Fri, 10 Feb 2006 02:30:37 +0000 (02:30 +0000)
with branches on partially invariant computations.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@26104 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/LoopUnswitch.cpp

index 524b0104f26fd38b096ad525efcd4966cc22bbd0..600061de354e51d98edd6cd01de622ddd9afd4d7 100644 (file)
@@ -115,6 +115,27 @@ static bool LoopValuesUsedOutsideLoop(Loop *L) {
   return false;
 }
 
+/// FindTrivialLoopExitBlock - We know that we have a branch from the loop
+/// header to the specified latch block.   See if one of the successors of the
+/// latch block is an exit, and if so what block it is.
+static BasicBlock *FindTrivialLoopExitBlock(Loop *L, BasicBlock *Latch) {
+  BasicBlock *Header = L->getHeader();
+  BranchInst *LatchBranch = dyn_cast<BranchInst>(Latch->getTerminator());
+  if (!LatchBranch || !LatchBranch->isConditional()) return 0;
+  
+  // Simple case, the latch block is a conditional branch.  The target that
+  // doesn't go to the loop header is our block if it is not in the loop.
+  if (LatchBranch->getSuccessor(0) == Header) {
+    if (L->contains(LatchBranch->getSuccessor(1))) return false;
+    return LatchBranch->getSuccessor(1);
+  } else {
+    assert(LatchBranch->getSuccessor(1) == Header);
+    if (L->contains(LatchBranch->getSuccessor(0))) return false;
+    return LatchBranch->getSuccessor(0);
+  }
+}
+
+
 /// IsTrivialUnswitchCondition - Check to see if this unswitch condition is
 /// trivial: that is, that the condition controls whether or not the loop does
 /// anything at all.  If this is a trivial condition, unswitching produces no
@@ -149,17 +170,9 @@ static bool IsTrivialUnswitchCondition(Loop *L, Value *Cond,
   
   // The latch block must end with a conditional branch where one edge goes to
   // the header (this much we know) and one edge goes OUT of the loop.
-  BranchInst *LatchBranch = dyn_cast<BranchInst>(Latch->getTerminator());
-  if (!LatchBranch || !LatchBranch->isConditional()) return false;
-
-  if (LatchBranch->getSuccessor(0) == Header) {
-    if (L->contains(LatchBranch->getSuccessor(1))) return false;
-    if (LoopExit) *LoopExit = LatchBranch->getSuccessor(1);
-  } else {
-    assert(LatchBranch->getSuccessor(1) == Header);
-    if (L->contains(LatchBranch->getSuccessor(0))) return false;
-    if (LoopExit) *LoopExit = LatchBranch->getSuccessor(0);
-  }
+  BasicBlock *LoopExitBlock = FindTrivialLoopExitBlock(L, Latch);
+  if (!LoopExitBlock) return 0;
+  if (LoopExit) *LoopExit = LoopExitBlock;
   
   // We already know that nothing uses any scalar values defined inside of this
   // loop.  As such, we just have to check to see if this loop will execute any
@@ -201,6 +214,32 @@ unsigned LoopUnswitch::getLoopUnswitchCost(Loop *L, Value *LIC) {
   return Cost;
 }
 
+/// FindLIVLoopCondition - Cond is a condition that occurs in L.  If it is
+/// invariant in the loop, or has an invariant piece, return the invariant.
+/// Otherwise, return null.
+static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) {
+  // Constants should be folded, not unswitched on!
+  if (isa<Constant>(Cond)) return false;
+  
+  // TODO: Handle: br (VARIANT|INVARIANT).
+  // TODO: Hoist simple expressions out of loops.
+  if (L->isLoopInvariant(Cond)) return Cond;
+  
+  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond))
+    if (BO->getOpcode() == Instruction::And ||
+        BO->getOpcode() == Instruction::Or) {
+      // If either the left or right side is invariant, we can unswitch on this,
+      // which will cause the branch to go away in one loop and the condition to
+      // simplify in the other one.
+      if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed))
+        return LHS;
+      if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed))
+        return RHS;
+    }
+  
+  return 0;
+}
+
 bool LoopUnswitch::visitLoop(Loop *L) {
   bool Changed = false;
 
@@ -217,6 +256,8 @@ bool LoopUnswitch::visitLoop(Loop *L) {
   for (Loop::block_iterator I = L->block_begin(), E = L->block_end();
        I != E; ++I) {
     TerminatorInst *TI = (*I)->getTerminator();
+    // FIXME: Handle invariant select instructions.
+    
     if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
       if (!isa<Constant>(SI) && L->isLoopInvariant(SI->getCondition()))
         DEBUG(std::cerr << "TODO: Implement unswitching 'switch' loop %"
@@ -229,12 +270,16 @@ bool LoopUnswitch::visitLoop(Loop *L) {
     if (!BI) continue;
     
     // If this isn't branching on an invariant condition, we can't unswitch it.
-    if (!BI->isConditional() || isa<Constant>(BI->getCondition()) ||
-        !L->isLoopInvariant(BI->getCondition()))
+    if (!BI->isConditional())
       continue;
     
+    // See if this, or some part of it, is loop invariant.  If so, we can
+    // unswitch on it if we desire.
+    Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), L, Changed);
+    if (LoopCond == 0) continue;
+    
     // Check to see if it would be profitable to unswitch this loop.
-    if (getLoopUnswitchCost(L, BI->getCondition()) > Threshold) {
+    if (getLoopUnswitchCost(L, LoopCond) > Threshold) {
       // FIXME: this should estimate growth by the amount of code shared by the
       // resultant unswitched loops.  This should have no code growth:
       //    for () { if (iv) {...} }
@@ -263,13 +308,11 @@ bool LoopUnswitch::visitLoop(Loop *L) {
     // duplication), do it now.
     bool EntersLoopOnCond;
     BasicBlock *ExitBlock;
-    if (IsTrivialUnswitchCondition(L, BI->getCondition(), &EntersLoopOnCond,
-                                   &ExitBlock)) {
-      UnswitchTrivialCondition(L, BI->getCondition(), 
-                               EntersLoopOnCond, ExitBlock);
+    if (IsTrivialUnswitchCondition(L, LoopCond, &EntersLoopOnCond, &ExitBlock)){
+      UnswitchTrivialCondition(L, LoopCond, EntersLoopOnCond, ExitBlock);
       NewLoop1 = L;
     } else {
-      VersionLoop(BI->getCondition(), L, NewLoop1, NewLoop2);
+      VersionLoop(LoopCond, L, NewLoop1, NewLoop2);
     }
     
     //std::cerr << "AFTER:\n"; LI->dump();
@@ -489,6 +532,8 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,
   //      ...
   ConstantBool *BoolVal = ConstantBool::get(Val);
 
+  // FOLD boolean conditions (X|LIC), (X&LIC).  Fold conditional branches,
+  // selects, switches.
   std::vector<User*> Users(LIC->use_begin(), LIC->use_end());
   for (unsigned i = 0, e = Users.size(); i != e; ++i)
     if (Instruction *U = cast<Instruction>(Users[i]))