Enable first-class aggregates support.
[oota-llvm.git] / lib / Transforms / IPO / StructRetPromotion.cpp
1 //===-- StructRetPromotion.cpp - Promote sret arguments ------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass finds functions that return a struct (using a pointer to the struct
11 // as the first argument of the function, marked with the 'sret' attribute) and
12 // replaces them with a new function that simply returns each of the elements of
13 // that struct (using multiple return values).
14 //
15 // This pass works under a number of conditions:
16 //  1. The returned struct must not contain other structs
17 //  2. The returned struct must only be used to load values from
18 //  3. The placeholder struct passed in is the result of an alloca
19 //
20 //===----------------------------------------------------------------------===//
21
22 #define DEBUG_TYPE "sretpromotion"
23 #include "llvm/Transforms/IPO.h"
24 #include "llvm/Constants.h"
25 #include "llvm/DerivedTypes.h"
26 #include "llvm/Module.h"
27 #include "llvm/CallGraphSCCPass.h"
28 #include "llvm/Instructions.h"
29 #include "llvm/Analysis/CallGraph.h"
30 #include "llvm/Support/CallSite.h"
31 #include "llvm/Support/CFG.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/ADT/Statistic.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/Statistic.h"
36 #include "llvm/Support/Compiler.h"
37 using namespace llvm;
38
39 STATISTIC(NumRejectedSRETUses , "Number of sret rejected due to unexpected uses");
40 STATISTIC(NumSRET , "Number of sret promoted");
41 namespace {
42   /// SRETPromotion - This pass removes sret parameter and updates
43   /// function to use multiple return value.
44   ///
45   struct VISIBILITY_HIDDEN SRETPromotion : public CallGraphSCCPass {
46     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
47       CallGraphSCCPass::getAnalysisUsage(AU);
48     }
49
50     virtual bool runOnSCC(const std::vector<CallGraphNode *> &SCC);
51     static char ID; // Pass identification, replacement for typeid
52     SRETPromotion() : CallGraphSCCPass((intptr_t)&ID) {}
53
54   private:
55     bool PromoteReturn(CallGraphNode *CGN);
56     bool isSafeToUpdateAllCallers(Function *F);
57     Function *cloneFunctionBody(Function *F, const StructType *STy);
58     void updateCallSites(Function *F, Function *NF);
59     bool nestedStructType(const StructType *STy);
60   };
61 }
62
63 char SRETPromotion::ID = 0;
64 static RegisterPass<SRETPromotion>
65 X("sretpromotion", "Promote sret arguments to multiple ret values");
66
67 Pass *llvm::createStructRetPromotionPass() {
68   return new SRETPromotion();
69 }
70
71 bool SRETPromotion::runOnSCC(const std::vector<CallGraphNode *> &SCC) {
72   bool Changed = false;
73
74   for (unsigned i = 0, e = SCC.size(); i != e; ++i)
75     Changed |= PromoteReturn(SCC[i]);
76
77   return Changed;
78 }
79
80 /// PromoteReturn - This method promotes function that uses StructRet paramater 
81 /// into a function that uses mulitple return value.
82 bool SRETPromotion::PromoteReturn(CallGraphNode *CGN) {
83   Function *F = CGN->getFunction();
84
85   if (!F || F->isDeclaration() || !F->hasInternalLinkage())
86     return false;
87
88   // Make sure that function returns struct.
89   if (F->arg_size() == 0 || !F->hasStructRetAttr() || F->doesNotReturn())
90     return false;
91
92   assert (F->getReturnType() == Type::VoidTy && "Invalid function return type");
93   Function::arg_iterator AI = F->arg_begin();
94   const llvm::PointerType *FArgType = dyn_cast<PointerType>(AI->getType());
95   assert (FArgType && "Invalid sret parameter type");
96   const llvm::StructType *STy = 
97     dyn_cast<StructType>(FArgType->getElementType());
98   assert (STy && "Invalid sret parameter element type");
99
100   // Check if it is ok to perform this promotion.
101   if (isSafeToUpdateAllCallers(F) == false) {
102     NumRejectedSRETUses++;
103     return false;
104   }
105
106   NumSRET++;
107   // [1] Replace use of sret parameter 
108   AllocaInst *TheAlloca = new AllocaInst (STy, NULL, "mrv", 
109                                           F->getEntryBlock().begin());
110   Value *NFirstArg = F->arg_begin();
111   NFirstArg->replaceAllUsesWith(TheAlloca);
112
113   // [2] Find and replace ret instructions
114   for (Function::iterator FI = F->begin(), FE = F->end();  FI != FE; ++FI) 
115     for(BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE; ) {
116       Instruction *I = BI;
117       ++BI;
118       if (isa<ReturnInst>(I)) {
119         Value *NV = new LoadInst(TheAlloca, "mrv.ld", I);
120         ReturnInst *NR = ReturnInst::Create(NV);
121         I->replaceAllUsesWith(NR);
122         I->eraseFromParent();
123       }
124     }
125
126   // [3] Create the new function body and insert it into the module.
127   Function *NF = cloneFunctionBody(F, STy);
128
129   // [4] Update all call sites to use new function
130   updateCallSites(F, NF);
131
132   F->eraseFromParent();
133   getAnalysis<CallGraph>().changeFunction(F, NF);
134   return true;
135 }
136
137 // Check if it is ok to perform this promotion.
138 bool SRETPromotion::isSafeToUpdateAllCallers(Function *F) {
139
140   if (F->use_empty())
141     // No users. OK to modify signature.
142     return true;
143
144   for (Value::use_iterator FnUseI = F->use_begin(), FnUseE = F->use_end();
145        FnUseI != FnUseE; ++FnUseI) {
146     // The function is passed in as an argument to (possibly) another function,
147     // we can't change it!
148     if (FnUseI.getOperandNo() != 0)
149       return false;
150
151     CallSite CS = CallSite::get(*FnUseI);
152     Instruction *Call = CS.getInstruction();
153     // The function is used by something else than a call or invoke instruction,
154     // we can't change it!
155     if (!Call)
156       return false;
157     CallSite::arg_iterator AI = CS.arg_begin();
158     Value *FirstArg = *AI;
159
160     if (!isa<AllocaInst>(FirstArg))
161       return false;
162
163     // Check FirstArg's users.
164     for (Value::use_iterator ArgI = FirstArg->use_begin(), 
165            ArgE = FirstArg->use_end(); ArgI != ArgE; ++ArgI) {
166
167       // If FirstArg user is a CallInst that does not correspond to current
168       // call site then this function F is not suitable for sret promotion.
169       if (CallInst *CI = dyn_cast<CallInst>(ArgI)) {
170         if (CI != Call)
171           return false;
172       }
173       // If FirstArg user is a GEP whose all users are not LoadInst then
174       // this function F is not suitable for sret promotion.
175       else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(ArgI)) {
176         // TODO : Use dom info and insert PHINodes to collect get results
177         // from multiple call sites for this GEP.
178         if (GEP->getParent() != Call->getParent())
179           return false;
180         for (Value::use_iterator GEPI = GEP->use_begin(), GEPE = GEP->use_end();
181              GEPI != GEPE; ++GEPI) 
182           if (!isa<LoadInst>(GEPI))
183             return false;
184       } 
185       // Any other FirstArg users make this function unsuitable for sret 
186       // promotion.
187       else
188         return false;
189     }
190   }
191
192   return true;
193 }
194
195 /// cloneFunctionBody - Create a new function based on F and
196 /// insert it into module. Remove first argument. Use STy as
197 /// the return type for new function.
198 Function *SRETPromotion::cloneFunctionBody(Function *F, 
199                                            const StructType *STy) {
200
201   const FunctionType *FTy = F->getFunctionType();
202   std::vector<const Type*> Params;
203
204   // ParamAttrs - Keep track of the parameter attributes for the arguments.
205   SmallVector<ParamAttrsWithIndex, 8> ParamAttrsVec;
206   const PAListPtr &PAL = F->getParamAttrs();
207
208   // Add any return attributes.
209   if (ParameterAttributes attrs = PAL.getParamAttrs(0))
210     ParamAttrsVec.push_back(ParamAttrsWithIndex::get(0, attrs));
211
212   // Skip first argument.
213   Function::arg_iterator I = F->arg_begin(), E = F->arg_end();
214   ++I;
215   // 0th parameter attribute is reserved for return type.
216   // 1th parameter attribute is for first 1st sret argument.
217   unsigned ParamIndex = 2; 
218   while (I != E) {
219     Params.push_back(I->getType());
220     if (ParameterAttributes Attrs = PAL.getParamAttrs(ParamIndex))
221       ParamAttrsVec.push_back(ParamAttrsWithIndex::get(ParamIndex - 1, Attrs));
222     ++I;
223     ++ParamIndex;
224   }
225
226   FunctionType *NFTy = FunctionType::get(STy, Params, FTy->isVarArg());
227   Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName());
228   NF->copyAttributesFrom(F);
229   NF->setParamAttrs(PAListPtr::get(ParamAttrsVec.begin(), ParamAttrsVec.end()));
230   F->getParent()->getFunctionList().insert(F, NF);
231   NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
232
233   // Replace arguments
234   I = F->arg_begin();
235   E = F->arg_end();
236   Function::arg_iterator NI = NF->arg_begin();
237   ++I;
238   while (I != E) {
239       I->replaceAllUsesWith(NI);
240       NI->takeName(I);
241       ++I;
242       ++NI;
243   }
244
245   return NF;
246 }
247
248 /// updateCallSites - Update all sites that call F to use NF.
249 void SRETPromotion::updateCallSites(Function *F, Function *NF) {
250
251   SmallVector<Value*, 16> Args;
252
253   // ParamAttrs - Keep track of the parameter attributes for the arguments.
254   SmallVector<ParamAttrsWithIndex, 8> ArgAttrsVec;
255
256   for (Value::use_iterator FUI = F->use_begin(), FUE = F->use_end();
257        FUI != FUE;) {
258     CallSite CS = CallSite::get(*FUI);
259     ++FUI;
260     Instruction *Call = CS.getInstruction();
261
262     const PAListPtr &PAL = F->getParamAttrs();
263     // Add any return attributes.
264     if (ParameterAttributes attrs = PAL.getParamAttrs(0))
265       ArgAttrsVec.push_back(ParamAttrsWithIndex::get(0, attrs));
266
267     // Copy arguments, however skip first one.
268     CallSite::arg_iterator AI = CS.arg_begin(), AE = CS.arg_end();
269     Value *FirstCArg = *AI;
270     ++AI;
271     // 0th parameter attribute is reserved for return type.
272     // 1th parameter attribute is for first 1st sret argument.
273     unsigned ParamIndex = 2; 
274     while (AI != AE) {
275       Args.push_back(*AI); 
276       if (ParameterAttributes Attrs = PAL.getParamAttrs(ParamIndex))
277         ArgAttrsVec.push_back(ParamAttrsWithIndex::get(ParamIndex - 1, Attrs));
278       ++ParamIndex;
279       ++AI;
280     }
281
282     
283     PAListPtr NewPAL = PAListPtr::get(ArgAttrsVec.begin(), ArgAttrsVec.end());
284     
285     // Build new call instruction.
286     Instruction *New;
287     if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) {
288       New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
289                                Args.begin(), Args.end(), "", Call);
290       cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv());
291       cast<InvokeInst>(New)->setParamAttrs(NewPAL);
292     } else {
293       New = CallInst::Create(NF, Args.begin(), Args.end(), "", Call);
294       cast<CallInst>(New)->setCallingConv(CS.getCallingConv());
295       cast<CallInst>(New)->setParamAttrs(NewPAL);
296       if (cast<CallInst>(Call)->isTailCall())
297         cast<CallInst>(New)->setTailCall();
298     }
299     Args.clear();
300     ArgAttrsVec.clear();
301     New->takeName(Call);
302
303     // Update all users of sret parameter to extract value using extractvalue.
304     for (Value::use_iterator UI = FirstCArg->use_begin(), 
305            UE = FirstCArg->use_end(); UI != UE; ) {
306       User *U2 = *UI++;
307       CallInst *C2 = dyn_cast<CallInst>(U2);
308       if (C2 && (C2 == Call))
309         continue;
310       else if (GetElementPtrInst *UGEP = dyn_cast<GetElementPtrInst>(U2)) {
311         ConstantInt *Idx = dyn_cast<ConstantInt>(UGEP->getOperand(2));
312         assert (Idx && "Unexpected getelementptr index!");
313         Value *GR = ExtractValueInst::Create(New, Idx->getZExtValue(),
314                                              "evi", UGEP);
315         for (Value::use_iterator GI = UGEP->use_begin(),
316                GE = UGEP->use_end(); GI != GE; ++GI) {
317           if (LoadInst *L = dyn_cast<LoadInst>(*GI)) {
318             L->replaceAllUsesWith(GR);
319             L->eraseFromParent();
320           }
321         }
322         UGEP->eraseFromParent();
323       }
324       else assert( 0 && "Unexpected sret parameter use");
325     }
326     Call->eraseFromParent();
327   }
328 }
329
330 /// nestedStructType - Return true if STy includes any
331 /// other aggregate types
332 bool SRETPromotion::nestedStructType(const StructType *STy) {
333   unsigned Num = STy->getNumElements();
334   for (unsigned i = 0; i < Num; i++) {
335     const Type *Ty = STy->getElementType(i);
336     if (!Ty->isSingleValueType() && Ty != Type::VoidTy)
337       return true;
338   }
339   return false;
340 }