instcombine: Migrate strncpy optimizations
[oota-llvm.git] / lib / Transforms / Utils / SimplifyLibCalls.cpp
index 5984a96cd0fd86fec5d7bcfa06bd0e5f83f2a168..cc03573326e883495f11dbb704c404a084071ec7 100644 (file)
@@ -592,6 +592,89 @@ struct StrCpyOpt : public LibCallOptimization {
   }
 };
 
+struct StpCpyOpt: public LibCallOptimization {
+  virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
+    // Verify the "stpcpy" function prototype.
+    FunctionType *FT = Callee->getFunctionType();
+    if (FT->getNumParams() != 2 ||
+        FT->getReturnType() != FT->getParamType(0) ||
+        FT->getParamType(0) != FT->getParamType(1) ||
+        FT->getParamType(0) != B.getInt8PtrTy())
+      return 0;
+
+    // These optimizations require DataLayout.
+    if (!TD) return 0;
+
+    Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1);
+    if (Dst == Src) {  // stpcpy(x,x)  -> x+strlen(x)
+      Value *StrLen = EmitStrLen(Src, B, TD, TLI);
+      return StrLen ? B.CreateInBoundsGEP(Dst, StrLen) : 0;
+    }
+
+    // See if we can get the length of the input string.
+    uint64_t Len = GetStringLength(Src);
+    if (Len == 0) return 0;
+
+    Type *PT = FT->getParamType(0);
+    Value *LenV = ConstantInt::get(TD->getIntPtrType(PT), Len);
+    Value *DstEnd = B.CreateGEP(Dst,
+                                ConstantInt::get(TD->getIntPtrType(PT),
+                                                 Len - 1));
+
+    // We have enough information to now generate the memcpy call to do the
+    // copy for us.  Make a memcpy to copy the nul byte with align = 1.
+    B.CreateMemCpy(Dst, Src, LenV, 1);
+    return DstEnd;
+  }
+};
+
+struct StrNCpyOpt : public LibCallOptimization {
+  virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
+    FunctionType *FT = Callee->getFunctionType();
+    if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) ||
+        FT->getParamType(0) != FT->getParamType(1) ||
+        FT->getParamType(0) != B.getInt8PtrTy() ||
+        !FT->getParamType(2)->isIntegerTy())
+      return 0;
+
+    Value *Dst = CI->getArgOperand(0);
+    Value *Src = CI->getArgOperand(1);
+    Value *LenOp = CI->getArgOperand(2);
+
+    // See if we can get the length of the input string.
+    uint64_t SrcLen = GetStringLength(Src);
+    if (SrcLen == 0) return 0;
+    --SrcLen;
+
+    if (SrcLen == 0) {
+      // strncpy(x, "", y) -> memset(x, '\0', y, 1)
+      B.CreateMemSet(Dst, B.getInt8('\0'), LenOp, 1);
+      return Dst;
+    }
+
+    uint64_t Len;
+    if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(LenOp))
+      Len = LengthArg->getZExtValue();
+    else
+      return 0;
+
+    if (Len == 0) return Dst; // strncpy(x, y, 0) -> x
+
+    // These optimizations require DataLayout.
+    if (!TD) return 0;
+
+    // Let strncpy handle the zero padding
+    if (Len > SrcLen+1) return 0;
+
+    Type *PT = FT->getParamType(0);
+    // strncpy(x, s, c) -> memcpy(x, s, c, 1) [s and c are constant]
+    B.CreateMemCpy(Dst, Src,
+                   ConstantInt::get(TD->getIntPtrType(PT), Len), 1);
+
+    return Dst;
+  }
+};
+
 } // End anonymous namespace.
 
 namespace llvm {
@@ -617,6 +700,8 @@ class LibCallSimplifierImpl {
   StrCmpOpt StrCmp;
   StrNCmpOpt StrNCmp;
   StrCpyOpt StrCpy;
+  StpCpyOpt StpCpy;
+  StrNCpyOpt StrNCpy;
 
   void initOptimizations();
 public:
@@ -646,6 +731,8 @@ void LibCallSimplifierImpl::initOptimizations() {
   Optimizations["strcmp"] = &StrCmp;
   Optimizations["strncmp"] = &StrNCmp;
   Optimizations["strcpy"] = &StrCpy;
+  Optimizations["stpcpy"] = &StpCpy;
+  Optimizations["strncpy"] = &StrNCpy;
 }
 
 Value *LibCallSimplifierImpl::optimizeCall(CallInst *CI) {