DAG: move sret demotion into most basic LowerCallTo implementation.
authorTim Northover <tnorthover@apple.com>
Wed, 18 Jun 2014 11:52:44 +0000 (11:52 +0000)
committerTim Northover <tnorthover@apple.com>
Wed, 18 Jun 2014 11:52:44 +0000 (11:52 +0000)
It looks like there are two versions of LowerCallTo here: the
SelectionDAGBuilder one is designed to operate on LLVM IR, and the
TargetLowering one in the case where everything is at DAG level.

Previously, only the SelectionDAGBuilder variant could handle demoting
an impossible return to sret semantics (before delegating to the
TargetLowering version), but this functionality is also useful for
certain libcalls (e.g. 128-bit operations on 32-bit x86).  So this
commit moves the sret handling down a level.

rdar://problem/17242889

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@211155 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
test/CodeGen/X86/libcall-sret.ll [new file with mode: 0644]

index 136baf56e8a7521fc70db5f825360c1e3f9845f8..e6dc27219787aaae2e8a4ddcbe7c3a75914f0b85 100644 (file)
@@ -5439,6 +5439,7 @@ SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, unsigned Intrinsic) {
 void SelectionDAGBuilder::LowerCallTo(ImmutableCallSite CS, SDValue Callee,
                                       bool isTailCall,
                                       MachineBasicBlock *LandingPad) {
+  const TargetLowering *TLI = TM.getTargetLowering();
   PointerType *PT = cast<PointerType>(CS.getCalledValue()->getType());
   FunctionType *FTy = cast<FunctionType>(PT->getElementType());
   Type *RetTy = FTy->getReturnType();
@@ -5449,45 +5450,6 @@ void SelectionDAGBuilder::LowerCallTo(ImmutableCallSite CS, SDValue Callee,
   TargetLowering::ArgListEntry Entry;
   Args.reserve(CS.arg_size());
 
-  // Check whether the function can return without sret-demotion.
-  SmallVector<ISD::OutputArg, 4> Outs;
-  const TargetLowering *TLI = TM.getTargetLowering();
-  GetReturnInfo(RetTy, CS.getAttributes(), Outs, *TLI);
-
-  bool CanLowerReturn = TLI->CanLowerReturn(CS.getCallingConv(),
-                                            DAG.getMachineFunction(),
-                                            FTy->isVarArg(), Outs,
-                                            FTy->getContext());
-
-  SDValue DemoteStackSlot;
-  int DemoteStackIdx = -100;
-
-  if (!CanLowerReturn) {
-    assert(!CS.hasInAllocaArgument() &&
-           "sret demotion is incompatible with inalloca");
-    uint64_t TySize = TLI->getDataLayout()->getTypeAllocSize(
-                      FTy->getReturnType());
-    unsigned Align  = TLI->getDataLayout()->getPrefTypeAlignment(
-                      FTy->getReturnType());
-    MachineFunction &MF = DAG.getMachineFunction();
-    DemoteStackIdx = MF.getFrameInfo()->CreateStackObject(TySize, Align, false);
-    Type *StackSlotPtrType = PointerType::getUnqual(FTy->getReturnType());
-
-    DemoteStackSlot = DAG.getFrameIndex(DemoteStackIdx, TLI->getPointerTy());
-    Entry.Node = DemoteStackSlot;
-    Entry.Ty = StackSlotPtrType;
-    Entry.isSExt = false;
-    Entry.isZExt = false;
-    Entry.isInReg = false;
-    Entry.isSRet = true;
-    Entry.isNest = false;
-    Entry.isByVal = false;
-    Entry.isReturned = false;
-    Entry.Alignment = Align;
-    Args.push_back(Entry);
-    RetTy = Type::getVoidTy(FTy->getContext());
-  }
-
   for (ImmutableCallSite::arg_iterator i = CS.arg_begin(), e = CS.arg_end();
        i != e; ++i) {
     const Value *V = *i;
@@ -5540,46 +5502,8 @@ void SelectionDAGBuilder::LowerCallTo(ImmutableCallSite CS, SDValue Callee,
          "Non-null chain expected with non-tail call!");
   assert((Result.second.getNode() || !Result.first.getNode()) &&
          "Null value expected with tail call!");
-  if (Result.first.getNode()) {
+  if (Result.first.getNode())
     setValue(CS.getInstruction(), Result.first);
-  } else if (!CanLowerReturn && Result.second.getNode()) {
-    // The instruction result is the result of loading from the
-    // hidden sret parameter.
-    SmallVector<EVT, 1> PVTs;
-    Type *PtrRetTy = PointerType::getUnqual(FTy->getReturnType());
-
-    ComputeValueVTs(*TLI, PtrRetTy, PVTs);
-    assert(PVTs.size() == 1 && "Pointers should fit in one register");
-    EVT PtrVT = PVTs[0];
-
-    SmallVector<EVT, 4> RetTys;
-    SmallVector<uint64_t, 4> Offsets;
-    RetTy = FTy->getReturnType();
-    ComputeValueVTs(*TLI, RetTy, RetTys, &Offsets);
-
-    unsigned NumValues = RetTys.size();
-    SmallVector<SDValue, 4> Values(NumValues);
-    SmallVector<SDValue, 4> Chains(NumValues);
-
-    for (unsigned i = 0; i < NumValues; ++i) {
-      SDValue Add = DAG.getNode(ISD::ADD, getCurSDLoc(), PtrVT,
-                                DemoteStackSlot,
-                                DAG.getConstant(Offsets[i], PtrVT));
-      SDValue L = DAG.getLoad(RetTys[i], getCurSDLoc(), Result.second, Add,
-                  MachinePointerInfo::getFixedStack(DemoteStackIdx, Offsets[i]),
-                              false, false, false, 1);
-      Values[i] = L;
-      Chains[i] = L.getValue(1);
-    }
-
-    SDValue Chain = DAG.getNode(ISD::TokenFactor, getCurSDLoc(),
-                                MVT::Other, Chains);
-    PendingLoads.push_back(Chain);
-
-    setValue(CS.getInstruction(),
-             DAG.getNode(ISD::MERGE_VALUES, getCurSDLoc(),
-                         DAG.getVTList(RetTys), Values));
-  }
 
   if (!Result.second.getNode()) {
     // As a special case, a null chain means that a tail call has been emitted
@@ -7121,6 +7045,21 @@ void SelectionDAGBuilder::visitPatchpoint(const CallInst &CI) {
   FuncInfo.MF->getFrameInfo()->setHasPatchPoint();
 }
 
+/// Returns an AttributeSet representing the attributes applied to the return
+/// value of the given call.
+static AttributeSet getReturnAttrs(TargetLowering::CallLoweringInfo &CLI) {
+  SmallVector<Attribute::AttrKind, 2> Attrs;
+  if (CLI.RetSExt)
+    Attrs.push_back(Attribute::SExt);
+  if (CLI.RetZExt)
+    Attrs.push_back(Attribute::ZExt);
+  if (CLI.IsInReg)
+    Attrs.push_back(Attribute::InReg);
+
+  return AttributeSet::get(CLI.RetTy->getContext(), AttributeSet::ReturnIndex,
+                           Attrs);
+}
+
 /// TargetLowering::LowerCallTo - This is the default LowerCallTo
 /// implementation, which just calls LowerCall.
 /// FIXME: When all targets are
@@ -7129,24 +7068,62 @@ std::pair<SDValue, SDValue>
 TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
   // Handle the incoming return values from the call.
   CLI.Ins.clear();
+  Type *OrigRetTy = CLI.RetTy;
   SmallVector<EVT, 4> RetTys;
-  ComputeValueVTs(*this, CLI.RetTy, RetTys);
-  for (unsigned I = 0, E = RetTys.size(); I != E; ++I) {
-    EVT VT = RetTys[I];
-    MVT RegisterVT = getRegisterType(CLI.RetTy->getContext(), VT);
-    unsigned NumRegs = getNumRegisters(CLI.RetTy->getContext(), VT);
-    for (unsigned i = 0; i != NumRegs; ++i) {
-      ISD::InputArg MyFlags;
-      MyFlags.VT = RegisterVT;
-      MyFlags.ArgVT = VT;
-      MyFlags.Used = CLI.IsReturnValueUsed;
-      if (CLI.RetSExt)
-        MyFlags.Flags.setSExt();
-      if (CLI.RetZExt)
-        MyFlags.Flags.setZExt();
-      if (CLI.IsInReg)
-        MyFlags.Flags.setInReg();
-      CLI.Ins.push_back(MyFlags);
+  SmallVector<uint64_t, 4> Offsets;
+  ComputeValueVTs(*this, CLI.RetTy, RetTys, &Offsets);
+
+  SmallVector<ISD::OutputArg, 4> Outs;
+  GetReturnInfo(CLI.RetTy, getReturnAttrs(CLI), Outs, *this);
+
+  bool CanLowerReturn =
+      this->CanLowerReturn(CLI.CallConv, CLI.DAG.getMachineFunction(),
+                           CLI.IsVarArg, Outs, CLI.RetTy->getContext());
+
+  SDValue DemoteStackSlot;
+  int DemoteStackIdx = -100;
+  if (!CanLowerReturn) {
+    // FIXME: equivalent assert?
+    // assert(!CS.hasInAllocaArgument() &&
+    //        "sret demotion is incompatible with inalloca");
+    uint64_t TySize = getDataLayout()->getTypeAllocSize(CLI.RetTy);
+    unsigned Align  = getDataLayout()->getPrefTypeAlignment(CLI.RetTy);
+    MachineFunction &MF = CLI.DAG.getMachineFunction();
+    DemoteStackIdx = MF.getFrameInfo()->CreateStackObject(TySize, Align, false);
+    Type *StackSlotPtrType = PointerType::getUnqual(CLI.RetTy);
+
+    DemoteStackSlot = CLI.DAG.getFrameIndex(DemoteStackIdx, getPointerTy());
+    ArgListEntry Entry;
+    Entry.Node = DemoteStackSlot;
+    Entry.Ty = StackSlotPtrType;
+    Entry.isSExt = false;
+    Entry.isZExt = false;
+    Entry.isInReg = false;
+    Entry.isSRet = true;
+    Entry.isNest = false;
+    Entry.isByVal = false;
+    Entry.isReturned = false;
+    Entry.Alignment = Align;
+    CLI.getArgs().insert(CLI.getArgs().begin(), Entry);
+    CLI.RetTy = Type::getVoidTy(CLI.RetTy->getContext());
+  } else {
+    for (unsigned I = 0, E = RetTys.size(); I != E; ++I) {
+      EVT VT = RetTys[I];
+      MVT RegisterVT = getRegisterType(CLI.RetTy->getContext(), VT);
+      unsigned NumRegs = getNumRegisters(CLI.RetTy->getContext(), VT);
+      for (unsigned i = 0; i != NumRegs; ++i) {
+        ISD::InputArg MyFlags;
+        MyFlags.VT = RegisterVT;
+        MyFlags.ArgVT = VT;
+        MyFlags.Used = CLI.IsReturnValueUsed;
+        if (CLI.RetSExt)
+          MyFlags.Flags.setSExt();
+        if (CLI.RetZExt)
+          MyFlags.Flags.setZExt();
+        if (CLI.IsInReg)
+          MyFlags.Flags.setInReg();
+        CLI.Ins.push_back(MyFlags);
+      }
     }
   }
 
@@ -7289,31 +7266,59 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
                  "LowerCall emitted a value with the wrong type!");
         });
 
-  // Collect the legal value parts into potentially illegal values
-  // that correspond to the original function's return values.
-  ISD::NodeType AssertOp = ISD::DELETED_NODE;
-  if (CLI.RetSExt)
-    AssertOp = ISD::AssertSext;
-  else if (CLI.RetZExt)
-    AssertOp = ISD::AssertZext;
   SmallVector<SDValue, 4> ReturnValues;
-  unsigned CurReg = 0;
-  for (unsigned I = 0, E = RetTys.size(); I != E; ++I) {
-    EVT VT = RetTys[I];
-    MVT RegisterVT = getRegisterType(CLI.RetTy->getContext(), VT);
-    unsigned NumRegs = getNumRegisters(CLI.RetTy->getContext(), VT);
-
-    ReturnValues.push_back(getCopyFromParts(CLI.DAG, CLI.DL, &InVals[CurReg],
-                                            NumRegs, RegisterVT, VT, nullptr,
-                                            AssertOp));
-    CurReg += NumRegs;
-  }
-
-  // For a function returning void, there is no return value. We can't create
-  // such a node, so we just return a null return value in that case. In
-  // that case, nothing will actually look at the value.
-  if (ReturnValues.empty())
-    return std::make_pair(SDValue(), CLI.Chain);
+  if (!CanLowerReturn) {
+    // The instruction result is the result of loading from the
+    // hidden sret parameter.
+    SmallVector<EVT, 1> PVTs;
+    Type *PtrRetTy = PointerType::getUnqual(OrigRetTy);
+
+    ComputeValueVTs(*this, PtrRetTy, PVTs);
+    assert(PVTs.size() == 1 && "Pointers should fit in one register");
+    EVT PtrVT = PVTs[0];
+
+    unsigned NumValues = RetTys.size();
+    ReturnValues.resize(NumValues);
+    SmallVector<SDValue, 4> Chains(NumValues);
+
+    for (unsigned i = 0; i < NumValues; ++i) {
+      SDValue Add = CLI.DAG.getNode(ISD::ADD, CLI.DL, PtrVT, DemoteStackSlot,
+                                    CLI.DAG.getConstant(Offsets[i], PtrVT));
+      SDValue L = CLI.DAG.getLoad(
+          RetTys[i], CLI.DL, CLI.Chain, Add,
+          MachinePointerInfo::getFixedStack(DemoteStackIdx, Offsets[i]), false,
+          false, false, 1);
+      ReturnValues[i] = L;
+      Chains[i] = L.getValue(1);
+    }
+
+    CLI.Chain = CLI.DAG.getNode(ISD::TokenFactor, CLI.DL, MVT::Other, Chains);
+  } else {
+    // Collect the legal value parts into potentially illegal values
+    // that correspond to the original function's return values.
+    ISD::NodeType AssertOp = ISD::DELETED_NODE;
+    if (CLI.RetSExt)
+      AssertOp = ISD::AssertSext;
+    else if (CLI.RetZExt)
+      AssertOp = ISD::AssertZext;
+    unsigned CurReg = 0;
+    for (unsigned I = 0, E = RetTys.size(); I != E; ++I) {
+      EVT VT = RetTys[I];
+      MVT RegisterVT = getRegisterType(CLI.RetTy->getContext(), VT);
+      unsigned NumRegs = getNumRegisters(CLI.RetTy->getContext(), VT);
+
+      ReturnValues.push_back(getCopyFromParts(CLI.DAG, CLI.DL, &InVals[CurReg],
+                                              NumRegs, RegisterVT, VT, nullptr,
+                                              AssertOp));
+      CurReg += NumRegs;
+    }
+
+    // For a function returning void, there is no return value. We can't create
+    // such a node, so we just return a null return value in that case. In
+    // that case, nothing will actually look at the value.
+    if (ReturnValues.empty())
+      return std::make_pair(SDValue(), CLI.Chain);
+  }
 
   SDValue Res = CLI.DAG.getNode(ISD::MERGE_VALUES, CLI.DL,
                                 CLI.DAG.getVTList(RetTys), ReturnValues);
diff --git a/test/CodeGen/X86/libcall-sret.ll b/test/CodeGen/X86/libcall-sret.ll
new file mode 100644 (file)
index 0000000..67b99ac
--- /dev/null
@@ -0,0 +1,28 @@
+; RUN: llc -mtriple=i686-linux-gnu -o - %s | FileCheck %s
+
+@var = global i128 0
+
+; We were trying to convert the i128 operation into a libcall, but failing to
+; perform sret demotion when we couldn't return the result in registers. Make
+; sure we marshal the return properly:
+
+define void @test_sret_libcall(i128 %l, i128 %r) {
+; CHECK-LABEL: test_sret_libcall:
+
+  ; Stack for call: 4(sret ptr), 16(i128 %l), 16(128 %r). So next logical
+  ; (aligned) place for the actual sret data is %esp + 40.
+; CHECK: leal 40(%esp), [[SRET_ADDR:%[a-z]+]]
+; CHECK: movl [[SRET_ADDR]], (%esp)
+; CHECK: calll __multi3
+; CHECK-DAG: movl 40(%esp), [[RES0:%[a-z]+]]
+; CHECK-DAG: movl 44(%esp), [[RES1:%[a-z]+]]
+; CHECK-DAG: movl 48(%esp), [[RES2:%[a-z]+]]
+; CHECK-DAG: movl 52(%esp), [[RES3:%[a-z]+]]
+; CHECK-DAG: movl [[RES0]], var
+; CHECK-DAG: movl [[RES1]], var+4
+; CHECK-DAG: movl [[RES2]], var+8
+; CHECK-DAG: movl [[RES3]], var+12
+  %prod = mul i128 %l, %r
+  store i128 %prod, i128* @var
+  ret void
+}