instcombine: Migrate strchr and strrchr optimizations
[oota-llvm.git] / lib / Transforms / Utils / SimplifyLibCalls.cpp
1 //===------ SimplifyLibCalls.cpp - Library calls simplifier ---------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This is a utility pass used for testing the InstructionSimplify analysis.
11 // The analysis is applied to every instruction, and if it simplifies then the
12 // instruction is replaced by the simplification.  If you are looking for a pass
13 // that performs serious instruction folding, use the instcombine pass instead.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "llvm/Transforms/Utils/SimplifyLibCalls.h"
18 #include "llvm/DataLayout.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/Analysis/ValueTracking.h"
21 #include "llvm/Function.h"
22 #include "llvm/IRBuilder.h"
23 #include "llvm/LLVMContext.h"
24 #include "llvm/Target/TargetLibraryInfo.h"
25 #include "llvm/Transforms/Utils/BuildLibCalls.h"
26
27 using namespace llvm;
28
29 /// This class is the abstract base class for the set of optimizations that
30 /// corresponds to one library call.
31 namespace {
32 class LibCallOptimization {
33 protected:
34   Function *Caller;
35   const DataLayout *TD;
36   const TargetLibraryInfo *TLI;
37   LLVMContext* Context;
38 public:
39   LibCallOptimization() { }
40   virtual ~LibCallOptimization() {}
41
42   /// callOptimizer - This pure virtual method is implemented by base classes to
43   /// do various optimizations.  If this returns null then no transformation was
44   /// performed.  If it returns CI, then it transformed the call and CI is to be
45   /// deleted.  If it returns something else, replace CI with the new value and
46   /// delete CI.
47   virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B)
48     =0;
49
50   Value *optimizeCall(CallInst *CI, const DataLayout *TD,
51                       const TargetLibraryInfo *TLI, IRBuilder<> &B) {
52     Caller = CI->getParent()->getParent();
53     this->TD = TD;
54     this->TLI = TLI;
55     if (CI->getCalledFunction())
56       Context = &CI->getCalledFunction()->getContext();
57
58     // We never change the calling convention.
59     if (CI->getCallingConv() != llvm::CallingConv::C)
60       return NULL;
61
62     return callOptimizer(CI->getCalledFunction(), CI, B);
63   }
64 };
65
66 //===----------------------------------------------------------------------===//
67 // Fortified Library Call Optimizations
68 //===----------------------------------------------------------------------===//
69
70 struct FortifiedLibCallOptimization : public LibCallOptimization {
71 protected:
72   virtual bool isFoldable(unsigned SizeCIOp, unsigned SizeArgOp,
73                           bool isString) const = 0;
74 };
75
76 struct InstFortifiedLibCallOptimization : public FortifiedLibCallOptimization {
77   CallInst *CI;
78
79   bool isFoldable(unsigned SizeCIOp, unsigned SizeArgOp, bool isString) const {
80     if (CI->getArgOperand(SizeCIOp) == CI->getArgOperand(SizeArgOp))
81       return true;
82     if (ConstantInt *SizeCI =
83                            dyn_cast<ConstantInt>(CI->getArgOperand(SizeCIOp))) {
84       if (SizeCI->isAllOnesValue())
85         return true;
86       if (isString) {
87         uint64_t Len = GetStringLength(CI->getArgOperand(SizeArgOp));
88         // If the length is 0 we don't know how long it is and so we can't
89         // remove the check.
90         if (Len == 0) return false;
91         return SizeCI->getZExtValue() >= Len;
92       }
93       if (ConstantInt *Arg = dyn_cast<ConstantInt>(
94                                                   CI->getArgOperand(SizeArgOp)))
95         return SizeCI->getZExtValue() >= Arg->getZExtValue();
96     }
97     return false;
98   }
99 };
100
101 struct MemCpyChkOpt : public InstFortifiedLibCallOptimization {
102   virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
103     this->CI = CI;
104     FunctionType *FT = Callee->getFunctionType();
105     LLVMContext &Context = CI->getParent()->getContext();
106
107     // Check if this has the right signature.
108     if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) ||
109         !FT->getParamType(0)->isPointerTy() ||
110         !FT->getParamType(1)->isPointerTy() ||
111         FT->getParamType(2) != TD->getIntPtrType(Context) ||
112         FT->getParamType(3) != TD->getIntPtrType(Context))
113       return 0;
114
115     if (isFoldable(3, 2, false)) {
116       B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1),
117                      CI->getArgOperand(2), 1);
118       return CI->getArgOperand(0);
119     }
120     return 0;
121   }
122 };
123
124 struct MemMoveChkOpt : public InstFortifiedLibCallOptimization {
125   virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
126     this->CI = CI;
127     FunctionType *FT = Callee->getFunctionType();
128     LLVMContext &Context = CI->getParent()->getContext();
129
130     // Check if this has the right signature.
131     if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) ||
132         !FT->getParamType(0)->isPointerTy() ||
133         !FT->getParamType(1)->isPointerTy() ||
134         FT->getParamType(2) != TD->getIntPtrType(Context) ||
135         FT->getParamType(3) != TD->getIntPtrType(Context))
136       return 0;
137
138     if (isFoldable(3, 2, false)) {
139       B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1),
140                       CI->getArgOperand(2), 1);
141       return CI->getArgOperand(0);
142     }
143     return 0;
144   }
145 };
146
147 struct MemSetChkOpt : public InstFortifiedLibCallOptimization {
148   virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
149     this->CI = CI;
150     FunctionType *FT = Callee->getFunctionType();
151     LLVMContext &Context = CI->getParent()->getContext();
152
153     // Check if this has the right signature.
154     if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) ||
155         !FT->getParamType(0)->isPointerTy() ||
156         !FT->getParamType(1)->isIntegerTy() ||
157         FT->getParamType(2) != TD->getIntPtrType(Context) ||
158         FT->getParamType(3) != TD->getIntPtrType(Context))
159       return 0;
160
161     if (isFoldable(3, 2, false)) {
162       Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(),
163                                    false);
164       B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1);
165       return CI->getArgOperand(0);
166     }
167     return 0;
168   }
169 };
170
171 struct StrCpyChkOpt : public InstFortifiedLibCallOptimization {
172   virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
173     this->CI = CI;
174     StringRef Name = Callee->getName();
175     FunctionType *FT = Callee->getFunctionType();
176     LLVMContext &Context = CI->getParent()->getContext();
177
178     // Check if this has the right signature.
179     if (FT->getNumParams() != 3 ||
180         FT->getReturnType() != FT->getParamType(0) ||
181         FT->getParamType(0) != FT->getParamType(1) ||
182         FT->getParamType(0) != Type::getInt8PtrTy(Context) ||
183         FT->getParamType(2) != TD->getIntPtrType(Context))
184       return 0;
185
186     // If a) we don't have any length information, or b) we know this will
187     // fit then just lower to a plain st[rp]cpy. Otherwise we'll keep our
188     // st[rp]cpy_chk call which may fail at runtime if the size is too long.
189     // TODO: It might be nice to get a maximum length out of the possible
190     // string lengths for varying.
191     if (isFoldable(2, 1, true)) {
192       Value *Ret = EmitStrCpy(CI->getArgOperand(0), CI->getArgOperand(1), B, TD,
193                               TLI, Name.substr(2, 6));
194       return Ret;
195     }
196     return 0;
197   }
198 };
199
200 struct StrNCpyChkOpt : public InstFortifiedLibCallOptimization {
201   virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
202     this->CI = CI;
203     StringRef Name = Callee->getName();
204     FunctionType *FT = Callee->getFunctionType();
205     LLVMContext &Context = CI->getParent()->getContext();
206
207     // Check if this has the right signature.
208     if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) ||
209         FT->getParamType(0) != FT->getParamType(1) ||
210         FT->getParamType(0) != Type::getInt8PtrTy(Context) ||
211         !FT->getParamType(2)->isIntegerTy() ||
212         FT->getParamType(3) != TD->getIntPtrType(Context))
213       return 0;
214
215     if (isFoldable(3, 2, false)) {
216       Value *Ret = EmitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1),
217                                CI->getArgOperand(2), B, TD, TLI,
218                                Name.substr(2, 7));
219       return Ret;
220     }
221     return 0;
222   }
223 };
224
225 //===----------------------------------------------------------------------===//
226 // String and Memory Library Call Optimizations
227 //===----------------------------------------------------------------------===//
228
229 struct StrCatOpt : public LibCallOptimization {
230   virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
231     // Verify the "strcat" function prototype.
232     FunctionType *FT = Callee->getFunctionType();
233     if (FT->getNumParams() != 2 ||
234         FT->getReturnType() != B.getInt8PtrTy() ||
235         FT->getParamType(0) != FT->getReturnType() ||
236         FT->getParamType(1) != FT->getReturnType())
237       return 0;
238
239     // Extract some information from the instruction
240     Value *Dst = CI->getArgOperand(0);
241     Value *Src = CI->getArgOperand(1);
242
243     // See if we can get the length of the input string.
244     uint64_t Len = GetStringLength(Src);
245     if (Len == 0) return 0;
246     --Len;  // Unbias length.
247
248     // Handle the simple, do-nothing case: strcat(x, "") -> x
249     if (Len == 0)
250       return Dst;
251
252     // These optimizations require DataLayout.
253     if (!TD) return 0;
254
255     return emitStrLenMemCpy(Src, Dst, Len, B);
256   }
257
258   Value *emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len,
259                           IRBuilder<> &B) {
260     // We need to find the end of the destination string.  That's where the
261     // memory is to be moved to. We just generate a call to strlen.
262     Value *DstLen = EmitStrLen(Dst, B, TD, TLI);
263     if (!DstLen)
264       return 0;
265
266     // Now that we have the destination's length, we must index into the
267     // destination's pointer to get the actual memcpy destination (end of
268     // the string .. we're concatenating).
269     Value *CpyDst = B.CreateGEP(Dst, DstLen, "endptr");
270
271     // We have enough information to now generate the memcpy call to do the
272     // concatenation for us.  Make a memcpy to copy the nul byte with align = 1.
273     B.CreateMemCpy(CpyDst, Src,
274                    ConstantInt::get(TD->getIntPtrType(*Context), Len + 1), 1);
275     return Dst;
276   }
277 };
278
279 struct StrNCatOpt : public StrCatOpt {
280   virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
281     // Verify the "strncat" function prototype.
282     FunctionType *FT = Callee->getFunctionType();
283     if (FT->getNumParams() != 3 ||
284         FT->getReturnType() != B.getInt8PtrTy() ||
285         FT->getParamType(0) != FT->getReturnType() ||
286         FT->getParamType(1) != FT->getReturnType() ||
287         !FT->getParamType(2)->isIntegerTy())
288       return 0;
289
290     // Extract some information from the instruction
291     Value *Dst = CI->getArgOperand(0);
292     Value *Src = CI->getArgOperand(1);
293     uint64_t Len;
294
295     // We don't do anything if length is not constant
296     if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2)))
297       Len = LengthArg->getZExtValue();
298     else
299       return 0;
300
301     // See if we can get the length of the input string.
302     uint64_t SrcLen = GetStringLength(Src);
303     if (SrcLen == 0) return 0;
304     --SrcLen;  // Unbias length.
305
306     // Handle the simple, do-nothing cases:
307     // strncat(x, "", c) -> x
308     // strncat(x,  c, 0) -> x
309     if (SrcLen == 0 || Len == 0) return Dst;
310
311     // These optimizations require DataLayout.
312     if (!TD) return 0;
313
314     // We don't optimize this case
315     if (Len < SrcLen) return 0;
316
317     // strncat(x, s, c) -> strcat(x, s)
318     // s is constant so the strcat can be optimized further
319     return emitStrLenMemCpy(Src, Dst, SrcLen, B);
320   }
321 };
322
323 struct StrChrOpt : public LibCallOptimization {
324   virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
325     // Verify the "strchr" function prototype.
326     FunctionType *FT = Callee->getFunctionType();
327     if (FT->getNumParams() != 2 ||
328         FT->getReturnType() != B.getInt8PtrTy() ||
329         FT->getParamType(0) != FT->getReturnType() ||
330         !FT->getParamType(1)->isIntegerTy(32))
331       return 0;
332
333     Value *SrcStr = CI->getArgOperand(0);
334
335     // If the second operand is non-constant, see if we can compute the length
336     // of the input string and turn this into memchr.
337     ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1));
338     if (CharC == 0) {
339       // These optimizations require DataLayout.
340       if (!TD) return 0;
341
342       uint64_t Len = GetStringLength(SrcStr);
343       if (Len == 0 || !FT->getParamType(1)->isIntegerTy(32))// memchr needs i32.
344         return 0;
345
346       return EmitMemChr(SrcStr, CI->getArgOperand(1), // include nul.
347                         ConstantInt::get(TD->getIntPtrType(*Context), Len),
348                         B, TD, TLI);
349     }
350
351     // Otherwise, the character is a constant, see if the first argument is
352     // a string literal.  If so, we can constant fold.
353     StringRef Str;
354     if (!getConstantStringInfo(SrcStr, Str))
355       return 0;
356
357     // Compute the offset, make sure to handle the case when we're searching for
358     // zero (a weird way to spell strlen).
359     size_t I = CharC->getSExtValue() == 0 ?
360         Str.size() : Str.find(CharC->getSExtValue());
361     if (I == StringRef::npos) // Didn't find the char.  strchr returns null.
362       return Constant::getNullValue(CI->getType());
363
364     // strchr(s+n,c)  -> gep(s+n+i,c)
365     return B.CreateGEP(SrcStr, B.getInt64(I), "strchr");
366   }
367 };
368
369 struct StrRChrOpt : public LibCallOptimization {
370   virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
371     // Verify the "strrchr" function prototype.
372     FunctionType *FT = Callee->getFunctionType();
373     if (FT->getNumParams() != 2 ||
374         FT->getReturnType() != B.getInt8PtrTy() ||
375         FT->getParamType(0) != FT->getReturnType() ||
376         !FT->getParamType(1)->isIntegerTy(32))
377       return 0;
378
379     Value *SrcStr = CI->getArgOperand(0);
380     ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1));
381
382     // Cannot fold anything if we're not looking for a constant.
383     if (!CharC)
384       return 0;
385
386     StringRef Str;
387     if (!getConstantStringInfo(SrcStr, Str)) {
388       // strrchr(s, 0) -> strchr(s, 0)
389       if (TD && CharC->isZero())
390         return EmitStrChr(SrcStr, '\0', B, TD, TLI);
391       return 0;
392     }
393
394     // Compute the offset.
395     size_t I = CharC->getSExtValue() == 0 ?
396         Str.size() : Str.rfind(CharC->getSExtValue());
397     if (I == StringRef::npos) // Didn't find the char. Return null.
398       return Constant::getNullValue(CI->getType());
399
400     // strrchr(s+n,c) -> gep(s+n+i,c)
401     return B.CreateGEP(SrcStr, B.getInt64(I), "strrchr");
402   }
403 };
404
405 } // End anonymous namespace.
406
407 namespace llvm {
408
409 class LibCallSimplifierImpl {
410   LibCallSimplifier *Simplifier;
411   const DataLayout *TD;
412   const TargetLibraryInfo *TLI;
413   StringMap<LibCallOptimization*> Optimizations;
414
415   // Fortified library call optimizations.
416   MemCpyChkOpt MemCpyChk;
417   MemMoveChkOpt MemMoveChk;
418   MemSetChkOpt MemSetChk;
419   StrCpyChkOpt StrCpyChk;
420   StrNCpyChkOpt StrNCpyChk;
421
422   // String and memory library call optimizations.
423   StrCatOpt StrCat;
424   StrNCatOpt StrNCat;
425   StrChrOpt StrChr;
426   StrRChrOpt StrRChr;
427
428   void initOptimizations();
429 public:
430   LibCallSimplifierImpl(const DataLayout *TD, const TargetLibraryInfo *TLI) {
431     this->TD = TD;
432     this->TLI = TLI;
433   }
434
435   Value *optimizeCall(CallInst *CI);
436 };
437
438 void LibCallSimplifierImpl::initOptimizations() {
439   // Fortified library call optimizations.
440   Optimizations["__memcpy_chk"] = &MemCpyChk;
441   Optimizations["__memmove_chk"] = &MemMoveChk;
442   Optimizations["__memset_chk"] = &MemSetChk;
443   Optimizations["__strcpy_chk"] = &StrCpyChk;
444   Optimizations["__stpcpy_chk"] = &StrCpyChk;
445   Optimizations["__strncpy_chk"] = &StrNCpyChk;
446   Optimizations["__stpncpy_chk"] = &StrNCpyChk;
447
448   // String and memory library call optimizations.
449   Optimizations["strcat"] = &StrCat;
450   Optimizations["strncat"] = &StrNCat;
451   Optimizations["strchr"] = &StrChr;
452   Optimizations["strrchr"] = &StrRChr;
453 }
454
455 Value *LibCallSimplifierImpl::optimizeCall(CallInst *CI) {
456   if (Optimizations.empty())
457     initOptimizations();
458
459   Function *Callee = CI->getCalledFunction();
460   LibCallOptimization *LCO = Optimizations.lookup(Callee->getName());
461   if (LCO) {
462     IRBuilder<> Builder(CI);
463     return LCO->optimizeCall(CI, TD, TLI, Builder);
464   }
465   return 0;
466 }
467
468 LibCallSimplifier::LibCallSimplifier(const DataLayout *TD,
469                                      const TargetLibraryInfo *TLI) {
470   Impl = new LibCallSimplifierImpl(TD, TLI);
471 }
472
473 LibCallSimplifier::~LibCallSimplifier() {
474   delete Impl;
475 }
476
477 Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
478   return Impl->optimizeCall(CI);
479 }
480
481 }