871bc3c9b2fff2f6c5ed91cfe06a2fb763a5834a
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
1 //
2 //                     The LLVM Compiler Infrastructure
3 //
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "NVPTXISelLowering.h"
15 #include "NVPTX.h"
16 #include "NVPTXTargetMachine.h"
17 #include "NVPTXTargetObjectFile.h"
18 #include "NVPTXUtilities.h"
19 #include "llvm/CodeGen/Analysis.h"
20 #include "llvm/CodeGen/MachineFrameInfo.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/MC/MCSectionELF.h"
32 #include "llvm/Support/CallSite.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <sstream>
38
39 #undef DEBUG_TYPE
40 #define DEBUG_TYPE "nvptx-lower"
41
42 using namespace llvm;
43
44 static unsigned int uniqueCallSite = 0;
45
46 static cl::opt<bool> sched4reg(
47     "nvptx-sched4reg",
48     cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
49
50 static bool IsPTXVectorType(MVT VT) {
51   switch (VT.SimpleTy) {
52   default:
53     return false;
54   case MVT::v2i1:
55   case MVT::v4i1:
56   case MVT::v2i8:
57   case MVT::v4i8:
58   case MVT::v2i16:
59   case MVT::v4i16:
60   case MVT::v2i32:
61   case MVT::v4i32:
62   case MVT::v2i64:
63   case MVT::v2f32:
64   case MVT::v4f32:
65   case MVT::v2f64:
66     return true;
67   }
68 }
69
70 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
71 /// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
72 /// into their primitive components.
73 /// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
74 /// same number of types as the Ins/Outs arrays in LowerFormalArguments,
75 /// LowerCall, and LowerReturn.
76 static void ComputePTXValueVTs(const TargetLowering &TLI, Type *Ty,
77                                SmallVectorImpl<EVT> &ValueVTs,
78                                SmallVectorImpl<uint64_t> *Offsets = 0,
79                                uint64_t StartingOffset = 0) {
80   SmallVector<EVT, 16> TempVTs;
81   SmallVector<uint64_t, 16> TempOffsets;
82
83   ComputeValueVTs(TLI, Ty, TempVTs, &TempOffsets, StartingOffset);
84   for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
85     EVT VT = TempVTs[i];
86     uint64_t Off = TempOffsets[i];
87     if (VT.isVector())
88       for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) {
89         ValueVTs.push_back(VT.getVectorElementType());
90         if (Offsets)
91           Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize());
92       }
93     else {
94       ValueVTs.push_back(VT);
95       if (Offsets)
96         Offsets->push_back(Off);
97     }
98   }
99 }
100
101 // NVPTXTargetLowering Constructor.
102 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
103     : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
104       nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
105
106   // always lower memset, memcpy, and memmove intrinsics to load/store
107   // instructions, rather
108   // then generating calls to memset, mempcy or memmove.
109   MaxStoresPerMemset = (unsigned) 0xFFFFFFFF;
110   MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF;
111   MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF;
112
113   setBooleanContents(ZeroOrNegativeOneBooleanContent);
114
115   // Jump is Expensive. Don't create extra control flow for 'and', 'or'
116   // condition branches.
117   setJumpIsExpensive(true);
118
119   // By default, use the Source scheduling
120   if (sched4reg)
121     setSchedulingPreference(Sched::RegPressure);
122   else
123     setSchedulingPreference(Sched::Source);
124
125   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
126   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
127   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
128   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
129   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
130   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
131
132   // Operations not directly supported by NVPTX.
133   setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
134   setOperationAction(ISD::BR_CC, MVT::f32, Expand);
135   setOperationAction(ISD::BR_CC, MVT::f64, Expand);
136   setOperationAction(ISD::BR_CC, MVT::i1, Expand);
137   setOperationAction(ISD::BR_CC, MVT::i8, Expand);
138   setOperationAction(ISD::BR_CC, MVT::i16, Expand);
139   setOperationAction(ISD::BR_CC, MVT::i32, Expand);
140   setOperationAction(ISD::BR_CC, MVT::i64, Expand);
141   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Expand);
142   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Expand);
143   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand);
144   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8, Expand);
145   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
146
147   if (nvptxSubtarget.hasROT64()) {
148     setOperationAction(ISD::ROTL, MVT::i64, Legal);
149     setOperationAction(ISD::ROTR, MVT::i64, Legal);
150   } else {
151     setOperationAction(ISD::ROTL, MVT::i64, Expand);
152     setOperationAction(ISD::ROTR, MVT::i64, Expand);
153   }
154   if (nvptxSubtarget.hasROT32()) {
155     setOperationAction(ISD::ROTL, MVT::i32, Legal);
156     setOperationAction(ISD::ROTR, MVT::i32, Legal);
157   } else {
158     setOperationAction(ISD::ROTL, MVT::i32, Expand);
159     setOperationAction(ISD::ROTR, MVT::i32, Expand);
160   }
161
162   setOperationAction(ISD::ROTL, MVT::i16, Expand);
163   setOperationAction(ISD::ROTR, MVT::i16, Expand);
164   setOperationAction(ISD::ROTL, MVT::i8, Expand);
165   setOperationAction(ISD::ROTR, MVT::i8, Expand);
166   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
167   setOperationAction(ISD::BSWAP, MVT::i32, Expand);
168   setOperationAction(ISD::BSWAP, MVT::i64, Expand);
169
170   // Indirect branch is not supported.
171   // This also disables Jump Table creation.
172   setOperationAction(ISD::BR_JT, MVT::Other, Expand);
173   setOperationAction(ISD::BRIND, MVT::Other, Expand);
174
175   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
176   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
177
178   // We want to legalize constant related memmove and memcopy
179   // intrinsics.
180   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
181
182   // Turn FP extload into load/fextend
183   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
184   // Turn FP truncstore into trunc + store.
185   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
186
187   // PTX does not support load / store predicate registers
188   setOperationAction(ISD::LOAD, MVT::i1, Custom);
189   setOperationAction(ISD::STORE, MVT::i1, Custom);
190
191   setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
192   setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
193   setTruncStoreAction(MVT::i64, MVT::i1, Expand);
194   setTruncStoreAction(MVT::i32, MVT::i1, Expand);
195   setTruncStoreAction(MVT::i16, MVT::i1, Expand);
196   setTruncStoreAction(MVT::i8, MVT::i1, Expand);
197
198   // This is legal in NVPTX
199   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
200   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
201
202   // TRAP can be lowered to PTX trap
203   setOperationAction(ISD::TRAP, MVT::Other, Legal);
204
205   // Register custom handling for vector loads/stores
206   for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
207        ++i) {
208     MVT VT = (MVT::SimpleValueType) i;
209     if (IsPTXVectorType(VT)) {
210       setOperationAction(ISD::LOAD, VT, Custom);
211       setOperationAction(ISD::STORE, VT, Custom);
212       setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
213     }
214   }
215
216   // Custom handling for i8 intrinsics
217   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
218
219   setOperationAction(ISD::CTLZ, MVT::i16, Legal);
220   setOperationAction(ISD::CTLZ, MVT::i32, Legal);
221   setOperationAction(ISD::CTLZ, MVT::i64, Legal);
222   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i16, Legal);
223   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Legal);
224   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i64, Legal);
225   setOperationAction(ISD::CTTZ, MVT::i16, Expand);
226   setOperationAction(ISD::CTTZ, MVT::i32, Expand);
227   setOperationAction(ISD::CTTZ, MVT::i64, Expand);
228   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i16, Expand);
229   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i32, Expand);
230   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i64, Expand);
231   setOperationAction(ISD::CTPOP, MVT::i16, Legal);
232   setOperationAction(ISD::CTPOP, MVT::i32, Legal);
233   setOperationAction(ISD::CTPOP, MVT::i64, Legal);
234
235   // Now deduce the information based on the above mentioned
236   // actions
237   computeRegisterProperties();
238 }
239
240 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
241   switch (Opcode) {
242   default:
243     return 0;
244   case NVPTXISD::CALL:
245     return "NVPTXISD::CALL";
246   case NVPTXISD::RET_FLAG:
247     return "NVPTXISD::RET_FLAG";
248   case NVPTXISD::Wrapper:
249     return "NVPTXISD::Wrapper";
250   case NVPTXISD::NVBuiltin:
251     return "NVPTXISD::NVBuiltin";
252   case NVPTXISD::DeclareParam:
253     return "NVPTXISD::DeclareParam";
254   case NVPTXISD::DeclareScalarParam:
255     return "NVPTXISD::DeclareScalarParam";
256   case NVPTXISD::DeclareRet:
257     return "NVPTXISD::DeclareRet";
258   case NVPTXISD::DeclareRetParam:
259     return "NVPTXISD::DeclareRetParam";
260   case NVPTXISD::PrintCall:
261     return "NVPTXISD::PrintCall";
262   case NVPTXISD::LoadParam:
263     return "NVPTXISD::LoadParam";
264   case NVPTXISD::LoadParamV2:
265     return "NVPTXISD::LoadParamV2";
266   case NVPTXISD::LoadParamV4:
267     return "NVPTXISD::LoadParamV4";
268   case NVPTXISD::StoreParam:
269     return "NVPTXISD::StoreParam";
270   case NVPTXISD::StoreParamV2:
271     return "NVPTXISD::StoreParamV2";
272   case NVPTXISD::StoreParamV4:
273     return "NVPTXISD::StoreParamV4";
274   case NVPTXISD::StoreParamS32:
275     return "NVPTXISD::StoreParamS32";
276   case NVPTXISD::StoreParamU32:
277     return "NVPTXISD::StoreParamU32";
278   case NVPTXISD::CallArgBegin:
279     return "NVPTXISD::CallArgBegin";
280   case NVPTXISD::CallArg:
281     return "NVPTXISD::CallArg";
282   case NVPTXISD::LastCallArg:
283     return "NVPTXISD::LastCallArg";
284   case NVPTXISD::CallArgEnd:
285     return "NVPTXISD::CallArgEnd";
286   case NVPTXISD::CallVoid:
287     return "NVPTXISD::CallVoid";
288   case NVPTXISD::CallVal:
289     return "NVPTXISD::CallVal";
290   case NVPTXISD::CallSymbol:
291     return "NVPTXISD::CallSymbol";
292   case NVPTXISD::Prototype:
293     return "NVPTXISD::Prototype";
294   case NVPTXISD::MoveParam:
295     return "NVPTXISD::MoveParam";
296   case NVPTXISD::StoreRetval:
297     return "NVPTXISD::StoreRetval";
298   case NVPTXISD::StoreRetvalV2:
299     return "NVPTXISD::StoreRetvalV2";
300   case NVPTXISD::StoreRetvalV4:
301     return "NVPTXISD::StoreRetvalV4";
302   case NVPTXISD::PseudoUseParam:
303     return "NVPTXISD::PseudoUseParam";
304   case NVPTXISD::RETURN:
305     return "NVPTXISD::RETURN";
306   case NVPTXISD::CallSeqBegin:
307     return "NVPTXISD::CallSeqBegin";
308   case NVPTXISD::CallSeqEnd:
309     return "NVPTXISD::CallSeqEnd";
310   case NVPTXISD::LoadV2:
311     return "NVPTXISD::LoadV2";
312   case NVPTXISD::LoadV4:
313     return "NVPTXISD::LoadV4";
314   case NVPTXISD::LDGV2:
315     return "NVPTXISD::LDGV2";
316   case NVPTXISD::LDGV4:
317     return "NVPTXISD::LDGV4";
318   case NVPTXISD::LDUV2:
319     return "NVPTXISD::LDUV2";
320   case NVPTXISD::LDUV4:
321     return "NVPTXISD::LDUV4";
322   case NVPTXISD::StoreV2:
323     return "NVPTXISD::StoreV2";
324   case NVPTXISD::StoreV4:
325     return "NVPTXISD::StoreV4";
326   }
327 }
328
329 bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
330   return VT == MVT::i1;
331 }
332
333 SDValue
334 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
335   SDLoc dl(Op);
336   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
337   Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
338   return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
339 }
340
341 /*
342 std::string NVPTXTargetLowering::getPrototype(
343     Type *retTy, const ArgListTy &Args,
344     const SmallVectorImpl<ISD::OutputArg> &Outs, unsigned retAlignment) const {
345
346   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
347
348   std::stringstream O;
349   O << "prototype_" << uniqueCallSite << " : .callprototype ";
350
351   if (retTy->getTypeID() == Type::VoidTyID)
352     O << "()";
353   else {
354     O << "(";
355     if (isABI) {
356       if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
357         unsigned size = 0;
358         if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
359           size = ITy->getBitWidth();
360           if (size < 32)
361             size = 32;
362         } else {
363           assert(retTy->isFloatingPointTy() &&
364                  "Floating point type expected here");
365           size = retTy->getPrimitiveSizeInBits();
366         }
367
368         O << ".param .b" << size << " _";
369       } else if (isa<PointerType>(retTy))
370         O << ".param .b" << getPointerTy().getSizeInBits() << " _";
371       else {
372         if ((retTy->getTypeID() == Type::StructTyID) ||
373             isa<VectorType>(retTy)) {
374           SmallVector<EVT, 16> vtparts;
375           ComputeValueVTs(*this, retTy, vtparts);
376           unsigned totalsz = 0;
377           for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
378             unsigned elems = 1;
379             EVT elemtype = vtparts[i];
380             if (vtparts[i].isVector()) {
381               elems = vtparts[i].getVectorNumElements();
382               elemtype = vtparts[i].getVectorElementType();
383             }
384             for (unsigned j = 0, je = elems; j != je; ++j) {
385               unsigned sz = elemtype.getSizeInBits();
386               if (elemtype.isInteger() && (sz < 8))
387                 sz = 8;
388               totalsz += sz / 8;
389             }
390           }
391           O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
392         } else {
393           assert(false && "Unknown return type");
394         }
395       }
396     } else {
397       SmallVector<EVT, 16> vtparts;
398       ComputeValueVTs(*this, retTy, vtparts);
399       unsigned idx = 0;
400       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
401         unsigned elems = 1;
402         EVT elemtype = vtparts[i];
403         if (vtparts[i].isVector()) {
404           elems = vtparts[i].getVectorNumElements();
405           elemtype = vtparts[i].getVectorElementType();
406         }
407
408         for (unsigned j = 0, je = elems; j != je; ++j) {
409           unsigned sz = elemtype.getSizeInBits();
410           if (elemtype.isInteger() && (sz < 32))
411             sz = 32;
412           O << ".reg .b" << sz << " _";
413           if (j < je - 1)
414             O << ", ";
415           ++idx;
416         }
417         if (i < e - 1)
418           O << ", ";
419       }
420     }
421     O << ") ";
422   }
423   O << "_ (";
424
425   bool first = true;
426   MVT thePointerTy = getPointerTy();
427
428   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
429     const Type *Ty = Args[i].Ty;
430     if (!first) {
431       O << ", ";
432     }
433     first = false;
434
435     if (Outs[i].Flags.isByVal() == false) {
436       unsigned sz = 0;
437       if (isa<IntegerType>(Ty)) {
438         sz = cast<IntegerType>(Ty)->getBitWidth();
439         if (sz < 32)
440           sz = 32;
441       } else if (isa<PointerType>(Ty))
442         sz = thePointerTy.getSizeInBits();
443       else
444         sz = Ty->getPrimitiveSizeInBits();
445       if (isABI)
446         O << ".param .b" << sz << " ";
447       else
448         O << ".reg .b" << sz << " ";
449       O << "_";
450       continue;
451     }
452     const PointerType *PTy = dyn_cast<PointerType>(Ty);
453     assert(PTy && "Param with byval attribute should be a pointer type");
454     Type *ETy = PTy->getElementType();
455
456     if (isABI) {
457       unsigned align = Outs[i].Flags.getByValAlign();
458       unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
459       O << ".param .align " << align << " .b8 ";
460       O << "_";
461       O << "[" << sz << "]";
462       continue;
463     } else {
464       SmallVector<EVT, 16> vtparts;
465       ComputeValueVTs(*this, ETy, vtparts);
466       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
467         unsigned elems = 1;
468         EVT elemtype = vtparts[i];
469         if (vtparts[i].isVector()) {
470           elems = vtparts[i].getVectorNumElements();
471           elemtype = vtparts[i].getVectorElementType();
472         }
473
474         for (unsigned j = 0, je = elems; j != je; ++j) {
475           unsigned sz = elemtype.getSizeInBits();
476           if (elemtype.isInteger() && (sz < 32))
477             sz = 32;
478           O << ".reg .b" << sz << " ";
479           O << "_";
480           if (j < je - 1)
481             O << ", ";
482         }
483         if (i < e - 1)
484           O << ", ";
485       }
486       continue;
487     }
488   }
489   O << ");";
490   return O.str();
491 }*/
492
493 std::string
494 NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
495                                   const SmallVectorImpl<ISD::OutputArg> &Outs,
496                                   unsigned retAlignment,
497                                   const ImmutableCallSite *CS) const {
498
499   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
500   assert(isABI && "Non-ABI compilation is not supported");
501   if (!isABI)
502     return "";
503
504   std::stringstream O;
505   O << "prototype_" << uniqueCallSite << " : .callprototype ";
506
507   if (retTy->getTypeID() == Type::VoidTyID) {
508     O << "()";
509   } else {
510     O << "(";
511     if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
512       unsigned size = 0;
513       if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
514         size = ITy->getBitWidth();
515         if (size < 32)
516           size = 32;
517       } else {
518         assert(retTy->isFloatingPointTy() &&
519                "Floating point type expected here");
520         size = retTy->getPrimitiveSizeInBits();
521       }
522
523       O << ".param .b" << size << " _";
524     } else if (isa<PointerType>(retTy)) {
525       O << ".param .b" << getPointerTy().getSizeInBits() << " _";
526     } else {
527       if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) {
528         SmallVector<EVT, 16> vtparts;
529         ComputeValueVTs(*this, retTy, vtparts);
530         unsigned totalsz = 0;
531         for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
532           unsigned elems = 1;
533           EVT elemtype = vtparts[i];
534           if (vtparts[i].isVector()) {
535             elems = vtparts[i].getVectorNumElements();
536             elemtype = vtparts[i].getVectorElementType();
537           }
538           // TODO: no need to loop
539           for (unsigned j = 0, je = elems; j != je; ++j) {
540             unsigned sz = elemtype.getSizeInBits();
541             if (elemtype.isInteger() && (sz < 8))
542               sz = 8;
543             totalsz += sz / 8;
544           }
545         }
546         O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
547       } else {
548         assert(false && "Unknown return type");
549       }
550     }
551     O << ") ";
552   }
553   O << "_ (";
554
555   bool first = true;
556   MVT thePointerTy = getPointerTy();
557
558   unsigned OIdx = 0;
559   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
560     Type *Ty = Args[i].Ty;
561     if (!first) {
562       O << ", ";
563     }
564     first = false;
565
566     if (Outs[OIdx].Flags.isByVal() == false) {
567       if (Ty->isAggregateType() || Ty->isVectorTy()) {
568         unsigned align = 0;
569         const CallInst *CallI = cast<CallInst>(CS->getInstruction());
570         const DataLayout *TD = getDataLayout();
571         // +1 because index 0 is reserved for return type alignment
572         if (!llvm::getAlign(*CallI, i + 1, align))
573           align = TD->getABITypeAlignment(Ty);
574         unsigned sz = TD->getTypeAllocSize(Ty);
575         O << ".param .align " << align << " .b8 ";
576         O << "_";
577         O << "[" << sz << "]";
578         // update the index for Outs
579         SmallVector<EVT, 16> vtparts;
580         ComputeValueVTs(*this, Ty, vtparts);
581         if (unsigned len = vtparts.size())
582           OIdx += len - 1;
583         continue;
584       }
585       assert(getValueType(Ty) == Outs[OIdx].VT &&
586              "type mismatch between callee prototype and arguments");
587       // scalar type
588       unsigned sz = 0;
589       if (isa<IntegerType>(Ty)) {
590         sz = cast<IntegerType>(Ty)->getBitWidth();
591         if (sz < 32)
592           sz = 32;
593       } else if (isa<PointerType>(Ty))
594         sz = thePointerTy.getSizeInBits();
595       else
596         sz = Ty->getPrimitiveSizeInBits();
597       O << ".param .b" << sz << " ";
598       O << "_";
599       continue;
600     }
601     const PointerType *PTy = dyn_cast<PointerType>(Ty);
602     assert(PTy && "Param with byval attribute should be a pointer type");
603     Type *ETy = PTy->getElementType();
604
605     unsigned align = Outs[OIdx].Flags.getByValAlign();
606     unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
607     O << ".param .align " << align << " .b8 ";
608     O << "_";
609     O << "[" << sz << "]";
610   }
611   O << ");";
612   return O.str();
613 }
614
615 unsigned
616 NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
617                                           const ImmutableCallSite *CS,
618                                           Type *Ty,
619                                           unsigned Idx) const {
620   const DataLayout *TD = getDataLayout();
621   unsigned align = 0;
622   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
623
624   if (Func) { // direct call
625     assert(CS->getCalledFunction() &&
626            "direct call cannot find callee");
627     if (!llvm::getAlign(*(CS->getCalledFunction()), Idx, align))
628       align = TD->getABITypeAlignment(Ty);
629   }
630   else { // indirect call
631     const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
632     if (!llvm::getAlign(*CallI, Idx, align))
633       align = TD->getABITypeAlignment(Ty);
634   }
635
636   return align;
637 }
638
639 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
640                                        SmallVectorImpl<SDValue> &InVals) const {
641   SelectionDAG &DAG = CLI.DAG;
642   SDLoc dl = CLI.DL;
643   SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
644   SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
645   SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
646   SDValue Chain = CLI.Chain;
647   SDValue Callee = CLI.Callee;
648   bool &isTailCall = CLI.IsTailCall;
649   ArgListTy &Args = CLI.Args;
650   Type *retTy = CLI.RetTy;
651   ImmutableCallSite *CS = CLI.CS;
652
653   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
654   assert(isABI && "Non-ABI compilation is not supported");
655   if (!isABI)
656     return Chain;
657   const DataLayout *TD = getDataLayout();
658   MachineFunction &MF = DAG.getMachineFunction();
659   const Function *F = MF.getFunction();
660
661   SDValue tempChain = Chain;
662   Chain =
663       DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
664                            dl);
665   SDValue InFlag = Chain.getValue(1);
666
667   unsigned paramCount = 0;
668   // Args.size() and Outs.size() need not match.
669   // Outs.size() will be larger
670   //   * if there is an aggregate argument with multiple fields (each field
671   //     showing up separately in Outs)
672   //   * if there is a vector argument with more than typical vector-length
673   //     elements (generally if more than 4) where each vector element is
674   //     individually present in Outs.
675   // So a different index should be used for indexing into Outs/OutVals.
676   // See similar issue in LowerFormalArguments.
677   unsigned OIdx = 0;
678   // Declare the .params or .reg need to pass values
679   // to the function
680   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
681     EVT VT = Outs[OIdx].VT;
682     Type *Ty = Args[i].Ty;
683
684     if (Outs[OIdx].Flags.isByVal() == false) {
685       if (Ty->isAggregateType()) {
686         // aggregate
687         SmallVector<EVT, 16> vtparts;
688         ComputeValueVTs(*this, Ty, vtparts);
689
690         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
691         // declare .param .align <align> .b8 .param<n>[<size>];
692         unsigned sz = TD->getTypeAllocSize(Ty);
693         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
694         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
695                                       DAG.getConstant(paramCount, MVT::i32),
696                                       DAG.getConstant(sz, MVT::i32), InFlag };
697         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
698                             DeclareParamOps, 5);
699         InFlag = Chain.getValue(1);
700         unsigned curOffset = 0;
701         for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
702           unsigned elems = 1;
703           EVT elemtype = vtparts[j];
704           if (vtparts[j].isVector()) {
705             elems = vtparts[j].getVectorNumElements();
706             elemtype = vtparts[j].getVectorElementType();
707           }
708           for (unsigned k = 0, ke = elems; k != ke; ++k) {
709             unsigned sz = elemtype.getSizeInBits();
710             if (elemtype.isInteger() && (sz < 8))
711               sz = 8;
712             SDValue StVal = OutVals[OIdx];
713             if (elemtype.getSizeInBits() < 16) {
714               StVal = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i16, StVal);
715             }
716             SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
717             SDValue CopyParamOps[] = { Chain,
718                                        DAG.getConstant(paramCount, MVT::i32),
719                                        DAG.getConstant(curOffset, MVT::i32),
720                                        StVal, InFlag };
721             Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
722                                             CopyParamVTs, &CopyParamOps[0], 5,
723                                             elemtype, MachinePointerInfo());
724             InFlag = Chain.getValue(1);
725             curOffset += sz / 8;
726             ++OIdx;
727           }
728         }
729         if (vtparts.size() > 0)
730           --OIdx;
731         ++paramCount;
732         continue;
733       }
734       if (Ty->isVectorTy()) {
735         EVT ObjectVT = getValueType(Ty);
736         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
737         // declare .param .align <align> .b8 .param<n>[<size>];
738         unsigned sz = TD->getTypeAllocSize(Ty);
739         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
740         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
741                                       DAG.getConstant(paramCount, MVT::i32),
742                                       DAG.getConstant(sz, MVT::i32), InFlag };
743         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
744                             DeclareParamOps, 5);
745         InFlag = Chain.getValue(1);
746         unsigned NumElts = ObjectVT.getVectorNumElements();
747         EVT EltVT = ObjectVT.getVectorElementType();
748         EVT MemVT = EltVT;
749         bool NeedExtend = false;
750         if (EltVT.getSizeInBits() < 16) {
751           NeedExtend = true;
752           EltVT = MVT::i16;
753         }
754
755         // V1 store
756         if (NumElts == 1) {
757           SDValue Elt = OutVals[OIdx++];
758           if (NeedExtend)
759             Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt);
760
761           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
762           SDValue CopyParamOps[] = { Chain,
763                                      DAG.getConstant(paramCount, MVT::i32),
764                                      DAG.getConstant(0, MVT::i32), Elt,
765                                      InFlag };
766           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
767                                           CopyParamVTs, &CopyParamOps[0], 5,
768                                           MemVT, MachinePointerInfo());
769           InFlag = Chain.getValue(1);
770         } else if (NumElts == 2) {
771           SDValue Elt0 = OutVals[OIdx++];
772           SDValue Elt1 = OutVals[OIdx++];
773           if (NeedExtend) {
774             Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0);
775             Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1);
776           }
777
778           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
779           SDValue CopyParamOps[] = { Chain,
780                                      DAG.getConstant(paramCount, MVT::i32),
781                                      DAG.getConstant(0, MVT::i32), Elt0, Elt1,
782                                      InFlag };
783           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl,
784                                           CopyParamVTs, &CopyParamOps[0], 6,
785                                           MemVT, MachinePointerInfo());
786           InFlag = Chain.getValue(1);
787         } else {
788           unsigned curOffset = 0;
789           // V4 stores
790           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
791           // the
792           // vector will be expanded to a power of 2 elements, so we know we can
793           // always round up to the next multiple of 4 when creating the vector
794           // stores.
795           // e.g.  4 elem => 1 st.v4
796           //       6 elem => 2 st.v4
797           //       8 elem => 2 st.v4
798           //      11 elem => 3 st.v4
799           unsigned VecSize = 4;
800           if (EltVT.getSizeInBits() == 64)
801             VecSize = 2;
802
803           // This is potentially only part of a vector, so assume all elements
804           // are packed together.
805           unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize;
806
807           for (unsigned i = 0; i < NumElts; i += VecSize) {
808             // Get values
809             SDValue StoreVal;
810             SmallVector<SDValue, 8> Ops;
811             Ops.push_back(Chain);
812             Ops.push_back(DAG.getConstant(paramCount, MVT::i32));
813             Ops.push_back(DAG.getConstant(curOffset, MVT::i32));
814
815             unsigned Opc = NVPTXISD::StoreParamV2;
816
817             StoreVal = OutVals[OIdx++];
818             if (NeedExtend)
819               StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
820             Ops.push_back(StoreVal);
821
822             if (i + 1 < NumElts) {
823               StoreVal = OutVals[OIdx++];
824               if (NeedExtend)
825                 StoreVal =
826                     DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
827             } else {
828               StoreVal = DAG.getUNDEF(EltVT);
829             }
830             Ops.push_back(StoreVal);
831
832             if (VecSize == 4) {
833               Opc = NVPTXISD::StoreParamV4;
834               if (i + 2 < NumElts) {
835                 StoreVal = OutVals[OIdx++];
836                 if (NeedExtend)
837                   StoreVal =
838                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
839               } else {
840                 StoreVal = DAG.getUNDEF(EltVT);
841               }
842               Ops.push_back(StoreVal);
843
844               if (i + 3 < NumElts) {
845                 StoreVal = OutVals[OIdx++];
846                 if (NeedExtend)
847                   StoreVal =
848                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
849               } else {
850                 StoreVal = DAG.getUNDEF(EltVT);
851               }
852               Ops.push_back(StoreVal);
853             }
854
855             SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
856             Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, &Ops[0],
857                                             Ops.size(), MemVT,
858                                             MachinePointerInfo());
859             InFlag = Chain.getValue(1);
860             curOffset += PerStoreOffset;
861           }
862         }
863         ++paramCount;
864         --OIdx;
865         continue;
866       }
867       // Plain scalar
868       // for ABI,    declare .param .b<size> .param<n>;
869       unsigned sz = VT.getSizeInBits();
870       bool needExtend = false;
871       if (VT.isInteger()) {
872         if (sz < 16)
873           needExtend = true;
874         if (sz < 32)
875           sz = 32;
876       }
877       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
878       SDValue DeclareParamOps[] = { Chain,
879                                     DAG.getConstant(paramCount, MVT::i32),
880                                     DAG.getConstant(sz, MVT::i32),
881                                     DAG.getConstant(0, MVT::i32), InFlag };
882       Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
883                           DeclareParamOps, 5);
884       InFlag = Chain.getValue(1);
885       SDValue OutV = OutVals[OIdx];
886       if (needExtend) {
887         // zext/sext i1 to i16
888         unsigned opc = ISD::ZERO_EXTEND;
889         if (Outs[OIdx].Flags.isSExt())
890           opc = ISD::SIGN_EXTEND;
891         OutV = DAG.getNode(opc, dl, MVT::i16, OutV);
892       }
893       SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
894       SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
895                                  DAG.getConstant(0, MVT::i32), OutV, InFlag };
896
897       unsigned opcode = NVPTXISD::StoreParam;
898       if (Outs[OIdx].Flags.isZExt())
899         opcode = NVPTXISD::StoreParamU32;
900       else if (Outs[OIdx].Flags.isSExt())
901         opcode = NVPTXISD::StoreParamS32;
902       Chain = DAG.getMemIntrinsicNode(opcode, dl, CopyParamVTs, CopyParamOps, 5,
903                                       VT, MachinePointerInfo());
904
905       InFlag = Chain.getValue(1);
906       ++paramCount;
907       continue;
908     }
909     // struct or vector
910     SmallVector<EVT, 16> vtparts;
911     const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
912     assert(PTy && "Type of a byval parameter should be pointer");
913     ComputeValueVTs(*this, PTy->getElementType(), vtparts);
914
915     // declare .param .align <align> .b8 .param<n>[<size>];
916     unsigned sz = Outs[OIdx].Flags.getByValSize();
917     SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
918     // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
919     // so we don't need to worry about natural alignment or not.
920     // See TargetLowering::LowerCallTo().
921     SDValue DeclareParamOps[] = {
922       Chain, DAG.getConstant(Outs[OIdx].Flags.getByValAlign(), MVT::i32),
923       DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32),
924       InFlag
925     };
926     Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
927                         DeclareParamOps, 5);
928     InFlag = Chain.getValue(1);
929     unsigned curOffset = 0;
930     for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
931       unsigned elems = 1;
932       EVT elemtype = vtparts[j];
933       if (vtparts[j].isVector()) {
934         elems = vtparts[j].getVectorNumElements();
935         elemtype = vtparts[j].getVectorElementType();
936       }
937       for (unsigned k = 0, ke = elems; k != ke; ++k) {
938         unsigned sz = elemtype.getSizeInBits();
939         if (elemtype.isInteger() && (sz < 8))
940           sz = 8;
941         SDValue srcAddr =
942             DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
943                         DAG.getConstant(curOffset, getPointerTy()));
944         SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
945                                      MachinePointerInfo(), false, false, false,
946                                      0);
947         if (elemtype.getSizeInBits() < 16) {
948           theVal = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i16, theVal);
949         }
950         SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
951         SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
952                                    DAG.getConstant(curOffset, MVT::i32), theVal,
953                                    InFlag };
954         Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
955                                         CopyParamOps, 5, elemtype,
956                                         MachinePointerInfo());
957
958         InFlag = Chain.getValue(1);
959         curOffset += sz / 8;
960       }
961     }
962     ++paramCount;
963   }
964
965   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
966   unsigned retAlignment = 0;
967
968   // Handle Result
969   if (Ins.size() > 0) {
970     SmallVector<EVT, 16> resvtparts;
971     ComputeValueVTs(*this, retTy, resvtparts);
972
973     // Declare
974     //  .param .align 16 .b8 retval0[<size-in-bytes>], or
975     //  .param .b<size-in-bits> retval0
976     unsigned resultsz = TD->getTypeAllocSizeInBits(retTy);
977     if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
978         retTy->isPointerTy()) {
979       // Scalar needs to be at least 32bit wide
980       if (resultsz < 32)
981         resultsz = 32;
982       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
983       SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
984                                   DAG.getConstant(resultsz, MVT::i32),
985                                   DAG.getConstant(0, MVT::i32), InFlag };
986       Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
987                           DeclareRetOps, 5);
988       InFlag = Chain.getValue(1);
989     } else {
990       retAlignment = getArgumentAlignment(Callee, CS, retTy, 0);
991       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
992       SDValue DeclareRetOps[] = { Chain,
993                                   DAG.getConstant(retAlignment, MVT::i32),
994                                   DAG.getConstant(resultsz / 8, MVT::i32),
995                                   DAG.getConstant(0, MVT::i32), InFlag };
996       Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
997                           DeclareRetOps, 5);
998       InFlag = Chain.getValue(1);
999     }
1000   }
1001
1002   if (!Func) {
1003     // This is indirect function call case : PTX requires a prototype of the
1004     // form
1005     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1006     // to be emitted, and the label has to used as the last arg of call
1007     // instruction.
1008     // The prototype is embedded in a string and put as the operand for an
1009     // INLINEASM SDNode.
1010     SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1011     std::string proto_string =
1012         getPrototype(retTy, Args, Outs, retAlignment, CS);
1013     const char *asmstr = nvTM->getManagedStrPool()
1014         ->getManagedString(proto_string.c_str())->c_str();
1015     SDValue InlineAsmOps[] = {
1016       Chain, DAG.getTargetExternalSymbol(asmstr, getPointerTy()),
1017       DAG.getMDNode(0), DAG.getTargetConstant(0, MVT::i32), InFlag
1018     };
1019     Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
1020     InFlag = Chain.getValue(1);
1021   }
1022   // Op to just print "call"
1023   SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1024   SDValue PrintCallOps[] = {
1025     Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, MVT::i32), InFlag
1026   };
1027   Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
1028                       dl, PrintCallVTs, PrintCallOps, 3);
1029   InFlag = Chain.getValue(1);
1030
1031   // Ops to print out the function name
1032   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1033   SDValue CallVoidOps[] = { Chain, Callee, InFlag };
1034   Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
1035   InFlag = Chain.getValue(1);
1036
1037   // Ops to print out the param list
1038   SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1039   SDValue CallArgBeginOps[] = { Chain, InFlag };
1040   Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
1041                       CallArgBeginOps, 2);
1042   InFlag = Chain.getValue(1);
1043
1044   for (unsigned i = 0, e = paramCount; i != e; ++i) {
1045     unsigned opcode;
1046     if (i == (e - 1))
1047       opcode = NVPTXISD::LastCallArg;
1048     else
1049       opcode = NVPTXISD::CallArg;
1050     SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1051     SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
1052                              DAG.getConstant(i, MVT::i32), InFlag };
1053     Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
1054     InFlag = Chain.getValue(1);
1055   }
1056   SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1057   SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32),
1058                               InFlag };
1059   Chain =
1060       DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps, 3);
1061   InFlag = Chain.getValue(1);
1062
1063   if (!Func) {
1064     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1065     SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32),
1066                                InFlag };
1067     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
1068     InFlag = Chain.getValue(1);
1069   }
1070
1071   // Generate loads from param memory/moves from registers for result
1072   if (Ins.size() > 0) {
1073     unsigned resoffset = 0;
1074     if (retTy && retTy->isVectorTy()) {
1075       EVT ObjectVT = getValueType(retTy);
1076       unsigned NumElts = ObjectVT.getVectorNumElements();
1077       EVT EltVT = ObjectVT.getVectorElementType();
1078       assert(nvTM->getTargetLowering()->getNumRegisters(F->getContext(),
1079                                                         ObjectVT) == NumElts &&
1080              "Vector was not scalarized");
1081       unsigned sz = EltVT.getSizeInBits();
1082       bool needTruncate = sz < 16 ? true : false;
1083
1084       if (NumElts == 1) {
1085         // Just a simple load
1086         std::vector<EVT> LoadRetVTs;
1087         if (needTruncate) {
1088           // If loading i1 result, generate
1089           //   load i16
1090           //   trunc i16 to i1
1091           LoadRetVTs.push_back(MVT::i16);
1092         } else
1093           LoadRetVTs.push_back(EltVT);
1094         LoadRetVTs.push_back(MVT::Other);
1095         LoadRetVTs.push_back(MVT::Glue);
1096         std::vector<SDValue> LoadRetOps;
1097         LoadRetOps.push_back(Chain);
1098         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1099         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1100         LoadRetOps.push_back(InFlag);
1101         SDValue retval = DAG.getMemIntrinsicNode(
1102             NVPTXISD::LoadParam, dl,
1103             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1104             LoadRetOps.size(), EltVT, MachinePointerInfo());
1105         Chain = retval.getValue(1);
1106         InFlag = retval.getValue(2);
1107         SDValue Ret0 = retval;
1108         if (needTruncate)
1109           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Ret0);
1110         InVals.push_back(Ret0);
1111       } else if (NumElts == 2) {
1112         // LoadV2
1113         std::vector<EVT> LoadRetVTs;
1114         if (needTruncate) {
1115           // If loading i1 result, generate
1116           //   load i16
1117           //   trunc i16 to i1
1118           LoadRetVTs.push_back(MVT::i16);
1119           LoadRetVTs.push_back(MVT::i16);
1120         } else {
1121           LoadRetVTs.push_back(EltVT);
1122           LoadRetVTs.push_back(EltVT);
1123         }
1124         LoadRetVTs.push_back(MVT::Other);
1125         LoadRetVTs.push_back(MVT::Glue);
1126         std::vector<SDValue> LoadRetOps;
1127         LoadRetOps.push_back(Chain);
1128         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1129         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1130         LoadRetOps.push_back(InFlag);
1131         SDValue retval = DAG.getMemIntrinsicNode(
1132             NVPTXISD::LoadParamV2, dl,
1133             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1134             LoadRetOps.size(), EltVT, MachinePointerInfo());
1135         Chain = retval.getValue(2);
1136         InFlag = retval.getValue(3);
1137         SDValue Ret0 = retval.getValue(0);
1138         SDValue Ret1 = retval.getValue(1);
1139         if (needTruncate) {
1140           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret0);
1141           InVals.push_back(Ret0);
1142           Ret1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret1);
1143           InVals.push_back(Ret1);
1144         } else {
1145           InVals.push_back(Ret0);
1146           InVals.push_back(Ret1);
1147         }
1148       } else {
1149         // Split into N LoadV4
1150         unsigned Ofst = 0;
1151         unsigned VecSize = 4;
1152         unsigned Opc = NVPTXISD::LoadParamV4;
1153         if (EltVT.getSizeInBits() == 64) {
1154           VecSize = 2;
1155           Opc = NVPTXISD::LoadParamV2;
1156         }
1157         EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1158         for (unsigned i = 0; i < NumElts; i += VecSize) {
1159           SmallVector<EVT, 8> LoadRetVTs;
1160           if (needTruncate) {
1161             // If loading i1 result, generate
1162             //   load i16
1163             //   trunc i16 to i1
1164             for (unsigned j = 0; j < VecSize; ++j)
1165               LoadRetVTs.push_back(MVT::i16);
1166           } else {
1167             for (unsigned j = 0; j < VecSize; ++j)
1168               LoadRetVTs.push_back(EltVT);
1169           }
1170           LoadRetVTs.push_back(MVT::Other);
1171           LoadRetVTs.push_back(MVT::Glue);
1172           SmallVector<SDValue, 4> LoadRetOps;
1173           LoadRetOps.push_back(Chain);
1174           LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1175           LoadRetOps.push_back(DAG.getConstant(Ofst, MVT::i32));
1176           LoadRetOps.push_back(InFlag);
1177           SDValue retval = DAG.getMemIntrinsicNode(
1178               Opc, dl, DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()),
1179               &LoadRetOps[0], LoadRetOps.size(), EltVT, MachinePointerInfo());
1180           if (VecSize == 2) {
1181             Chain = retval.getValue(2);
1182             InFlag = retval.getValue(3);
1183           } else {
1184             Chain = retval.getValue(4);
1185             InFlag = retval.getValue(5);
1186           }
1187
1188           for (unsigned j = 0; j < VecSize; ++j) {
1189             if (i + j >= NumElts)
1190               break;
1191             SDValue Elt = retval.getValue(j);
1192             if (needTruncate)
1193               Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
1194             InVals.push_back(Elt);
1195           }
1196           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1197         }
1198       }
1199     } else {
1200       SmallVector<EVT, 16> VTs;
1201       ComputePTXValueVTs(*this, retTy, VTs);
1202       assert(VTs.size() == Ins.size() && "Bad value decomposition");
1203       for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
1204         unsigned sz = VTs[i].getSizeInBits();
1205         bool needTruncate = sz < 8 ? true : false;
1206         if (VTs[i].isInteger() && (sz < 8))
1207           sz = 8;
1208
1209         SmallVector<EVT, 4> LoadRetVTs;
1210         EVT TheLoadType = VTs[i];
1211         if (retTy->isIntegerTy() &&
1212             TD->getTypeAllocSizeInBits(retTy) < 32) {
1213           // This is for integer types only, and specifically not for
1214           // aggregates.
1215           LoadRetVTs.push_back(MVT::i32);
1216           TheLoadType = MVT::i32;
1217         } else if (sz < 16) {
1218           // If loading i1/i8 result, generate
1219           //   load i8 (-> i16)
1220           //   trunc i16 to i1/i8
1221           LoadRetVTs.push_back(MVT::i16);
1222         } else
1223           LoadRetVTs.push_back(Ins[i].VT);
1224         LoadRetVTs.push_back(MVT::Other);
1225         LoadRetVTs.push_back(MVT::Glue);
1226
1227         SmallVector<SDValue, 4> LoadRetOps;
1228         LoadRetOps.push_back(Chain);
1229         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1230         LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
1231         LoadRetOps.push_back(InFlag);
1232         SDValue retval = DAG.getMemIntrinsicNode(
1233             NVPTXISD::LoadParam, dl,
1234             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1235             LoadRetOps.size(), TheLoadType, MachinePointerInfo());
1236         Chain = retval.getValue(1);
1237         InFlag = retval.getValue(2);
1238         SDValue Ret0 = retval.getValue(0);
1239         if (needTruncate)
1240           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
1241         InVals.push_back(Ret0);
1242         resoffset += sz / 8;
1243       }
1244     }
1245   }
1246
1247   Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
1248                              DAG.getIntPtrConstant(uniqueCallSite + 1, true),
1249                              InFlag, dl);
1250   uniqueCallSite++;
1251
1252   // set isTailCall to false for now, until we figure out how to express
1253   // tail call optimization in PTX
1254   isTailCall = false;
1255   return Chain;
1256 }
1257
1258 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1259 // (see LegalizeDAG.cpp). This is slow and uses local memory.
1260 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1261 SDValue
1262 NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1263   SDNode *Node = Op.getNode();
1264   SDLoc dl(Node);
1265   SmallVector<SDValue, 8> Ops;
1266   unsigned NumOperands = Node->getNumOperands();
1267   for (unsigned i = 0; i < NumOperands; ++i) {
1268     SDValue SubOp = Node->getOperand(i);
1269     EVT VVT = SubOp.getNode()->getValueType(0);
1270     EVT EltVT = VVT.getVectorElementType();
1271     unsigned NumSubElem = VVT.getVectorNumElements();
1272     for (unsigned j = 0; j < NumSubElem; ++j) {
1273       Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
1274                                 DAG.getIntPtrConstant(j)));
1275     }
1276   }
1277   return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), &Ops[0],
1278                      Ops.size());
1279 }
1280
1281 SDValue
1282 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1283   switch (Op.getOpcode()) {
1284   case ISD::RETURNADDR:
1285     return SDValue();
1286   case ISD::FRAMEADDR:
1287     return SDValue();
1288   case ISD::GlobalAddress:
1289     return LowerGlobalAddress(Op, DAG);
1290   case ISD::INTRINSIC_W_CHAIN:
1291     return Op;
1292   case ISD::BUILD_VECTOR:
1293   case ISD::EXTRACT_SUBVECTOR:
1294     return Op;
1295   case ISD::CONCAT_VECTORS:
1296     return LowerCONCAT_VECTORS(Op, DAG);
1297   case ISD::STORE:
1298     return LowerSTORE(Op, DAG);
1299   case ISD::LOAD:
1300     return LowerLOAD(Op, DAG);
1301   default:
1302     llvm_unreachable("Custom lowering not defined for operation");
1303   }
1304 }
1305
1306 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
1307   if (Op.getValueType() == MVT::i1)
1308     return LowerLOADi1(Op, DAG);
1309   else
1310     return SDValue();
1311 }
1312
1313 // v = ld i1* addr
1314 //   =>
1315 // v1 = ld i8* addr (-> i16)
1316 // v = trunc i16 to i1
1317 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
1318   SDNode *Node = Op.getNode();
1319   LoadSDNode *LD = cast<LoadSDNode>(Node);
1320   SDLoc dl(Node);
1321   assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
1322   assert(Node->getValueType(0) == MVT::i1 &&
1323          "Custom lowering for i1 load only");
1324   SDValue newLD =
1325       DAG.getLoad(MVT::i16, dl, LD->getChain(), LD->getBasePtr(),
1326                   LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(),
1327                   LD->isInvariant(), LD->getAlignment());
1328   SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
1329   // The legalizer (the caller) is expecting two values from the legalized
1330   // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
1331   // in LegalizeDAG.cpp which also uses MergeValues.
1332   SDValue Ops[] = { result, LD->getChain() };
1333   return DAG.getMergeValues(Ops, 2, dl);
1334 }
1335
1336 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
1337   EVT ValVT = Op.getOperand(1).getValueType();
1338   if (ValVT == MVT::i1)
1339     return LowerSTOREi1(Op, DAG);
1340   else if (ValVT.isVector())
1341     return LowerSTOREVector(Op, DAG);
1342   else
1343     return SDValue();
1344 }
1345
1346 SDValue
1347 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
1348   SDNode *N = Op.getNode();
1349   SDValue Val = N->getOperand(1);
1350   SDLoc DL(N);
1351   EVT ValVT = Val.getValueType();
1352
1353   if (ValVT.isVector()) {
1354     // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1355     // legal.  We can (and should) split that into 2 stores of <2 x double> here
1356     // but I'm leaving that as a TODO for now.
1357     if (!ValVT.isSimple())
1358       return SDValue();
1359     switch (ValVT.getSimpleVT().SimpleTy) {
1360     default:
1361       return SDValue();
1362     case MVT::v2i8:
1363     case MVT::v2i16:
1364     case MVT::v2i32:
1365     case MVT::v2i64:
1366     case MVT::v2f32:
1367     case MVT::v2f64:
1368     case MVT::v4i8:
1369     case MVT::v4i16:
1370     case MVT::v4i32:
1371     case MVT::v4f32:
1372       // This is a "native" vector type
1373       break;
1374     }
1375
1376     unsigned Opcode = 0;
1377     EVT EltVT = ValVT.getVectorElementType();
1378     unsigned NumElts = ValVT.getVectorNumElements();
1379
1380     // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
1381     // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1382     // stored type to i16 and propogate the "real" type as the memory type.
1383     bool NeedSExt = false;
1384     if (EltVT.getSizeInBits() < 16)
1385       NeedSExt = true;
1386
1387     switch (NumElts) {
1388     default:
1389       return SDValue();
1390     case 2:
1391       Opcode = NVPTXISD::StoreV2;
1392       break;
1393     case 4: {
1394       Opcode = NVPTXISD::StoreV4;
1395       break;
1396     }
1397     }
1398
1399     SmallVector<SDValue, 8> Ops;
1400
1401     // First is the chain
1402     Ops.push_back(N->getOperand(0));
1403
1404     // Then the split values
1405     for (unsigned i = 0; i < NumElts; ++i) {
1406       SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
1407                                    DAG.getIntPtrConstant(i));
1408       if (NeedSExt)
1409         ExtVal = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i16, ExtVal);
1410       Ops.push_back(ExtVal);
1411     }
1412
1413     // Then any remaining arguments
1414     for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
1415       Ops.push_back(N->getOperand(i));
1416     }
1417
1418     MemSDNode *MemSD = cast<MemSDNode>(N);
1419
1420     SDValue NewSt = DAG.getMemIntrinsicNode(
1421         Opcode, DL, DAG.getVTList(MVT::Other), &Ops[0], Ops.size(),
1422         MemSD->getMemoryVT(), MemSD->getMemOperand());
1423
1424     //return DCI.CombineTo(N, NewSt, true);
1425     return NewSt;
1426   }
1427
1428   return SDValue();
1429 }
1430
1431 // st i1 v, addr
1432 //    =>
1433 // v1 = zxt v to i16
1434 // st.u8 i16, addr
1435 SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
1436   SDNode *Node = Op.getNode();
1437   SDLoc dl(Node);
1438   StoreSDNode *ST = cast<StoreSDNode>(Node);
1439   SDValue Tmp1 = ST->getChain();
1440   SDValue Tmp2 = ST->getBasePtr();
1441   SDValue Tmp3 = ST->getValue();
1442   assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
1443   unsigned Alignment = ST->getAlignment();
1444   bool isVolatile = ST->isVolatile();
1445   bool isNonTemporal = ST->isNonTemporal();
1446   Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
1447   SDValue Result = DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2,
1448                                      ST->getPointerInfo(), MVT::i8, isNonTemporal,
1449                                      isVolatile, Alignment);
1450   return Result;
1451 }
1452
1453 SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname,
1454                                         int idx, EVT v) const {
1455   std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
1456   std::stringstream suffix;
1457   suffix << idx;
1458   *name += suffix.str();
1459   return DAG.getTargetExternalSymbol(name->c_str(), v);
1460 }
1461
1462 SDValue
1463 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1464   return getExtSymb(DAG, ".PARAM", idx, v);
1465 }
1466
1467 SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1468   return getExtSymb(DAG, ".HLPPARAM", idx);
1469 }
1470
1471 // Check to see if the kernel argument is image*_t or sampler_t
1472
1473 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1474   static const char *const specialTypes[] = { "struct._image2d_t",
1475                                               "struct._image3d_t",
1476                                               "struct._sampler_t" };
1477
1478   const Type *Ty = arg->getType();
1479   const PointerType *PTy = dyn_cast<PointerType>(Ty);
1480
1481   if (!PTy)
1482     return false;
1483
1484   if (!context)
1485     return false;
1486
1487   const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1488   const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1489
1490   for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1491     if (TypeName == specialTypes[i])
1492       return true;
1493
1494   return false;
1495 }
1496
1497 SDValue NVPTXTargetLowering::LowerFormalArguments(
1498     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1499     const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG,
1500     SmallVectorImpl<SDValue> &InVals) const {
1501   MachineFunction &MF = DAG.getMachineFunction();
1502   const DataLayout *TD = getDataLayout();
1503
1504   const Function *F = MF.getFunction();
1505   const AttributeSet &PAL = F->getAttributes();
1506   const TargetLowering *TLI = nvTM->getTargetLowering();
1507
1508   SDValue Root = DAG.getRoot();
1509   std::vector<SDValue> OutChains;
1510
1511   bool isKernel = llvm::isKernelFunction(*F);
1512   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1513   assert(isABI && "Non-ABI compilation is not supported");
1514   if (!isABI)
1515     return Chain;
1516
1517   std::vector<Type *> argTypes;
1518   std::vector<const Argument *> theArgs;
1519   for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1520        I != E; ++I) {
1521     theArgs.push_back(I);
1522     argTypes.push_back(I->getType());
1523   }
1524   // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
1525   // Ins.size() will be larger
1526   //   * if there is an aggregate argument with multiple fields (each field
1527   //     showing up separately in Ins)
1528   //   * if there is a vector argument with more than typical vector-length
1529   //     elements (generally if more than 4) where each vector element is
1530   //     individually present in Ins.
1531   // So a different index should be used for indexing into Ins.
1532   // See similar issue in LowerCall.
1533   unsigned InsIdx = 0;
1534
1535   int idx = 0;
1536   for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) {
1537     Type *Ty = argTypes[i];
1538
1539     // If the kernel argument is image*_t or sampler_t, convert it to
1540     // a i32 constant holding the parameter position. This can later
1541     // matched in the AsmPrinter to output the correct mangled name.
1542     if (isImageOrSamplerVal(
1543             theArgs[i],
1544             (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent()
1545                                      : 0))) {
1546       assert(isKernel && "Only kernels can have image/sampler params");
1547       InVals.push_back(DAG.getConstant(i + 1, MVT::i32));
1548       continue;
1549     }
1550
1551     if (theArgs[i]->use_empty()) {
1552       // argument is dead
1553       if (Ty->isAggregateType()) {
1554         SmallVector<EVT, 16> vtparts;
1555
1556         ComputePTXValueVTs(*this, Ty, vtparts);
1557         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1558         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1559              ++parti) {
1560           EVT partVT = vtparts[parti];
1561           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, partVT));
1562           ++InsIdx;
1563         }
1564         if (vtparts.size() > 0)
1565           --InsIdx;
1566         continue;
1567       }
1568       if (Ty->isVectorTy()) {
1569         EVT ObjectVT = getValueType(Ty);
1570         unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
1571         for (unsigned parti = 0; parti < NumRegs; ++parti) {
1572           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1573           ++InsIdx;
1574         }
1575         if (NumRegs > 0)
1576           --InsIdx;
1577         continue;
1578       }
1579       InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1580       continue;
1581     }
1582
1583     // In the following cases, assign a node order of "idx+1"
1584     // to newly created nodes. The SDNodes for params have to
1585     // appear in the same order as their order of appearance
1586     // in the original function. "idx+1" holds that order.
1587     if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
1588       if (Ty->isAggregateType()) {
1589         SmallVector<EVT, 16> vtparts;
1590         SmallVector<uint64_t, 16> offsets;
1591
1592         // NOTE: Here, we lose the ability to issue vector loads for vectors
1593         // that are a part of a struct.  This should be investigated in the
1594         // future.
1595         ComputePTXValueVTs(*this, Ty, vtparts, &offsets, 0);
1596         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1597         bool aggregateIsPacked = false;
1598         if (StructType *STy = llvm::dyn_cast<StructType>(Ty))
1599           aggregateIsPacked = STy->isPacked();
1600
1601         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1602         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1603              ++parti) {
1604           EVT partVT = vtparts[parti];
1605           Value *srcValue = Constant::getNullValue(
1606               PointerType::get(partVT.getTypeForEVT(F->getContext()),
1607                                llvm::ADDRESS_SPACE_PARAM));
1608           SDValue srcAddr =
1609               DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1610                           DAG.getConstant(offsets[parti], getPointerTy()));
1611           unsigned partAlign =
1612               aggregateIsPacked ? 1
1613                                 : TD->getABITypeAlignment(
1614                                       partVT.getTypeForEVT(F->getContext()));
1615                     SDValue p;
1616           if (Ins[InsIdx].VT.getSizeInBits() > partVT.getSizeInBits())
1617             p = DAG.getExtLoad(ISD::SEXTLOAD, dl, Ins[InsIdx].VT, Root, srcAddr,
1618                                MachinePointerInfo(srcValue), partVT, false,
1619                                false, partAlign);
1620           else
1621             p = DAG.getLoad(partVT, dl, Root, srcAddr,
1622                             MachinePointerInfo(srcValue), false, false, false,
1623                             partAlign);
1624           if (p.getNode())
1625             p.getNode()->setIROrder(idx + 1);
1626           InVals.push_back(p);
1627           ++InsIdx;
1628         }
1629         if (vtparts.size() > 0)
1630           --InsIdx;
1631         continue;
1632       }
1633       if (Ty->isVectorTy()) {
1634         EVT ObjectVT = getValueType(Ty);
1635         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1636         unsigned NumElts = ObjectVT.getVectorNumElements();
1637         assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1638                "Vector was not scalarized");
1639         unsigned Ofst = 0;
1640         EVT EltVT = ObjectVT.getVectorElementType();
1641
1642         // V1 load
1643         // f32 = load ...
1644         if (NumElts == 1) {
1645           // We only have one element, so just directly load it
1646           Value *SrcValue = Constant::getNullValue(PointerType::get(
1647               EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1648           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1649                                         DAG.getConstant(Ofst, getPointerTy()));
1650           SDValue P = DAG.getLoad(
1651               EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1652               false, true,
1653               TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
1654           if (P.getNode())
1655             P.getNode()->setIROrder(idx + 1);
1656
1657           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1658             P = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, P);
1659           InVals.push_back(P);
1660           Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
1661           ++InsIdx;
1662         } else if (NumElts == 2) {
1663           // V2 load
1664           // f32,f32 = load ...
1665           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
1666           Value *SrcValue = Constant::getNullValue(PointerType::get(
1667               VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1668           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1669                                         DAG.getConstant(Ofst, getPointerTy()));
1670           SDValue P = DAG.getLoad(
1671               VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1672               false, true,
1673               TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1674           if (P.getNode())
1675             P.getNode()->setIROrder(idx + 1);
1676
1677           SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1678                                      DAG.getIntPtrConstant(0));
1679           SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1680                                      DAG.getIntPtrConstant(1));
1681
1682           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) {
1683             Elt0 = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, Elt0);
1684             Elt1 = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, Elt1);
1685           }
1686
1687           InVals.push_back(Elt0);
1688           InVals.push_back(Elt1);
1689           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1690           InsIdx += 2;
1691         } else {
1692           // V4 loads
1693           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
1694           // the
1695           // vector will be expanded to a power of 2 elements, so we know we can
1696           // always round up to the next multiple of 4 when creating the vector
1697           // loads.
1698           // e.g.  4 elem => 1 ld.v4
1699           //       6 elem => 2 ld.v4
1700           //       8 elem => 2 ld.v4
1701           //      11 elem => 3 ld.v4
1702           unsigned VecSize = 4;
1703           if (EltVT.getSizeInBits() == 64) {
1704             VecSize = 2;
1705           }
1706           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1707           for (unsigned i = 0; i < NumElts; i += VecSize) {
1708             Value *SrcValue = Constant::getNullValue(
1709                 PointerType::get(VecVT.getTypeForEVT(F->getContext()),
1710                                  llvm::ADDRESS_SPACE_PARAM));
1711             SDValue SrcAddr =
1712                 DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1713                             DAG.getConstant(Ofst, getPointerTy()));
1714             SDValue P = DAG.getLoad(
1715                 VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1716                 false, true,
1717                 TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1718             if (P.getNode())
1719               P.getNode()->setIROrder(idx + 1);
1720
1721             for (unsigned j = 0; j < VecSize; ++j) {
1722               if (i + j >= NumElts)
1723                 break;
1724               SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1725                                         DAG.getIntPtrConstant(j));
1726               if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1727                 Elt = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, Elt);
1728               InVals.push_back(Elt);
1729             }
1730             Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1731             InsIdx += VecSize;
1732           }
1733         }
1734
1735         if (NumElts > 0)
1736           --InsIdx;
1737         continue;
1738       }
1739       // A plain scalar.
1740       EVT ObjectVT = getValueType(Ty);
1741       // If ABI, load from the param symbol
1742       SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1743       Value *srcValue = Constant::getNullValue(PointerType::get(
1744           ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1745       SDValue p;
1746       if (ObjectVT.getSizeInBits() < Ins[InsIdx].VT.getSizeInBits())
1747         p = DAG.getExtLoad(ISD::SEXTLOAD, dl, Ins[InsIdx].VT, Root, Arg,
1748                            MachinePointerInfo(srcValue), ObjectVT, false, false,
1749               TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1750       else
1751         p = DAG.getLoad(Ins[InsIdx].VT, dl, Root, Arg,
1752                         MachinePointerInfo(srcValue), false, false, false,
1753               TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1754       if (p.getNode())
1755         p.getNode()->setIROrder(idx + 1);
1756       InVals.push_back(p);
1757       continue;
1758     }
1759
1760     // Param has ByVal attribute
1761     // Return MoveParam(param symbol).
1762     // Ideally, the param symbol can be returned directly,
1763     // but when SDNode builder decides to use it in a CopyToReg(),
1764     // machine instruction fails because TargetExternalSymbol
1765     // (not lowered) is target dependent, and CopyToReg assumes
1766     // the source is lowered.
1767     EVT ObjectVT = getValueType(Ty);
1768     assert(ObjectVT == Ins[InsIdx].VT &&
1769            "Ins type did not match function type");
1770     SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1771     SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1772     if (p.getNode())
1773       p.getNode()->setIROrder(idx + 1);
1774     if (isKernel)
1775       InVals.push_back(p);
1776     else {
1777       SDValue p2 = DAG.getNode(
1778           ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1779           DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
1780       InVals.push_back(p2);
1781     }
1782   }
1783
1784   // Clang will check explicit VarArg and issue error if any. However, Clang
1785   // will let code with
1786   // implicit var arg like f() pass. See bug 617733.
1787   // We treat this case as if the arg list is empty.
1788   // if (F.isVarArg()) {
1789   // assert(0 && "VarArg not supported yet!");
1790   //}
1791
1792   if (!OutChains.empty())
1793     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &OutChains[0],
1794                             OutChains.size()));
1795
1796   return Chain;
1797 }
1798
1799
1800 SDValue
1801 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1802                                  bool isVarArg,
1803                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
1804                                  const SmallVectorImpl<SDValue> &OutVals,
1805                                  SDLoc dl, SelectionDAG &DAG) const {
1806   MachineFunction &MF = DAG.getMachineFunction();
1807   const Function *F = MF.getFunction();
1808   Type *RetTy = F->getReturnType();
1809   const DataLayout *TD = getDataLayout();
1810
1811   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1812   assert(isABI && "Non-ABI compilation is not supported");
1813   if (!isABI)
1814     return Chain;
1815
1816   if (VectorType *VTy = dyn_cast<VectorType>(RetTy)) {
1817     // If we have a vector type, the OutVals array will be the scalarized
1818     // components and we have combine them into 1 or more vector stores.
1819     unsigned NumElts = VTy->getNumElements();
1820     assert(NumElts == Outs.size() && "Bad scalarization of return value");
1821
1822     // const_cast can be removed in later LLVM versions
1823     EVT EltVT = getValueType(RetTy).getVectorElementType();
1824     bool NeedExtend = false;
1825     if (EltVT.getSizeInBits() < 16)
1826       NeedExtend = true;
1827
1828     // V1 store
1829     if (NumElts == 1) {
1830       SDValue StoreVal = OutVals[0];
1831       // We only have one element, so just directly store it
1832       if (NeedExtend)
1833         StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
1834       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal };
1835       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1836                                       DAG.getVTList(MVT::Other), &Ops[0], 3,
1837                                       EltVT, MachinePointerInfo());
1838
1839     } else if (NumElts == 2) {
1840       // V2 store
1841       SDValue StoreVal0 = OutVals[0];
1842       SDValue StoreVal1 = OutVals[1];
1843
1844       if (NeedExtend) {
1845         StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal0);
1846         StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal1);
1847       }
1848
1849       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal0,
1850                         StoreVal1 };
1851       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetvalV2, dl,
1852                                       DAG.getVTList(MVT::Other), &Ops[0], 4,
1853                                       EltVT, MachinePointerInfo());
1854     } else {
1855       // V4 stores
1856       // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
1857       // vector will be expanded to a power of 2 elements, so we know we can
1858       // always round up to the next multiple of 4 when creating the vector
1859       // stores.
1860       // e.g.  4 elem => 1 st.v4
1861       //       6 elem => 2 st.v4
1862       //       8 elem => 2 st.v4
1863       //      11 elem => 3 st.v4
1864
1865       unsigned VecSize = 4;
1866       if (OutVals[0].getValueType().getSizeInBits() == 64)
1867         VecSize = 2;
1868
1869       unsigned Offset = 0;
1870
1871       EVT VecVT =
1872           EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
1873       unsigned PerStoreOffset =
1874           TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1875
1876       for (unsigned i = 0; i < NumElts; i += VecSize) {
1877         // Get values
1878         SDValue StoreVal;
1879         SmallVector<SDValue, 8> Ops;
1880         Ops.push_back(Chain);
1881         Ops.push_back(DAG.getConstant(Offset, MVT::i32));
1882         unsigned Opc = NVPTXISD::StoreRetvalV2;
1883         EVT ExtendedVT = (NeedExtend) ? MVT::i16 : OutVals[0].getValueType();
1884
1885         StoreVal = OutVals[i];
1886         if (NeedExtend)
1887           StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1888         Ops.push_back(StoreVal);
1889
1890         if (i + 1 < NumElts) {
1891           StoreVal = OutVals[i + 1];
1892           if (NeedExtend)
1893             StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1894         } else {
1895           StoreVal = DAG.getUNDEF(ExtendedVT);
1896         }
1897         Ops.push_back(StoreVal);
1898
1899         if (VecSize == 4) {
1900           Opc = NVPTXISD::StoreRetvalV4;
1901           if (i + 2 < NumElts) {
1902             StoreVal = OutVals[i + 2];
1903             if (NeedExtend)
1904               StoreVal =
1905                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1906           } else {
1907             StoreVal = DAG.getUNDEF(ExtendedVT);
1908           }
1909           Ops.push_back(StoreVal);
1910
1911           if (i + 3 < NumElts) {
1912             StoreVal = OutVals[i + 3];
1913             if (NeedExtend)
1914               StoreVal =
1915                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1916           } else {
1917             StoreVal = DAG.getUNDEF(ExtendedVT);
1918           }
1919           Ops.push_back(StoreVal);
1920         }
1921
1922         // Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
1923         Chain =
1924             DAG.getMemIntrinsicNode(Opc, dl, DAG.getVTList(MVT::Other), &Ops[0],
1925                                     Ops.size(), EltVT, MachinePointerInfo());
1926         Offset += PerStoreOffset;
1927       }
1928     }
1929   } else {
1930     SmallVector<EVT, 16> ValVTs;
1931     // const_cast is necessary since we are still using an LLVM version from
1932     // before the type system re-write.
1933     ComputePTXValueVTs(*this, RetTy, ValVTs);
1934     assert(ValVTs.size() == OutVals.size() && "Bad return value decomposition");
1935
1936     unsigned SizeSoFar = 0;
1937     for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
1938       SDValue theVal = OutVals[i];
1939       EVT TheValType = theVal.getValueType();
1940       unsigned numElems = 1;
1941       if (TheValType.isVector())
1942         numElems = TheValType.getVectorNumElements();
1943       for (unsigned j = 0, je = numElems; j != je; ++j) {
1944         SDValue TmpVal = theVal;
1945         if (TheValType.isVector())
1946           TmpVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1947                                TheValType.getVectorElementType(), TmpVal,
1948                                DAG.getIntPtrConstant(j));
1949         EVT TheStoreType = ValVTs[i];
1950         if (RetTy->isIntegerTy() &&
1951             TD->getTypeAllocSizeInBits(RetTy) < 32) {
1952           // The following zero-extension is for integer types only, and
1953           // specifically not for aggregates.
1954           TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal);
1955           TheStoreType = MVT::i32;
1956         }
1957         else if (TmpVal.getValueType().getSizeInBits() < 16)
1958           TmpVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, TmpVal);
1959
1960         SDValue Ops[] = { Chain, DAG.getConstant(SizeSoFar, MVT::i32), TmpVal };
1961         Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1962                                         DAG.getVTList(MVT::Other), &Ops[0],
1963                                         3, TheStoreType,
1964                                         MachinePointerInfo());
1965         if(TheValType.isVector())
1966           SizeSoFar += 
1967             TheStoreType.getVectorElementType().getStoreSizeInBits() / 8;
1968         else
1969           SizeSoFar += TheStoreType.getStoreSizeInBits()/8;
1970       }
1971     }
1972   }
1973
1974   return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1975 }
1976
1977
1978 void NVPTXTargetLowering::LowerAsmOperandForConstraint(
1979     SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
1980     SelectionDAG &DAG) const {
1981   if (Constraint.length() > 1)
1982     return;
1983   else
1984     TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1985 }
1986
1987 // NVPTX suuport vector of legal types of any length in Intrinsics because the
1988 // NVPTX specific type legalizer
1989 // will legalize them to the PTX supported length.
1990 bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1991   if (isTypeLegal(VT))
1992     return true;
1993   if (VT.isVector()) {
1994     MVT eVT = VT.getVectorElementType();
1995     if (isTypeLegal(eVT))
1996       return true;
1997   }
1998   return false;
1999 }
2000
2001 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
2002 // TgtMemIntrinsic
2003 // because we need the information that is only available in the "Value" type
2004 // of destination
2005 // pointer. In particular, the address space information.
2006 bool NVPTXTargetLowering::getTgtMemIntrinsic(
2007     IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const {
2008   switch (Intrinsic) {
2009   default:
2010     return false;
2011
2012   case Intrinsic::nvvm_atomic_load_add_f32:
2013     Info.opc = ISD::INTRINSIC_W_CHAIN;
2014     Info.memVT = MVT::f32;
2015     Info.ptrVal = I.getArgOperand(0);
2016     Info.offset = 0;
2017     Info.vol = 0;
2018     Info.readMem = true;
2019     Info.writeMem = true;
2020     Info.align = 0;
2021     return true;
2022
2023   case Intrinsic::nvvm_atomic_load_inc_32:
2024   case Intrinsic::nvvm_atomic_load_dec_32:
2025     Info.opc = ISD::INTRINSIC_W_CHAIN;
2026     Info.memVT = MVT::i32;
2027     Info.ptrVal = I.getArgOperand(0);
2028     Info.offset = 0;
2029     Info.vol = 0;
2030     Info.readMem = true;
2031     Info.writeMem = true;
2032     Info.align = 0;
2033     return true;
2034
2035   case Intrinsic::nvvm_ldu_global_i:
2036   case Intrinsic::nvvm_ldu_global_f:
2037   case Intrinsic::nvvm_ldu_global_p:
2038
2039     Info.opc = ISD::INTRINSIC_W_CHAIN;
2040     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
2041       Info.memVT = getValueType(I.getType());
2042     else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
2043       Info.memVT = getValueType(I.getType());
2044     else
2045       Info.memVT = MVT::f32;
2046     Info.ptrVal = I.getArgOperand(0);
2047     Info.offset = 0;
2048     Info.vol = 0;
2049     Info.readMem = true;
2050     Info.writeMem = false;
2051     Info.align = 0;
2052     return true;
2053
2054   }
2055   return false;
2056 }
2057
2058 /// isLegalAddressingMode - Return true if the addressing mode represented
2059 /// by AM is legal for this target, for a load/store of the specified type.
2060 /// Used to guide target specific optimizations, like loop strength reduction
2061 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
2062 /// (CodeGenPrepare.cpp)
2063 bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
2064                                                 Type *Ty) const {
2065
2066   // AddrMode - This represents an addressing mode of:
2067   //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
2068   //
2069   // The legal address modes are
2070   // - [avar]
2071   // - [areg]
2072   // - [areg+immoff]
2073   // - [immAddr]
2074
2075   if (AM.BaseGV) {
2076     if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
2077       return false;
2078     return true;
2079   }
2080
2081   switch (AM.Scale) {
2082   case 0: // "r", "r+i" or "i" is allowed
2083     break;
2084   case 1:
2085     if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
2086       return false;
2087     // Otherwise we have r+i.
2088     break;
2089   default:
2090     // No scale > 1 is allowed
2091     return false;
2092   }
2093   return true;
2094 }
2095
2096 //===----------------------------------------------------------------------===//
2097 //                         NVPTX Inline Assembly Support
2098 //===----------------------------------------------------------------------===//
2099
2100 /// getConstraintType - Given a constraint letter, return the type of
2101 /// constraint it is for this target.
2102 NVPTXTargetLowering::ConstraintType
2103 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
2104   if (Constraint.size() == 1) {
2105     switch (Constraint[0]) {
2106     default:
2107       break;
2108     case 'r':
2109     case 'h':
2110     case 'c':
2111     case 'l':
2112     case 'f':
2113     case 'd':
2114     case '0':
2115     case 'N':
2116       return C_RegisterClass;
2117     }
2118   }
2119   return TargetLowering::getConstraintType(Constraint);
2120 }
2121
2122 std::pair<unsigned, const TargetRegisterClass *>
2123 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
2124                                                   MVT VT) const {
2125   if (Constraint.size() == 1) {
2126     switch (Constraint[0]) {
2127     case 'c':
2128       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2129     case 'h':
2130       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2131     case 'r':
2132       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
2133     case 'l':
2134     case 'N':
2135       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
2136     case 'f':
2137       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
2138     case 'd':
2139       return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
2140     }
2141   }
2142   return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
2143 }
2144
2145 /// getFunctionAlignment - Return the Log2 alignment of this function.
2146 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
2147   return 4;
2148 }
2149
2150 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
2151 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
2152                               SmallVectorImpl<SDValue> &Results) {
2153   EVT ResVT = N->getValueType(0);
2154   SDLoc DL(N);
2155
2156   assert(ResVT.isVector() && "Vector load must have vector type");
2157
2158   // We only handle "native" vector sizes for now, e.g. <4 x double> is not
2159   // legal.  We can (and should) split that into 2 loads of <2 x double> here
2160   // but I'm leaving that as a TODO for now.
2161   assert(ResVT.isSimple() && "Can only handle simple types");
2162   switch (ResVT.getSimpleVT().SimpleTy) {
2163   default:
2164     return;
2165   case MVT::v2i8:
2166   case MVT::v2i16:
2167   case MVT::v2i32:
2168   case MVT::v2i64:
2169   case MVT::v2f32:
2170   case MVT::v2f64:
2171   case MVT::v4i8:
2172   case MVT::v4i16:
2173   case MVT::v4i32:
2174   case MVT::v4f32:
2175     // This is a "native" vector type
2176     break;
2177   }
2178
2179   EVT EltVT = ResVT.getVectorElementType();
2180   unsigned NumElts = ResVT.getVectorNumElements();
2181
2182   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
2183   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2184   // loaded type to i16 and propogate the "real" type as the memory type.
2185   bool NeedTrunc = false;
2186   if (EltVT.getSizeInBits() < 16) {
2187     EltVT = MVT::i16;
2188     NeedTrunc = true;
2189   }
2190
2191   unsigned Opcode = 0;
2192   SDVTList LdResVTs;
2193
2194   switch (NumElts) {
2195   default:
2196     return;
2197   case 2:
2198     Opcode = NVPTXISD::LoadV2;
2199     LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2200     break;
2201   case 4: {
2202     Opcode = NVPTXISD::LoadV4;
2203     EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2204     LdResVTs = DAG.getVTList(ListVTs, 5);
2205     break;
2206   }
2207   }
2208
2209   SmallVector<SDValue, 8> OtherOps;
2210
2211   // Copy regular operands
2212   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2213     OtherOps.push_back(N->getOperand(i));
2214
2215   LoadSDNode *LD = cast<LoadSDNode>(N);
2216
2217   // The select routine does not have access to the LoadSDNode instance, so
2218   // pass along the extension information
2219   OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
2220
2221   SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
2222                                           OtherOps.size(), LD->getMemoryVT(),
2223                                           LD->getMemOperand());
2224
2225   SmallVector<SDValue, 4> ScalarRes;
2226
2227   for (unsigned i = 0; i < NumElts; ++i) {
2228     SDValue Res = NewLD.getValue(i);
2229     if (NeedTrunc)
2230       Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2231     ScalarRes.push_back(Res);
2232   }
2233
2234   SDValue LoadChain = NewLD.getValue(NumElts);
2235
2236   SDValue BuildVec =
2237       DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2238
2239   Results.push_back(BuildVec);
2240   Results.push_back(LoadChain);
2241 }
2242
2243 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
2244                                      SmallVectorImpl<SDValue> &Results) {
2245   SDValue Chain = N->getOperand(0);
2246   SDValue Intrin = N->getOperand(1);
2247   SDLoc DL(N);
2248
2249   // Get the intrinsic ID
2250   unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
2251   switch (IntrinNo) {
2252   default:
2253     return;
2254   case Intrinsic::nvvm_ldg_global_i:
2255   case Intrinsic::nvvm_ldg_global_f:
2256   case Intrinsic::nvvm_ldg_global_p:
2257   case Intrinsic::nvvm_ldu_global_i:
2258   case Intrinsic::nvvm_ldu_global_f:
2259   case Intrinsic::nvvm_ldu_global_p: {
2260     EVT ResVT = N->getValueType(0);
2261
2262     if (ResVT.isVector()) {
2263       // Vector LDG/LDU
2264
2265       unsigned NumElts = ResVT.getVectorNumElements();
2266       EVT EltVT = ResVT.getVectorElementType();
2267
2268       // Since LDU/LDG are target nodes, we cannot rely on DAG type
2269       // legalization.
2270       // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2271       // loaded type to i16 and propogate the "real" type as the memory type.
2272       bool NeedTrunc = false;
2273       if (EltVT.getSizeInBits() < 16) {
2274         EltVT = MVT::i16;
2275         NeedTrunc = true;
2276       }
2277
2278       unsigned Opcode = 0;
2279       SDVTList LdResVTs;
2280
2281       switch (NumElts) {
2282       default:
2283         return;
2284       case 2:
2285         switch (IntrinNo) {
2286         default:
2287           return;
2288         case Intrinsic::nvvm_ldg_global_i:
2289         case Intrinsic::nvvm_ldg_global_f:
2290         case Intrinsic::nvvm_ldg_global_p:
2291           Opcode = NVPTXISD::LDGV2;
2292           break;
2293         case Intrinsic::nvvm_ldu_global_i:
2294         case Intrinsic::nvvm_ldu_global_f:
2295         case Intrinsic::nvvm_ldu_global_p:
2296           Opcode = NVPTXISD::LDUV2;
2297           break;
2298         }
2299         LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2300         break;
2301       case 4: {
2302         switch (IntrinNo) {
2303         default:
2304           return;
2305         case Intrinsic::nvvm_ldg_global_i:
2306         case Intrinsic::nvvm_ldg_global_f:
2307         case Intrinsic::nvvm_ldg_global_p:
2308           Opcode = NVPTXISD::LDGV4;
2309           break;
2310         case Intrinsic::nvvm_ldu_global_i:
2311         case Intrinsic::nvvm_ldu_global_f:
2312         case Intrinsic::nvvm_ldu_global_p:
2313           Opcode = NVPTXISD::LDUV4;
2314           break;
2315         }
2316         EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2317         LdResVTs = DAG.getVTList(ListVTs, 5);
2318         break;
2319       }
2320       }
2321
2322       SmallVector<SDValue, 8> OtherOps;
2323
2324       // Copy regular operands
2325
2326       OtherOps.push_back(Chain); // Chain
2327                                  // Skip operand 1 (intrinsic ID)
2328       // Others
2329       for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
2330         OtherOps.push_back(N->getOperand(i));
2331
2332       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2333
2334       SDValue NewLD = DAG.getMemIntrinsicNode(
2335           Opcode, DL, LdResVTs, &OtherOps[0], OtherOps.size(),
2336           MemSD->getMemoryVT(), MemSD->getMemOperand());
2337
2338       SmallVector<SDValue, 4> ScalarRes;
2339
2340       for (unsigned i = 0; i < NumElts; ++i) {
2341         SDValue Res = NewLD.getValue(i);
2342         if (NeedTrunc)
2343           Res =
2344               DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2345         ScalarRes.push_back(Res);
2346       }
2347
2348       SDValue LoadChain = NewLD.getValue(NumElts);
2349
2350       SDValue BuildVec =
2351           DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2352
2353       Results.push_back(BuildVec);
2354       Results.push_back(LoadChain);
2355     } else {
2356       // i8 LDG/LDU
2357       assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
2358              "Custom handling of non-i8 ldu/ldg?");
2359
2360       // Just copy all operands as-is
2361       SmallVector<SDValue, 4> Ops;
2362       for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2363         Ops.push_back(N->getOperand(i));
2364
2365       // Force output to i16
2366       SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
2367
2368       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2369
2370       // We make sure the memory type is i8, which will be used during isel
2371       // to select the proper instruction.
2372       SDValue NewLD =
2373           DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0],
2374                                   Ops.size(), MVT::i8, MemSD->getMemOperand());
2375
2376       Results.push_back(NewLD.getValue(0));
2377       Results.push_back(NewLD.getValue(1));
2378     }
2379   }
2380   }
2381 }
2382
2383 void NVPTXTargetLowering::ReplaceNodeResults(
2384     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
2385   switch (N->getOpcode()) {
2386   default:
2387     report_fatal_error("Unhandled custom legalization");
2388   case ISD::LOAD:
2389     ReplaceLoadVector(N, DAG, Results);
2390     return;
2391   case ISD::INTRINSIC_W_CHAIN:
2392     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
2393     return;
2394   }
2395 }