PTX: Add initial support for device function calls
authorJustin Holewinski <justin.holewinski@gmail.com>
Tue, 9 Aug 2011 17:36:31 +0000 (17:36 +0000)
committerJustin Holewinski <justin.holewinski@gmail.com>
Tue, 9 Aug 2011 17:36:31 +0000 (17:36 +0000)
- Calls are supported on SM 2.0+ for function with no return values

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@137125 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/PTX/PTXAsmPrinter.cpp
lib/Target/PTX/PTXISelLowering.cpp
lib/Target/PTX/PTXISelLowering.h
lib/Target/PTX/PTXInstrInfo.td
lib/Target/PTX/PTXMachineFunctionInfo.h
lib/Target/PTX/PTXSubtarget.h
test/CodeGen/PTX/simple-call.ll [new file with mode: 0644]

index bb48e0ab4ba228f0eb88c261762cec58866311f8..fc0ec701990c2c4cfa05978cf9e2835ae5aa6f8c 100644 (file)
@@ -70,6 +70,8 @@ public:
                           const char *Modifier = 0); 
   void printPredicateOperand(const MachineInstr *MI, raw_ostream &O);
 
+  void printCall(const MachineInstr *MI, raw_ostream &O);
+
   unsigned GetOrCreateSourceID(StringRef FileName,
                                StringRef DirName);
 
@@ -242,6 +244,19 @@ void PTXAsmPrinter::EmitFunctionBodyStart() {
       OutStreamer.EmitRawText(Twine(def));
     }
   }
+
+  unsigned Index = 1;
+  // Print parameter passing params
+  for (PTXMachineFunctionInfo::param_iterator
+       i = MFI->paramBegin(), e = MFI->paramEnd(); i != e; ++i) {
+    std::string def = "\t.param .b";
+    def += utostr(*i);
+    def += " __ret_";
+    def += utostr(Index);
+    Index++;
+    def += ";";
+    OutStreamer.EmitRawText(Twine(def));
+  }
 }
 
 void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
@@ -302,7 +317,11 @@ void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
   printPredicateOperand(MI, OS);
 
   // Write instruction to str
-  printInstruction(MI, OS);
+  if (MI->getOpcode() == PTX::CALL) {
+    printCall(MI, OS);
+  } else {
+    printInstruction(MI, OS);
+  }
   OS << ';';
   OS.flush();
 
@@ -569,6 +588,28 @@ printPredicateOperand(const MachineInstr *MI, raw_ostream &O) {
   }
 }
 
+void PTXAsmPrinter::
+printCall(const MachineInstr *MI, raw_ostream &O) {
+
+  O << "\tcall.uni\t";
+
+  const GlobalValue *Address = MI->getOperand(2).getGlobal();
+  O << Address->getName() << ", (";
+
+  // (0,1) : predicate register/flag
+  // (2)   : callee
+  for (unsigned i = 3; i < MI->getNumOperands(); ++i) {
+    //const MachineOperand& MO = MI->getOperand(i);
+
+    printReturnOperand(MI, i, O);
+    if (i < MI->getNumOperands()-1) {
+      O << ", ";
+    }
+  }
+
+  O << ")";
+}
+
 unsigned PTXAsmPrinter::GetOrCreateSourceID(StringRef FileName,
                                             StringRef DirName) {
   // If FE did not provide a file name, then assume stdin.
index 6fcf710e3f1fe72e8e1b9697879d38ee0d06aa8d..d52aa2a01ad20f2e942a1911f92c0c1d79ab1202 100644 (file)
@@ -22,6 +22,7 @@
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
@@ -134,6 +135,8 @@ const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
       return "PTXISD::EXIT";
     case PTXISD::RET:
       return "PTXISD::RET";
+    case PTXISD::CALL:
+      return "PTXISD::CALL";
   }
 }
 
@@ -345,3 +348,49 @@ SDValue PTXTargetLowering::
     return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
   }
 }
