PTX: Generalize handling of .param types
authorJustin Holewinski <justin.holewinski@gmail.com>
Fri, 23 Sep 2011 14:18:22 +0000 (14:18 +0000)
committerJustin Holewinski <justin.holewinski@gmail.com>
Fri, 23 Sep 2011 14:18:22 +0000 (14:18 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@140375 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/PTX/CMakeLists.txt
lib/Target/PTX/PTXAsmPrinter.cpp
lib/Target/PTX/PTXISelLowering.cpp
lib/Target/PTX/PTXInstrInfo.td
lib/Target/PTX/PTXMachineFunctionInfo.h
lib/Target/PTX/PTXParamManager.cpp [new file with mode: 0644]
lib/Target/PTX/PTXParamManager.h [new file with mode: 0644]

index f6e5c2295f0dc391002e8326a58622eeec15ca0d..abf6dcd0311cb78cf8be452970a32fe5b2d43db6 100644 (file)
@@ -15,6 +15,7 @@ add_llvm_target(PTXCodeGen
   PTXFrameLowering.cpp
   PTXMCAsmStreamer.cpp
   PTXMFInfoExtract.cpp
+  PTXParamManager.cpp
   PTXRegAlloc.cpp
   PTXRegisterInfo.cpp
   PTXSubtarget.cpp
index 6337ee99705b620f0609cc59943103de02edca65..06cab0bc79132d01c56610778eec821072e488ed 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "PTX.h"
 #include "PTXMachineFunctionInfo.h"
+#include "PTXParamManager.h"
 #include "PTXRegisterInfo.h"
 #include "PTXTargetMachine.h"
 #include "llvm/DerivedTypes.h"
@@ -435,7 +436,9 @@ void PTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
 
 void PTXAsmPrinter::printParamOperand(const MachineInstr *MI, int opNum,
                                       raw_ostream &OS, const char *Modifier) {
-  OS << PARAM_PREFIX << (int) MI->getOperand(opNum).getImm() + 1;
+  const PTXMachineFunctionInfo *MFI = MI->getParent()->getParent()->
+                                      getInfo<PTXMachineFunctionInfo>();
+  OS << MFI->getParamManager().getParamName(MI->getOperand(opNum).getImm());
 }
 
 void PTXAsmPrinter::printReturnOperand(const MachineInstr *MI, int opNum,
@@ -562,6 +565,7 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
   }
 
   const PTXMachineFunctionInfo *MFI = MF->getInfo<PTXMachineFunctionInfo>();
+  const PTXParamManager &PM = MFI->getParamManager();
   const bool isKernel = MFI->isKernel();
   const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
   const MachineRegisterInfo& MRI = MF->getRegInfo();
@@ -572,10 +576,18 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
 
   if (!isKernel) {
     decl += " (";
-    if (ST.useParamSpaceForDeviceArgs() && MFI->getRetParamSize() != 0) {
-      decl += ".param .b";
-      decl += utostr(MFI->getRetParamSize());
-      decl += " __ret";
+    if (ST.useParamSpaceForDeviceArgs()) {
+      for (PTXParamManager::param_iterator i = PM.ret_begin(), e = PM.ret_end(),
+           b = i; i != e; ++i) {
+        if (i != b) {
+          decl += ", ";
+        }
+
+        decl += ".param .b";
+        decl += utostr(PM.getParamSize(*i));
+        decl += " ";
+        decl += PM.getParamName(*i);
+      }
     } else {
       for (PTXMachineFunctionInfo::ret_iterator
            i = MFI->retRegBegin(), e = MFI->retRegEnd(), b = i;
@@ -602,18 +614,16 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
 
   // Print parameters
   if (isKernel || ST.useParamSpaceForDeviceArgs()) {
-    for (PTXMachineFunctionInfo::argparam_iterator
-         i = MFI->argParamBegin(), e = MFI->argParamEnd(), b = i;
-         i != e; ++i) {
+    for (PTXParamManager::param_iterator i = PM.arg_begin(), e = PM.arg_end(),
+         b = i; i != e; ++i) {
       if (i != b) {
         decl += ", ";
       }
 
       decl += ".param .b";
-      decl += utostr(*i);
+      decl += utostr(PM.getParamSize(*i));
       decl += " ";
-      decl += PARAM_PREFIX;
-      decl += utostr(++cnt);
+      decl += PM.getParamName(*i);
     }
   } else {
     for (PTXMachineFunctionInfo::reg_iterator
index 2d7756e214e14e6f2c334f953749aa1ca34a3465..79967280344d44ca62e7910c837b5a1ca4e59592 100644 (file)
@@ -199,6 +199,7 @@ SDValue PTXTargetLowering::
   MachineFunction &MF = DAG.getMachineFunction();
   const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
+  PTXParamManager &PM = MFI->getParamManager();
 
   switch (CallConv) {
     default:
@@ -221,8 +222,10 @@ SDValue PTXTargetLowering::
       assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) &&
              "Kernels cannot take pred operands");
 
+      unsigned ParamSize = Ins[i].VT.getStoreSizeInBits();
+      unsigned Param = PM.addArgumentParam(ParamSize);
       SDValue ArgValue = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
-                                     DAG.getTargetConstant(i, MVT::i32));
+                                     DAG.getTargetConstant(Param, MVT::i32));
       InVals.push_back(ArgValue);
 
       // Instead of storing a physical register in our argument list, we just
@@ -322,6 +325,7 @@ SDValue PTXTargetLowering::
 
   MachineFunction& MF = DAG.getMachineFunction();
   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
+  PTXParamManager &PM = MFI->getParamManager();
 
   SDValue Flag;
 
@@ -336,13 +340,15 @@ SDValue PTXTargetLowering::
     assert(Outs.size() < 2 && "Device functions can return at most one value");
 
     if (Outs.size() == 1) {
-      unsigned Size = OutVals[0].getValueType().getSizeInBits();
-      SDValue Index = DAG.getTargetConstant(MFI->getNextParam(Size), MVT::i32);
+      unsigned ParamSize = OutVals[0].getValueType().getSizeInBits();
+      unsigned Param = PM.addReturnParam(ParamSize);
+      SDValue ParamIndex = DAG.getTargetConstant(Param, MVT::i32);
       Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
-                          Index, OutVals[0]);
+                          ParamIndex, OutVals[0]);
+
 
       //Flag = Chain.getValue(1);
-      MFI->setRetParamSize(Outs[0].VT.getStoreSizeInBits());
+      //MFI->setRetParamSize(Outs[0].VT.getStoreSizeInBits());
     }
   } else {
     //SmallVector<CCValAssign, 16> RVLocs;
index 088142b2724793795922e43d891518a1e200a257..0c9b85626ecdbd69c18399098d1798d9ecdc17b2 100644 (file)
@@ -873,22 +873,22 @@ let hasSideEffects = 1 in {
                          "ld.param.f64\t$d, [$a]",
                          [(set RegF64:$d, (PTXloadparam timm:$a))]>;
 
-  def STpiPred : InstPTX<(outs), (ins MEMret:$d, RegPred:$a),
+  def STpiPred : InstPTX<(outs), (ins MEMpi:$d, RegPred:$a),
                          "st.param.pred\t[$d], $a",
                          [(PTXstoreparam timm:$d, RegPred:$a)]>;
-  def STpiU16  : InstPTX<(outs), (ins MEMret:$d, RegI16:$a),
+  def STpiU16  : InstPTX<(outs), (ins MEMpi:$d, RegI16:$a),
                          "st.param.u16\t[$d], $a",
                          [(PTXstoreparam timm:$d, RegI16:$a)]>;
-  def STpiU32  : InstPTX<(outs), (ins MEMret:$d, RegI32:$a),
+  def STpiU32  : InstPTX<(outs), (ins MEMpi:$d, RegI32:$a),
                          "st.param.u32\t[$d], $a",
                          [(PTXstoreparam timm:$d, RegI32:$a)]>;
-  def STpiU64  : InstPTX<(outs), (ins MEMret:$d, RegI64:$a),
+  def STpiU64  : InstPTX<(outs), (ins MEMpi:$d, RegI64:$a),
                          "st.param.u64\t[$d], $a",
                          [(PTXstoreparam timm:$d, RegI64:$a)]>;
-  def STpiF32  : InstPTX<(outs), (ins MEMret:$d, RegF32:$a),
+  def STpiF32  : InstPTX<(outs), (ins MEMpi:$d, RegF32:$a),
                          "st.param.f32\t[$d], $a",
                          [(PTXstoreparam timm:$d, RegF32:$a)]>;
-  def STpiF64  : InstPTX<(outs), (ins MEMret:$d, RegF64:$a),
+  def STpiF64  : InstPTX<(outs), (ins MEMpi:$d, RegF64:$a),
                          "st.param.f64\t[$d], $a",
                          [(PTXstoreparam timm:$d, RegF64:$a)]>;
 }
index 93189bbf62c9a7853358910f8bd328dda28553be..90795ea99a85c1eb3a8f90d4732cf1343b5e70b1 100644 (file)
@@ -15,6 +15,7 @@
 #define PTX_MACHINE_FUNCTION_INFO_H
 
 #include "PTX.h"
+#include "PTXParamManager.h"
 #include "PTXRegisterInfo.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
@@ -48,6 +49,8 @@ private:
 
   unsigned retParamSize;
 
+  PTXParamManager ParamManager;
+
 public:
   PTXMachineFunctionInfo(MachineFunction &MF)
     : is_kernel(false), reg_ret(PTX::NoRegister), _isDoneAddArg(false) {
@@ -61,6 +64,9 @@ public:
       retParamSize = 0;
     }
 
+  PTXParamManager& getParamManager() { return ParamManager; }
+  const PTXParamManager& getParamManager() const { return ParamManager; }
+
   void setKernel(bool _is_kernel=true) { is_kernel = _is_kernel; }
 
 
diff --git a/lib/Target/PTX/PTXParamManager.cpp b/lib/Target/PTX/PTXParamManager.cpp
new file mode 100644 (file)
index 0000000..f4945d9
--- /dev/null
@@ -0,0 +1,73 @@
+//===- PTXParamManager.cpp - Manager for .param variables -------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the PTXParamManager class.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PTX.h"
+#include "PTXParamManager.h"
+#include "llvm/ADT/StringExtras.h"
+
+using namespace llvm;
+
+PTXParamManager::PTXParamManager() {
+}
+
+unsigned PTXParamManager::addArgumentParam(unsigned Size) {
+  PTXParam Param;
+  Param.Type = PTX_PARAM_TYPE_ARGUMENT;
+  Param.Size = Size;
+
+  std::string Name;
+  Name = "__param_";
+  Name += utostr(ArgumentParams.size()+1);
+  Param.Name = Name;
+
+  unsigned Index = AllParams.size();
+  AllParams[Index] = Param;
+  ArgumentParams.insert(Index);
+
+  return Index;
+}
+
+unsigned PTXParamManager::addReturnParam(unsigned Size) {
+  PTXParam Param;
+  Param.Type = PTX_PARAM_TYPE_RETURN;
+  Param.Size = Size;
+
+  std::string Name;
+  Name = "__ret_";
+  Name += utostr(ReturnParams.size()+1);
+  Param.Name = Name;
+
+  unsigned Index = AllParams.size();
+  AllParams[Index] = Param;
+  ReturnParams.insert(Index);
+
+  return Index;
+}
+
+unsigned PTXParamManager::addLocalParam(unsigned Size) {
+  PTXParam Param;
+  Param.Type = PTX_PARAM_TYPE_LOCAL;
+  Param.Size = Size;
+
+  std::string Name;
+  Name = "__localparam_";
+  Name += utostr(LocalParams.size()+1);
+  Param.Name = Name;
+
+  unsigned Index = AllParams.size();
+  AllParams[Index] = Param;
+  LocalParams.insert(Index);
+
+  return Index;
+}
+
diff --git a/lib/Target/PTX/PTXParamManager.h b/lib/Target/PTX/PTXParamManager.h
new file mode 100644 (file)
index 0000000..05b0d31
--- /dev/null
@@ -0,0 +1,86 @@
+//===- PTXParamManager.h - Manager for .param variables ----------*- C++ -*-==//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the PTXParamManager class, which manages all defined .param
+// variables for a particular function.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PTX_PARAM_MANAGER_H
+#define PTX_PARAM_MANAGER_H
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+
+namespace llvm {
+
+/// PTXParamManager - This class manages all .param variables defined for a
+/// particular function.
+class PTXParamManager {
+private:
+
+  /// PTXParamType - Type of a .param variable
+  enum PTXParamType {
+    PTX_PARAM_TYPE_ARGUMENT,
+    PTX_PARAM_TYPE_RETURN,
+    PTX_PARAM_TYPE_LOCAL
+  };
+
+  /// PTXParam - Definition of a PTX .param variable
+  struct PTXParam {
+    PTXParamType  Type;
+    unsigned      Size;
+    std::string   Name;
+  };
+
+  DenseMap<unsigned, PTXParam> AllParams;
+  DenseSet<unsigned> ArgumentParams;
+  DenseSet<unsigned> ReturnParams;
+  DenseSet<unsigned> LocalParams;
+
+public:
+
+  typedef DenseSet<unsigned>::const_iterator param_iterator;
+
+  PTXParamManager();
+
+  param_iterator arg_begin() const { return ArgumentParams.begin(); }
+  param_iterator arg_end() const { return ArgumentParams.end(); }
+  param_iterator ret_begin() const { return ReturnParams.begin(); }
+  param_iterator ret_end() const { return ReturnParams.end(); }
+  param_iterator local_begin() const { return LocalParams.begin(); }
+  param_iterator local_end() const { return LocalParams.end(); }
+
+  /// addArgumentParam - Returns a new .param used as an argument.
+  unsigned addArgumentParam(unsigned Size);
+
+  /// addReturnParam - Returns a new .param used as a return argument.
+  unsigned addReturnParam(unsigned Size);
+
+  /// addLocalParam - Returns a new .param used as a local .param variable.
+  unsigned addLocalParam(unsigned Size);
+
+  /// getParamName - Returns the name of the parameter as a string.
+  std::string getParamName(unsigned Param) const {
+    assert(AllParams.count(Param) == 1 && "Param has not been defined!");
+    return AllParams.lookup(Param).Name;
+  }
+
+  /// getParamSize - Returns the size of the parameter in bits.
+  unsigned getParamSize(unsigned Param) const {
+    assert(AllParams.count(Param) == 1 && "Param has not been defined!");
+    return AllParams.lookup(Param).Size;
+  }
+
+};
+
+}
+
+#endif
+