Add pass to promote sret.
[oota-llvm.git] / lib / Transforms / IPO / StructRetPromotion.cpp
1 //===-- StructRetPromotion.cpp - Promote sret arguments -000000------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass promotes "by reference" arguments to be "by value" arguments.  In
11 // practice, this means looking for internal functions that have pointer
12 // arguments.  If it can prove, through the use of alias analysis, that an
13 // argument is *only* loaded, then it can pass the value into the function
14 // instead of the address of the value.  This can cause recursive simplification
15 // of code and lead to the elimination of allocas (especially in C++ template
16 // code like the STL).
17 //
18 // This pass also handles aggregate arguments that are passed into a function,
19 // scalarizing them if the elements of the aggregate are only loaded.  Note that
20 // it refuses to scalarize aggregates which would require passing in more than
21 // three operands to the function, because passing thousands of operands for a
22 // large array or structure is unprofitable!
23 //
24 // Note that this transformation could also be done for arguments that are only
25 // stored to (returning the value instead), but does not currently.  This case
26 // would be best handled when and if LLVM begins supporting multiple return
27 // values from functions.
28 //
29 //===----------------------------------------------------------------------===//
30
31 #define DEBUG_TYPE "sretpromotion"
32 #include "llvm/Transforms/IPO.h"
33 #include "llvm/Constants.h"
34 #include "llvm/DerivedTypes.h"
35 #include "llvm/Module.h"
36 #include "llvm/CallGraphSCCPass.h"
37 #include "llvm/Instructions.h"
38 #include "llvm/Analysis/CallGraph.h"
39 #include "llvm/Support/CallSite.h"
40 #include "llvm/Support/CFG.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/ADT/Statistic.h"
43 #include "llvm/ADT/SmallVector.h"
44 #include "llvm/Support/Compiler.h"
45 using namespace llvm;
46
47 namespace {
48   /// SRETPromotion - This pass removes sret parameter and updates
49   /// function to use multiple return value.
50   ///
51   struct VISIBILITY_HIDDEN SRETPromotion : public CallGraphSCCPass {
52     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
53       CallGraphSCCPass::getAnalysisUsage(AU);
54     }
55
56     virtual bool runOnSCC(const std::vector<CallGraphNode *> &SCC);
57     static char ID; // Pass identification, replacement for typeid
58     SRETPromotion() : CallGraphSCCPass((intptr_t)&ID) {}
59
60   private:
61     bool PromoteReturn(CallGraphNode *CGN);
62     bool isSafeToUpdateAllCallers(Function *F);
63     Function *cloneFunctionBody(Function *F, const StructType *STy);
64     void updateCallSites(Function *F, Function *NF);
65   };
66
67   char SRETPromotion::ID = 0;
68   RegisterPass<SRETPromotion> X("sretpromotion",
69                                "Promote sret arguments to multiple ret values");
70 }
71
72 Pass *llvm::createStructRetPromotionPass() {
73   return new SRETPromotion();
74 }
75
76 bool SRETPromotion::runOnSCC(const std::vector<CallGraphNode *> &SCC) {
77   bool Changed = false;
78
79   for (unsigned i = 0, e = SCC.size(); i != e; ++i)
80     Changed |= PromoteReturn(SCC[i]);
81
82   return Changed;
83 }
84
85 /// PromoteReturn - This method promotes function that uses StructRet paramater 
86 /// into a function that uses mulitple return value.
87 bool SRETPromotion::PromoteReturn(CallGraphNode *CGN) {
88   Function *F = CGN->getFunction();
89
90   // Make sure that it is local to this module.
91   if (!F || !F->hasInternalLinkage())
92     return false;
93
94   // Make sure that function returns struct.
95   if (F->arg_size() == 0 || !F->isStructReturn() || F->doesNotReturn())
96     return false;
97
98   assert (F->getReturnType() == Type::VoidTy && "Invalid function return type");
99   Function::arg_iterator AI = F->arg_begin();
100   const llvm::PointerType *FArgType = dyn_cast<PointerType>(AI->getType());
101   assert (FArgType && "Invalid sret paramater type");
102   const llvm::StructType *STy = 
103     dyn_cast<StructType>(FArgType->getElementType());
104   assert (STy && "Invalid sret parameter element type");
105
106   // Check if it is ok to perform this promotion.
107   if (isSafeToUpdateAllCallers(F) == false)
108     return false;
109
110   // [1] Replace use of sret parameter 
111   AllocaInst *TheAlloca = new AllocaInst (STy, NULL, "mrv", F->getEntryBlock().begin());
112   Value *NFirstArg = F->arg_begin();
113   NFirstArg->replaceAllUsesWith(TheAlloca);
114
115   // Find and replace ret instructions
116   SmallVector<Value *,4> RetVals;
117   for (Function::iterator FI = F->begin(), FE = F->end();  FI != FE; ++FI) 
118     for(BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE; ) {
119       Instruction *I = BI;
120       ++BI;
121       if (isa<ReturnInst>(I)) {
122         RetVals.clear();
123         for (unsigned idx = 0; idx < STy->getNumElements(); ++idx) {
124           SmallVector<Value*, 2> GEPIdx;
125           GEPIdx.push_back(ConstantInt::get(Type::Int32Ty, 0));
126           GEPIdx.push_back(ConstantInt::get(Type::Int32Ty, idx));
127           Value *NGEPI = new GetElementPtrInst(TheAlloca, GEPIdx.begin(), GEPIdx.end(),
128                                                "mrv.gep", I);
129           Value *NV = new LoadInst(NGEPI, "mrv.ld", I);
130           RetVals.push_back(NV);
131         }
132     
133         ReturnInst *NR = new ReturnInst(&RetVals[0], RetVals.size(), I);
134         I->replaceAllUsesWith(NR);
135         I->eraseFromParent();
136       }
137     }
138
139   // Create the new function body and insert it into the module.
140   Function *NF = cloneFunctionBody(F, STy);
141
142   // Update all call sites to use new function
143   updateCallSites(F, NF);
144
145   F->eraseFromParent();
146   getAnalysis<CallGraph>().changeFunction(F, NF);
147   return true;
148 }
149
150   // Check if it is ok to perform this promotion.
151 bool SRETPromotion::isSafeToUpdateAllCallers(Function *F) {
152
153   if (F->use_empty())
154     // No users. OK to modify signature.
155     return true;
156
157   for (Value::use_iterator FnUseI = F->use_begin(), FnUseE = F->use_end();
158        FnUseI != FnUseE; ++FnUseI) {
159
160     CallSite CS = CallSite::get(*FnUseI);
161     Instruction *Call = CS.getInstruction();
162     CallSite::arg_iterator AI = CS.arg_begin();
163     Value *FirstArg = *AI;
164
165     if (!isa<AllocaInst>(FirstArg))
166       return false;
167
168     // Check FirstArg's users.
169     for (Value::use_iterator ArgI = FirstArg->use_begin(), 
170            ArgE = FirstArg->use_end(); ArgI != ArgE; ++ArgI) {
171
172       // If FirstArg user is a CallInst that does not correspond to current
173       // call site then this function F is not suitable for sret promotion.
174       if (CallInst *CI = dyn_cast<CallInst>(ArgI)) {
175         if (CI != Call)
176           return false;
177       }
178       // If FirstArg user is a GEP whose all users are not LoadInst then
179       // this function F is not suitable for sret promotion.
180       else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(ArgI)) {
181         for (Value::use_iterator GEPI = GEP->use_begin(), GEPE = GEP->use_end();
182              GEPI != GEPE; ++GEPI) 
183           if (!isa<LoadInst>(GEPI))
184             return false;
185       } 
186       // Any other FirstArg users make this function unsuitable for sret 
187       // promotion.
188       else
189         return false;
190     }
191   }
192
193   return true;
194 }
195
196 /// cloneFunctionBody - Create a new function based on F and
197 /// insert it into module. Remove first argument. Use STy as
198 /// the return type for new function.
199 Function *SRETPromotion::cloneFunctionBody(Function *F, 
200                                            const StructType *STy) {
201
202   // FIXME : Do not drop param attributes on the floor.
203   const FunctionType *FTy = F->getFunctionType();
204   std::vector<const Type*> Params;
205
206   // Skip first argument.
207   Function::arg_iterator I = F->arg_begin(), E = F->arg_end();
208   ++I;
209   while (I != E) {
210     Params.push_back(I->getType());
211     ++I;
212   }
213
214   FunctionType *NFTy = FunctionType::get(STy, Params, FTy->isVarArg());
215   Function *NF = new Function(NFTy, F->getLinkage(), F->getName());
216   NF->setCallingConv(F->getCallingConv());
217   F->getParent()->getFunctionList().insert(F, NF);
218   NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
219
220   // Replace arguments
221   I = F->arg_begin();
222   E = F->arg_end();
223   Function::arg_iterator NI = NF->arg_begin();
224   ++I;
225   while (I != E) {
226       I->replaceAllUsesWith(NI);
227       NI->takeName(I);
228       ++I;
229       ++NI;
230   }
231
232   return NF;
233 }
234
235 /// updateCallSites - Update all sites that call F to use NF.
236 void SRETPromotion::updateCallSites(Function *F, Function *NF) {
237
238   // FIXME : Handle parameter attributes
239   SmallVector<Value*, 16> Args;
240
241   for (Value::use_iterator FUI = F->use_begin(), FUE = F->use_end(); FUI != FUE;) {
242     CallSite CS = CallSite::get(*FUI);
243     ++FUI;
244     Instruction *Call = CS.getInstruction();
245
246     // Copy arguments, however skip first one.
247     CallSite::arg_iterator AI = CS.arg_begin(), AE = CS.arg_end();
248     Value *FirstCArg = *AI;
249     ++AI;
250     while (AI != AE) {
251       Args.push_back(*AI); 
252       ++AI;
253     }
254
255     // Build new call instruction.
256     Instruction *New;
257     if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) {
258       New = new InvokeInst(NF, II->getNormalDest(), II->getUnwindDest(),
259                            Args.begin(), Args.end(), "", Call);
260       cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv());
261     } else {
262       New = new CallInst(NF, Args.begin(), Args.end(), "", Call);
263       cast<CallInst>(New)->setCallingConv(CS.getCallingConv());
264       if (cast<CallInst>(Call)->isTailCall())
265         cast<CallInst>(New)->setTailCall();
266     }
267     Args.clear();
268     New->takeName(Call);
269
270     // Update all users of sret parameter to extract value using getresult.
271     for (Value::use_iterator UI = FirstCArg->use_begin(), 
272            UE = FirstCArg->use_end(); UI != UE; ) {
273       User *U2 = *UI++;
274       CallInst *C2 = dyn_cast<CallInst>(U2);
275       if (C2 && (C2 == Call))
276         continue;
277       else if (GetElementPtrInst *UGEP = dyn_cast<GetElementPtrInst>(U2)) {
278         Value *GR = new GetResultInst(New, 5, "xxx", UGEP);
279         for (Value::use_iterator GI = UGEP->use_begin(),
280                GE = UGEP->use_end(); GI != GE; ++GI) {
281           if (LoadInst *L = dyn_cast<LoadInst>(*GI)) {
282             L->replaceAllUsesWith(GR);
283             L->eraseFromParent();
284           }
285         }
286         UGEP->eraseFromParent();
287       }
288       else assert( 0 && "Unexpected sret parameter use");
289     }
290     Call->eraseFromParent();
291   }
292 }