PTX: Handle function call return values
authorJustin Holewinski <justin.holewinski@gmail.com>
Fri, 23 Sep 2011 16:48:41 +0000 (16:48 +0000)
committerJustin Holewinski <justin.holewinski@gmail.com>
Fri, 23 Sep 2011 16:48:41 +0000 (16:48 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@140386 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/PTX/PTXAsmPrinter.cpp
lib/Target/PTX/PTXISelLowering.cpp
test/CodeGen/PTX/simple-call.ll

index 77164cac881573e26242e303fd5281267ca87b85..d2b7c5f6b55a60254c93ef76a946cd7ef6575c9e 100644 (file)
@@ -677,21 +677,36 @@ printPredicateOperand(const MachineInstr *MI, raw_ostream &O) {
 
 void PTXAsmPrinter::
 printCall(const MachineInstr *MI, raw_ostream &O) {
-
   O << "\tcall.uni\t";
+  // The first two operands are the predicate slot
+  unsigned Index = 2;
+  while (!MI->getOperand(Index).isGlobal()) {
+    if (Index == 2) {
+      O << "(";
+    } else {
+      O << ", ";
+    }
+    printParamOperand(MI, Index, O);
+    Index++;
+  }
 
-  const GlobalValue *Address = MI->getOperand(2).getGlobal();
-  O << Address->getName() << ", (";
+  if (Index != 2) {
+    O << "), ";
+  }
 
-  // (0,1) : predicate register/flag
-  // (2)   : callee
-  for (unsigned i = 3; i < MI->getNumOperands(); ++i) {
-    //const MachineOperand& MO = MI->getOperand(i);
+  assert(MI->getOperand(Index).isGlobal() &&
+         "A GlobalAddress must follow the return arguments");
+
+  const GlobalValue *Address = MI->getOperand(Index).getGlobal();
+  O << Address->getName() << ", (";
+  Index++;
 
-    printParamOperand(MI, i, O);
-    if (i < MI->getNumOperands()-1) {
+  while (Index < MI->getNumOperands()) {
+    printParamOperand(MI, Index, O);
+    if (Index < MI->getNumOperands()-1) {
       O << ", ";
     }
+    Index++;
   }
 
   O << ")";
index 3fdfcdf57498e3da133624b7b634ff7f10e282ad..053e140efe8e0dfbadce8335f53b035533beafb3 100644 (file)
@@ -16,6 +16,7 @@
 #include "PTXMachineFunctionInfo.h"
 #include "PTXRegisterInfo.h"
 #include "PTXSubtarget.h"
+#include "llvm/Function.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/MachineFunction.h"
@@ -440,15 +441,22 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
   assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
          "Calls are not handled for the target device");
 
-  // Is there a more "LLVM"-way to create a variable-length array of values?
-  SDValue* ops = new SDValue[OutVals.size() + 2];
+  std::vector<SDValue> Ops;
+  // The layout of the ops will be [Chain, Ins, Callee, Outs]
+  Ops.resize(Outs.size() + Ins.size() + 2);
 
-  ops[0] = Chain;
+  Ops[0] = Chain;
 
   if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
     const GlobalValue *GV = G->getGlobal();
-    Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
-    ops[1] = Callee;
+    if (const Function *F = dyn_cast<Function>(GV)) {
+      assert(F->getCallingConv() == CallingConv::PTX_Device &&
+             "PTX function calls must be to PTX device functions");
+      Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
+      Ops[Ins.size()+1] = Callee;
+    } else {
+      assert(false && "GlobalValue is not a function");
+    }
   } else {
     assert(false && "Function must be a GlobalAddressSDNode");
   }
@@ -459,14 +467,28 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
     SDValue Index = DAG.getTargetConstant(Param, MVT::i32);
     Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
                         Index, OutVals[i]);
-    ops[i+2] = Index;
+    Ops[i+Ins.size()+2] = Index;
   }
 
-  ops[0] = Chain;
+  std::vector<unsigned> InParams;
 
-  Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, ops, OutVals.size()+2);
+  for (unsigned i = 0; i < Ins.size(); ++i) {
+    unsigned Size = Ins[i].VT.getStoreSizeInBits();
+    unsigned Param = PM.addLocalParam(Size);
+    SDValue Index = DAG.getTargetConstant(Param, MVT::i32);
+    Ops[i+1] = Index;
+    InParams.push_back(Param);
+  }
 
-  delete [] ops;
+  Ops[0] = Chain;
+
+  Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, &Ops[0], Ops.size());
+
+  for (unsigned i = 0; i < Ins.size(); ++i) {
+    SDValue Index = DAG.getTargetConstant(InParams[i], MVT::i32);
+    SDValue Load = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain, Index);
+    InVals.push_back(Load);
+  }
 
   return Chain;
 }
index 1e980655d3e629eb2517ccda97a9e1f5658bc378..77ea29eae8bd3da9bef155910807887c5841dd9a 100644 (file)
@@ -12,3 +12,16 @@ define ptx_device float @test_call(float %x, float %y) {
   call void @test_add(float %a, float %y)
   ret float %a
 }
+
+define ptx_device float @test_compute(float %x, float %y) {
+; CHECK: ret;
+  %z = fadd float %x, %y
+  ret float %z
+}
+
+define ptx_device float @test_call_compute(float %x, float %y) {
+; CHECK: call.uni (__localparam_{{[0-9]+}}), test_compute, (__localparam_{{[0-9]+}}, __localparam_{{[0-9]+}})
+  %z = call float @test_compute(float %x, float %y)
+  ret float %z
+}
+