+
+SDValue
+PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
+                             CallingConv::ID CallConv, bool isVarArg,
+                             bool &isTailCall,
+                             const SmallVectorImpl<ISD::OutputArg> &Outs,
+                             const SmallVectorImpl<SDValue> &OutVals,
+                             const SmallVectorImpl<ISD::InputArg> &Ins,
+                             DebugLoc dl, SelectionDAG &DAG,
+                             SmallVectorImpl<SDValue> &InVals) const {
+
+  MachineFunction& MF = DAG.getMachineFunction();
+  PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
+  const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
+
+  assert(ST.callsAreHandled() && "Calls are not handled for the target device");
+
+  // Is there a more "LLVM"-way to create a variable-length array of values?
+  SDValue* ops = new SDValue[OutVals.size() + 2];
+
+  ops[0] = Chain;
+
+  if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
+    const GlobalValue *GV = G->getGlobal();
+    Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
+    ops[1] = Callee;
+  } else {
+    assert(false && "Function must be a GlobalAddressSDNode");
+  }
+
+  for (unsigned i = 0; i != OutVals.size(); ++i) {
+    unsigned Size = OutVals[i].getValueType().getSizeInBits();
+    SDValue Index = DAG.getTargetConstant(MFI->getNextParam(Size), MVT::i32);
+    Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
+                        Index, OutVals[i]);
+    ops[i+2] = Index;
+  }
+
+  ops[0] = Chain;
+
+  Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, ops, OutVals.size()+2);
+
+  delete [] ops;
+
+  return Chain;
+}
index 43185416e1fc7fca8e4a5ed5c85e986b72c4f58d..f99ac7bc789eb701ba54dc6db9504f7cdebff2ff 100644 (file)
@@ -28,7 +28,8 @@ namespace PTXISD {
     STORE_PARAM,
     EXIT,
     RET,
-    COPY_ADDRESS
+    COPY_ADDRESS,
+    CALL
   };
 }                               // namespace PTXISD
 
@@ -60,6 +61,16 @@ class PTXTargetLowering : public TargetLowering {
                   DebugLoc dl,
                   SelectionDAG &DAG) const;
 
+    virtual SDValue
+      LowerCall(SDValue Chain, SDValue Callee,
+                CallingConv::ID CallConv, bool isVarArg,
+                bool &isTailCall,
+                const SmallVectorImpl<ISD::OutputArg> &Outs,
+                const SmallVectorImpl<SDValue> &OutVals,
+                const SmallVectorImpl<ISD::InputArg> &Ins,
+                DebugLoc dl, SelectionDAG &DAG,
+                SmallVectorImpl<SDValue> &InVals) const;
+
     virtual MVT::SimpleValueType getSetCCResultType(EVT VT) const;
 
   private:
index 6bfe906d40abe5c45154ba7d09a65dc04d4969aa..11caa7f1f9d795db93fcd51a7f01e5cb0998d07d 100644 (file)
@@ -168,6 +168,18 @@ def MEMret : Operand<i32> {
   let MIOperandInfo = (ops i32imm);
 }
 
+// def SDT_PTXCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>]>;
+// def SDT_PTXCallSeqEnd   : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>;
+
+// def PTXcallseq_start : SDNode<"ISD::CALLSEQ_START", SDT_PTXCallSeqStart,
+//                               [SDNPHasChain, SDNPOutGlue]>;
+// def PTXcallseq_end   : SDNode<"ISD::CALLSEQ_END", SDT_PTXCallSeqEnd,
+//                               [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue]>;
+
+def PTXcall : SDNode<"PTXISD::CALL", SDTNone,
+                     [SDNPHasChain, SDNPVariadic, SDNPOptInGlue, SDNPOutGlue]>;
+
+
 // Branch & call targets have OtherVT type.
 def brtarget   : Operand<OtherVT>;
 def calltarget : Operand<i32>;
@@ -1073,6 +1085,11 @@ let isReturn = 1, isTerminator = 1, isBarrier = 1 in {
   def RET  : InstPTX<(outs), (ins), "ret",  [(PTXret)]>;
 }
 
+let hasSideEffects = 1 in {
+  def CALL : InstPTX<(outs), (ins), "call", [(PTXcall)]>;
+}
+
+
 ///===- Spill Instructions ------------------------------------------------===//
 // Special instructions used for stack spilling
 def STACKSTOREI16 : InstPTX<(outs), (ins i32imm:$d, RegI16:$a),
