Generate correct Sparc32 ABI compliant code for functions that return a struct.
[oota-llvm.git] / lib / Target / Sparc / SparcISelLowering.cpp
index 196b87dd58d0e6246c70a203839cebd8fbdf17b1..f39e91bb2f3a2836340ba5c72f0b74957379efd6 100644 (file)
@@ -16,7 +16,9 @@
 #include "SparcISelLowering.h"
 #include "SparcTargetMachine.h"
 #include "SparcMachineFunctionInfo.h"
+#include "llvm/DerivedTypes.h"
 #include "llvm/Function.h"
+#include "llvm/Module.h"
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineFunction.h"
@@ -116,6 +118,8 @@ SparcTargetLowering::LowerReturn(SDValue Chain,
     // Guarantee that all emitted copies are stuck together with flags.
     Flag = Chain.getValue(1);
   }
+
+  unsigned RetAddrOffset = 8; //Call Inst + Delay Slot
   // If the function returns a struct, copy the SRetReturnReg to I0
   if (MF.getFunction()->hasStructRetAttr()) {
     SparcMachineFunctionInfo *SFI = MF.getInfo<SparcMachineFunctionInfo>();
@@ -127,11 +131,16 @@ SparcTargetLowering::LowerReturn(SDValue Chain,
     Flag = Chain.getValue(1);
     if (MF.getRegInfo().liveout_empty())
       MF.getRegInfo().addLiveOut(SP::I0);
+    RetAddrOffset = 12; // CallInst + Delay Slot + Unimp
   }
 
+  SDValue RetAddrOffsetNode = DAG.getConstant(RetAddrOffset, MVT::i32);
+
   if (Flag.getNode())
-    return DAG.getNode(SPISD::RET_FLAG, dl, MVT::Other, Chain, Flag);
-  return DAG.getNode(SPISD::RET_FLAG, dl, MVT::Other, Chain);
+    return DAG.getNode(SPISD::RET_FLAG, dl, MVT::Other, Chain,
+                       RetAddrOffsetNode, Flag);
+  return DAG.getNode(SPISD::RET_FLAG, dl, MVT::Other, Chain, 
+                     RetAddrOffsetNode);
 }
 
 /// LowerFormalArguments - V8 uses a very simple ABI, where all values are
@@ -393,6 +402,7 @@ SparcTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
   SmallVector<SDValue, 8> MemOpChains;
 
   const unsigned StackOffset = 92;
+  bool hasStructRetAttr = false;
   // Walk the register/memloc assignments, inserting copies/loads.
   for (unsigned i = 0, realArgIdx = 0, byvalArgIdx = 0, e = ArgLocs.size();
        i != e;
@@ -433,6 +443,7 @@ SparcTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
       MemOpChains.push_back(DAG.getStore(Chain, dl, Arg, PtrOff,
                                          MachinePointerInfo(),
                                          false, false, 0));
+      hasStructRetAttr = true;
       continue;
     }
 
@@ -546,6 +557,8 @@ SparcTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
     InFlag = Chain.getValue(1);
   }
 
+  unsigned SRetArgSize = (hasStructRetAttr)? getSRetArgSize(DAG, Callee):0;
+
   // If the callee is a GlobalAddress node (quite common, every direct call is)
   // turn it into a TargetGlobalAddress node so that legalize doesn't hack it.
   // Likewise ExternalSymbol -> TargetExternalSymbol.
@@ -559,6 +572,8 @@ SparcTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
   SmallVector<SDValue, 8> Ops;
   Ops.push_back(Chain);
   Ops.push_back(Callee);
+  if (hasStructRetAttr)
+    Ops.push_back(DAG.getTargetConstant(SRetArgSize, MVT::i32));
   for (unsigned i = 0, e = RegsToPass.size(); i != e; ++i) {
     unsigned Reg = RegsToPass[i].first;
     if (Reg >= SP::I0 && Reg <= SP::I7)
@@ -600,7 +615,29 @@ SparcTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
   return Chain;
 }
 
+unsigned
+SparcTargetLowering::getSRetArgSize(SelectionDAG &DAG, SDValue Callee) const
+{
+  const Function *CalleeFn = 0;
+  if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
+    CalleeFn = dyn_cast<Function>(G->getGlobal());
+  } else if (ExternalSymbolSDNode *E =
+             dyn_cast<ExternalSymbolSDNode>(Callee)) {
+    const Function *Fn = DAG.getMachineFunction().getFunction();
+    const Module *M = Fn->getParent();
+    CalleeFn = M->getFunction(E->getSymbol());
+  }
+
+  if (!CalleeFn)
+    return 0;
 
+  assert(CalleeFn->hasStructRetAttr() &&
+         "Callee does not have the StructRet attribute.");
+
+  const PointerType *Ty = cast<PointerType>(CalleeFn->arg_begin()->getType());
+  const Type *ElementTy = Ty->getElementType();
+  return getTargetData()->getTypeAllocSize(ElementTy);
+}
 
 //===----------------------------------------------------------------------===//
 // TargetLowering Implementation