1 //===-- PTXISelLowering.cpp - PTX DAG Lowering Implementation -------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This file implements the PTXTargetLowering class.
12 //===----------------------------------------------------------------------===//
15 #include "PTXISelLowering.h"
16 #include "PTXMachineFunctionInfo.h"
17 #include "PTXRegisterInfo.h"
18 #include "PTXSubtarget.h"
19 #include "llvm/Function.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/CodeGen/CallingConvLower.h"
22 #include "llvm/CodeGen/MachineFunction.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/SelectionDAG.h"
25 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
31 //===----------------------------------------------------------------------===//
32 // TargetLowering Implementation
33 //===----------------------------------------------------------------------===//
35 PTXTargetLowering::PTXTargetLowering(TargetMachine &TM)
36 : TargetLowering(TM, new TargetLoweringObjectFileELF()) {
37 // Set up the register classes.
38 addRegisterClass(MVT::i1, PTX::RegPredRegisterClass);
39 addRegisterClass(MVT::i16, PTX::RegI16RegisterClass);
40 addRegisterClass(MVT::i32, PTX::RegI32RegisterClass);
41 addRegisterClass(MVT::i64, PTX::RegI64RegisterClass);
42 addRegisterClass(MVT::f32, PTX::RegF32RegisterClass);
43 addRegisterClass(MVT::f64, PTX::RegF64RegisterClass);
45 setBooleanContents(ZeroOrOneBooleanContent);
46 setBooleanVectorContents(ZeroOrOneBooleanContent); // FIXME: Is this correct?
47 setMinFunctionAlignment(2);
49 ////////////////////////////////////
50 /////////// Expansion //////////////
51 ////////////////////////////////////
53 // (any/zero/sign) extload => load + (any/zero/sign) extend
55 setLoadExtAction(ISD::EXTLOAD, MVT::i16, Expand);
56 setLoadExtAction(ISD::ZEXTLOAD, MVT::i16, Expand);
57 setLoadExtAction(ISD::SEXTLOAD, MVT::i16, Expand);
59 // f32 extload => load + fextend
61 setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
63 // f64 truncstore => trunc + store
65 setTruncStoreAction(MVT::f64, MVT::f32, Expand);
67 // sign_extend_inreg => sign_extend
69 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
73 setOperationAction(ISD::BR_CC, MVT::Other, Expand);
77 setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
78 setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
79 setOperationAction(ISD::SELECT_CC, MVT::f64, Expand);
81 ////////////////////////////////////
82 //////////// Legal /////////////////
83 ////////////////////////////////////
85 setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
86 setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
88 ////////////////////////////////////
89 //////////// Custom ////////////////
90 ////////////////////////////////////
92 // customise setcc to use bitwise logic if possible
94 setOperationAction(ISD::SETCC, MVT::i1, Custom);
96 // customize translation of memory addresses
98 setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
99 setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
101 // Compute derived properties from the register classes
102 computeRegisterProperties();
105 EVT PTXTargetLowering::getSetCCResultType(EVT VT) const {
109 SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
110 switch (Op.getOpcode()) {
112 llvm_unreachable("Unimplemented operand");
114 return LowerSETCC(Op, DAG);
115 case ISD::GlobalAddress:
116 return LowerGlobalAddress(Op, DAG);
120 const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
123 llvm_unreachable("Unknown opcode");
124 case PTXISD::COPY_ADDRESS:
125 return "PTXISD::COPY_ADDRESS";
126 case PTXISD::LOAD_PARAM:
127 return "PTXISD::LOAD_PARAM";
128 case PTXISD::STORE_PARAM:
129 return "PTXISD::STORE_PARAM";
130 case PTXISD::READ_PARAM:
131 return "PTXISD::READ_PARAM";
132 case PTXISD::WRITE_PARAM:
133 return "PTXISD::WRITE_PARAM";
135 return "PTXISD::EXIT";
137 return "PTXISD::RET";
139 return "PTXISD::CALL";
143 //===----------------------------------------------------------------------===//
144 // Custom Lower Operation
145 //===----------------------------------------------------------------------===//
147 SDValue PTXTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
148 assert(Op.getValueType() == MVT::i1 && "SetCC type must be 1-bit integer");
149 SDValue Op0 = Op.getOperand(0);
150 SDValue Op1 = Op.getOperand(1);
151 SDValue Op2 = Op.getOperand(2);
152 DebugLoc dl = Op.getDebugLoc();
153 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
155 // Look for X == 0, X == 1, X != 0, or X != 1
156 // We can simplify these to bitwise logic
158 if (Op1.getOpcode() == ISD::Constant &&
159 (cast<ConstantSDNode>(Op1)->getZExtValue() == 1 ||
160 cast<ConstantSDNode>(Op1)->isNullValue()) &&
161 (CC == ISD::SETEQ || CC == ISD::SETNE)) {
163 return DAG.getNode(ISD::AND, dl, MVT::i1, Op0, Op1);
166 return DAG.getNode(ISD::SETCC, dl, MVT::i1, Op0, Op1, Op2);
169 SDValue PTXTargetLowering::
170 LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
171 EVT PtrVT = getPointerTy();
172 DebugLoc dl = Op.getDebugLoc();
173 const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
175 assert(PtrVT.isSimple() && "Pointer must be to primitive type.");
177 SDValue targetGlobal = DAG.getTargetGlobalAddress(GV, dl, PtrVT);
178 SDValue movInstr = DAG.getNode(PTXISD::COPY_ADDRESS,
186 //===----------------------------------------------------------------------===//
187 // Calling Convention Implementation
188 //===----------------------------------------------------------------------===//
190 SDValue PTXTargetLowering::
191 LowerFormalArguments(SDValue Chain,
192 CallingConv::ID CallConv,
194 const SmallVectorImpl<ISD::InputArg> &Ins,
197 SmallVectorImpl<SDValue> &InVals) const {
198 if (isVarArg) llvm_unreachable("PTX does not support varargs");
200 MachineFunction &MF = DAG.getMachineFunction();
201 const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
202 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
203 PTXParamManager &PM = MFI->getParamManager();
207 llvm_unreachable("Unsupported calling convention");
209 case CallingConv::PTX_Kernel:
210 MFI->setKernel(true);
212 case CallingConv::PTX_Device:
213 MFI->setKernel(false);
217 // We do one of two things here:
218 // IsKernel || SM >= 2.0 -> Use param space for arguments
219 // SM < 2.0 -> Use registers for arguments
220 if (MFI->isKernel() || ST.useParamSpaceForDeviceArgs()) {
221 // We just need to emit the proper LOAD_PARAM ISDs
222 for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
223 assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) &&
224 "Kernels cannot take pred operands");
226 unsigned ParamSize = Ins[i].VT.getStoreSizeInBits();
227 unsigned Param = PM.addArgumentParam(ParamSize);
228 const std::string &ParamName = PM.getParamName(Param);
229 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
231 SDValue ArgValue = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
233 InVals.push_back(ArgValue);
237 for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
238 EVT RegVT = Ins[i].VT;
239 TargetRegisterClass* TRC = 0;
241 // Determine which register class we need
242 if (RegVT == MVT::i1)
243 TRC = PTX::RegPredRegisterClass;
244 else if (RegVT == MVT::i16)
245 TRC = PTX::RegI16RegisterClass;
246 else if (RegVT == MVT::i32)
247 TRC = PTX::RegI32RegisterClass;
248 else if (RegVT == MVT::i64)
249 TRC = PTX::RegI64RegisterClass;
250 else if (RegVT == MVT::f32)
251 TRC = PTX::RegF32RegisterClass;
252 else if (RegVT == MVT::f64)
253 TRC = PTX::RegF64RegisterClass;
255 llvm_unreachable("Unknown parameter type");
257 // Use a unique index in the instruction to prevent instruction folding.
258 // Yes, this is a hack.
259 SDValue Index = DAG.getTargetConstant(i, MVT::i32);
260 unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
261 SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, RegVT, Chain,
264 InVals.push_back(ArgValue);
273 SDValue PTXTargetLowering::
274 LowerReturn(SDValue Chain,
275 CallingConv::ID CallConv,
277 const SmallVectorImpl<ISD::OutputArg> &Outs,
278 const SmallVectorImpl<SDValue> &OutVals,
280 SelectionDAG &DAG) const {
281 if (isVarArg) llvm_unreachable("PTX does not support varargs");
285 llvm_unreachable("Unsupported calling convention.");
286 case CallingConv::PTX_Kernel:
287 assert(Outs.size() == 0 && "Kernel must return void.");
288 return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
289 case CallingConv::PTX_Device:
290 assert(Outs.size() <= 1 && "Can at most return one value.");
294 MachineFunction& MF = DAG.getMachineFunction();
295 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
296 PTXParamManager &PM = MFI->getParamManager();
299 const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
301 if (ST.useParamSpaceForDeviceArgs()) {
302 assert(Outs.size() < 2 && "Device functions can return at most one value");
304 if (Outs.size() == 1) {
305 unsigned ParamSize = OutVals[0].getValueType().getSizeInBits();
306 unsigned Param = PM.addReturnParam(ParamSize);
307 const std::string &ParamName = PM.getParamName(Param);
308 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
310 Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
311 ParamValue, OutVals[0]);
314 for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
315 EVT RegVT = Outs[i].VT;
316 TargetRegisterClass* TRC = 0;
318 // Determine which register class we need
319 if (RegVT == MVT::i1) {
320 TRC = PTX::RegPredRegisterClass;
322 else if (RegVT == MVT::i16) {
323 TRC = PTX::RegI16RegisterClass;
325 else if (RegVT == MVT::i32) {
326 TRC = PTX::RegI32RegisterClass;
328 else if (RegVT == MVT::i64) {
329 TRC = PTX::RegI64RegisterClass;
331 else if (RegVT == MVT::f32) {
332 TRC = PTX::RegF32RegisterClass;
334 else if (RegVT == MVT::f64) {
335 TRC = PTX::RegF64RegisterClass;
338 llvm_unreachable("Unknown parameter type");
341 unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
343 SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/);
344 SDValue OutReg = DAG.getRegister(Reg, RegVT);
346 Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg);
352 if (Flag.getNode() == 0) {
353 return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);
356 return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
361 PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
362 CallingConv::ID CallConv, bool isVarArg,
364 const SmallVectorImpl<ISD::OutputArg> &Outs,
365 const SmallVectorImpl<SDValue> &OutVals,
366 const SmallVectorImpl<ISD::InputArg> &Ins,
367 DebugLoc dl, SelectionDAG &DAG,
368 SmallVectorImpl<SDValue> &InVals) const {
370 MachineFunction& MF = DAG.getMachineFunction();
371 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
372 PTXParamManager &PM = MFI->getParamManager();
374 assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
375 "Calls are not handled for the target device");
377 std::vector<SDValue> Ops;
378 // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs]
379 Ops.resize(Outs.size() + Ins.size() + 4);
383 // Identify the callee function
384 const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
385 assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device &&
386 "PTX function calls must be to PTX device functions");
387 Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
388 Ops[Ins.size()+2] = Callee;
390 // Generate STORE_PARAM nodes for each function argument. In PTX, function
391 // arguments are explicitly stored into .param variables and passed as
392 // arguments. There is no register/stack-based calling convention in PTX.
393 Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32);
394 for (unsigned i = 0; i != OutVals.size(); ++i) {
395 unsigned Size = OutVals[i].getValueType().getSizeInBits();
396 unsigned Param = PM.addLocalParam(Size);
397 const std::string &ParamName = PM.getParamName(Param);
398 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
400 Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
401 ParamValue, OutVals[i]);
402 Ops[i+Ins.size()+4] = ParamValue;
405 std::vector<SDValue> InParams;
407 // Generate list of .param variables to hold the return value(s).
408 Ops[1] = DAG.getTargetConstant(Ins.size(), MVT::i32);
409 for (unsigned i = 0; i < Ins.size(); ++i) {
410 unsigned Size = Ins[i].VT.getStoreSizeInBits();
411 unsigned Param = PM.addLocalParam(Size);
412 const std::string &ParamName = PM.getParamName(Param);
413 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
415 Ops[i+2] = ParamValue;
416 InParams.push_back(ParamValue);
421 // Create the CALL node.
422 Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, &Ops[0], Ops.size());
424 // Create the LOAD_PARAM nodes that retrieve the function return value(s).
425 for (unsigned i = 0; i < Ins.size(); ++i) {
426 SDValue Load = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
428 InVals.push_back(Load);