From: Arnold Schwaighofer Date: Sat, 12 Apr 2008 18:11:06 +0000 (+0000) Subject: This patch corrects the handling of byval arguments for tailcall X-Git-Url: http://plrg.eecs.uci.edu/git/?a=commitdiff_plain;h=4b5324ad2cbf774c9c6ed02ea0fcc864f2f5f885;p=oota-llvm.git This patch corrects the handling of byval arguments for tailcall optimized x86-64 (and x86) calls so that they work (... at least for my test cases). Should fix the following problems: Problem 1: When i introduced the optimized handling of arguments for tail called functions (using a sequence of copyto/copyfrom virtual registers instead of always lowering to top of the stack) i did not handle byval arguments correctly e.g they did not work at all :). Problem 2: On x86-64 after the arguments of the tail called function are moved to their registers (which include ESI/RSI etc), tail call optimization performs byval lowering which causes xSI,xDI, xCX registers to be overwritten. This is handled in this patch by moving the arguments to virtual registers first and after the byval lowering the arguments are moved from those virtual registers back to RSI/RDI/RCX. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@49584 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 9db0288c4e3..ac58ab4f05c 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -979,8 +979,8 @@ static bool ArgsAreStructReturn(SDOperand Op) { return cast(Op.getOperand(3))->getArgFlags().isSRet(); } -/// IsCalleePop - Determines whether a CALL or FORMAL_ARGUMENTS node requires the -/// callee to pop its own arguments. Callee pop is necessary to support tail +/// IsCalleePop - Determines whether a CALL or FORMAL_ARGUMENTS node requires +/// the callee to pop its own arguments. Callee pop is necessary to support tail /// calls. bool X86TargetLowering::IsCalleePop(SDOperand Op) { bool IsVarArg = cast(Op.getOperand(2))->getValue() != 0; @@ -1104,8 +1104,8 @@ CopyTailCallClobberedArgumentsToVRegs(SDOperand Chain, /// CreateCopyOfByValArgument - Make a copy of an aggregate at address specified /// by "Src" to address "Dst" with size and alignment information specified by -/// the specific parameter attribute. The copy will be passed as a byval function -/// parameter. +/// the specific parameter attribute. The copy will be passed as a byval +/// function parameter. static SDOperand CreateCopyOfByValArgument(SDOperand Src, SDOperand Dst, SDOperand Chain, ISD::ArgFlagsTy Flags, SelectionDAG &DAG) { @@ -1347,6 +1347,99 @@ X86TargetLowering::LowerMemOpCallTo(SDOperand Op, SelectionDAG &DAG, PseudoSourceValue::getStack(), LocMemOffset); } +/// EmitTailCallLoadRetAddr - Emit a load of return adress if tail call +/// optimization is performed and it is required. +SDOperand +X86TargetLowering::EmitTailCallLoadRetAddr(SelectionDAG &DAG, + SDOperand &OutRetAddr, + SDOperand Chain, + bool IsTailCall, + bool Is64Bit, + int FPDiff) { + if (!IsTailCall || FPDiff==0) return Chain; + + // Adjust the Return address stack slot. + MVT::ValueType VT = getPointerTy(); + OutRetAddr = getReturnAddressFrameIndex(DAG); + // Load the "old" Return address. + OutRetAddr = DAG.getLoad(VT, Chain,OutRetAddr, NULL, 0); + return SDOperand(OutRetAddr.Val, 1); +} + +/// EmitTailCallStoreRetAddr - Emit a store of the return adress if tail call +/// optimization is performed and it is required (FPDiff!=0). +static SDOperand +EmitTailCallStoreRetAddr(SelectionDAG & DAG, MachineFunction &MF, + SDOperand Chain, SDOperand RetAddrFrIdx, + bool Is64Bit, int FPDiff) { + // Store the return address to the appropriate stack slot. + if (!FPDiff) return Chain; + // Calculate the new stack slot for the return address. + int SlotSize = Is64Bit ? 8 : 4; + int NewReturnAddrFI = + MF.getFrameInfo()->CreateFixedObject(SlotSize, FPDiff-SlotSize); + MVT::ValueType VT = Is64Bit ? MVT::i64 : MVT::i32; + SDOperand NewRetAddrFrIdx = DAG.getFrameIndex(NewReturnAddrFI, VT); + Chain = DAG.getStore(Chain, RetAddrFrIdx, NewRetAddrFrIdx, + PseudoSourceValue::getFixedStack(), NewReturnAddrFI); + return Chain; +} + +/// CopyTailCallByValClobberedRegToVirtReg - Copy arguments with register target +/// which might be overwritten by later byval tail call lowering to a virtual +/// register. +bool +X86TargetLowering::CopyTailCallByValClobberedRegToVirtReg(bool containsByValArg, + SmallVector< std::pair, 8> &TailCallByValClobberedVRegs, + SmallVector &TailCallByValClobberedVRegTypes, + std::pair &RegToPass, + SDOperand &OutChain, + SDOperand &OutFlag, + MachineFunction &MF, + SelectionDAG & DAG) { + if (!containsByValArg) return false; + + std::pair ArgRegVReg; + MVT::ValueType VT = RegToPass.second.getValueType(); + + ArgRegVReg.first = RegToPass.first; + ArgRegVReg.second = MF.getRegInfo().createVirtualRegister(getRegClassFor(VT)); + + // Copy Argument to virtual register. + OutChain = DAG.getCopyToReg(OutChain, ArgRegVReg.second, + RegToPass.second, OutFlag); + OutFlag = OutChain.getValue(1); + // Remember virtual register and type. + TailCallByValClobberedVRegs.push_back(ArgRegVReg); + TailCallByValClobberedVRegTypes.push_back(VT); + return true; +} + + +/// RestoreTailCallByValClobberedReg - Restore registers which were saved to +/// virtual registers to prevent tail call byval lowering from overwriting +/// parameter registers. +static SDOperand +RestoreTailCallByValClobberedRegs(SelectionDAG & DAG, SDOperand Chain, + SmallVector< std::pair, 8> &TailCallByValClobberedVRegs, + SmallVector &TailCallByValClobberedVRegTypes) { + if (TailCallByValClobberedVRegs.size()==0) return Chain; + + SmallVector RegOpChains; + for (unsigned i = 0, e=TailCallByValClobberedVRegs.size(); i != e; i++) { + SDOperand InFlag; + unsigned DestReg = TailCallByValClobberedVRegs[i].first; + unsigned VirtReg = TailCallByValClobberedVRegs[i].second; + MVT::ValueType VT = TailCallByValClobberedVRegTypes[i]; + SDOperand Tmp = DAG.getCopyFromReg(Chain, VirtReg, VT, InFlag); + Chain = DAG.getCopyToReg(Chain, DestReg, Tmp, InFlag); + RegOpChains.push_back(Chain); + } + if (!RegOpChains.empty()) + Chain = DAG.getNode(ISD::TokenFactor, MVT::Other, + &RegOpChains[0], RegOpChains.size()); + return Chain; +} SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) { MachineFunction &MF = DAG.getMachineFunction(); @@ -1396,30 +1489,29 @@ SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) { Chain = DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(NumBytes)); SDOperand RetAddrFrIdx; - if (IsTailCall) { - // Adjust the Return address stack slot. - if (FPDiff) { - MVT::ValueType VT = Is64Bit ? MVT::i64 : MVT::i32; - RetAddrFrIdx = getReturnAddressFrameIndex(DAG); - // Load the "old" Return address. - RetAddrFrIdx = - DAG.getLoad(VT, Chain,RetAddrFrIdx, NULL, 0); - Chain = SDOperand(RetAddrFrIdx.Val, 1); - } - } + // Load return adress for tail calls. + Chain = EmitTailCallLoadRetAddr(DAG, RetAddrFrIdx, Chain, IsTailCall, Is64Bit, + FPDiff); SmallVector, 8> RegsToPass; SmallVector, 8> TailCallClobberedVRegs; + SmallVector MemOpChains; SDOperand StackPtr; + bool containsTailCallByValArg = false; + SmallVector, 8> TailCallByValClobberedVRegs; + SmallVector TailCallByValClobberedVRegTypes; + // Walk the register/memloc assignments, inserting copies/loads. For tail // calls, remember all arguments for later special lowering. for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) { CCValAssign &VA = ArgLocs[i]; SDOperand Arg = Op.getOperand(5+2*VA.getValNo()); - + bool isByVal = cast(Op.getOperand(6+2*VA.getValNo()))-> + getArgFlags().isByVal(); + // Promote the value if needed. switch (VA.getLocInfo()) { default: assert(0 && "Unknown loc info!"); @@ -1438,13 +1530,15 @@ SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) { if (VA.isRegLoc()) { RegsToPass.push_back(std::make_pair(VA.getLocReg(), Arg)); } else { - if (!IsTailCall) { + if (!IsTailCall || (IsTailCall && isByVal)) { assert(VA.isMemLoc()); if (StackPtr.Val == 0) StackPtr = DAG.getCopyFromReg(Chain, X86StackPtr, getPointerTy()); MemOpChains.push_back(LowerMemOpCallTo(Op, DAG, StackPtr, VA, Chain, Arg)); + // Remember fact that this call contains byval arguments. + containsTailCallByValArg |= IsTailCall && isByVal; } else if (IsPossiblyOverwrittenArgumentOfTailCall(Arg, MFI)) { TailCallClobberedVRegs.push_back(std::make_pair(i,Arg)); } @@ -1459,6 +1553,16 @@ SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) { // and flag operands which copy the outgoing args into registers. SDOperand InFlag; for (unsigned i = 0, e = RegsToPass.size(); i != e; ++i) { + // Tail call byval lowering might overwrite argument registers so arguments + // passed to be copied to a virtual register for + // later processing. + if (CopyTailCallByValClobberedRegToVirtReg(containsTailCallByValArg, + TailCallByValClobberedVRegs, + TailCallByValClobberedVRegTypes, + RegsToPass[i], Chain, InFlag, MF, + DAG)) + continue; + Chain = DAG.getCopyToReg(Chain, RegsToPass[i].first, RegsToPass[i].second, InFlag); InFlag = Chain.getValue(1); @@ -1533,7 +1637,7 @@ SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) { int32_t Offset = VA.getLocMemOffset()+FPDiff; uint32_t OpSize = (MVT::getSizeInBits(VA.getLocVT())+7)/8; FI = MF.getFrameInfo()->CreateFixedObject(OpSize, Offset); - FIN = DAG.getFrameIndex(FI, MVT::i32); + FIN = DAG.getFrameIndex(FI, getPointerTy()); // Find virtual register for this argument. bool Found=false; @@ -1548,7 +1652,12 @@ SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) { if (Flags.isByVal()) { // Copy relative to framepointer. - MemOpChains2.push_back(CreateCopyOfByValArgument(Arg, FIN, Chain, + SDOperand Source = DAG.getIntPtrConstant(VA.getLocMemOffset()); + if (StackPtr.Val == 0) + StackPtr = DAG.getCopyFromReg(Chain, X86StackPtr, getPointerTy()); + Source = DAG.getNode(ISD::ADD, getPointerTy(), StackPtr, Source); + + MemOpChains2.push_back(CreateCopyOfByValArgument(Source, FIN, Chain, Flags, DAG)); } else { // Store relative to framepointer. @@ -1563,17 +1672,14 @@ SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) { Chain = DAG.getNode(ISD::TokenFactor, MVT::Other, &MemOpChains2[0], MemOpChains2.size()); + // Restore byval lowering clobbered registers. + Chain = RestoreTailCallByValClobberedRegs(DAG, Chain, + TailCallByValClobberedVRegs, + TailCallByValClobberedVRegTypes); + // Store the return address to the appropriate stack slot. - if (FPDiff) { - // Calculate the new stack slot for the return address. - int SlotSize = Is64Bit ? 8 : 4; - int NewReturnAddrFI = - MF.getFrameInfo()->CreateFixedObject(SlotSize, FPDiff-SlotSize); - MVT::ValueType VT = Is64Bit ? MVT::i64 : MVT::i32; - SDOperand NewRetAddrFrIdx = DAG.getFrameIndex(NewReturnAddrFI, VT); - Chain = DAG.getStore(Chain, RetAddrFrIdx, NewRetAddrFrIdx, - PseudoSourceValue::getFixedStack(), NewReturnAddrFI); - } + Chain = EmitTailCallStoreRetAddr(DAG, MF, Chain, RetAddrFrIdx, Is64Bit, + FPDiff); } // If the callee is a GlobalAddress node (quite common, every direct call is) diff --git a/lib/Target/X86/X86ISelLowering.h b/lib/Target/X86/X86ISelLowering.h index 2abe237ed82..fea2d2b3577 100644 --- a/lib/Target/X86/X86ISelLowering.h +++ b/lib/Target/X86/X86ISelLowering.h @@ -484,6 +484,19 @@ namespace llvm { bool IsCalleePop(SDOperand Op); bool CallRequiresGOTPtrInReg(bool Is64Bit, bool IsTailCall); bool CallRequiresFnAddressInReg(bool Is64Bit, bool IsTailCall); + SDOperand EmitTailCallLoadRetAddr(SelectionDAG &DAG, SDOperand &OutRetAddr, + SDOperand Chain, bool IsTailCall, bool Is64Bit, + int FPDiff); + + bool CopyTailCallByValClobberedRegToVirtReg(bool containsByValArg, + SmallVector< std::pair,8> &TailCallByValClobberedVRegs, + SmallVector &TailCallByValClobberedVRegTypes, + std::pair &RegToPass, + SDOperand &OutChain, + SDOperand &OutFlag, + MachineFunction &MF, + SelectionDAG & DAG); + CCAssignFn *CCAssignFnForNode(SDOperand Op) const; NameDecorationStyle NameDecorationForFORMAL_ARGUMENTS(SDOperand Op); unsigned GetAlignedArgumentStackSize(unsigned StackSize, SelectionDAG &DAG); diff --git a/test/CodeGen/X86/tailcallbyval.ll b/test/CodeGen/X86/tailcallbyval.ll index dc1dea7e113..9085b050ec5 100644 --- a/test/CodeGen/X86/tailcallbyval.ll +++ b/test/CodeGen/X86/tailcallbyval.ll @@ -1,5 +1,9 @@ ; RUN: llvm-as < %s | llc -march=x86 -tailcallopt | grep TAILCALL -%struct.s = type { i32, i32, i32 } +; check for the 2 byval moves +; RUN: llvm-as < %s | llc -march=x86 -tailcallopt | grep rep | wc -l | grep 2 +%struct.s = type {i32, i32, i32, i32, i32, i32, i32, i32, + i32, i32, i32, i32, i32, i32, i32, i32, + i32, i32, i32, i32, i32, i32, i32, i32 } define fastcc i32 @tailcallee(%struct.s* byval %a) { entry: diff --git a/test/CodeGen/X86/tailcallbyval64.ll b/test/CodeGen/X86/tailcallbyval64.ll new file mode 100644 index 00000000000..7b65863f00b --- /dev/null +++ b/test/CodeGen/X86/tailcallbyval64.ll @@ -0,0 +1,29 @@ +; RUN: llvm-as < %s | llc -march=x86-64 -tailcallopt | grep TAILCALL +; Expect 2 rep;movs because of tail call byval lowering. +; RUN: llvm-as < %s | llc -march=x86-64 -tailcallopt | grep rep | wc -l | grep 2 +; A sequence of copyto/copyfrom virtual registers is used to deal with byval +; lowering appearing after moving arguments to registers. The following two +; checks verify that the register allocator changes those sequences to direct +; moves to argument register where it can (for registers that are not used in +; byval lowering - not rsi, not rdi, not rcx). +; Expect argument 4 to be moved directly to register edx. +; RUN: llvm-as < %s | llc -march=x86-64 -tailcallopt | grep movl | grep {7} | grep edx +; Expect argument 6 to be moved directly to register r8. +; RUN: llvm-as < %s | llc -march=x86-64 -tailcallopt | grep movl | grep {17} | grep r8 + +%struct.s = type { i64, i64, i64, i64, i64, i64, i64, i64, + i64, i64, i64, i64, i64, i64, i64, i64, + i64, i64, i64, i64, i64, i64, i64, i64 } + +declare fastcc i64 @tailcallee(%struct.s* byval %a, i64 %val, i64 %val2, i64 %val3, i64 %val4, i64 %val5) + + +define fastcc i64 @tailcaller(i64 %b, %struct.s* byval %a) { +entry: + %tmp2 = getelementptr %struct.s* %a, i32 0, i32 1 + %tmp3 = load i64* %tmp2, align 8 + %tmp4 = tail call fastcc i64 @tailcallee(%struct.s* %a byval, i64 %tmp3, i64 %b, i64 7, i64 13, i64 17) + ret i64 %tmp4 +} + +