[NVPTX] Clean up comparison/select/convert patterns and factor out PTX instructions...
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
1 //
2 //                     The LLVM Compiler Infrastructure
3 //
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "NVPTXISelLowering.h"
15 #include "NVPTX.h"
16 #include "NVPTXTargetMachine.h"
17 #include "NVPTXTargetObjectFile.h"
18 #include "NVPTXUtilities.h"
19 #include "llvm/CodeGen/Analysis.h"
20 #include "llvm/CodeGen/MachineFrameInfo.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/MC/MCSectionELF.h"
32 #include "llvm/Support/CallSite.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <sstream>
38
39 #undef DEBUG_TYPE
40 #define DEBUG_TYPE "nvptx-lower"
41
42 using namespace llvm;
43
44 static unsigned int uniqueCallSite = 0;
45
46 static cl::opt<bool> sched4reg(
47     "nvptx-sched4reg",
48     cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
49
50 static bool IsPTXVectorType(MVT VT) {
51   switch (VT.SimpleTy) {
52   default:
53     return false;
54   case MVT::v2i1:
55   case MVT::v4i1:
56   case MVT::v2i8:
57   case MVT::v4i8:
58   case MVT::v2i16:
59   case MVT::v4i16:
60   case MVT::v2i32:
61   case MVT::v4i32:
62   case MVT::v2i64:
63   case MVT::v2f32:
64   case MVT::v4f32:
65   case MVT::v2f64:
66     return true;
67   }
68 }
69
70 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
71 /// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
72 /// into their primitive components.
73 /// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
74 /// same number of types as the Ins/Outs arrays in LowerFormalArguments,
75 /// LowerCall, and LowerReturn.
76 static void ComputePTXValueVTs(const TargetLowering &TLI, Type *Ty,
77                                SmallVectorImpl<EVT> &ValueVTs,
78                                SmallVectorImpl<uint64_t> *Offsets = 0,
79                                uint64_t StartingOffset = 0) {
80   SmallVector<EVT, 16> TempVTs;
81   SmallVector<uint64_t, 16> TempOffsets;
82
83   ComputeValueVTs(TLI, Ty, TempVTs, &TempOffsets, StartingOffset);
84   for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
85     EVT VT = TempVTs[i];
86     uint64_t Off = TempOffsets[i];
87     if (VT.isVector())
88       for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) {
89         ValueVTs.push_back(VT.getVectorElementType());
90         if (Offsets)
91           Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize());
92       }
93     else {
94       ValueVTs.push_back(VT);
95       if (Offsets)
96         Offsets->push_back(Off);
97     }
98   }
99 }
100
101 // NVPTXTargetLowering Constructor.
102 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
103     : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
104       nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
105
106   // always lower memset, memcpy, and memmove intrinsics to load/store
107   // instructions, rather
108   // then generating calls to memset, mempcy or memmove.
109   MaxStoresPerMemset = (unsigned) 0xFFFFFFFF;
110   MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF;
111   MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF;
112
113   setBooleanContents(ZeroOrNegativeOneBooleanContent);
114
115   // Jump is Expensive. Don't create extra control flow for 'and', 'or'
116   // condition branches.
117   setJumpIsExpensive(true);
118
119   // By default, use the Source scheduling
120   if (sched4reg)
121     setSchedulingPreference(Sched::RegPressure);
122   else
123     setSchedulingPreference(Sched::Source);
124
125   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
126   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
127   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
128   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
129   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
130   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
131
132   // Operations not directly supported by NVPTX.
133   setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
134   setOperationAction(ISD::BR_CC, MVT::f32, Expand);
135   setOperationAction(ISD::BR_CC, MVT::f64, Expand);
136   setOperationAction(ISD::BR_CC, MVT::i1, Expand);
137   setOperationAction(ISD::BR_CC, MVT::i8, Expand);
138   setOperationAction(ISD::BR_CC, MVT::i16, Expand);
139   setOperationAction(ISD::BR_CC, MVT::i32, Expand);
140   setOperationAction(ISD::BR_CC, MVT::i64, Expand);
141   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Expand);
142   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Expand);
143   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand);
144   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8, Expand);
145   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
146
147   if (nvptxSubtarget.hasROT64()) {
148     setOperationAction(ISD::ROTL, MVT::i64, Legal);
149     setOperationAction(ISD::ROTR, MVT::i64, Legal);
150   } else {
151     setOperationAction(ISD::ROTL, MVT::i64, Expand);
152     setOperationAction(ISD::ROTR, MVT::i64, Expand);
153   }
154   if (nvptxSubtarget.hasROT32()) {
155     setOperationAction(ISD::ROTL, MVT::i32, Legal);
156     setOperationAction(ISD::ROTR, MVT::i32, Legal);
157   } else {
158     setOperationAction(ISD::ROTL, MVT::i32, Expand);
159     setOperationAction(ISD::ROTR, MVT::i32, Expand);
160   }
161
162   setOperationAction(ISD::ROTL, MVT::i16, Expand);
163   setOperationAction(ISD::ROTR, MVT::i16, Expand);
164   setOperationAction(ISD::ROTL, MVT::i8, Expand);
165   setOperationAction(ISD::ROTR, MVT::i8, Expand);
166   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
167   setOperationAction(ISD::BSWAP, MVT::i32, Expand);
168   setOperationAction(ISD::BSWAP, MVT::i64, Expand);
169
170   // Indirect branch is not supported.
171   // This also disables Jump Table creation.
172   setOperationAction(ISD::BR_JT, MVT::Other, Expand);
173   setOperationAction(ISD::BRIND, MVT::Other, Expand);
174
175   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
176   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
177
178   // We want to legalize constant related memmove and memcopy
179   // intrinsics.
180   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
181
182   // Turn FP extload into load/fextend
183   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
184   // Turn FP truncstore into trunc + store.
185   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
186
187   // PTX does not support load / store predicate registers
188   setOperationAction(ISD::LOAD, MVT::i1, Custom);
189   setOperationAction(ISD::STORE, MVT::i1, Custom);
190
191   setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
192   setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
193   setTruncStoreAction(MVT::i64, MVT::i1, Expand);
194   setTruncStoreAction(MVT::i32, MVT::i1, Expand);
195   setTruncStoreAction(MVT::i16, MVT::i1, Expand);
196   setTruncStoreAction(MVT::i8, MVT::i1, Expand);
197
198   // This is legal in NVPTX
199   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
200   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
201
202   // TRAP can be lowered to PTX trap
203   setOperationAction(ISD::TRAP, MVT::Other, Legal);
204
205   // Register custom handling for vector loads/stores
206   for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
207        ++i) {
208     MVT VT = (MVT::SimpleValueType) i;
209     if (IsPTXVectorType(VT)) {
210       setOperationAction(ISD::LOAD, VT, Custom);
211       setOperationAction(ISD::STORE, VT, Custom);
212       setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
213     }
214   }
215
216   // Custom handling for i8 intrinsics
217   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
218
219   // Now deduce the information based on the above mentioned
220   // actions
221   computeRegisterProperties();
222 }
223
224 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
225   switch (Opcode) {
226   default:
227     return 0;
228   case NVPTXISD::CALL:
229     return "NVPTXISD::CALL";
230   case NVPTXISD::RET_FLAG:
231     return "NVPTXISD::RET_FLAG";
232   case NVPTXISD::Wrapper:
233     return "NVPTXISD::Wrapper";
234   case NVPTXISD::NVBuiltin:
235     return "NVPTXISD::NVBuiltin";
236   case NVPTXISD::DeclareParam:
237     return "NVPTXISD::DeclareParam";
238   case NVPTXISD::DeclareScalarParam:
239     return "NVPTXISD::DeclareScalarParam";
240   case NVPTXISD::DeclareRet:
241     return "NVPTXISD::DeclareRet";
242   case NVPTXISD::DeclareRetParam:
243     return "NVPTXISD::DeclareRetParam";
244   case NVPTXISD::PrintCall:
245     return "NVPTXISD::PrintCall";
246   case NVPTXISD::LoadParam:
247     return "NVPTXISD::LoadParam";
248   case NVPTXISD::LoadParamV2:
249     return "NVPTXISD::LoadParamV2";
250   case NVPTXISD::LoadParamV4:
251     return "NVPTXISD::LoadParamV4";
252   case NVPTXISD::StoreParam:
253     return "NVPTXISD::StoreParam";
254   case NVPTXISD::StoreParamV2:
255     return "NVPTXISD::StoreParamV2";
256   case NVPTXISD::StoreParamV4:
257     return "NVPTXISD::StoreParamV4";
258   case NVPTXISD::StoreParamS32:
259     return "NVPTXISD::StoreParamS32";
260   case NVPTXISD::StoreParamU32:
261     return "NVPTXISD::StoreParamU32";
262   case NVPTXISD::CallArgBegin:
263     return "NVPTXISD::CallArgBegin";
264   case NVPTXISD::CallArg:
265     return "NVPTXISD::CallArg";
266   case NVPTXISD::LastCallArg:
267     return "NVPTXISD::LastCallArg";
268   case NVPTXISD::CallArgEnd:
269     return "NVPTXISD::CallArgEnd";
270   case NVPTXISD::CallVoid:
271     return "NVPTXISD::CallVoid";
272   case NVPTXISD::CallVal:
273     return "NVPTXISD::CallVal";
274   case NVPTXISD::CallSymbol:
275     return "NVPTXISD::CallSymbol";
276   case NVPTXISD::Prototype:
277     return "NVPTXISD::Prototype";
278   case NVPTXISD::MoveParam:
279     return "NVPTXISD::MoveParam";
280   case NVPTXISD::StoreRetval:
281     return "NVPTXISD::StoreRetval";
282   case NVPTXISD::StoreRetvalV2:
283     return "NVPTXISD::StoreRetvalV2";
284   case NVPTXISD::StoreRetvalV4:
285     return "NVPTXISD::StoreRetvalV4";
286   case NVPTXISD::PseudoUseParam:
287     return "NVPTXISD::PseudoUseParam";
288   case NVPTXISD::RETURN:
289     return "NVPTXISD::RETURN";
290   case NVPTXISD::CallSeqBegin:
291     return "NVPTXISD::CallSeqBegin";
292   case NVPTXISD::CallSeqEnd:
293     return "NVPTXISD::CallSeqEnd";
294   case NVPTXISD::LoadV2:
295     return "NVPTXISD::LoadV2";
296   case NVPTXISD::LoadV4:
297     return "NVPTXISD::LoadV4";
298   case NVPTXISD::LDGV2:
299     return "NVPTXISD::LDGV2";
300   case NVPTXISD::LDGV4:
301     return "NVPTXISD::LDGV4";
302   case NVPTXISD::LDUV2:
303     return "NVPTXISD::LDUV2";
304   case NVPTXISD::LDUV4:
305     return "NVPTXISD::LDUV4";
306   case NVPTXISD::StoreV2:
307     return "NVPTXISD::StoreV2";
308   case NVPTXISD::StoreV4:
309     return "NVPTXISD::StoreV4";
310   }
311 }
312
313 bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
314   return VT == MVT::i1;
315 }
316
317 SDValue
318 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
319   SDLoc dl(Op);
320   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
321   Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
322   return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
323 }
324
325 /*
326 std::string NVPTXTargetLowering::getPrototype(
327     Type *retTy, const ArgListTy &Args,
328     const SmallVectorImpl<ISD::OutputArg> &Outs, unsigned retAlignment) const {
329
330   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
331
332   std::stringstream O;
333   O << "prototype_" << uniqueCallSite << " : .callprototype ";
334
335   if (retTy->getTypeID() == Type::VoidTyID)
336     O << "()";
337   else {
338     O << "(";
339     if (isABI) {
340       if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
341         unsigned size = 0;
342         if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
343           size = ITy->getBitWidth();
344           if (size < 32)
345             size = 32;
346         } else {
347           assert(retTy->isFloatingPointTy() &&
348                  "Floating point type expected here");
349           size = retTy->getPrimitiveSizeInBits();
350         }
351
352         O << ".param .b" << size << " _";
353       } else if (isa<PointerType>(retTy))
354         O << ".param .b" << getPointerTy().getSizeInBits() << " _";
355       else {
356         if ((retTy->getTypeID() == Type::StructTyID) ||
357             isa<VectorType>(retTy)) {
358           SmallVector<EVT, 16> vtparts;
359           ComputeValueVTs(*this, retTy, vtparts);
360           unsigned totalsz = 0;
361           for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
362             unsigned elems = 1;
363             EVT elemtype = vtparts[i];
364             if (vtparts[i].isVector()) {
365               elems = vtparts[i].getVectorNumElements();
366               elemtype = vtparts[i].getVectorElementType();
367             }
368             for (unsigned j = 0, je = elems; j != je; ++j) {
369               unsigned sz = elemtype.getSizeInBits();
370               if (elemtype.isInteger() && (sz < 8))
371                 sz = 8;
372               totalsz += sz / 8;
373             }
374           }
375           O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
376         } else {
377           assert(false && "Unknown return type");
378         }
379       }
380     } else {
381       SmallVector<EVT, 16> vtparts;
382       ComputeValueVTs(*this, retTy, vtparts);
383       unsigned idx = 0;
384       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
385         unsigned elems = 1;
386         EVT elemtype = vtparts[i];
387         if (vtparts[i].isVector()) {
388           elems = vtparts[i].getVectorNumElements();
389           elemtype = vtparts[i].getVectorElementType();
390         }
391
392         for (unsigned j = 0, je = elems; j != je; ++j) {
393           unsigned sz = elemtype.getSizeInBits();
394           if (elemtype.isInteger() && (sz < 32))
395             sz = 32;
396           O << ".reg .b" << sz << " _";
397           if (j < je - 1)
398             O << ", ";
399           ++idx;
400         }
401         if (i < e - 1)
402           O << ", ";
403       }
404     }
405     O << ") ";
406   }
407   O << "_ (";
408
409   bool first = true;
410   MVT thePointerTy = getPointerTy();
411
412   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
413     const Type *Ty = Args[i].Ty;
414     if (!first) {
415       O << ", ";
416     }
417     first = false;
418
419     if (Outs[i].Flags.isByVal() == false) {
420       unsigned sz = 0;
421       if (isa<IntegerType>(Ty)) {
422         sz = cast<IntegerType>(Ty)->getBitWidth();
423         if (sz < 32)
424           sz = 32;
425       } else if (isa<PointerType>(Ty))
426         sz = thePointerTy.getSizeInBits();
427       else
428         sz = Ty->getPrimitiveSizeInBits();
429       if (isABI)
430         O << ".param .b" << sz << " ";
431       else
432         O << ".reg .b" << sz << " ";
433       O << "_";
434       continue;
435     }
436     const PointerType *PTy = dyn_cast<PointerType>(Ty);
437     assert(PTy && "Param with byval attribute should be a pointer type");
438     Type *ETy = PTy->getElementType();
439
440     if (isABI) {
441       unsigned align = Outs[i].Flags.getByValAlign();
442       unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
443       O << ".param .align " << align << " .b8 ";
444       O << "_";
445       O << "[" << sz << "]";
446       continue;
447     } else {
448       SmallVector<EVT, 16> vtparts;
449       ComputeValueVTs(*this, ETy, vtparts);
450       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
451         unsigned elems = 1;
452         EVT elemtype = vtparts[i];
453         if (vtparts[i].isVector()) {
454           elems = vtparts[i].getVectorNumElements();
455           elemtype = vtparts[i].getVectorElementType();
456         }
457
458         for (unsigned j = 0, je = elems; j != je; ++j) {
459           unsigned sz = elemtype.getSizeInBits();
460           if (elemtype.isInteger() && (sz < 32))
461             sz = 32;
462           O << ".reg .b" << sz << " ";
463           O << "_";
464           if (j < je - 1)
465             O << ", ";
466         }
467         if (i < e - 1)
468           O << ", ";
469       }
470       continue;
471     }
472   }
473   O << ");";
474   return O.str();
475 }*/
476
477 std::string
478 NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
479                                   const SmallVectorImpl<ISD::OutputArg> &Outs,
480                                   unsigned retAlignment,
481                                   const ImmutableCallSite *CS) const {
482
483   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
484   assert(isABI && "Non-ABI compilation is not supported");
485   if (!isABI)
486     return "";
487
488   std::stringstream O;
489   O << "prototype_" << uniqueCallSite << " : .callprototype ";
490
491   if (retTy->getTypeID() == Type::VoidTyID) {
492     O << "()";
493   } else {
494     O << "(";
495     if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
496       unsigned size = 0;
497       if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
498         size = ITy->getBitWidth();
499         if (size < 32)
500           size = 32;
501       } else {
502         assert(retTy->isFloatingPointTy() &&
503                "Floating point type expected here");
504         size = retTy->getPrimitiveSizeInBits();
505       }
506
507       O << ".param .b" << size << " _";
508     } else if (isa<PointerType>(retTy)) {
509       O << ".param .b" << getPointerTy().getSizeInBits() << " _";
510     } else {
511       if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) {
512         SmallVector<EVT, 16> vtparts;
513         ComputeValueVTs(*this, retTy, vtparts);
514         unsigned totalsz = 0;
515         for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
516           unsigned elems = 1;
517           EVT elemtype = vtparts[i];
518           if (vtparts[i].isVector()) {
519             elems = vtparts[i].getVectorNumElements();
520             elemtype = vtparts[i].getVectorElementType();
521           }
522           // TODO: no need to loop
523           for (unsigned j = 0, je = elems; j != je; ++j) {
524             unsigned sz = elemtype.getSizeInBits();
525             if (elemtype.isInteger() && (sz < 8))
526               sz = 8;
527             totalsz += sz / 8;
528           }
529         }
530         O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
531       } else {
532         assert(false && "Unknown return type");
533       }
534     }
535     O << ") ";
536   }
537   O << "_ (";
538
539   bool first = true;
540   MVT thePointerTy = getPointerTy();
541
542   unsigned OIdx = 0;
543   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
544     Type *Ty = Args[i].Ty;
545     if (!first) {
546       O << ", ";
547     }
548     first = false;
549
550     if (Outs[OIdx].Flags.isByVal() == false) {
551       if (Ty->isAggregateType() || Ty->isVectorTy()) {
552         unsigned align = 0;
553         const CallInst *CallI = cast<CallInst>(CS->getInstruction());
554         const DataLayout *TD = getDataLayout();
555         // +1 because index 0 is reserved for return type alignment
556         if (!llvm::getAlign(*CallI, i + 1, align))
557           align = TD->getABITypeAlignment(Ty);
558         unsigned sz = TD->getTypeAllocSize(Ty);
559         O << ".param .align " << align << " .b8 ";
560         O << "_";
561         O << "[" << sz << "]";
562         // update the index for Outs
563         SmallVector<EVT, 16> vtparts;
564         ComputeValueVTs(*this, Ty, vtparts);
565         if (unsigned len = vtparts.size())
566           OIdx += len - 1;
567         continue;
568       }
569       assert(getValueType(Ty) == Outs[OIdx].VT &&
570              "type mismatch between callee prototype and arguments");
571       // scalar type
572       unsigned sz = 0;
573       if (isa<IntegerType>(Ty)) {
574         sz = cast<IntegerType>(Ty)->getBitWidth();
575         if (sz < 32)
576           sz = 32;
577       } else if (isa<PointerType>(Ty))
578         sz = thePointerTy.getSizeInBits();
579       else
580         sz = Ty->getPrimitiveSizeInBits();
581       O << ".param .b" << sz << " ";
582       O << "_";
583       continue;
584     }
585     const PointerType *PTy = dyn_cast<PointerType>(Ty);
586     assert(PTy && "Param with byval attribute should be a pointer type");
587     Type *ETy = PTy->getElementType();
588
589     unsigned align = Outs[OIdx].Flags.getByValAlign();
590     unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
591     O << ".param .align " << align << " .b8 ";
592     O << "_";
593     O << "[" << sz << "]";
594   }
595   O << ");";
596   return O.str();
597 }
598
599 unsigned
600 NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
601                                           const ImmutableCallSite *CS,
602                                           Type *Ty,
603                                           unsigned Idx) const {
604   const DataLayout *TD = getDataLayout();
605   unsigned align = 0;
606   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
607
608   if (Func) { // direct call
609     assert(CS->getCalledFunction() &&
610            "direct call cannot find callee");
611     if (!llvm::getAlign(*(CS->getCalledFunction()), Idx, align))
612       align = TD->getABITypeAlignment(Ty);
613   }
614   else { // indirect call
615     const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
616     if (!llvm::getAlign(*CallI, Idx, align))
617       align = TD->getABITypeAlignment(Ty);
618   }
619
620   return align;
621 }
622
623 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
624                                        SmallVectorImpl<SDValue> &InVals) const {
625   SelectionDAG &DAG = CLI.DAG;
626   SDLoc dl = CLI.DL;
627   SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
628   SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
629   SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
630   SDValue Chain = CLI.Chain;
631   SDValue Callee = CLI.Callee;
632   bool &isTailCall = CLI.IsTailCall;
633   ArgListTy &Args = CLI.Args;
634   Type *retTy = CLI.RetTy;
635   ImmutableCallSite *CS = CLI.CS;
636
637   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
638   assert(isABI && "Non-ABI compilation is not supported");
639   if (!isABI)
640     return Chain;
641   const DataLayout *TD = getDataLayout();
642   MachineFunction &MF = DAG.getMachineFunction();
643   const Function *F = MF.getFunction();
644   const TargetLowering *TLI = nvTM->getTargetLowering();
645
646   SDValue tempChain = Chain;
647   Chain =
648       DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
649                            dl);
650   SDValue InFlag = Chain.getValue(1);
651
652   unsigned paramCount = 0;
653   // Args.size() and Outs.size() need not match.
654   // Outs.size() will be larger
655   //   * if there is an aggregate argument with multiple fields (each field
656   //     showing up separately in Outs)
657   //   * if there is a vector argument with more than typical vector-length
658   //     elements (generally if more than 4) where each vector element is
659   //     individually present in Outs.
660   // So a different index should be used for indexing into Outs/OutVals.
661   // See similar issue in LowerFormalArguments.
662   unsigned OIdx = 0;
663   // Declare the .params or .reg need to pass values
664   // to the function
665   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
666     EVT VT = Outs[OIdx].VT;
667     Type *Ty = Args[i].Ty;
668
669     if (Outs[OIdx].Flags.isByVal() == false) {
670       if (Ty->isAggregateType()) {
671         // aggregate
672         SmallVector<EVT, 16> vtparts;
673         ComputeValueVTs(*this, Ty, vtparts);
674
675         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
676         // declare .param .align <align> .b8 .param<n>[<size>];
677         unsigned sz = TD->getTypeAllocSize(Ty);
678         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
679         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
680                                       DAG.getConstant(paramCount, MVT::i32),
681                                       DAG.getConstant(sz, MVT::i32), InFlag };
682         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
683                             DeclareParamOps, 5);
684         InFlag = Chain.getValue(1);
685         unsigned curOffset = 0;
686         for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
687           unsigned elems = 1;
688           EVT elemtype = vtparts[j];
689           if (vtparts[j].isVector()) {
690             elems = vtparts[j].getVectorNumElements();
691             elemtype = vtparts[j].getVectorElementType();
692           }
693           for (unsigned k = 0, ke = elems; k != ke; ++k) {
694             unsigned sz = elemtype.getSizeInBits();
695             if (elemtype.isInteger() && (sz < 8))
696               sz = 8;
697             SDValue StVal = OutVals[OIdx];
698             if (elemtype.getSizeInBits() < 16) {
699               StVal = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i16, StVal);
700             }
701             SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
702             SDValue CopyParamOps[] = { Chain,
703                                        DAG.getConstant(paramCount, MVT::i32),
704                                        DAG.getConstant(curOffset, MVT::i32),
705                                        StVal, InFlag };
706             Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
707                                             CopyParamVTs, &CopyParamOps[0], 5,
708                                             elemtype, MachinePointerInfo());
709             InFlag = Chain.getValue(1);
710             curOffset += sz / 8;
711             ++OIdx;
712           }
713         }
714         if (vtparts.size() > 0)
715           --OIdx;
716         ++paramCount;
717         continue;
718       }
719       if (Ty->isVectorTy()) {
720         EVT ObjectVT = getValueType(Ty);
721         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
722         // declare .param .align <align> .b8 .param<n>[<size>];
723         unsigned sz = TD->getTypeAllocSize(Ty);
724         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
725         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
726                                       DAG.getConstant(paramCount, MVT::i32),
727                                       DAG.getConstant(sz, MVT::i32), InFlag };
728         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
729                             DeclareParamOps, 5);
730         InFlag = Chain.getValue(1);
731         unsigned NumElts = ObjectVT.getVectorNumElements();
732         EVT EltVT = ObjectVT.getVectorElementType();
733         EVT MemVT = EltVT;
734         bool NeedExtend = false;
735         if (EltVT.getSizeInBits() < 16) {
736           NeedExtend = true;
737           EltVT = MVT::i16;
738         }
739
740         // V1 store
741         if (NumElts == 1) {
742           SDValue Elt = OutVals[OIdx++];
743           if (NeedExtend)
744             Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt);
745
746           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
747           SDValue CopyParamOps[] = { Chain,
748                                      DAG.getConstant(paramCount, MVT::i32),
749                                      DAG.getConstant(0, MVT::i32), Elt,
750                                      InFlag };
751           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
752                                           CopyParamVTs, &CopyParamOps[0], 5,
753                                           MemVT, MachinePointerInfo());
754           InFlag = Chain.getValue(1);
755         } else if (NumElts == 2) {
756           SDValue Elt0 = OutVals[OIdx++];
757           SDValue Elt1 = OutVals[OIdx++];
758           if (NeedExtend) {
759             Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0);
760             Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1);
761           }
762
763           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
764           SDValue CopyParamOps[] = { Chain,
765                                      DAG.getConstant(paramCount, MVT::i32),
766                                      DAG.getConstant(0, MVT::i32), Elt0, Elt1,
767                                      InFlag };
768           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl,
769                                           CopyParamVTs, &CopyParamOps[0], 6,
770                                           MemVT, MachinePointerInfo());
771           InFlag = Chain.getValue(1);
772         } else {
773           unsigned curOffset = 0;
774           // V4 stores
775           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
776           // the
777           // vector will be expanded to a power of 2 elements, so we know we can
778           // always round up to the next multiple of 4 when creating the vector
779           // stores.
780           // e.g.  4 elem => 1 st.v4
781           //       6 elem => 2 st.v4
782           //       8 elem => 2 st.v4
783           //      11 elem => 3 st.v4
784           unsigned VecSize = 4;
785           if (EltVT.getSizeInBits() == 64)
786             VecSize = 2;
787
788           // This is potentially only part of a vector, so assume all elements
789           // are packed together.
790           unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize;
791
792           for (unsigned i = 0; i < NumElts; i += VecSize) {
793             // Get values
794             SDValue StoreVal;
795             SmallVector<SDValue, 8> Ops;
796             Ops.push_back(Chain);
797             Ops.push_back(DAG.getConstant(paramCount, MVT::i32));
798             Ops.push_back(DAG.getConstant(curOffset, MVT::i32));
799
800             unsigned Opc = NVPTXISD::StoreParamV2;
801
802             StoreVal = OutVals[OIdx++];
803             if (NeedExtend)
804               StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
805             Ops.push_back(StoreVal);
806
807             if (i + 1 < NumElts) {
808               StoreVal = OutVals[OIdx++];
809               if (NeedExtend)
810                 StoreVal =
811                     DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
812             } else {
813               StoreVal = DAG.getUNDEF(EltVT);
814             }
815             Ops.push_back(StoreVal);
816
817             if (VecSize == 4) {
818               Opc = NVPTXISD::StoreParamV4;
819               if (i + 2 < NumElts) {
820                 StoreVal = OutVals[OIdx++];
821                 if (NeedExtend)
822                   StoreVal =
823                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
824               } else {
825                 StoreVal = DAG.getUNDEF(EltVT);
826               }
827               Ops.push_back(StoreVal);
828
829               if (i + 3 < NumElts) {
830                 StoreVal = OutVals[OIdx++];
831                 if (NeedExtend)
832                   StoreVal =
833                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
834               } else {
835                 StoreVal = DAG.getUNDEF(EltVT);
836               }
837               Ops.push_back(StoreVal);
838             }
839
840             SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
841             Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, &Ops[0],
842                                             Ops.size(), MemVT,
843                                             MachinePointerInfo());
844             InFlag = Chain.getValue(1);
845             curOffset += PerStoreOffset;
846           }
847         }
848         ++paramCount;
849         --OIdx;
850         continue;
851       }
852       // Plain scalar
853       // for ABI,    declare .param .b<size> .param<n>;
854       unsigned sz = VT.getSizeInBits();
855       bool needExtend = false;
856       if (VT.isInteger()) {
857         if (sz < 16)
858           needExtend = true;
859         if (sz < 32)
860           sz = 32;
861       }
862       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
863       SDValue DeclareParamOps[] = { Chain,
864                                     DAG.getConstant(paramCount, MVT::i32),
865                                     DAG.getConstant(sz, MVT::i32),
866                                     DAG.getConstant(0, MVT::i32), InFlag };
867       Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
868                           DeclareParamOps, 5);
869       InFlag = Chain.getValue(1);
870       SDValue OutV = OutVals[OIdx];
871       if (needExtend) {
872         // zext/sext i1 to i16
873         unsigned opc = ISD::ZERO_EXTEND;
874         if (Outs[OIdx].Flags.isSExt())
875           opc = ISD::SIGN_EXTEND;
876         OutV = DAG.getNode(opc, dl, MVT::i16, OutV);
877       }
878       SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
879       SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
880                                  DAG.getConstant(0, MVT::i32), OutV, InFlag };
881
882       unsigned opcode = NVPTXISD::StoreParam;
883       if (Outs[OIdx].Flags.isZExt())
884         opcode = NVPTXISD::StoreParamU32;
885       else if (Outs[OIdx].Flags.isSExt())
886         opcode = NVPTXISD::StoreParamS32;
887       Chain = DAG.getMemIntrinsicNode(opcode, dl, CopyParamVTs, CopyParamOps, 5,
888                                       VT, MachinePointerInfo());
889
890       InFlag = Chain.getValue(1);
891       ++paramCount;
892       continue;
893     }
894     // struct or vector
895     SmallVector<EVT, 16> vtparts;
896     const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
897     assert(PTy && "Type of a byval parameter should be pointer");
898     ComputeValueVTs(*this, PTy->getElementType(), vtparts);
899
900     // declare .param .align <align> .b8 .param<n>[<size>];
901     unsigned sz = Outs[OIdx].Flags.getByValSize();
902     SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
903     // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
904     // so we don't need to worry about natural alignment or not.
905     // See TargetLowering::LowerCallTo().
906     SDValue DeclareParamOps[] = {
907       Chain, DAG.getConstant(Outs[OIdx].Flags.getByValAlign(), MVT::i32),
908       DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32),
909       InFlag
910     };
911     Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
912                         DeclareParamOps, 5);
913     InFlag = Chain.getValue(1);
914     unsigned curOffset = 0;
915     for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
916       unsigned elems = 1;
917       EVT elemtype = vtparts[j];
918       if (vtparts[j].isVector()) {
919         elems = vtparts[j].getVectorNumElements();
920         elemtype = vtparts[j].getVectorElementType();
921       }
922       for (unsigned k = 0, ke = elems; k != ke; ++k) {
923         unsigned sz = elemtype.getSizeInBits();
924         if (elemtype.isInteger() && (sz < 8))
925           sz = 8;
926         SDValue srcAddr =
927             DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
928                         DAG.getConstant(curOffset, getPointerTy()));
929         SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
930                                      MachinePointerInfo(), false, false, false,
931                                      0);
932         if (elemtype.getSizeInBits() < 16) {
933           theVal = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i16, theVal);
934         }
935         SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
936         SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
937                                    DAG.getConstant(curOffset, MVT::i32), theVal,
938                                    InFlag };
939         Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
940                                         CopyParamOps, 5, elemtype,
941                                         MachinePointerInfo());
942
943         InFlag = Chain.getValue(1);
944         curOffset += sz / 8;
945       }
946     }
947     ++paramCount;
948   }
949
950   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
951   unsigned retAlignment = 0;
952
953   // Handle Result
954   if (Ins.size() > 0) {
955     SmallVector<EVT, 16> resvtparts;
956     ComputeValueVTs(*this, retTy, resvtparts);
957
958     // Declare
959     //  .param .align 16 .b8 retval0[<size-in-bytes>], or
960     //  .param .b<size-in-bits> retval0
961     unsigned resultsz = TD->getTypeAllocSizeInBits(retTy);
962     if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
963         retTy->isPointerTy()) {
964       // Scalar needs to be at least 32bit wide
965       if (resultsz < 32)
966         resultsz = 32;
967       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
968       SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
969                                   DAG.getConstant(resultsz, MVT::i32),
970                                   DAG.getConstant(0, MVT::i32), InFlag };
971       Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
972                           DeclareRetOps, 5);
973       InFlag = Chain.getValue(1);
974     } else {
975       retAlignment = getArgumentAlignment(Callee, CS, retTy, 0);
976       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
977       SDValue DeclareRetOps[] = { Chain,
978                                   DAG.getConstant(retAlignment, MVT::i32),
979                                   DAG.getConstant(resultsz / 8, MVT::i32),
980                                   DAG.getConstant(0, MVT::i32), InFlag };
981       Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
982                           DeclareRetOps, 5);
983       InFlag = Chain.getValue(1);
984     }
985   }
986
987   if (!Func) {
988     // This is indirect function call case : PTX requires a prototype of the
989     // form
990     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
991     // to be emitted, and the label has to used as the last arg of call
992     // instruction.
993     // The prototype is embedded in a string and put as the operand for an
994     // INLINEASM SDNode.
995     SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
996     std::string proto_string =
997         getPrototype(retTy, Args, Outs, retAlignment, CS);
998     const char *asmstr = nvTM->getManagedStrPool()
999         ->getManagedString(proto_string.c_str())->c_str();
1000     SDValue InlineAsmOps[] = {
1001       Chain, DAG.getTargetExternalSymbol(asmstr, getPointerTy()),
1002       DAG.getMDNode(0), DAG.getTargetConstant(0, MVT::i32), InFlag
1003     };
1004     Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
1005     InFlag = Chain.getValue(1);
1006   }
1007   // Op to just print "call"
1008   SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1009   SDValue PrintCallOps[] = {
1010     Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, MVT::i32), InFlag
1011   };
1012   Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
1013                       dl, PrintCallVTs, PrintCallOps, 3);
1014   InFlag = Chain.getValue(1);
1015
1016   // Ops to print out the function name
1017   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1018   SDValue CallVoidOps[] = { Chain, Callee, InFlag };
1019   Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
1020   InFlag = Chain.getValue(1);
1021
1022   // Ops to print out the param list
1023   SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1024   SDValue CallArgBeginOps[] = { Chain, InFlag };
1025   Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
1026                       CallArgBeginOps, 2);
1027   InFlag = Chain.getValue(1);
1028
1029   for (unsigned i = 0, e = paramCount; i != e; ++i) {
1030     unsigned opcode;
1031     if (i == (e - 1))
1032       opcode = NVPTXISD::LastCallArg;
1033     else
1034       opcode = NVPTXISD::CallArg;
1035     SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1036     SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
1037                              DAG.getConstant(i, MVT::i32), InFlag };
1038     Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
1039     InFlag = Chain.getValue(1);
1040   }
1041   SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1042   SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32),
1043                               InFlag };
1044   Chain =
1045       DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps, 3);
1046   InFlag = Chain.getValue(1);
1047
1048   if (!Func) {
1049     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1050     SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32),
1051                                InFlag };
1052     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
1053     InFlag = Chain.getValue(1);
1054   }
1055
1056   // Generate loads from param memory/moves from registers for result
1057   if (Ins.size() > 0) {
1058     unsigned resoffset = 0;
1059     if (retTy && retTy->isVectorTy()) {
1060       EVT ObjectVT = getValueType(retTy);
1061       unsigned NumElts = ObjectVT.getVectorNumElements();
1062       EVT EltVT = ObjectVT.getVectorElementType();
1063       assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1064              "Vector was not scalarized");
1065       unsigned sz = EltVT.getSizeInBits();
1066       bool needTruncate = sz < 16 ? true : false;
1067
1068       if (NumElts == 1) {
1069         // Just a simple load
1070         std::vector<EVT> LoadRetVTs;
1071         if (needTruncate) {
1072           // If loading i1 result, generate
1073           //   load i16
1074           //   trunc i16 to i1
1075           LoadRetVTs.push_back(MVT::i16);
1076         } else
1077           LoadRetVTs.push_back(EltVT);
1078         LoadRetVTs.push_back(MVT::Other);
1079         LoadRetVTs.push_back(MVT::Glue);
1080         std::vector<SDValue> LoadRetOps;
1081         LoadRetOps.push_back(Chain);
1082         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1083         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1084         LoadRetOps.push_back(InFlag);
1085         SDValue retval = DAG.getMemIntrinsicNode(
1086             NVPTXISD::LoadParam, dl,
1087             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1088             LoadRetOps.size(), EltVT, MachinePointerInfo());
1089         Chain = retval.getValue(1);
1090         InFlag = retval.getValue(2);
1091         SDValue Ret0 = retval;
1092         if (needTruncate)
1093           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Ret0);
1094         InVals.push_back(Ret0);
1095       } else if (NumElts == 2) {
1096         // LoadV2
1097         std::vector<EVT> LoadRetVTs;
1098         if (needTruncate) {
1099           // If loading i1 result, generate
1100           //   load i16
1101           //   trunc i16 to i1
1102           LoadRetVTs.push_back(MVT::i16);
1103           LoadRetVTs.push_back(MVT::i16);
1104         } else {
1105           LoadRetVTs.push_back(EltVT);
1106           LoadRetVTs.push_back(EltVT);
1107         }
1108         LoadRetVTs.push_back(MVT::Other);
1109         LoadRetVTs.push_back(MVT::Glue);
1110         std::vector<SDValue> LoadRetOps;
1111         LoadRetOps.push_back(Chain);
1112         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1113         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1114         LoadRetOps.push_back(InFlag);
1115         SDValue retval = DAG.getMemIntrinsicNode(
1116             NVPTXISD::LoadParamV2, dl,
1117             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1118             LoadRetOps.size(), EltVT, MachinePointerInfo());
1119         Chain = retval.getValue(2);
1120         InFlag = retval.getValue(3);
1121         SDValue Ret0 = retval.getValue(0);
1122         SDValue Ret1 = retval.getValue(1);
1123         if (needTruncate) {
1124           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret0);
1125           InVals.push_back(Ret0);
1126           Ret1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret1);
1127           InVals.push_back(Ret1);
1128         } else {
1129           InVals.push_back(Ret0);
1130           InVals.push_back(Ret1);
1131         }
1132       } else {
1133         // Split into N LoadV4
1134         unsigned Ofst = 0;
1135         unsigned VecSize = 4;
1136         unsigned Opc = NVPTXISD::LoadParamV4;
1137         if (EltVT.getSizeInBits() == 64) {
1138           VecSize = 2;
1139           Opc = NVPTXISD::LoadParamV2;
1140         }
1141         EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1142         for (unsigned i = 0; i < NumElts; i += VecSize) {
1143           SmallVector<EVT, 8> LoadRetVTs;
1144           if (needTruncate) {
1145             // If loading i1 result, generate
1146             //   load i16
1147             //   trunc i16 to i1
1148             for (unsigned j = 0; j < VecSize; ++j)
1149               LoadRetVTs.push_back(MVT::i16);
1150           } else {
1151             for (unsigned j = 0; j < VecSize; ++j)
1152               LoadRetVTs.push_back(EltVT);
1153           }
1154           LoadRetVTs.push_back(MVT::Other);
1155           LoadRetVTs.push_back(MVT::Glue);
1156           SmallVector<SDValue, 4> LoadRetOps;
1157           LoadRetOps.push_back(Chain);
1158           LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1159           LoadRetOps.push_back(DAG.getConstant(Ofst, MVT::i32));
1160           LoadRetOps.push_back(InFlag);
1161           SDValue retval = DAG.getMemIntrinsicNode(
1162               Opc, dl, DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()),
1163               &LoadRetOps[0], LoadRetOps.size(), EltVT, MachinePointerInfo());
1164           if (VecSize == 2) {
1165             Chain = retval.getValue(2);
1166             InFlag = retval.getValue(3);
1167           } else {
1168             Chain = retval.getValue(4);
1169             InFlag = retval.getValue(5);
1170           }
1171
1172           for (unsigned j = 0; j < VecSize; ++j) {
1173             if (i + j >= NumElts)
1174               break;
1175             SDValue Elt = retval.getValue(j);
1176             if (needTruncate)
1177               Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
1178             InVals.push_back(Elt);
1179           }
1180           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1181         }
1182       }
1183     } else {
1184       SmallVector<EVT, 16> VTs;
1185       ComputePTXValueVTs(*this, retTy, VTs);
1186       assert(VTs.size() == Ins.size() && "Bad value decomposition");
1187       for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
1188         unsigned sz = VTs[i].getSizeInBits();
1189         bool needTruncate = sz < 8 ? true : false;
1190         if (VTs[i].isInteger() && (sz < 8))
1191           sz = 8;
1192
1193         SmallVector<EVT, 4> LoadRetVTs;
1194         if (sz < 16) {
1195           // If loading i1/i8 result, generate
1196           //   load i8 (-> i16)
1197           //   trunc i16 to i1/i8
1198           LoadRetVTs.push_back(MVT::i16);
1199         } else
1200           LoadRetVTs.push_back(Ins[i].VT);
1201         LoadRetVTs.push_back(MVT::Other);
1202         LoadRetVTs.push_back(MVT::Glue);
1203
1204         SmallVector<SDValue, 4> LoadRetOps;
1205         LoadRetOps.push_back(Chain);
1206         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1207         LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
1208         LoadRetOps.push_back(InFlag);
1209         SDValue retval = DAG.getMemIntrinsicNode(
1210             NVPTXISD::LoadParam, dl,
1211             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1212             LoadRetOps.size(), VTs[i], MachinePointerInfo());
1213         Chain = retval.getValue(1);
1214         InFlag = retval.getValue(2);
1215         SDValue Ret0 = retval.getValue(0);
1216         if (needTruncate)
1217           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
1218         InVals.push_back(Ret0);
1219         resoffset += sz / 8;
1220       }
1221     }
1222   }
1223
1224   Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
1225                              DAG.getIntPtrConstant(uniqueCallSite + 1, true),
1226                              InFlag, dl);
1227   uniqueCallSite++;
1228
1229   // set isTailCall to false for now, until we figure out how to express
1230   // tail call optimization in PTX
1231   isTailCall = false;
1232   return Chain;
1233 }
1234
1235 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1236 // (see LegalizeDAG.cpp). This is slow and uses local memory.
1237 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1238 SDValue
1239 NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1240   SDNode *Node = Op.getNode();
1241   SDLoc dl(Node);
1242   SmallVector<SDValue, 8> Ops;
1243   unsigned NumOperands = Node->getNumOperands();
1244   for (unsigned i = 0; i < NumOperands; ++i) {
1245     SDValue SubOp = Node->getOperand(i);
1246     EVT VVT = SubOp.getNode()->getValueType(0);
1247     EVT EltVT = VVT.getVectorElementType();
1248     unsigned NumSubElem = VVT.getVectorNumElements();
1249     for (unsigned j = 0; j < NumSubElem; ++j) {
1250       Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
1251                                 DAG.getIntPtrConstant(j)));
1252     }
1253   }
1254   return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), &Ops[0],
1255                      Ops.size());
1256 }
1257
1258 SDValue
1259 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1260   switch (Op.getOpcode()) {
1261   case ISD::RETURNADDR:
1262     return SDValue();
1263   case ISD::FRAMEADDR:
1264     return SDValue();
1265   case ISD::GlobalAddress:
1266     return LowerGlobalAddress(Op, DAG);
1267   case ISD::INTRINSIC_W_CHAIN:
1268     return Op;
1269   case ISD::BUILD_VECTOR:
1270   case ISD::EXTRACT_SUBVECTOR:
1271     return Op;
1272   case ISD::CONCAT_VECTORS:
1273     return LowerCONCAT_VECTORS(Op, DAG);
1274   case ISD::STORE:
1275     return LowerSTORE(Op, DAG);
1276   case ISD::LOAD:
1277     return LowerLOAD(Op, DAG);
1278   default:
1279     llvm_unreachable("Custom lowering not defined for operation");
1280   }
1281 }
1282
1283 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
1284   if (Op.getValueType() == MVT::i1)
1285     return LowerLOADi1(Op, DAG);
1286   else
1287     return SDValue();
1288 }
1289
1290 // v = ld i1* addr
1291 //   =>
1292 // v1 = ld i8* addr (-> i16)
1293 // v = trunc i16 to i1
1294 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
1295   SDNode *Node = Op.getNode();
1296   LoadSDNode *LD = cast<LoadSDNode>(Node);
1297   SDLoc dl(Node);
1298   assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
1299   assert(Node->getValueType(0) == MVT::i1 &&
1300          "Custom lowering for i1 load only");
1301   SDValue newLD =
1302       DAG.getLoad(MVT::i16, dl, LD->getChain(), LD->getBasePtr(),
1303                   LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(),
1304                   LD->isInvariant(), LD->getAlignment());
1305   SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
1306   // The legalizer (the caller) is expecting two values from the legalized
1307   // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
1308   // in LegalizeDAG.cpp which also uses MergeValues.
1309   SDValue Ops[] = { result, LD->getChain() };
1310   return DAG.getMergeValues(Ops, 2, dl);
1311 }
1312
1313 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
1314   EVT ValVT = Op.getOperand(1).getValueType();
1315   if (ValVT == MVT::i1)
1316     return LowerSTOREi1(Op, DAG);
1317   else if (ValVT.isVector())
1318     return LowerSTOREVector(Op, DAG);
1319   else
1320     return SDValue();
1321 }
1322
1323 SDValue
1324 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
1325   SDNode *N = Op.getNode();
1326   SDValue Val = N->getOperand(1);
1327   SDLoc DL(N);
1328   EVT ValVT = Val.getValueType();
1329
1330   if (ValVT.isVector()) {
1331     // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1332     // legal.  We can (and should) split that into 2 stores of <2 x double> here
1333     // but I'm leaving that as a TODO for now.
1334     if (!ValVT.isSimple())
1335       return SDValue();
1336     switch (ValVT.getSimpleVT().SimpleTy) {
1337     default:
1338       return SDValue();
1339     case MVT::v2i8:
1340     case MVT::v2i16:
1341     case MVT::v2i32:
1342     case MVT::v2i64:
1343     case MVT::v2f32:
1344     case MVT::v2f64:
1345     case MVT::v4i8:
1346     case MVT::v4i16:
1347     case MVT::v4i32:
1348     case MVT::v4f32:
1349       // This is a "native" vector type
1350       break;
1351     }
1352
1353     unsigned Opcode = 0;
1354     EVT EltVT = ValVT.getVectorElementType();
1355     unsigned NumElts = ValVT.getVectorNumElements();
1356
1357     // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
1358     // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1359     // stored type to i16 and propogate the "real" type as the memory type.
1360     bool NeedSExt = false;
1361     if (EltVT.getSizeInBits() < 16)
1362       NeedSExt = true;
1363
1364     switch (NumElts) {
1365     default:
1366       return SDValue();
1367     case 2:
1368       Opcode = NVPTXISD::StoreV2;
1369       break;
1370     case 4: {
1371       Opcode = NVPTXISD::StoreV4;
1372       break;
1373     }
1374     }
1375
1376     SmallVector<SDValue, 8> Ops;
1377
1378     // First is the chain
1379     Ops.push_back(N->getOperand(0));
1380
1381     // Then the split values
1382     for (unsigned i = 0; i < NumElts; ++i) {
1383       SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
1384                                    DAG.getIntPtrConstant(i));
1385       if (NeedSExt)
1386         ExtVal = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i16, ExtVal);
1387       Ops.push_back(ExtVal);
1388     }
1389
1390     // Then any remaining arguments
1391     for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
1392       Ops.push_back(N->getOperand(i));
1393     }
1394
1395     MemSDNode *MemSD = cast<MemSDNode>(N);
1396
1397     SDValue NewSt = DAG.getMemIntrinsicNode(
1398         Opcode, DL, DAG.getVTList(MVT::Other), &Ops[0], Ops.size(),
1399         MemSD->getMemoryVT(), MemSD->getMemOperand());
1400
1401     //return DCI.CombineTo(N, NewSt, true);
1402     return NewSt;
1403   }
1404
1405   return SDValue();
1406 }
1407
1408 // st i1 v, addr
1409 //    =>
1410 // v1 = zxt v to i16
1411 // st.u8 i16, addr
1412 SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
1413   SDNode *Node = Op.getNode();
1414   SDLoc dl(Node);
1415   StoreSDNode *ST = cast<StoreSDNode>(Node);
1416   SDValue Tmp1 = ST->getChain();
1417   SDValue Tmp2 = ST->getBasePtr();
1418   SDValue Tmp3 = ST->getValue();
1419   assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
1420   unsigned Alignment = ST->getAlignment();
1421   bool isVolatile = ST->isVolatile();
1422   bool isNonTemporal = ST->isNonTemporal();
1423   Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
1424   SDValue Result = DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2,
1425                                      ST->getPointerInfo(), MVT::i8, isNonTemporal,
1426                                      isVolatile, Alignment);
1427   return Result;
1428 }
1429
1430 SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname,
1431                                         int idx, EVT v) const {
1432   std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
1433   std::stringstream suffix;
1434   suffix << idx;
1435   *name += suffix.str();
1436   return DAG.getTargetExternalSymbol(name->c_str(), v);
1437 }
1438
1439 SDValue
1440 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1441   return getExtSymb(DAG, ".PARAM", idx, v);
1442 }
1443
1444 SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1445   return getExtSymb(DAG, ".HLPPARAM", idx);
1446 }
1447
1448 // Check to see if the kernel argument is image*_t or sampler_t
1449
1450 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1451   static const char *const specialTypes[] = { "struct._image2d_t",
1452                                               "struct._image3d_t",
1453                                               "struct._sampler_t" };
1454
1455   const Type *Ty = arg->getType();
1456   const PointerType *PTy = dyn_cast<PointerType>(Ty);
1457
1458   if (!PTy)
1459     return false;
1460
1461   if (!context)
1462     return false;
1463
1464   const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1465   const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1466
1467   for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1468     if (TypeName == specialTypes[i])
1469       return true;
1470
1471   return false;
1472 }
1473
1474 SDValue NVPTXTargetLowering::LowerFormalArguments(
1475     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1476     const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG,
1477     SmallVectorImpl<SDValue> &InVals) const {
1478   MachineFunction &MF = DAG.getMachineFunction();
1479   const DataLayout *TD = getDataLayout();
1480
1481   const Function *F = MF.getFunction();
1482   const AttributeSet &PAL = F->getAttributes();
1483   const TargetLowering *TLI = nvTM->getTargetLowering();
1484
1485   SDValue Root = DAG.getRoot();
1486   std::vector<SDValue> OutChains;
1487
1488   bool isKernel = llvm::isKernelFunction(*F);
1489   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1490   assert(isABI && "Non-ABI compilation is not supported");
1491   if (!isABI)
1492     return Chain;
1493
1494   std::vector<Type *> argTypes;
1495   std::vector<const Argument *> theArgs;
1496   for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1497        I != E; ++I) {
1498     theArgs.push_back(I);
1499     argTypes.push_back(I->getType());
1500   }
1501   // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
1502   // Ins.size() will be larger
1503   //   * if there is an aggregate argument with multiple fields (each field
1504   //     showing up separately in Ins)
1505   //   * if there is a vector argument with more than typical vector-length
1506   //     elements (generally if more than 4) where each vector element is
1507   //     individually present in Ins.
1508   // So a different index should be used for indexing into Ins.
1509   // See similar issue in LowerCall.
1510   unsigned InsIdx = 0;
1511
1512   int idx = 0;
1513   for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) {
1514     Type *Ty = argTypes[i];
1515
1516     // If the kernel argument is image*_t or sampler_t, convert it to
1517     // a i32 constant holding the parameter position. This can later
1518     // matched in the AsmPrinter to output the correct mangled name.
1519     if (isImageOrSamplerVal(
1520             theArgs[i],
1521             (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent()
1522                                      : 0))) {
1523       assert(isKernel && "Only kernels can have image/sampler params");
1524       InVals.push_back(DAG.getConstant(i + 1, MVT::i32));
1525       continue;
1526     }
1527
1528     if (theArgs[i]->use_empty()) {
1529       // argument is dead
1530       if (Ty->isAggregateType()) {
1531         SmallVector<EVT, 16> vtparts;
1532
1533         ComputePTXValueVTs(*this, Ty, vtparts);
1534         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1535         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1536              ++parti) {
1537           EVT partVT = vtparts[parti];
1538           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, partVT));
1539           ++InsIdx;
1540         }
1541         if (vtparts.size() > 0)
1542           --InsIdx;
1543         continue;
1544       }
1545       if (Ty->isVectorTy()) {
1546         EVT ObjectVT = getValueType(Ty);
1547         unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
1548         for (unsigned parti = 0; parti < NumRegs; ++parti) {
1549           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1550           ++InsIdx;
1551         }
1552         if (NumRegs > 0)
1553           --InsIdx;
1554         continue;
1555       }
1556       InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1557       continue;
1558     }
1559
1560     // In the following cases, assign a node order of "idx+1"
1561     // to newly created nodes. The SDNodes for params have to
1562     // appear in the same order as their order of appearance
1563     // in the original function. "idx+1" holds that order.
1564     if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
1565       if (Ty->isAggregateType()) {
1566         SmallVector<EVT, 16> vtparts;
1567         SmallVector<uint64_t, 16> offsets;
1568
1569         // NOTE: Here, we lose the ability to issue vector loads for vectors
1570         // that are a part of a struct.  This should be investigated in the
1571         // future.
1572         ComputePTXValueVTs(*this, Ty, vtparts, &offsets, 0);
1573         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1574         bool aggregateIsPacked = false;
1575         if (StructType *STy = llvm::dyn_cast<StructType>(Ty))
1576           aggregateIsPacked = STy->isPacked();
1577
1578         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1579         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1580              ++parti) {
1581           EVT partVT = vtparts[parti];
1582           Value *srcValue = Constant::getNullValue(
1583               PointerType::get(partVT.getTypeForEVT(F->getContext()),
1584                                llvm::ADDRESS_SPACE_PARAM));
1585           SDValue srcAddr =
1586               DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1587                           DAG.getConstant(offsets[parti], getPointerTy()));
1588           unsigned partAlign =
1589               aggregateIsPacked ? 1
1590                                 : TD->getABITypeAlignment(
1591                                       partVT.getTypeForEVT(F->getContext()));
1592                     SDValue p;
1593           if (Ins[InsIdx].VT.getSizeInBits() > partVT.getSizeInBits())
1594             p = DAG.getExtLoad(ISD::SEXTLOAD, dl, Ins[InsIdx].VT, Root, srcAddr,
1595                                MachinePointerInfo(srcValue), partVT, false,
1596                                false, partAlign);
1597           else
1598             p = DAG.getLoad(partVT, dl, Root, srcAddr,
1599                             MachinePointerInfo(srcValue), false, false, false,
1600                             partAlign);
1601           if (p.getNode())
1602             p.getNode()->setIROrder(idx + 1);
1603           InVals.push_back(p);
1604           ++InsIdx;
1605         }
1606         if (vtparts.size() > 0)
1607           --InsIdx;
1608         continue;
1609       }
1610       if (Ty->isVectorTy()) {
1611         EVT ObjectVT = getValueType(Ty);
1612         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1613         unsigned NumElts = ObjectVT.getVectorNumElements();
1614         assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1615                "Vector was not scalarized");
1616         unsigned Ofst = 0;
1617         EVT EltVT = ObjectVT.getVectorElementType();
1618
1619         // V1 load
1620         // f32 = load ...
1621         if (NumElts == 1) {
1622           // We only have one element, so just directly load it
1623           Value *SrcValue = Constant::getNullValue(PointerType::get(
1624               EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1625           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1626                                         DAG.getConstant(Ofst, getPointerTy()));
1627           SDValue P = DAG.getLoad(
1628               EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1629               false, true,
1630               TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
1631           if (P.getNode())
1632             P.getNode()->setIROrder(idx + 1);
1633
1634           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1635             P = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, P);
1636           InVals.push_back(P);
1637           Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
1638           ++InsIdx;
1639         } else if (NumElts == 2) {
1640           // V2 load
1641           // f32,f32 = load ...
1642           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
1643           Value *SrcValue = Constant::getNullValue(PointerType::get(
1644               VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1645           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1646                                         DAG.getConstant(Ofst, getPointerTy()));
1647           SDValue P = DAG.getLoad(
1648               VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1649               false, true,
1650               TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1651           if (P.getNode())
1652             P.getNode()->setIROrder(idx + 1);
1653
1654           SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1655                                      DAG.getIntPtrConstant(0));
1656           SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1657                                      DAG.getIntPtrConstant(1));
1658
1659           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) {
1660             Elt0 = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, Elt0);
1661             Elt1 = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, Elt1);
1662           }
1663
1664           InVals.push_back(Elt0);
1665           InVals.push_back(Elt1);
1666           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1667           InsIdx += 2;
1668         } else {
1669           // V4 loads
1670           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
1671           // the
1672           // vector will be expanded to a power of 2 elements, so we know we can
1673           // always round up to the next multiple of 4 when creating the vector
1674           // loads.
1675           // e.g.  4 elem => 1 ld.v4
1676           //       6 elem => 2 ld.v4
1677           //       8 elem => 2 ld.v4
1678           //      11 elem => 3 ld.v4
1679           unsigned VecSize = 4;
1680           if (EltVT.getSizeInBits() == 64) {
1681             VecSize = 2;
1682           }
1683           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1684           for (unsigned i = 0; i < NumElts; i += VecSize) {
1685             Value *SrcValue = Constant::getNullValue(
1686                 PointerType::get(VecVT.getTypeForEVT(F->getContext()),
1687                                  llvm::ADDRESS_SPACE_PARAM));
1688             SDValue SrcAddr =
1689                 DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1690                             DAG.getConstant(Ofst, getPointerTy()));
1691             SDValue P = DAG.getLoad(
1692                 VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1693                 false, true,
1694                 TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1695             if (P.getNode())
1696               P.getNode()->setIROrder(idx + 1);
1697
1698             for (unsigned j = 0; j < VecSize; ++j) {
1699               if (i + j >= NumElts)
1700                 break;
1701               SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1702                                         DAG.getIntPtrConstant(j));
1703               if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1704                 Elt = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, Elt);
1705               InVals.push_back(Elt);
1706             }
1707             Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1708             InsIdx += VecSize;
1709           }
1710         }
1711
1712         if (NumElts > 0)
1713           --InsIdx;
1714         continue;
1715       }
1716       // A plain scalar.
1717       EVT ObjectVT = getValueType(Ty);
1718       // If ABI, load from the param symbol
1719       SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1720       Value *srcValue = Constant::getNullValue(PointerType::get(
1721           ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1722       SDValue p;
1723       if (ObjectVT.getSizeInBits() < Ins[InsIdx].VT.getSizeInBits())
1724         p = DAG.getExtLoad(ISD::SEXTLOAD, dl, Ins[InsIdx].VT, Root, Arg,
1725                            MachinePointerInfo(srcValue), ObjectVT, false, false,
1726               TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1727       else
1728         p = DAG.getLoad(Ins[InsIdx].VT, dl, Root, Arg,
1729                         MachinePointerInfo(srcValue), false, false, false,
1730               TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1731       if (p.getNode())
1732         p.getNode()->setIROrder(idx + 1);
1733       InVals.push_back(p);
1734       continue;
1735     }
1736
1737     // Param has ByVal attribute
1738     // Return MoveParam(param symbol).
1739     // Ideally, the param symbol can be returned directly,
1740     // but when SDNode builder decides to use it in a CopyToReg(),
1741     // machine instruction fails because TargetExternalSymbol
1742     // (not lowered) is target dependent, and CopyToReg assumes
1743     // the source is lowered.
1744     EVT ObjectVT = getValueType(Ty);
1745     assert(ObjectVT == Ins[InsIdx].VT &&
1746            "Ins type did not match function type");
1747     SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1748     SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1749     if (p.getNode())
1750       p.getNode()->setIROrder(idx + 1);
1751     if (isKernel)
1752       InVals.push_back(p);
1753     else {
1754       SDValue p2 = DAG.getNode(
1755           ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1756           DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
1757       InVals.push_back(p2);
1758     }
1759   }
1760
1761   // Clang will check explicit VarArg and issue error if any. However, Clang
1762   // will let code with
1763   // implicit var arg like f() pass. See bug 617733.
1764   // We treat this case as if the arg list is empty.
1765   // if (F.isVarArg()) {
1766   // assert(0 && "VarArg not supported yet!");
1767   //}
1768
1769   if (!OutChains.empty())
1770     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &OutChains[0],
1771                             OutChains.size()));
1772
1773   return Chain;
1774 }
1775
1776
1777 SDValue
1778 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1779                                  bool isVarArg,
1780                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
1781                                  const SmallVectorImpl<SDValue> &OutVals,
1782                                  SDLoc dl, SelectionDAG &DAG) const {
1783   MachineFunction &MF = DAG.getMachineFunction();
1784   const Function *F = MF.getFunction();
1785   const Type *RetTy = F->getReturnType();
1786   const DataLayout *TD = getDataLayout();
1787
1788   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1789   assert(isABI && "Non-ABI compilation is not supported");
1790   if (!isABI)
1791     return Chain;
1792
1793   if (const VectorType *VTy = dyn_cast<const VectorType>(RetTy)) {
1794     // If we have a vector type, the OutVals array will be the scalarized
1795     // components and we have combine them into 1 or more vector stores.
1796     unsigned NumElts = VTy->getNumElements();
1797     assert(NumElts == Outs.size() && "Bad scalarization of return value");
1798
1799     // const_cast can be removed in later LLVM versions
1800     EVT EltVT = getValueType(const_cast<Type *>(RetTy)).getVectorElementType();
1801     bool NeedExtend = false;
1802     if (EltVT.getSizeInBits() < 16)
1803       NeedExtend = true;
1804
1805     // V1 store
1806     if (NumElts == 1) {
1807       SDValue StoreVal = OutVals[0];
1808       // We only have one element, so just directly store it
1809       if (NeedExtend)
1810         StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
1811       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal };
1812       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1813                                       DAG.getVTList(MVT::Other), &Ops[0], 3,
1814                                       EltVT, MachinePointerInfo());
1815
1816     } else if (NumElts == 2) {
1817       // V2 store
1818       SDValue StoreVal0 = OutVals[0];
1819       SDValue StoreVal1 = OutVals[1];
1820
1821       if (NeedExtend) {
1822         StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal0);
1823         StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal1);
1824       }
1825
1826       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal0,
1827                         StoreVal1 };
1828       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetvalV2, dl,
1829                                       DAG.getVTList(MVT::Other), &Ops[0], 4,
1830                                       EltVT, MachinePointerInfo());
1831     } else {
1832       // V4 stores
1833       // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
1834       // vector will be expanded to a power of 2 elements, so we know we can
1835       // always round up to the next multiple of 4 when creating the vector
1836       // stores.
1837       // e.g.  4 elem => 1 st.v4
1838       //       6 elem => 2 st.v4
1839       //       8 elem => 2 st.v4
1840       //      11 elem => 3 st.v4
1841
1842       unsigned VecSize = 4;
1843       if (OutVals[0].getValueType().getSizeInBits() == 64)
1844         VecSize = 2;
1845
1846       unsigned Offset = 0;
1847
1848       EVT VecVT =
1849           EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
1850       unsigned PerStoreOffset =
1851           TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1852
1853       for (unsigned i = 0; i < NumElts; i += VecSize) {
1854         // Get values
1855         SDValue StoreVal;
1856         SmallVector<SDValue, 8> Ops;
1857         Ops.push_back(Chain);
1858         Ops.push_back(DAG.getConstant(Offset, MVT::i32));
1859         unsigned Opc = NVPTXISD::StoreRetvalV2;
1860         EVT ExtendedVT = (NeedExtend) ? MVT::i16 : OutVals[0].getValueType();
1861
1862         StoreVal = OutVals[i];
1863         if (NeedExtend)
1864           StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1865         Ops.push_back(StoreVal);
1866
1867         if (i + 1 < NumElts) {
1868           StoreVal = OutVals[i + 1];
1869           if (NeedExtend)
1870             StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1871         } else {
1872           StoreVal = DAG.getUNDEF(ExtendedVT);
1873         }
1874         Ops.push_back(StoreVal);
1875
1876         if (VecSize == 4) {
1877           Opc = NVPTXISD::StoreRetvalV4;
1878           if (i + 2 < NumElts) {
1879             StoreVal = OutVals[i + 2];
1880             if (NeedExtend)
1881               StoreVal =
1882                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1883           } else {
1884             StoreVal = DAG.getUNDEF(ExtendedVT);
1885           }
1886           Ops.push_back(StoreVal);
1887
1888           if (i + 3 < NumElts) {
1889             StoreVal = OutVals[i + 3];
1890             if (NeedExtend)
1891               StoreVal =
1892                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1893           } else {
1894             StoreVal = DAG.getUNDEF(ExtendedVT);
1895           }
1896           Ops.push_back(StoreVal);
1897         }
1898
1899         // Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
1900         Chain =
1901             DAG.getMemIntrinsicNode(Opc, dl, DAG.getVTList(MVT::Other), &Ops[0],
1902                                     Ops.size(), EltVT, MachinePointerInfo());
1903         Offset += PerStoreOffset;
1904       }
1905     }
1906   } else {
1907     SmallVector<EVT, 16> ValVTs;
1908     // const_cast is necessary since we are still using an LLVM version from
1909     // before the type system re-write.
1910     ComputePTXValueVTs(*this, const_cast<Type *>(RetTy), ValVTs);
1911     assert(ValVTs.size() == OutVals.size() && "Bad return value decomposition");
1912
1913     unsigned sizesofar = 0;
1914     for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
1915       SDValue theVal = OutVals[i];
1916       EVT theValType = theVal.getValueType();
1917       unsigned numElems = 1;
1918       if (theValType.isVector())
1919         numElems = theValType.getVectorNumElements();
1920       for (unsigned j = 0, je = numElems; j != je; ++j) {
1921         SDValue tmpval = theVal;
1922         if (theValType.isVector())
1923           tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1924                                theValType.getVectorElementType(), tmpval,
1925                                DAG.getIntPtrConstant(j));
1926         EVT theStoreType = tmpval.getValueType();
1927         if (theStoreType.getSizeInBits() < 8)
1928           tmpval = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, tmpval);
1929         SDValue Ops[] = { Chain, DAG.getConstant(sizesofar, MVT::i32), tmpval };
1930         Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1931                                         DAG.getVTList(MVT::Other), &Ops[0], 3,
1932                                         ValVTs[i], MachinePointerInfo());
1933         if (theValType.isVector())
1934           sizesofar +=
1935               ValVTs[i].getVectorElementType().getStoreSizeInBits() / 8;
1936         else
1937           sizesofar += ValVTs[i].getStoreSizeInBits() / 8;
1938       }
1939     }
1940   }
1941
1942   return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1943 }
1944
1945
1946 void NVPTXTargetLowering::LowerAsmOperandForConstraint(
1947     SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
1948     SelectionDAG &DAG) const {
1949   if (Constraint.length() > 1)
1950     return;
1951   else
1952     TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1953 }
1954
1955 // NVPTX suuport vector of legal types of any length in Intrinsics because the
1956 // NVPTX specific type legalizer
1957 // will legalize them to the PTX supported length.
1958 bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1959   if (isTypeLegal(VT))
1960     return true;
1961   if (VT.isVector()) {
1962     MVT eVT = VT.getVectorElementType();
1963     if (isTypeLegal(eVT))
1964       return true;
1965   }
1966   return false;
1967 }
1968
1969 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
1970 // TgtMemIntrinsic
1971 // because we need the information that is only available in the "Value" type
1972 // of destination
1973 // pointer. In particular, the address space information.
1974 bool NVPTXTargetLowering::getTgtMemIntrinsic(
1975     IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const {
1976   switch (Intrinsic) {
1977   default:
1978     return false;
1979
1980   case Intrinsic::nvvm_atomic_load_add_f32:
1981     Info.opc = ISD::INTRINSIC_W_CHAIN;
1982     Info.memVT = MVT::f32;
1983     Info.ptrVal = I.getArgOperand(0);
1984     Info.offset = 0;
1985     Info.vol = 0;
1986     Info.readMem = true;
1987     Info.writeMem = true;
1988     Info.align = 0;
1989     return true;
1990
1991   case Intrinsic::nvvm_atomic_load_inc_32:
1992   case Intrinsic::nvvm_atomic_load_dec_32:
1993     Info.opc = ISD::INTRINSIC_W_CHAIN;
1994     Info.memVT = MVT::i32;
1995     Info.ptrVal = I.getArgOperand(0);
1996     Info.offset = 0;
1997     Info.vol = 0;
1998     Info.readMem = true;
1999     Info.writeMem = true;
2000     Info.align = 0;
2001     return true;
2002
2003   case Intrinsic::nvvm_ldu_global_i:
2004   case Intrinsic::nvvm_ldu_global_f:
2005   case Intrinsic::nvvm_ldu_global_p:
2006
2007     Info.opc = ISD::INTRINSIC_W_CHAIN;
2008     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
2009       Info.memVT = getValueType(I.getType());
2010     else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
2011       Info.memVT = getValueType(I.getType());
2012     else
2013       Info.memVT = MVT::f32;
2014     Info.ptrVal = I.getArgOperand(0);
2015     Info.offset = 0;
2016     Info.vol = 0;
2017     Info.readMem = true;
2018     Info.writeMem = false;
2019     Info.align = 0;
2020     return true;
2021
2022   }
2023   return false;
2024 }
2025
2026 /// isLegalAddressingMode - Return true if the addressing mode represented
2027 /// by AM is legal for this target, for a load/store of the specified type.
2028 /// Used to guide target specific optimizations, like loop strength reduction
2029 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
2030 /// (CodeGenPrepare.cpp)
2031 bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
2032                                                 Type *Ty) const {
2033
2034   // AddrMode - This represents an addressing mode of:
2035   //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
2036   //
2037   // The legal address modes are
2038   // - [avar]
2039   // - [areg]
2040   // - [areg+immoff]
2041   // - [immAddr]
2042
2043   if (AM.BaseGV) {
2044     if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
2045       return false;
2046     return true;
2047   }
2048
2049   switch (AM.Scale) {
2050   case 0: // "r", "r+i" or "i" is allowed
2051     break;
2052   case 1:
2053     if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
2054       return false;
2055     // Otherwise we have r+i.
2056     break;
2057   default:
2058     // No scale > 1 is allowed
2059     return false;
2060   }
2061   return true;
2062 }
2063
2064 //===----------------------------------------------------------------------===//
2065 //                         NVPTX Inline Assembly Support
2066 //===----------------------------------------------------------------------===//
2067
2068 /// getConstraintType - Given a constraint letter, return the type of
2069 /// constraint it is for this target.
2070 NVPTXTargetLowering::ConstraintType
2071 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
2072   if (Constraint.size() == 1) {
2073     switch (Constraint[0]) {
2074     default:
2075       break;
2076     case 'r':
2077     case 'h':
2078     case 'c':
2079     case 'l':
2080     case 'f':
2081     case 'd':
2082     case '0':
2083     case 'N':
2084       return C_RegisterClass;
2085     }
2086   }
2087   return TargetLowering::getConstraintType(Constraint);
2088 }
2089
2090 std::pair<unsigned, const TargetRegisterClass *>
2091 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
2092                                                   MVT VT) const {
2093   if (Constraint.size() == 1) {
2094     switch (Constraint[0]) {
2095     case 'c':
2096       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2097     case 'h':
2098       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2099     case 'r':
2100       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
2101     case 'l':
2102     case 'N':
2103       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
2104     case 'f':
2105       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
2106     case 'd':
2107       return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
2108     }
2109   }
2110   return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
2111 }
2112
2113 /// getFunctionAlignment - Return the Log2 alignment of this function.
2114 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
2115   return 4;
2116 }
2117
2118 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
2119 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
2120                               SmallVectorImpl<SDValue> &Results) {
2121   EVT ResVT = N->getValueType(0);
2122   SDLoc DL(N);
2123
2124   assert(ResVT.isVector() && "Vector load must have vector type");
2125
2126   // We only handle "native" vector sizes for now, e.g. <4 x double> is not
2127   // legal.  We can (and should) split that into 2 loads of <2 x double> here
2128   // but I'm leaving that as a TODO for now.
2129   assert(ResVT.isSimple() && "Can only handle simple types");
2130   switch (ResVT.getSimpleVT().SimpleTy) {
2131   default:
2132     return;
2133   case MVT::v2i8:
2134   case MVT::v2i16:
2135   case MVT::v2i32:
2136   case MVT::v2i64:
2137   case MVT::v2f32:
2138   case MVT::v2f64:
2139   case MVT::v4i8:
2140   case MVT::v4i16:
2141   case MVT::v4i32:
2142   case MVT::v4f32:
2143     // This is a "native" vector type
2144     break;
2145   }
2146
2147   EVT EltVT = ResVT.getVectorElementType();
2148   unsigned NumElts = ResVT.getVectorNumElements();
2149
2150   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
2151   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2152   // loaded type to i16 and propogate the "real" type as the memory type.
2153   bool NeedTrunc = false;
2154   if (EltVT.getSizeInBits() < 16) {
2155     EltVT = MVT::i16;
2156     NeedTrunc = true;
2157   }
2158
2159   unsigned Opcode = 0;
2160   SDVTList LdResVTs;
2161
2162   switch (NumElts) {
2163   default:
2164     return;
2165   case 2:
2166     Opcode = NVPTXISD::LoadV2;
2167     LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2168     break;
2169   case 4: {
2170     Opcode = NVPTXISD::LoadV4;
2171     EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2172     LdResVTs = DAG.getVTList(ListVTs, 5);
2173     break;
2174   }
2175   }
2176
2177   SmallVector<SDValue, 8> OtherOps;
2178
2179   // Copy regular operands
2180   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2181     OtherOps.push_back(N->getOperand(i));
2182
2183   LoadSDNode *LD = cast<LoadSDNode>(N);
2184
2185   // The select routine does not have access to the LoadSDNode instance, so
2186   // pass along the extension information
2187   OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
2188
2189   SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
2190                                           OtherOps.size(), LD->getMemoryVT(),
2191                                           LD->getMemOperand());
2192
2193   SmallVector<SDValue, 4> ScalarRes;
2194
2195   for (unsigned i = 0; i < NumElts; ++i) {
2196     SDValue Res = NewLD.getValue(i);
2197     if (NeedTrunc)
2198       Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2199     ScalarRes.push_back(Res);
2200   }
2201
2202   SDValue LoadChain = NewLD.getValue(NumElts);
2203
2204   SDValue BuildVec =
2205       DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2206
2207   Results.push_back(BuildVec);
2208   Results.push_back(LoadChain);
2209 }
2210
2211 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
2212                                      SmallVectorImpl<SDValue> &Results) {
2213   SDValue Chain = N->getOperand(0);
2214   SDValue Intrin = N->getOperand(1);
2215   SDLoc DL(N);
2216
2217   // Get the intrinsic ID
2218   unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
2219   switch (IntrinNo) {
2220   default:
2221     return;
2222   case Intrinsic::nvvm_ldg_global_i:
2223   case Intrinsic::nvvm_ldg_global_f:
2224   case Intrinsic::nvvm_ldg_global_p:
2225   case Intrinsic::nvvm_ldu_global_i:
2226   case Intrinsic::nvvm_ldu_global_f:
2227   case Intrinsic::nvvm_ldu_global_p: {
2228     EVT ResVT = N->getValueType(0);
2229
2230     if (ResVT.isVector()) {
2231       // Vector LDG/LDU
2232
2233       unsigned NumElts = ResVT.getVectorNumElements();
2234       EVT EltVT = ResVT.getVectorElementType();
2235
2236       // Since LDU/LDG are target nodes, we cannot rely on DAG type
2237       // legalization.
2238       // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2239       // loaded type to i16 and propogate the "real" type as the memory type.
2240       bool NeedTrunc = false;
2241       if (EltVT.getSizeInBits() < 16) {
2242         EltVT = MVT::i16;
2243         NeedTrunc = true;
2244       }
2245
2246       unsigned Opcode = 0;
2247       SDVTList LdResVTs;
2248
2249       switch (NumElts) {
2250       default:
2251         return;
2252       case 2:
2253         switch (IntrinNo) {
2254         default:
2255           return;
2256         case Intrinsic::nvvm_ldg_global_i:
2257         case Intrinsic::nvvm_ldg_global_f:
2258         case Intrinsic::nvvm_ldg_global_p:
2259           Opcode = NVPTXISD::LDGV2;
2260           break;
2261         case Intrinsic::nvvm_ldu_global_i:
2262         case Intrinsic::nvvm_ldu_global_f:
2263         case Intrinsic::nvvm_ldu_global_p:
2264           Opcode = NVPTXISD::LDUV2;
2265           break;
2266         }
2267         LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2268         break;
2269       case 4: {
2270         switch (IntrinNo) {
2271         default:
2272           return;
2273         case Intrinsic::nvvm_ldg_global_i:
2274         case Intrinsic::nvvm_ldg_global_f:
2275         case Intrinsic::nvvm_ldg_global_p:
2276           Opcode = NVPTXISD::LDGV4;
2277           break;
2278         case Intrinsic::nvvm_ldu_global_i:
2279         case Intrinsic::nvvm_ldu_global_f:
2280         case Intrinsic::nvvm_ldu_global_p:
2281           Opcode = NVPTXISD::LDUV4;
2282           break;
2283         }
2284         EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2285         LdResVTs = DAG.getVTList(ListVTs, 5);
2286         break;
2287       }
2288       }
2289
2290       SmallVector<SDValue, 8> OtherOps;
2291
2292       // Copy regular operands
2293
2294       OtherOps.push_back(Chain); // Chain
2295                                  // Skip operand 1 (intrinsic ID)
2296       // Others
2297       for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
2298         OtherOps.push_back(N->getOperand(i));
2299
2300       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2301
2302       SDValue NewLD = DAG.getMemIntrinsicNode(
2303           Opcode, DL, LdResVTs, &OtherOps[0], OtherOps.size(),
2304           MemSD->getMemoryVT(), MemSD->getMemOperand());
2305
2306       SmallVector<SDValue, 4> ScalarRes;
2307
2308       for (unsigned i = 0; i < NumElts; ++i) {
2309         SDValue Res = NewLD.getValue(i);
2310         if (NeedTrunc)
2311           Res =
2312               DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2313         ScalarRes.push_back(Res);
2314       }
2315
2316       SDValue LoadChain = NewLD.getValue(NumElts);
2317
2318       SDValue BuildVec =
2319           DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2320
2321       Results.push_back(BuildVec);
2322       Results.push_back(LoadChain);
2323     } else {
2324       // i8 LDG/LDU
2325       assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
2326              "Custom handling of non-i8 ldu/ldg?");
2327
2328       // Just copy all operands as-is
2329       SmallVector<SDValue, 4> Ops;
2330       for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2331         Ops.push_back(N->getOperand(i));
2332
2333       // Force output to i16
2334       SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
2335
2336       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2337
2338       // We make sure the memory type is i8, which will be used during isel
2339       // to select the proper instruction.
2340       SDValue NewLD =
2341           DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0],
2342                                   Ops.size(), MVT::i8, MemSD->getMemOperand());
2343
2344       Results.push_back(NewLD.getValue(0));
2345       Results.push_back(NewLD.getValue(1));
2346     }
2347   }
2348   }
2349 }
2350
2351 void NVPTXTargetLowering::ReplaceNodeResults(
2352     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
2353   switch (N->getOpcode()) {
2354   default:
2355     report_fatal_error("Unhandled custom legalization");
2356   case ISD::LOAD:
2357     ReplaceLoadVector(N, DAG, Results);
2358     return;
2359   case ISD::INTRINSIC_W_CHAIN:
2360     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
2361     return;
2362   }
2363 }