instcombine: Migrate fwrite optimizations
[oota-llvm.git] / lib / Transforms / Utils / SimplifyLibCalls.cpp
index 27ccdc1095ec01cc410d09d76b85ee5d4d3dd5f5..00a5d89cb2e5674357cc26efc78de7d364656283 100644 (file)
@@ -102,6 +102,15 @@ static bool isOnlyUsedInEqualityComparison(Value *V, Value *With) {
   return true;
 }
 
+static bool callHasFloatingPointArgument(const CallInst *CI) {
+  for (CallInst::const_op_iterator it = CI->op_begin(), e = CI->op_end();
+       it != e; ++it) {
+    if ((*it)->getType()->isFloatingPointTy())
+      return true;
+  }
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // Fortified Library Call Optimizations
 //===----------------------------------------------------------------------===//
@@ -1306,12 +1315,301 @@ struct ToAsciiOpt : public LibCallOptimization {
         !FT->getParamType(0)->isIntegerTy(32))
       return 0;
 
-    // isascii(c) -> c & 0x7f
+    // toascii(c) -> c & 0x7f
     return B.CreateAnd(CI->getArgOperand(0),
                        ConstantInt::get(CI->getType(),0x7F));
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Formatting and IO Library Call Optimizations
+//===----------------------------------------------------------------------===//
+
+struct PrintFOpt : public LibCallOptimization {
+  Value *optimizeFixedFormatString(Function *Callee, CallInst *CI,
+                                   IRBuilder<> &B) {
+    // Check for a fixed format string.
+    StringRef FormatStr;
+    if (!getConstantStringInfo(CI->getArgOperand(0), FormatStr))
+      return 0;
+
+    // Empty format string -> noop.
+    if (FormatStr.empty())  // Tolerate printf's declared void.
+      return CI->use_empty() ? (Value*)CI :
+                               ConstantInt::get(CI->getType(), 0);
+
+    // Do not do any of the following transformations if the printf return value
+    // is used, in general the printf return value is not compatible with either
+    // putchar() or puts().
+    if (!CI->use_empty())
+      return 0;
+
+    // printf("x") -> putchar('x'), even for '%'.
+    if (FormatStr.size() == 1) {
+      Value *Res = EmitPutChar(B.getInt32(FormatStr[0]), B, TD, TLI);
+      if (CI->use_empty() || !Res) return Res;
+      return B.CreateIntCast(Res, CI->getType(), true);
+    }
+
+    // printf("foo\n") --> puts("foo")
+    if (FormatStr[FormatStr.size()-1] == '\n' &&
+        FormatStr.find('%') == std::string::npos) {  // no format characters.
+      // Create a string literal with no \n on it.  We expect the constant merge
+      // pass to be run after this pass, to merge duplicate strings.
+      FormatStr = FormatStr.drop_back();
+      Value *GV = B.CreateGlobalString(FormatStr, "str");
+      Value *NewCI = EmitPutS(GV, B, TD, TLI);
+      return (CI->use_empty() || !NewCI) ?
+              NewCI :
+              ConstantInt::get(CI->getType(), FormatStr.size()+1);
+    }
+
+    // Optimize specific format strings.
+    // printf("%c", chr) --> putchar(chr)
+    if (FormatStr == "%c" && CI->getNumArgOperands() > 1 &&
+        CI->getArgOperand(1)->getType()->isIntegerTy()) {
+      Value *Res = EmitPutChar(CI->getArgOperand(1), B, TD, TLI);
+
+      if (CI->use_empty() || !Res) return Res;
+      return B.CreateIntCast(Res, CI->getType(), true);
+    }
+
+    // printf("%s\n", str) --> puts(str)
+    if (FormatStr == "%s\n" && CI->getNumArgOperands() > 1 &&
+        CI->getArgOperand(1)->getType()->isPointerTy()) {
+      return EmitPutS(CI->getArgOperand(1), B, TD, TLI);
+    }
+    return 0;
+  }
+
+  virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
+    // Require one fixed pointer argument and an integer/void result.
+    FunctionType *FT = Callee->getFunctionType();
+    if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() ||
+        !(FT->getReturnType()->isIntegerTy() ||
+          FT->getReturnType()->isVoidTy()))
+      return 0;
+
+    if (Value *V = optimizeFixedFormatString(Callee, CI, B)) {
+      return V;
+    }
+
+    // printf(format, ...) -> iprintf(format, ...) if no floating point
+    // arguments.
+    if (TLI->has(LibFunc::iprintf) && !callHasFloatingPointArgument(CI)) {
+      Module *M = B.GetInsertBlock()->getParent()->getParent();
+      Constant *IPrintFFn =
+        M->getOrInsertFunction("iprintf", FT, Callee->getAttributes());
+      CallInst *New = cast<CallInst>(CI->clone());
+      New->setCalledFunction(IPrintFFn);
+      B.Insert(New);
+      return New;
+    }
+    return 0;
+  }
+};
+
+struct SPrintFOpt : public LibCallOptimization {
+  Value *OptimizeFixedFormatString(Function *Callee, CallInst *CI,
+                                   IRBuilder<> &B) {
+    // Check for a fixed format string.
+    StringRef FormatStr;
+    if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr))
+      return 0;
+
+    // If we just have a format string (nothing else crazy) transform it.
+    if (CI->getNumArgOperands() == 2) {
+      // Make sure there's no % in the constant array.  We could try to handle
+      // %% -> % in the future if we cared.
+      for (unsigned i = 0, e = FormatStr.size(); i != e; ++i)
+        if (FormatStr[i] == '%')
+          return 0; // we found a format specifier, bail out.
+
+      // These optimizations require DataLayout.
+      if (!TD) return 0;
+
+      // sprintf(str, fmt) -> llvm.memcpy(str, fmt, strlen(fmt)+1, 1)
+      B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1),
+                     ConstantInt::get(TD->getIntPtrType(*Context), // Copy the
+                                      FormatStr.size() + 1), 1);   // nul byte.
+      return ConstantInt::get(CI->getType(), FormatStr.size());
+    }
+
+    // The remaining optimizations require the format string to be "%s" or "%c"
+    // and have an extra operand.
+    if (FormatStr.size() != 2 || FormatStr[0] != '%' ||
+        CI->getNumArgOperands() < 3)
+      return 0;
+
+    // Decode the second character of the format string.
+    if (FormatStr[1] == 'c') {
+      // sprintf(dst, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0
+      if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return 0;
+      Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char");
+      Value *Ptr = CastToCStr(CI->getArgOperand(0), B);
+      B.CreateStore(V, Ptr);
+      Ptr = B.CreateGEP(Ptr, B.getInt32(1), "nul");
+      B.CreateStore(B.getInt8(0), Ptr);
+
+      return ConstantInt::get(CI->getType(), 1);
+    }
+
+    if (FormatStr[1] == 's') {
+      // These optimizations require DataLayout.
+      if (!TD) return 0;
+
+      // sprintf(dest, "%s", str) -> llvm.memcpy(dest, str, strlen(str)+1, 1)
+      if (!CI->getArgOperand(2)->getType()->isPointerTy()) return 0;
+
+      Value *Len = EmitStrLen(CI->getArgOperand(2), B, TD, TLI);
+      if (!Len)
+        return 0;
+      Value *IncLen = B.CreateAdd(Len,
+                                  ConstantInt::get(Len->getType(), 1),
+                                  "leninc");
+      B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(2), IncLen, 1);
+
+      // The sprintf result is the unincremented number of bytes in the string.
+      return B.CreateIntCast(Len, CI->getType(), false);
+    }
+    return 0;
+  }
+
+  virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
+    // Require two fixed pointer arguments and an integer result.
+    FunctionType *FT = Callee->getFunctionType();
+    if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() ||
+        !FT->getParamType(1)->isPointerTy() ||
+        !FT->getReturnType()->isIntegerTy())
+      return 0;
+
+    if (Value *V = OptimizeFixedFormatString(Callee, CI, B)) {
+      return V;
+    }
+
+    // sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating
+    // point arguments.
+    if (TLI->has(LibFunc::siprintf) && !callHasFloatingPointArgument(CI)) {
+      Module *M = B.GetInsertBlock()->getParent()->getParent();
+      Constant *SIPrintFFn =
+        M->getOrInsertFunction("siprintf", FT, Callee->getAttributes());
+      CallInst *New = cast<CallInst>(CI->clone());
+      New->setCalledFunction(SIPrintFFn);
+      B.Insert(New);
+      return New;
+    }
+    return 0;
+  }
+};
+
+struct FPrintFOpt : public LibCallOptimization {
+  Value *optimizeFixedFormatString(Function *Callee, CallInst *CI,
+                                   IRBuilder<> &B) {
+    // All the optimizations depend on the format string.
+    StringRef FormatStr;
+    if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr))
+      return 0;
+
+    // fprintf(F, "foo") --> fwrite("foo", 3, 1, F)
+    if (CI->getNumArgOperands() == 2) {
+      for (unsigned i = 0, e = FormatStr.size(); i != e; ++i)
+        if (FormatStr[i] == '%')  // Could handle %% -> % if we cared.
+          return 0; // We found a format specifier.
+
+      // These optimizations require DataLayout.
+      if (!TD) return 0;
+
+      Value *NewCI = EmitFWrite(CI->getArgOperand(1),
+                                ConstantInt::get(TD->getIntPtrType(*Context),
+                                                 FormatStr.size()),
+                                CI->getArgOperand(0), B, TD, TLI);
+      return NewCI ? ConstantInt::get(CI->getType(), FormatStr.size()) : 0;
+    }
+
+    // The remaining optimizations require the format string to be "%s" or "%c"
+    // and have an extra operand.
+    if (FormatStr.size() != 2 || FormatStr[0] != '%' ||
+        CI->getNumArgOperands() < 3)
+      return 0;
+
+    // Decode the second character of the format string.
+    if (FormatStr[1] == 'c') {
+      // fprintf(F, "%c", chr) --> fputc(chr, F)
+      if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return 0;
+      Value *NewCI = EmitFPutC(CI->getArgOperand(2), CI->getArgOperand(0), B,
+                               TD, TLI);
+      return NewCI ? ConstantInt::get(CI->getType(), 1) : 0;
+    }
+
+    if (FormatStr[1] == 's') {
+      // fprintf(F, "%s", str) --> fputs(str, F)
+      if (!CI->getArgOperand(2)->getType()->isPointerTy() || !CI->use_empty())
+        return 0;
+      return EmitFPutS(CI->getArgOperand(2), CI->getArgOperand(0), B, TD, TLI);
+    }
+    return 0;
+  }
+
+  virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
+    // Require two fixed paramters as pointers and integer result.
+    FunctionType *FT = Callee->getFunctionType();
+    if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() ||
+        !FT->getParamType(1)->isPointerTy() ||
+        !FT->getReturnType()->isIntegerTy())
+      return 0;
+
+    if (Value *V = optimizeFixedFormatString(Callee, CI, B)) {
+      return V;
+    }
+
+    // fprintf(stream, format, ...) -> fiprintf(stream, format, ...) if no
+    // floating point arguments.
+    if (TLI->has(LibFunc::fiprintf) && !callHasFloatingPointArgument(CI)) {
+      Module *M = B.GetInsertBlock()->getParent()->getParent();
+      Constant *FIPrintFFn =
+        M->getOrInsertFunction("fiprintf", FT, Callee->getAttributes());
+      CallInst *New = cast<CallInst>(CI->clone());
+      New->setCalledFunction(FIPrintFFn);
+      B.Insert(New);
+      return New;
+    }
+    return 0;
+  }
+};
+
+struct FWriteOpt : public LibCallOptimization {
+  virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
+    // Require a pointer, an integer, an integer, a pointer, returning integer.
+    FunctionType *FT = Callee->getFunctionType();
+    if (FT->getNumParams() != 4 || !FT->getParamType(0)->isPointerTy() ||
+        !FT->getParamType(1)->isIntegerTy() ||
+        !FT->getParamType(2)->isIntegerTy() ||
+        !FT->getParamType(3)->isPointerTy() ||
+        !FT->getReturnType()->isIntegerTy())
+      return 0;
+
+    // Get the element size and count.
+    ConstantInt *SizeC = dyn_cast<ConstantInt>(CI->getArgOperand(1));
+    ConstantInt *CountC = dyn_cast<ConstantInt>(CI->getArgOperand(2));
+    if (!SizeC || !CountC) return 0;
+    uint64_t Bytes = SizeC->getZExtValue()*CountC->getZExtValue();
+
+    // If this is writing zero records, remove the call (it's a noop).
+    if (Bytes == 0)
+      return ConstantInt::get(CI->getType(), 0);
+
+    // If this is writing one byte, turn it into fputc.
+    // This optimisation is only valid, if the return value is unused.
+    if (Bytes == 1 && CI->use_empty()) {  // fwrite(S,1,1,F) -> fputc(S[0],F)
+      Value *Char = B.CreateLoad(CastToCStr(CI->getArgOperand(0), B), "char");
+      Value *NewCI = EmitFPutC(Char, CI->getArgOperand(3), B, TD, TLI);
+      return NewCI ? ConstantInt::get(CI->getType(), 1) : 0;
+    }
+
+    return 0;
+  }
+};
+
 } // End anonymous namespace.
 
 namespace llvm {
@@ -1365,6 +1663,12 @@ class LibCallSimplifierImpl {
   IsAsciiOpt IsAscii;
   ToAsciiOpt ToAscii;
 
+  // Formatting and IO library call optimizations.
+  PrintFOpt PrintF;
+  SPrintFOpt SPrintF;
+  FPrintFOpt FPrintF;
+  FWriteOpt FWrite;
+
   void initOptimizations();
   void addOpt(LibFunc::Func F, LibCallOptimization* Opt);
   void addOpt(LibFunc::Func F1, LibFunc::Func F2, LibCallOptimization* Opt);
@@ -1485,6 +1789,12 @@ void LibCallSimplifierImpl::initOptimizations() {
   addOpt(LibFunc::isdigit, &IsDigit);
   addOpt(LibFunc::isascii, &IsAscii);
   addOpt(LibFunc::toascii, &ToAscii);
+
+  // Formatting and IO library call optimizations.
+  addOpt(LibFunc::printf, &PrintF);
+  addOpt(LibFunc::sprintf, &SPrintF);
+  addOpt(LibFunc::fprintf, &FPrintF);
+  addOpt(LibFunc::fwrite, &FWrite);
 }
 
 Value *LibCallSimplifierImpl::optimizeCall(CallInst *CI) {