Don't use PassInfo* as a type identifier for passes. Instead, use the address of...
[oota-llvm.git] / lib / Transforms / IPO / StructRetPromotion.cpp
index 7771bc4b74d0eaa4b69bc4320aa435186137951d..aa470b954d6a5ad53fd635608cecba1238b5d7df 100644 (file)
@@ -34,7 +34,7 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Statistic.h"
-#include "llvm/Support/Compiler.h"
+#include "llvm/Support/raw_ostream.h"
 using namespace llvm;
 
 STATISTIC(NumRejectedSRETUses , "Number of sret rejected due to unexpected uses");
@@ -43,75 +43,79 @@ namespace {
   /// SRETPromotion - This pass removes sret parameter and updates
   /// function to use multiple return value.
   ///
-  struct VISIBILITY_HIDDEN SRETPromotion : public CallGraphSCCPass {
+  struct SRETPromotion : public CallGraphSCCPass {
     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
       CallGraphSCCPass::getAnalysisUsage(AU);
     }
 
-    virtual bool runOnSCC(const std::vector<CallGraphNode *> &SCC);
+    virtual bool runOnSCC(CallGraphSCC &SCC);
     static char ID; // Pass identification, replacement for typeid
-    SRETPromotion() : CallGraphSCCPass(&ID) {}
+    SRETPromotion() : CallGraphSCCPass(ID) {}
 
   private:
-    bool PromoteReturn(CallGraphNode *CGN);
+    CallGraphNode *PromoteReturn(CallGraphNode *CGN);
     bool isSafeToUpdateAllCallers(Function *F);
     Function *cloneFunctionBody(Function *F, const StructType *STy);
-    void updateCallSites(Function *F, Function *NF);
+    CallGraphNode *updateCallSites(Function *F, Function *NF);
     bool nestedStructType(const StructType *STy);
   };
 }
 
 char SRETPromotion::ID = 0;
-static RegisterPass<SRETPromotion>
-X("sretpromotion", "Promote sret arguments to multiple ret values");
+INITIALIZE_PASS(SRETPromotion, "sretpromotion",
+                "Promote sret arguments to multiple ret values", false, false);
 
 Pass *llvm::createStructRetPromotionPass() {
   return new SRETPromotion();
 }
 
-bool SRETPromotion::runOnSCC(const std::vector<CallGraphNode *> &SCC) {
+bool SRETPromotion::runOnSCC(CallGraphSCC &SCC) {
   bool Changed = false;
 
-  for (unsigned i = 0, e = SCC.size(); i != e; ++i)
-    Changed |= PromoteReturn(SCC[i]);
+  for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I)
+    if (CallGraphNode *NewNode = PromoteReturn(*I)) {
+      SCC.ReplaceNode(*I, NewNode);
+      Changed = true;
+    }
 
   return Changed;
 }
 
 /// PromoteReturn - This method promotes function that uses StructRet paramater 
