This patch corrects the handling of byval arguments for tailcall
authorArnold Schwaighofer <arnold.schwaighofer@gmail.com>
Sat, 12 Apr 2008 18:11:06 +0000 (18:11 +0000)
committerArnold Schwaighofer <arnold.schwaighofer@gmail.com>
Sat, 12 Apr 2008 18:11:06 +0000 (18:11 +0000)
optimized x86-64 (and x86) calls so that they work (... at least for
my test cases).

Should fix the following problems:

Problem 1: When i introduced the optimized handling of arguments for
tail called functions (using a sequence of copyto/copyfrom virtual
registers instead of always lowering to top of the stack) i did not
handle byval arguments correctly e.g they did not work at all :).

Problem 2: On x86-64 after the arguments of the tail called function
are moved to their registers (which include ESI/RSI etc), tail call
optimization performs byval lowering which causes xSI,xDI, xCX
registers to be overwritten. This is handled in this patch by moving
the arguments to virtual registers first and after the byval lowering
the arguments are moved from those virtual registers back to
RSI/RDI/RCX.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@49584 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
lib/Target/X86/X86ISelLowering.h
test/CodeGen/X86/tailcallbyval.ll
test/CodeGen/X86/tailcallbyval64.ll [new file with mode: 0644]

index 9db0288c4e38b252ba487489507f5c7c0a1bbaa0..ac58ab4f05c1cedd354d687b47b6705455751d31 100644 (file)
@@ -979,8 +979,8 @@ static bool ArgsAreStructReturn(SDOperand Op) {
   return cast<ARG_FLAGSSDNode>(Op.getOperand(3))->getArgFlags().isSRet();
 }
 
