From ca891ecf9152791f72f33a0dafff6b4a022642ee Mon Sep 17 00:00:00 2001 From: Devang Patel Date: Fri, 29 Feb 2008 23:34:08 +0000 Subject: [PATCH] Add pass to promote sret. This pass transforms %struct._Point = type { i32, i32, i32, i32, i32, i32 } define internal void @foo(%struct._Point* sret %agg.result) into %struct._Point = type { i32, i32, i32, i32, i32, i32 } define internal %struct._Point @foo() This pass updates foo() clients appropriately to use getresult instruction to extract return values. This pass is not yet ready for prime time. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@47776 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/Transforms/IPO.h | 1 + lib/Transforms/IPO/StructRetPromotion.cpp | 292 ++++++++++++++++++++++ 2 files changed, 293 insertions(+) create mode 100644 lib/Transforms/IPO/StructRetPromotion.cpp diff --git a/include/llvm/Transforms/IPO.h b/include/llvm/Transforms/IPO.h index e7590ac019b..8d59fa6355c 100644 --- a/include/llvm/Transforms/IPO.h +++ b/include/llvm/Transforms/IPO.h @@ -125,6 +125,7 @@ ModulePass *createDeadArgHackingPass(); /// be passed by value. /// Pass *createArgumentPromotionPass(); +Pass *createStructRetPromotionPass(); //===----------------------------------------------------------------------===// /// createIPConstantPropagationPass - This pass propagates constants from call diff --git a/lib/Transforms/IPO/StructRetPromotion.cpp b/lib/Transforms/IPO/StructRetPromotion.cpp new file mode 100644 index 00000000000..dd626a59123 --- /dev/null +++ b/lib/Transforms/IPO/StructRetPromotion.cpp @@ -0,0 +1,292 @@ +//===-- StructRetPromotion.cpp - Promote sret arguments -000000------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass promotes "by reference" arguments to be "by value" arguments. In +// practice, this means looking for internal functions that have pointer +// arguments. If it can prove, through the use of alias analysis, that an +// argument is *only* loaded, then it can pass the value into the function +// instead of the address of the value. This can cause recursive simplification +// of code and lead to the elimination of allocas (especially in C++ template +// code like the STL). +// +// This pass also handles aggregate arguments that are passed into a function, +// scalarizing them if the elements of the aggregate are only loaded. Note that +// it refuses to scalarize aggregates which would require passing in more than +// three operands to the function, because passing thousands of operands for a +// large array or structure is unprofitable! +// +// Note that this transformation could also be done for arguments that are only +// stored to (returning the value instead), but does not currently. This case +// would be best handled when and if LLVM begins supporting multiple return +// values from functions. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "sretpromotion" +#include "llvm/Transforms/IPO.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Module.h" +#include "llvm/CallGraphSCCPass.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Support/CallSite.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +namespace { + /// SRETPromotion - This pass removes sret parameter and updates + /// function to use multiple return value. + /// + struct VISIBILITY_HIDDEN SRETPromotion : public CallGraphSCCPass { + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + CallGraphSCCPass::getAnalysisUsage(AU); + } + + virtual bool runOnSCC(const std::vector &SCC); + static char ID; // Pass identification, replacement for typeid + SRETPromotion() : CallGraphSCCPass((intptr_t)&ID) {} + + private: + bool PromoteReturn(CallGraphNode *CGN); + bool isSafeToUpdateAllCallers(Function *F); + Function *cloneFunctionBody(Function *F, const StructType *STy); + void updateCallSites(Function *F, Function *NF); + }; + + char SRETPromotion::ID = 0; + RegisterPass X("sretpromotion", + "Promote sret arguments to multiple ret values"); +} + +Pass *llvm::createStructRetPromotionPass() { + return new SRETPromotion(); +} + +bool SRETPromotion::runOnSCC(const std::vector &SCC) { + bool Changed = false; + + for (unsigned i = 0, e = SCC.size(); i != e; ++i) + Changed |= PromoteReturn(SCC[i]); + + return Changed; +} + +/// PromoteReturn - This method promotes function that uses StructRet paramater +/// into a function that uses mulitple return value. +bool SRETPromotion::PromoteReturn(CallGraphNode *CGN) { + Function *F = CGN->getFunction(); + + // Make sure that it is local to this module. + if (!F || !F->hasInternalLinkage()) + return false; + + // Make sure that function returns struct. + if (F->arg_size() == 0 || !F->isStructReturn() || F->doesNotReturn()) + return false; + + assert (F->getReturnType() == Type::VoidTy && "Invalid function return type"); + Function::arg_iterator AI = F->arg_begin(); + const llvm::PointerType *FArgType = dyn_cast(AI->getType()); + assert (FArgType && "Invalid sret paramater type"); + const llvm::StructType *STy = + dyn_cast(FArgType->getElementType()); + assert (STy && "Invalid sret parameter element type"); + + // Check if it is ok to perform this promotion. + if (isSafeToUpdateAllCallers(F) == false) + return false; + + // [1] Replace use of sret parameter + AllocaInst *TheAlloca = new AllocaInst (STy, NULL, "mrv", F->getEntryBlock().begin()); + Value *NFirstArg = F->arg_begin(); + NFirstArg->replaceAllUsesWith(TheAlloca); + + // Find and replace ret instructions + SmallVector RetVals; + for (Function::iterator FI = F->begin(), FE = F->end(); FI != FE; ++FI) + for(BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE; ) { + Instruction *I = BI; + ++BI; + if (isa(I)) { + RetVals.clear(); + for (unsigned idx = 0; idx < STy->getNumElements(); ++idx) { + SmallVector GEPIdx; + GEPIdx.push_back(ConstantInt::get(Type::Int32Ty, 0)); + GEPIdx.push_back(ConstantInt::get(Type::Int32Ty, idx)); + Value *NGEPI = new GetElementPtrInst(TheAlloca, GEPIdx.begin(), GEPIdx.end(), + "mrv.gep", I); + Value *NV = new LoadInst(NGEPI, "mrv.ld", I); + RetVals.push_back(NV); + } + + ReturnInst *NR = new ReturnInst(&RetVals[0], RetVals.size(), I); + I->replaceAllUsesWith(NR); + I->eraseFromParent(); + } + } + + // Create the new function body and insert it into the module. + Function *NF = cloneFunctionBody(F, STy); + + // Update all call sites to use new function + updateCallSites(F, NF); + + F->eraseFromParent(); + getAnalysis().changeFunction(F, NF); + return true; +} + + // Check if it is ok to perform this promotion. +bool SRETPromotion::isSafeToUpdateAllCallers(Function *F) { + + if (F->use_empty()) + // No users. OK to modify signature. + return true; + + for (Value::use_iterator FnUseI = F->use_begin(), FnUseE = F->use_end(); + FnUseI != FnUseE; ++FnUseI) { + + CallSite CS = CallSite::get(*FnUseI); + Instruction *Call = CS.getInstruction(); + CallSite::arg_iterator AI = CS.arg_begin(); + Value *FirstArg = *AI; + + if (!isa(FirstArg)) + return false; + + // Check FirstArg's users. + for (Value::use_iterator ArgI = FirstArg->use_begin(), + ArgE = FirstArg->use_end(); ArgI != ArgE; ++ArgI) { + + // If FirstArg user is a CallInst that does not correspond to current + // call site then this function F is not suitable for sret promotion. + if (CallInst *CI = dyn_cast(ArgI)) { + if (CI != Call) + return false; + } + // If FirstArg user is a GEP whose all users are not LoadInst then + // this function F is not suitable for sret promotion. + else if (GetElementPtrInst *GEP = dyn_cast(ArgI)) { + for (Value::use_iterator GEPI = GEP->use_begin(), GEPE = GEP->use_end(); + GEPI != GEPE; ++GEPI) + if (!isa(GEPI)) + return false; + } + // Any other FirstArg users make this function unsuitable for sret + // promotion. + else + return false; + } + } + + return true; +} + +/// cloneFunctionBody - Create a new function based on F and +/// insert it into module. Remove first argument. Use STy as +/// the return type for new function. +Function *SRETPromotion::cloneFunctionBody(Function *F, + const StructType *STy) { + + // FIXME : Do not drop param attributes on the floor. + const FunctionType *FTy = F->getFunctionType(); + std::vector Params; + + // Skip first argument. + Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); + ++I; + while (I != E) { + Params.push_back(I->getType()); + ++I; + } + + FunctionType *NFTy = FunctionType::get(STy, Params, FTy->isVarArg()); + Function *NF = new Function(NFTy, F->getLinkage(), F->getName()); + NF->setCallingConv(F->getCallingConv()); + F->getParent()->getFunctionList().insert(F, NF); + NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); + + // Replace arguments + I = F->arg_begin(); + E = F->arg_end(); + Function::arg_iterator NI = NF->arg_begin(); + ++I; + while (I != E) { + I->replaceAllUsesWith(NI); + NI->takeName(I); + ++I; + ++NI; + } + + return NF; +} + +/// updateCallSites - Update all sites that call F to use NF. +void SRETPromotion::updateCallSites(Function *F, Function *NF) { + + // FIXME : Handle parameter attributes + SmallVector Args; + + for (Value::use_iterator FUI = F->use_begin(), FUE = F->use_end(); FUI != FUE;) { + CallSite CS = CallSite::get(*FUI); + ++FUI; + Instruction *Call = CS.getInstruction(); + + // Copy arguments, however skip first one. + CallSite::arg_iterator AI = CS.arg_begin(), AE = CS.arg_end(); + Value *FirstCArg = *AI; + ++AI; + while (AI != AE) { + Args.push_back(*AI); + ++AI; + } + + // Build new call instruction. + Instruction *New; + if (InvokeInst *II = dyn_cast(Call)) { + New = new InvokeInst(NF, II->getNormalDest(), II->getUnwindDest(), + Args.begin(), Args.end(), "", Call); + cast(New)->setCallingConv(CS.getCallingConv()); + } else { + New = new CallInst(NF, Args.begin(), Args.end(), "", Call); + cast(New)->setCallingConv(CS.getCallingConv()); + if (cast(Call)->isTailCall()) + cast(New)->setTailCall(); + } + Args.clear(); + New->takeName(Call); + + // Update all users of sret parameter to extract value using getresult. + for (Value::use_iterator UI = FirstCArg->use_begin(), + UE = FirstCArg->use_end(); UI != UE; ) { + User *U2 = *UI++; + CallInst *C2 = dyn_cast(U2); + if (C2 && (C2 == Call)) + continue; + else if (GetElementPtrInst *UGEP = dyn_cast(U2)) { + Value *GR = new GetResultInst(New, 5, "xxx", UGEP); + for (Value::use_iterator GI = UGEP->use_begin(), + GE = UGEP->use_end(); GI != GE; ++GI) { + if (LoadInst *L = dyn_cast(*GI)) { + L->replaceAllUsesWith(GR); + L->eraseFromParent(); + } + } + UGEP->eraseFromParent(); + } + else assert( 0 && "Unexpected sret parameter use"); + } + Call->eraseFromParent(); + } +} -- 2.34.1