The inliner/cloner can now optionally take TargetData info, which can be
[oota-llvm.git] / lib / Transforms / Utils / CloneFunction.cpp
index 623da058db971f8002101447ece86a40d8dcca19..12dff3bf588a870f7d17638827969055bf68b844 100644 (file)
 #include "llvm/DerivedTypes.h"
 #include "llvm/Instructions.h"
 #include "llvm/Function.h"
+#include "llvm/Support/CFG.h"
 #include "ValueMapper.h"
 #include "llvm/Transforms/Utils/Local.h"
+#include "llvm/ADT/SmallVector.h"
 using namespace llvm;
 
 // CloneBasicBlock - See comments in Cloning.h
@@ -156,15 +158,17 @@ namespace {
     std::vector<ReturnInst*> &Returns;
     const char *NameSuffix;
     ClonedCodeInfo *CodeInfo;
+    const TargetData *TD;
 
   public:
     PruningFunctionCloner(Function *newFunc, const Function *oldFunc,
                           std::map<const Value*, Value*> &valueMap,
                           std::vector<ReturnInst*> &returns,
                           const char *nameSuffix, 
-                          ClonedCodeInfo *codeInfo)
+                          ClonedCodeInfo *codeInfo,
+                          const TargetData *td)
     : NewFunc(newFunc), OldFunc(oldFunc), ValueMap(valueMap), Returns(returns),