-/// IsCalleePop - Determines whether a CALL or FORMAL_ARGUMENTS node requires the
-/// callee to pop its own arguments. Callee pop is necessary to support tail
+/// IsCalleePop - Determines whether a CALL or FORMAL_ARGUMENTS node requires
+/// the callee to pop its own arguments. Callee pop is necessary to support tail
 /// calls.
 bool X86TargetLowering::IsCalleePop(SDOperand Op) {
   bool IsVarArg = cast<ConstantSDNode>(Op.getOperand(2))->getValue() != 0;
@@ -1104,8 +1104,8 @@ CopyTailCallClobberedArgumentsToVRegs(SDOperand Chain,
 
 /// CreateCopyOfByValArgument - Make a copy of an aggregate at address specified
 /// by "Src" to address "Dst" with size and alignment information specified by
-/// the specific parameter attribute. The copy will be passed as a byval function
-/// parameter.
+/// the specific parameter attribute. The copy will be passed as a byval
+/// function parameter.
 static SDOperand 
 CreateCopyOfByValArgument(SDOperand Src, SDOperand Dst, SDOperand Chain,
                           ISD::ArgFlagsTy Flags, SelectionDAG &DAG) {
@@ -1347,6 +1347,99 @@ X86TargetLowering::LowerMemOpCallTo(SDOperand Op, SelectionDAG &DAG,
                       PseudoSourceValue::getStack(), LocMemOffset);
 }
 
+/// EmitTailCallLoadRetAddr - Emit a load of return adress if tail call
+/// optimization is performed and it is required.
+SDOperand 
+X86TargetLowering::EmitTailCallLoadRetAddr(SelectionDAG &DAG, 
+                                           SDOperand &OutRetAddr,
+                                           SDOperand Chain, 
+                                           bool IsTailCall, 
+                                           bool Is64Bit, 
+                                           int FPDiff) {
+  if (!IsTailCall || FPDiff==0) return Chain;
+
+  // Adjust the Return address stack slot.
+  MVT::ValueType VT = getPointerTy();
+  OutRetAddr = getReturnAddressFrameIndex(DAG);
+  // Load the "old" Return address.
+  OutRetAddr = DAG.getLoad(VT, Chain,OutRetAddr, NULL, 0);
+  return SDOperand(OutRetAddr.Val, 1);
+}
+
+/// EmitTailCallStoreRetAddr - Emit a store of the return adress if tail call
+/// optimization is performed and it is required (FPDiff!=0).
+static SDOperand 
+EmitTailCallStoreRetAddr(SelectionDAG & DAG, MachineFunction &MF, 
+                         SDOperand Chain, SDOperand RetAddrFrIdx,
+                         bool Is64Bit, int FPDiff) {
+  // Store the return address to the appropriate stack slot.
+  if (!FPDiff) return Chain;
+  // Calculate the new stack slot for the return address.
+  int SlotSize = Is64Bit ? 8 : 4;
+  int NewReturnAddrFI = 
+    MF.getFrameInfo()->CreateFixedObject(SlotSize, FPDiff-SlotSize);
+  MVT::ValueType VT = Is64Bit ? MVT::i64 : MVT::i32;
+  SDOperand NewRetAddrFrIdx = DAG.getFrameIndex(NewReturnAddrFI, VT);
+  Chain = DAG.getStore(Chain, RetAddrFrIdx, NewRetAddrFrIdx, 
+                       PseudoSourceValue::getFixedStack(), NewReturnAddrFI);
+  return Chain;
+}
+
+/// CopyTailCallByValClobberedRegToVirtReg - Copy arguments with register target
+/// which might be overwritten by later byval tail call lowering to a virtual
+/// register.
+bool
+X86TargetLowering::CopyTailCallByValClobberedRegToVirtReg(bool containsByValArg,
+    SmallVector< std::pair<unsigned, unsigned>, 8> &TailCallByValClobberedVRegs,
+    SmallVector<MVT::ValueType, 8> &TailCallByValClobberedVRegTypes,
+    std::pair<unsigned, SDOperand> &RegToPass,
+    SDOperand &OutChain,
+    SDOperand &OutFlag,
+    MachineFunction &MF,
+    SelectionDAG & DAG) {
+  if (!containsByValArg) return false;
+
+  std::pair<unsigned, unsigned> ArgRegVReg;
+  MVT::ValueType VT = RegToPass.second.getValueType();
+  
+  ArgRegVReg.first = RegToPass.first;
+  ArgRegVReg.second = MF.getRegInfo().createVirtualRegister(getRegClassFor(VT));
+  
+  // Copy Argument to virtual register.
+  OutChain = DAG.getCopyToReg(OutChain, ArgRegVReg.second,
+                              RegToPass.second, OutFlag);
+  OutFlag = OutChain.getValue(1);
+  // Remember virtual register and type.
+  TailCallByValClobberedVRegs.push_back(ArgRegVReg);
+  TailCallByValClobberedVRegTypes.push_back(VT);
+  return true;
+}
+
+
+/// RestoreTailCallByValClobberedReg - Restore registers which were saved to
+/// virtual registers to prevent tail call byval lowering from overwriting
+/// parameter registers.
+static SDOperand 
+RestoreTailCallByValClobberedRegs(SelectionDAG & DAG, SDOperand Chain,
+    SmallVector< std::pair<unsigned, unsigned>, 8> &TailCallByValClobberedVRegs,
+    SmallVector<MVT::ValueType, 8> &TailCallByValClobberedVRegTypes) {
+  if (TailCallByValClobberedVRegs.size()==0) return Chain;
+  
+  SmallVector<SDOperand, 8> RegOpChains;
+  for (unsigned i = 0, e=TailCallByValClobberedVRegs.size(); i != e; i++) {
+    SDOperand InFlag;
+    unsigned DestReg = TailCallByValClobberedVRegs[i].first;
+    unsigned VirtReg = TailCallByValClobberedVRegs[i].second;
+    MVT::ValueType VT = TailCallByValClobberedVRegTypes[i];
+    SDOperand Tmp = DAG.getCopyFromReg(Chain, VirtReg, VT, InFlag);
+    Chain = DAG.getCopyToReg(Chain, DestReg, Tmp, InFlag);
+    RegOpChains.push_back(Chain);
+  }
+  if (!RegOpChains.empty())
+    Chain = DAG.getNode(ISD::TokenFactor, MVT::Other,
+                        &RegOpChains[0], RegOpChains.size());
+  return Chain;
+}
 
 SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) {
   MachineFunction &MF = DAG.getMachineFunction();
@@ -1396,30 +1489,29 @@ SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) {
   Chain = DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(NumBytes));
 
   SDOperand RetAddrFrIdx;
-  if (IsTailCall) {
-    // Adjust the Return address stack slot.
-    if (FPDiff) {
-      MVT::ValueType VT = Is64Bit ? MVT::i64 : MVT::i32;
-      RetAddrFrIdx = getReturnAddressFrameIndex(DAG);
-      // Load the "old" Return address.
-      RetAddrFrIdx = 
-        DAG.getLoad(VT, Chain,RetAddrFrIdx, NULL, 0);
-      Chain = SDOperand(RetAddrFrIdx.Val, 1);
-    }
-  }
+  // Load return adress for tail calls.
+  Chain = EmitTailCallLoadRetAddr(DAG, RetAddrFrIdx, Chain, IsTailCall, Is64Bit,
+                                  FPDiff);
 
   SmallVector<std::pair<unsigned, SDOperand>, 8> RegsToPass;
   SmallVector<std::pair<unsigned, SDOperand>, 8> TailCallClobberedVRegs;
+  
   SmallVector<SDOperand, 8> MemOpChains;
 
   SDOperand StackPtr;
+  bool containsTailCallByValArg = false;
+  SmallVector<std::pair<unsigned, unsigned>, 8> TailCallByValClobberedVRegs;
+  SmallVector<MVT::ValueType, 8> TailCallByValClobberedVRegTypes;
+  
 
   // Walk the register/memloc assignments, inserting copies/loads.  For tail
   // calls, remember all arguments for later special lowering.
   for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
     CCValAssign &VA = ArgLocs[i];
     SDOperand Arg = Op.getOperand(5+2*VA.getValNo());
-    
+    bool isByVal = cast<ARG_FLAGSSDNode>(Op.getOperand(6+2*VA.getValNo()))->
+      getArgFlags().isByVal();
+  
     // Promote the value if needed.
     switch (VA.getLocInfo()) {
     default: assert(0 && "Unknown loc info!");
@@ -1438,13 +1530,15 @@ SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) {
     if (VA.isRegLoc()) {
       RegsToPass.push_back(std::make_pair(VA.getLocReg(), Arg));
     } else {
-      if (!IsTailCall) {
+      if (!IsTailCall || (IsTailCall && isByVal)) {
         assert(VA.isMemLoc());
         if (StackPtr.Val == 0)
           StackPtr = DAG.getCopyFromReg(Chain, X86StackPtr, getPointerTy());
         
         MemOpChains.push_back(LowerMemOpCallTo(Op, DAG, StackPtr, VA, Chain,
                                                Arg));
+        // Remember fact that this call contains byval arguments.
+        containsTailCallByValArg |= IsTailCall && isByVal;
       } else if (IsPossiblyOverwrittenArgumentOfTailCall(Arg, MFI)) {
         TailCallClobberedVRegs.push_back(std::make_pair(i,Arg));
       }
@@ -1459,6 +1553,16 @@ SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) {
   // and flag operands which copy the outgoing args into registers.
   SDOperand InFlag;
   for (unsigned i = 0, e = RegsToPass.size(); i != e; ++i) {
+    // Tail call byval lowering might overwrite argument registers so arguments
+    // passed to be copied to a virtual register for
+    // later processing.
+    if (CopyTailCallByValClobberedRegToVirtReg(containsTailCallByValArg,
+                                               TailCallByValClobberedVRegs,
+                                               TailCallByValClobberedVRegTypes,
+                                               RegsToPass[i], Chain, InFlag, MF,
+                                               DAG))
+      continue;
+     
     Chain = DAG.getCopyToReg(Chain, RegsToPass[i].first, RegsToPass[i].second,
                              InFlag);
     InFlag = Chain.getValue(1);
@@ -1533,7 +1637,7 @@ SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) {
         int32_t Offset = VA.getLocMemOffset()+FPDiff;
         uint32_t OpSize = (MVT::getSizeInBits(VA.getLocVT())+7)/8;
         FI = MF.getFrameInfo()->CreateFixedObject(OpSize, Offset);
-        FIN = DAG.getFrameIndex(FI, MVT::i32);
+        FIN = DAG.getFrameIndex(FI, getPointerTy());
 
         // Find virtual register for this argument.
         bool Found=false;
@@ -1548,7 +1652,12 @@ SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) {
 
         if (Flags.isByVal()) {
           // Copy relative to framepointer.
-          MemOpChains2.push_back(CreateCopyOfByValArgument(Arg, FIN, Chain,
+          SDOperand Source = DAG.getIntPtrConstant(VA.getLocMemOffset());
+          if (StackPtr.Val == 0)
+            StackPtr = DAG.getCopyFromReg(Chain, X86StackPtr, getPointerTy());
+          Source = DAG.getNode(ISD::ADD, getPointerTy(), StackPtr, Source);
+
+          MemOpChains2.push_back(CreateCopyOfByValArgument(Source, FIN, Chain,
                                                            Flags, DAG));
         } else {
           // Store relative to framepointer.
@@ -1563,17 +1672,14 @@ SDOperand X86TargetLowering::LowerCALL(SDOperand Op, SelectionDAG &DAG) {
       Chain = DAG.getNode(ISD::TokenFactor, MVT::Other,
                           &MemOpChains2[0], MemOpChains2.size());
 
+    // Restore byval lowering clobbered registers.
+    Chain = RestoreTailCallByValClobberedRegs(DAG, Chain,
+                                             TailCallByValClobberedVRegs,
+                                             TailCallByValClobberedVRegTypes);
+
     // Store the return address to the appropriate stack slot.
-    if (FPDiff) {
-      // Calculate the new stack slot for the return address.
-      int SlotSize = Is64Bit ? 8 : 4;
-      int NewReturnAddrFI = 
-        MF.getFrameInfo()->CreateFixedObject(SlotSize, FPDiff-SlotSize);
-      MVT::ValueType VT = Is64Bit ? MVT::i64 : MVT::i32;
-      SDOperand NewRetAddrFrIdx = DAG.getFrameIndex(NewReturnAddrFI, VT);
-      Chain = DAG.getStore(Chain, RetAddrFrIdx, NewRetAddrFrIdx, 
-                           PseudoSourceValue::getFixedStack(), NewReturnAddrFI);
-    }
+    Chain = EmitTailCallStoreRetAddr(DAG, MF, Chain, RetAddrFrIdx, Is64Bit,
+                                     FPDiff);
   }
 
   // If the callee is a GlobalAddress node (quite common, every direct call is)
index 2abe237ed825302fcaaf0238d884ecaac14c10d9..fea2d2b357775af8cbdd94af11de5c0bac50e8cf 100644 (file)
@@ -484,6 +484,19 @@ namespace llvm {
     bool IsCalleePop(SDOperand Op);
     bool CallRequiresGOTPtrInReg(bool Is64Bit, bool IsTailCall);
     bool CallRequiresFnAddressInReg(bool Is64Bit, bool IsTailCall);
+    SDOperand EmitTailCallLoadRetAddr(SelectionDAG &DAG, SDOperand &OutRetAddr,
+                                SDOperand Chain, bool IsTailCall, bool Is64Bit,
+                                int FPDiff);
+    
+    bool CopyTailCallByValClobberedRegToVirtReg(bool containsByValArg,
+     SmallVector< std::pair<unsigned, unsigned>,8> &TailCallByValClobberedVRegs,
+     SmallVector<MVT::ValueType, 8> &TailCallByValClobberedVRegTypes,
+     std::pair<unsigned, SDOperand> &RegToPass,
+     SDOperand &OutChain,
+     SDOperand &OutFlag,
+     MachineFunction &MF,
+     SelectionDAG & DAG);
+
     CCAssignFn *CCAssignFnForNode(SDOperand Op) const;
     NameDecorationStyle NameDecorationForFORMAL_ARGUMENTS(SDOperand Op);
     unsigned GetAlignedArgumentStackSize(unsigned StackSize, SelectionDAG &DAG);
index dc1dea7e113ee804f461e02397fd4ed84d91a7bb..9085b050ec541469d7c3e272ccecfbd4f1639cf5 100644 (file)
@@ -1,5 +1,9 @@
 ; RUN: llvm-as < %s | llc -march=x86 -tailcallopt | grep TAILCALL
-%struct.s = type { i32, i32, i32 }
+; check for the 2 byval moves
+; RUN: llvm-as < %s | llc -march=x86 -tailcallopt | grep rep | wc -l | grep 2
+%struct.s = type {i32, i32, i32, i32, i32, i32, i32, i32,
+                  i32, i32, i32, i32, i32, i32, i32, i32,
+                  i32, i32, i32, i32, i32, i32, i32, i32 }
 
 define  fastcc i32 @tailcallee(%struct.s* byval %a) {
 entry:
diff --git a/test/CodeGen/X86/tailcallbyval64.ll b/test/CodeGen/X86/tailcallbyval64.ll
new file mode 100644 (file)
index 0000000..7b65863
--- /dev/null
@@ -0,0 +1,29 @@
+; RUN: llvm-as < %s | llc -march=x86-64  -tailcallopt  | grep TAILCALL
+; Expect 2 rep;movs because of tail call byval lowering.
+; RUN: llvm-as < %s | llc -march=x86-64  -tailcallopt  | grep rep | wc -l | grep 2
+; A sequence of copyto/copyfrom virtual registers is used to deal with byval
+; lowering appearing after moving arguments to registers. The following two
+; checks verify that the register allocator changes those sequences to direct
+; moves to argument register where it can (for registers that are not used in 
+; byval lowering - not rsi, not rdi, not rcx).
+; Expect argument 4 to be moved directly to register edx.
+; RUN: llvm-as < %s | llc -march=x86-64  -tailcallopt  | grep movl | grep {7} | grep edx
+; Expect argument 6 to be moved directly to register r8.
+; RUN: llvm-as < %s | llc -march=x86-64  -tailcallopt  | grep movl | grep {17} | grep r8
+
+%struct.s = type { i64, i64, i64, i64, i64, i64, i64, i64,
+                   i64, i64, i64, i64, i64, i64, i64, i64,
+                   i64, i64, i64, i64, i64, i64, i64, i64 }
+
+declare  fastcc i64 @tailcallee(%struct.s* byval %a, i64 %val, i64 %val2, i64 %val3, i64 %val4, i64 %val5)
+
+
+define  fastcc i64 @tailcaller(i64 %b, %struct.s* byval %a) {
+entry:
+        %tmp2 = getelementptr %struct.s* %a, i32 0, i32 1
+        %tmp3 = load i64* %tmp2, align 8
+        %tmp4 = tail call fastcc i64 @tailcallee(%struct.s* %a byval, i64 %tmp3, i64 %b, i64 7, i64 13, i64 17)
+        ret i64 %tmp4
+}
+
+