deb4405754d7d2ea5a75d825e1b22752a5e66854
[oota-llvm.git] / lib / Transforms / IPO / RaiseAllocations.cpp
1 //===- RaiseAllocations.cpp - Convert @free calls to insts ------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines the RaiseAllocations pass which convert free calls to free
11 // instructions.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #define DEBUG_TYPE "raiseallocs"
16 #include "llvm/Transforms/IPO.h"
17 #include "llvm/Constants.h"
18 #include "llvm/DerivedTypes.h"
19 #include "llvm/LLVMContext.h"
20 #include "llvm/Module.h"
21 #include "llvm/Instructions.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Support/CallSite.h"
24 #include "llvm/Support/Compiler.h"
25 #include "llvm/ADT/Statistic.h"
26 #include <algorithm>
27 using namespace llvm;
28
29 STATISTIC(NumRaised, "Number of allocations raised");
30
31 namespace {
32   // RaiseAllocations - Turn @free calls into the appropriate
33   // instruction.
34   //
35   class VISIBILITY_HIDDEN RaiseAllocations : public ModulePass {
36     Function *FreeFunc;   // Functions in the module we are processing
37                           // Initialized by doPassInitializationVirt
38   public:
39     static char ID; // Pass identification, replacement for typeid
40     RaiseAllocations() 
41       : ModulePass(&ID), FreeFunc(0) {}
42
43     // doPassInitialization - For the raise allocations pass, this finds a
44     // declaration for free if it exists.
45     //
46     void doInitialization(Module &M);
47
48     // run - This method does the actual work of converting instructions over.
49     //
50     bool runOnModule(Module &M);
51   };
52 }  // end anonymous namespace
53
54 char RaiseAllocations::ID = 0;
55 static RegisterPass<RaiseAllocations>
56 X("raiseallocs", "Raise allocations from calls to instructions");
57
58 // createRaiseAllocationsPass - The interface to this file...
59 ModulePass *llvm::createRaiseAllocationsPass() {
60   return new RaiseAllocations();
61 }
62
63
64 // If the module has a symbol table, they might be referring to the free 
65 // function.  If this is the case, grab the method pointers that the module is
66 // using.
67 //
68 // Lookup @free in the symbol table, for later use.  If they don't
69 // exist, or are not external, we do not worry about converting calls to that
70 // function into the appropriate instruction.
71 //
72 void RaiseAllocations::doInitialization(Module &M) {
73   // Get free prototype if it exists!
74   FreeFunc = M.getFunction("free");
75   if (FreeFunc) {
76     const FunctionType* TyWeHave = FreeFunc->getFunctionType();
77     
78     // Get the expected prototype for void free(i8*)
79     const FunctionType *Free1Type =
80       FunctionType::get(Type::getVoidTy(M.getContext()),
81         std::vector<const Type*>(1, PointerType::getUnqual(
82                                  Type::getInt8Ty(M.getContext()))), 
83                                  false);
84
85     if (TyWeHave != Free1Type) {
86       // Check to see if the prototype was forgotten, giving us 
87       // void (...) * free
88       // This handles the common forward declaration of: 'void free();'
89       const FunctionType* Free2Type =
90                     FunctionType::get(Type::getVoidTy(M.getContext()), true);
91
92       if (TyWeHave != Free2Type) {
93         // One last try, check to see if we can find free as 
94         // int (...)* free.  This handles the case where NOTHING was declared.
95         const FunctionType* Free3Type =
96                     FunctionType::get(Type::getInt32Ty(M.getContext()), true);
97         
98         if (TyWeHave != Free3Type) {
99           // Give up.
100           FreeFunc = 0;
101         }
102       }
103     }
104   }
105
106   // Don't mess with locally defined versions of these functions...
107   if (FreeFunc && !FreeFunc->isDeclaration())     FreeFunc = 0;
108 }
109
110 // run - Transform calls into instructions...
111 //
112 bool RaiseAllocations::runOnModule(Module &M) {
113   // Find the free prototype...
114   doInitialization(M);
115   
116   bool Changed = false;
117
118   // Process all free calls...
119   if (FreeFunc) {
120     std::vector<User*> Users(FreeFunc->use_begin(), FreeFunc->use_end());
121     std::vector<Value*> EqPointers;   // Values equal to FreeFunc
122
123     while (!Users.empty()) {
124       User *U = Users.back();
125       Users.pop_back();
126
127       if (Instruction *I = dyn_cast<Instruction>(U)) {
128         if (isa<InvokeInst>(I))
129           continue;
130         CallSite CS = CallSite::get(I);
131         if (CS.getInstruction() && !CS.arg_empty() &&
132             (CS.getCalledFunction() == FreeFunc ||
133              std::find(EqPointers.begin(), EqPointers.end(),
134                        CS.getCalledValue()) != EqPointers.end())) {
135
136           // If no prototype was provided for free, we may need to cast the
137           // source pointer.  This should be really uncommon, but it's necessary
138           // just in case we are dealing with weird code like this:
139           //   free((long)ptr);
140           //
141           Value *Source = *CS.arg_begin();
142           if (!isa<PointerType>(Source->getType()))
143             Source = new IntToPtrInst(Source,           
144                         Type::getInt8PtrTy(M.getContext()), 
145                                       "FreePtrCast", I);
146           new FreeInst(Source, I);
147
148           // If the old instruction was an invoke, add an unconditional branch
149           // before the invoke, which will become the new terminator.
150           if (InvokeInst *II = dyn_cast<InvokeInst>(I))
151             BranchInst::Create(II->getNormalDest(), I);
152
153           // Delete the old call site
154           if (I->getType() != Type::getVoidTy(M.getContext()))
155             I->replaceAllUsesWith(UndefValue::get(I->getType()));
156           I->eraseFromParent();
157           Changed = true;
158           ++NumRaised;
159         }
160       } else if (GlobalValue *GV = dyn_cast<GlobalValue>(U)) {
161         Users.insert(Users.end(), GV->use_begin(), GV->use_end());
162         EqPointers.push_back(GV);
163       } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U)) {
164         if (CE->isCast()) {
165           Users.insert(Users.end(), CE->use_begin(), CE->use_end());
166           EqPointers.push_back(CE);
167         }
168       }
169     }
170   }
171
172   return Changed;
173 }