[NVPTX] Add support for cttz/ctlz/ctpop
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
1 //
2 //                     The LLVM Compiler Infrastructure
3 //
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "NVPTXISelLowering.h"
15 #include "NVPTX.h"
16 #include "NVPTXTargetMachine.h"
17 #include "NVPTXTargetObjectFile.h"
18 #include "NVPTXUtilities.h"
19 #include "llvm/CodeGen/Analysis.h"
20 #include "llvm/CodeGen/MachineFrameInfo.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/MC/MCSectionELF.h"
32 #include "llvm/Support/CallSite.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <sstream>
38
39 #undef DEBUG_TYPE
40 #define DEBUG_TYPE "nvptx-lower"
41
42 using namespace llvm;
43
44 static unsigned int uniqueCallSite = 0;
45
46 static cl::opt<bool> sched4reg(
47     "nvptx-sched4reg",
48     cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
49
50 static bool IsPTXVectorType(MVT VT) {
51   switch (VT.SimpleTy) {
52   default:
53     return false;
54   case MVT::v2i1:
55   case MVT::v4i1:
56   case MVT::v2i8:
57   case MVT::v4i8:
58   case MVT::v2i16:
59   case MVT::v4i16:
60   case MVT::v2i32:
61   case MVT::v4i32:
62   case MVT::v2i64:
63   case MVT::v2f32:
64   case MVT::v4f32:
65   case MVT::v2f64:
66     return true;
67   }
68 }
69
70 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
71 /// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
72 /// into their primitive components.
73 /// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
74 /// same number of types as the Ins/Outs arrays in LowerFormalArguments,
75 /// LowerCall, and LowerReturn.
76 static void ComputePTXValueVTs(const TargetLowering &TLI, Type *Ty,
77                                SmallVectorImpl<EVT> &ValueVTs,
78                                SmallVectorImpl<uint64_t> *Offsets = 0,
79                                uint64_t StartingOffset = 0) {
80   SmallVector<EVT, 16> TempVTs;
81   SmallVector<uint64_t, 16> TempOffsets;
82
83   ComputeValueVTs(TLI, Ty, TempVTs, &TempOffsets, StartingOffset);
84   for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
85     EVT VT = TempVTs[i];
86     uint64_t Off = TempOffsets[i];
87     if (VT.isVector())
88       for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) {
89         ValueVTs.push_back(VT.getVectorElementType());
90         if (Offsets)
91           Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize());
92       }
93     else {
94       ValueVTs.push_back(VT);
95       if (Offsets)
96         Offsets->push_back(Off);
97     }
98   }
99 }
100
101 // NVPTXTargetLowering Constructor.
102 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
103     : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
104       nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
105
106   // always lower memset, memcpy, and memmove intrinsics to load/store
107   // instructions, rather
108   // then generating calls to memset, mempcy or memmove.
109   MaxStoresPerMemset = (unsigned) 0xFFFFFFFF;
110   MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF;
111   MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF;
112
113   setBooleanContents(ZeroOrNegativeOneBooleanContent);
114
115   // Jump is Expensive. Don't create extra control flow for 'and', 'or'
116   // condition branches.
117   setJumpIsExpensive(true);
118
119   // By default, use the Source scheduling
120   if (sched4reg)
121     setSchedulingPreference(Sched::RegPressure);
122   else
123     setSchedulingPreference(Sched::Source);
124
125   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
126   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
127   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
128   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
129   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
130   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
131
132   // Operations not directly supported by NVPTX.
133   setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
134   setOperationAction(ISD::BR_CC, MVT::f32, Expand);
135   setOperationAction(ISD::BR_CC, MVT::f64, Expand);
136   setOperationAction(ISD::BR_CC, MVT::i1, Expand);
137   setOperationAction(ISD::BR_CC, MVT::i8, Expand);
138   setOperationAction(ISD::BR_CC, MVT::i16, Expand);
139   setOperationAction(ISD::BR_CC, MVT::i32, Expand);
140   setOperationAction(ISD::BR_CC, MVT::i64, Expand);
141   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Expand);
142   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Expand);
143   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand);
144   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8, Expand);
145   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
146
147   if (nvptxSubtarget.hasROT64()) {
148     setOperationAction(ISD::ROTL, MVT::i64, Legal);
149     setOperationAction(ISD::ROTR, MVT::i64, Legal);
150   } else {
151     setOperationAction(ISD::ROTL, MVT::i64, Expand);
152     setOperationAction(ISD::ROTR, MVT::i64, Expand);
153   }
154   if (nvptxSubtarget.hasROT32()) {
155     setOperationAction(ISD::ROTL, MVT::i32, Legal);
156     setOperationAction(ISD::ROTR, MVT::i32, Legal);
157   } else {
158     setOperationAction(ISD::ROTL, MVT::i32, Expand);
159     setOperationAction(ISD::ROTR, MVT::i32, Expand);
160   }
161
162   setOperationAction(ISD::ROTL, MVT::i16, Expand);
163   setOperationAction(ISD::ROTR, MVT::i16, Expand);
164   setOperationAction(ISD::ROTL, MVT::i8, Expand);
165   setOperationAction(ISD::ROTR, MVT::i8, Expand);
166   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
167   setOperationAction(ISD::BSWAP, MVT::i32, Expand);
168   setOperationAction(ISD::BSWAP, MVT::i64, Expand);
169
170   // Indirect branch is not supported.
171   // This also disables Jump Table creation.
172   setOperationAction(ISD::BR_JT, MVT::Other, Expand);
173   setOperationAction(ISD::BRIND, MVT::Other, Expand);
174
175   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
176   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
177
178   // We want to legalize constant related memmove and memcopy
179   // intrinsics.
180   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
181
182   // Turn FP extload into load/fextend
183   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
184   // Turn FP truncstore into trunc + store.
185   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
186
187   // PTX does not support load / store predicate registers
188   setOperationAction(ISD::LOAD, MVT::i1, Custom);
189   setOperationAction(ISD::STORE, MVT::i1, Custom);
190
191   setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
192   setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
193   setTruncStoreAction(MVT::i64, MVT::i1, Expand);
194   setTruncStoreAction(MVT::i32, MVT::i1, Expand);
195   setTruncStoreAction(MVT::i16, MVT::i1, Expand);
196   setTruncStoreAction(MVT::i8, MVT::i1, Expand);
197
198   // This is legal in NVPTX
199   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
200   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
201
202   // TRAP can be lowered to PTX trap
203   setOperationAction(ISD::TRAP, MVT::Other, Legal);
204
205   // Register custom handling for vector loads/stores
206   for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
207        ++i) {
208     MVT VT = (MVT::SimpleValueType) i;
209     if (IsPTXVectorType(VT)) {
210       setOperationAction(ISD::LOAD, VT, Custom);
211       setOperationAction(ISD::STORE, VT, Custom);
212       setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
213     }
214   }
215
216   // Custom handling for i8 intrinsics
217   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
218
219   setOperationAction(ISD::CTLZ, MVT::i16, Legal);
220   setOperationAction(ISD::CTLZ, MVT::i32, Legal);
221   setOperationAction(ISD::CTLZ, MVT::i64, Legal);
222   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i16, Legal);
223   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Legal);
224   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i64, Legal);
225   setOperationAction(ISD::CTTZ, MVT::i16, Expand);
226   setOperationAction(ISD::CTTZ, MVT::i32, Expand);
227   setOperationAction(ISD::CTTZ, MVT::i64, Expand);
228   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i16, Expand);
229   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i32, Expand);
230   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i64, Expand);
231   setOperationAction(ISD::CTPOP, MVT::i16, Legal);
232   setOperationAction(ISD::CTPOP, MVT::i32, Legal);
233   setOperationAction(ISD::CTPOP, MVT::i64, Legal);
234
235   // Now deduce the information based on the above mentioned
236   // actions
237   computeRegisterProperties();
238 }
239
240 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
241   switch (Opcode) {
242   default:
243     return 0;
244   case NVPTXISD::CALL:
245     return "NVPTXISD::CALL";
246   case NVPTXISD::RET_FLAG:
247     return "NVPTXISD::RET_FLAG";
248   case NVPTXISD::Wrapper:
249     return "NVPTXISD::Wrapper";
250   case NVPTXISD::NVBuiltin:
251     return "NVPTXISD::NVBuiltin";
252   case NVPTXISD::DeclareParam:
253     return "NVPTXISD::DeclareParam";
254   case NVPTXISD::DeclareScalarParam:
255     return "NVPTXISD::DeclareScalarParam";
256   case NVPTXISD::DeclareRet:
257     return "NVPTXISD::DeclareRet";
258   case NVPTXISD::DeclareRetParam:
259     return "NVPTXISD::DeclareRetParam";
260   case NVPTXISD::PrintCall:
261     return "NVPTXISD::PrintCall";
262   case NVPTXISD::LoadParam:
263     return "NVPTXISD::LoadParam";
264   case NVPTXISD::LoadParamV2:
265     return "NVPTXISD::LoadParamV2";
266   case NVPTXISD::LoadParamV4:
267     return "NVPTXISD::LoadParamV4";
268   case NVPTXISD::StoreParam:
269     return "NVPTXISD::StoreParam";
270   case NVPTXISD::StoreParamV2:
271     return "NVPTXISD::StoreParamV2";
272   case NVPTXISD::StoreParamV4:
273     return "NVPTXISD::StoreParamV4";
274   case NVPTXISD::StoreParamS32:
275     return "NVPTXISD::StoreParamS32";
276   case NVPTXISD::StoreParamU32:
277     return "NVPTXISD::StoreParamU32";
278   case NVPTXISD::CallArgBegin:
279     return "NVPTXISD::CallArgBegin";
280   case NVPTXISD::CallArg:
281     return "NVPTXISD::CallArg";
282   case NVPTXISD::LastCallArg:
283     return "NVPTXISD::LastCallArg";
284   case NVPTXISD::CallArgEnd:
285     return "NVPTXISD::CallArgEnd";
286   case NVPTXISD::CallVoid:
287     return "NVPTXISD::CallVoid";
288   case NVPTXISD::CallVal:
289     return "NVPTXISD::CallVal";
290   case NVPTXISD::CallSymbol:
291     return "NVPTXISD::CallSymbol";
292   case NVPTXISD::Prototype:
293     return "NVPTXISD::Prototype";
294   case NVPTXISD::MoveParam:
295     return "NVPTXISD::MoveParam";
296   case NVPTXISD::StoreRetval:
297     return "NVPTXISD::StoreRetval";
298   case NVPTXISD::StoreRetvalV2:
299     return "NVPTXISD::StoreRetvalV2";
300   case NVPTXISD::StoreRetvalV4:
301     return "NVPTXISD::StoreRetvalV4";
302   case NVPTXISD::PseudoUseParam:
303     return "NVPTXISD::PseudoUseParam";
304   case NVPTXISD::RETURN:
305     return "NVPTXISD::RETURN";
306   case NVPTXISD::CallSeqBegin:
307     return "NVPTXISD::CallSeqBegin";
308   case NVPTXISD::CallSeqEnd:
309     return "NVPTXISD::CallSeqEnd";
310   case NVPTXISD::LoadV2:
311     return "NVPTXISD::LoadV2";
312   case NVPTXISD::LoadV4:
313     return "NVPTXISD::LoadV4";
314   case NVPTXISD::LDGV2:
315     return "NVPTXISD::LDGV2";
316   case NVPTXISD::LDGV4:
317     return "NVPTXISD::LDGV4";
318   case NVPTXISD::LDUV2:
319     return "NVPTXISD::LDUV2";
320   case NVPTXISD::LDUV4:
321     return "NVPTXISD::LDUV4";
322   case NVPTXISD::StoreV2:
323     return "NVPTXISD::StoreV2";
324   case NVPTXISD::StoreV4:
325     return "NVPTXISD::StoreV4";
326   }
327 }
328
329 bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
330   return VT == MVT::i1;
331 }
332
333 SDValue
334 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
335   SDLoc dl(Op);
336   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
337   Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
338   return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
339 }
340
341 /*
342 std::string NVPTXTargetLowering::getPrototype(
343     Type *retTy, const ArgListTy &Args,
344     const SmallVectorImpl<ISD::OutputArg> &Outs, unsigned retAlignment) const {
345
346   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
347
348   std::stringstream O;
349   O << "prototype_" << uniqueCallSite << " : .callprototype ";
350
351   if (retTy->getTypeID() == Type::VoidTyID)
352     O << "()";
353   else {
354     O << "(";
355     if (isABI) {
356       if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
357         unsigned size = 0;
358         if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
359           size = ITy->getBitWidth();
360           if (size < 32)
361             size = 32;
362         } else {
363           assert(retTy->isFloatingPointTy() &&
364                  "Floating point type expected here");
365           size = retTy->getPrimitiveSizeInBits();
366         }
367
368         O << ".param .b" << size << " _";
369       } else if (isa<PointerType>(retTy))
370         O << ".param .b" << getPointerTy().getSizeInBits() << " _";
371       else {
372         if ((retTy->getTypeID() == Type::StructTyID) ||
373             isa<VectorType>(retTy)) {
374           SmallVector<EVT, 16> vtparts;
375           ComputeValueVTs(*this, retTy, vtparts);
376           unsigned totalsz = 0;
377           for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
378             unsigned elems = 1;
379             EVT elemtype = vtparts[i];
380             if (vtparts[i].isVector()) {
381               elems = vtparts[i].getVectorNumElements();
382               elemtype = vtparts[i].getVectorElementType();
383             }
384             for (unsigned j = 0, je = elems; j != je; ++j) {
385               unsigned sz = elemtype.getSizeInBits();
386               if (elemtype.isInteger() && (sz < 8))
387                 sz = 8;
388               totalsz += sz / 8;
389             }
390           }
391           O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
392         } else {
393           assert(false && "Unknown return type");
394         }
395       }
396     } else {
397       SmallVector<EVT, 16> vtparts;
398       ComputeValueVTs(*this, retTy, vtparts);
399       unsigned idx = 0;
400       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
401         unsigned elems = 1;
402         EVT elemtype = vtparts[i];
403         if (vtparts[i].isVector()) {
404           elems = vtparts[i].getVectorNumElements();
405           elemtype = vtparts[i].getVectorElementType();
406         }
407
408         for (unsigned j = 0, je = elems; j != je; ++j) {
409           unsigned sz = elemtype.getSizeInBits();
410           if (elemtype.isInteger() && (sz < 32))
411             sz = 32;
412           O << ".reg .b" << sz << " _";
413           if (j < je - 1)
414             O << ", ";
415           ++idx;
416         }
417         if (i < e - 1)
418           O << ", ";
419       }
420     }
421     O << ") ";
422   }
423   O << "_ (";
424
425   bool first = true;
426   MVT thePointerTy = getPointerTy();
427
428   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
429     const Type *Ty = Args[i].Ty;
430     if (!first) {
431       O << ", ";
432     }
433     first = false;
434
435     if (Outs[i].Flags.isByVal() == false) {
436       unsigned sz = 0;
437       if (isa<IntegerType>(Ty)) {
438         sz = cast<IntegerType>(Ty)->getBitWidth();
439         if (sz < 32)
440           sz = 32;
441       } else if (isa<PointerType>(Ty))
442         sz = thePointerTy.getSizeInBits();
443       else
444         sz = Ty->getPrimitiveSizeInBits();
445       if (isABI)
446         O << ".param .b" << sz << " ";
447       else
448         O << ".reg .b" << sz << " ";
449       O << "_";
450       continue;
451     }
452     const PointerType *PTy = dyn_cast<PointerType>(Ty);
453     assert(PTy && "Param with byval attribute should be a pointer type");
454     Type *ETy = PTy->getElementType();
455
456     if (isABI) {
457       unsigned align = Outs[i].Flags.getByValAlign();
458       unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
459       O << ".param .align " << align << " .b8 ";
460       O << "_";
461       O << "[" << sz << "]";
462       continue;
463     } else {
464       SmallVector<EVT, 16> vtparts;
465       ComputeValueVTs(*this, ETy, vtparts);
466       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
467         unsigned elems = 1;
468         EVT elemtype = vtparts[i];
469         if (vtparts[i].isVector()) {
470           elems = vtparts[i].getVectorNumElements();
471           elemtype = vtparts[i].getVectorElementType();
472         }
473
474         for (unsigned j = 0, je = elems; j != je; ++j) {
475           unsigned sz = elemtype.getSizeInBits();
476           if (elemtype.isInteger() && (sz < 32))
477             sz = 32;
478           O << ".reg .b" << sz << " ";
479           O << "_";
480           if (j < je - 1)
481             O << ", ";
482         }
483         if (i < e - 1)
484           O << ", ";
485       }
486       continue;
487     }
488   }
489   O << ");";
490   return O.str();
491 }*/
492
493 std::string
494 NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
495                                   const SmallVectorImpl<ISD::OutputArg> &Outs,
496                                   unsigned retAlignment,
497                                   const ImmutableCallSite *CS) const {
498
499   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
500   assert(isABI && "Non-ABI compilation is not supported");
501   if (!isABI)
502     return "";
503
504   std::stringstream O;
505   O << "prototype_" << uniqueCallSite << " : .callprototype ";
506
507   if (retTy->getTypeID() == Type::VoidTyID) {
508     O << "()";
509   } else {
510     O << "(";
511     if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
512       unsigned size = 0;
513       if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
514         size = ITy->getBitWidth();
515         if (size < 32)
516           size = 32;
517       } else {
518         assert(retTy->isFloatingPointTy() &&
519                "Floating point type expected here");
520         size = retTy->getPrimitiveSizeInBits();
521       }
522
523       O << ".param .b" << size << " _";
524     } else if (isa<PointerType>(retTy)) {
525       O << ".param .b" << getPointerTy().getSizeInBits() << " _";
526     } else {
527       if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) {
528         SmallVector<EVT, 16> vtparts;
529         ComputeValueVTs(*this, retTy, vtparts);
530         unsigned totalsz = 0;
531         for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
532           unsigned elems = 1;
533           EVT elemtype = vtparts[i];
534           if (vtparts[i].isVector()) {
535             elems = vtparts[i].getVectorNumElements();
536             elemtype = vtparts[i].getVectorElementType();
537           }
538           // TODO: no need to loop
539           for (unsigned j = 0, je = elems; j != je; ++j) {
540             unsigned sz = elemtype.getSizeInBits();
541             if (elemtype.isInteger() && (sz < 8))
542               sz = 8;
543             totalsz += sz / 8;
544           }
545         }
546         O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
547       } else {
548         assert(false && "Unknown return type");
549       }
550     }
551     O << ") ";
552   }
553   O << "_ (";
554
555   bool first = true;
556   MVT thePointerTy = getPointerTy();
557
558   unsigned OIdx = 0;
559   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
560     Type *Ty = Args[i].Ty;
561     if (!first) {
562       O << ", ";
563     }
564     first = false;
565
566     if (Outs[OIdx].Flags.isByVal() == false) {
567       if (Ty->isAggregateType() || Ty->isVectorTy()) {
568         unsigned align = 0;
569         const CallInst *CallI = cast<CallInst>(CS->getInstruction());
570         const DataLayout *TD = getDataLayout();
571         // +1 because index 0 is reserved for return type alignment
572         if (!llvm::getAlign(*CallI, i + 1, align))
573           align = TD->getABITypeAlignment(Ty);
574         unsigned sz = TD->getTypeAllocSize(Ty);
575         O << ".param .align " << align << " .b8 ";
576         O << "_";
577         O << "[" << sz << "]";
578         // update the index for Outs
579         SmallVector<EVT, 16> vtparts;
580         ComputeValueVTs(*this, Ty, vtparts);
581         if (unsigned len = vtparts.size())
582           OIdx += len - 1;
583         continue;
584       }
585       assert(getValueType(Ty) == Outs[OIdx].VT &&
586              "type mismatch between callee prototype and arguments");
587       // scalar type
588       unsigned sz = 0;
589       if (isa<IntegerType>(Ty)) {
590         sz = cast<IntegerType>(Ty)->getBitWidth();
591         if (sz < 32)
592           sz = 32;
593       } else if (isa<PointerType>(Ty))
594         sz = thePointerTy.getSizeInBits();
595       else
596         sz = Ty->getPrimitiveSizeInBits();
597       O << ".param .b" << sz << " ";
598       O << "_";
599       continue;
600     }
601     const PointerType *PTy = dyn_cast<PointerType>(Ty);
602     assert(PTy && "Param with byval attribute should be a pointer type");
603     Type *ETy = PTy->getElementType();
604
605     unsigned align = Outs[OIdx].Flags.getByValAlign();
606     unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
607     O << ".param .align " << align << " .b8 ";
608     O << "_";
609     O << "[" << sz << "]";
610   }
611   O << ");";
612   return O.str();
613 }
614
615 unsigned
616 NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
617                                           const ImmutableCallSite *CS,
618                                           Type *Ty,
619                                           unsigned Idx) const {
620   const DataLayout *TD = getDataLayout();
621   unsigned align = 0;
622   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
623
624   if (Func) { // direct call
625     assert(CS->getCalledFunction() &&
626            "direct call cannot find callee");
627     if (!llvm::getAlign(*(CS->getCalledFunction()), Idx, align))
628       align = TD->getABITypeAlignment(Ty);
629   }
630   else { // indirect call
631     const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
632     if (!llvm::getAlign(*CallI, Idx, align))
633       align = TD->getABITypeAlignment(Ty);
634   }
635
636   return align;
637 }
638
639 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
640                                        SmallVectorImpl<SDValue> &InVals) const {
641   SelectionDAG &DAG = CLI.DAG;
642   SDLoc dl = CLI.DL;
643   SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
644   SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
645   SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
646   SDValue Chain = CLI.Chain;
647   SDValue Callee = CLI.Callee;
648   bool &isTailCall = CLI.IsTailCall;
649   ArgListTy &Args = CLI.Args;
650   Type *retTy = CLI.RetTy;
651   ImmutableCallSite *CS = CLI.CS;
652
653   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
654   assert(isABI && "Non-ABI compilation is not supported");
655   if (!isABI)
656     return Chain;
657   const DataLayout *TD = getDataLayout();
658   MachineFunction &MF = DAG.getMachineFunction();
659   const Function *F = MF.getFunction();
660   const TargetLowering *TLI = nvTM->getTargetLowering();
661
662   SDValue tempChain = Chain;
663   Chain =
664       DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
665                            dl);
666   SDValue InFlag = Chain.getValue(1);
667
668   unsigned paramCount = 0;
669   // Args.size() and Outs.size() need not match.
670   // Outs.size() will be larger
671   //   * if there is an aggregate argument with multiple fields (each field
672   //     showing up separately in Outs)
673   //   * if there is a vector argument with more than typical vector-length
674   //     elements (generally if more than 4) where each vector element is
675   //     individually present in Outs.
676   // So a different index should be used for indexing into Outs/OutVals.
677   // See similar issue in LowerFormalArguments.
678   unsigned OIdx = 0;
679   // Declare the .params or .reg need to pass values
680   // to the function
681   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
682     EVT VT = Outs[OIdx].VT;
683     Type *Ty = Args[i].Ty;
684
685     if (Outs[OIdx].Flags.isByVal() == false) {
686       if (Ty->isAggregateType()) {
687         // aggregate
688         SmallVector<EVT, 16> vtparts;
689         ComputeValueVTs(*this, Ty, vtparts);
690
691         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
692         // declare .param .align <align> .b8 .param<n>[<size>];
693         unsigned sz = TD->getTypeAllocSize(Ty);
694         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
695         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
696                                       DAG.getConstant(paramCount, MVT::i32),
697                                       DAG.getConstant(sz, MVT::i32), InFlag };
698         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
699                             DeclareParamOps, 5);
700         InFlag = Chain.getValue(1);
701         unsigned curOffset = 0;
702         for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
703           unsigned elems = 1;
704           EVT elemtype = vtparts[j];
705           if (vtparts[j].isVector()) {
706             elems = vtparts[j].getVectorNumElements();
707             elemtype = vtparts[j].getVectorElementType();
708           }
709           for (unsigned k = 0, ke = elems; k != ke; ++k) {
710             unsigned sz = elemtype.getSizeInBits();
711             if (elemtype.isInteger() && (sz < 8))
712               sz = 8;
713             SDValue StVal = OutVals[OIdx];
714             if (elemtype.getSizeInBits() < 16) {
715               StVal = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i16, StVal);
716             }
717             SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
718             SDValue CopyParamOps[] = { Chain,
719                                        DAG.getConstant(paramCount, MVT::i32),
720                                        DAG.getConstant(curOffset, MVT::i32),
721                                        StVal, InFlag };
722             Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
723                                             CopyParamVTs, &CopyParamOps[0], 5,
724                                             elemtype, MachinePointerInfo());
725             InFlag = Chain.getValue(1);
726             curOffset += sz / 8;
727             ++OIdx;
728           }
729         }
730         if (vtparts.size() > 0)
731           --OIdx;
732         ++paramCount;
733         continue;
734       }
735       if (Ty->isVectorTy()) {
736         EVT ObjectVT = getValueType(Ty);
737         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
738         // declare .param .align <align> .b8 .param<n>[<size>];
739         unsigned sz = TD->getTypeAllocSize(Ty);
740         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
741         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
742                                       DAG.getConstant(paramCount, MVT::i32),
743                                       DAG.getConstant(sz, MVT::i32), InFlag };
744         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
745                             DeclareParamOps, 5);
746         InFlag = Chain.getValue(1);
747         unsigned NumElts = ObjectVT.getVectorNumElements();
748         EVT EltVT = ObjectVT.getVectorElementType();
749         EVT MemVT = EltVT;
750         bool NeedExtend = false;
751         if (EltVT.getSizeInBits() < 16) {
752           NeedExtend = true;
753           EltVT = MVT::i16;
754         }
755
756         // V1 store
757         if (NumElts == 1) {
758           SDValue Elt = OutVals[OIdx++];
759           if (NeedExtend)
760             Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt);
761
762           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
763           SDValue CopyParamOps[] = { Chain,
764                                      DAG.getConstant(paramCount, MVT::i32),
765                                      DAG.getConstant(0, MVT::i32), Elt,
766                                      InFlag };
767           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
768                                           CopyParamVTs, &CopyParamOps[0], 5,
769                                           MemVT, MachinePointerInfo());
770           InFlag = Chain.getValue(1);
771         } else if (NumElts == 2) {
772           SDValue Elt0 = OutVals[OIdx++];
773           SDValue Elt1 = OutVals[OIdx++];
774           if (NeedExtend) {
775             Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0);
776             Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1);
777           }
778
779           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
780           SDValue CopyParamOps[] = { Chain,
781                                      DAG.getConstant(paramCount, MVT::i32),
782                                      DAG.getConstant(0, MVT::i32), Elt0, Elt1,
783                                      InFlag };
784           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl,
785                                           CopyParamVTs, &CopyParamOps[0], 6,
786                                           MemVT, MachinePointerInfo());
787           InFlag = Chain.getValue(1);
788         } else {
789           unsigned curOffset = 0;
790           // V4 stores
791           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
792           // the
793           // vector will be expanded to a power of 2 elements, so we know we can
794           // always round up to the next multiple of 4 when creating the vector
795           // stores.
796           // e.g.  4 elem => 1 st.v4
797           //       6 elem => 2 st.v4
798           //       8 elem => 2 st.v4
799           //      11 elem => 3 st.v4
800           unsigned VecSize = 4;
801           if (EltVT.getSizeInBits() == 64)
802             VecSize = 2;
803
804           // This is potentially only part of a vector, so assume all elements
805           // are packed together.
806           unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize;
807
808           for (unsigned i = 0; i < NumElts; i += VecSize) {
809             // Get values
810             SDValue StoreVal;
811             SmallVector<SDValue, 8> Ops;
812             Ops.push_back(Chain);
813             Ops.push_back(DAG.getConstant(paramCount, MVT::i32));
814             Ops.push_back(DAG.getConstant(curOffset, MVT::i32));
815
816             unsigned Opc = NVPTXISD::StoreParamV2;
817
818             StoreVal = OutVals[OIdx++];
819             if (NeedExtend)
820               StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
821             Ops.push_back(StoreVal);
822
823             if (i + 1 < NumElts) {
824               StoreVal = OutVals[OIdx++];
825               if (NeedExtend)
826                 StoreVal =
827                     DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
828             } else {
829               StoreVal = DAG.getUNDEF(EltVT);
830             }
831             Ops.push_back(StoreVal);
832
833             if (VecSize == 4) {
834               Opc = NVPTXISD::StoreParamV4;
835               if (i + 2 < NumElts) {
836                 StoreVal = OutVals[OIdx++];
837                 if (NeedExtend)
838                   StoreVal =
839                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
840               } else {
841                 StoreVal = DAG.getUNDEF(EltVT);
842               }
843               Ops.push_back(StoreVal);
844
845               if (i + 3 < NumElts) {
846                 StoreVal = OutVals[OIdx++];
847                 if (NeedExtend)
848                   StoreVal =
849                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
850               } else {
851                 StoreVal = DAG.getUNDEF(EltVT);
852               }
853               Ops.push_back(StoreVal);
854             }
855
856             SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
857             Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, &Ops[0],
858                                             Ops.size(), MemVT,
859                                             MachinePointerInfo());
860             InFlag = Chain.getValue(1);
861             curOffset += PerStoreOffset;
862           }
863         }
864         ++paramCount;
865         --OIdx;
866         continue;
867       }
868       // Plain scalar
869       // for ABI,    declare .param .b<size> .param<n>;
870       unsigned sz = VT.getSizeInBits();
871       bool needExtend = false;
872       if (VT.isInteger()) {
873         if (sz < 16)
874           needExtend = true;
875         if (sz < 32)
876           sz = 32;
877       }
878       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
879       SDValue DeclareParamOps[] = { Chain,
880                                     DAG.getConstant(paramCount, MVT::i32),
881                                     DAG.getConstant(sz, MVT::i32),
882                                     DAG.getConstant(0, MVT::i32), InFlag };
883       Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
884                           DeclareParamOps, 5);
885       InFlag = Chain.getValue(1);
886       SDValue OutV = OutVals[OIdx];
887       if (needExtend) {
888         // zext/sext i1 to i16
889         unsigned opc = ISD::ZERO_EXTEND;
890         if (Outs[OIdx].Flags.isSExt())
891           opc = ISD::SIGN_EXTEND;
892         OutV = DAG.getNode(opc, dl, MVT::i16, OutV);
893       }
894       SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
895       SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
896                                  DAG.getConstant(0, MVT::i32), OutV, InFlag };
897
898       unsigned opcode = NVPTXISD::StoreParam;
899       if (Outs[OIdx].Flags.isZExt())
900         opcode = NVPTXISD::StoreParamU32;
901       else if (Outs[OIdx].Flags.isSExt())
902         opcode = NVPTXISD::StoreParamS32;
903       Chain = DAG.getMemIntrinsicNode(opcode, dl, CopyParamVTs, CopyParamOps, 5,
904                                       VT, MachinePointerInfo());
905
906       InFlag = Chain.getValue(1);
907       ++paramCount;
908       continue;
909     }
910     // struct or vector
911     SmallVector<EVT, 16> vtparts;
912     const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
913     assert(PTy && "Type of a byval parameter should be pointer");
914     ComputeValueVTs(*this, PTy->getElementType(), vtparts);
915
916     // declare .param .align <align> .b8 .param<n>[<size>];
917     unsigned sz = Outs[OIdx].Flags.getByValSize();
918     SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
919     // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
920     // so we don't need to worry about natural alignment or not.
921     // See TargetLowering::LowerCallTo().
922     SDValue DeclareParamOps[] = {
923       Chain, DAG.getConstant(Outs[OIdx].Flags.getByValAlign(), MVT::i32),
924       DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32),
925       InFlag
926     };
927     Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
928                         DeclareParamOps, 5);
929     InFlag = Chain.getValue(1);
930     unsigned curOffset = 0;
931     for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
932       unsigned elems = 1;
933       EVT elemtype = vtparts[j];
934       if (vtparts[j].isVector()) {
935         elems = vtparts[j].getVectorNumElements();
936         elemtype = vtparts[j].getVectorElementType();
937       }
938       for (unsigned k = 0, ke = elems; k != ke; ++k) {
939         unsigned sz = elemtype.getSizeInBits();
940         if (elemtype.isInteger() && (sz < 8))
941           sz = 8;
942         SDValue srcAddr =
943             DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
944                         DAG.getConstant(curOffset, getPointerTy()));
945         SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
946                                      MachinePointerInfo(), false, false, false,
947                                      0);
948         if (elemtype.getSizeInBits() < 16) {
949           theVal = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i16, theVal);
950         }
951         SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
952         SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
953                                    DAG.getConstant(curOffset, MVT::i32), theVal,
954                                    InFlag };
955         Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
956                                         CopyParamOps, 5, elemtype,
957                                         MachinePointerInfo());
958
959         InFlag = Chain.getValue(1);
960         curOffset += sz / 8;
961       }
962     }
963     ++paramCount;
964   }
965
966   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
967   unsigned retAlignment = 0;
968
969   // Handle Result
970   if (Ins.size() > 0) {
971     SmallVector<EVT, 16> resvtparts;
972     ComputeValueVTs(*this, retTy, resvtparts);
973
974     // Declare
975     //  .param .align 16 .b8 retval0[<size-in-bytes>], or
976     //  .param .b<size-in-bits> retval0
977     unsigned resultsz = TD->getTypeAllocSizeInBits(retTy);
978     if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
979         retTy->isPointerTy()) {
980       // Scalar needs to be at least 32bit wide
981       if (resultsz < 32)
982         resultsz = 32;
983       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
984       SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
985                                   DAG.getConstant(resultsz, MVT::i32),
986                                   DAG.getConstant(0, MVT::i32), InFlag };
987       Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
988                           DeclareRetOps, 5);
989       InFlag = Chain.getValue(1);
990     } else {
991       retAlignment = getArgumentAlignment(Callee, CS, retTy, 0);
992       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
993       SDValue DeclareRetOps[] = { Chain,
994                                   DAG.getConstant(retAlignment, MVT::i32),
995                                   DAG.getConstant(resultsz / 8, MVT::i32),
996                                   DAG.getConstant(0, MVT::i32), InFlag };
997       Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
998                           DeclareRetOps, 5);
999       InFlag = Chain.getValue(1);
1000     }
1001   }
1002
1003   if (!Func) {
1004     // This is indirect function call case : PTX requires a prototype of the
1005     // form
1006     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1007     // to be emitted, and the label has to used as the last arg of call
1008     // instruction.
1009     // The prototype is embedded in a string and put as the operand for an
1010     // INLINEASM SDNode.
1011     SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1012     std::string proto_string =
1013         getPrototype(retTy, Args, Outs, retAlignment, CS);
1014     const char *asmstr = nvTM->getManagedStrPool()
1015         ->getManagedString(proto_string.c_str())->c_str();
1016     SDValue InlineAsmOps[] = {
1017       Chain, DAG.getTargetExternalSymbol(asmstr, getPointerTy()),
1018       DAG.getMDNode(0), DAG.getTargetConstant(0, MVT::i32), InFlag
1019     };
1020     Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
1021     InFlag = Chain.getValue(1);
1022   }
1023   // Op to just print "call"
1024   SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1025   SDValue PrintCallOps[] = {
1026     Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, MVT::i32), InFlag
1027   };
1028   Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
1029                       dl, PrintCallVTs, PrintCallOps, 3);
1030   InFlag = Chain.getValue(1);
1031
1032   // Ops to print out the function name
1033   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1034   SDValue CallVoidOps[] = { Chain, Callee, InFlag };
1035   Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
1036   InFlag = Chain.getValue(1);
1037
1038   // Ops to print out the param list
1039   SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1040   SDValue CallArgBeginOps[] = { Chain, InFlag };
1041   Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
1042                       CallArgBeginOps, 2);
1043   InFlag = Chain.getValue(1);
1044
1045   for (unsigned i = 0, e = paramCount; i != e; ++i) {
1046     unsigned opcode;
1047     if (i == (e - 1))
1048       opcode = NVPTXISD::LastCallArg;
1049     else
1050       opcode = NVPTXISD::CallArg;
1051     SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1052     SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
1053                              DAG.getConstant(i, MVT::i32), InFlag };
1054     Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
1055     InFlag = Chain.getValue(1);
1056   }
1057   SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1058   SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32),
1059                               InFlag };
1060   Chain =
1061       DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps, 3);
1062   InFlag = Chain.getValue(1);
1063
1064   if (!Func) {
1065     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1066     SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32),
1067                                InFlag };
1068     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
1069     InFlag = Chain.getValue(1);
1070   }
1071
1072   // Generate loads from param memory/moves from registers for result
1073   if (Ins.size() > 0) {
1074     unsigned resoffset = 0;
1075     if (retTy && retTy->isVectorTy()) {
1076       EVT ObjectVT = getValueType(retTy);
1077       unsigned NumElts = ObjectVT.getVectorNumElements();
1078       EVT EltVT = ObjectVT.getVectorElementType();
1079       assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1080              "Vector was not scalarized");
1081       unsigned sz = EltVT.getSizeInBits();
1082       bool needTruncate = sz < 16 ? true : false;
1083
1084       if (NumElts == 1) {
1085         // Just a simple load
1086         std::vector<EVT> LoadRetVTs;
1087         if (needTruncate) {
1088           // If loading i1 result, generate
1089           //   load i16
1090           //   trunc i16 to i1
1091           LoadRetVTs.push_back(MVT::i16);
1092         } else
1093           LoadRetVTs.push_back(EltVT);
1094         LoadRetVTs.push_back(MVT::Other);
1095         LoadRetVTs.push_back(MVT::Glue);
1096         std::vector<SDValue> LoadRetOps;
1097         LoadRetOps.push_back(Chain);
1098         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1099         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1100         LoadRetOps.push_back(InFlag);
1101         SDValue retval = DAG.getMemIntrinsicNode(
1102             NVPTXISD::LoadParam, dl,
1103             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1104             LoadRetOps.size(), EltVT, MachinePointerInfo());
1105         Chain = retval.getValue(1);
1106         InFlag = retval.getValue(2);
1107         SDValue Ret0 = retval;
1108         if (needTruncate)
1109           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Ret0);
1110         InVals.push_back(Ret0);
1111       } else if (NumElts == 2) {
1112         // LoadV2
1113         std::vector<EVT> LoadRetVTs;
1114         if (needTruncate) {
1115           // If loading i1 result, generate
1116           //   load i16
1117           //   trunc i16 to i1
1118           LoadRetVTs.push_back(MVT::i16);
1119           LoadRetVTs.push_back(MVT::i16);
1120         } else {
1121           LoadRetVTs.push_back(EltVT);
1122           LoadRetVTs.push_back(EltVT);
1123         }
1124         LoadRetVTs.push_back(MVT::Other);
1125         LoadRetVTs.push_back(MVT::Glue);
1126         std::vector<SDValue> LoadRetOps;
1127         LoadRetOps.push_back(Chain);
1128         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1129         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1130         LoadRetOps.push_back(InFlag);
1131         SDValue retval = DAG.getMemIntrinsicNode(
1132             NVPTXISD::LoadParamV2, dl,
1133             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1134             LoadRetOps.size(), EltVT, MachinePointerInfo());
1135         Chain = retval.getValue(2);
1136         InFlag = retval.getValue(3);
1137         SDValue Ret0 = retval.getValue(0);
1138         SDValue Ret1 = retval.getValue(1);
1139         if (needTruncate) {
1140           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret0);
1141           InVals.push_back(Ret0);
1142           Ret1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret1);
1143           InVals.push_back(Ret1);
1144         } else {
1145           InVals.push_back(Ret0);
1146           InVals.push_back(Ret1);
1147         }
1148       } else {
1149         // Split into N LoadV4
1150         unsigned Ofst = 0;
1151         unsigned VecSize = 4;
1152         unsigned Opc = NVPTXISD::LoadParamV4;
1153         if (EltVT.getSizeInBits() == 64) {
1154           VecSize = 2;
1155           Opc = NVPTXISD::LoadParamV2;
1156         }
1157         EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1158         for (unsigned i = 0; i < NumElts; i += VecSize) {
1159           SmallVector<EVT, 8> LoadRetVTs;
1160           if (needTruncate) {
1161             // If loading i1 result, generate
1162             //   load i16
1163             //   trunc i16 to i1
1164             for (unsigned j = 0; j < VecSize; ++j)
1165               LoadRetVTs.push_back(MVT::i16);
1166           } else {
1167             for (unsigned j = 0; j < VecSize; ++j)
1168               LoadRetVTs.push_back(EltVT);
1169           }
1170           LoadRetVTs.push_back(MVT::Other);
1171           LoadRetVTs.push_back(MVT::Glue);
1172           SmallVector<SDValue, 4> LoadRetOps;
1173           LoadRetOps.push_back(Chain);
1174           LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1175           LoadRetOps.push_back(DAG.getConstant(Ofst, MVT::i32));
1176           LoadRetOps.push_back(InFlag);
1177           SDValue retval = DAG.getMemIntrinsicNode(
1178               Opc, dl, DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()),
1179               &LoadRetOps[0], LoadRetOps.size(), EltVT, MachinePointerInfo());
1180           if (VecSize == 2) {
1181             Chain = retval.getValue(2);
1182             InFlag = retval.getValue(3);
1183           } else {
1184             Chain = retval.getValue(4);
1185             InFlag = retval.getValue(5);
1186           }
1187
1188           for (unsigned j = 0; j < VecSize; ++j) {
1189             if (i + j >= NumElts)
1190               break;
1191             SDValue Elt = retval.getValue(j);
1192             if (needTruncate)
1193               Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
1194             InVals.push_back(Elt);
1195           }
1196           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1197         }
1198       }
1199     } else {
1200       SmallVector<EVT, 16> VTs;
1201       ComputePTXValueVTs(*this, retTy, VTs);
1202       assert(VTs.size() == Ins.size() && "Bad value decomposition");
1203       for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
1204         unsigned sz = VTs[i].getSizeInBits();
1205         bool needTruncate = sz < 8 ? true : false;
1206         if (VTs[i].isInteger() && (sz < 8))
1207           sz = 8;
1208
1209         SmallVector<EVT, 4> LoadRetVTs;
1210         if (sz < 16) {
1211           // If loading i1/i8 result, generate
1212           //   load i8 (-> i16)
1213           //   trunc i16 to i1/i8
1214           LoadRetVTs.push_back(MVT::i16);
1215         } else
1216           LoadRetVTs.push_back(Ins[i].VT);
1217         LoadRetVTs.push_back(MVT::Other);
1218         LoadRetVTs.push_back(MVT::Glue);
1219
1220         SmallVector<SDValue, 4> LoadRetOps;
1221         LoadRetOps.push_back(Chain);
1222         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1223         LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
1224         LoadRetOps.push_back(InFlag);
1225         SDValue retval = DAG.getMemIntrinsicNode(
1226             NVPTXISD::LoadParam, dl,
1227             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1228             LoadRetOps.size(), VTs[i], MachinePointerInfo());
1229         Chain = retval.getValue(1);
1230         InFlag = retval.getValue(2);
1231         SDValue Ret0 = retval.getValue(0);
1232         if (needTruncate)
1233           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
1234         InVals.push_back(Ret0);
1235         resoffset += sz / 8;
1236       }
1237     }
1238   }
1239
1240   Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
1241                              DAG.getIntPtrConstant(uniqueCallSite + 1, true),
1242                              InFlag, dl);
1243   uniqueCallSite++;
1244
1245   // set isTailCall to false for now, until we figure out how to express
1246   // tail call optimization in PTX
1247   isTailCall = false;
1248   return Chain;
1249 }
1250
1251 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1252 // (see LegalizeDAG.cpp). This is slow and uses local memory.
1253 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1254 SDValue
1255 NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1256   SDNode *Node = Op.getNode();
1257   SDLoc dl(Node);
1258   SmallVector<SDValue, 8> Ops;
1259   unsigned NumOperands = Node->getNumOperands();
1260   for (unsigned i = 0; i < NumOperands; ++i) {
1261     SDValue SubOp = Node->getOperand(i);
1262     EVT VVT = SubOp.getNode()->getValueType(0);
1263     EVT EltVT = VVT.getVectorElementType();
1264     unsigned NumSubElem = VVT.getVectorNumElements();
1265     for (unsigned j = 0; j < NumSubElem; ++j) {
1266       Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
1267                                 DAG.getIntPtrConstant(j)));
1268     }
1269   }
1270   return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), &Ops[0],
1271                      Ops.size());
1272 }
1273
1274 SDValue
1275 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1276   switch (Op.getOpcode()) {
1277   case ISD::RETURNADDR:
1278     return SDValue();
1279   case ISD::FRAMEADDR:
1280     return SDValue();
1281   case ISD::GlobalAddress:
1282     return LowerGlobalAddress(Op, DAG);
1283   case ISD::INTRINSIC_W_CHAIN:
1284     return Op;
1285   case ISD::BUILD_VECTOR:
1286   case ISD::EXTRACT_SUBVECTOR:
1287     return Op;
1288   case ISD::CONCAT_VECTORS:
1289     return LowerCONCAT_VECTORS(Op, DAG);
1290   case ISD::STORE:
1291     return LowerSTORE(Op, DAG);
1292   case ISD::LOAD:
1293     return LowerLOAD(Op, DAG);
1294   default:
1295     llvm_unreachable("Custom lowering not defined for operation");
1296   }
1297 }
1298
1299 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
1300   if (Op.getValueType() == MVT::i1)
1301     return LowerLOADi1(Op, DAG);
1302   else
1303     return SDValue();
1304 }
1305
1306 // v = ld i1* addr
1307 //   =>
1308 // v1 = ld i8* addr (-> i16)
1309 // v = trunc i16 to i1
1310 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
1311   SDNode *Node = Op.getNode();
1312   LoadSDNode *LD = cast<LoadSDNode>(Node);
1313   SDLoc dl(Node);
1314   assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
1315   assert(Node->getValueType(0) == MVT::i1 &&
1316          "Custom lowering for i1 load only");
1317   SDValue newLD =
1318       DAG.getLoad(MVT::i16, dl, LD->getChain(), LD->getBasePtr(),
1319                   LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(),
1320                   LD->isInvariant(), LD->getAlignment());
1321   SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
1322   // The legalizer (the caller) is expecting two values from the legalized
1323   // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
1324   // in LegalizeDAG.cpp which also uses MergeValues.
1325   SDValue Ops[] = { result, LD->getChain() };
1326   return DAG.getMergeValues(Ops, 2, dl);
1327 }
1328
1329 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
1330   EVT ValVT = Op.getOperand(1).getValueType();
1331   if (ValVT == MVT::i1)
1332     return LowerSTOREi1(Op, DAG);
1333   else if (ValVT.isVector())
1334     return LowerSTOREVector(Op, DAG);
1335   else
1336     return SDValue();
1337 }
1338
1339 SDValue
1340 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
1341   SDNode *N = Op.getNode();
1342   SDValue Val = N->getOperand(1);
1343   SDLoc DL(N);
1344   EVT ValVT = Val.getValueType();
1345
1346   if (ValVT.isVector()) {
1347     // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1348     // legal.  We can (and should) split that into 2 stores of <2 x double> here
1349     // but I'm leaving that as a TODO for now.
1350     if (!ValVT.isSimple())
1351       return SDValue();
1352     switch (ValVT.getSimpleVT().SimpleTy) {
1353     default:
1354       return SDValue();
1355     case MVT::v2i8:
1356     case MVT::v2i16:
1357     case MVT::v2i32:
1358     case MVT::v2i64:
1359     case MVT::v2f32:
1360     case MVT::v2f64:
1361     case MVT::v4i8:
1362     case MVT::v4i16:
1363     case MVT::v4i32:
1364     case MVT::v4f32:
1365       // This is a "native" vector type
1366       break;
1367     }
1368
1369     unsigned Opcode = 0;
1370     EVT EltVT = ValVT.getVectorElementType();
1371     unsigned NumElts = ValVT.getVectorNumElements();
1372
1373     // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
1374     // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1375     // stored type to i16 and propogate the "real" type as the memory type.
1376     bool NeedSExt = false;
1377     if (EltVT.getSizeInBits() < 16)
1378       NeedSExt = true;
1379
1380     switch (NumElts) {
1381     default:
1382       return SDValue();
1383     case 2:
1384       Opcode = NVPTXISD::StoreV2;
1385       break;
1386     case 4: {
1387       Opcode = NVPTXISD::StoreV4;
1388       break;
1389     }
1390     }
1391
1392     SmallVector<SDValue, 8> Ops;
1393
1394     // First is the chain
1395     Ops.push_back(N->getOperand(0));
1396
1397     // Then the split values
1398     for (unsigned i = 0; i < NumElts; ++i) {
1399       SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
1400                                    DAG.getIntPtrConstant(i));
1401       if (NeedSExt)
1402         ExtVal = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i16, ExtVal);
1403       Ops.push_back(ExtVal);
1404     }
1405
1406     // Then any remaining arguments
1407     for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
1408       Ops.push_back(N->getOperand(i));
1409     }
1410
1411     MemSDNode *MemSD = cast<MemSDNode>(N);
1412
1413     SDValue NewSt = DAG.getMemIntrinsicNode(
1414         Opcode, DL, DAG.getVTList(MVT::Other), &Ops[0], Ops.size(),
1415         MemSD->getMemoryVT(), MemSD->getMemOperand());
1416
1417     //return DCI.CombineTo(N, NewSt, true);
1418     return NewSt;
1419   }
1420
1421   return SDValue();
1422 }
1423
1424 // st i1 v, addr
1425 //    =>
1426 // v1 = zxt v to i16
1427 // st.u8 i16, addr
1428 SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
1429   SDNode *Node = Op.getNode();
1430   SDLoc dl(Node);
1431   StoreSDNode *ST = cast<StoreSDNode>(Node);
1432   SDValue Tmp1 = ST->getChain();
1433   SDValue Tmp2 = ST->getBasePtr();
1434   SDValue Tmp3 = ST->getValue();
1435   assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
1436   unsigned Alignment = ST->getAlignment();
1437   bool isVolatile = ST->isVolatile();
1438   bool isNonTemporal = ST->isNonTemporal();
1439   Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
1440   SDValue Result = DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2,
1441                                      ST->getPointerInfo(), MVT::i8, isNonTemporal,
1442                                      isVolatile, Alignment);
1443   return Result;
1444 }
1445
1446 SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname,
1447                                         int idx, EVT v) const {
1448   std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
1449   std::stringstream suffix;
1450   suffix << idx;
1451   *name += suffix.str();
1452   return DAG.getTargetExternalSymbol(name->c_str(), v);
1453 }
1454
1455 SDValue
1456 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1457   return getExtSymb(DAG, ".PARAM", idx, v);
1458 }
1459
1460 SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1461   return getExtSymb(DAG, ".HLPPARAM", idx);
1462 }
1463
1464 // Check to see if the kernel argument is image*_t or sampler_t
1465
1466 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1467   static const char *const specialTypes[] = { "struct._image2d_t",
1468                                               "struct._image3d_t",
1469                                               "struct._sampler_t" };
1470
1471   const Type *Ty = arg->getType();
1472   const PointerType *PTy = dyn_cast<PointerType>(Ty);
1473
1474   if (!PTy)
1475     return false;
1476
1477   if (!context)
1478     return false;
1479
1480   const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1481   const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1482
1483   for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1484     if (TypeName == specialTypes[i])
1485       return true;
1486
1487   return false;
1488 }
1489
1490 SDValue NVPTXTargetLowering::LowerFormalArguments(
1491     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1492     const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG,
1493     SmallVectorImpl<SDValue> &InVals) const {
1494   MachineFunction &MF = DAG.getMachineFunction();
1495   const DataLayout *TD = getDataLayout();
1496
1497   const Function *F = MF.getFunction();
1498   const AttributeSet &PAL = F->getAttributes();
1499   const TargetLowering *TLI = nvTM->getTargetLowering();
1500
1501   SDValue Root = DAG.getRoot();
1502   std::vector<SDValue> OutChains;
1503
1504   bool isKernel = llvm::isKernelFunction(*F);
1505   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1506   assert(isABI && "Non-ABI compilation is not supported");
1507   if (!isABI)
1508     return Chain;
1509
1510   std::vector<Type *> argTypes;
1511   std::vector<const Argument *> theArgs;
1512   for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1513        I != E; ++I) {
1514     theArgs.push_back(I);
1515     argTypes.push_back(I->getType());
1516   }
1517   // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
1518   // Ins.size() will be larger
1519   //   * if there is an aggregate argument with multiple fields (each field
1520   //     showing up separately in Ins)
1521   //   * if there is a vector argument with more than typical vector-length
1522   //     elements (generally if more than 4) where each vector element is
1523   //     individually present in Ins.
1524   // So a different index should be used for indexing into Ins.
1525   // See similar issue in LowerCall.
1526   unsigned InsIdx = 0;
1527
1528   int idx = 0;
1529   for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) {
1530     Type *Ty = argTypes[i];
1531
1532     // If the kernel argument is image*_t or sampler_t, convert it to
1533     // a i32 constant holding the parameter position. This can later
1534     // matched in the AsmPrinter to output the correct mangled name.
1535     if (isImageOrSamplerVal(
1536             theArgs[i],
1537             (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent()
1538                                      : 0))) {
1539       assert(isKernel && "Only kernels can have image/sampler params");
1540       InVals.push_back(DAG.getConstant(i + 1, MVT::i32));
1541       continue;
1542     }
1543
1544     if (theArgs[i]->use_empty()) {
1545       // argument is dead
1546       if (Ty->isAggregateType()) {
1547         SmallVector<EVT, 16> vtparts;
1548
1549         ComputePTXValueVTs(*this, Ty, vtparts);
1550         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1551         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1552              ++parti) {
1553           EVT partVT = vtparts[parti];
1554           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, partVT));
1555           ++InsIdx;
1556         }
1557         if (vtparts.size() > 0)
1558           --InsIdx;
1559         continue;
1560       }
1561       if (Ty->isVectorTy()) {
1562         EVT ObjectVT = getValueType(Ty);
1563         unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
1564         for (unsigned parti = 0; parti < NumRegs; ++parti) {
1565           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1566           ++InsIdx;
1567         }
1568         if (NumRegs > 0)
1569           --InsIdx;
1570         continue;
1571       }
1572       InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1573       continue;
1574     }
1575
1576     // In the following cases, assign a node order of "idx+1"
1577     // to newly created nodes. The SDNodes for params have to
1578     // appear in the same order as their order of appearance
1579     // in the original function. "idx+1" holds that order.
1580     if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
1581       if (Ty->isAggregateType()) {
1582         SmallVector<EVT, 16> vtparts;
1583         SmallVector<uint64_t, 16> offsets;
1584
1585         // NOTE: Here, we lose the ability to issue vector loads for vectors
1586         // that are a part of a struct.  This should be investigated in the
1587         // future.
1588         ComputePTXValueVTs(*this, Ty, vtparts, &offsets, 0);
1589         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1590         bool aggregateIsPacked = false;
1591         if (StructType *STy = llvm::dyn_cast<StructType>(Ty))
1592           aggregateIsPacked = STy->isPacked();
1593
1594         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1595         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1596              ++parti) {
1597           EVT partVT = vtparts[parti];
1598           Value *srcValue = Constant::getNullValue(
1599               PointerType::get(partVT.getTypeForEVT(F->getContext()),
1600                                llvm::ADDRESS_SPACE_PARAM));
1601           SDValue srcAddr =
1602               DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1603                           DAG.getConstant(offsets[parti], getPointerTy()));
1604           unsigned partAlign =
1605               aggregateIsPacked ? 1
1606                                 : TD->getABITypeAlignment(
1607                                       partVT.getTypeForEVT(F->getContext()));
1608                     SDValue p;
1609           if (Ins[InsIdx].VT.getSizeInBits() > partVT.getSizeInBits())
1610             p = DAG.getExtLoad(ISD::SEXTLOAD, dl, Ins[InsIdx].VT, Root, srcAddr,
1611                                MachinePointerInfo(srcValue), partVT, false,
1612                                false, partAlign);
1613           else
1614             p = DAG.getLoad(partVT, dl, Root, srcAddr,
1615                             MachinePointerInfo(srcValue), false, false, false,
1616                             partAlign);
1617           if (p.getNode())
1618             p.getNode()->setIROrder(idx + 1);
1619           InVals.push_back(p);
1620           ++InsIdx;
1621         }
1622         if (vtparts.size() > 0)
1623           --InsIdx;
1624         continue;
1625       }
1626       if (Ty->isVectorTy()) {
1627         EVT ObjectVT = getValueType(Ty);
1628         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1629         unsigned NumElts = ObjectVT.getVectorNumElements();
1630         assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1631                "Vector was not scalarized");
1632         unsigned Ofst = 0;
1633         EVT EltVT = ObjectVT.getVectorElementType();
1634
1635         // V1 load
1636         // f32 = load ...
1637         if (NumElts == 1) {
1638           // We only have one element, so just directly load it
1639           Value *SrcValue = Constant::getNullValue(PointerType::get(
1640               EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1641           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1642                                         DAG.getConstant(Ofst, getPointerTy()));
1643           SDValue P = DAG.getLoad(
1644               EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1645               false, true,
1646               TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
1647           if (P.getNode())
1648             P.getNode()->setIROrder(idx + 1);
1649
1650           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1651             P = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, P);
1652           InVals.push_back(P);
1653           Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
1654           ++InsIdx;
1655         } else if (NumElts == 2) {
1656           // V2 load
1657           // f32,f32 = load ...
1658           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
1659           Value *SrcValue = Constant::getNullValue(PointerType::get(
1660               VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1661           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1662                                         DAG.getConstant(Ofst, getPointerTy()));
1663           SDValue P = DAG.getLoad(
1664               VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1665               false, true,
1666               TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1667           if (P.getNode())
1668             P.getNode()->setIROrder(idx + 1);
1669
1670           SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1671                                      DAG.getIntPtrConstant(0));
1672           SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1673                                      DAG.getIntPtrConstant(1));
1674
1675           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) {
1676             Elt0 = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, Elt0);
1677             Elt1 = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, Elt1);
1678           }
1679
1680           InVals.push_back(Elt0);
1681           InVals.push_back(Elt1);
1682           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1683           InsIdx += 2;
1684         } else {
1685           // V4 loads
1686           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
1687           // the
1688           // vector will be expanded to a power of 2 elements, so we know we can
1689           // always round up to the next multiple of 4 when creating the vector
1690           // loads.
1691           // e.g.  4 elem => 1 ld.v4
1692           //       6 elem => 2 ld.v4
1693           //       8 elem => 2 ld.v4
1694           //      11 elem => 3 ld.v4
1695           unsigned VecSize = 4;
1696           if (EltVT.getSizeInBits() == 64) {
1697             VecSize = 2;
1698           }
1699           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1700           for (unsigned i = 0; i < NumElts; i += VecSize) {
1701             Value *SrcValue = Constant::getNullValue(
1702                 PointerType::get(VecVT.getTypeForEVT(F->getContext()),
1703                                  llvm::ADDRESS_SPACE_PARAM));
1704             SDValue SrcAddr =
1705                 DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1706                             DAG.getConstant(Ofst, getPointerTy()));
1707             SDValue P = DAG.getLoad(
1708                 VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1709                 false, true,
1710                 TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1711             if (P.getNode())
1712               P.getNode()->setIROrder(idx + 1);
1713
1714             for (unsigned j = 0; j < VecSize; ++j) {
1715               if (i + j >= NumElts)
1716                 break;
1717               SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1718                                         DAG.getIntPtrConstant(j));
1719               if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1720                 Elt = DAG.getNode(ISD::SIGN_EXTEND, dl, Ins[InsIdx].VT, Elt);
1721               InVals.push_back(Elt);
1722             }
1723             Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1724             InsIdx += VecSize;
1725           }
1726         }
1727
1728         if (NumElts > 0)
1729           --InsIdx;
1730         continue;
1731       }
1732       // A plain scalar.
1733       EVT ObjectVT = getValueType(Ty);
1734       // If ABI, load from the param symbol
1735       SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1736       Value *srcValue = Constant::getNullValue(PointerType::get(
1737           ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1738       SDValue p;
1739       if (ObjectVT.getSizeInBits() < Ins[InsIdx].VT.getSizeInBits())
1740         p = DAG.getExtLoad(ISD::SEXTLOAD, dl, Ins[InsIdx].VT, Root, Arg,
1741                            MachinePointerInfo(srcValue), ObjectVT, false, false,
1742               TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1743       else
1744         p = DAG.getLoad(Ins[InsIdx].VT, dl, Root, Arg,
1745                         MachinePointerInfo(srcValue), false, false, false,
1746               TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1747       if (p.getNode())
1748         p.getNode()->setIROrder(idx + 1);
1749       InVals.push_back(p);
1750       continue;
1751     }
1752
1753     // Param has ByVal attribute
1754     // Return MoveParam(param symbol).
1755     // Ideally, the param symbol can be returned directly,
1756     // but when SDNode builder decides to use it in a CopyToReg(),
1757     // machine instruction fails because TargetExternalSymbol
1758     // (not lowered) is target dependent, and CopyToReg assumes
1759     // the source is lowered.
1760     EVT ObjectVT = getValueType(Ty);
1761     assert(ObjectVT == Ins[InsIdx].VT &&
1762            "Ins type did not match function type");
1763     SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1764     SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1765     if (p.getNode())
1766       p.getNode()->setIROrder(idx + 1);
1767     if (isKernel)
1768       InVals.push_back(p);
1769     else {
1770       SDValue p2 = DAG.getNode(
1771           ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1772           DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
1773       InVals.push_back(p2);
1774     }
1775   }
1776
1777   // Clang will check explicit VarArg and issue error if any. However, Clang
1778   // will let code with
1779   // implicit var arg like f() pass. See bug 617733.
1780   // We treat this case as if the arg list is empty.
1781   // if (F.isVarArg()) {
1782   // assert(0 && "VarArg not supported yet!");
1783   //}
1784
1785   if (!OutChains.empty())
1786     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &OutChains[0],
1787                             OutChains.size()));
1788
1789   return Chain;
1790 }
1791
1792
1793 SDValue
1794 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1795                                  bool isVarArg,
1796                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
1797                                  const SmallVectorImpl<SDValue> &OutVals,
1798                                  SDLoc dl, SelectionDAG &DAG) const {
1799   MachineFunction &MF = DAG.getMachineFunction();
1800   const Function *F = MF.getFunction();
1801   const Type *RetTy = F->getReturnType();
1802   const DataLayout *TD = getDataLayout();
1803
1804   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1805   assert(isABI && "Non-ABI compilation is not supported");
1806   if (!isABI)
1807     return Chain;
1808
1809   if (const VectorType *VTy = dyn_cast<const VectorType>(RetTy)) {
1810     // If we have a vector type, the OutVals array will be the scalarized
1811     // components and we have combine them into 1 or more vector stores.
1812     unsigned NumElts = VTy->getNumElements();
1813     assert(NumElts == Outs.size() && "Bad scalarization of return value");
1814
1815     // const_cast can be removed in later LLVM versions
1816     EVT EltVT = getValueType(const_cast<Type *>(RetTy)).getVectorElementType();
1817     bool NeedExtend = false;
1818     if (EltVT.getSizeInBits() < 16)
1819       NeedExtend = true;
1820
1821     // V1 store
1822     if (NumElts == 1) {
1823       SDValue StoreVal = OutVals[0];
1824       // We only have one element, so just directly store it
1825       if (NeedExtend)
1826         StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
1827       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal };
1828       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1829                                       DAG.getVTList(MVT::Other), &Ops[0], 3,
1830                                       EltVT, MachinePointerInfo());
1831
1832     } else if (NumElts == 2) {
1833       // V2 store
1834       SDValue StoreVal0 = OutVals[0];
1835       SDValue StoreVal1 = OutVals[1];
1836
1837       if (NeedExtend) {
1838         StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal0);
1839         StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal1);
1840       }
1841
1842       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal0,
1843                         StoreVal1 };
1844       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetvalV2, dl,
1845                                       DAG.getVTList(MVT::Other), &Ops[0], 4,
1846                                       EltVT, MachinePointerInfo());
1847     } else {
1848       // V4 stores
1849       // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
1850       // vector will be expanded to a power of 2 elements, so we know we can
1851       // always round up to the next multiple of 4 when creating the vector
1852       // stores.
1853       // e.g.  4 elem => 1 st.v4
1854       //       6 elem => 2 st.v4
1855       //       8 elem => 2 st.v4
1856       //      11 elem => 3 st.v4
1857
1858       unsigned VecSize = 4;
1859       if (OutVals[0].getValueType().getSizeInBits() == 64)
1860         VecSize = 2;
1861
1862       unsigned Offset = 0;
1863
1864       EVT VecVT =
1865           EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
1866       unsigned PerStoreOffset =
1867           TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1868
1869       for (unsigned i = 0; i < NumElts; i += VecSize) {
1870         // Get values
1871         SDValue StoreVal;
1872         SmallVector<SDValue, 8> Ops;
1873         Ops.push_back(Chain);
1874         Ops.push_back(DAG.getConstant(Offset, MVT::i32));
1875         unsigned Opc = NVPTXISD::StoreRetvalV2;
1876         EVT ExtendedVT = (NeedExtend) ? MVT::i16 : OutVals[0].getValueType();
1877
1878         StoreVal = OutVals[i];
1879         if (NeedExtend)
1880           StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1881         Ops.push_back(StoreVal);
1882
1883         if (i + 1 < NumElts) {
1884           StoreVal = OutVals[i + 1];
1885           if (NeedExtend)
1886             StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1887         } else {
1888           StoreVal = DAG.getUNDEF(ExtendedVT);
1889         }
1890         Ops.push_back(StoreVal);
1891
1892         if (VecSize == 4) {
1893           Opc = NVPTXISD::StoreRetvalV4;
1894           if (i + 2 < NumElts) {
1895             StoreVal = OutVals[i + 2];
1896             if (NeedExtend)
1897               StoreVal =
1898                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1899           } else {
1900             StoreVal = DAG.getUNDEF(ExtendedVT);
1901           }
1902           Ops.push_back(StoreVal);
1903
1904           if (i + 3 < NumElts) {
1905             StoreVal = OutVals[i + 3];
1906             if (NeedExtend)
1907               StoreVal =
1908                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1909           } else {
1910             StoreVal = DAG.getUNDEF(ExtendedVT);
1911           }
1912           Ops.push_back(StoreVal);
1913         }
1914
1915         // Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
1916         Chain =
1917             DAG.getMemIntrinsicNode(Opc, dl, DAG.getVTList(MVT::Other), &Ops[0],
1918                                     Ops.size(), EltVT, MachinePointerInfo());
1919         Offset += PerStoreOffset;
1920       }
1921     }
1922   } else {
1923     SmallVector<EVT, 16> ValVTs;
1924     // const_cast is necessary since we are still using an LLVM version from
1925     // before the type system re-write.
1926     ComputePTXValueVTs(*this, const_cast<Type *>(RetTy), ValVTs);
1927     assert(ValVTs.size() == OutVals.size() && "Bad return value decomposition");
1928
1929     unsigned sizesofar = 0;
1930     for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
1931       SDValue theVal = OutVals[i];
1932       EVT theValType = theVal.getValueType();
1933       unsigned numElems = 1;
1934       if (theValType.isVector())
1935         numElems = theValType.getVectorNumElements();
1936       for (unsigned j = 0, je = numElems; j != je; ++j) {
1937         SDValue tmpval = theVal;
1938         if (theValType.isVector())
1939           tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1940                                theValType.getVectorElementType(), tmpval,
1941                                DAG.getIntPtrConstant(j));
1942         EVT theStoreType = tmpval.getValueType();
1943         if (theStoreType.getSizeInBits() < 8)
1944           tmpval = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, tmpval);
1945         SDValue Ops[] = { Chain, DAG.getConstant(sizesofar, MVT::i32), tmpval };
1946         Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1947                                         DAG.getVTList(MVT::Other), &Ops[0], 3,
1948                                         ValVTs[i], MachinePointerInfo());
1949         if (theValType.isVector())
1950           sizesofar +=
1951               ValVTs[i].getVectorElementType().getStoreSizeInBits() / 8;
1952         else
1953           sizesofar += ValVTs[i].getStoreSizeInBits() / 8;
1954       }
1955     }
1956   }
1957
1958   return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1959 }
1960
1961
1962 void NVPTXTargetLowering::LowerAsmOperandForConstraint(
1963     SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
1964     SelectionDAG &DAG) const {
1965   if (Constraint.length() > 1)
1966     return;
1967   else
1968     TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1969 }
1970
1971 // NVPTX suuport vector of legal types of any length in Intrinsics because the
1972 // NVPTX specific type legalizer
1973 // will legalize them to the PTX supported length.
1974 bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1975   if (isTypeLegal(VT))
1976     return true;
1977   if (VT.isVector()) {
1978     MVT eVT = VT.getVectorElementType();
1979     if (isTypeLegal(eVT))
1980       return true;
1981   }
1982   return false;
1983 }
1984
1985 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
1986 // TgtMemIntrinsic
1987 // because we need the information that is only available in the "Value" type
1988 // of destination
1989 // pointer. In particular, the address space information.
1990 bool NVPTXTargetLowering::getTgtMemIntrinsic(
1991     IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const {
1992   switch (Intrinsic) {
1993   default:
1994     return false;
1995
1996   case Intrinsic::nvvm_atomic_load_add_f32:
1997     Info.opc = ISD::INTRINSIC_W_CHAIN;
1998     Info.memVT = MVT::f32;
1999     Info.ptrVal = I.getArgOperand(0);
2000     Info.offset = 0;
2001     Info.vol = 0;
2002     Info.readMem = true;
2003     Info.writeMem = true;
2004     Info.align = 0;
2005     return true;
2006
2007   case Intrinsic::nvvm_atomic_load_inc_32:
2008   case Intrinsic::nvvm_atomic_load_dec_32:
2009     Info.opc = ISD::INTRINSIC_W_CHAIN;
2010     Info.memVT = MVT::i32;
2011     Info.ptrVal = I.getArgOperand(0);
2012     Info.offset = 0;
2013     Info.vol = 0;
2014     Info.readMem = true;
2015     Info.writeMem = true;
2016     Info.align = 0;
2017     return true;
2018
2019   case Intrinsic::nvvm_ldu_global_i:
2020   case Intrinsic::nvvm_ldu_global_f:
2021   case Intrinsic::nvvm_ldu_global_p:
2022
2023     Info.opc = ISD::INTRINSIC_W_CHAIN;
2024     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
2025       Info.memVT = getValueType(I.getType());
2026     else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
2027       Info.memVT = getValueType(I.getType());
2028     else
2029       Info.memVT = MVT::f32;
2030     Info.ptrVal = I.getArgOperand(0);
2031     Info.offset = 0;
2032     Info.vol = 0;
2033     Info.readMem = true;
2034     Info.writeMem = false;
2035     Info.align = 0;
2036     return true;
2037
2038   }
2039   return false;
2040 }
2041
2042 /// isLegalAddressingMode - Return true if the addressing mode represented
2043 /// by AM is legal for this target, for a load/store of the specified type.
2044 /// Used to guide target specific optimizations, like loop strength reduction
2045 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
2046 /// (CodeGenPrepare.cpp)
2047 bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
2048                                                 Type *Ty) const {
2049
2050   // AddrMode - This represents an addressing mode of:
2051   //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
2052   //
2053   // The legal address modes are
2054   // - [avar]
2055   // - [areg]
2056   // - [areg+immoff]
2057   // - [immAddr]
2058
2059   if (AM.BaseGV) {
2060     if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
2061       return false;
2062     return true;
2063   }
2064
2065   switch (AM.Scale) {
2066   case 0: // "r", "r+i" or "i" is allowed
2067     break;
2068   case 1:
2069     if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
2070       return false;
2071     // Otherwise we have r+i.
2072     break;
2073   default:
2074     // No scale > 1 is allowed
2075     return false;
2076   }
2077   return true;
2078 }
2079
2080 //===----------------------------------------------------------------------===//
2081 //                         NVPTX Inline Assembly Support
2082 //===----------------------------------------------------------------------===//
2083
2084 /// getConstraintType - Given a constraint letter, return the type of
2085 /// constraint it is for this target.
2086 NVPTXTargetLowering::ConstraintType
2087 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
2088   if (Constraint.size() == 1) {
2089     switch (Constraint[0]) {
2090     default:
2091       break;
2092     case 'r':
2093     case 'h':
2094     case 'c':
2095     case 'l':
2096     case 'f':
2097     case 'd':
2098     case '0':
2099     case 'N':
2100       return C_RegisterClass;
2101     }
2102   }
2103   return TargetLowering::getConstraintType(Constraint);
2104 }
2105
2106 std::pair<unsigned, const TargetRegisterClass *>
2107 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
2108                                                   MVT VT) const {
2109   if (Constraint.size() == 1) {
2110     switch (Constraint[0]) {
2111     case 'c':
2112       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2113     case 'h':
2114       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2115     case 'r':
2116       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
2117     case 'l':
2118     case 'N':
2119       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
2120     case 'f':
2121       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
2122     case 'd':
2123       return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
2124     }
2125   }
2126   return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
2127 }
2128
2129 /// getFunctionAlignment - Return the Log2 alignment of this function.
2130 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
2131   return 4;
2132 }
2133
2134 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
2135 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
2136                               SmallVectorImpl<SDValue> &Results) {
2137   EVT ResVT = N->getValueType(0);
2138   SDLoc DL(N);
2139
2140   assert(ResVT.isVector() && "Vector load must have vector type");
2141
2142   // We only handle "native" vector sizes for now, e.g. <4 x double> is not
2143   // legal.  We can (and should) split that into 2 loads of <2 x double> here
2144   // but I'm leaving that as a TODO for now.
2145   assert(ResVT.isSimple() && "Can only handle simple types");
2146   switch (ResVT.getSimpleVT().SimpleTy) {
2147   default:
2148     return;
2149   case MVT::v2i8:
2150   case MVT::v2i16:
2151   case MVT::v2i32:
2152   case MVT::v2i64:
2153   case MVT::v2f32:
2154   case MVT::v2f64:
2155   case MVT::v4i8:
2156   case MVT::v4i16:
2157   case MVT::v4i32:
2158   case MVT::v4f32:
2159     // This is a "native" vector type
2160     break;
2161   }
2162
2163   EVT EltVT = ResVT.getVectorElementType();
2164   unsigned NumElts = ResVT.getVectorNumElements();
2165
2166   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
2167   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2168   // loaded type to i16 and propogate the "real" type as the memory type.
2169   bool NeedTrunc = false;
2170   if (EltVT.getSizeInBits() < 16) {
2171     EltVT = MVT::i16;
2172     NeedTrunc = true;
2173   }
2174
2175   unsigned Opcode = 0;
2176   SDVTList LdResVTs;
2177
2178   switch (NumElts) {
2179   default:
2180     return;
2181   case 2:
2182     Opcode = NVPTXISD::LoadV2;
2183     LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2184     break;
2185   case 4: {
2186     Opcode = NVPTXISD::LoadV4;
2187     EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2188     LdResVTs = DAG.getVTList(ListVTs, 5);
2189     break;
2190   }
2191   }
2192
2193   SmallVector<SDValue, 8> OtherOps;
2194
2195   // Copy regular operands
2196   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2197     OtherOps.push_back(N->getOperand(i));
2198
2199   LoadSDNode *LD = cast<LoadSDNode>(N);
2200
2201   // The select routine does not have access to the LoadSDNode instance, so
2202   // pass along the extension information
2203   OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
2204
2205   SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
2206                                           OtherOps.size(), LD->getMemoryVT(),
2207                                           LD->getMemOperand());
2208
2209   SmallVector<SDValue, 4> ScalarRes;
2210
2211   for (unsigned i = 0; i < NumElts; ++i) {
2212     SDValue Res = NewLD.getValue(i);
2213     if (NeedTrunc)
2214       Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2215     ScalarRes.push_back(Res);
2216   }
2217
2218   SDValue LoadChain = NewLD.getValue(NumElts);
2219
2220   SDValue BuildVec =
2221       DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2222
2223   Results.push_back(BuildVec);
2224   Results.push_back(LoadChain);
2225 }
2226
2227 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
2228                                      SmallVectorImpl<SDValue> &Results) {
2229   SDValue Chain = N->getOperand(0);
2230   SDValue Intrin = N->getOperand(1);
2231   SDLoc DL(N);
2232
2233   // Get the intrinsic ID
2234   unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
2235   switch (IntrinNo) {
2236   default:
2237     return;
2238   case Intrinsic::nvvm_ldg_global_i:
2239   case Intrinsic::nvvm_ldg_global_f:
2240   case Intrinsic::nvvm_ldg_global_p:
2241   case Intrinsic::nvvm_ldu_global_i:
2242   case Intrinsic::nvvm_ldu_global_f:
2243   case Intrinsic::nvvm_ldu_global_p: {
2244     EVT ResVT = N->getValueType(0);
2245
2246     if (ResVT.isVector()) {
2247       // Vector LDG/LDU
2248
2249       unsigned NumElts = ResVT.getVectorNumElements();
2250       EVT EltVT = ResVT.getVectorElementType();
2251
2252       // Since LDU/LDG are target nodes, we cannot rely on DAG type
2253       // legalization.
2254       // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2255       // loaded type to i16 and propogate the "real" type as the memory type.
2256       bool NeedTrunc = false;
2257       if (EltVT.getSizeInBits() < 16) {
2258         EltVT = MVT::i16;
2259         NeedTrunc = true;
2260       }
2261
2262       unsigned Opcode = 0;
2263       SDVTList LdResVTs;
2264
2265       switch (NumElts) {
2266       default:
2267         return;
2268       case 2:
2269         switch (IntrinNo) {
2270         default:
2271           return;
2272         case Intrinsic::nvvm_ldg_global_i:
2273         case Intrinsic::nvvm_ldg_global_f:
2274         case Intrinsic::nvvm_ldg_global_p:
2275           Opcode = NVPTXISD::LDGV2;
2276           break;
2277         case Intrinsic::nvvm_ldu_global_i:
2278         case Intrinsic::nvvm_ldu_global_f:
2279         case Intrinsic::nvvm_ldu_global_p:
2280           Opcode = NVPTXISD::LDUV2;
2281           break;
2282         }
2283         LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2284         break;
2285       case 4: {
2286         switch (IntrinNo) {
2287         default:
2288           return;
2289         case Intrinsic::nvvm_ldg_global_i:
2290         case Intrinsic::nvvm_ldg_global_f:
2291         case Intrinsic::nvvm_ldg_global_p:
2292           Opcode = NVPTXISD::LDGV4;
2293           break;
2294         case Intrinsic::nvvm_ldu_global_i:
2295         case Intrinsic::nvvm_ldu_global_f:
2296         case Intrinsic::nvvm_ldu_global_p:
2297           Opcode = NVPTXISD::LDUV4;
2298           break;
2299         }
2300         EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2301         LdResVTs = DAG.getVTList(ListVTs, 5);
2302         break;
2303       }
2304       }
2305
2306       SmallVector<SDValue, 8> OtherOps;
2307
2308       // Copy regular operands
2309
2310       OtherOps.push_back(Chain); // Chain
2311                                  // Skip operand 1 (intrinsic ID)
2312       // Others
2313       for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
2314         OtherOps.push_back(N->getOperand(i));
2315
2316       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2317
2318       SDValue NewLD = DAG.getMemIntrinsicNode(
2319           Opcode, DL, LdResVTs, &OtherOps[0], OtherOps.size(),
2320           MemSD->getMemoryVT(), MemSD->getMemOperand());
2321
2322       SmallVector<SDValue, 4> ScalarRes;
2323
2324       for (unsigned i = 0; i < NumElts; ++i) {
2325         SDValue Res = NewLD.getValue(i);
2326         if (NeedTrunc)
2327           Res =
2328               DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2329         ScalarRes.push_back(Res);
2330       }
2331
2332       SDValue LoadChain = NewLD.getValue(NumElts);
2333
2334       SDValue BuildVec =
2335           DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2336
2337       Results.push_back(BuildVec);
2338       Results.push_back(LoadChain);
2339     } else {
2340       // i8 LDG/LDU
2341       assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
2342              "Custom handling of non-i8 ldu/ldg?");
2343
2344       // Just copy all operands as-is
2345       SmallVector<SDValue, 4> Ops;
2346       for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2347         Ops.push_back(N->getOperand(i));
2348
2349       // Force output to i16
2350       SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
2351
2352       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2353
2354       // We make sure the memory type is i8, which will be used during isel
2355       // to select the proper instruction.
2356       SDValue NewLD =
2357           DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0],
2358                                   Ops.size(), MVT::i8, MemSD->getMemOperand());
2359
2360       Results.push_back(NewLD.getValue(0));
2361       Results.push_back(NewLD.getValue(1));
2362     }
2363   }
2364   }
2365 }
2366
2367 void NVPTXTargetLowering::ReplaceNodeResults(
2368     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
2369   switch (N->getOpcode()) {
2370   default:
2371     report_fatal_error("Unhandled custom legalization");
2372   case ISD::LOAD:
2373     ReplaceLoadVector(N, DAG, Results);
2374     return;
2375   case ISD::INTRINSIC_W_CHAIN:
2376     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
2377     return;
2378   }
2379 }