* Implement StrLenOptimization
[oota-llvm.git] / lib / Transforms / IPO / SimplifyLibCalls.cpp
1 //===- SimplifyLibCalls.cpp - Optimize specific well-known library calls --===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Reid Spencer and is distributed under the 
6 // University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements a variety of small optimizations for calls to specific
11 // well-known (e.g. runtime library) function calls. For example, a call to the
12 // function "exit(3)" that occurs within the main() function can be transformed
13 // into a simple "return 3" instruction. Any optimization that takes this form
14 // (replace call to library function with simpler code that provides same 
15 // result) belongs in this file. 
16 //
17 //===----------------------------------------------------------------------===//
18
19 #include "llvm/Transforms/IPO.h"
20 #include "llvm/Module.h"
21 #include "llvm/Pass.h"
22 #include "llvm/DerivedTypes.h"
23 #include "llvm/Constants.h"
24 #include "llvm/Instructions.h"
25 #include "llvm/ADT/Statistic.h"
26 #include "llvm/ADT/hash_map"
27 #include <iostream>
28 using namespace llvm;
29
30 namespace {
31   Statistic<> SimplifiedLibCalls("simplified-lib-calls", 
32       "Number of well-known library calls simplified");
33
34   /// This class is the base class for a set of small but important 
35   /// optimizations of calls to well-known functions, such as those in the c
36   /// library. This class provides the basic infrastructure for handling 
37   /// runOnModule. Subclasses register themselves and provide two methods:
38   /// RecognizeCall and OptimizeCall. Whenever this class finds a function call,
39   /// it asks the subclasses to recognize the call. If it is recognized, then
40   /// the OptimizeCall method is called on that subclass instance. In this way
41   /// the subclasses implement the calling conditions on which they trigger and
42   /// the action to perform, making it easy to add new optimizations of this
43   /// form.
44   /// @brief A ModulePass for optimizing well-known function calls
45   struct SimplifyLibCalls : public ModulePass {
46
47
48     /// For this pass, process all of the function calls in the module, calling
49     /// RecognizeCall and OptimizeCall as appropriate.
50     virtual bool runOnModule(Module &M);
51
52   };
53
54   RegisterOpt<SimplifyLibCalls> 
55     X("simplify-libcalls","Simplify well-known library calls");
56
57   struct CallOptimizer
58   {
59     /// @brief Constructor that registers the optimization
60     CallOptimizer(const char * fname );
61
62     virtual ~CallOptimizer();
63
64     /// The implementation of this function in subclasses should determine if
65     /// \p F is suitable for the optimization. This method is called by 
66     /// runOnModule to short circuit visiting all the call sites of such a
67     /// function if that function is not suitable in the first place.
68     /// If the called function is suitabe, this method should return true;
69     /// false, otherwise. This function should also perform any lazy 
70     /// initialization that the CallOptimizer needs to do, if its to return 
71     /// true. This avoids doing initialization until the optimizer is actually
72     /// going to be called upon to do some optimization.
73     virtual bool ValidateCalledFunction(
74       const Function* F ///< The function that is the target of call sites
75     ) = 0;
76
77     /// The implementations of this function in subclasses is the heart of the 
78     /// SimplifyLibCalls algorithm. Sublcasses of this class implement 
79     /// OptimizeCall to determine if (a) the conditions are right for optimizing
80     /// the call and (b) to perform the optimization. If an action is taken 
81     /// against ci, the subclass is responsible for returning true and ensuring
82     /// that ci is erased from its parent.
83     /// @param ci the call instruction under consideration
84     /// @param f the function that ci calls.
85     /// @brief Optimize a call, if possible.
86     virtual bool OptimizeCall(
87       CallInst* ci ///< The call instruction that should be optimized.
88     ) = 0;
89
90     const char * getFunctionName() const { return func_name; }
91   private:
92     const char* func_name;
93   };
94
95   /// @brief The list of optimizations deriving from CallOptimizer
96
97   hash_map<std::string,CallOptimizer*> optlist;
98
99   CallOptimizer::CallOptimizer(const char* fname)
100     : func_name(fname)
101   {
102     // Register this call optimizer
103     optlist[func_name] = this;
104   }
105
106   /// Make sure we get our virtual table in this file.
107   CallOptimizer::~CallOptimizer() { }
108
109   /// Provide some functions for accessing standard library prototypes and
110   /// caching them so we don't have to keep recomputing them
111   FunctionType* get_strlen()
112   {
113     static FunctionType* strlen_type = 0;
114     if (!strlen_type)
115     {
116       std::vector<const Type*> args;
117       args.push_back(PointerType::get(Type::SByteTy));
118       strlen_type = FunctionType::get(Type::IntTy, args, false);
119     }
120     return strlen_type;
121   }
122
123   FunctionType* get_memcpy()
124   {
125     static FunctionType* memcpy_type = 0;
126     if (!memcpy_type)
127     {
128       // Note: this is for llvm.memcpy intrinsic
129       std::vector<const Type*> args;
130       args.push_back(PointerType::get(Type::SByteTy));
131       args.push_back(PointerType::get(Type::SByteTy));
132       args.push_back(Type::IntTy);
133       args.push_back(Type::IntTy);
134       memcpy_type = FunctionType::get(
135         PointerType::get(Type::SByteTy), args, false);
136     }
137     return memcpy_type;
138   }
139
140   // Provide some utility functions for various checks common to more than
141   // one CallOptimizer
142   Constant* get_GVInitializer(Value* V)
143   {
144     User* GEP = 0;
145     // If the value not a GEP instruction nor a constant expression with a GEP 
146     // instruction, then return 0 because ConstantArray can't occur any other
147     // way
148     if (GetElementPtrInst* GEPI = dyn_cast<GetElementPtrInst>(V))
149       GEP = GEPI;
150     else if (ConstantExpr* CE = dyn_cast<ConstantExpr>(V))
151       if (CE->getOpcode() == Instruction::GetElementPtr)
152         GEP = CE;
153       else
154         return 0;
155     else
156       return 0;
157
158     // Check to make sure that the first operand of the GEP is an integer and
159     // has value 0 so that we are sure we're indexing into the initializer. 
160     if (ConstantInt* op1 = dyn_cast<ConstantInt>(GEP->getOperand(1)))
161       if (op1->isNullValue())
162         ;
163       else
164         return false;
165     else
166       return false;
167
168     // Ensure that the second operand is a ConstantInt. If it isn't then this
169     // GEP is wonky and we're not really sure what were referencing into and 
170     // better of not optimizing it.
171     if (!dyn_cast<ConstantInt>(GEP->getOperand(2)))
172       return 0;
173
174     // The GEP instruction, constant or instruction, must reference a global
175     // variable that is a constant and is initialized. The referenced constant
176     // initializer is the array that we'll use for optimization.
177     GlobalVariable* GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
178     if (!GV || !GV->isConstant() || !GV->hasInitializer())
179       return 0;
180
181     // Return the result
182     return GV->getInitializer();
183   }
184
185   /// A function to compute the length of a null-terminated string of integers.
186   /// This function can't rely on the size of the constant array because there 
187   /// could be a null terminator in the middle of the array. We also have to 
188   /// bail out if we find a non-integer constant initializer of one of the 
189   /// elements or if there is no null-terminator. The logic below checks
190   bool getCharArrayLength(ConstantArray* A, unsigned& len)
191   {
192     assert(A != 0 && "Invalid args to getCharArrayLength");
193     // Get the supposed length
194     unsigned max_elems = A->getType()->getNumElements();
195     len = 0;
196     // Examine all the elements
197     for (; len < max_elems; len++)
198     {
199       if (ConstantInt* CI = dyn_cast<ConstantInt>(A->getOperand(len)))
200       {
201         // Check for the null terminator
202         if (CI->isNullValue())
203           break; // we found end of string
204       }
205       else
206         return false; // This array isn't suitable, non-int initializer
207     }
208     if (len >= max_elems)
209       return false; // This array isn't null terminated
210     return true; // success!
211   }
212 }
213
214 ModulePass *llvm::createSimplifyLibCallsPass() 
215
216   return new SimplifyLibCalls(); 
217 }
218
219 bool SimplifyLibCalls::runOnModule(Module &M) 
220 {
221   bool result = false;
222
223   // The call optimizations can be recursive. That is, the optimization might
224   // generate a call to another function which can also be optimized. This way
225   // we make the CallOptimizer instances very specific to the case they handle.
226   // It also means we need to keep running over the function calls in the module
227   // until we don't get any more optimizations possible.
228   bool found_optimization = false;
229   do
230   {
231     found_optimization = false;
232     for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI)
233     {
234       // All the "well-known" functions are external and have external linkage
235       // because they live in a runtime library somewhere and were (probably) 
236       // not compiled by LLVM.  So, we only act on external functions that have 
237       // external linkage and non-empty uses.
238       if (FI->isExternal() && FI->hasExternalLinkage() && !FI->use_empty())
239       {
240         // Get the optimization class that pertains to this function
241         if (CallOptimizer* CO = optlist[FI->getName().c_str()] )
242         {
243           // Make sure the called function is suitable for the optimization
244           if (CO->ValidateCalledFunction(FI))
245           {
246             // Loop over each of the uses of the function
247             for (Value::use_iterator UI = FI->use_begin(), UE = FI->use_end(); 
248                  UI != UE ; )
249             {
250               // If the use of the function is a call instruction
251               if (CallInst* CI = dyn_cast<CallInst>(*UI++))
252               {
253                 // Do the optimization on the CallOptimizer.
254                 if (CO->OptimizeCall(CI))
255                 {
256                   ++SimplifiedLibCalls;
257                   found_optimization = result = true;
258                 }
259               }
260             }
261           }
262         }
263       }
264     }
265   } while (found_optimization);
266   return result;
267 }
268
269 namespace {
270
271 /// This CallOptimizer will find instances of a call to "exit" that occurs
272 /// within the "main" function and change it to a simple "ret" instruction with
273 /// the same value as passed to the exit function. It assumes that the 
274 /// instructions after the call to exit(3) can be deleted since they are 
275 /// unreachable anyway.
276 /// @brief Replace calls to exit in main with a simple return
277 struct ExitInMainOptimization : public CallOptimizer
278 {
279   ExitInMainOptimization() : CallOptimizer("exit") {}
280   virtual ~ExitInMainOptimization() {}
281
282   // Make sure the called function looks like exit (int argument, int return
283   // type, external linkage, not varargs). 
284   virtual bool ValidateCalledFunction(const Function* f)
285   {
286     if (f->getReturnType()->getTypeID() == Type::VoidTyID && !f->isVarArg())
287       if (f->arg_size() == 1)
288         if (f->arg_begin()->getType()->isInteger())
289           return true;
290     return false;
291   }
292
293   virtual bool OptimizeCall(CallInst* ci)
294   {
295     // To be careful, we check that the call to exit is coming from "main", that
296     // main has external linkage, and the return type of main and the argument
297     // to exit have the same type. 
298     Function *from = ci->getParent()->getParent();
299     if (from->hasExternalLinkage())
300       if (from->getReturnType() == ci->getOperand(1)->getType())
301         if (from->getName() == "main")
302         {
303           // Okay, time to actually do the optimization. First, get the basic 
304           // block of the call instruction
305           BasicBlock* bb = ci->getParent();
306
307           // Create a return instruction that we'll replace the call with. 
308           // Note that the argument of the return is the argument of the call 
309           // instruction.
310           ReturnInst* ri = new ReturnInst(ci->getOperand(1), ci);
311
312           // Split the block at the call instruction which places it in a new
313           // basic block.
314           bb->splitBasicBlock(ci);
315
316           // The block split caused a branch instruction to be inserted into
317           // the end of the original block, right after the return instruction
318           // that we put there. That's not a valid block, so delete the branch
319           // instruction.
320           bb->getInstList().pop_back();
321
322           // Now we can finally get rid of the call instruction which now lives
323           // in the new basic block.
324           ci->eraseFromParent();
325
326           // Optimization succeeded, return true.
327           return true;
328         }
329     // We didn't pass the criteria for this optimization so return false
330     return false;
331   }
332 } ExitInMainOptimizer;
333
334 /// This CallOptimizer will simplify a call to the strcat library function. The
335 /// simplification is possible only if the string being concatenated is a 
336 /// constant array or a constant expression that results in a constant array. In
337 /// this case, if the array is small, we can generate a series of inline store
338 /// instructions to effect the concatenation without calling strcat.
339 /// @brief Simplify the strcat library function.
340 struct StrCatOptimization : public CallOptimizer
341 {
342 private:
343   Function* strlen_func;
344   Function* memcpy_func;
345 public:
346   StrCatOptimization() 
347     : CallOptimizer("strcat") 
348     , strlen_func(0)
349     , memcpy_func(0)
350     {}
351   virtual ~StrCatOptimization() {}
352
353   inline Function* get_strlen_func(Module*M)
354   {
355     if (strlen_func)
356       return strlen_func;
357     return strlen_func = M->getOrInsertFunction("strlen",get_strlen());
358   }
359
360   inline Function* get_memcpy_func(Module* M) 
361   {
362     if (memcpy_func)
363       return memcpy_func;
364     return memcpy_func = M->getOrInsertFunction("llvm.memcpy",get_memcpy());
365   }
366
367   /// @brief Make sure that the "strcat" function has the right prototype
368   virtual bool ValidateCalledFunction(const Function* f) 
369   {
370     if (f->getReturnType() == PointerType::get(Type::SByteTy))
371       if (f->arg_size() == 2) 
372       {
373         Function::const_arg_iterator AI = f->arg_begin();
374         if (AI++->getType() == PointerType::get(Type::SByteTy))
375           if (AI->getType() == PointerType::get(Type::SByteTy))
376           {
377             // Invalidate the pre-computed strlen_func and memcpy_func Functions
378             // because, by definition, this method is only called when a new
379             // Module is being traversed. Invalidation causes re-computation for
380             // the new Module (if necessary).
381             strlen_func = 0;
382             memcpy_func = 0;
383
384             // Indicate this is a suitable call type.
385             return true;
386           }
387       }
388     return false;
389   }
390
391   /// Perform the optimization if the length of the string concatenated
392   /// is reasonably short and it is a constant array.
393   virtual bool OptimizeCall(CallInst* ci)
394   {
395     // Extract the initializer (while making numerous checks) from the 
396     // source operand of the call to strcat. If we get null back, one of
397     // a variety of checks in get_GVInitializer failed
398     Constant* INTLZR = get_GVInitializer(ci->getOperand(2));
399     if (!INTLZR)
400       return false;
401
402     // Handle the ConstantArray case.
403     if (ConstantArray* A = dyn_cast<ConstantArray>(INTLZR))
404     {
405       // First off, we can't do this if the constant array isn't a string, 
406       // meaning its base type is sbyte and its constant initializers for all
407       // the elements are constantInt or constantInt expressions.
408       if (!A->isString())
409         return false;
410
411       // Now we need to examine the source string to find its actual length. We
412       // can't rely on the size of the constant array becasue there could be a
413       // null terminator in the middle of the array. We also have to bail out if
414       // we find a non-integer constant initializer of one of the elements. 
415       // Also, if we never find a terminator before the end of the array.
416       unsigned max_elems = A->getType()->getNumElements();
417       unsigned len = 0;
418       if (!getCharArrayLength(A,len))
419         return false;
420       else
421         len++; // increment for null terminator
422
423       // Extract some information from the instruction
424       Module* M = ci->getParent()->getParent()->getParent();
425
426       // We need to find the end of the destination string.  That's where the 
427       // memory is to be moved to. We just generate a call to strlen (further 
428       // optimized in another pass). Note that the get_strlen_func() call 
429       // caches the Function* for us.
430       CallInst* strlen_inst = 
431         new CallInst(get_strlen_func(M),ci->getOperand(1),"",ci);
432
433       // Now that we have the destination's length, we must index into the 
434       // destination's pointer to get the actual memcpy destination (end of
435       // the string .. we're concatenating).
436       std::vector<Value*> idx;
437       idx.push_back(strlen_inst);
438       GetElementPtrInst* gep = 
439         new GetElementPtrInst(ci->getOperand(1),idx,"",ci);
440
441       // We have enough information to now generate the memcpy call to
442       // do the concatenation for us.
443       std::vector<Value*> vals;
444       vals.push_back(gep); // destination
445       vals.push_back(ci->getOperand(2)); // source
446       vals.push_back(ConstantSInt::get(Type::IntTy,len)); // length
447       vals.push_back(ConstantSInt::get(Type::IntTy,1)); // alignment
448       CallInst* memcpy_inst = 
449         new CallInst(get_memcpy_func(M), vals, "", ci);
450
451       // Finally, substitute the first operand of the strcat call for the 
452       // strcat call itself since strcat returns its first operand; and, 
453       // kill the strcat CallInst.
454       ci->replaceAllUsesWith(ci->getOperand(1));
455       ci->eraseFromParent();
456       return true;
457     }
458
459     // Handle the ConstantAggregateZero case
460     else if (ConstantAggregateZero* CAZ = 
461         dyn_cast<ConstantAggregateZero>(INTLZR))
462     {
463       // We know this is the zero length string case so we can just avoid
464       // the strcat altogether and replace the CallInst with its first operand
465       // (what strcat returns).
466       ci->replaceAllUsesWith(ci->getOperand(1));
467       ci->eraseFromParent();
468       return true;
469     }
470
471     // We didn't pass the criteria for this optimization so return false.
472     return false;
473   }
474 } StrCatOptimizer;
475
476 /// This CallOptimizer will simplify a call to the strlen library function by
477 /// replacing it with a constant value if the string provided to it is a 
478 /// constant array.
479 /// @brief Simplify the strlen library function.
480 struct StrLenOptimization : public CallOptimizer
481 {
482   StrLenOptimization() : CallOptimizer("strlen") {}
483   virtual ~StrLenOptimization() {}
484
485   /// @brief Make sure that the "strlen" function has the right prototype
486   virtual bool ValidateCalledFunction(const Function* f)
487   {
488     if (f->getReturnType() == Type::IntTy)
489       if (f->arg_size() == 1) 
490         if (Function::const_arg_iterator AI = f->arg_begin())
491           if (AI->getType() == PointerType::get(Type::SByteTy))
492             return true;
493     return false;
494   }
495
496   /// @brief Perform the strlen optimization
497   virtual bool OptimizeCall(CallInst* ci)
498   {
499     // Extract the initializer (while making numerous checks) from the 
500     // source operand of the call to strlen. If we get null back, one of
501     // a variety of checks in get_GVInitializer failed
502     Constant* INTLZR = get_GVInitializer(ci->getOperand(1));
503     if (!INTLZR)
504       return false;
505
506     if (ConstantArray* A = dyn_cast<ConstantArray>(INTLZR))
507     {
508       unsigned len = 0;
509       if (!getCharArrayLength(A,len))
510         return false;
511       ci->replaceAllUsesWith(ConstantInt::get(Type::IntTy,len));
512       ci->eraseFromParent();
513       return true;
514     }
515
516     // Handle the ConstantAggregateZero case
517     else if (ConstantAggregateZero* CAZ = 
518         dyn_cast<ConstantAggregateZero>(INTLZR))
519     {
520       // We know this is the zero length string case so we can just avoid
521       // the strlen altogether and replace the CallInst with zero
522       ci->replaceAllUsesWith(ConstantInt::get(Type::IntTy,0));
523       ci->eraseFromParent();
524       return true;
525     }
526
527     // We didn't pass the criteria for this optimization so return false.
528     return false;
529   }
530 } StrLenOptimizer;
531
532 /// This CallOptimizer will simplify a call to the memcpy library function by
533 /// expanding it out to a small set of stores if the copy source is a constant
534 /// array. 
535 /// @brief Simplify the memcpy library function.
536 struct MemCpyOptimization : public CallOptimizer
537 {
538   MemCpyOptimization() : CallOptimizer("llvm.memcpy") {}
539   virtual ~MemCpyOptimization() {}
540
541   /// @brief Make sure that the "memcpy" function has the right prototype
542   virtual bool ValidateCalledFunction(const Function* f)
543   {
544     if (f->getReturnType() == PointerType::get(Type::SByteTy))
545       if (f->arg_size() == 2) 
546       {
547         Function::const_arg_iterator AI = f->arg_begin();
548         if (AI++->getType() == PointerType::get(Type::SByteTy))
549           if (AI++->getType() == PointerType::get(Type::SByteTy))
550             if (AI++->getType() == Type::IntTy)
551               if (AI->getType() == Type::IntTy)
552             return true;
553       }
554     return false;
555   }
556
557   /// Perform the optimization if the length of the string concatenated
558   /// is reasonably short and it is a constant array.
559   virtual bool OptimizeCall(CallInst* ci)
560   {
561     // 
562     // We didn't pass the criteria for this optimization so return false.
563     return false;
564   }
565 } MemCpyOptimizer;
566 }