2 // The LLVM Compiler Infrastructure
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
7 //===----------------------------------------------------------------------===//
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
12 //===----------------------------------------------------------------------===//
14 #include "NVPTXISelLowering.h"
16 #include "NVPTXTargetMachine.h"
17 #include "NVPTXTargetObjectFile.h"
18 #include "NVPTXUtilities.h"
19 #include "llvm/CodeGen/Analysis.h"
20 #include "llvm/CodeGen/MachineFrameInfo.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/MC/MCSectionELF.h"
32 #include "llvm/Support/CallSite.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/raw_ostream.h"
40 #define DEBUG_TYPE "nvptx-lower"
44 static unsigned int uniqueCallSite = 0;
46 static cl::opt<bool> sched4reg(
48 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
50 static bool IsPTXVectorType(MVT VT) {
51 switch (VT.SimpleTy) {
68 // NVPTXTargetLowering Constructor.
69 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
70 : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
71 nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
73 // always lower memset, memcpy, and memmove intrinsics to load/store
74 // instructions, rather
75 // then generating calls to memset, mempcy or memmove.
76 MaxStoresPerMemset = (unsigned) 0xFFFFFFFF;
77 MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF;
78 MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF;
80 setBooleanContents(ZeroOrNegativeOneBooleanContent);
82 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
83 // condition branches.
84 setJumpIsExpensive(true);
86 // By default, use the Source scheduling
88 setSchedulingPreference(Sched::RegPressure);
90 setSchedulingPreference(Sched::Source);
92 addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
93 addRegisterClass(MVT::i8, &NVPTX::Int8RegsRegClass);
94 addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
95 addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
96 addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
97 addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
98 addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
100 // Operations not directly supported by NVPTX.
101 setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
102 setOperationAction(ISD::BR_CC, MVT::f32, Expand);
103 setOperationAction(ISD::BR_CC, MVT::f64, Expand);
104 setOperationAction(ISD::BR_CC, MVT::i1, Expand);
105 setOperationAction(ISD::BR_CC, MVT::i8, Expand);
106 setOperationAction(ISD::BR_CC, MVT::i16, Expand);
107 setOperationAction(ISD::BR_CC, MVT::i32, Expand);
108 setOperationAction(ISD::BR_CC, MVT::i64, Expand);
109 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Expand);
110 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Expand);
111 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand);
112 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8, Expand);
113 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
115 if (nvptxSubtarget.hasROT64()) {
116 setOperationAction(ISD::ROTL, MVT::i64, Legal);
117 setOperationAction(ISD::ROTR, MVT::i64, Legal);
119 setOperationAction(ISD::ROTL, MVT::i64, Expand);
120 setOperationAction(ISD::ROTR, MVT::i64, Expand);
122 if (nvptxSubtarget.hasROT32()) {
123 setOperationAction(ISD::ROTL, MVT::i32, Legal);
124 setOperationAction(ISD::ROTR, MVT::i32, Legal);
126 setOperationAction(ISD::ROTL, MVT::i32, Expand);
127 setOperationAction(ISD::ROTR, MVT::i32, Expand);
130 setOperationAction(ISD::ROTL, MVT::i16, Expand);
131 setOperationAction(ISD::ROTR, MVT::i16, Expand);
132 setOperationAction(ISD::ROTL, MVT::i8, Expand);
133 setOperationAction(ISD::ROTR, MVT::i8, Expand);
134 setOperationAction(ISD::BSWAP, MVT::i16, Expand);
135 setOperationAction(ISD::BSWAP, MVT::i32, Expand);
136 setOperationAction(ISD::BSWAP, MVT::i64, Expand);
138 // Indirect branch is not supported.
139 // This also disables Jump Table creation.
140 setOperationAction(ISD::BR_JT, MVT::Other, Expand);
141 setOperationAction(ISD::BRIND, MVT::Other, Expand);
143 setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
144 setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
146 // We want to legalize constant related memmove and memcopy
148 setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
150 // Turn FP extload into load/fextend
151 setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
152 // Turn FP truncstore into trunc + store.
153 setTruncStoreAction(MVT::f64, MVT::f32, Expand);
155 // PTX does not support load / store predicate registers
156 setOperationAction(ISD::LOAD, MVT::i1, Custom);
157 setOperationAction(ISD::STORE, MVT::i1, Custom);
159 setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
160 setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
161 setTruncStoreAction(MVT::i64, MVT::i1, Expand);
162 setTruncStoreAction(MVT::i32, MVT::i1, Expand);
163 setTruncStoreAction(MVT::i16, MVT::i1, Expand);
164 setTruncStoreAction(MVT::i8, MVT::i1, Expand);
166 // This is legal in NVPTX
167 setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
168 setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
170 // TRAP can be lowered to PTX trap
171 setOperationAction(ISD::TRAP, MVT::Other, Legal);
173 // Register custom handling for vector loads/stores
174 for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
176 MVT VT = (MVT::SimpleValueType) i;
177 if (IsPTXVectorType(VT)) {
178 setOperationAction(ISD::LOAD, VT, Custom);
179 setOperationAction(ISD::STORE, VT, Custom);
180 setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
184 // Now deduce the information based on the above mentioned
186 computeRegisterProperties();
189 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
194 return "NVPTXISD::CALL";
195 case NVPTXISD::RET_FLAG:
196 return "NVPTXISD::RET_FLAG";
197 case NVPTXISD::Wrapper:
198 return "NVPTXISD::Wrapper";
199 case NVPTXISD::NVBuiltin:
200 return "NVPTXISD::NVBuiltin";
201 case NVPTXISD::DeclareParam:
202 return "NVPTXISD::DeclareParam";
203 case NVPTXISD::DeclareScalarParam:
204 return "NVPTXISD::DeclareScalarParam";
205 case NVPTXISD::DeclareRet:
206 return "NVPTXISD::DeclareRet";
207 case NVPTXISD::DeclareRetParam:
208 return "NVPTXISD::DeclareRetParam";
209 case NVPTXISD::PrintCall:
210 return "NVPTXISD::PrintCall";
211 case NVPTXISD::LoadParam:
212 return "NVPTXISD::LoadParam";
213 case NVPTXISD::LoadParamV2:
214 return "NVPTXISD::LoadParamV2";
215 case NVPTXISD::LoadParamV4:
216 return "NVPTXISD::LoadParamV4";
217 case NVPTXISD::StoreParam:
218 return "NVPTXISD::StoreParam";
219 case NVPTXISD::StoreParamV2:
220 return "NVPTXISD::StoreParamV2";
221 case NVPTXISD::StoreParamV4:
222 return "NVPTXISD::StoreParamV4";
223 case NVPTXISD::StoreParamS32:
224 return "NVPTXISD::StoreParamS32";
225 case NVPTXISD::StoreParamU32:
226 return "NVPTXISD::StoreParamU32";
227 case NVPTXISD::MoveToParam:
228 return "NVPTXISD::MoveToParam";
229 case NVPTXISD::CallArgBegin:
230 return "NVPTXISD::CallArgBegin";
231 case NVPTXISD::CallArg:
232 return "NVPTXISD::CallArg";
233 case NVPTXISD::LastCallArg:
234 return "NVPTXISD::LastCallArg";
235 case NVPTXISD::CallArgEnd:
236 return "NVPTXISD::CallArgEnd";
237 case NVPTXISD::CallVoid:
238 return "NVPTXISD::CallVoid";
239 case NVPTXISD::CallVal:
240 return "NVPTXISD::CallVal";
241 case NVPTXISD::CallSymbol:
242 return "NVPTXISD::CallSymbol";
243 case NVPTXISD::Prototype:
244 return "NVPTXISD::Prototype";
245 case NVPTXISD::MoveParam:
246 return "NVPTXISD::MoveParam";
247 case NVPTXISD::MoveRetval:
248 return "NVPTXISD::MoveRetval";
249 case NVPTXISD::MoveToRetval:
250 return "NVPTXISD::MoveToRetval";
251 case NVPTXISD::StoreRetval:
252 return "NVPTXISD::StoreRetval";
253 case NVPTXISD::StoreRetvalV2:
254 return "NVPTXISD::StoreRetvalV2";
255 case NVPTXISD::StoreRetvalV4:
256 return "NVPTXISD::StoreRetvalV4";
257 case NVPTXISD::PseudoUseParam:
258 return "NVPTXISD::PseudoUseParam";
259 case NVPTXISD::RETURN:
260 return "NVPTXISD::RETURN";
261 case NVPTXISD::CallSeqBegin:
262 return "NVPTXISD::CallSeqBegin";
263 case NVPTXISD::CallSeqEnd:
264 return "NVPTXISD::CallSeqEnd";
265 case NVPTXISD::LoadV2:
266 return "NVPTXISD::LoadV2";
267 case NVPTXISD::LoadV4:
268 return "NVPTXISD::LoadV4";
269 case NVPTXISD::LDGV2:
270 return "NVPTXISD::LDGV2";
271 case NVPTXISD::LDGV4:
272 return "NVPTXISD::LDGV4";
273 case NVPTXISD::LDUV2:
274 return "NVPTXISD::LDUV2";
275 case NVPTXISD::LDUV4:
276 return "NVPTXISD::LDUV4";
277 case NVPTXISD::StoreV2:
278 return "NVPTXISD::StoreV2";
279 case NVPTXISD::StoreV4:
280 return "NVPTXISD::StoreV4";
284 bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
285 return VT == MVT::i1;
289 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
291 const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
292 Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
293 return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
296 std::string NVPTXTargetLowering::getPrototype(
297 Type *retTy, const ArgListTy &Args,
298 const SmallVectorImpl<ISD::OutputArg> &Outs, unsigned retAlignment) const {
300 bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
303 O << "prototype_" << uniqueCallSite << " : .callprototype ";
305 if (retTy->getTypeID() == Type::VoidTyID)
310 if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
312 if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
313 size = ITy->getBitWidth();
317 assert(retTy->isFloatingPointTy() &&
318 "Floating point type expected here");
319 size = retTy->getPrimitiveSizeInBits();
322 O << ".param .b" << size << " _";
323 } else if (isa<PointerType>(retTy))
324 O << ".param .b" << getPointerTy().getSizeInBits() << " _";
326 if ((retTy->getTypeID() == Type::StructTyID) ||
327 isa<VectorType>(retTy)) {
328 SmallVector<EVT, 16> vtparts;
329 ComputeValueVTs(*this, retTy, vtparts);
330 unsigned totalsz = 0;
331 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
333 EVT elemtype = vtparts[i];
334 if (vtparts[i].isVector()) {
335 elems = vtparts[i].getVectorNumElements();
336 elemtype = vtparts[i].getVectorElementType();
338 for (unsigned j = 0, je = elems; j != je; ++j) {
339 unsigned sz = elemtype.getSizeInBits();
340 if (elemtype.isInteger() && (sz < 8))
345 O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
347 assert(false && "Unknown return type");
351 SmallVector<EVT, 16> vtparts;
352 ComputeValueVTs(*this, retTy, vtparts);
354 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
356 EVT elemtype = vtparts[i];
357 if (vtparts[i].isVector()) {
358 elems = vtparts[i].getVectorNumElements();
359 elemtype = vtparts[i].getVectorElementType();
362 for (unsigned j = 0, je = elems; j != je; ++j) {
363 unsigned sz = elemtype.getSizeInBits();
364 if (elemtype.isInteger() && (sz < 32))
366 O << ".reg .b" << sz << " _";
380 MVT thePointerTy = getPointerTy();
382 for (unsigned i = 0, e = Args.size(); i != e; ++i) {
383 const Type *Ty = Args[i].Ty;
389 if (Outs[i].Flags.isByVal() == false) {
391 if (isa<IntegerType>(Ty)) {
392 sz = cast<IntegerType>(Ty)->getBitWidth();
395 } else if (isa<PointerType>(Ty))
396 sz = thePointerTy.getSizeInBits();
398 sz = Ty->getPrimitiveSizeInBits();
400 O << ".param .b" << sz << " ";
402 O << ".reg .b" << sz << " ";
406 const PointerType *PTy = dyn_cast<PointerType>(Ty);
407 assert(PTy && "Param with byval attribute should be a pointer type");
408 Type *ETy = PTy->getElementType();
411 unsigned align = Outs[i].Flags.getByValAlign();
412 unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
413 O << ".param .align " << align << " .b8 ";
415 O << "[" << sz << "]";
418 SmallVector<EVT, 16> vtparts;
419 ComputeValueVTs(*this, ETy, vtparts);
420 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
422 EVT elemtype = vtparts[i];
423 if (vtparts[i].isVector()) {
424 elems = vtparts[i].getVectorNumElements();
425 elemtype = vtparts[i].getVectorElementType();
428 for (unsigned j = 0, je = elems; j != je; ++j) {
429 unsigned sz = elemtype.getSizeInBits();
430 if (elemtype.isInteger() && (sz < 32))
432 O << ".reg .b" << sz << " ";
447 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
448 SmallVectorImpl<SDValue> &InVals) const {
449 SelectionDAG &DAG = CLI.DAG;
451 SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
452 SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
453 SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
454 SDValue Chain = CLI.Chain;
455 SDValue Callee = CLI.Callee;
456 bool &isTailCall = CLI.IsTailCall;
457 ArgListTy &Args = CLI.Args;
458 Type *retTy = CLI.RetTy;
459 ImmutableCallSite *CS = CLI.CS;
461 bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
463 SDValue tempChain = Chain;
464 Chain = DAG.getCALLSEQ_START(Chain,
465 DAG.getIntPtrConstant(uniqueCallSite, true),
467 SDValue InFlag = Chain.getValue(1);
469 assert((Outs.size() == Args.size()) &&
470 "Unexpected number of arguments to function call");
471 unsigned paramCount = 0;
472 // Declare the .params or .reg need to pass values
474 for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
477 if (Outs[i].Flags.isByVal() == false) {
479 // for ABI, declare .param .b<size> .param<n>;
480 // for nonABI, declare .reg .b<size> .param<n>;
484 unsigned sz = VT.getSizeInBits();
485 if (VT.isInteger() && (sz < 32))
487 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
488 SDValue DeclareParamOps[] = { Chain,
489 DAG.getConstant(paramCount, MVT::i32),
490 DAG.getConstant(sz, MVT::i32),
491 DAG.getConstant(isReg, MVT::i32), InFlag };
492 Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
494 InFlag = Chain.getValue(1);
495 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
496 SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
497 DAG.getConstant(0, MVT::i32), OutVals[i],
500 unsigned opcode = NVPTXISD::StoreParam;
502 opcode = NVPTXISD::MoveToParam;
504 if (Outs[i].Flags.isZExt())
505 opcode = NVPTXISD::StoreParamU32;
506 else if (Outs[i].Flags.isSExt())
507 opcode = NVPTXISD::StoreParamS32;
509 Chain = DAG.getNode(opcode, dl, CopyParamVTs, CopyParamOps, 5);
511 InFlag = Chain.getValue(1);
516 SmallVector<EVT, 16> vtparts;
517 const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
518 assert(PTy && "Type of a byval parameter should be pointer");
519 ComputeValueVTs(*this, PTy->getElementType(), vtparts);
522 // declare .param .align 16 .b8 .param<n>[<size>];
523 unsigned sz = Outs[i].Flags.getByValSize();
524 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
525 // The ByValAlign in the Outs[i].Flags is alway set at this point, so we
527 // worry about natural alignment or not. See TargetLowering::LowerCallTo()
528 SDValue DeclareParamOps[] = {
529 Chain, DAG.getConstant(Outs[i].Flags.getByValAlign(), MVT::i32),
530 DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32),
533 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
535 InFlag = Chain.getValue(1);
536 unsigned curOffset = 0;
537 for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
539 EVT elemtype = vtparts[j];
540 if (vtparts[j].isVector()) {
541 elems = vtparts[j].getVectorNumElements();
542 elemtype = vtparts[j].getVectorElementType();
544 for (unsigned k = 0, ke = elems; k != ke; ++k) {
545 unsigned sz = elemtype.getSizeInBits();
546 if (elemtype.isInteger() && (sz < 8))
549 DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i],
550 DAG.getConstant(curOffset, getPointerTy()));
552 DAG.getLoad(elemtype, dl, tempChain, srcAddr,
553 MachinePointerInfo(), false, false, false, 0);
554 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
555 SDValue CopyParamOps[] = { Chain,
556 DAG.getConstant(paramCount, MVT::i32),
557 DAG.getConstant(curOffset, MVT::i32),
559 Chain = DAG.getNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
561 InFlag = Chain.getValue(1);
568 // Non-abi, struct or vector
569 // Declare a bunch or .reg .b<size> .param<n>
570 unsigned curOffset = 0;
571 for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
573 EVT elemtype = vtparts[j];
574 if (vtparts[j].isVector()) {
575 elems = vtparts[j].getVectorNumElements();
576 elemtype = vtparts[j].getVectorElementType();
578 for (unsigned k = 0, ke = elems; k != ke; ++k) {
579 unsigned sz = elemtype.getSizeInBits();
580 if (elemtype.isInteger() && (sz < 32))
582 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
583 SDValue DeclareParamOps[] = { Chain,
584 DAG.getConstant(paramCount, MVT::i32),
585 DAG.getConstant(sz, MVT::i32),
586 DAG.getConstant(1, MVT::i32), InFlag };
587 Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
589 InFlag = Chain.getValue(1);
591 DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i],
592 DAG.getConstant(curOffset, getPointerTy()));
594 DAG.getLoad(elemtype, dl, tempChain, srcAddr, MachinePointerInfo(),
595 false, false, false, 0);
596 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
597 SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
598 DAG.getConstant(0, MVT::i32), theVal,
600 Chain = DAG.getNode(NVPTXISD::MoveToParam, dl, CopyParamVTs,
602 InFlag = Chain.getValue(1);
608 GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
609 unsigned retAlignment = 0;
612 unsigned retCount = 0;
613 if (Ins.size() > 0) {
614 SmallVector<EVT, 16> resvtparts;
615 ComputeValueVTs(*this, retTy, resvtparts);
617 // Declare one .param .align 16 .b8 func_retval0[<size>] for ABI or
618 // individual .reg .b<size> func_retval<0..> for non ABI
619 unsigned resultsz = 0;
620 for (unsigned i = 0, e = resvtparts.size(); i != e; ++i) {
622 EVT elemtype = resvtparts[i];
623 if (resvtparts[i].isVector()) {
624 elems = resvtparts[i].getVectorNumElements();
625 elemtype = resvtparts[i].getVectorElementType();
627 for (unsigned j = 0, je = elems; j != je; ++j) {
628 unsigned sz = elemtype.getSizeInBits();
629 if (isABI == false) {
630 if (elemtype.isInteger() && (sz < 32))
633 if (elemtype.isInteger() && (sz < 8))
636 if (isABI == false) {
637 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
638 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(2, MVT::i32),
639 DAG.getConstant(sz, MVT::i32),
640 DAG.getConstant(retCount, MVT::i32),
642 Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
644 InFlag = Chain.getValue(1);
651 if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
652 retTy->isPointerTy()) {
653 // Scalar needs to be at least 32bit wide
656 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
657 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
658 DAG.getConstant(resultsz, MVT::i32),
659 DAG.getConstant(0, MVT::i32), InFlag };
660 Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
662 InFlag = Chain.getValue(1);
664 if (Func) { // direct call
665 if (!llvm::getAlign(*(CS->getCalledFunction()), 0, retAlignment))
666 retAlignment = getDataLayout()->getABITypeAlignment(retTy);
667 } else { // indirect call
668 const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
669 if (!llvm::getAlign(*CallI, 0, retAlignment))
670 retAlignment = getDataLayout()->getABITypeAlignment(retTy);
672 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
673 SDValue DeclareRetOps[] = { Chain,
674 DAG.getConstant(retAlignment, MVT::i32),
675 DAG.getConstant(resultsz / 8, MVT::i32),
676 DAG.getConstant(0, MVT::i32), InFlag };
677 Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
679 InFlag = Chain.getValue(1);
685 // This is indirect function call case : PTX requires a prototype of the
687 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
688 // to be emitted, and the label has to used as the last arg of call
690 // The prototype is embedded in a string and put as the operand for an
692 SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
693 std::string proto_string = getPrototype(retTy, Args, Outs, retAlignment);
694 const char *asmstr = nvTM->getManagedStrPool()
695 ->getManagedString(proto_string.c_str())->c_str();
696 SDValue InlineAsmOps[] = {
697 Chain, DAG.getTargetExternalSymbol(asmstr, getPointerTy()),
698 DAG.getMDNode(0), DAG.getTargetConstant(0, MVT::i32), InFlag
700 Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
701 InFlag = Chain.getValue(1);
703 // Op to just print "call"
704 SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
705 SDValue PrintCallOps[] = {
707 DAG.getConstant(isABI ? ((Ins.size() == 0) ? 0 : 1) : retCount, MVT::i32),
710 Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
711 dl, PrintCallVTs, PrintCallOps, 3);
712 InFlag = Chain.getValue(1);
714 // Ops to print out the function name
715 SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
716 SDValue CallVoidOps[] = { Chain, Callee, InFlag };
717 Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
718 InFlag = Chain.getValue(1);
720 // Ops to print out the param list
721 SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
722 SDValue CallArgBeginOps[] = { Chain, InFlag };
723 Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
725 InFlag = Chain.getValue(1);
727 for (unsigned i = 0, e = paramCount; i != e; ++i) {
730 opcode = NVPTXISD::LastCallArg;
732 opcode = NVPTXISD::CallArg;
733 SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
734 SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
735 DAG.getConstant(i, MVT::i32), InFlag };
736 Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
737 InFlag = Chain.getValue(1);
739 SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
740 SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32),
743 DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps, 3);
744 InFlag = Chain.getValue(1);
747 SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
748 SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32),
750 Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
751 InFlag = Chain.getValue(1);
754 // Generate loads from param memory/moves from registers for result
755 if (Ins.size() > 0) {
757 unsigned resoffset = 0;
758 for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
759 unsigned sz = Ins[i].VT.getSizeInBits();
760 if (Ins[i].VT.isInteger() && (sz < 8))
762 EVT LoadRetVTs[] = { Ins[i].VT, MVT::Other, MVT::Glue };
763 SDValue LoadRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
764 DAG.getConstant(resoffset, MVT::i32), InFlag };
765 SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, LoadRetVTs,
766 LoadRetOps, array_lengthof(LoadRetOps));
767 Chain = retval.getValue(1);
768 InFlag = retval.getValue(2);
769 InVals.push_back(retval);
773 SmallVector<EVT, 16> resvtparts;
774 ComputeValueVTs(*this, retTy, resvtparts);
776 assert(Ins.size() == resvtparts.size() &&
777 "Unexpected number of return values in non-ABI case");
778 unsigned paramNum = 0;
779 for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
780 assert(EVT(Ins[i].VT) == resvtparts[i] &&
781 "Unexpected EVT type in non-ABI case");
782 unsigned numelems = 1;
783 EVT elemtype = Ins[i].VT;
784 if (Ins[i].VT.isVector()) {
785 numelems = Ins[i].VT.getVectorNumElements();
786 elemtype = Ins[i].VT.getVectorElementType();
788 std::vector<SDValue> tempRetVals;
789 for (unsigned j = 0; j < numelems; ++j) {
790 EVT MoveRetVTs[] = { elemtype, MVT::Other, MVT::Glue };
791 SDValue MoveRetOps[] = { Chain, DAG.getConstant(0, MVT::i32),
792 DAG.getConstant(paramNum, MVT::i32),
794 SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, MoveRetVTs,
795 MoveRetOps, array_lengthof(MoveRetOps));
796 Chain = retval.getValue(1);
797 InFlag = retval.getValue(2);
798 tempRetVals.push_back(retval);
801 if (Ins[i].VT.isVector())
802 InVals.push_back(DAG.getNode(ISD::BUILD_VECTOR, dl, Ins[i].VT,
803 &tempRetVals[0], tempRetVals.size()));
805 InVals.push_back(tempRetVals[0]);
809 Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
810 DAG.getIntPtrConstant(uniqueCallSite + 1, true),
814 // set isTailCall to false for now, until we figure out how to express
815 // tail call optimization in PTX
820 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
821 // (see LegalizeDAG.cpp). This is slow and uses local memory.
822 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
824 NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
825 SDNode *Node = Op.getNode();
827 SmallVector<SDValue, 8> Ops;
828 unsigned NumOperands = Node->getNumOperands();
829 for (unsigned i = 0; i < NumOperands; ++i) {
830 SDValue SubOp = Node->getOperand(i);
831 EVT VVT = SubOp.getNode()->getValueType(0);
832 EVT EltVT = VVT.getVectorElementType();
833 unsigned NumSubElem = VVT.getVectorNumElements();
834 for (unsigned j = 0; j < NumSubElem; ++j) {
835 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
836 DAG.getIntPtrConstant(j)));
839 return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), &Ops[0],
844 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
845 switch (Op.getOpcode()) {
846 case ISD::RETURNADDR:
850 case ISD::GlobalAddress:
851 return LowerGlobalAddress(Op, DAG);
852 case ISD::INTRINSIC_W_CHAIN:
854 case ISD::BUILD_VECTOR:
855 case ISD::EXTRACT_SUBVECTOR:
857 case ISD::CONCAT_VECTORS:
858 return LowerCONCAT_VECTORS(Op, DAG);
860 return LowerSTORE(Op, DAG);
862 return LowerLOAD(Op, DAG);
864 llvm_unreachable("Custom lowering not defined for operation");
868 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
869 if (Op.getValueType() == MVT::i1)
870 return LowerLOADi1(Op, DAG);
878 // v = trunc v1 to i1
879 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
880 SDNode *Node = Op.getNode();
881 LoadSDNode *LD = cast<LoadSDNode>(Node);
883 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
884 assert(Node->getValueType(0) == MVT::i1 &&
885 "Custom lowering for i1 load only");
887 DAG.getLoad(MVT::i8, dl, LD->getChain(), LD->getBasePtr(),
888 LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(),
889 LD->isInvariant(), LD->getAlignment());
890 SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
891 // The legalizer (the caller) is expecting two values from the legalized
892 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
893 // in LegalizeDAG.cpp which also uses MergeValues.
894 SDValue Ops[] = { result, LD->getChain() };
895 return DAG.getMergeValues(Ops, 2, dl);
898 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
899 EVT ValVT = Op.getOperand(1).getValueType();
900 if (ValVT == MVT::i1)
901 return LowerSTOREi1(Op, DAG);
902 else if (ValVT.isVector())
903 return LowerSTOREVector(Op, DAG);
909 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
910 SDNode *N = Op.getNode();
911 SDValue Val = N->getOperand(1);
913 EVT ValVT = Val.getValueType();
915 if (ValVT.isVector()) {
916 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
917 // legal. We can (and should) split that into 2 stores of <2 x double> here
918 // but I'm leaving that as a TODO for now.
919 if (!ValVT.isSimple())
921 switch (ValVT.getSimpleVT().SimpleTy) {
934 // This is a "native" vector type
939 EVT EltVT = ValVT.getVectorElementType();
940 unsigned NumElts = ValVT.getVectorNumElements();
942 // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
943 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
944 // stored type to i16 and propogate the "real" type as the memory type.
945 bool NeedExt = false;
946 if (EltVT.getSizeInBits() < 16)
953 Opcode = NVPTXISD::StoreV2;
956 Opcode = NVPTXISD::StoreV4;
961 SmallVector<SDValue, 8> Ops;
963 // First is the chain
964 Ops.push_back(N->getOperand(0));
966 // Then the split values
967 for (unsigned i = 0; i < NumElts; ++i) {
968 SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
969 DAG.getIntPtrConstant(i));
971 // ANY_EXTEND is correct here since the store will only look at the
972 // lower-order bits anyway.
973 ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
974 Ops.push_back(ExtVal);
977 // Then any remaining arguments
978 for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
979 Ops.push_back(N->getOperand(i));
982 MemSDNode *MemSD = cast<MemSDNode>(N);
984 SDValue NewSt = DAG.getMemIntrinsicNode(
985 Opcode, DL, DAG.getVTList(MVT::Other), &Ops[0], Ops.size(),
986 MemSD->getMemoryVT(), MemSD->getMemOperand());
988 //return DCI.CombineTo(N, NewSt, true);
999 SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
1000 SDNode *Node = Op.getNode();
1002 StoreSDNode *ST = cast<StoreSDNode>(Node);
1003 SDValue Tmp1 = ST->getChain();
1004 SDValue Tmp2 = ST->getBasePtr();
1005 SDValue Tmp3 = ST->getValue();
1006 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
1007 unsigned Alignment = ST->getAlignment();
1008 bool isVolatile = ST->isVolatile();
1009 bool isNonTemporal = ST->isNonTemporal();
1010 Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, Tmp3);
1011 SDValue Result = DAG.getStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(),
1012 isVolatile, isNonTemporal, Alignment);
1016 SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname,
1017 int idx, EVT v) const {
1018 std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
1019 std::stringstream suffix;
1021 *name += suffix.str();
1022 return DAG.getTargetExternalSymbol(name->c_str(), v);
1026 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1027 return getExtSymb(DAG, ".PARAM", idx, v);
1030 SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1031 return getExtSymb(DAG, ".HLPPARAM", idx);
1034 // Check to see if the kernel argument is image*_t or sampler_t
1036 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1037 static const char *const specialTypes[] = { "struct._image2d_t",
1038 "struct._image3d_t",
1039 "struct._sampler_t" };
1041 const Type *Ty = arg->getType();
1042 const PointerType *PTy = dyn_cast<PointerType>(Ty);
1050 const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1051 const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1053 for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1054 if (TypeName == specialTypes[i])
1060 SDValue NVPTXTargetLowering::LowerFormalArguments(
1061 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1062 const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG,
1063 SmallVectorImpl<SDValue> &InVals) const {
1064 MachineFunction &MF = DAG.getMachineFunction();
1065 const DataLayout *TD = getDataLayout();
1067 const Function *F = MF.getFunction();
1068 const AttributeSet &PAL = F->getAttributes();
1069 const TargetLowering *TLI = nvTM->getTargetLowering();
1071 SDValue Root = DAG.getRoot();
1072 std::vector<SDValue> OutChains;
1074 bool isKernel = llvm::isKernelFunction(*F);
1075 bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1076 assert(isABI && "Non-ABI compilation is not supported");
1080 std::vector<Type *> argTypes;
1081 std::vector<const Argument *> theArgs;
1082 for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1084 theArgs.push_back(I);
1085 argTypes.push_back(I->getType());
1087 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
1088 // Ins.size() will be larger
1089 // * if there is an aggregate argument with multiple fields (each field
1090 // showing up separately in Ins)
1091 // * if there is a vector argument with more than typical vector-length
1092 // elements (generally if more than 4) where each vector element is
1093 // individually present in Ins.
1094 // So a different index should be used for indexing into Ins.
1095 // See similar issue in LowerCall.
1096 unsigned InsIdx = 0;
1099 for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) {
1100 Type *Ty = argTypes[i];
1102 // If the kernel argument is image*_t or sampler_t, convert it to
1103 // a i32 constant holding the parameter position. This can later
1104 // matched in the AsmPrinter to output the correct mangled name.
1105 if (isImageOrSamplerVal(
1107 (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent()
1109 assert(isKernel && "Only kernels can have image/sampler params");
1110 InVals.push_back(DAG.getConstant(i + 1, MVT::i32));
1114 if (theArgs[i]->use_empty()) {
1116 if (Ty->isAggregateType()) {
1117 SmallVector<EVT, 16> vtparts;
1119 ComputeValueVTs(*this, Ty, vtparts);
1120 assert(vtparts.size() > 0 && "empty aggregate type not expected");
1121 for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1123 EVT partVT = vtparts[parti];
1124 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, partVT));
1127 if (vtparts.size() > 0)
1131 if (Ty->isVectorTy()) {
1132 EVT ObjectVT = getValueType(Ty);
1133 unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
1134 for (unsigned parti = 0; parti < NumRegs; ++parti) {
1135 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1142 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1146 // In the following cases, assign a node order of "idx+1"
1147 // to newly created nodes. The SDNodes for params have to
1148 // appear in the same order as their order of appearance
1149 // in the original function. "idx+1" holds that order.
1150 if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
1151 if (Ty->isAggregateType()) {
1152 SmallVector<EVT, 16> vtparts;
1153 SmallVector<uint64_t, 16> offsets;
1155 ComputeValueVTs(*this, Ty, vtparts, &offsets, 0);
1156 assert(vtparts.size() > 0 && "empty aggregate type not expected");
1157 bool aggregateIsPacked = false;
1158 if (StructType *STy = llvm::dyn_cast<StructType>(Ty))
1159 aggregateIsPacked = STy->isPacked();
1161 SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1162 for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1164 EVT partVT = vtparts[parti];
1165 Value *srcValue = Constant::getNullValue(
1166 PointerType::get(partVT.getTypeForEVT(F->getContext()),
1167 llvm::ADDRESS_SPACE_PARAM));
1169 DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1170 DAG.getConstant(offsets[parti], getPointerTy()));
1171 unsigned partAlign =
1172 aggregateIsPacked ? 1
1173 : TD->getABITypeAlignment(
1174 partVT.getTypeForEVT(F->getContext()));
1175 SDValue p = DAG.getLoad(partVT, dl, Root, srcAddr,
1176 MachinePointerInfo(srcValue), false, false,
1179 p.getNode()->setIROrder(idx + 1);
1180 InVals.push_back(p);
1183 if (vtparts.size() > 0)
1187 if (Ty->isVectorTy()) {
1188 EVT ObjectVT = getValueType(Ty);
1189 SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1190 unsigned NumElts = ObjectVT.getVectorNumElements();
1191 assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1192 "Vector was not scalarized");
1194 EVT EltVT = ObjectVT.getVectorElementType();
1199 // We only have one element, so just directly load it
1200 Value *SrcValue = Constant::getNullValue(PointerType::get(
1201 EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1202 SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1203 DAG.getConstant(Ofst, getPointerTy()));
1204 SDValue P = DAG.getLoad(
1205 EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1207 TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
1209 P.getNode()->setIROrder(idx + 1);
1211 InVals.push_back(P);
1212 Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
1214 } else if (NumElts == 2) {
1216 // f32,f32 = load ...
1217 EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
1218 Value *SrcValue = Constant::getNullValue(PointerType::get(
1219 VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1220 SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1221 DAG.getConstant(Ofst, getPointerTy()));
1222 SDValue P = DAG.getLoad(
1223 VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1225 TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1227 P.getNode()->setIROrder(idx + 1);
1229 SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1230 DAG.getIntPtrConstant(0));
1231 SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1232 DAG.getIntPtrConstant(1));
1233 InVals.push_back(Elt0);
1234 InVals.push_back(Elt1);
1235 Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1239 // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
1241 // vector will be expanded to a power of 2 elements, so we know we can
1242 // always round up to the next multiple of 4 when creating the vector
1244 // e.g. 4 elem => 1 ld.v4
1245 // 6 elem => 2 ld.v4
1246 // 8 elem => 2 ld.v4
1247 // 11 elem => 3 ld.v4
1248 unsigned VecSize = 4;
1249 if (EltVT.getSizeInBits() == 64) {
1252 EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1253 for (unsigned i = 0; i < NumElts; i += VecSize) {
1254 Value *SrcValue = Constant::getNullValue(
1255 PointerType::get(VecVT.getTypeForEVT(F->getContext()),
1256 llvm::ADDRESS_SPACE_PARAM));
1258 DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1259 DAG.getConstant(Ofst, getPointerTy()));
1260 SDValue P = DAG.getLoad(
1261 VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1263 TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1265 P.getNode()->setIROrder(idx + 1);
1267 for (unsigned j = 0; j < VecSize; ++j) {
1268 if (i + j >= NumElts)
1270 SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1271 DAG.getIntPtrConstant(j));
1272 InVals.push_back(Elt);
1274 Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1284 EVT ObjectVT = getValueType(Ty);
1285 assert(ObjectVT == Ins[InsIdx].VT &&
1286 "Ins type did not match function type");
1287 // If ABI, load from the param symbol
1288 SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1289 Value *srcValue = Constant::getNullValue(PointerType::get(
1290 ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1291 SDValue p = DAG.getLoad(
1292 ObjectVT, dl, Root, Arg, MachinePointerInfo(srcValue), false, false,
1294 TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1296 p.getNode()->setIROrder(idx + 1);
1297 InVals.push_back(p);
1301 // Param has ByVal attribute
1302 // Return MoveParam(param symbol).
1303 // Ideally, the param symbol can be returned directly,
1304 // but when SDNode builder decides to use it in a CopyToReg(),
1305 // machine instruction fails because TargetExternalSymbol
1306 // (not lowered) is target dependent, and CopyToReg assumes
1307 // the source is lowered.
1308 EVT ObjectVT = getValueType(Ty);
1309 assert(ObjectVT == Ins[InsIdx].VT &&
1310 "Ins type did not match function type");
1311 SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1312 SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1314 p.getNode()->setIROrder(idx + 1);
1316 InVals.push_back(p);
1318 SDValue p2 = DAG.getNode(
1319 ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1320 DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
1321 InVals.push_back(p2);
1325 // Clang will check explicit VarArg and issue error if any. However, Clang
1326 // will let code with
1327 // implicit var arg like f() pass. See bug 617733.
1328 // We treat this case as if the arg list is empty.
1329 // if (F.isVarArg()) {
1330 // assert(0 && "VarArg not supported yet!");
1333 if (!OutChains.empty())
1334 DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &OutChains[0],
1342 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1344 const SmallVectorImpl<ISD::OutputArg> &Outs,
1345 const SmallVectorImpl<SDValue> &OutVals,
1346 SDLoc dl, SelectionDAG &DAG) const {
1347 MachineFunction &MF = DAG.getMachineFunction();
1348 const Function *F = MF.getFunction();
1349 const Type *RetTy = F->getReturnType();
1350 const DataLayout *TD = getDataLayout();
1352 bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1353 assert(isABI && "Non-ABI compilation is not supported");
1357 if (const VectorType *VTy = dyn_cast<const VectorType>(RetTy)) {
1358 // If we have a vector type, the OutVals array will be the scalarized
1359 // components and we have combine them into 1 or more vector stores.
1360 unsigned NumElts = VTy->getNumElements();
1361 assert(NumElts == Outs.size() && "Bad scalarization of return value");
1365 SDValue StoreVal = OutVals[0];
1366 // We only have one element, so just directly store it
1367 if (StoreVal.getValueType().getSizeInBits() < 8)
1368 StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
1369 Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain,
1370 DAG.getConstant(0, MVT::i32), StoreVal);
1371 } else if (NumElts == 2) {
1373 SDValue StoreVal0 = OutVals[0];
1374 SDValue StoreVal1 = OutVals[1];
1376 if (StoreVal0.getValueType().getSizeInBits() < 8) {
1377 StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal0);
1378 StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal1);
1381 Chain = DAG.getNode(NVPTXISD::StoreRetvalV2, dl, MVT::Other, Chain,
1382 DAG.getConstant(0, MVT::i32), StoreVal0, StoreVal1);
1385 // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
1386 // vector will be expanded to a power of 2 elements, so we know we can
1387 // always round up to the next multiple of 4 when creating the vector
1389 // e.g. 4 elem => 1 st.v4
1390 // 6 elem => 2 st.v4
1391 // 8 elem => 2 st.v4
1392 // 11 elem => 3 st.v4
1394 unsigned VecSize = 4;
1395 if (OutVals[0].getValueType().getSizeInBits() == 64)
1398 unsigned Offset = 0;
1401 EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
1402 unsigned PerStoreOffset =
1403 TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1405 bool Extend = false;
1406 if (OutVals[0].getValueType().getSizeInBits() < 8)
1409 for (unsigned i = 0; i < NumElts; i += VecSize) {
1412 SmallVector<SDValue, 8> Ops;
1413 Ops.push_back(Chain);
1414 Ops.push_back(DAG.getConstant(Offset, MVT::i32));
1415 unsigned Opc = NVPTXISD::StoreRetvalV2;
1416 EVT ExtendedVT = (Extend) ? MVT::i8 : OutVals[0].getValueType();
1418 StoreVal = OutVals[i];
1420 StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
1421 Ops.push_back(StoreVal);
1423 if (i + 1 < NumElts) {
1424 StoreVal = OutVals[i + 1];
1426 StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
1428 StoreVal = DAG.getUNDEF(ExtendedVT);
1430 Ops.push_back(StoreVal);
1433 Opc = NVPTXISD::StoreRetvalV4;
1434 if (i + 2 < NumElts) {
1435 StoreVal = OutVals[i + 2];
1437 StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
1439 StoreVal = DAG.getUNDEF(ExtendedVT);
1441 Ops.push_back(StoreVal);
1443 if (i + 3 < NumElts) {
1444 StoreVal = OutVals[i + 3];
1446 StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
1448 StoreVal = DAG.getUNDEF(ExtendedVT);
1450 Ops.push_back(StoreVal);
1453 Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
1454 Offset += PerStoreOffset;
1458 unsigned sizesofar = 0;
1459 for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
1460 SDValue theVal = OutVals[i];
1461 EVT theValType = theVal.getValueType();
1462 unsigned numElems = 1;
1463 if (theValType.isVector())
1464 numElems = theValType.getVectorNumElements();
1465 for (unsigned j = 0, je = numElems; j != je; ++j) {
1466 SDValue tmpval = theVal;
1467 if (theValType.isVector())
1468 tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1469 theValType.getVectorElementType(), tmpval,
1470 DAG.getIntPtrConstant(j));
1471 EVT theStoreType = tmpval.getValueType();
1472 if (theStoreType.getSizeInBits() < 8)
1473 tmpval = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, tmpval);
1474 Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain,
1475 DAG.getConstant(sizesofar, MVT::i32), tmpval);
1476 if (theValType.isVector())
1478 theValType.getVectorElementType().getStoreSizeInBits() / 8;
1480 sizesofar += theValType.getStoreSizeInBits() / 8;
1485 return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1488 void NVPTXTargetLowering::LowerAsmOperandForConstraint(
1489 SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
1490 SelectionDAG &DAG) const {
1491 if (Constraint.length() > 1)
1494 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1497 // NVPTX suuport vector of legal types of any length in Intrinsics because the
1498 // NVPTX specific type legalizer
1499 // will legalize them to the PTX supported length.
1500 bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1501 if (isTypeLegal(VT))
1503 if (VT.isVector()) {
1504 MVT eVT = VT.getVectorElementType();
1505 if (isTypeLegal(eVT))
1511 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
1513 // because we need the information that is only available in the "Value" type
1515 // pointer. In particular, the address space information.
1516 bool NVPTXTargetLowering::getTgtMemIntrinsic(
1517 IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const {
1518 switch (Intrinsic) {
1522 case Intrinsic::nvvm_atomic_load_add_f32:
1523 Info.opc = ISD::INTRINSIC_W_CHAIN;
1524 Info.memVT = MVT::f32;
1525 Info.ptrVal = I.getArgOperand(0);
1528 Info.readMem = true;
1529 Info.writeMem = true;
1533 case Intrinsic::nvvm_atomic_load_inc_32:
1534 case Intrinsic::nvvm_atomic_load_dec_32:
1535 Info.opc = ISD::INTRINSIC_W_CHAIN;
1536 Info.memVT = MVT::i32;
1537 Info.ptrVal = I.getArgOperand(0);
1540 Info.readMem = true;
1541 Info.writeMem = true;
1545 case Intrinsic::nvvm_ldu_global_i:
1546 case Intrinsic::nvvm_ldu_global_f:
1547 case Intrinsic::nvvm_ldu_global_p:
1549 Info.opc = ISD::INTRINSIC_W_CHAIN;
1550 if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
1551 Info.memVT = MVT::i32;
1552 else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
1553 Info.memVT = getPointerTy();
1555 Info.memVT = MVT::f32;
1556 Info.ptrVal = I.getArgOperand(0);
1559 Info.readMem = true;
1560 Info.writeMem = false;
1568 /// isLegalAddressingMode - Return true if the addressing mode represented
1569 /// by AM is legal for this target, for a load/store of the specified type.
1570 /// Used to guide target specific optimizations, like loop strength reduction
1571 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
1572 /// (CodeGenPrepare.cpp)
1573 bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
1576 // AddrMode - This represents an addressing mode of:
1577 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
1579 // The legal address modes are
1586 if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
1592 case 0: // "r", "r+i" or "i" is allowed
1595 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
1597 // Otherwise we have r+i.
1600 // No scale > 1 is allowed
1606 //===----------------------------------------------------------------------===//
1607 // NVPTX Inline Assembly Support
1608 //===----------------------------------------------------------------------===//
1610 /// getConstraintType - Given a constraint letter, return the type of
1611 /// constraint it is for this target.
1612 NVPTXTargetLowering::ConstraintType
1613 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
1614 if (Constraint.size() == 1) {
1615 switch (Constraint[0]) {
1626 return C_RegisterClass;
1629 return TargetLowering::getConstraintType(Constraint);
1632 std::pair<unsigned, const TargetRegisterClass *>
1633 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
1635 if (Constraint.size() == 1) {
1636 switch (Constraint[0]) {
1638 return std::make_pair(0U, &NVPTX::Int8RegsRegClass);
1640 return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
1642 return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
1645 return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
1647 return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
1649 return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
1652 return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
1655 /// getFunctionAlignment - Return the Log2 alignment of this function.
1656 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
1660 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
1661 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
1662 SmallVectorImpl<SDValue> &Results) {
1663 EVT ResVT = N->getValueType(0);
1666 assert(ResVT.isVector() && "Vector load must have vector type");
1668 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1669 // legal. We can (and should) split that into 2 loads of <2 x double> here
1670 // but I'm leaving that as a TODO for now.
1671 assert(ResVT.isSimple() && "Can only handle simple types");
1672 switch (ResVT.getSimpleVT().SimpleTy) {
1685 // This is a "native" vector type
1689 EVT EltVT = ResVT.getVectorElementType();
1690 unsigned NumElts = ResVT.getVectorNumElements();
1692 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
1693 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
1694 // loaded type to i16 and propogate the "real" type as the memory type.
1695 bool NeedTrunc = false;
1696 if (EltVT.getSizeInBits() < 16) {
1701 unsigned Opcode = 0;
1708 Opcode = NVPTXISD::LoadV2;
1709 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
1712 Opcode = NVPTXISD::LoadV4;
1713 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
1714 LdResVTs = DAG.getVTList(ListVTs, 5);
1719 SmallVector<SDValue, 8> OtherOps;
1721 // Copy regular operands
1722 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1723 OtherOps.push_back(N->getOperand(i));
1725 LoadSDNode *LD = cast<LoadSDNode>(N);
1727 // The select routine does not have access to the LoadSDNode instance, so
1728 // pass along the extension information
1729 OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
1731 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
1732 OtherOps.size(), LD->getMemoryVT(),
1733 LD->getMemOperand());
1735 SmallVector<SDValue, 4> ScalarRes;
1737 for (unsigned i = 0; i < NumElts; ++i) {
1738 SDValue Res = NewLD.getValue(i);
1740 Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
1741 ScalarRes.push_back(Res);
1744 SDValue LoadChain = NewLD.getValue(NumElts);
1747 DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
1749 Results.push_back(BuildVec);
1750 Results.push_back(LoadChain);
1753 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
1754 SmallVectorImpl<SDValue> &Results) {
1755 SDValue Chain = N->getOperand(0);
1756 SDValue Intrin = N->getOperand(1);
1759 // Get the intrinsic ID
1760 unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
1764 case Intrinsic::nvvm_ldg_global_i:
1765 case Intrinsic::nvvm_ldg_global_f:
1766 case Intrinsic::nvvm_ldg_global_p:
1767 case Intrinsic::nvvm_ldu_global_i:
1768 case Intrinsic::nvvm_ldu_global_f:
1769 case Intrinsic::nvvm_ldu_global_p: {
1770 EVT ResVT = N->getValueType(0);
1772 if (ResVT.isVector()) {
1775 unsigned NumElts = ResVT.getVectorNumElements();
1776 EVT EltVT = ResVT.getVectorElementType();
1778 // Since LDU/LDG are target nodes, we cannot rely on DAG type legalization.
1779 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
1780 // loaded type to i16 and propogate the "real" type as the memory type.
1781 bool NeedTrunc = false;
1782 if (EltVT.getSizeInBits() < 16) {
1787 unsigned Opcode = 0;
1797 case Intrinsic::nvvm_ldg_global_i:
1798 case Intrinsic::nvvm_ldg_global_f:
1799 case Intrinsic::nvvm_ldg_global_p:
1800 Opcode = NVPTXISD::LDGV2;
1802 case Intrinsic::nvvm_ldu_global_i:
1803 case Intrinsic::nvvm_ldu_global_f:
1804 case Intrinsic::nvvm_ldu_global_p:
1805 Opcode = NVPTXISD::LDUV2;
1808 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
1814 case Intrinsic::nvvm_ldg_global_i:
1815 case Intrinsic::nvvm_ldg_global_f:
1816 case Intrinsic::nvvm_ldg_global_p:
1817 Opcode = NVPTXISD::LDGV4;
1819 case Intrinsic::nvvm_ldu_global_i:
1820 case Intrinsic::nvvm_ldu_global_f:
1821 case Intrinsic::nvvm_ldu_global_p:
1822 Opcode = NVPTXISD::LDUV4;
1825 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
1826 LdResVTs = DAG.getVTList(ListVTs, 5);
1831 SmallVector<SDValue, 8> OtherOps;
1833 // Copy regular operands
1835 OtherOps.push_back(Chain); // Chain
1836 // Skip operand 1 (intrinsic ID)
1838 for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
1839 OtherOps.push_back(N->getOperand(i));
1841 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
1843 SDValue NewLD = DAG.getMemIntrinsicNode(
1844 Opcode, DL, LdResVTs, &OtherOps[0], OtherOps.size(),
1845 MemSD->getMemoryVT(), MemSD->getMemOperand());
1847 SmallVector<SDValue, 4> ScalarRes;
1849 for (unsigned i = 0; i < NumElts; ++i) {
1850 SDValue Res = NewLD.getValue(i);
1853 DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
1854 ScalarRes.push_back(Res);
1857 SDValue LoadChain = NewLD.getValue(NumElts);
1860 DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
1862 Results.push_back(BuildVec);
1863 Results.push_back(LoadChain);
1866 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
1867 "Custom handling of non-i8 ldu/ldg?");
1869 // Just copy all operands as-is
1870 SmallVector<SDValue, 4> Ops;
1871 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1872 Ops.push_back(N->getOperand(i));
1874 // Force output to i16
1875 SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
1877 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
1879 // We make sure the memory type is i8, which will be used during isel
1880 // to select the proper instruction.
1882 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0],
1883 Ops.size(), MVT::i8, MemSD->getMemOperand());
1885 Results.push_back(NewLD.getValue(0));
1886 Results.push_back(NewLD.getValue(1));
1892 void NVPTXTargetLowering::ReplaceNodeResults(
1893 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
1894 switch (N->getOpcode()) {
1896 report_fatal_error("Unhandled custom legalization");
1898 ReplaceLoadVector(N, DAG, Results);
1900 case ISD::INTRINSIC_W_CHAIN:
1901 ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);