-      NameSuffix(nameSuffix), CodeInfo(codeInfo) {
+      NameSuffix(nameSuffix), CodeInfo(codeInfo), TD(td) {
     }
 
     /// CloneBlock - The specified block is found to be reachable, clone it and
@@ -195,7 +199,7 @@ void PruningFunctionCloner::CloneBlock(const BasicBlock *BB) {
   
   // Loop over all instructions, and copy them over, DCE'ing as we go.  This
   // loop doesn't include the terminator.
-  for (BasicBlock::const_iterator II = BB->begin(), IE = BB->end();
+  for (BasicBlock::const_iterator II = BB->begin(), IE = --BB->end();
        II != IE; ++II) {
     // If this instruction constant folds, don't bother cloning the instruction,
     // instead, just add the constant to the value map.
@@ -219,9 +223,54 @@ void PruningFunctionCloner::CloneBlock(const BasicBlock *BB) {
     }
   }
   
+  // Finally, clone over the terminator.
+  const TerminatorInst *OldTI = BB->getTerminator();
+  bool TerminatorDone = false;
+  if (const BranchInst *BI = dyn_cast<BranchInst>(OldTI)) {
+    if (BI->isConditional()) {
+      // If the condition was a known constant in the callee...
+      ConstantInt *Cond = dyn_cast<ConstantInt>(BI->getCondition());
+      // Or is a known constant in the caller...
+      if (Cond == 0)  
+        Cond = dyn_cast_or_null<ConstantInt>(ValueMap[BI->getCondition()]);
+
+      // Constant fold to uncond branch!
+      if (Cond) {
+        BasicBlock *Dest = BI->getSuccessor(!Cond->getZExtValue());
+        ValueMap[OldTI] = new BranchInst(Dest, NewBB);
+        CloneBlock(Dest);
+        TerminatorDone = true;
+      }
+    }
+  } else if (const SwitchInst *SI = dyn_cast<SwitchInst>(OldTI)) {
+    // If switching on a value known constant in the caller.
+    ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition());
+    if (Cond == 0)  // Or known constant after constant prop in the callee...
+      Cond = dyn_cast_or_null<ConstantInt>(ValueMap[SI->getCondition()]);
+    if (Cond) {     // Constant fold to uncond branch!
+      BasicBlock *Dest = SI->getSuccessor(SI->findCaseValue(Cond));
+      ValueMap[OldTI] = new BranchInst(Dest, NewBB);
+      CloneBlock(Dest);
+      TerminatorDone = true;
+    }
+  }
+  
+  if (!TerminatorDone) {
+    Instruction *NewInst = OldTI->clone();
+    if (OldTI->hasName())
+      NewInst->setName(OldTI->getName()+NameSuffix);
+    NewBB->getInstList().push_back(NewInst);
+    ValueMap[OldTI] = NewInst;             // Add instruction map to value.
+    
+    // Recursively clone any reachable successor blocks.
+    const TerminatorInst *TI = BB->getTerminator();
+    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
+      CloneBlock(TI->getSuccessor(i));
+  }
+  
   if (CodeInfo) {
     CodeInfo->ContainsCalls          |= hasCalls;
-    CodeInfo->ContainsUnwinds        |= isa<UnwindInst>(BB->getTerminator());
+    CodeInfo->ContainsUnwinds        |= isa<UnwindInst>(OldTI);
     CodeInfo->ContainsDynamicAllocas |= hasDynamicAllocas;
     CodeInfo->ContainsDynamicAllocas |= hasStaticAllocas && 
       BB != &BB->getParent()->front();
@@ -229,27 +278,13 @@ void PruningFunctionCloner::CloneBlock(const BasicBlock *BB) {
   
   if (ReturnInst *RI = dyn_cast<ReturnInst>(NewBB->getTerminator()))
     Returns.push_back(RI);
-  
-  // Recursively clone any reachable successor blocks.
-  const TerminatorInst *TI = BB->getTerminator();
-  for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
-    CloneBlock(TI->getSuccessor(i));
 }
 
 /// ConstantFoldMappedInstruction - Constant fold the specified instruction,
 /// mapping its operands through ValueMap if they are available.
 Constant *PruningFunctionCloner::
 ConstantFoldMappedInstruction(const Instruction *I) {
-  if (isa<BinaryOperator>(I) || isa<ShiftInst>(I)) {
-    if (Constant *Op0 = dyn_cast_or_null<Constant>(MapValue(I->getOperand(0),
-                                                            ValueMap)))
-      if (Constant *Op1 = dyn_cast_or_null<Constant>(MapValue(I->getOperand(1),
-                                                              ValueMap)))
-        return ConstantExpr::get(I->getOpcode(), Op0, Op1);
-    return 0;
-  }
-
-  std::vector<Constant*> Ops;
+  SmallVector<Constant*, 8> Ops;
   for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i)
     if (Constant *Op = dyn_cast_or_null<Constant>(MapValue(I->getOperand(i),
                                                            ValueMap)))
@@ -257,10 +292,9 @@ ConstantFoldMappedInstruction(const Instruction *I) {
     else
       return 0;  // All operands not constant!
 
-  return ConstantFoldInstOperands(I->getOpcode(), I->getType(), Ops);
+  return ConstantFoldInstOperands(I, &Ops[0], Ops.size(), TD);
 }
 
-
 /// CloneAndPruneFunctionInto - This works exactly like CloneFunctionInto,
 /// except that it does some simple constant prop and DCE on the fly.  The
 /// effect of this is to copy significantly less code in cases where (for
@@ -272,17 +306,18 @@ void llvm::CloneAndPruneFunctionInto(Function *NewFunc, const Function *OldFunc,
                                      std::map<const Value*, Value*> &ValueMap,
                                      std::vector<ReturnInst*> &Returns,
                                      const char *NameSuffix, 
-                                     ClonedCodeInfo *CodeInfo) {
+                                     ClonedCodeInfo *CodeInfo,
+                                     const TargetData *TD) {
   assert(NameSuffix && "NameSuffix cannot be null!");
   
 #ifndef NDEBUG
-  for (Function::const_arg_iterator I = OldFunc->arg_begin(), 
-       E = OldFunc->arg_end(); I != E; ++I)
-    assert(ValueMap.count(I) && "No mapping from source argument specified!");
+  for (Function::const_arg_iterator II = OldFunc->arg_begin(), 
+       E = OldFunc->arg_end(); II != E; ++II)
+    assert(ValueMap.count(II) && "No mapping from source argument specified!");
 #endif
   
   PruningFunctionCloner PFC(NewFunc, OldFunc, ValueMap, Returns, 
-                            NameSuffix, CodeInfo);
+                            NameSuffix, CodeInfo, TD);
 
   // Clone the entry block, and anything recursively reachable from it.
   PFC.CloneBlock(&OldFunc->getEntryBlock());
@@ -291,11 +326,13 @@ void llvm::CloneAndPruneFunctionInto(Function *NewFunc, const Function *OldFunc,
   // reachable, we have cloned it and the old block is now in the value map:
   // insert it into the new function in the right order.  If not, ignore it.
   //
+  // Defer PHI resolution until rest of function is resolved.
+  std::vector<const PHINode*> PHIToResolve;
   for (Function::const_iterator BI = OldFunc->begin(), BE = OldFunc->end();
        BI != BE; ++BI) {
     BasicBlock *NewBB = cast_or_null<BasicBlock>(ValueMap[BI]);
     if (NewBB == 0) continue;  // Dead block.
-    
+
     // Add the new block to the new function.
     NewFunc->getBasicBlockList().push_back(NewBB);
     
@@ -307,27 +344,132 @@ void llvm::CloneAndPruneFunctionInto(Function *NewFunc, const Function *OldFunc,
     // Handle PHI nodes specially, as we have to remove references to dead
     // blocks.
     if (PHINode *PN = dyn_cast<PHINode>(I)) {
-      unsigned NumPreds = PN->getNumIncomingValues();
-      for (; (PN = dyn_cast<PHINode>(I)); ++I) {
-        for (unsigned pred = 0, e = NumPreds; pred != e; ++pred) {
-          if (BasicBlock *MappedBlock = 
-                   cast<BasicBlock>(ValueMap[PN->getIncomingBlock(pred)])) {
-            Value *InVal = MapValue(PN->getIncomingValue(pred), ValueMap);
-            assert(InVal && "Unknown input value?");
-            PN->setIncomingValue(pred, InVal);
-            PN->setIncomingBlock(pred, MappedBlock);
-          } else {
-            PN->removeIncomingValue(pred, false);
-            --pred, --e;  // Revisit the next entry.
-          }
-        }
-      }
+      // Skip over all PHI nodes, remembering them for later.
+      BasicBlock::const_iterator OldI = BI->begin();
+      for (; (PN = dyn_cast<PHINode>(I)); ++I, ++OldI)
+        PHIToResolve.push_back(cast<PHINode>(OldI));
     }
     
     // Otherwise, remap the rest of the instructions normally.
     for (; I != NewBB->end(); ++I)
       RemapInstruction(I, ValueMap);
   }
-}
+  
+  // Defer PHI resolution until rest of function is resolved, PHI resolution
+  // requires the CFG to be up-to-date.
+  for (unsigned phino = 0, e = PHIToResolve.size(); phino != e; ) {
+    const PHINode *OPN = PHIToResolve[phino];
+    unsigned NumPreds = OPN->getNumIncomingValues();
+    const BasicBlock *OldBB = OPN->getParent();
+    BasicBlock *NewBB = cast<BasicBlock>(ValueMap[OldBB]);
 
+    // Map operands for blocks that are live and remove operands for blocks
+    // that are dead.
+    for (; phino != PHIToResolve.size() &&
+         PHIToResolve[phino]->getParent() == OldBB; ++phino) {
+      OPN = PHIToResolve[phino];
+      PHINode *PN = cast<PHINode>(ValueMap[OPN]);
+      for (unsigned pred = 0, e = NumPreds; pred != e; ++pred) {
+        if (BasicBlock *MappedBlock = 
+            cast_or_null<BasicBlock>(ValueMap[PN->getIncomingBlock(pred)])) {
+          Value *InVal = MapValue(PN->getIncomingValue(pred), ValueMap);
+          assert(InVal && "Unknown input value?");
+          PN->setIncomingValue(pred, InVal);
+          PN->setIncomingBlock(pred, MappedBlock);
+        } else {
+          PN->removeIncomingValue(pred, false);
+          --pred, --e;  // Revisit the next entry.
+        }
+      } 
+    }
+    
+    // The loop above has removed PHI entries for those blocks that are dead
+    // and has updated others.  However, if a block is live (i.e. copied over)
+    // but its terminator has been changed to not go to this block, then our
+    // phi nodes will have invalid entries.  Update the PHI nodes in this
+    // case.
+    PHINode *PN = cast<PHINode>(NewBB->begin());
+    NumPreds = std::distance(pred_begin(NewBB), pred_end(NewBB));
+    if (NumPreds != PN->getNumIncomingValues()) {
+      assert(NumPreds < PN->getNumIncomingValues());
+      // Count how many times each predecessor comes to this block.
+      std::map<BasicBlock*, unsigned> PredCount;
+      for (pred_iterator PI = pred_begin(NewBB), E = pred_end(NewBB);
+           PI != E; ++PI)
+        --PredCount[*PI];
+      
+      // Figure out how many entries to remove from each PHI.
+      for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
+        ++PredCount[PN->getIncomingBlock(i)];
+      
+      // At this point, the excess predecessor entries are positive in the
+      // map.  Loop over all of the PHIs and remove excess predecessor
+      // entries.
+      BasicBlock::iterator I = NewBB->begin();
+      for (; (PN = dyn_cast<PHINode>(I)); ++I) {
+        for (std::map<BasicBlock*, unsigned>::iterator PCI =PredCount.begin(),
+             E = PredCount.end(); PCI != E; ++PCI) {
+          BasicBlock *Pred     = PCI->first;
+          for (unsigned NumToRemove = PCI->second; NumToRemove; --NumToRemove)
+            PN->removeIncomingValue(Pred, false);
+        }
+      }
+    }
+    
+    // If the loops above have made these phi nodes have 0 or 1 operand,
+    // replace them with undef or the input value.  We must do this for
+    // correctness, because 0-operand phis are not valid.
+    PN = cast<PHINode>(NewBB->begin());
+    if (PN->getNumIncomingValues() == 0) {
+      BasicBlock::iterator I = NewBB->begin();
+      BasicBlock::const_iterator OldI = OldBB->begin();
+      while ((PN = dyn_cast<PHINode>(I++))) {
+        Value *NV = UndefValue::get(PN->getType());
+        PN->replaceAllUsesWith(NV);
+        assert(ValueMap[OldI] == PN && "ValueMap mismatch");
+        ValueMap[OldI] = NV;
+        PN->eraseFromParent();
+        ++OldI;
+      }
+    } else if (PN->getNumIncomingValues() == 1) {
+      BasicBlock::iterator I = NewBB->begin();
+      BasicBlock::const_iterator OldI = OldBB->begin();
+      while ((PN = dyn_cast<PHINode>(I++))) {
+        Value *NV = PN->getIncomingValue(0);
+        PN->replaceAllUsesWith(NV);
+        assert(ValueMap[OldI] == PN && "ValueMap mismatch");
+        ValueMap[OldI] = NV;
+        PN->eraseFromParent();
+        ++OldI;
+      }
+    }
+  }
+  
+  // Now that the inlined function body has been fully constructed, go through
+  // and zap unconditional fall-through branches.  This happen all the time when
+  // specializing code: code specialization turns conditional branches into
+  // uncond branches, and this code folds them.
+  Function::iterator I = cast<BasicBlock>(ValueMap[&OldFunc->getEntryBlock()]);
+  while (I != NewFunc->end()) {
+    BranchInst *BI = dyn_cast<BranchInst>(I->getTerminator());
+    if (!BI || BI->isConditional()) { ++I; continue; }
+    
+    BasicBlock *Dest = BI->getSuccessor(0);
+    if (!Dest->getSinglePredecessor()) { ++I; continue; }
+    
+    // We know all single-entry PHI nodes in the inlined function have been
+    // removed, so we just need to splice the blocks.
+    BI->eraseFromParent();
+    
+    // Move all the instructions in the succ to the pred.
+    I->getInstList().splice(I->end(), Dest->getInstList());
+    
+    // Make all PHI nodes that referred to Dest now refer to I as their source.
+    Dest->replaceAllUsesWith(I);
 
+    // Remove the dest block.
+    Dest->eraseFromParent();
+    
+    // Do not increment I, iteratively merge all things this block branches to.
+  }
+}