Fix some nasty callgraph dangling pointer problems in
[oota-llvm.git] / lib / Transforms / IPO / StructRetPromotion.cpp
index e28fc42ed248fd8526b1573032512b68c8080943..4c4c6d6828d9043131005eeff52f87ea1af0f1ca 100644 (file)
@@ -49,15 +49,15 @@ namespace {
       CallGraphSCCPass::getAnalysisUsage(AU);
     }
 
-    virtual bool runOnSCC(const std::vector<CallGraphNode *> &SCC);
+    virtual bool runOnSCC(std::vector<CallGraphNode *> &SCC);
     static char ID; // Pass identification, replacement for typeid
     SRETPromotion() : CallGraphSCCPass(&ID) {}
 
   private:
-    bool PromoteReturn(CallGraphNode *CGN);
+    CallGraphNode *PromoteReturn(CallGraphNode *CGN);
     bool isSafeToUpdateAllCallers(Function *F);
     Function *cloneFunctionBody(Function *F, const StructType *STy);
-    void updateCallSites(Function *F, Function *NF);
+    CallGraphNode *updateCallSites(Function *F, Function *NF);
     bool nestedStructType(const StructType *STy);
   };
 }
@@ -70,44 +70,47 @@ Pass *llvm::createStructRetPromotionPass() {
   return new SRETPromotion();
 }
 
-bool SRETPromotion::runOnSCC(const std::vector<CallGraphNode *> &SCC) {
+bool SRETPromotion::runOnSCC(std::vector<CallGraphNode *> &SCC) {
   bool Changed = false;
 
   for (unsigned i = 0, e = SCC.size(); i != e; ++i)
-    Changed |= PromoteReturn(SCC[i]);
+    if (CallGraphNode *NewNode = PromoteReturn(SCC[i])) {
+      SCC[i] = NewNode;
+      Changed = true;
+    }
 
   return Changed;
 }
 
 /// PromoteReturn - This method promotes function that uses StructRet paramater 
-/// into a function that uses mulitple return value.
-bool SRETPromotion::PromoteReturn(CallGraphNode *CGN) {
+/// into a function that uses multiple return values.
+CallGraphNode *SRETPromotion::PromoteReturn(CallGraphNode *CGN) {
   Function *F = CGN->getFunction();
 
   if (!F || F->isDeclaration() || !F->hasLocalLinkage())
-    return false;
+    return 0;
 
   // Make sure that function returns struct.
   if (F->arg_size() == 0 || !F->hasStructRetAttr() || F->doesNotReturn())
-    return false;
+    return 0;
 
   DEBUG(errs() << "SretPromotion: Looking at sret function " 
         << F->getName() << "\n");
 
-  assert (F->getReturnType() == Type::getVoidTy(F->getContext()) &&
-          "Invalid function return type");
+  assert(F->getReturnType() == Type::getVoidTy(F->getContext()) &&
+         "Invalid function return type");
   Function::arg_iterator AI = F->arg_begin();
   const llvm::PointerType *FArgType = dyn_cast<PointerType>(AI->getType());
-  assert (FArgType && "Invalid sret parameter type");
+  assert(FArgType && "Invalid sret parameter type");
   const llvm::StructType *STy = 
     dyn_cast<StructType>(FArgType->getElementType());
-  assert (STy && "Invalid sret parameter element type");
+  assert(STy && "Invalid sret parameter element type");
 
   // Check if it is ok to perform this promotion.
   if (isSafeToUpdateAllCallers(F) == false) {
     DEBUG(errs() << "SretPromotion: Not all callers can be updated\n");
     NumRejectedSRETUses++;
-    return false;
+    return 0;
   }
 
   DEBUG(errs() << "SretPromotion: sret argument will be promoted\n");
@@ -135,11 +138,13 @@ bool SRETPromotion::PromoteReturn(CallGraphNode *CGN) {
   Function *NF = cloneFunctionBody(F, STy);
 
   // [4] Update all call sites to use new function
-  updateCallSites(F, NF);
+  CallGraphNode *NF_CFN = updateCallSites(F, NF);
 
-  F->eraseFromParent();
-  getAnalysis<CallGraph>().changeFunction(F, NF);
-  return true;
+  CallGraph &CG = getAnalysis<CallGraph>();
+  NF_CFN->stealCalledFunctionsFrom(CG[F]);
+
+  delete CG.removeFunctionFromModule(F);
+  return NF_CFN;
 }
 
 // Check if it is ok to perform this promotion.
@@ -247,23 +252,26 @@ Function *SRETPromotion::cloneFunctionBody(Function *F,
   Function::arg_iterator NI = NF->arg_begin();
   ++I;
   while (I != E) {
-      I->replaceAllUsesWith(NI);
-      NI->takeName(I);
-      ++I;
-      ++NI;
+    I->replaceAllUsesWith(NI);
+    NI->takeName(I);
+    ++I;
+    ++NI;
   }
 
   return NF;
 }
 
 /// updateCallSites - Update all sites that call F to use NF.
-void SRETPromotion::updateCallSites(Function *F, Function *NF) {
+CallGraphNode *SRETPromotion::updateCallSites(Function *F, Function *NF) {
   CallGraph &CG = getAnalysis<CallGraph>();
   SmallVector<Value*, 16> Args;
 
   // Attributes - Keep track of the parameter attributes for the arguments.
   SmallVector<AttributeWithIndex, 8> ArgAttrsVec;
 
+  // Get a new callgraph node for NF.
+  CallGraphNode *NF_CGN = CG.getOrInsertFunction(NF);
+
   while (!F->use_empty()) {
     CallSite CS = CallSite::get(*F->use_begin());
     Instruction *Call = CS.getInstruction();
@@ -313,7 +321,7 @@ void SRETPromotion::updateCallSites(Function *F, Function *NF) {
     New->takeName(Call);
 
     // Update the callgraph to know that the callsite has been transformed.
-    CG[Call->getParent()->getParent()]->replaceCallSite(Call, New);
+    CG[Call->getParent()->getParent()]->replaceCallSite(Call, New, NF_CGN);
 
     // Update all users of sret parameter to extract value using extractvalue.
     for (Value::use_iterator UI = FirstCArg->use_begin(), 
@@ -322,7 +330,8 @@ void SRETPromotion::updateCallSites(Function *F, Function *NF) {
       CallInst *C2 = dyn_cast<CallInst>(U2);
       if (C2 && (C2 == Call))
         continue;
-      else if (GetElementPtrInst *UGEP = dyn_cast<GetElementPtrInst>(U2)) {
+      
+      if (GetElementPtrInst *UGEP = dyn_cast<GetElementPtrInst>(U2)) {
         ConstantInt *Idx = dyn_cast<ConstantInt>(UGEP->getOperand(2));
         assert (Idx && "Unexpected getelementptr index!");
         Value *GR = ExtractValueInst::Create(New, Idx->getZExtValue(),
@@ -335,11 +344,15 @@ void SRETPromotion::updateCallSites(Function *F, Function *NF) {
           L->eraseFromParent();
         }
         UGEP->eraseFromParent();
+        continue;
       }
-      else assert( 0 && "Unexpected sret parameter use");
+      
+      assert(0 && "Unexpected sret parameter use");
     }
     Call->eraseFromParent();
   }
+  
+  return NF_CGN;
 }
 
 /// nestedStructType - Return true if STy includes any