@@ -1097,6 +1114,15 @@ def STACKLOADF32 : InstPTX<(outs), (ins RegF32:$d, i32imm:$a),
 def STACKLOADF64 : InstPTX<(outs), (ins RegF64:$d, i32imm:$a),
                            "mov.f64\t$d, s$a", []>;
 
+
+// Call handling
+// def ADJCALLSTACKUP :
+//   InstPTX<(outs), (ins i32imm:$amt1, i32imm:$amt2), "",
+//           [(PTXcallseq_end timm:$amt1, timm:$amt2)]>;
+// def ADJCALLSTACKDOWN :
+//   InstPTX<(outs), (ins i32imm:$amt), "",
+//           [(PTXcallseq_start timm:$amt)]>;
+
 ///===- Intrinsic Instructions --------------------------------------------===//
 
 include "PTXIntrinsicInstrInfo.td"
index 9d65f5bd1adedead0b695325b8b83db0f8386dfb..a3b0f324feb8ca863190f963d661537df8ca1b23 100644 (file)
@@ -27,6 +27,7 @@ private:
   bool is_kernel;
   std::vector<unsigned> reg_arg, reg_local_var;
   std::vector<unsigned> reg_ret;
+  std::vector<unsigned> call_params;
   bool _isDoneAddArg;
 
 public:
@@ -56,6 +57,7 @@ public:
   typedef std::vector<unsigned>::const_iterator         reg_iterator;
   typedef std::vector<unsigned>::const_reverse_iterator reg_reverse_iterator;
   typedef std::vector<unsigned>::const_iterator         ret_iterator;
+  typedef std::vector<unsigned>::const_iterator         param_iterator;
 
   bool         argRegEmpty() const { return reg_arg.empty(); }
   int          getNumArg() const { return reg_arg.size(); }
@@ -73,6 +75,13 @@ public:
   ret_iterator retRegBegin() const { return reg_ret.begin(); }
   ret_iterator retRegEnd()   const { return reg_ret.end(); }
 
+  param_iterator paramBegin() const { return call_params.begin(); }
+  param_iterator paramEnd() const { return call_params.end(); }
+  unsigned       getNextParam(unsigned size) {
+    call_params.push_back(size);
+    return call_params.size()-1;
+  }
+
   bool isArgReg(unsigned reg) const {
     return std::find(reg_arg.begin(), reg_arg.end(), reg) != reg_arg.end();
   }
index 0921f1f22c49b472165aa7e55bb8a5fd6e4e8c42..0404200992000d83e6c1eeda014cf05b400b246f 100644 (file)
@@ -114,7 +114,12 @@ class StringRef;
                (PTXTarget >= PTX_COMPUTE_2_0 && PTXTarget < PTX_LAST_COMPUTE);
       }
 
-    void ParseSubtargetFeatures(StringRef CPU, StringRef FS);
+      bool callsAreHandled() const {
+        return (PTXTarget >= PTX_SM_2_0 && PTXTarget < PTX_LAST_SM) ||
+               (PTXTarget >= PTX_COMPUTE_2_0 && PTXTarget < PTX_LAST_COMPUTE);
+      }
+
+      void ParseSubtargetFeatures(StringRef CPU, StringRef FS);
   }; // class PTXSubtarget
 } // namespace llvm
 
diff --git a/test/CodeGen/PTX/simple-call.ll b/test/CodeGen/PTX/simple-call.ll
new file mode 100644 (file)
index 0000000..36f6d8c
--- /dev/null
@@ -0,0 +1,14 @@
+; RUN: llc < %s -march=ptx32 -mattr=sm20 | FileCheck %s
+
+define ptx_device void @test_add(float %x, float %y) {
+; CHECK: ret;
+       %z = fadd float %x, %y
+       ret void
+}
+
+define ptx_device float @test_call(float %x, float %y) {
+  %a = fadd float %x, %y
+; CHECK: call.uni test_add, (__ret_{{[0-9]+}}, __ret_{{[0-9]+}});
+  call void @test_add(float %a, float %y)
+  ret float %a
+}