From: Reid Spencer Date: Wed, 4 May 2005 03:20:21 +0000 (+0000) Subject: * Correct the function prototypes for some of the functions to match the X-Git-Url: http://plrg.eecs.uci.edu/git/?a=commitdiff_plain;h=58b563ce4351996dd403646231fed40794a0aa4c;p=oota-llvm.git * Correct the function prototypes for some of the functions to match the actual spec (int -> uint) * Add the ability to get/cache the strlen function prototype. * Make sure generated values are appropriately named for debugging purposes * Add the SPrintFOptimiation for 4 casts of sprintf optimization: sprintf(str,cstr) -> llvm.memcpy(str,cstr) (if cstr has no %) sprintf(str,"") -> store sbyte 0, str sprintf(str,"%s",src) -> llvm.memcpy(str,src) (if src is constant) sprintf(str,"%c",chr) -> store chr, str ; store sbyte 0, str+1 The sprintf optimization didn't fire as much as I had hoped: 2 MultiSource/Applications/SPASS 5 MultiSource/Benchmarks/McCat/18-imp 22 MultiSource/Benchmarks/Prolangs-C/TimberWolfMC 1 MultiSource/Benchmarks/Prolangs-C/assembler 6 MultiSource/Benchmarks/Prolangs-C/unix-smail 2 MultiSource/Benchmarks/mediabench/mpeg2/mpeg2dec git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@21679 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Transforms/IPO/SimplifyLibCalls.cpp b/lib/Transforms/IPO/SimplifyLibCalls.cpp index e6978100be9..4c1fe29c2a3 100644 --- a/lib/Transforms/IPO/SimplifyLibCalls.cpp +++ b/lib/Transforms/IPO/SimplifyLibCalls.cpp @@ -256,6 +256,21 @@ public: return sqrt_func; } + /// @brief Return a Function* for the strlen libcall + Function* get_strcpy() + { + if (!strcpy_func) + { + std::vector args; + args.push_back(PointerType::get(Type::SByteTy)); + args.push_back(PointerType::get(Type::SByteTy)); + FunctionType* strcpy_type = + FunctionType::get(PointerType::get(Type::SByteTy), args, false); + strcpy_func = M->getOrInsertFunction("strcpy",strcpy_type); + } + return strcpy_func; + } + /// @brief Return a Function* for the strlen libcall Function* get_strlen() { @@ -295,8 +310,8 @@ public: std::vector args; args.push_back(PointerType::get(Type::SByteTy)); args.push_back(PointerType::get(Type::SByteTy)); - args.push_back(Type::IntTy); - args.push_back(Type::IntTy); + args.push_back(Type::UIntTy); + args.push_back(Type::UIntTy); FunctionType* memcpy_type = FunctionType::get(Type::VoidTy, args, false); memcpy_func = M->getOrInsertFunction("llvm.memcpy",memcpy_type); } @@ -314,6 +329,7 @@ private: memcpy_func = 0; memchr_func = 0; sqrt_func = 0; + strcpy_func = 0; strlen_func = 0; } @@ -323,6 +339,7 @@ private: Function* memcpy_func; ///< Cached llvm.memcpy function Function* memchr_func; ///< Cached memchr function Function* sqrt_func; ///< Cached sqrt function + Function* strcpy_func; ///< Cached strcpy function Function* strlen_func; ///< Cached strlen function Module* M; ///< Cached Module TargetData* TD; ///< Cached TargetData @@ -493,8 +510,8 @@ public: std::vector vals; vals.push_back(gep); // destination vals.push_back(ci->getOperand(2)); // source - vals.push_back(ConstantSInt::get(Type::IntTy,len)); // length - vals.push_back(ConstantSInt::get(Type::IntTy,1)); // alignment + vals.push_back(ConstantUInt::get(Type::UIntTy,len)); // length + vals.push_back(ConstantUInt::get(Type::UIntTy,1)); // alignment new CallInst(SLC.get_memcpy(), vals, "", ci); // Finally, substitute the first operand of the strcat call for the @@ -862,8 +879,8 @@ public: std::vector vals; vals.push_back(dest); // destination vals.push_back(src); // source - vals.push_back(ConstantSInt::get(Type::IntTy,len)); // length - vals.push_back(ConstantSInt::get(Type::IntTy,1)); // alignment + vals.push_back(ConstantUInt::get(Type::UIntTy,len)); // length + vals.push_back(ConstantUInt::get(Type::UIntTy,1)); // alignment new CallInst(SLC.get_memcpy(), vals, "", ci); // Finally, substitute the first operand of the strcat call for the @@ -1255,7 +1272,8 @@ public: args.push_back(ConstantUInt::get(SLC.getIntPtrType(),len)); args.push_back(ConstantUInt::get(SLC.getIntPtrType(),1)); args.push_back(ci->getOperand(1)); - new CallInst(fwrite_func,args,"",ci); + new CallInst(fwrite_func,args,ci->getName(),ci); + ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,len)); ci->eraseFromParent(); return true; } @@ -1281,7 +1299,7 @@ public: if (!getConstantStringLength(ci->getOperand(3), len, &CA)) return false; - // fprintf(file,fmt) -> fwrite(fmt,strlen(fmt),1,file) + // fprintf(file,"%s",str) -> fwrite(fmt,strlen(fmt),1,file) const Type* FILEptr_type = ci->getOperand(1)->getType(); Function* fwrite_func = SLC.get_fwrite(FILEptr_type); if (!fwrite_func) @@ -1291,7 +1309,8 @@ public: args.push_back(ConstantUInt::get(SLC.getIntPtrType(),len)); args.push_back(ConstantUInt::get(SLC.getIntPtrType(),1)); args.push_back(ci->getOperand(1)); - new CallInst(fwrite_func,args,"",ci); + new CallInst(fwrite_func,args,ci->getName(),ci); + ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,len)); break; } case 'c': @@ -1306,6 +1325,7 @@ public: return false; CastInst* cast = new CastInst(CI,Type::IntTy,CI->getName()+".int",ci); new CallInst(fputc_func,cast,ci->getOperand(1),"",ci); + ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,1)); break; } default: @@ -1317,6 +1337,149 @@ public: } FPrintFOptimizer; +/// This LibCallOptimization will simplify calls to the "sprintf" library +/// function. It looks for cases where the result of sprintf is not used and the +/// operation can be reduced to something simpler. +/// @brief Simplify the pow library function. +struct SPrintFOptimization : public LibCallOptimization +{ +public: + /// @brief Default Constructor + SPrintFOptimization() : LibCallOptimization("sprintf", + "simplify-libcalls:sprintf", "Number of 'sprintf' calls simplified") {} + + /// @brief Destructor + virtual ~SPrintFOptimization() {} + + /// @brief Make sure that the "fprintf" function has the right prototype + virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC) + { + // Just make sure this has at least 2 arguments + return (f->getReturnType() == Type::IntTy && f->arg_size() >= 2); + } + + /// @brief Perform the sprintf optimization. + virtual bool OptimizeCall(CallInst* ci, SimplifyLibCalls& SLC) + { + // If the call has more than 3 operands, we can't optimize it + if (ci->getNumOperands() > 4 || ci->getNumOperands() < 3) + return false; + + // All the optimizations depend on the length of the second argument and the + // fact that it is a constant string array. Check that now + uint64_t len = 0; + ConstantArray* CA = 0; + if (!getConstantStringLength(ci->getOperand(2), len, &CA)) + return false; + + if (ci->getNumOperands() == 3) + { + if (len == 0) + { + // If the length is 0, we just need to store a null byte + new StoreInst(ConstantInt::get(Type::SByteTy,0),ci->getOperand(1),ci); + ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,0)); + ci->eraseFromParent(); + return true; + } + + // Make sure there's no % in the constant array + for (unsigned i = 0; i < len; ++i) + { + if (ConstantInt* CI = dyn_cast(CA->getOperand(i))) + { + // Check for the null terminator + if (CI->getRawValue() == '%') + return false; // we found a %, can't optimize + } + else + return false; // initializer is not constant int, can't optimize + } + + // Increment length because we want to copy the null byte too + len++; + + // sprintf(str,fmt) -> llvm.memcpy(str,fmt,strlen(fmt),1) + Function* memcpy_func = SLC.get_memcpy(); + if (!memcpy_func) + return false; + std::vector args; + args.push_back(ci->getOperand(1)); + args.push_back(ci->getOperand(2)); + args.push_back(ConstantUInt::get(Type::UIntTy,len)); + args.push_back(ConstantUInt::get(Type::UIntTy,1)); + new CallInst(memcpy_func,args,"",ci); + ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,len)); + ci->eraseFromParent(); + return true; + } + + // The remaining optimizations require the format string to be length 2 + // "%s" or "%c". + if (len != 2) + return false; + + // The first character has to be a % + if (ConstantInt* CI = dyn_cast(CA->getOperand(0))) + if (CI->getRawValue() != '%') + return false; + + // Get the second character and switch on its value + ConstantInt* CI = dyn_cast(CA->getOperand(1)); + switch (CI->getRawValue()) + { + case 's': + { + uint64_t len = 0; + if (ci->hasNUses(0)) + { + // sprintf(dest,"%s",str) -> strcpy(dest,str) + Function* strcpy_func = SLC.get_strcpy(); + if (!strcpy_func) + return false; + std::vector args; + args.push_back(ci->getOperand(1)); + args.push_back(ci->getOperand(3)); + new CallInst(strcpy_func,args,"",ci); + } + else if (getConstantStringLength(ci->getOperand(3),len)) + { + // sprintf(dest,"%s",cstr) -> llvm.memcpy(dest,str,strlen(str),1) + len++; // get the null-terminator + Function* memcpy_func = SLC.get_memcpy(); + if (!memcpy_func) + return false; + std::vector args; + args.push_back(ci->getOperand(1)); + args.push_back(ci->getOperand(3)); + args.push_back(ConstantUInt::get(Type::UIntTy,len)); + args.push_back(ConstantUInt::get(Type::UIntTy,1)); + new CallInst(memcpy_func,args,"",ci); + ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,len)); + } + break; + } + case 'c': + { + // sprintf(dest,"%c",chr) -> store chr, dest + CastInst* cast = + new CastInst(ci->getOperand(3),Type::SByteTy,"char",ci); + new StoreInst(cast, ci->getOperand(1), ci); + GetElementPtrInst* gep = new GetElementPtrInst(ci->getOperand(1), + ConstantUInt::get(Type::UIntTy,1),ci->getOperand(1)->getName()+".end", + ci); + new StoreInst(ConstantInt::get(Type::SByteTy,0),gep,ci); + ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,1)); + break; + } + default: + return false; + } + ci->eraseFromParent(); + return true; + } +} SPrintFOptimizer; + /// This LibCallOptimization will simplify calls to the "fputs" library /// function. It looks for cases where the result of fputs is not used and the /// operation can be reduced to something simpler.