Simplify code a lot by using the Module::getFunction & getOrInsertFunction
[oota-llvm.git] / lib / Transforms / Utils / LowerAllocations.cpp
1 //===- ChangeAllocations.cpp - Modify %malloc & %free calls -----------------=//
2 //
3 // This file defines two passes that convert malloc and free instructions to
4 // calls to and from %malloc & %free function calls.  The LowerAllocations
5 // transformation is a target dependant tranformation because it depends on the
6 // size of data types and alignment constraints.
7 //
8 //===----------------------------------------------------------------------===//
9
10 #include "llvm/Transforms/ChangeAllocations.h"
11 #include "llvm/Target/TargetData.h"
12 #include "llvm/Module.h"
13 #include "llvm/DerivedTypes.h"
14 #include "llvm/iMemory.h"
15 #include "llvm/iOther.h"
16 #include "llvm/ConstantVals.h"
17 #include "llvm/Pass.h"
18 #include "TransformInternals.h"
19 using std::vector;
20
21 namespace {
22
23 // LowerAllocations - Turn malloc and free instructions into %malloc and %free
24 // calls.
25 //
26 class LowerAllocations : public BasicBlockPass {
27   Function *MallocFunc;   // Functions in the module we are processing
28   Function *FreeFunc;     // Initialized by doInitialization
29
30   const TargetData &DataLayout;
31 public:
32   inline LowerAllocations(const TargetData &TD) : DataLayout(TD) {
33     MallocFunc = FreeFunc = 0;
34   }
35
36   // doPassInitialization - For the lower allocations pass, this ensures that a
37   // module contains a declaration for a malloc and a free function.
38   //
39   bool doInitialization(Module *M);
40
41   // runOnBasicBlock - This method does the actual work of converting
42   // instructions over, assuming that the pass has already been initialized.
43   //
44   bool runOnBasicBlock(BasicBlock *BB);
45 };
46
47 // RaiseAllocations - Turn %malloc and %free calls into the appropriate
48 // instruction.
49 //
50 class RaiseAllocations : public BasicBlockPass {
51   Function *MallocFunc;   // Functions in the module we are processing
52   Function *FreeFunc;     // Initialized by doPassInitializationVirt
53 public:
54   inline RaiseAllocations() : MallocFunc(0), FreeFunc(0) {}
55
56   // doPassInitialization - For the raise allocations pass, this finds a
57   // declaration for malloc and free if they exist.
58   //
59   bool doInitialization(Module *M);
60
61   // runOnBasicBlock - This method does the actual work of converting
62   // instructions over, assuming that the pass has already been initialized.
63   //
64   bool runOnBasicBlock(BasicBlock *BB);
65 };
66
67 }  // end anonymous namespace
68
69 // doInitialization - For the lower allocations pass, this ensures that a
70 // module contains a declaration for a malloc and a free function.
71 //
72 // This function is always successful.
73 //
74 bool LowerAllocations::doInitialization(Module *M) {
75   const FunctionType *MallocType = 
76     FunctionType::get(PointerType::get(Type::SByteTy),
77                       vector<const Type*>(1, Type::UIntTy), false);
78   const FunctionType *FreeType = 
79     FunctionType::get(Type::VoidTy,
80                       vector<const Type*>(1, PointerType::get(Type::SByteTy)),
81                       false);
82
83   MallocFunc = M->getOrInsertFunction("malloc", MallocType);
84   FreeFunc   = M->getOrInsertFunction("free"  , FreeType);
85
86   return false;
87 }
88
89 // runOnBasicBlock - This method does the actual work of converting
90 // instructions over, assuming that the pass has already been initialized.
91 //
92 bool LowerAllocations::runOnBasicBlock(BasicBlock *BB) {
93   bool Changed = false;
94   assert(MallocFunc && FreeFunc && BB && "Pass not initialized!");
95
96   // Loop over all of the instructions, looking for malloc or free instructions
97   for (unsigned i = 0; i < BB->size(); ++i) {
98     BasicBlock::InstListType &BBIL = BB->getInstList();
99     if (MallocInst *MI = dyn_cast<MallocInst>(*(BBIL.begin()+i))) {
100       BBIL.remove(BBIL.begin()+i);   // remove the malloc instr...
101         
102       const Type *AllocTy =cast<PointerType>(MI->getType())->getElementType();
103       
104       // Get the number of bytes to be allocated for one element of the
105       // requested type...
106       unsigned Size = DataLayout.getTypeSize(AllocTy);
107       
108       // malloc(type) becomes sbyte *malloc(constint)
109       Value *MallocArg = ConstantUInt::get(Type::UIntTy, Size);
110       if (MI->getNumOperands() && Size == 1) {
111         MallocArg = MI->getOperand(0);         // Operand * 1 = Operand
112       } else if (MI->getNumOperands()) {
113         // Multiply it by the array size if neccesary...
114         MallocArg = BinaryOperator::create(Instruction::Mul,MI->getOperand(0),
115                                            MallocArg);
116         BBIL.insert(BBIL.begin()+i++, cast<Instruction>(MallocArg));
117       }
118       
119       // Create the call to Malloc...
120       CallInst *MCall = new CallInst(MallocFunc,
121                                      vector<Value*>(1, MallocArg));
122       BBIL.insert(BBIL.begin()+i, MCall);
123       
124       // Create a cast instruction to convert to the right type...
125       CastInst *MCast = new CastInst(MCall, MI->getType());
126       BBIL.insert(BBIL.begin()+i+1, MCast);
127       
128       // Replace all uses of the old malloc inst with the cast inst
129       MI->replaceAllUsesWith(MCast);
130       delete MI;                          // Delete the malloc inst
131       Changed = true;
132     } else if (FreeInst *FI = dyn_cast<FreeInst>(*(BBIL.begin()+i))) {
133       BBIL.remove(BB->getInstList().begin()+i);
134       
135       // Cast the argument to free into a ubyte*...
136       CastInst *MCast = new CastInst(FI->getOperand(0), 
137                                      PointerType::get(Type::UByteTy));
138       BBIL.insert(BBIL.begin()+i, MCast);
139       
140       // Insert a call to the free function...
141       CallInst *FCall = new CallInst(FreeFunc,
142                                      vector<Value*>(1, MCast));
143       BBIL.insert(BBIL.begin()+i+1, FCall);
144       
145       // Delete the old free instruction
146       delete FI;
147       Changed = true;
148     }
149   }
150
151   return Changed;
152 }
153
154 bool RaiseAllocations::doInitialization(Module *M) {
155   // If the module has a symbol table, they might be referring to the malloc
156   // and free functions.  If this is the case, grab the method pointers that 
157   // the module is using.
158   //
159   // Lookup %malloc and %free in the symbol table, for later use.  If they
160   // don't exist, or are not external, we do not worry about converting calls
161   // to that function into the appropriate instruction.
162   //
163   const FunctionType *MallocType =   // Get the type for malloc
164     FunctionType::get(PointerType::get(Type::SByteTy),
165                       vector<const Type*>(1, Type::UIntTy), false);
166
167   const FunctionType *FreeType =     // Get the type for free
168     FunctionType::get(Type::VoidTy,
169                       vector<const Type*>(1, PointerType::get(Type::SByteTy)),
170                       false);
171
172   MallocFunc = M->getFunction("malloc", MallocType);
173   FreeFunc   = M->getFunction("free"  , FreeType);
174
175   // Don't mess with locally defined versions of these functions...
176   if (MallocFunc && !MallocFunc->isExternal()) MallocFunc = 0;
177   if (FreeFunc && !FreeFunc->isExternal())     FreeFunc = 0;
178   return false;
179 }
180
181 // doOneCleanupPass - Do one pass over the input method, fixing stuff up.
182 //
183 bool RaiseAllocations::runOnBasicBlock(BasicBlock *BB) {
184   bool Changed = false;
185   BasicBlock::InstListType &BIL = BB->getInstList();
186
187   for (BasicBlock::iterator BI = BB->begin(); BI != BB->end();) {
188     Instruction *I = *BI;
189
190     if (CallInst *CI = dyn_cast<CallInst>(I)) {
191       if (CI->getCalledValue() == MallocFunc) {      // Replace call to malloc?
192         const Type *PtrSByte = PointerType::get(Type::SByteTy);
193         MallocInst *MallocI = new MallocInst(PtrSByte, CI->getOperand(1),
194                                              CI->getName());
195         CI->setName("");
196         ReplaceInstWithInst(BIL, BI, MallocI);
197         Changed = true;
198         continue;  // Skip the ++BI
199       } else if (CI->getCalledValue() == FreeFunc) { // Replace call to free?
200         ReplaceInstWithInst(BIL, BI, new FreeInst(CI->getOperand(1)));
201         Changed = true;
202         continue;  // Skip the ++BI
203       }
204     }
205
206     ++BI;
207   }
208
209   return Changed;
210 }
211
212 Pass *createLowerAllocationsPass(const TargetData &TD) {
213   return new LowerAllocations(TD);
214 }
215 Pass *createRaiseAllocationsPass() {
216   return new RaiseAllocations();
217 }
218
219