[NVPTX] Remove i8 register class. PTX support for i8 (.b8, .u8, .s8) is rather poor...
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
1 //
2 //                     The LLVM Compiler Infrastructure
3 //
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "NVPTXISelLowering.h"
15 #include "NVPTX.h"
16 #include "NVPTXTargetMachine.h"
17 #include "NVPTXTargetObjectFile.h"
18 #include "NVPTXUtilities.h"
19 #include "llvm/CodeGen/Analysis.h"
20 #include "llvm/CodeGen/MachineFrameInfo.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/MC/MCSectionELF.h"
32 #include "llvm/Support/CallSite.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <sstream>
38
39 #undef DEBUG_TYPE
40 #define DEBUG_TYPE "nvptx-lower"
41
42 using namespace llvm;
43
44 static unsigned int uniqueCallSite = 0;
45
46 static cl::opt<bool> sched4reg(
47     "nvptx-sched4reg",
48     cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
49
50 static bool IsPTXVectorType(MVT VT) {
51   switch (VT.SimpleTy) {
52   default:
53     return false;
54   case MVT::v2i1:
55   case MVT::v4i1:
56   case MVT::v2i8:
57   case MVT::v4i8:
58   case MVT::v2i16:
59   case MVT::v4i16:
60   case MVT::v2i32:
61   case MVT::v4i32:
62   case MVT::v2i64:
63   case MVT::v2f32:
64   case MVT::v4f32:
65   case MVT::v2f64:
66     return true;
67   }
68 }
69
70 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
71 /// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
72 /// into their primitive components.
73 /// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
74 /// same number of types as the Ins/Outs arrays in LowerFormalArguments,
75 /// LowerCall, and LowerReturn.
76 static void ComputePTXValueVTs(const TargetLowering &TLI, Type *Ty,
77                                SmallVectorImpl<EVT> &ValueVTs,
78                                SmallVectorImpl<uint64_t> *Offsets = 0,
79                                uint64_t StartingOffset = 0) {
80   SmallVector<EVT, 16> TempVTs;
81   SmallVector<uint64_t, 16> TempOffsets;
82
83   ComputeValueVTs(TLI, Ty, TempVTs, &TempOffsets, StartingOffset);
84   for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
85     EVT VT = TempVTs[i];
86     uint64_t Off = TempOffsets[i];
87     if (VT.isVector())
88       for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) {
89         ValueVTs.push_back(VT.getVectorElementType());
90         if (Offsets)
91           Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize());
92       }
93     else {
94       ValueVTs.push_back(VT);
95       if (Offsets)
96         Offsets->push_back(Off);
97     }
98   }
99 }
100
101 // NVPTXTargetLowering Constructor.
102 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
103     : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
104       nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
105
106   // always lower memset, memcpy, and memmove intrinsics to load/store
107   // instructions, rather
108   // then generating calls to memset, mempcy or memmove.
109   MaxStoresPerMemset = (unsigned) 0xFFFFFFFF;
110   MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF;
111   MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF;
112
113   setBooleanContents(ZeroOrNegativeOneBooleanContent);
114
115   // Jump is Expensive. Don't create extra control flow for 'and', 'or'
116   // condition branches.
117   setJumpIsExpensive(true);
118
119   // By default, use the Source scheduling
120   if (sched4reg)
121     setSchedulingPreference(Sched::RegPressure);
122   else
123     setSchedulingPreference(Sched::Source);
124
125   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
126   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
127   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
128   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
129   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
130   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
131
132   // Operations not directly supported by NVPTX.
133   setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
134   setOperationAction(ISD::BR_CC, MVT::f32, Expand);
135   setOperationAction(ISD::BR_CC, MVT::f64, Expand);
136   setOperationAction(ISD::BR_CC, MVT::i1, Expand);
137   setOperationAction(ISD::BR_CC, MVT::i8, Expand);
138   setOperationAction(ISD::BR_CC, MVT::i16, Expand);
139   setOperationAction(ISD::BR_CC, MVT::i32, Expand);
140   setOperationAction(ISD::BR_CC, MVT::i64, Expand);
141   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Expand);
142   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Expand);
143   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand);
144   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8, Expand);
145   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
146
147   if (nvptxSubtarget.hasROT64()) {
148     setOperationAction(ISD::ROTL, MVT::i64, Legal);
149     setOperationAction(ISD::ROTR, MVT::i64, Legal);
150   } else {
151     setOperationAction(ISD::ROTL, MVT::i64, Expand);
152     setOperationAction(ISD::ROTR, MVT::i64, Expand);
153   }
154   if (nvptxSubtarget.hasROT32()) {
155     setOperationAction(ISD::ROTL, MVT::i32, Legal);
156     setOperationAction(ISD::ROTR, MVT::i32, Legal);
157   } else {
158     setOperationAction(ISD::ROTL, MVT::i32, Expand);
159     setOperationAction(ISD::ROTR, MVT::i32, Expand);
160   }
161
162   setOperationAction(ISD::ROTL, MVT::i16, Expand);
163   setOperationAction(ISD::ROTR, MVT::i16, Expand);
164   setOperationAction(ISD::ROTL, MVT::i8, Expand);
165   setOperationAction(ISD::ROTR, MVT::i8, Expand);
166   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
167   setOperationAction(ISD::BSWAP, MVT::i32, Expand);
168   setOperationAction(ISD::BSWAP, MVT::i64, Expand);
169
170   // Indirect branch is not supported.
171   // This also disables Jump Table creation.
172   setOperationAction(ISD::BR_JT, MVT::Other, Expand);
173   setOperationAction(ISD::BRIND, MVT::Other, Expand);
174
175   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
176   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
177
178   // We want to legalize constant related memmove and memcopy
179   // intrinsics.
180   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
181
182   // Turn FP extload into load/fextend
183   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
184   // Turn FP truncstore into trunc + store.
185   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
186
187   // PTX does not support load / store predicate registers
188   setOperationAction(ISD::LOAD, MVT::i1, Custom);
189   setOperationAction(ISD::STORE, MVT::i1, Custom);
190
191   setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
192   setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
193   setTruncStoreAction(MVT::i64, MVT::i1, Expand);
194   setTruncStoreAction(MVT::i32, MVT::i1, Expand);
195   setTruncStoreAction(MVT::i16, MVT::i1, Expand);
196   setTruncStoreAction(MVT::i8, MVT::i1, Expand);
197
198   // This is legal in NVPTX
199   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
200   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
201
202   // TRAP can be lowered to PTX trap
203   setOperationAction(ISD::TRAP, MVT::Other, Legal);
204
205   // Register custom handling for vector loads/stores
206   for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
207        ++i) {
208     MVT VT = (MVT::SimpleValueType) i;
209     if (IsPTXVectorType(VT)) {
210       setOperationAction(ISD::LOAD, VT, Custom);
211       setOperationAction(ISD::STORE, VT, Custom);
212       setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
213     }
214   }
215
216   // Custom handling for i8 intrinsics
217   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
218
219   // Now deduce the information based on the above mentioned
220   // actions
221   computeRegisterProperties();
222 }
223
224 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
225   switch (Opcode) {
226   default:
227     return 0;
228   case NVPTXISD::CALL:
229     return "NVPTXISD::CALL";
230   case NVPTXISD::RET_FLAG:
231     return "NVPTXISD::RET_FLAG";
232   case NVPTXISD::Wrapper:
233     return "NVPTXISD::Wrapper";
234   case NVPTXISD::NVBuiltin:
235     return "NVPTXISD::NVBuiltin";
236   case NVPTXISD::DeclareParam:
237     return "NVPTXISD::DeclareParam";
238   case NVPTXISD::DeclareScalarParam:
239     return "NVPTXISD::DeclareScalarParam";
240   case NVPTXISD::DeclareRet:
241     return "NVPTXISD::DeclareRet";
242   case NVPTXISD::DeclareRetParam:
243     return "NVPTXISD::DeclareRetParam";
244   case NVPTXISD::PrintCall:
245     return "NVPTXISD::PrintCall";
246   case NVPTXISD::LoadParam:
247     return "NVPTXISD::LoadParam";
248   case NVPTXISD::LoadParamV2:
249     return "NVPTXISD::LoadParamV2";
250   case NVPTXISD::LoadParamV4:
251     return "NVPTXISD::LoadParamV4";
252   case NVPTXISD::StoreParam:
253     return "NVPTXISD::StoreParam";
254   case NVPTXISD::StoreParamV2:
255     return "NVPTXISD::StoreParamV2";
256   case NVPTXISD::StoreParamV4:
257     return "NVPTXISD::StoreParamV4";
258   case NVPTXISD::StoreParamS32:
259     return "NVPTXISD::StoreParamS32";
260   case NVPTXISD::StoreParamU32:
261     return "NVPTXISD::StoreParamU32";
262   case NVPTXISD::MoveToParam:
263     return "NVPTXISD::MoveToParam";
264   case NVPTXISD::CallArgBegin:
265     return "NVPTXISD::CallArgBegin";
266   case NVPTXISD::CallArg:
267     return "NVPTXISD::CallArg";
268   case NVPTXISD::LastCallArg:
269     return "NVPTXISD::LastCallArg";
270   case NVPTXISD::CallArgEnd:
271     return "NVPTXISD::CallArgEnd";
272   case NVPTXISD::CallVoid:
273     return "NVPTXISD::CallVoid";
274   case NVPTXISD::CallVal:
275     return "NVPTXISD::CallVal";
276   case NVPTXISD::CallSymbol:
277     return "NVPTXISD::CallSymbol";
278   case NVPTXISD::Prototype:
279     return "NVPTXISD::Prototype";
280   case NVPTXISD::MoveParam:
281     return "NVPTXISD::MoveParam";
282   case NVPTXISD::MoveRetval:
283     return "NVPTXISD::MoveRetval";
284   case NVPTXISD::MoveToRetval:
285     return "NVPTXISD::MoveToRetval";
286   case NVPTXISD::StoreRetval:
287     return "NVPTXISD::StoreRetval";
288   case NVPTXISD::StoreRetvalV2:
289     return "NVPTXISD::StoreRetvalV2";
290   case NVPTXISD::StoreRetvalV4:
291     return "NVPTXISD::StoreRetvalV4";
292   case NVPTXISD::PseudoUseParam:
293     return "NVPTXISD::PseudoUseParam";
294   case NVPTXISD::RETURN:
295     return "NVPTXISD::RETURN";
296   case NVPTXISD::CallSeqBegin:
297     return "NVPTXISD::CallSeqBegin";
298   case NVPTXISD::CallSeqEnd:
299     return "NVPTXISD::CallSeqEnd";
300   case NVPTXISD::LoadV2:
301     return "NVPTXISD::LoadV2";
302   case NVPTXISD::LoadV4:
303     return "NVPTXISD::LoadV4";
304   case NVPTXISD::LDGV2:
305     return "NVPTXISD::LDGV2";
306   case NVPTXISD::LDGV4:
307     return "NVPTXISD::LDGV4";
308   case NVPTXISD::LDUV2:
309     return "NVPTXISD::LDUV2";
310   case NVPTXISD::LDUV4:
311     return "NVPTXISD::LDUV4";
312   case NVPTXISD::StoreV2:
313     return "NVPTXISD::StoreV2";
314   case NVPTXISD::StoreV4:
315     return "NVPTXISD::StoreV4";
316   }
317 }
318
319 bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
320   return VT == MVT::i1;
321 }
322
323 SDValue
324 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
325   SDLoc dl(Op);
326   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
327   Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
328   return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
329 }
330
331 /*
332 std::string NVPTXTargetLowering::getPrototype(
333     Type *retTy, const ArgListTy &Args,
334     const SmallVectorImpl<ISD::OutputArg> &Outs, unsigned retAlignment) const {
335
336   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
337
338   std::stringstream O;
339   O << "prototype_" << uniqueCallSite << " : .callprototype ";
340
341   if (retTy->getTypeID() == Type::VoidTyID)
342     O << "()";
343   else {
344     O << "(";
345     if (isABI) {
346       if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
347         unsigned size = 0;
348         if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
349           size = ITy->getBitWidth();
350           if (size < 32)
351             size = 32;
352         } else {
353           assert(retTy->isFloatingPointTy() &&
354                  "Floating point type expected here");
355           size = retTy->getPrimitiveSizeInBits();
356         }
357
358         O << ".param .b" << size << " _";
359       } else if (isa<PointerType>(retTy))
360         O << ".param .b" << getPointerTy().getSizeInBits() << " _";
361       else {
362         if ((retTy->getTypeID() == Type::StructTyID) ||
363             isa<VectorType>(retTy)) {
364           SmallVector<EVT, 16> vtparts;
365           ComputeValueVTs(*this, retTy, vtparts);
366           unsigned totalsz = 0;
367           for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
368             unsigned elems = 1;
369             EVT elemtype = vtparts[i];
370             if (vtparts[i].isVector()) {
371               elems = vtparts[i].getVectorNumElements();
372               elemtype = vtparts[i].getVectorElementType();
373             }
374             for (unsigned j = 0, je = elems; j != je; ++j) {
375               unsigned sz = elemtype.getSizeInBits();
376               if (elemtype.isInteger() && (sz < 8))
377                 sz = 8;
378               totalsz += sz / 8;
379             }
380           }
381           O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
382         } else {
383           assert(false && "Unknown return type");
384         }
385       }
386     } else {
387       SmallVector<EVT, 16> vtparts;
388       ComputeValueVTs(*this, retTy, vtparts);
389       unsigned idx = 0;
390       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
391         unsigned elems = 1;
392         EVT elemtype = vtparts[i];
393         if (vtparts[i].isVector()) {
394           elems = vtparts[i].getVectorNumElements();
395           elemtype = vtparts[i].getVectorElementType();
396         }
397
398         for (unsigned j = 0, je = elems; j != je; ++j) {
399           unsigned sz = elemtype.getSizeInBits();
400           if (elemtype.isInteger() && (sz < 32))
401             sz = 32;
402           O << ".reg .b" << sz << " _";
403           if (j < je - 1)
404             O << ", ";
405           ++idx;
406         }
407         if (i < e - 1)
408           O << ", ";
409       }
410     }
411     O << ") ";
412   }
413   O << "_ (";
414
415   bool first = true;
416   MVT thePointerTy = getPointerTy();
417
418   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
419     const Type *Ty = Args[i].Ty;
420     if (!first) {
421       O << ", ";
422     }
423     first = false;
424
425     if (Outs[i].Flags.isByVal() == false) {
426       unsigned sz = 0;
427       if (isa<IntegerType>(Ty)) {
428         sz = cast<IntegerType>(Ty)->getBitWidth();
429         if (sz < 32)
430           sz = 32;
431       } else if (isa<PointerType>(Ty))
432         sz = thePointerTy.getSizeInBits();
433       else
434         sz = Ty->getPrimitiveSizeInBits();
435       if (isABI)
436         O << ".param .b" << sz << " ";
437       else
438         O << ".reg .b" << sz << " ";
439       O << "_";
440       continue;
441     }
442     const PointerType *PTy = dyn_cast<PointerType>(Ty);
443     assert(PTy && "Param with byval attribute should be a pointer type");
444     Type *ETy = PTy->getElementType();
445
446     if (isABI) {
447       unsigned align = Outs[i].Flags.getByValAlign();
448       unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
449       O << ".param .align " << align << " .b8 ";
450       O << "_";
451       O << "[" << sz << "]";
452       continue;
453     } else {
454       SmallVector<EVT, 16> vtparts;
455       ComputeValueVTs(*this, ETy, vtparts);
456       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
457         unsigned elems = 1;
458         EVT elemtype = vtparts[i];
459         if (vtparts[i].isVector()) {
460           elems = vtparts[i].getVectorNumElements();
461           elemtype = vtparts[i].getVectorElementType();
462         }
463
464         for (unsigned j = 0, je = elems; j != je; ++j) {
465           unsigned sz = elemtype.getSizeInBits();
466           if (elemtype.isInteger() && (sz < 32))
467             sz = 32;
468           O << ".reg .b" << sz << " ";
469           O << "_";
470           if (j < je - 1)
471             O << ", ";
472         }
473         if (i < e - 1)
474           O << ", ";
475       }
476       continue;
477     }
478   }
479   O << ");";
480   return O.str();
481 }*/
482
483 std::string
484 NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
485                                   const SmallVectorImpl<ISD::OutputArg> &Outs,
486                                   unsigned retAlignment,
487                                   const ImmutableCallSite *CS) const {
488
489   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
490   assert(isABI && "Non-ABI compilation is not supported");
491   if (!isABI)
492     return "";
493
494   std::stringstream O;
495   O << "prototype_" << uniqueCallSite << " : .callprototype ";
496
497   if (retTy->getTypeID() == Type::VoidTyID) {
498     O << "()";
499   } else {
500     O << "(";
501     if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
502       unsigned size = 0;
503       if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
504         size = ITy->getBitWidth();
505         if (size < 32)
506           size = 32;
507       } else {
508         assert(retTy->isFloatingPointTy() &&
509                "Floating point type expected here");
510         size = retTy->getPrimitiveSizeInBits();
511       }
512
513       O << ".param .b" << size << " _";
514     } else if (isa<PointerType>(retTy)) {
515       O << ".param .b" << getPointerTy().getSizeInBits() << " _";
516     } else {
517       if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) {
518         SmallVector<EVT, 16> vtparts;
519         ComputeValueVTs(*this, retTy, vtparts);
520         unsigned totalsz = 0;
521         for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
522           unsigned elems = 1;
523           EVT elemtype = vtparts[i];
524           if (vtparts[i].isVector()) {
525             elems = vtparts[i].getVectorNumElements();
526             elemtype = vtparts[i].getVectorElementType();
527           }
528           // TODO: no need to loop
529           for (unsigned j = 0, je = elems; j != je; ++j) {
530             unsigned sz = elemtype.getSizeInBits();
531             if (elemtype.isInteger() && (sz < 8))
532               sz = 8;
533             totalsz += sz / 8;
534           }
535         }
536         O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
537       } else {
538         assert(false && "Unknown return type");
539       }
540     }
541     O << ") ";
542   }
543   O << "_ (";
544
545   bool first = true;
546   MVT thePointerTy = getPointerTy();
547
548   unsigned OIdx = 0;
549   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
550     Type *Ty = Args[i].Ty;
551     if (!first) {
552       O << ", ";
553     }
554     first = false;
555
556     if (Outs[OIdx].Flags.isByVal() == false) {
557       if (Ty->isAggregateType() || Ty->isVectorTy()) {
558         unsigned align = 0;
559         const CallInst *CallI = cast<CallInst>(CS->getInstruction());
560         const DataLayout *TD = getDataLayout();
561         // +1 because index 0 is reserved for return type alignment
562         if (!llvm::getAlign(*CallI, i + 1, align))
563           align = TD->getABITypeAlignment(Ty);
564         unsigned sz = TD->getTypeAllocSize(Ty);
565         O << ".param .align " << align << " .b8 ";
566         O << "_";
567         O << "[" << sz << "]";
568         // update the index for Outs
569         SmallVector<EVT, 16> vtparts;
570         ComputeValueVTs(*this, Ty, vtparts);
571         if (unsigned len = vtparts.size())
572           OIdx += len - 1;
573         continue;
574       }
575       assert(getValueType(Ty) == Outs[OIdx].VT &&
576              "type mismatch between callee prototype and arguments");
577       // scalar type
578       unsigned sz = 0;
579       if (isa<IntegerType>(Ty)) {
580         sz = cast<IntegerType>(Ty)->getBitWidth();
581         if (sz < 32)
582           sz = 32;
583       } else if (isa<PointerType>(Ty))
584         sz = thePointerTy.getSizeInBits();
585       else
586         sz = Ty->getPrimitiveSizeInBits();
587       O << ".param .b" << sz << " ";
588       O << "_";
589       continue;
590     }
591     const PointerType *PTy = dyn_cast<PointerType>(Ty);
592     assert(PTy && "Param with byval attribute should be a pointer type");
593     Type *ETy = PTy->getElementType();
594
595     unsigned align = Outs[OIdx].Flags.getByValAlign();
596     unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
597     O << ".param .align " << align << " .b8 ";
598     O << "_";
599     O << "[" << sz << "]";
600   }
601   O << ");";
602   return O.str();
603 }
604
605 unsigned
606 NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
607                                           const ImmutableCallSite *CS,
608                                           Type *Ty,
609                                           unsigned Idx) const {
610   const DataLayout *TD = getDataLayout();
611   unsigned align = 0;
612   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
613
614   if (Func) { // direct call
615     assert(CS->getCalledFunction() &&
616            "direct call cannot find callee");
617     if (!llvm::getAlign(*(CS->getCalledFunction()), Idx, align))
618       align = TD->getABITypeAlignment(Ty);
619   }
620   else { // indirect call
621     const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
622     if (!llvm::getAlign(*CallI, Idx, align))
623       align = TD->getABITypeAlignment(Ty);
624   }
625
626   return align;
627 }
628
629 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
630                                        SmallVectorImpl<SDValue> &InVals) const {
631   SelectionDAG &DAG = CLI.DAG;
632   SDLoc dl = CLI.DL;
633   SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
634   SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
635   SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
636   SDValue Chain = CLI.Chain;
637   SDValue Callee = CLI.Callee;
638   bool &isTailCall = CLI.IsTailCall;
639   ArgListTy &Args = CLI.Args;
640   Type *retTy = CLI.RetTy;
641   ImmutableCallSite *CS = CLI.CS;
642
643   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
644   assert(isABI && "Non-ABI compilation is not supported");
645   if (!isABI)
646     return Chain;
647   const DataLayout *TD = getDataLayout();
648   MachineFunction &MF = DAG.getMachineFunction();
649   const Function *F = MF.getFunction();
650   const TargetLowering *TLI = nvTM->getTargetLowering();
651
652   SDValue tempChain = Chain;
653   Chain =
654       DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
655                            dl);
656   SDValue InFlag = Chain.getValue(1);
657
658   unsigned paramCount = 0;
659   // Args.size() and Outs.size() need not match.
660   // Outs.size() will be larger
661   //   * if there is an aggregate argument with multiple fields (each field
662   //     showing up separately in Outs)
663   //   * if there is a vector argument with more than typical vector-length
664   //     elements (generally if more than 4) where each vector element is
665   //     individually present in Outs.
666   // So a different index should be used for indexing into Outs/OutVals.
667   // See similar issue in LowerFormalArguments.
668   unsigned OIdx = 0;
669   // Declare the .params or .reg need to pass values
670   // to the function
671   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
672     EVT VT = Outs[OIdx].VT;
673     Type *Ty = Args[i].Ty;
674
675     if (Outs[OIdx].Flags.isByVal() == false) {
676       if (Ty->isAggregateType()) {
677         // aggregate
678         SmallVector<EVT, 16> vtparts;
679         ComputeValueVTs(*this, Ty, vtparts);
680
681         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
682         // declare .param .align <align> .b8 .param<n>[<size>];
683         unsigned sz = TD->getTypeAllocSize(Ty);
684         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
685         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
686                                       DAG.getConstant(paramCount, MVT::i32),
687                                       DAG.getConstant(sz, MVT::i32), InFlag };
688         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
689                             DeclareParamOps, 5);
690         InFlag = Chain.getValue(1);
691         unsigned curOffset = 0;
692         for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
693           unsigned elems = 1;
694           EVT elemtype = vtparts[j];
695           if (vtparts[j].isVector()) {
696             elems = vtparts[j].getVectorNumElements();
697             elemtype = vtparts[j].getVectorElementType();
698           }
699           for (unsigned k = 0, ke = elems; k != ke; ++k) {
700             unsigned sz = elemtype.getSizeInBits();
701             if (elemtype.isInteger() && (sz < 8))
702               sz = 8;
703             SDValue StVal = OutVals[OIdx];
704             if (elemtype.getSizeInBits() < 16) {
705               StVal = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i16, StVal);
706             }
707             SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
708             SDValue CopyParamOps[] = { Chain,
709                                        DAG.getConstant(paramCount, MVT::i32),
710                                        DAG.getConstant(curOffset, MVT::i32),
711                                        StVal, InFlag };
712             Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
713                                             CopyParamVTs, &CopyParamOps[0], 5,
714                                             elemtype, MachinePointerInfo());
715             InFlag = Chain.getValue(1);
716             curOffset += sz / 8;
717             ++OIdx;
718           }
719         }
720         if (vtparts.size() > 0)
721           --OIdx;
722         ++paramCount;
723         continue;
724       }
725       if (Ty->isVectorTy()) {
726         EVT ObjectVT = getValueType(Ty);
727         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
728         // declare .param .align <align> .b8 .param<n>[<size>];
729         unsigned sz = TD->getTypeAllocSize(Ty);
730         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
731         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
732                                       DAG.getConstant(paramCount, MVT::i32),
733                                       DAG.getConstant(sz, MVT::i32), InFlag };
734         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
735                             DeclareParamOps, 5);
736         InFlag = Chain.getValue(1);
737         unsigned NumElts = ObjectVT.getVectorNumElements();
738         EVT EltVT = ObjectVT.getVectorElementType();
739         EVT MemVT = EltVT;
740         bool NeedExtend = false;
741         if (EltVT.getSizeInBits() < 16) {
742           NeedExtend = true;
743           EltVT = MVT::i16;
744         }
745
746         // V1 store
747         if (NumElts == 1) {
748           SDValue Elt = OutVals[OIdx++];
749           if (NeedExtend)
750             Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt);
751
752           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
753           SDValue CopyParamOps[] = { Chain,
754                                      DAG.getConstant(paramCount, MVT::i32),
755                                      DAG.getConstant(0, MVT::i32), Elt,
756                                      InFlag };
757           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
758                                           CopyParamVTs, &CopyParamOps[0], 5,
759                                           MemVT, MachinePointerInfo());
760           InFlag = Chain.getValue(1);
761         } else if (NumElts == 2) {
762           SDValue Elt0 = OutVals[OIdx++];
763           SDValue Elt1 = OutVals[OIdx++];
764           if (NeedExtend) {
765             Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0);
766             Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1);
767           }
768
769           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
770           SDValue CopyParamOps[] = { Chain,
771                                      DAG.getConstant(paramCount, MVT::i32),
772                                      DAG.getConstant(0, MVT::i32), Elt0, Elt1,
773                                      InFlag };
774           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl,
775                                           CopyParamVTs, &CopyParamOps[0], 6,
776                                           MemVT, MachinePointerInfo());
777           InFlag = Chain.getValue(1);
778         } else {
779           unsigned curOffset = 0;
780           // V4 stores
781           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
782           // the
783           // vector will be expanded to a power of 2 elements, so we know we can
784           // always round up to the next multiple of 4 when creating the vector
785           // stores.
786           // e.g.  4 elem => 1 st.v4
787           //       6 elem => 2 st.v4
788           //       8 elem => 2 st.v4
789           //      11 elem => 3 st.v4
790           unsigned VecSize = 4;
791           if (EltVT.getSizeInBits() == 64)
792             VecSize = 2;
793
794           // This is potentially only part of a vector, so assume all elements
795           // are packed together.
796           unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize;
797
798           for (unsigned i = 0; i < NumElts; i += VecSize) {
799             // Get values
800             SDValue StoreVal;
801             SmallVector<SDValue, 8> Ops;
802             Ops.push_back(Chain);
803             Ops.push_back(DAG.getConstant(paramCount, MVT::i32));
804             Ops.push_back(DAG.getConstant(curOffset, MVT::i32));
805
806             unsigned Opc = NVPTXISD::StoreParamV2;
807
808             StoreVal = OutVals[OIdx++];
809             if (NeedExtend)
810               StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
811             Ops.push_back(StoreVal);
812
813             if (i + 1 < NumElts) {
814               StoreVal = OutVals[OIdx++];
815               if (NeedExtend)
816                 StoreVal =
817                     DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
818             } else {
819               StoreVal = DAG.getUNDEF(EltVT);
820             }
821             Ops.push_back(StoreVal);
822
823             if (VecSize == 4) {
824               Opc = NVPTXISD::StoreParamV4;
825               if (i + 2 < NumElts) {
826                 StoreVal = OutVals[OIdx++];
827                 if (NeedExtend)
828                   StoreVal =
829                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
830               } else {
831                 StoreVal = DAG.getUNDEF(EltVT);
832               }
833               Ops.push_back(StoreVal);
834
835               if (i + 3 < NumElts) {
836                 StoreVal = OutVals[OIdx++];
837                 if (NeedExtend)
838                   StoreVal =
839                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
840               } else {
841                 StoreVal = DAG.getUNDEF(EltVT);
842               }
843               Ops.push_back(StoreVal);
844             }
845
846             SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
847             Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, &Ops[0],
848                                             Ops.size(), MemVT,
849                                             MachinePointerInfo());
850             InFlag = Chain.getValue(1);
851             curOffset += PerStoreOffset;
852           }
853         }
854         ++paramCount;
855         --OIdx;
856         continue;
857       }
858       // Plain scalar
859       // for ABI,    declare .param .b<size> .param<n>;
860       unsigned sz = VT.getSizeInBits();
861       bool needExtend = false;
862       if (VT.isInteger()) {
863         if (sz < 16)
864           needExtend = true;
865         if (sz < 32)
866           sz = 32;
867       }
868       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
869       SDValue DeclareParamOps[] = { Chain,
870                                     DAG.getConstant(paramCount, MVT::i32),
871                                     DAG.getConstant(sz, MVT::i32),
872                                     DAG.getConstant(0, MVT::i32), InFlag };
873       Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
874                           DeclareParamOps, 5);
875       InFlag = Chain.getValue(1);
876       SDValue OutV = OutVals[OIdx];
877       if (needExtend) {
878         // zext/sext i1 to i16
879         unsigned opc = ISD::ZERO_EXTEND;
880         if (Outs[OIdx].Flags.isSExt())
881           opc = ISD::SIGN_EXTEND;
882         OutV = DAG.getNode(opc, dl, MVT::i16, OutV);
883       }
884       SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
885       SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
886                                  DAG.getConstant(0, MVT::i32), OutV, InFlag };
887
888       unsigned opcode = NVPTXISD::StoreParam;
889       if (Outs[OIdx].Flags.isZExt())
890         opcode = NVPTXISD::StoreParamU32;
891       else if (Outs[OIdx].Flags.isSExt())
892         opcode = NVPTXISD::StoreParamS32;
893       Chain = DAG.getMemIntrinsicNode(opcode, dl, CopyParamVTs, CopyParamOps, 5,
894                                       VT, MachinePointerInfo());
895
896       InFlag = Chain.getValue(1);
897       ++paramCount;
898       continue;
899     }
900     // struct or vector
901     SmallVector<EVT, 16> vtparts;
902     const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
903     assert(PTy && "Type of a byval parameter should be pointer");
904     ComputeValueVTs(*this, PTy->getElementType(), vtparts);
905
906     // declare .param .align <align> .b8 .param<n>[<size>];
907     unsigned sz = Outs[OIdx].Flags.getByValSize();
908     SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
909     // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
910     // so we don't need to worry about natural alignment or not.
911     // See TargetLowering::LowerCallTo().
912     SDValue DeclareParamOps[] = {
913       Chain, DAG.getConstant(Outs[OIdx].Flags.getByValAlign(), MVT::i32),
914       DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32),
915       InFlag
916     };
917     Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
918                         DeclareParamOps, 5);
919     InFlag = Chain.getValue(1);
920     unsigned curOffset = 0;
921     for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
922       unsigned elems = 1;
923       EVT elemtype = vtparts[j];
924       if (vtparts[j].isVector()) {
925         elems = vtparts[j].getVectorNumElements();
926         elemtype = vtparts[j].getVectorElementType();
927       }
928       for (unsigned k = 0, ke = elems; k != ke; ++k) {
929         unsigned sz = elemtype.getSizeInBits();
930         if (elemtype.isInteger() && (sz < 8))
931           sz = 8;
932         SDValue srcAddr =
933             DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
934                         DAG.getConstant(curOffset, getPointerTy()));
935         SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
936                                      MachinePointerInfo(), false, false, false,
937                                      0);
938         if (elemtype.getSizeInBits() < 16) {
939           theVal = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i16, theVal);
940         }
941         SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
942         SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
943                                    DAG.getConstant(curOffset, MVT::i32), theVal,
944                                    InFlag };
945         Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
946                                         CopyParamOps, 5, elemtype,
947                                         MachinePointerInfo());
948
949         InFlag = Chain.getValue(1);
950         curOffset += sz / 8;
951       }
952     }
953     ++paramCount;
954   }
955
956   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
957   unsigned retAlignment = 0;
958
959   // Handle Result
960   if (Ins.size() > 0) {
961     SmallVector<EVT, 16> resvtparts;
962     ComputeValueVTs(*this, retTy, resvtparts);
963
964     // Declare
965     //  .param .align 16 .b8 retval0[<size-in-bytes>], or
966     //  .param .b<size-in-bits> retval0
967     unsigned resultsz = TD->getTypeAllocSizeInBits(retTy);
968     if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
969         retTy->isPointerTy()) {
970       // Scalar needs to be at least 32bit wide
971       if (resultsz < 32)
972         resultsz = 32;
973       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
974       SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
975                                   DAG.getConstant(resultsz, MVT::i32),
976                                   DAG.getConstant(0, MVT::i32), InFlag };
977       Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
978                           DeclareRetOps, 5);
979       InFlag = Chain.getValue(1);
980     } else {
981       retAlignment = getArgumentAlignment(Callee, CS, retTy, 0);
982       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
983       SDValue DeclareRetOps[] = { Chain,
984                                   DAG.getConstant(retAlignment, MVT::i32),
985                                   DAG.getConstant(resultsz / 8, MVT::i32),
986                                   DAG.getConstant(0, MVT::i32), InFlag };
987       Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
988                           DeclareRetOps, 5);
989       InFlag = Chain.getValue(1);
990     }
991   }
992
993   if (!Func) {
994     // This is indirect function call case : PTX requires a prototype of the
995     // form
996     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
997     // to be emitted, and the label has to used as the last arg of call
998     // instruction.
999     // The prototype is embedded in a string and put as the operand for an
1000     // INLINEASM SDNode.
1001     SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1002     std::string proto_string =
1003         getPrototype(retTy, Args, Outs, retAlignment, CS);
1004     const char *asmstr = nvTM->getManagedStrPool()
1005         ->getManagedString(proto_string.c_str())->c_str();
1006     SDValue InlineAsmOps[] = {
1007       Chain, DAG.getTargetExternalSymbol(asmstr, getPointerTy()),
1008       DAG.getMDNode(0), DAG.getTargetConstant(0, MVT::i32), InFlag
1009     };
1010     Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
1011     InFlag = Chain.getValue(1);
1012   }
1013   // Op to just print "call"
1014   SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1015   SDValue PrintCallOps[] = {
1016     Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, MVT::i32), InFlag
1017   };
1018   Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
1019                       dl, PrintCallVTs, PrintCallOps, 3);
1020   InFlag = Chain.getValue(1);
1021
1022   // Ops to print out the function name
1023   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1024   SDValue CallVoidOps[] = { Chain, Callee, InFlag };
1025   Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
1026   InFlag = Chain.getValue(1);
1027
1028   // Ops to print out the param list
1029   SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1030   SDValue CallArgBeginOps[] = { Chain, InFlag };
1031   Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
1032                       CallArgBeginOps, 2);
1033   InFlag = Chain.getValue(1);
1034
1035   for (unsigned i = 0, e = paramCount; i != e; ++i) {
1036     unsigned opcode;
1037     if (i == (e - 1))
1038       opcode = NVPTXISD::LastCallArg;
1039     else
1040       opcode = NVPTXISD::CallArg;
1041     SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1042     SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
1043                              DAG.getConstant(i, MVT::i32), InFlag };
1044     Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
1045     InFlag = Chain.getValue(1);
1046   }
1047   SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1048   SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32),
1049                               InFlag };
1050   Chain =
1051       DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps, 3);
1052   InFlag = Chain.getValue(1);
1053
1054   if (!Func) {
1055     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1056     SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32),
1057                                InFlag };
1058     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
1059     InFlag = Chain.getValue(1);
1060   }
1061
1062   // Generate loads from param memory/moves from registers for result
1063   if (Ins.size() > 0) {
1064     unsigned resoffset = 0;
1065     if (retTy && retTy->isVectorTy()) {
1066       EVT ObjectVT = getValueType(retTy);
1067       unsigned NumElts = ObjectVT.getVectorNumElements();
1068       EVT EltVT = ObjectVT.getVectorElementType();
1069       assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1070              "Vector was not scalarized");
1071       unsigned sz = EltVT.getSizeInBits();
1072       bool needTruncate = sz < 16 ? true : false;
1073
1074       if (NumElts == 1) {
1075         // Just a simple load
1076         std::vector<EVT> LoadRetVTs;
1077         if (needTruncate) {
1078           // If loading i1 result, generate
1079           //   load i16
1080           //   trunc i16 to i1
1081           LoadRetVTs.push_back(MVT::i16);
1082         } else
1083           LoadRetVTs.push_back(EltVT);
1084         LoadRetVTs.push_back(MVT::Other);
1085         LoadRetVTs.push_back(MVT::Glue);
1086         std::vector<SDValue> LoadRetOps;
1087         LoadRetOps.push_back(Chain);
1088         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1089         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1090         LoadRetOps.push_back(InFlag);
1091         SDValue retval = DAG.getMemIntrinsicNode(
1092             NVPTXISD::LoadParam, dl,
1093             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1094             LoadRetOps.size(), EltVT, MachinePointerInfo());
1095         Chain = retval.getValue(1);
1096         InFlag = retval.getValue(2);
1097         SDValue Ret0 = retval;
1098         if (needTruncate)
1099           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Ret0);
1100         InVals.push_back(Ret0);
1101       } else if (NumElts == 2) {
1102         // LoadV2
1103         std::vector<EVT> LoadRetVTs;
1104         if (needTruncate) {
1105           // If loading i1 result, generate
1106           //   load i16
1107           //   trunc i16 to i1
1108           LoadRetVTs.push_back(MVT::i16);
1109           LoadRetVTs.push_back(MVT::i16);
1110         } else {
1111           LoadRetVTs.push_back(EltVT);
1112           LoadRetVTs.push_back(EltVT);
1113         }
1114         LoadRetVTs.push_back(MVT::Other);
1115         LoadRetVTs.push_back(MVT::Glue);
1116         std::vector<SDValue> LoadRetOps;
1117         LoadRetOps.push_back(Chain);
1118         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1119         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1120         LoadRetOps.push_back(InFlag);
1121         SDValue retval = DAG.getMemIntrinsicNode(
1122             NVPTXISD::LoadParamV2, dl,
1123             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1124             LoadRetOps.size(), EltVT, MachinePointerInfo());
1125         Chain = retval.getValue(2);
1126         InFlag = retval.getValue(3);
1127         SDValue Ret0 = retval.getValue(0);
1128         SDValue Ret1 = retval.getValue(1);
1129         if (needTruncate) {
1130           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret0);
1131           InVals.push_back(Ret0);
1132           Ret1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret1);
1133           InVals.push_back(Ret1);
1134         } else {
1135           InVals.push_back(Ret0);
1136           InVals.push_back(Ret1);
1137         }
1138       } else {
1139         // Split into N LoadV4
1140         unsigned Ofst = 0;
1141         unsigned VecSize = 4;
1142         unsigned Opc = NVPTXISD::LoadParamV4;
1143         if (EltVT.getSizeInBits() == 64) {
1144           VecSize = 2;
1145           Opc = NVPTXISD::LoadParamV2;
1146         }
1147         EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1148         for (unsigned i = 0; i < NumElts; i += VecSize) {
1149           SmallVector<EVT, 8> LoadRetVTs;
1150           if (needTruncate) {
1151             // If loading i1 result, generate
1152             //   load i16
1153             //   trunc i16 to i1
1154             for (unsigned j = 0; j < VecSize; ++j)
1155               LoadRetVTs.push_back(MVT::i16);
1156           } else {
1157             for (unsigned j = 0; j < VecSize; ++j)
1158               LoadRetVTs.push_back(EltVT);
1159           }
1160           LoadRetVTs.push_back(MVT::Other);
1161           LoadRetVTs.push_back(MVT::Glue);
1162           SmallVector<SDValue, 4> LoadRetOps;
1163           LoadRetOps.push_back(Chain);
1164           LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1165           LoadRetOps.push_back(DAG.getConstant(Ofst, MVT::i32));
1166           LoadRetOps.push_back(InFlag);
1167           SDValue retval = DAG.getMemIntrinsicNode(
1168               Opc, dl, DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()),
1169               &LoadRetOps[0], LoadRetOps.size(), EltVT, MachinePointerInfo());
1170           if (VecSize == 2) {
1171             Chain = retval.getValue(2);
1172             InFlag = retval.getValue(3);
1173           } else {
1174             Chain = retval.getValue(4);
1175             InFlag = retval.getValue(5);
1176           }
1177
1178           for (unsigned j = 0; j < VecSize; ++j) {
1179             if (i + j >= NumElts)
1180               break;
1181             SDValue Elt = retval.getValue(j);
1182             if (needTruncate)
1183               Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
1184             InVals.push_back(Elt);
1185           }
1186           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1187         }
1188       }
1189     } else {
1190       SmallVector<EVT, 16> VTs;
1191       ComputePTXValueVTs(*this, retTy, VTs);
1192       assert(VTs.size() == Ins.size() && "Bad value decomposition");
1193       for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
1194         unsigned sz = VTs[i].getSizeInBits();
1195         bool needTruncate = sz < 8 ? true : false;
1196         if (VTs[i].isInteger() && (sz < 8))
1197           sz = 8;
1198
1199         SmallVector<EVT, 4> LoadRetVTs;
1200         if (sz < 16) {
1201           // If loading i1/i8 result, generate
1202           //   load i8 (-> i16)
1203           //   trunc i16 to i1/i8
1204           LoadRetVTs.push_back(MVT::i16);
1205         } else
1206           LoadRetVTs.push_back(Ins[i].VT);
1207         LoadRetVTs.push_back(MVT::Other);
1208         LoadRetVTs.push_back(MVT::Glue);
1209
1210         SmallVector<SDValue, 4> LoadRetOps;
1211         LoadRetOps.push_back(Chain);
1212         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1213         LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
1214         LoadRetOps.push_back(InFlag);
1215         SDValue retval = DAG.getMemIntrinsicNode(
1216             NVPTXISD::LoadParam, dl,
1217             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1218             LoadRetOps.size(), VTs[i], MachinePointerInfo());
1219         Chain = retval.getValue(1);
1220         InFlag = retval.getValue(2);
1221         SDValue Ret0 = retval.getValue(0);
1222         if (needTruncate)
1223           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
1224         InVals.push_back(Ret0);
1225         resoffset += sz / 8;
1226       }
1227     }
1228   }
1229
1230   Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
1231                              DAG.getIntPtrConstant(uniqueCallSite + 1, true),
1232                              InFlag, dl);
1233   uniqueCallSite++;
1234
1235   // set isTailCall to false for now, until we figure out how to express
1236   // tail call optimization in PTX
1237   isTailCall = false;
1238   return Chain;
1239 }
1240
1241 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1242 // (see LegalizeDAG.cpp). This is slow and uses local memory.
1243 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1244 SDValue
1245 NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1246   SDNode *Node = Op.getNode();
1247   SDLoc dl(Node);
1248   SmallVector<SDValue, 8> Ops;
1249   unsigned NumOperands = Node->getNumOperands();
1250   for (unsigned i = 0; i < NumOperands; ++i) {
1251     SDValue SubOp = Node->getOperand(i);
1252     EVT VVT = SubOp.getNode()->getValueType(0);
1253     EVT EltVT = VVT.getVectorElementType();
1254     unsigned NumSubElem = VVT.getVectorNumElements();
1255     for (unsigned j = 0; j < NumSubElem; ++j) {
1256       Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
1257                                 DAG.getIntPtrConstant(j)));
1258     }
1259   }
1260   return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), &Ops[0],
1261                      Ops.size());
1262 }
1263
1264 SDValue
1265 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1266   switch (Op.getOpcode()) {
1267   case ISD::RETURNADDR:
1268     return SDValue();
1269   case ISD::FRAMEADDR:
1270     return SDValue();
1271   case ISD::GlobalAddress:
1272     return LowerGlobalAddress(Op, DAG);
1273   case ISD::INTRINSIC_W_CHAIN:
1274     return Op;
1275   case ISD::BUILD_VECTOR:
1276   case ISD::EXTRACT_SUBVECTOR:
1277     return Op;
1278   case ISD::CONCAT_VECTORS:
1279     return LowerCONCAT_VECTORS(Op, DAG);
1280   case ISD::STORE:
1281     return LowerSTORE(Op, DAG);
1282   case ISD::LOAD:
1283     return LowerLOAD(Op, DAG);
1284   default:
1285     llvm_unreachable("Custom lowering not defined for operation");
1286   }
1287 }
1288
1289 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
1290   if (Op.getValueType() == MVT::i1)
1291     return LowerLOADi1(Op, DAG);
1292   else
1293     return SDValue();
1294 }
1295
1296 // v = ld i1* addr
1297 //   =>
1298 // v1 = ld i8* addr (-> i16)
1299 // v = trunc i16 to i1
1300 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
1301   SDNode *Node = Op.getNode();
1302   LoadSDNode *LD = cast<LoadSDNode>(Node);
1303   SDLoc dl(Node);
1304   assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
1305   assert(Node->getValueType(0) == MVT::i1 &&
1306          "Custom lowering for i1 load only");
1307   SDValue newLD =
1308       DAG.getLoad(MVT::i16, dl, LD->getChain(), LD->getBasePtr(),
1309                   LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(),
1310                   LD->isInvariant(), LD->getAlignment());
1311   SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
1312   // The legalizer (the caller) is expecting two values from the legalized
1313   // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
1314   // in LegalizeDAG.cpp which also uses MergeValues.
1315   SDValue Ops[] = { result, LD->getChain() };
1316   return DAG.getMergeValues(Ops, 2, dl);
1317 }
1318
1319 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
1320   EVT ValVT = Op.getOperand(1).getValueType();
1321   if (ValVT == MVT::i1)
1322     return LowerSTOREi1(Op, DAG);
1323   else if (ValVT.isVector())
1324     return LowerSTOREVector(Op, DAG);
1325   else
1326     return SDValue();
1327 }
1328
1329 SDValue
1330 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
1331   SDNode *N = Op.getNode();
1332   SDValue Val = N->getOperand(1);
1333   SDLoc DL(N);
1334   EVT ValVT = Val.getValueType();
1335
1336   if (ValVT.isVector()) {
1337     // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1338     // legal.  We can (and should) split that into 2 stores of <2 x double> here
1339     // but I'm leaving that as a TODO for now.
1340     if (!ValVT.isSimple())
1341       return SDValue();
1342     switch (ValVT.getSimpleVT().SimpleTy) {
1343     default:
1344       return SDValue();
1345     case MVT::v2i8:
1346     case MVT::v2i16:
1347     case MVT::v2i32:
1348     case MVT::v2i64:
1349     case MVT::v2f32:
1350     case MVT::v2f64:
1351     case MVT::v4i8:
1352     case MVT::v4i16:
1353     case MVT::v4i32:
1354     case MVT::v4f32:
1355       // This is a "native" vector type
1356       break;
1357     }
1358
1359     unsigned Opcode = 0;
1360     EVT EltVT = ValVT.getVectorElementType();
1361     unsigned NumElts = ValVT.getVectorNumElements();
1362
1363     // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
1364     // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1365     // stored type to i16 and propogate the "real" type as the memory type.
1366     bool NeedSExt = false;
1367     if (EltVT.getSizeInBits() < 16)
1368       NeedSExt = true;
1369
1370     switch (NumElts) {
1371     default:
1372       return SDValue();
1373     case 2:
1374       Opcode = NVPTXISD::StoreV2;
1375       break;
1376     case 4: {
1377       Opcode = NVPTXISD::StoreV4;
1378       break;
1379     }
1380     }
1381
1382     SmallVector<SDValue, 8> Ops;
1383
1384     // First is the chain
1385     Ops.push_back(N->getOperand(0));
1386
1387     // Then the split values
1388     for (unsigned i = 0; i < NumElts; ++i) {
1389       SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
1390                                    DAG.getIntPtrConstant(i));
1391       if (NeedSExt)
1392         ExtVal = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i16, ExtVal);
1393       Ops.push_back(ExtVal);
1394     }
1395
1396     // Then any remaining arguments
1397     for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
1398       Ops.push_back(N->getOperand(i));
1399     }
1400
1401     MemSDNode *MemSD = cast<MemSDNode>(N);
1402
1403     SDValue NewSt = DAG.getMemIntrinsicNode(
1404         Opcode, DL, DAG.getVTList(MVT::Other), &Ops[0], Ops.size(),
1405         MemSD->getMemoryVT(), MemSD->getMemOperand());
1406
1407     //return DCI.CombineTo(N, NewSt, true);
1408     return NewSt;
1409   }
1410
1411   return SDValue();
1412 }
1413
1414 // st i1 v, addr
1415 //    =>
1416 // v1 = zxt v to i16
1417 // st.u8 i16, addr
1418 SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
1419   SDNode *Node = Op.getNode();
1420   SDLoc dl(Node);
1421   StoreSDNode *ST = cast<StoreSDNode>(Node);
1422   SDValue Tmp1 = ST->getChain();
1423   SDValue Tmp2 = ST->getBasePtr();
1424   SDValue Tmp3 = ST->getValue();
1425   assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
1426   unsigned Alignment = ST->getAlignment();
1427   bool isVolatile = ST->isVolatile();
1428   bool isNonTemporal = ST->isNonTemporal();
1429   Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
1430   SDValue Result = DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2,
1431                                      ST->getPointerInfo(), MVT::i8, isNonTemporal,
1432                                      isVolatile, Alignment);
1433   return Result;
1434 }
1435
1436 SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname,
1437                                         int idx, EVT v) const {
1438   std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
1439   std::stringstream suffix;
1440   suffix << idx;
1441   *name += suffix.str();
1442   return DAG.getTargetExternalSymbol(name->c_str(), v);
1443 }
1444
1445 SDValue
1446 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1447   return getExtSymb(DAG, ".PARAM", idx, v);
1448 }
1449
1450 SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1451   return getExtSymb(DAG, ".HLPPARAM", idx);
1452 }
1453
1454 // Check to see if the kernel argument is image*_t or sampler_t
1455
1456 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1457   static const char *const specialTypes[] = { "struct._image2d_t",
1458                                               "struct._image3d_t",
1459                                               "struct._sampler_t" };
1460
1461   const Type *Ty = arg->getType();
1462   const PointerType *PTy = dyn_cast<PointerType>(Ty);
1463
1464   if (!PTy)
1465     return false;
1466
1467   if (!context)
1468     return false;
1469
1470   const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1471   const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1472
1473   for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1474     if (TypeName == specialTypes[i])
1475       return true;
1476
1477   return false;
1478 }
1479
1480 SDValue NVPTXTargetLowering::LowerFormalArguments(
1481     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1482     const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG,
1483     SmallVectorImpl<SDValue> &InVals) const {
1484   MachineFunction &MF = DAG.getMachineFunction();
1485   const DataLayout *TD = getDataLayout();
1486
1487   const Function *F = MF.getFunction();
1488   const AttributeSet &PAL = F->getAttributes();
1489   const TargetLowering *TLI = nvTM->getTargetLowering();
1490
1491   SDValue Root = DAG.getRoot();
1492   std::vector<SDValue> OutChains;
1493
1494   bool isKernel = llvm::isKernelFunction(*F);
1495   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1496   assert(isABI && "Non-ABI compilation is not supported");
1497   if (!isABI)
1498     return Chain;
1499
1500   std::vector<Type *> argTypes;
1501   std::vector<const Argument *> theArgs;
1502   for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1503        I != E; ++I) {
1504     theArgs.push_back(I);
1505     argTypes.push_back(I->getType());
1506   }
1507   // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
1508   // Ins.size() will be larger
1509   //   * if there is an aggregate argument with multiple fields (each field
1510   //     showing up separately in Ins)
1511   //   * if there is a vector argument with more than typical vector-length
1512   //     elements (generally if more than 4) where each vector element is
1513   //     individually present in Ins.
1514   // So a different index should be used for indexing into Ins.
1515   // See similar issue in LowerCall.
1516   unsigned InsIdx = 0;
1517
1518   int idx = 0;
1519   for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) {
1520     Type *Ty = argTypes[i];
1521
1522     // If the kernel argument is image*_t or sampler_t, convert it to
1523     // a i32 constant holding the parameter position. This can later
1524     // matched in the AsmPrinter to output the correct mangled name.
1525     if (isImageOrSamplerVal(
1526             theArgs[i],
1527             (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent()
1528                                      : 0))) {
1529       assert(isKernel && "Only kernels can have image/sampler params");
1530       InVals.push_back(DAG.getConstant(i + 1, MVT::i32));
1531       continue;
1532     }
1533
1534     if (theArgs[i]->use_empty()) {
1535       // argument is dead
1536       if (Ty->isAggregateType()) {
1537         SmallVector<EVT, 16> vtparts;
1538
1539         ComputePTXValueVTs(*this, Ty, vtparts);
1540         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1541         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1542              ++parti) {
1543           EVT partVT = vtparts[parti];
1544           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, partVT));
1545           ++InsIdx;
1546         }
1547         if (vtparts.size() > 0)
1548           --InsIdx;
1549         continue;
1550       }
1551       if (Ty->isVectorTy()) {
1552         EVT ObjectVT = getValueType(Ty);
1553         unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
1554         for (unsigned parti = 0; parti < NumRegs; ++parti) {
1555           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1556           ++InsIdx;
1557         }
1558         if (NumRegs > 0)
1559           --InsIdx;
1560         continue;
1561       }
1562       InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1563       continue;
1564     }
1565
1566     // In the following cases, assign a node order of "idx+1"
1567     // to newly created nodes. The SDNodes for params have to
1568     // appear in the same order as their order of appearance
1569     // in the original function. "idx+1" holds that order.
1570     if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
1571       if (Ty->isAggregateType()) {
1572         SmallVector<EVT, 16> vtparts;
1573         SmallVector<uint64_t, 16> offsets;
1574
1575         // NOTE: Here, we lose the ability to issue vector loads for vectors
1576         // that are a part of a struct.  This should be investigated in the
1577         // future.
1578         ComputePTXValueVTs(*this, Ty, vtparts, &offsets, 0);
1579         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1580         bool aggregateIsPacked = false;
1581         if (StructType *STy = llvm::dyn_cast<StructType>(Ty))
1582           aggregateIsPacked = STy->isPacked();
1583
1584         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1585         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1586              ++parti) {
1587           EVT partVT = vtparts[parti];
1588           Value *srcValue = Constant::getNullValue(
1589               PointerType::get(partVT.getTypeForEVT(F->getContext()),
1590                                llvm::ADDRESS_SPACE_PARAM));
1591           SDValue srcAddr =
1592               DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1593                           DAG.getConstant(offsets[parti], getPointerTy()));
1594           unsigned partAlign =
1595               aggregateIsPacked ? 1
1596                                 : TD->getABITypeAlignment(
1597                                       partVT.getTypeForEVT(F->getContext()));
1598                     SDValue p;
1599           if (Ins[InsIdx].VT.getSizeInBits() > partVT.getSizeInBits())
1600             p = DAG.getExtLoad(ISD::SEXTLOAD, dl, Ins[InsIdx].VT, Root, srcAddr,
1601                                MachinePointerInfo(srcValue), partVT, false,
1602                                false, partAlign);
1603           else
1604             p = DAG.getLoad(partVT, dl, Root, srcAddr,
1605                             MachinePointerInfo(srcValue), false, false, false,
1606                             partAlign);
1607           if (p.getNode())
1608             p.getNode()->setIROrder(idx + 1);
1609           InVals.push_back(p);
1610           ++InsIdx;
1611         }
1612         if (vtparts.size() > 0)
1613           --InsIdx;
1614         continue;
1615       }
1616       if (Ty->isVectorTy()) {
1617         EVT ObjectVT = getValueType(Ty);
1618         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1619         unsigned NumElts = ObjectVT.getVectorNumElements();
1620         assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1621                "Vector was not scalarized");
1622         unsigned Ofst = 0;
1623         EVT EltVT = ObjectVT.getVectorElementType();
1624
1625         // V1 load
1626         // f32 = load ...
1627         if (NumElts == 1) {
1628           // We only have one element, so just directly load it
1629           Value *SrcValue = Constant::getNullValue(PointerType::get(
1630               EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1631           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1632                                         DAG.getConstant(Ofst, getPointerTy()));
1633           SDValue P = DAG.getLoad(
1634               EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1635               false, true,
1636               TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
1637           if (P.getNode())
1638             P.getNode()->setIROrder(idx + 1);
1639
1640           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1641             P = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, P);
1642           InVals.push_back(P);
1643           Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
1644           ++InsIdx;
1645         } else if (NumElts == 2) {
1646           // V2 load
1647           // f32,f32 = load ...
1648           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
1649           Value *SrcValue = Constant::getNullValue(PointerType::get(
1650               VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1651           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1652                                         DAG.getConstant(Ofst, getPointerTy()));
1653           SDValue P = DAG.getLoad(
1654               VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1655               false, true,
1656               TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1657           if (P.getNode())
1658             P.getNode()->setIROrder(idx + 1);
1659
1660           SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1661                                      DAG.getIntPtrConstant(0));
1662           SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1663                                      DAG.getIntPtrConstant(1));
1664
1665           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) {
1666             Elt0 = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, Elt0);
1667             Elt1 = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, Elt1);
1668           }
1669
1670           InVals.push_back(Elt0);
1671           InVals.push_back(Elt1);
1672           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1673           InsIdx += 2;
1674         } else {
1675           // V4 loads
1676           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
1677           // the
1678           // vector will be expanded to a power of 2 elements, so we know we can
1679           // always round up to the next multiple of 4 when creating the vector
1680           // loads.
1681           // e.g.  4 elem => 1 ld.v4
1682           //       6 elem => 2 ld.v4
1683           //       8 elem => 2 ld.v4
1684           //      11 elem => 3 ld.v4
1685           unsigned VecSize = 4;
1686           if (EltVT.getSizeInBits() == 64) {
1687             VecSize = 2;
1688           }
1689           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1690           for (unsigned i = 0; i < NumElts; i += VecSize) {
1691             Value *SrcValue = Constant::getNullValue(
1692                 PointerType::get(VecVT.getTypeForEVT(F->getContext()),
1693                                  llvm::ADDRESS_SPACE_PARAM));
1694             SDValue SrcAddr =
1695                 DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1696                             DAG.getConstant(Ofst, getPointerTy()));
1697             SDValue P = DAG.getLoad(
1698                 VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1699                 false, true,
1700                 TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1701             if (P.getNode())
1702               P.getNode()->setIROrder(idx + 1);
1703
1704             for (unsigned j = 0; j < VecSize; ++j) {
1705               if (i + j >= NumElts)
1706                 break;
1707               SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1708                                         DAG.getIntPtrConstant(j));
1709               if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1710                 Elt = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, Elt);
1711               InVals.push_back(Elt);
1712             }
1713             Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1714             InsIdx += VecSize;
1715           }
1716         }
1717
1718         if (NumElts > 0)
1719           --InsIdx;
1720         continue;
1721       }
1722       // A plain scalar.
1723       EVT ObjectVT = getValueType(Ty);
1724       // If ABI, load from the param symbol
1725       SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1726       Value *srcValue = Constant::getNullValue(PointerType::get(
1727           ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1728       SDValue p;
1729       if (ObjectVT.getSizeInBits() < Ins[InsIdx].VT.getSizeInBits())
1730         p = DAG.getExtLoad(ISD::SEXTLOAD, dl, Ins[InsIdx].VT, Root, Arg,
1731                            MachinePointerInfo(srcValue), ObjectVT, false, false,
1732               TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1733       else
1734         p = DAG.getLoad(Ins[InsIdx].VT, dl, Root, Arg,
1735                         MachinePointerInfo(srcValue), false, false, false,
1736               TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1737       if (p.getNode())
1738         p.getNode()->setIROrder(idx + 1);
1739       InVals.push_back(p);
1740       continue;
1741     }
1742
1743     // Param has ByVal attribute
1744     // Return MoveParam(param symbol).
1745     // Ideally, the param symbol can be returned directly,
1746     // but when SDNode builder decides to use it in a CopyToReg(),
1747     // machine instruction fails because TargetExternalSymbol
1748     // (not lowered) is target dependent, and CopyToReg assumes
1749     // the source is lowered.
1750     EVT ObjectVT = getValueType(Ty);
1751     assert(ObjectVT == Ins[InsIdx].VT &&
1752            "Ins type did not match function type");
1753     SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1754     SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1755     if (p.getNode())
1756       p.getNode()->setIROrder(idx + 1);
1757     if (isKernel)
1758       InVals.push_back(p);
1759     else {
1760       SDValue p2 = DAG.getNode(
1761           ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1762           DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
1763       InVals.push_back(p2);
1764     }
1765   }
1766
1767   // Clang will check explicit VarArg and issue error if any. However, Clang
1768   // will let code with
1769   // implicit var arg like f() pass. See bug 617733.
1770   // We treat this case as if the arg list is empty.
1771   // if (F.isVarArg()) {
1772   // assert(0 && "VarArg not supported yet!");
1773   //}
1774
1775   if (!OutChains.empty())
1776     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &OutChains[0],
1777                             OutChains.size()));
1778
1779   return Chain;
1780 }
1781
1782
1783 SDValue
1784 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1785                                  bool isVarArg,
1786                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
1787                                  const SmallVectorImpl<SDValue> &OutVals,
1788                                  SDLoc dl, SelectionDAG &DAG) const {
1789   MachineFunction &MF = DAG.getMachineFunction();
1790   const Function *F = MF.getFunction();
1791   const Type *RetTy = F->getReturnType();
1792   const DataLayout *TD = getDataLayout();
1793
1794   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1795   assert(isABI && "Non-ABI compilation is not supported");
1796   if (!isABI)
1797     return Chain;
1798
1799   if (const VectorType *VTy = dyn_cast<const VectorType>(RetTy)) {
1800     // If we have a vector type, the OutVals array will be the scalarized
1801     // components and we have combine them into 1 or more vector stores.
1802     unsigned NumElts = VTy->getNumElements();
1803     assert(NumElts == Outs.size() && "Bad scalarization of return value");
1804
1805     // const_cast can be removed in later LLVM versions
1806     EVT EltVT = getValueType(const_cast<Type *>(RetTy)).getVectorElementType();
1807     bool NeedExtend = false;
1808     if (EltVT.getSizeInBits() < 16)
1809       NeedExtend = true;
1810
1811     // V1 store
1812     if (NumElts == 1) {
1813       SDValue StoreVal = OutVals[0];
1814       // We only have one element, so just directly store it
1815       if (NeedExtend)
1816         StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
1817       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal };
1818       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1819                                       DAG.getVTList(MVT::Other), &Ops[0], 3,
1820                                       EltVT, MachinePointerInfo());
1821
1822     } else if (NumElts == 2) {
1823       // V2 store
1824       SDValue StoreVal0 = OutVals[0];
1825       SDValue StoreVal1 = OutVals[1];
1826
1827       if (NeedExtend) {
1828         StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal0);
1829         StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal1);
1830       }
1831
1832       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal0,
1833                         StoreVal1 };
1834       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetvalV2, dl,
1835                                       DAG.getVTList(MVT::Other), &Ops[0], 4,
1836                                       EltVT, MachinePointerInfo());
1837     } else {
1838       // V4 stores
1839       // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
1840       // vector will be expanded to a power of 2 elements, so we know we can
1841       // always round up to the next multiple of 4 when creating the vector
1842       // stores.
1843       // e.g.  4 elem => 1 st.v4
1844       //       6 elem => 2 st.v4
1845       //       8 elem => 2 st.v4
1846       //      11 elem => 3 st.v4
1847
1848       unsigned VecSize = 4;
1849       if (OutVals[0].getValueType().getSizeInBits() == 64)
1850         VecSize = 2;
1851
1852       unsigned Offset = 0;
1853
1854       EVT VecVT =
1855           EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
1856       unsigned PerStoreOffset =
1857           TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1858
1859       for (unsigned i = 0; i < NumElts; i += VecSize) {
1860         // Get values
1861         SDValue StoreVal;
1862         SmallVector<SDValue, 8> Ops;
1863         Ops.push_back(Chain);
1864         Ops.push_back(DAG.getConstant(Offset, MVT::i32));
1865         unsigned Opc = NVPTXISD::StoreRetvalV2;
1866         EVT ExtendedVT = (NeedExtend) ? MVT::i16 : OutVals[0].getValueType();
1867
1868         StoreVal = OutVals[i];
1869         if (NeedExtend)
1870           StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1871         Ops.push_back(StoreVal);
1872
1873         if (i + 1 < NumElts) {
1874           StoreVal = OutVals[i + 1];
1875           if (NeedExtend)
1876             StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1877         } else {
1878           StoreVal = DAG.getUNDEF(ExtendedVT);
1879         }
1880         Ops.push_back(StoreVal);
1881
1882         if (VecSize == 4) {
1883           Opc = NVPTXISD::StoreRetvalV4;
1884           if (i + 2 < NumElts) {
1885             StoreVal = OutVals[i + 2];
1886             if (NeedExtend)
1887               StoreVal =
1888                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1889           } else {
1890             StoreVal = DAG.getUNDEF(ExtendedVT);
1891           }
1892           Ops.push_back(StoreVal);
1893
1894           if (i + 3 < NumElts) {
1895             StoreVal = OutVals[i + 3];
1896             if (NeedExtend)
1897               StoreVal =
1898                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1899           } else {
1900             StoreVal = DAG.getUNDEF(ExtendedVT);
1901           }
1902           Ops.push_back(StoreVal);
1903         }
1904
1905         // Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
1906         Chain =
1907             DAG.getMemIntrinsicNode(Opc, dl, DAG.getVTList(MVT::Other), &Ops[0],
1908                                     Ops.size(), EltVT, MachinePointerInfo());
1909         Offset += PerStoreOffset;
1910       }
1911     }
1912   } else {
1913     SmallVector<EVT, 16> ValVTs;
1914     // const_cast is necessary since we are still using an LLVM version from
1915     // before the type system re-write.
1916     ComputePTXValueVTs(*this, const_cast<Type *>(RetTy), ValVTs);
1917     assert(ValVTs.size() == OutVals.size() && "Bad return value decomposition");
1918
1919     unsigned sizesofar = 0;
1920     for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
1921       SDValue theVal = OutVals[i];
1922       EVT theValType = theVal.getValueType();
1923       unsigned numElems = 1;
1924       if (theValType.isVector())
1925         numElems = theValType.getVectorNumElements();
1926       for (unsigned j = 0, je = numElems; j != je; ++j) {
1927         SDValue tmpval = theVal;
1928         if (theValType.isVector())
1929           tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1930                                theValType.getVectorElementType(), tmpval,
1931                                DAG.getIntPtrConstant(j));
1932         EVT theStoreType = tmpval.getValueType();
1933         if (theStoreType.getSizeInBits() < 8)
1934           tmpval = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, tmpval);
1935         SDValue Ops[] = { Chain, DAG.getConstant(sizesofar, MVT::i32), tmpval };
1936         Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1937                                         DAG.getVTList(MVT::Other), &Ops[0], 3,
1938                                         ValVTs[i], MachinePointerInfo());
1939         if (theValType.isVector())
1940           sizesofar +=
1941               ValVTs[i].getVectorElementType().getStoreSizeInBits() / 8;
1942         else
1943           sizesofar += ValVTs[i].getStoreSizeInBits() / 8;
1944       }
1945     }
1946   }
1947
1948   return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1949 }
1950
1951
1952 void NVPTXTargetLowering::LowerAsmOperandForConstraint(
1953     SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
1954     SelectionDAG &DAG) const {
1955   if (Constraint.length() > 1)
1956     return;
1957   else
1958     TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1959 }
1960
1961 // NVPTX suuport vector of legal types of any length in Intrinsics because the
1962 // NVPTX specific type legalizer
1963 // will legalize them to the PTX supported length.
1964 bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1965   if (isTypeLegal(VT))
1966     return true;
1967   if (VT.isVector()) {
1968     MVT eVT = VT.getVectorElementType();
1969     if (isTypeLegal(eVT))
1970       return true;
1971   }
1972   return false;
1973 }
1974
1975 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
1976 // TgtMemIntrinsic
1977 // because we need the information that is only available in the "Value" type
1978 // of destination
1979 // pointer. In particular, the address space information.
1980 bool NVPTXTargetLowering::getTgtMemIntrinsic(
1981     IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const {
1982   switch (Intrinsic) {
1983   default:
1984     return false;
1985
1986   case Intrinsic::nvvm_atomic_load_add_f32:
1987     Info.opc = ISD::INTRINSIC_W_CHAIN;
1988     Info.memVT = MVT::f32;
1989     Info.ptrVal = I.getArgOperand(0);
1990     Info.offset = 0;
1991     Info.vol = 0;
1992     Info.readMem = true;
1993     Info.writeMem = true;
1994     Info.align = 0;
1995     return true;
1996
1997   case Intrinsic::nvvm_atomic_load_inc_32:
1998   case Intrinsic::nvvm_atomic_load_dec_32:
1999     Info.opc = ISD::INTRINSIC_W_CHAIN;
2000     Info.memVT = MVT::i32;
2001     Info.ptrVal = I.getArgOperand(0);
2002     Info.offset = 0;
2003     Info.vol = 0;
2004     Info.readMem = true;
2005     Info.writeMem = true;
2006     Info.align = 0;
2007     return true;
2008
2009   case Intrinsic::nvvm_ldu_global_i:
2010   case Intrinsic::nvvm_ldu_global_f:
2011   case Intrinsic::nvvm_ldu_global_p:
2012
2013     Info.opc = ISD::INTRINSIC_W_CHAIN;
2014     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
2015       Info.memVT = getValueType(I.getType());
2016     else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
2017       Info.memVT = getValueType(I.getType());
2018     else
2019       Info.memVT = MVT::f32;
2020     Info.ptrVal = I.getArgOperand(0);
2021     Info.offset = 0;
2022     Info.vol = 0;
2023     Info.readMem = true;
2024     Info.writeMem = false;
2025     Info.align = 0;
2026     return true;
2027
2028   }
2029   return false;
2030 }
2031
2032 /// isLegalAddressingMode - Return true if the addressing mode represented
2033 /// by AM is legal for this target, for a load/store of the specified type.
2034 /// Used to guide target specific optimizations, like loop strength reduction
2035 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
2036 /// (CodeGenPrepare.cpp)
2037 bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
2038                                                 Type *Ty) const {
2039
2040   // AddrMode - This represents an addressing mode of:
2041   //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
2042   //
2043   // The legal address modes are
2044   // - [avar]
2045   // - [areg]
2046   // - [areg+immoff]
2047   // - [immAddr]
2048
2049   if (AM.BaseGV) {
2050     if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
2051       return false;
2052     return true;
2053   }
2054
2055   switch (AM.Scale) {
2056   case 0: // "r", "r+i" or "i" is allowed
2057     break;
2058   case 1:
2059     if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
2060       return false;
2061     // Otherwise we have r+i.
2062     break;
2063   default:
2064     // No scale > 1 is allowed
2065     return false;
2066   }
2067   return true;
2068 }
2069
2070 //===----------------------------------------------------------------------===//
2071 //                         NVPTX Inline Assembly Support
2072 //===----------------------------------------------------------------------===//
2073
2074 /// getConstraintType - Given a constraint letter, return the type of
2075 /// constraint it is for this target.
2076 NVPTXTargetLowering::ConstraintType
2077 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
2078   if (Constraint.size() == 1) {
2079     switch (Constraint[0]) {
2080     default:
2081       break;
2082     case 'r':
2083     case 'h':
2084     case 'c':
2085     case 'l':
2086     case 'f':
2087     case 'd':
2088     case '0':
2089     case 'N':
2090       return C_RegisterClass;
2091     }
2092   }
2093   return TargetLowering::getConstraintType(Constraint);
2094 }
2095
2096 std::pair<unsigned, const TargetRegisterClass *>
2097 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
2098                                                   MVT VT) const {
2099   if (Constraint.size() == 1) {
2100     switch (Constraint[0]) {
2101     case 'c':
2102       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2103     case 'h':
2104       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2105     case 'r':
2106       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
2107     case 'l':
2108     case 'N':
2109       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
2110     case 'f':
2111       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
2112     case 'd':
2113       return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
2114     }
2115   }
2116   return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
2117 }
2118
2119 /// getFunctionAlignment - Return the Log2 alignment of this function.
2120 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
2121   return 4;
2122 }
2123
2124 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
2125 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
2126                               SmallVectorImpl<SDValue> &Results) {
2127   EVT ResVT = N->getValueType(0);
2128   SDLoc DL(N);
2129
2130   assert(ResVT.isVector() && "Vector load must have vector type");
2131
2132   // We only handle "native" vector sizes for now, e.g. <4 x double> is not
2133   // legal.  We can (and should) split that into 2 loads of <2 x double> here
2134   // but I'm leaving that as a TODO for now.
2135   assert(ResVT.isSimple() && "Can only handle simple types");
2136   switch (ResVT.getSimpleVT().SimpleTy) {
2137   default:
2138     return;
2139   case MVT::v2i8:
2140   case MVT::v2i16:
2141   case MVT::v2i32:
2142   case MVT::v2i64:
2143   case MVT::v2f32:
2144   case MVT::v2f64:
2145   case MVT::v4i8:
2146   case MVT::v4i16:
2147   case MVT::v4i32:
2148   case MVT::v4f32:
2149     // This is a "native" vector type
2150     break;
2151   }
2152
2153   EVT EltVT = ResVT.getVectorElementType();
2154   unsigned NumElts = ResVT.getVectorNumElements();
2155
2156   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
2157   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2158   // loaded type to i16 and propogate the "real" type as the memory type.
2159   bool NeedTrunc = false;
2160   if (EltVT.getSizeInBits() < 16) {
2161     EltVT = MVT::i16;
2162     NeedTrunc = true;
2163   }
2164
2165   unsigned Opcode = 0;
2166   SDVTList LdResVTs;
2167
2168   switch (NumElts) {
2169   default:
2170     return;
2171   case 2:
2172     Opcode = NVPTXISD::LoadV2;
2173     LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2174     break;
2175   case 4: {
2176     Opcode = NVPTXISD::LoadV4;
2177     EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2178     LdResVTs = DAG.getVTList(ListVTs, 5);
2179     break;
2180   }
2181   }
2182
2183   SmallVector<SDValue, 8> OtherOps;
2184
2185   // Copy regular operands
2186   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2187     OtherOps.push_back(N->getOperand(i));
2188
2189   LoadSDNode *LD = cast<LoadSDNode>(N);
2190
2191   // The select routine does not have access to the LoadSDNode instance, so
2192   // pass along the extension information
2193   OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
2194
2195   SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
2196                                           OtherOps.size(), LD->getMemoryVT(),
2197                                           LD->getMemOperand());
2198
2199   SmallVector<SDValue, 4> ScalarRes;
2200
2201   for (unsigned i = 0; i < NumElts; ++i) {
2202     SDValue Res = NewLD.getValue(i);
2203     if (NeedTrunc)
2204       Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2205     ScalarRes.push_back(Res);
2206   }
2207
2208   SDValue LoadChain = NewLD.getValue(NumElts);
2209
2210   SDValue BuildVec =
2211       DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2212
2213   Results.push_back(BuildVec);
2214   Results.push_back(LoadChain);
2215 }
2216
2217 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
2218                                      SmallVectorImpl<SDValue> &Results) {
2219   SDValue Chain = N->getOperand(0);
2220   SDValue Intrin = N->getOperand(1);
2221   SDLoc DL(N);
2222
2223   // Get the intrinsic ID
2224   unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
2225   switch (IntrinNo) {
2226   default:
2227     return;
2228   case Intrinsic::nvvm_ldg_global_i:
2229   case Intrinsic::nvvm_ldg_global_f:
2230   case Intrinsic::nvvm_ldg_global_p:
2231   case Intrinsic::nvvm_ldu_global_i:
2232   case Intrinsic::nvvm_ldu_global_f:
2233   case Intrinsic::nvvm_ldu_global_p: {
2234     EVT ResVT = N->getValueType(0);
2235
2236     if (ResVT.isVector()) {
2237       // Vector LDG/LDU
2238
2239       unsigned NumElts = ResVT.getVectorNumElements();
2240       EVT EltVT = ResVT.getVectorElementType();
2241
2242       // Since LDU/LDG are target nodes, we cannot rely on DAG type
2243       // legalization.
2244       // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2245       // loaded type to i16 and propogate the "real" type as the memory type.
2246       bool NeedTrunc = false;
2247       if (EltVT.getSizeInBits() < 16) {
2248         EltVT = MVT::i16;
2249         NeedTrunc = true;
2250       }
2251
2252       unsigned Opcode = 0;
2253       SDVTList LdResVTs;
2254
2255       switch (NumElts) {
2256       default:
2257         return;
2258       case 2:
2259         switch (IntrinNo) {
2260         default:
2261           return;
2262         case Intrinsic::nvvm_ldg_global_i:
2263         case Intrinsic::nvvm_ldg_global_f:
2264         case Intrinsic::nvvm_ldg_global_p:
2265           Opcode = NVPTXISD::LDGV2;
2266           break;
2267         case Intrinsic::nvvm_ldu_global_i:
2268         case Intrinsic::nvvm_ldu_global_f:
2269         case Intrinsic::nvvm_ldu_global_p:
2270           Opcode = NVPTXISD::LDUV2;
2271           break;
2272         }
2273         LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2274         break;
2275       case 4: {
2276         switch (IntrinNo) {
2277         default:
2278           return;
2279         case Intrinsic::nvvm_ldg_global_i:
2280         case Intrinsic::nvvm_ldg_global_f:
2281         case Intrinsic::nvvm_ldg_global_p:
2282           Opcode = NVPTXISD::LDGV4;
2283           break;
2284         case Intrinsic::nvvm_ldu_global_i:
2285         case Intrinsic::nvvm_ldu_global_f:
2286         case Intrinsic::nvvm_ldu_global_p:
2287           Opcode = NVPTXISD::LDUV4;
2288           break;
2289         }
2290         EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2291         LdResVTs = DAG.getVTList(ListVTs, 5);
2292         break;
2293       }
2294       }
2295
2296       SmallVector<SDValue, 8> OtherOps;
2297
2298       // Copy regular operands
2299
2300       OtherOps.push_back(Chain); // Chain
2301                                  // Skip operand 1 (intrinsic ID)
2302       // Others
2303       for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
2304         OtherOps.push_back(N->getOperand(i));
2305
2306       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2307
2308       SDValue NewLD = DAG.getMemIntrinsicNode(
2309           Opcode, DL, LdResVTs, &OtherOps[0], OtherOps.size(),
2310           MemSD->getMemoryVT(), MemSD->getMemOperand());
2311
2312       SmallVector<SDValue, 4> ScalarRes;
2313
2314       for (unsigned i = 0; i < NumElts; ++i) {
2315         SDValue Res = NewLD.getValue(i);
2316         if (NeedTrunc)
2317           Res =
2318               DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2319         ScalarRes.push_back(Res);
2320       }
2321
2322       SDValue LoadChain = NewLD.getValue(NumElts);
2323
2324       SDValue BuildVec =
2325           DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2326
2327       Results.push_back(BuildVec);
2328       Results.push_back(LoadChain);
2329     } else {
2330       // i8 LDG/LDU
2331       assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
2332              "Custom handling of non-i8 ldu/ldg?");
2333
2334       // Just copy all operands as-is
2335       SmallVector<SDValue, 4> Ops;
2336       for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2337         Ops.push_back(N->getOperand(i));
2338
2339       // Force output to i16
2340       SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
2341
2342       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2343
2344       // We make sure the memory type is i8, which will be used during isel
2345       // to select the proper instruction.
2346       SDValue NewLD =
2347           DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0],
2348                                   Ops.size(), MVT::i8, MemSD->getMemOperand());
2349
2350       Results.push_back(NewLD.getValue(0));
2351       Results.push_back(NewLD.getValue(1));
2352     }
2353   }
2354   }
2355 }
2356
2357 void NVPTXTargetLowering::ReplaceNodeResults(
2358     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
2359   switch (N->getOpcode()) {
2360   default:
2361     report_fatal_error("Unhandled custom legalization");
2362   case ISD::LOAD:
2363     ReplaceLoadVector(N, DAG, Results);
2364     return;
2365   case ISD::INTRINSIC_W_CHAIN:
2366     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
2367     return;
2368   }
2369 }