-/// into a function that uses mulitple return value.
-bool SRETPromotion::PromoteReturn(CallGraphNode *CGN) {
+/// into a function that uses multiple return values.
+CallGraphNode *SRETPromotion::PromoteReturn(CallGraphNode *CGN) {
   Function *F = CGN->getFunction();
 
   if (!F || F->isDeclaration() || !F->hasLocalLinkage())
-    return false;
+    return 0;
 
   // Make sure that function returns struct.
   if (F->arg_size() == 0 || !F->hasStructRetAttr() || F->doesNotReturn())
-    return false;
+    return 0;
 
-  DOUT << "SretPromotion: Looking at sret function " << F->getNameStart() << "\n";
+  DEBUG(dbgs() << "SretPromotion: Looking at sret function " 
+        << F->getName() << "\n");
 
-  assert (F->getReturnType() == Type::VoidTy && "Invalid function return type");
+  assert(F->getReturnType()->isVoidTy() && "Invalid function return type");
   Function::arg_iterator AI = F->arg_begin();
   const llvm::PointerType *FArgType = dyn_cast<PointerType>(AI->getType());
-  assert (FArgType && "Invalid sret parameter type");
+  assert(FArgType && "Invalid sret parameter type");
   const llvm::StructType *STy = 
     dyn_cast<StructType>(FArgType->getElementType());
-  assert (STy && "Invalid sret parameter element type");
+  assert(STy && "Invalid sret parameter element type");
 
   // Check if it is ok to perform this promotion.
   if (isSafeToUpdateAllCallers(F) == false) {
-    DOUT << "SretPromotion: Not all callers can be updated\n";
-    NumRejectedSRETUses++;
-    return false;
+    DEBUG(dbgs() << "SretPromotion: Not all callers can be updated\n");
+    ++NumRejectedSRETUses;
+    return 0;
   }
 
-  DOUT << "SretPromotion: sret argument will be promoted\n";
-  NumSRET++;
+  DEBUG(dbgs() << "SretPromotion: sret argument will be promoted\n");
+  ++NumSRET;
   // [1] Replace use of sret parameter 
-  AllocaInst *TheAlloca = new AllocaInst (*Context, STy, NULL, "mrv", 
-                                          F->getEntryBlock().begin());
+  AllocaInst *TheAlloca = new AllocaInst(STy, NULL, "mrv", 
+                                         F->getEntryBlock().begin());
   Value *NFirstArg = F->arg_begin();
   NFirstArg->replaceAllUsesWith(TheAlloca);
 
@@ -122,7 +126,7 @@ bool SRETPromotion::PromoteReturn(CallGraphNode *CGN) {
       ++BI;
       if (isa<ReturnInst>(I)) {
         Value *NV = new LoadInst(TheAlloca, "mrv.ld", I);
-        ReturnInst *NR = ReturnInst::Create(NV, I);
+        ReturnInst *NR = ReturnInst::Create(F->getContext(), NV, I);
         I->replaceAllUsesWith(NR);
         I->eraseFromParent();
       }
@@ -132,11 +136,13 @@ bool SRETPromotion::PromoteReturn(CallGraphNode *CGN) {
   Function *NF = cloneFunctionBody(F, STy);
 
   // [4] Update all call sites to use new function
-  updateCallSites(F, NF);
+  CallGraphNode *NF_CFN = updateCallSites(F, NF);
 
-  F->eraseFromParent();
-  getAnalysis<CallGraph>().changeFunction(F, NF);
-  return true;
+  CallGraph &CG = getAnalysis<CallGraph>();
+  NF_CFN->stealCalledFunctionsFrom(CG[F]);
+
+  delete CG.removeFunctionFromModule(F);
+  return NF_CFN;
 }
 
 // Check if it is ok to perform this promotion.
@@ -150,7 +156,7 @@ bool SRETPromotion::isSafeToUpdateAllCallers(Function *F) {
        FnUseI != FnUseE; ++FnUseI) {
     // The function is passed in as an argument to (possibly) another function,
     // we can't change it!
-    CallSite CS = CallSite::get(*FnUseI);
+    CallSite CS(*FnUseI);
     Instruction *Call = CS.getInstruction();
     // The function is used by something else than a call or invoke instruction,
     // we can't change it!
@@ -165,23 +171,23 @@ bool SRETPromotion::isSafeToUpdateAllCallers(Function *F) {
     // Check FirstArg's users.
     for (Value::use_iterator ArgI = FirstArg->use_begin(), 
            ArgE = FirstArg->use_end(); ArgI != ArgE; ++ArgI) {
-
+      User *U = *ArgI;
       // If FirstArg user is a CallInst that does not correspond to current
       // call site then this function F is not suitable for sret promotion.
-      if (CallInst *CI = dyn_cast<CallInst>(ArgI)) {
+      if (CallInst *CI = dyn_cast<CallInst>(U)) {
         if (CI != Call)
           return false;
       }
       // If FirstArg user is a GEP whose all users are not LoadInst then
       // this function F is not suitable for sret promotion.
-      else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(ArgI)) {
+      else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
         // TODO : Use dom info and insert PHINodes to collect get results
         // from multiple call sites for this GEP.
         if (GEP->getParent() != Call->getParent())
           return false;
         for (Value::use_iterator GEPI = GEP->use_begin(), GEPE = GEP->use_end();
              GEPI != GEPE; ++GEPI) 
-          if (!isa<LoadInst>(GEPI))
+          if (!isa<LoadInst>(*GEPI))
             return false;
       } 
       // Any other FirstArg users make this function unsuitable for sret 
@@ -230,7 +236,7 @@ Function *SRETPromotion::cloneFunctionBody(Function *F,
     AttributesVec.push_back(AttributeWithIndex::get(~0, attrs));
 
 
-  FunctionType *NFTy = Context->getFunctionType(STy, Params, FTy->isVarArg());
+  FunctionType *NFTy = FunctionType::get(STy, Params, FTy->isVarArg());
   Function *NF = Function::Create(NFTy, F->getLinkage());
   NF->takeName(F);
   NF->copyAttributesFrom(F);
@@ -244,25 +250,28 @@ Function *SRETPromotion::cloneFunctionBody(Function *F,
   Function::arg_iterator NI = NF->arg_begin();
   ++I;
   while (I != E) {
-      I->replaceAllUsesWith(NI);
-      NI->takeName(I);
-      ++I;
-      ++NI;
+    I->replaceAllUsesWith(NI);
+    NI->takeName(I);
+    ++I;
+    ++NI;
   }
 
   return NF;
 }
 
 /// updateCallSites - Update all sites that call F to use NF.
-void SRETPromotion::updateCallSites(Function *F, Function *NF) {
+CallGraphNode *SRETPromotion::updateCallSites(Function *F, Function *NF) {
   CallGraph &CG = getAnalysis<CallGraph>();
   SmallVector<Value*, 16> Args;
 
   // Attributes - Keep track of the parameter attributes for the arguments.
   SmallVector<AttributeWithIndex, 8> ArgAttrsVec;
 
+  // Get a new callgraph node for NF.
+  CallGraphNode *NF_CGN = CG.getOrInsertFunction(NF);
+
   while (!F->use_empty()) {
-    CallSite CS = CallSite::get(*F->use_begin());
+    CallSite CS(*F->use_begin());
     Instruction *Call = CS.getInstruction();
 
     const AttrListPtr &PAL = F->getAttributes();
@@ -310,8 +319,10 @@ void SRETPromotion::updateCallSites(Function *F, Function *NF) {
     New->takeName(Call);
 
     // Update the callgraph to know that the callsite has been transformed.
-    CG[Call->getParent()->getParent()]->replaceCallSite(Call, New);
-
+    CallGraphNode *CalleeNode = CG[Call->getParent()->getParent()];
+    CalleeNode->removeCallEdgeFor(Call);
+    CalleeNode->addCalledFunction(New, NF_CGN);
+    
     // Update all users of sret parameter to extract value using extractvalue.
     for (Value::use_iterator UI = FirstCArg->use_begin(), 
            UE = FirstCArg->use_end(); UI != UE; ) {
@@ -319,24 +330,25 @@ void SRETPromotion::updateCallSites(Function *F, Function *NF) {
       CallInst *C2 = dyn_cast<CallInst>(U2);
       if (C2 && (C2 == Call))
         continue;
-      else if (GetElementPtrInst *UGEP = dyn_cast<GetElementPtrInst>(U2)) {
-        ConstantInt *Idx = dyn_cast<ConstantInt>(UGEP->getOperand(2));
-        assert (Idx && "Unexpected getelementptr index!");
-        Value *GR = ExtractValueInst::Create(New, Idx->getZExtValue(),
-                                             "evi", UGEP);
-        while(!UGEP->use_empty()) {
-          // isSafeToUpdateAllCallers has checked that all GEP uses are
-          // LoadInsts
-          LoadInst *L = cast<LoadInst>(*UGEP->use_begin());
-          L->replaceAllUsesWith(GR);
-          L->eraseFromParent();
-        }
-        UGEP->eraseFromParent();
+      
+      GetElementPtrInst *UGEP = cast<GetElementPtrInst>(U2);
+      ConstantInt *Idx = cast<ConstantInt>(UGEP->getOperand(2));
+      Value *GR = ExtractValueInst::Create(New, Idx->getZExtValue(),
+                                           "evi", UGEP);
+      while(!UGEP->use_empty()) {
+        // isSafeToUpdateAllCallers has checked that all GEP uses are
+        // LoadInsts
+        LoadInst *L = cast<LoadInst>(*UGEP->use_begin());
+        L->replaceAllUsesWith(GR);
+        L->eraseFromParent();
       }
-      else assert( 0 && "Unexpected sret parameter use");
+      UGEP->eraseFromParent();
+      continue;
     }
     Call->eraseFromParent();
   }
+  
+  return NF_CGN;
 }
 
 /// nestedStructType - Return true if STy includes any
@@ -345,7 +357,7 @@ bool SRETPromotion::nestedStructType(const StructType *STy) {
   unsigned Num = STy->getNumElements();
   for (unsigned i = 0; i < Num; i++) {
     const Type *Ty = STy->getElementType(i);
-    if (!Ty->isSingleValueType() && Ty != Type::VoidTy)
+    if (!Ty->isSingleValueType() && !Ty->isVoidTy())
       return true;
   }
   return false;