DAG: move sret demotion into most basic LowerCallTo implementation.
[oota-llvm.git] / lib / CodeGen / SelectionDAG / SelectionDAGBuilder.cpp
index 136baf56e8a7521fc70db5f825360c1e3f9845f8..e6dc27219787aaae2e8a4ddcbe7c3a75914f0b85 100644 (file)
@@ -5439,6 +5439,7 @@ SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, unsigned Intrinsic) {
 void SelectionDAGBuilder::LowerCallTo(ImmutableCallSite CS, SDValue Callee,
                                       bool isTailCall,
                                       MachineBasicBlock *LandingPad) {
+  const TargetLowering *TLI = TM.getTargetLowering();
   PointerType *PT = cast<PointerType>(CS.getCalledValue()->getType());
   FunctionType *FTy = cast<FunctionType>(PT->getElementType());
   Type *RetTy = FTy->getReturnType();
@@ -5449,45 +5450,6 @@ void SelectionDAGBuilder::LowerCallTo(ImmutableCallSite CS, SDValue Callee,
   TargetLowering::ArgListEntry Entry;
   Args.reserve(CS.arg_size());
 
-  // Check whether the function can return without sret-demotion.
-  SmallVector<ISD::OutputArg, 4> Outs;
-  const TargetLowering *TLI = TM.getTargetLowering();
-  GetReturnInfo(RetTy, CS.getAttributes(), Outs, *TLI);
-
-  bool CanLowerReturn = TLI->CanLowerReturn(CS.getCallingConv(),
-                                            DAG.getMachineFunction(),
-                                            FTy->isVarArg(), Outs,
-                                            FTy->getContext());
-
-  SDValue DemoteStackSlot;
-  int DemoteStackIdx = -100;
-
-  if (!CanLowerReturn) {
-    assert(!CS.hasInAllocaArgument() &&
-           "sret demotion is incompatible with inalloca");
-    uint64_t TySize = TLI->getDataLayout()->getTypeAllocSize(
-                      FTy->getReturnType());
-    unsigned Align  = TLI->getDataLayout()->getPrefTypeAlignment(
-                      FTy->getReturnType());
-    MachineFunction &MF = DAG.getMachineFunction();
-    DemoteStackIdx = MF.getFrameInfo()->CreateStackObject(TySize, Align, false);
-    Type *StackSlotPtrType = PointerType::getUnqual(FTy->getReturnType());
-
-    DemoteStackSlot = DAG.getFrameIndex(DemoteStackIdx, TLI->getPointerTy());
-    Entry.Node = DemoteStackSlot;
-    Entry.Ty = StackSlotPtrType;
-    Entry.isSExt = false;
-    Entry.isZExt = false;
-    Entry.isInReg = false;
-    Entry.isSRet = true;
-    Entry.isNest = false;
-    Entry.isByVal = false;
-    Entry.isReturned = false;
-    Entry.Alignment = Align;
-    Args.push_back(Entry);
-    RetTy = Type::getVoidTy(FTy->getContext());
-  }
-
   for (ImmutableCallSite::arg_iterator i = CS.arg_begin(), e = CS.arg_end();
        i != e; ++i) {
     const Value *V = *i;
@@ -5540,46 +5502,8 @@ void SelectionDAGBuilder::LowerCallTo(ImmutableCallSite CS, SDValue Callee,
          "Non-null chain expected with non-tail call!");
   assert((Result.second.getNode() || !Result.first.getNode()) &&
          "Null value expected with tail call!");
-  if (Result.first.getNode()) {
+  if (Result.first.getNode())
     setValue(CS.getInstruction(), Result.first);
-  } else if (!CanLowerReturn && Result.second.getNode()) {
-    // The instruction result is the result of loading from the
-    // hidden sret parameter.
-    SmallVector<EVT, 1> PVTs;
-    Type *PtrRetTy = PointerType::getUnqual(FTy->getReturnType());
-
-    ComputeValueVTs(*TLI, PtrRetTy, PVTs);
-    assert(PVTs.size() == 1 && "Pointers should fit in one register");
-    EVT PtrVT = PVTs[0];
-
-    SmallVector<EVT, 4> RetTys;
-    SmallVector<uint64_t, 4> Offsets;
-    RetTy = FTy->getReturnType();
-    ComputeValueVTs(*TLI, RetTy, RetTys, &Offsets);
-
-    unsigned NumValues = RetTys.size();
-    SmallVector<SDValue, 4> Values(NumValues);
-    SmallVector<SDValue, 4> Chains(NumValues);
-
-    for (unsigned i = 0; i < NumValues; ++i) {
-      SDValue Add = DAG.getNode(ISD::ADD, getCurSDLoc(), PtrVT,
-                                DemoteStackSlot,
-                                DAG.getConstant(Offsets[i], PtrVT));
-      SDValue L = DAG.getLoad(RetTys[i], getCurSDLoc(), Result.second, Add,
-                  MachinePointerInfo::getFixedStack(DemoteStackIdx, Offsets[i]),
-                              false, false, false, 1);
-      Values[i] = L;
-      Chains[i] = L.getValue(1);
-    }
-
-    SDValue Chain = DAG.getNode(ISD::TokenFactor, getCurSDLoc(),
-                                MVT::Other, Chains);
-    PendingLoads.push_back(Chain);
-
-    setValue(CS.getInstruction(),
-             DAG.getNode(ISD::MERGE_VALUES, getCurSDLoc(),
-                         DAG.getVTList(RetTys), Values));
-  }
 
   if (!Result.second.getNode()) {
     // As a special case, a null chain means that a tail call has been emitted
@@ -7121,6 +7045,21 @@ void SelectionDAGBuilder::visitPatchpoint(const CallInst &CI) {
   FuncInfo.MF->getFrameInfo()->setHasPatchPoint();
 }
 
+/// Returns an AttributeSet representing the attributes applied to the return
+/// value of the given call.
+static AttributeSet getReturnAttrs(TargetLowering::CallLoweringInfo &CLI) {
+  SmallVector<Attribute::AttrKind, 2> Attrs;
+  if (CLI.RetSExt)
+    Attrs.push_back(Attribute::SExt);
+  if (CLI.RetZExt)
+    Attrs.push_back(Attribute::ZExt);
+  if (CLI.IsInReg)
+    Attrs.push_back(Attribute::InReg);
+
+  return AttributeSet::get(CLI.RetTy->getContext(), AttributeSet::ReturnIndex,
+                           Attrs);
+}
+
 /// TargetLowering::LowerCallTo - This is the default LowerCallTo
 /// implementation, which just calls LowerCall.
 /// FIXME: When all targets are
@@ -7129,24 +7068,62 @@ std::pair<SDValue, SDValue>
 TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
   // Handle the incoming return values from the call.
   CLI.Ins.clear();
+  Type *OrigRetTy = CLI.RetTy;
   SmallVector<EVT, 4> RetTys;
-  ComputeValueVTs(*this, CLI.RetTy, RetTys);
-  for (unsigned I = 0, E = RetTys.size(); I != E; ++I) {
-    EVT VT = RetTys[I];
-    MVT RegisterVT = getRegisterType(CLI.RetTy->getContext(), VT);
-    unsigned NumRegs = getNumRegisters(CLI.RetTy->getContext(), VT);
-    for (unsigned i = 0; i != NumRegs; ++i) {
-      ISD::InputArg MyFlags;
-      MyFlags.VT = RegisterVT;
-      MyFlags.ArgVT = VT;
-      MyFlags.Used = CLI.IsReturnValueUsed;
-      if (CLI.RetSExt)
-        MyFlags.Flags.setSExt();
-      if (CLI.RetZExt)
-        MyFlags.Flags.setZExt();
-      if (CLI.IsInReg)
-        MyFlags.Flags.setInReg();
-      CLI.Ins.push_back(MyFlags);
+  SmallVector<uint64_t, 4> Offsets;
+  ComputeValueVTs(*this, CLI.RetTy, RetTys, &Offsets);
+
+  SmallVector<ISD::OutputArg, 4> Outs;
+  GetReturnInfo(CLI.RetTy, getReturnAttrs(CLI), Outs, *this);
+
+  bool CanLowerReturn =
+      this->CanLowerReturn(CLI.CallConv, CLI.DAG.getMachineFunction(),
+                           CLI.IsVarArg, Outs, CLI.RetTy->getContext());
+
+  SDValue DemoteStackSlot;
+  int DemoteStackIdx = -100;
+  if (!CanLowerReturn) {
+    // FIXME: equivalent assert?
+    // assert(!CS.hasInAllocaArgument() &&
+    //        "sret demotion is incompatible with inalloca");
+    uint64_t TySize = getDataLayout()->getTypeAllocSize(CLI.RetTy);
+    unsigned Align  = getDataLayout()->getPrefTypeAlignment(CLI.RetTy);
+    MachineFunction &MF = CLI.DAG.getMachineFunction();
+    DemoteStackIdx = MF.getFrameInfo()->CreateStackObject(TySize, Align, false);
+    Type *StackSlotPtrType = PointerType::getUnqual(CLI.RetTy);
+
+    DemoteStackSlot = CLI.DAG.getFrameIndex(DemoteStackIdx, getPointerTy());
+    ArgListEntry Entry;
+    Entry.Node = DemoteStackSlot;
+    Entry.Ty = StackSlotPtrType;
+    Entry.isSExt = false;
+    Entry.isZExt = false;
+    Entry.isInReg = false;
+    Entry.isSRet = true;
+    Entry.isNest = false;
+    Entry.isByVal = false;
+    Entry.isReturned = false;
+    Entry.Alignment = Align;
+    CLI.getArgs().insert(CLI.getArgs().begin(), Entry);
+    CLI.RetTy = Type::getVoidTy(CLI.RetTy->getContext());
+  } else {
+    for (unsigned I = 0, E = RetTys.size(); I != E; ++I) {
+      EVT VT = RetTys[I];
+      MVT RegisterVT = getRegisterType(CLI.RetTy->getContext(), VT);
+      unsigned NumRegs = getNumRegisters(CLI.RetTy->getContext(), VT);
+      for (unsigned i = 0; i != NumRegs; ++i) {
+        ISD::InputArg MyFlags;
+        MyFlags.VT = RegisterVT;
+        MyFlags.ArgVT = VT;
+        MyFlags.Used = CLI.IsReturnValueUsed;
+        if (CLI.RetSExt)
+          MyFlags.Flags.setSExt();
+        if (CLI.RetZExt)
+          MyFlags.Flags.setZExt();
+        if (CLI.IsInReg)
+          MyFlags.Flags.setInReg();
+        CLI.Ins.push_back(MyFlags);
+      }
     }
   }
 
@@ -7289,31 +7266,59 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
                  "LowerCall emitted a value with the wrong type!");
         });
 
-  // Collect the legal value parts into potentially illegal values
-  // that correspond to the original function's return values.
-  ISD::NodeType AssertOp = ISD::DELETED_NODE;
-  if (CLI.RetSExt)
-    AssertOp = ISD::AssertSext;
-  else if (CLI.RetZExt)
-    AssertOp = ISD::AssertZext;
   SmallVector<SDValue, 4> ReturnValues;
-  unsigned CurReg = 0;
-  for (unsigned I = 0, E = RetTys.size(); I != E; ++I) {
-    EVT VT = RetTys[I];
-    MVT RegisterVT = getRegisterType(CLI.RetTy->getContext(), VT);
-    unsigned NumRegs = getNumRegisters(CLI.RetTy->getContext(), VT);
-
-    ReturnValues.push_back(getCopyFromParts(CLI.DAG, CLI.DL, &InVals[CurReg],
-                                            NumRegs, RegisterVT, VT, nullptr,
-                                            AssertOp));
-    CurReg += NumRegs;
-  }
-
-  // For a function returning void, there is no return value. We can't create
-  // such a node, so we just return a null return value in that case. In
-  // that case, nothing will actually look at the value.
-  if (ReturnValues.empty())
-    return std::make_pair(SDValue(), CLI.Chain);
+  if (!CanLowerReturn) {
+    // The instruction result is the result of loading from the
+    // hidden sret parameter.
+    SmallVector<EVT, 1> PVTs;
+    Type *PtrRetTy = PointerType::getUnqual(OrigRetTy);
+
+    ComputeValueVTs(*this, PtrRetTy, PVTs);
+    assert(PVTs.size() == 1 && "Pointers should fit in one register");
+    EVT PtrVT = PVTs[0];
+
+    unsigned NumValues = RetTys.size();
+    ReturnValues.resize(NumValues);
+    SmallVector<SDValue, 4> Chains(NumValues);
+
+    for (unsigned i = 0; i < NumValues; ++i) {
+      SDValue Add = CLI.DAG.getNode(ISD::ADD, CLI.DL, PtrVT, DemoteStackSlot,
+                                    CLI.DAG.getConstant(Offsets[i], PtrVT));
+      SDValue L = CLI.DAG.getLoad(
+          RetTys[i], CLI.DL, CLI.Chain, Add,
+          MachinePointerInfo::getFixedStack(DemoteStackIdx, Offsets[i]), false,
+          false, false, 1);
+      ReturnValues[i] = L;
+      Chains[i] = L.getValue(1);
+    }
+
+    CLI.Chain = CLI.DAG.getNode(ISD::TokenFactor, CLI.DL, MVT::Other, Chains);
+  } else {
+    // Collect the legal value parts into potentially illegal values
+    // that correspond to the original function's return values.
+    ISD::NodeType AssertOp = ISD::DELETED_NODE;
+    if (CLI.RetSExt)
+      AssertOp = ISD::AssertSext;
+    else if (CLI.RetZExt)
+      AssertOp = ISD::AssertZext;
+    unsigned CurReg = 0;
+    for (unsigned I = 0, E = RetTys.size(); I != E; ++I) {
+      EVT VT = RetTys[I];
+      MVT RegisterVT = getRegisterType(CLI.RetTy->getContext(), VT);
+      unsigned NumRegs = getNumRegisters(CLI.RetTy->getContext(), VT);
+
+      ReturnValues.push_back(getCopyFromParts(CLI.DAG, CLI.DL, &InVals[CurReg],
+                                              NumRegs, RegisterVT, VT, nullptr,
+                                              AssertOp));
+      CurReg += NumRegs;
+    }
+
+    // For a function returning void, there is no return value. We can't create
+    // such a node, so we just return a null return value in that case. In
+    // that case, nothing will actually look at the value.
+    if (ReturnValues.empty())
+      return std::make_pair(SDValue(), CLI.Chain);
+  }
 
   SDValue Res = CLI.DAG.getNode(ISD::MERGE_VALUES, CLI.DL,
                                 CLI.DAG.getVTList(RetTys), ReturnValues);