[NVPTX] Add support for vectorized function return values
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
1 //
2 //                     The LLVM Compiler Infrastructure
3 //
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "NVPTXISelLowering.h"
15 #include "NVPTX.h"
16 #include "NVPTXTargetMachine.h"
17 #include "NVPTXTargetObjectFile.h"
18 #include "NVPTXUtilities.h"
19 #include "llvm/CodeGen/Analysis.h"
20 #include "llvm/CodeGen/MachineFrameInfo.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/MC/MCSectionELF.h"
32 #include "llvm/Support/CallSite.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <sstream>
38
39 #undef DEBUG_TYPE
40 #define DEBUG_TYPE "nvptx-lower"
41
42 using namespace llvm;
43
44 static unsigned int uniqueCallSite = 0;
45
46 static cl::opt<bool> sched4reg(
47     "nvptx-sched4reg",
48     cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
49
50 static bool IsPTXVectorType(MVT VT) {
51   switch (VT.SimpleTy) {
52   default:
53     return false;
54   case MVT::v2i8:
55   case MVT::v4i8:
56   case MVT::v2i16:
57   case MVT::v4i16:
58   case MVT::v2i32:
59   case MVT::v4i32:
60   case MVT::v2i64:
61   case MVT::v2f32:
62   case MVT::v4f32:
63   case MVT::v2f64:
64     return true;
65   }
66 }
67
68 // NVPTXTargetLowering Constructor.
69 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
70     : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
71       nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
72
73   // always lower memset, memcpy, and memmove intrinsics to load/store
74   // instructions, rather
75   // then generating calls to memset, mempcy or memmove.
76   MaxStoresPerMemset = (unsigned) 0xFFFFFFFF;
77   MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF;
78   MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF;
79
80   setBooleanContents(ZeroOrNegativeOneBooleanContent);
81
82   // Jump is Expensive. Don't create extra control flow for 'and', 'or'
83   // condition branches.
84   setJumpIsExpensive(true);
85
86   // By default, use the Source scheduling
87   if (sched4reg)
88     setSchedulingPreference(Sched::RegPressure);
89   else
90     setSchedulingPreference(Sched::Source);
91
92   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
93   addRegisterClass(MVT::i8, &NVPTX::Int8RegsRegClass);
94   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
95   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
96   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
97   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
98   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
99
100   // Operations not directly supported by NVPTX.
101   setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
102   setOperationAction(ISD::BR_CC, MVT::f32, Expand);
103   setOperationAction(ISD::BR_CC, MVT::f64, Expand);
104   setOperationAction(ISD::BR_CC, MVT::i1, Expand);
105   setOperationAction(ISD::BR_CC, MVT::i8, Expand);
106   setOperationAction(ISD::BR_CC, MVT::i16, Expand);
107   setOperationAction(ISD::BR_CC, MVT::i32, Expand);
108   setOperationAction(ISD::BR_CC, MVT::i64, Expand);
109   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Expand);
110   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Expand);
111   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand);
112   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8, Expand);
113   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
114
115   if (nvptxSubtarget.hasROT64()) {
116     setOperationAction(ISD::ROTL, MVT::i64, Legal);
117     setOperationAction(ISD::ROTR, MVT::i64, Legal);
118   } else {
119     setOperationAction(ISD::ROTL, MVT::i64, Expand);
120     setOperationAction(ISD::ROTR, MVT::i64, Expand);
121   }
122   if (nvptxSubtarget.hasROT32()) {
123     setOperationAction(ISD::ROTL, MVT::i32, Legal);
124     setOperationAction(ISD::ROTR, MVT::i32, Legal);
125   } else {
126     setOperationAction(ISD::ROTL, MVT::i32, Expand);
127     setOperationAction(ISD::ROTR, MVT::i32, Expand);
128   }
129
130   setOperationAction(ISD::ROTL, MVT::i16, Expand);
131   setOperationAction(ISD::ROTR, MVT::i16, Expand);
132   setOperationAction(ISD::ROTL, MVT::i8, Expand);
133   setOperationAction(ISD::ROTR, MVT::i8, Expand);
134   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
135   setOperationAction(ISD::BSWAP, MVT::i32, Expand);
136   setOperationAction(ISD::BSWAP, MVT::i64, Expand);
137
138   // Indirect branch is not supported.
139   // This also disables Jump Table creation.
140   setOperationAction(ISD::BR_JT, MVT::Other, Expand);
141   setOperationAction(ISD::BRIND, MVT::Other, Expand);
142
143   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
144   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
145
146   // We want to legalize constant related memmove and memcopy
147   // intrinsics.
148   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
149
150   // Turn FP extload into load/fextend
151   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
152   // Turn FP truncstore into trunc + store.
153   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
154
155   // PTX does not support load / store predicate registers
156   setOperationAction(ISD::LOAD, MVT::i1, Custom);
157   setOperationAction(ISD::STORE, MVT::i1, Custom);
158
159   setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
160   setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
161   setTruncStoreAction(MVT::i64, MVT::i1, Expand);
162   setTruncStoreAction(MVT::i32, MVT::i1, Expand);
163   setTruncStoreAction(MVT::i16, MVT::i1, Expand);
164   setTruncStoreAction(MVT::i8, MVT::i1, Expand);
165
166   // This is legal in NVPTX
167   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
168   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
169
170   // TRAP can be lowered to PTX trap
171   setOperationAction(ISD::TRAP, MVT::Other, Legal);
172
173   // Register custom handling for vector loads/stores
174   for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
175        ++i) {
176     MVT VT = (MVT::SimpleValueType) i;
177     if (IsPTXVectorType(VT)) {
178       setOperationAction(ISD::LOAD, VT, Custom);
179       setOperationAction(ISD::STORE, VT, Custom);
180       setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
181     }
182   }
183
184   // Now deduce the information based on the above mentioned
185   // actions
186   computeRegisterProperties();
187 }
188
189 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
190   switch (Opcode) {
191   default:
192     return 0;
193   case NVPTXISD::CALL:
194     return "NVPTXISD::CALL";
195   case NVPTXISD::RET_FLAG:
196     return "NVPTXISD::RET_FLAG";
197   case NVPTXISD::Wrapper:
198     return "NVPTXISD::Wrapper";
199   case NVPTXISD::NVBuiltin:
200     return "NVPTXISD::NVBuiltin";
201   case NVPTXISD::DeclareParam:
202     return "NVPTXISD::DeclareParam";
203   case NVPTXISD::DeclareScalarParam:
204     return "NVPTXISD::DeclareScalarParam";
205   case NVPTXISD::DeclareRet:
206     return "NVPTXISD::DeclareRet";
207   case NVPTXISD::DeclareRetParam:
208     return "NVPTXISD::DeclareRetParam";
209   case NVPTXISD::PrintCall:
210     return "NVPTXISD::PrintCall";
211   case NVPTXISD::LoadParam:
212     return "NVPTXISD::LoadParam";
213   case NVPTXISD::LoadParamV2:
214     return "NVPTXISD::LoadParamV2";
215   case NVPTXISD::LoadParamV4:
216     return "NVPTXISD::LoadParamV4";
217   case NVPTXISD::StoreParam:
218     return "NVPTXISD::StoreParam";
219   case NVPTXISD::StoreParamV2:
220     return "NVPTXISD::StoreParamV2";
221   case NVPTXISD::StoreParamV4:
222     return "NVPTXISD::StoreParamV4";
223   case NVPTXISD::StoreParamS32:
224     return "NVPTXISD::StoreParamS32";
225   case NVPTXISD::StoreParamU32:
226     return "NVPTXISD::StoreParamU32";
227   case NVPTXISD::MoveToParam:
228     return "NVPTXISD::MoveToParam";
229   case NVPTXISD::CallArgBegin:
230     return "NVPTXISD::CallArgBegin";
231   case NVPTXISD::CallArg:
232     return "NVPTXISD::CallArg";
233   case NVPTXISD::LastCallArg:
234     return "NVPTXISD::LastCallArg";
235   case NVPTXISD::CallArgEnd:
236     return "NVPTXISD::CallArgEnd";
237   case NVPTXISD::CallVoid:
238     return "NVPTXISD::CallVoid";
239   case NVPTXISD::CallVal:
240     return "NVPTXISD::CallVal";
241   case NVPTXISD::CallSymbol:
242     return "NVPTXISD::CallSymbol";
243   case NVPTXISD::Prototype:
244     return "NVPTXISD::Prototype";
245   case NVPTXISD::MoveParam:
246     return "NVPTXISD::MoveParam";
247   case NVPTXISD::MoveRetval:
248     return "NVPTXISD::MoveRetval";
249   case NVPTXISD::MoveToRetval:
250     return "NVPTXISD::MoveToRetval";
251   case NVPTXISD::StoreRetval:
252     return "NVPTXISD::StoreRetval";
253   case NVPTXISD::StoreRetvalV2:
254     return "NVPTXISD::StoreRetvalV2";
255   case NVPTXISD::StoreRetvalV4:
256     return "NVPTXISD::StoreRetvalV4";
257   case NVPTXISD::PseudoUseParam:
258     return "NVPTXISD::PseudoUseParam";
259   case NVPTXISD::RETURN:
260     return "NVPTXISD::RETURN";
261   case NVPTXISD::CallSeqBegin:
262     return "NVPTXISD::CallSeqBegin";
263   case NVPTXISD::CallSeqEnd:
264     return "NVPTXISD::CallSeqEnd";
265   case NVPTXISD::LoadV2:
266     return "NVPTXISD::LoadV2";
267   case NVPTXISD::LoadV4:
268     return "NVPTXISD::LoadV4";
269   case NVPTXISD::LDGV2:
270     return "NVPTXISD::LDGV2";
271   case NVPTXISD::LDGV4:
272     return "NVPTXISD::LDGV4";
273   case NVPTXISD::LDUV2:
274     return "NVPTXISD::LDUV2";
275   case NVPTXISD::LDUV4:
276     return "NVPTXISD::LDUV4";
277   case NVPTXISD::StoreV2:
278     return "NVPTXISD::StoreV2";
279   case NVPTXISD::StoreV4:
280     return "NVPTXISD::StoreV4";
281   }
282 }
283
284 bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
285   return VT == MVT::i1;
286 }
287
288 SDValue
289 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
290   SDLoc dl(Op);
291   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
292   Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
293   return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
294 }
295
296 std::string NVPTXTargetLowering::getPrototype(
297     Type *retTy, const ArgListTy &Args,
298     const SmallVectorImpl<ISD::OutputArg> &Outs, unsigned retAlignment) const {
299
300   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
301
302   std::stringstream O;
303   O << "prototype_" << uniqueCallSite << " : .callprototype ";
304
305   if (retTy->getTypeID() == Type::VoidTyID)
306     O << "()";
307   else {
308     O << "(";
309     if (isABI) {
310       if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
311         unsigned size = 0;
312         if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
313           size = ITy->getBitWidth();
314           if (size < 32)
315             size = 32;
316         } else {
317           assert(retTy->isFloatingPointTy() &&
318                  "Floating point type expected here");
319           size = retTy->getPrimitiveSizeInBits();
320         }
321
322         O << ".param .b" << size << " _";
323       } else if (isa<PointerType>(retTy))
324         O << ".param .b" << getPointerTy().getSizeInBits() << " _";
325       else {
326         if ((retTy->getTypeID() == Type::StructTyID) ||
327             isa<VectorType>(retTy)) {
328           SmallVector<EVT, 16> vtparts;
329           ComputeValueVTs(*this, retTy, vtparts);
330           unsigned totalsz = 0;
331           for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
332             unsigned elems = 1;
333             EVT elemtype = vtparts[i];
334             if (vtparts[i].isVector()) {
335               elems = vtparts[i].getVectorNumElements();
336               elemtype = vtparts[i].getVectorElementType();
337             }
338             for (unsigned j = 0, je = elems; j != je; ++j) {
339               unsigned sz = elemtype.getSizeInBits();
340               if (elemtype.isInteger() && (sz < 8))
341                 sz = 8;
342               totalsz += sz / 8;
343             }
344           }
345           O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
346         } else {
347           assert(false && "Unknown return type");
348         }
349       }
350     } else {
351       SmallVector<EVT, 16> vtparts;
352       ComputeValueVTs(*this, retTy, vtparts);
353       unsigned idx = 0;
354       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
355         unsigned elems = 1;
356         EVT elemtype = vtparts[i];
357         if (vtparts[i].isVector()) {
358           elems = vtparts[i].getVectorNumElements();
359           elemtype = vtparts[i].getVectorElementType();
360         }
361
362         for (unsigned j = 0, je = elems; j != je; ++j) {
363           unsigned sz = elemtype.getSizeInBits();
364           if (elemtype.isInteger() && (sz < 32))
365             sz = 32;
366           O << ".reg .b" << sz << " _";
367           if (j < je - 1)
368             O << ", ";
369           ++idx;
370         }
371         if (i < e - 1)
372           O << ", ";
373       }
374     }
375     O << ") ";
376   }
377   O << "_ (";
378
379   bool first = true;
380   MVT thePointerTy = getPointerTy();
381
382   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
383     const Type *Ty = Args[i].Ty;
384     if (!first) {
385       O << ", ";
386     }
387     first = false;
388
389     if (Outs[i].Flags.isByVal() == false) {
390       unsigned sz = 0;
391       if (isa<IntegerType>(Ty)) {
392         sz = cast<IntegerType>(Ty)->getBitWidth();
393         if (sz < 32)
394           sz = 32;
395       } else if (isa<PointerType>(Ty))
396         sz = thePointerTy.getSizeInBits();
397       else
398         sz = Ty->getPrimitiveSizeInBits();
399       if (isABI)
400         O << ".param .b" << sz << " ";
401       else
402         O << ".reg .b" << sz << " ";
403       O << "_";
404       continue;
405     }
406     const PointerType *PTy = dyn_cast<PointerType>(Ty);
407     assert(PTy && "Param with byval attribute should be a pointer type");
408     Type *ETy = PTy->getElementType();
409
410     if (isABI) {
411       unsigned align = Outs[i].Flags.getByValAlign();
412       unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
413       O << ".param .align " << align << " .b8 ";
414       O << "_";
415       O << "[" << sz << "]";
416       continue;
417     } else {
418       SmallVector<EVT, 16> vtparts;
419       ComputeValueVTs(*this, ETy, vtparts);
420       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
421         unsigned elems = 1;
422         EVT elemtype = vtparts[i];
423         if (vtparts[i].isVector()) {
424           elems = vtparts[i].getVectorNumElements();
425           elemtype = vtparts[i].getVectorElementType();
426         }
427
428         for (unsigned j = 0, je = elems; j != je; ++j) {
429           unsigned sz = elemtype.getSizeInBits();
430           if (elemtype.isInteger() && (sz < 32))
431             sz = 32;
432           O << ".reg .b" << sz << " ";
433           O << "_";
434           if (j < je - 1)
435             O << ", ";
436         }
437         if (i < e - 1)
438           O << ", ";
439       }
440       continue;
441     }
442   }
443   O << ");";
444   return O.str();
445 }
446
447 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
448                                        SmallVectorImpl<SDValue> &InVals) const {
449   SelectionDAG &DAG = CLI.DAG;
450   SDLoc dl = CLI.DL;
451   SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
452   SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
453   SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
454   SDValue Chain = CLI.Chain;
455   SDValue Callee = CLI.Callee;
456   bool &isTailCall = CLI.IsTailCall;
457   ArgListTy &Args = CLI.Args;
458   Type *retTy = CLI.RetTy;
459   ImmutableCallSite *CS = CLI.CS;
460
461   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
462
463   SDValue tempChain = Chain;
464   Chain = DAG.getCALLSEQ_START(Chain,
465                                DAG.getIntPtrConstant(uniqueCallSite, true),
466                                dl);
467   SDValue InFlag = Chain.getValue(1);
468
469   assert((Outs.size() == Args.size()) &&
470          "Unexpected number of arguments to function call");
471   unsigned paramCount = 0;
472   // Declare the .params or .reg need to pass values
473   // to the function
474   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
475     EVT VT = Outs[i].VT;
476
477     if (Outs[i].Flags.isByVal() == false) {
478       // Plain scalar
479       // for ABI,    declare .param .b<size> .param<n>;
480       // for nonABI, declare .reg .b<size> .param<n>;
481       unsigned isReg = 1;
482       if (isABI)
483         isReg = 0;
484       unsigned sz = VT.getSizeInBits();
485       if (VT.isInteger() && (sz < 32))
486         sz = 32;
487       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
488       SDValue DeclareParamOps[] = { Chain,
489                                     DAG.getConstant(paramCount, MVT::i32),
490                                     DAG.getConstant(sz, MVT::i32),
491                                     DAG.getConstant(isReg, MVT::i32), InFlag };
492       Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
493                           DeclareParamOps, 5);
494       InFlag = Chain.getValue(1);
495       SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
496       SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
497                                  DAG.getConstant(0, MVT::i32), OutVals[i],
498                                  InFlag };
499
500       unsigned opcode = NVPTXISD::StoreParam;
501       if (isReg)
502         opcode = NVPTXISD::MoveToParam;
503       else {
504         if (Outs[i].Flags.isZExt())
505           opcode = NVPTXISD::StoreParamU32;
506         else if (Outs[i].Flags.isSExt())
507           opcode = NVPTXISD::StoreParamS32;
508       }
509       Chain = DAG.getNode(opcode, dl, CopyParamVTs, CopyParamOps, 5);
510
511       InFlag = Chain.getValue(1);
512       ++paramCount;
513       continue;
514     }
515     // struct or vector
516     SmallVector<EVT, 16> vtparts;
517     const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
518     assert(PTy && "Type of a byval parameter should be pointer");
519     ComputeValueVTs(*this, PTy->getElementType(), vtparts);
520
521     if (isABI) {
522       // declare .param .align 16 .b8 .param<n>[<size>];
523       unsigned sz = Outs[i].Flags.getByValSize();
524       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
525       // The ByValAlign in the Outs[i].Flags is alway set at this point, so we
526       // don't need to
527       // worry about natural alignment or not. See TargetLowering::LowerCallTo()
528       SDValue DeclareParamOps[] = {
529         Chain, DAG.getConstant(Outs[i].Flags.getByValAlign(), MVT::i32),
530         DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32),
531         InFlag
532       };
533       Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
534                           DeclareParamOps, 5);
535       InFlag = Chain.getValue(1);
536       unsigned curOffset = 0;
537       for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
538         unsigned elems = 1;
539         EVT elemtype = vtparts[j];
540         if (vtparts[j].isVector()) {
541           elems = vtparts[j].getVectorNumElements();
542           elemtype = vtparts[j].getVectorElementType();
543         }
544         for (unsigned k = 0, ke = elems; k != ke; ++k) {
545           unsigned sz = elemtype.getSizeInBits();
546           if (elemtype.isInteger() && (sz < 8))
547             sz = 8;
548           SDValue srcAddr =
549               DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i],
550                           DAG.getConstant(curOffset, getPointerTy()));
551           SDValue theVal =
552               DAG.getLoad(elemtype, dl, tempChain, srcAddr,
553                           MachinePointerInfo(), false, false, false, 0);
554           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
555           SDValue CopyParamOps[] = { Chain,
556                                      DAG.getConstant(paramCount, MVT::i32),
557                                      DAG.getConstant(curOffset, MVT::i32),
558                                      theVal, InFlag };
559           Chain = DAG.getNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
560                               CopyParamOps, 5);
561           InFlag = Chain.getValue(1);
562           curOffset += sz / 8;
563         }
564       }
565       ++paramCount;
566       continue;
567     }
568     // Non-abi, struct or vector
569     // Declare a bunch or .reg .b<size> .param<n>
570     unsigned curOffset = 0;
571     for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
572       unsigned elems = 1;
573       EVT elemtype = vtparts[j];
574       if (vtparts[j].isVector()) {
575         elems = vtparts[j].getVectorNumElements();
576         elemtype = vtparts[j].getVectorElementType();
577       }
578       for (unsigned k = 0, ke = elems; k != ke; ++k) {
579         unsigned sz = elemtype.getSizeInBits();
580         if (elemtype.isInteger() && (sz < 32))
581           sz = 32;
582         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
583         SDValue DeclareParamOps[] = { Chain,
584                                       DAG.getConstant(paramCount, MVT::i32),
585                                       DAG.getConstant(sz, MVT::i32),
586                                       DAG.getConstant(1, MVT::i32), InFlag };
587         Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
588                             DeclareParamOps, 5);
589         InFlag = Chain.getValue(1);
590         SDValue srcAddr =
591             DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i],
592                         DAG.getConstant(curOffset, getPointerTy()));
593         SDValue theVal =
594             DAG.getLoad(elemtype, dl, tempChain, srcAddr, MachinePointerInfo(),
595                         false, false, false, 0);
596         SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
597         SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
598                                    DAG.getConstant(0, MVT::i32), theVal,
599                                    InFlag };
600         Chain = DAG.getNode(NVPTXISD::MoveToParam, dl, CopyParamVTs,
601                             CopyParamOps, 5);
602         InFlag = Chain.getValue(1);
603         ++paramCount;
604       }
605     }
606   }
607
608   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
609   unsigned retAlignment = 0;
610
611   // Handle Result
612   unsigned retCount = 0;
613   if (Ins.size() > 0) {
614     SmallVector<EVT, 16> resvtparts;
615     ComputeValueVTs(*this, retTy, resvtparts);
616
617     // Declare one .param .align 16 .b8 func_retval0[<size>] for ABI or
618     // individual .reg .b<size> func_retval<0..> for non ABI
619     unsigned resultsz = 0;
620     for (unsigned i = 0, e = resvtparts.size(); i != e; ++i) {
621       unsigned elems = 1;
622       EVT elemtype = resvtparts[i];
623       if (resvtparts[i].isVector()) {
624         elems = resvtparts[i].getVectorNumElements();
625         elemtype = resvtparts[i].getVectorElementType();
626       }
627       for (unsigned j = 0, je = elems; j != je; ++j) {
628         unsigned sz = elemtype.getSizeInBits();
629         if (isABI == false) {
630           if (elemtype.isInteger() && (sz < 32))
631             sz = 32;
632         } else {
633           if (elemtype.isInteger() && (sz < 8))
634             sz = 8;
635         }
636         if (isABI == false) {
637           SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
638           SDValue DeclareRetOps[] = { Chain, DAG.getConstant(2, MVT::i32),
639                                       DAG.getConstant(sz, MVT::i32),
640                                       DAG.getConstant(retCount, MVT::i32),
641                                       InFlag };
642           Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
643                               DeclareRetOps, 5);
644           InFlag = Chain.getValue(1);
645           ++retCount;
646         }
647         resultsz += sz;
648       }
649     }
650     if (isABI) {
651       if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
652           retTy->isPointerTy()) {
653         // Scalar needs to be at least 32bit wide
654         if (resultsz < 32)
655           resultsz = 32;
656         SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
657         SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
658                                     DAG.getConstant(resultsz, MVT::i32),
659                                     DAG.getConstant(0, MVT::i32), InFlag };
660         Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
661                             DeclareRetOps, 5);
662         InFlag = Chain.getValue(1);
663       } else {
664         if (Func) { // direct call
665           if (!llvm::getAlign(*(CS->getCalledFunction()), 0, retAlignment))
666             retAlignment = getDataLayout()->getABITypeAlignment(retTy);
667         } else { // indirect call
668           const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
669           if (!llvm::getAlign(*CallI, 0, retAlignment))
670             retAlignment = getDataLayout()->getABITypeAlignment(retTy);
671         }
672         SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
673         SDValue DeclareRetOps[] = { Chain,
674                                     DAG.getConstant(retAlignment, MVT::i32),
675                                     DAG.getConstant(resultsz / 8, MVT::i32),
676                                     DAG.getConstant(0, MVT::i32), InFlag };
677         Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
678                             DeclareRetOps, 5);
679         InFlag = Chain.getValue(1);
680       }
681     }
682   }
683
684   if (!Func) {
685     // This is indirect function call case : PTX requires a prototype of the
686     // form
687     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
688     // to be emitted, and the label has to used as the last arg of call
689     // instruction.
690     // The prototype is embedded in a string and put as the operand for an
691     // INLINEASM SDNode.
692     SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
693     std::string proto_string = getPrototype(retTy, Args, Outs, retAlignment);
694     const char *asmstr = nvTM->getManagedStrPool()
695         ->getManagedString(proto_string.c_str())->c_str();
696     SDValue InlineAsmOps[] = {
697       Chain, DAG.getTargetExternalSymbol(asmstr, getPointerTy()),
698       DAG.getMDNode(0), DAG.getTargetConstant(0, MVT::i32), InFlag
699     };
700     Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
701     InFlag = Chain.getValue(1);
702   }
703   // Op to just print "call"
704   SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
705   SDValue PrintCallOps[] = {
706     Chain,
707     DAG.getConstant(isABI ? ((Ins.size() == 0) ? 0 : 1) : retCount, MVT::i32),
708     InFlag
709   };
710   Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
711                       dl, PrintCallVTs, PrintCallOps, 3);
712   InFlag = Chain.getValue(1);
713
714   // Ops to print out the function name
715   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
716   SDValue CallVoidOps[] = { Chain, Callee, InFlag };
717   Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
718   InFlag = Chain.getValue(1);
719
720   // Ops to print out the param list
721   SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
722   SDValue CallArgBeginOps[] = { Chain, InFlag };
723   Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
724                       CallArgBeginOps, 2);
725   InFlag = Chain.getValue(1);
726
727   for (unsigned i = 0, e = paramCount; i != e; ++i) {
728     unsigned opcode;
729     if (i == (e - 1))
730       opcode = NVPTXISD::LastCallArg;
731     else
732       opcode = NVPTXISD::CallArg;
733     SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
734     SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
735                              DAG.getConstant(i, MVT::i32), InFlag };
736     Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
737     InFlag = Chain.getValue(1);
738   }
739   SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
740   SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32),
741                               InFlag };
742   Chain =
743       DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps, 3);
744   InFlag = Chain.getValue(1);
745
746   if (!Func) {
747     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
748     SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32),
749                                InFlag };
750     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
751     InFlag = Chain.getValue(1);
752   }
753
754   // Generate loads from param memory/moves from registers for result
755   if (Ins.size() > 0) {
756     if (isABI) {
757       unsigned resoffset = 0;
758       for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
759         unsigned sz = Ins[i].VT.getSizeInBits();
760         if (Ins[i].VT.isInteger() && (sz < 8))
761           sz = 8;
762         EVT LoadRetVTs[] = { Ins[i].VT, MVT::Other, MVT::Glue };
763         SDValue LoadRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
764                                  DAG.getConstant(resoffset, MVT::i32), InFlag };
765         SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, LoadRetVTs,
766                                      LoadRetOps, array_lengthof(LoadRetOps));
767         Chain = retval.getValue(1);
768         InFlag = retval.getValue(2);
769         InVals.push_back(retval);
770         resoffset += sz / 8;
771       }
772     } else {
773       SmallVector<EVT, 16> resvtparts;
774       ComputeValueVTs(*this, retTy, resvtparts);
775
776       assert(Ins.size() == resvtparts.size() &&
777              "Unexpected number of return values in non-ABI case");
778       unsigned paramNum = 0;
779       for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
780         assert(EVT(Ins[i].VT) == resvtparts[i] &&
781                "Unexpected EVT type in non-ABI case");
782         unsigned numelems = 1;
783         EVT elemtype = Ins[i].VT;
784         if (Ins[i].VT.isVector()) {
785           numelems = Ins[i].VT.getVectorNumElements();
786           elemtype = Ins[i].VT.getVectorElementType();
787         }
788         std::vector<SDValue> tempRetVals;
789         for (unsigned j = 0; j < numelems; ++j) {
790           EVT MoveRetVTs[] = { elemtype, MVT::Other, MVT::Glue };
791           SDValue MoveRetOps[] = { Chain, DAG.getConstant(0, MVT::i32),
792                                    DAG.getConstant(paramNum, MVT::i32),
793                                    InFlag };
794           SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, MoveRetVTs,
795                                        MoveRetOps, array_lengthof(MoveRetOps));
796           Chain = retval.getValue(1);
797           InFlag = retval.getValue(2);
798           tempRetVals.push_back(retval);
799           ++paramNum;
800         }
801         if (Ins[i].VT.isVector())
802           InVals.push_back(DAG.getNode(ISD::BUILD_VECTOR, dl, Ins[i].VT,
803                                        &tempRetVals[0], tempRetVals.size()));
804         else
805           InVals.push_back(tempRetVals[0]);
806       }
807     }
808   }
809   Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
810                              DAG.getIntPtrConstant(uniqueCallSite + 1, true),
811                              InFlag, dl);
812   uniqueCallSite++;
813
814   // set isTailCall to false for now, until we figure out how to express
815   // tail call optimization in PTX
816   isTailCall = false;
817   return Chain;
818 }
819
820 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
821 // (see LegalizeDAG.cpp). This is slow and uses local memory.
822 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
823 SDValue
824 NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
825   SDNode *Node = Op.getNode();
826   SDLoc dl(Node);
827   SmallVector<SDValue, 8> Ops;
828   unsigned NumOperands = Node->getNumOperands();
829   for (unsigned i = 0; i < NumOperands; ++i) {
830     SDValue SubOp = Node->getOperand(i);
831     EVT VVT = SubOp.getNode()->getValueType(0);
832     EVT EltVT = VVT.getVectorElementType();
833     unsigned NumSubElem = VVT.getVectorNumElements();
834     for (unsigned j = 0; j < NumSubElem; ++j) {
835       Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
836                                 DAG.getIntPtrConstant(j)));
837     }
838   }
839   return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), &Ops[0],
840                      Ops.size());
841 }
842
843 SDValue
844 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
845   switch (Op.getOpcode()) {
846   case ISD::RETURNADDR:
847     return SDValue();
848   case ISD::FRAMEADDR:
849     return SDValue();
850   case ISD::GlobalAddress:
851     return LowerGlobalAddress(Op, DAG);
852   case ISD::INTRINSIC_W_CHAIN:
853     return Op;
854   case ISD::BUILD_VECTOR:
855   case ISD::EXTRACT_SUBVECTOR:
856     return Op;
857   case ISD::CONCAT_VECTORS:
858     return LowerCONCAT_VECTORS(Op, DAG);
859   case ISD::STORE:
860     return LowerSTORE(Op, DAG);
861   case ISD::LOAD:
862     return LowerLOAD(Op, DAG);
863   default:
864     llvm_unreachable("Custom lowering not defined for operation");
865   }
866 }
867
868 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
869   if (Op.getValueType() == MVT::i1)
870     return LowerLOADi1(Op, DAG);
871   else
872     return SDValue();
873 }
874
875 // v = ld i1* addr
876 //   =>
877 // v1 = ld i8* addr
878 // v = trunc v1 to i1
879 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
880   SDNode *Node = Op.getNode();
881   LoadSDNode *LD = cast<LoadSDNode>(Node);
882   SDLoc dl(Node);
883   assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
884   assert(Node->getValueType(0) == MVT::i1 &&
885          "Custom lowering for i1 load only");
886   SDValue newLD =
887       DAG.getLoad(MVT::i8, dl, LD->getChain(), LD->getBasePtr(),
888                   LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(),
889                   LD->isInvariant(), LD->getAlignment());
890   SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
891   // The legalizer (the caller) is expecting two values from the legalized
892   // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
893   // in LegalizeDAG.cpp which also uses MergeValues.
894   SDValue Ops[] = { result, LD->getChain() };
895   return DAG.getMergeValues(Ops, 2, dl);
896 }
897
898 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
899   EVT ValVT = Op.getOperand(1).getValueType();
900   if (ValVT == MVT::i1)
901     return LowerSTOREi1(Op, DAG);
902   else if (ValVT.isVector())
903     return LowerSTOREVector(Op, DAG);
904   else
905     return SDValue();
906 }
907
908 SDValue
909 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
910   SDNode *N = Op.getNode();
911   SDValue Val = N->getOperand(1);
912   SDLoc DL(N);
913   EVT ValVT = Val.getValueType();
914
915   if (ValVT.isVector()) {
916     // We only handle "native" vector sizes for now, e.g. <4 x double> is not
917     // legal.  We can (and should) split that into 2 stores of <2 x double> here
918     // but I'm leaving that as a TODO for now.
919     if (!ValVT.isSimple())
920       return SDValue();
921     switch (ValVT.getSimpleVT().SimpleTy) {
922     default:
923       return SDValue();
924     case MVT::v2i8:
925     case MVT::v2i16:
926     case MVT::v2i32:
927     case MVT::v2i64:
928     case MVT::v2f32:
929     case MVT::v2f64:
930     case MVT::v4i8:
931     case MVT::v4i16:
932     case MVT::v4i32:
933     case MVT::v4f32:
934       // This is a "native" vector type
935       break;
936     }
937
938     unsigned Opcode = 0;
939     EVT EltVT = ValVT.getVectorElementType();
940     unsigned NumElts = ValVT.getVectorNumElements();
941
942     // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
943     // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
944     // stored type to i16 and propogate the "real" type as the memory type.
945     bool NeedExt = false;
946     if (EltVT.getSizeInBits() < 16)
947       NeedExt = true;
948
949     switch (NumElts) {
950     default:
951       return SDValue();
952     case 2:
953       Opcode = NVPTXISD::StoreV2;
954       break;
955     case 4: {
956       Opcode = NVPTXISD::StoreV4;
957       break;
958     }
959     }
960
961     SmallVector<SDValue, 8> Ops;
962
963     // First is the chain
964     Ops.push_back(N->getOperand(0));
965
966     // Then the split values
967     for (unsigned i = 0; i < NumElts; ++i) {
968       SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
969                                    DAG.getIntPtrConstant(i));
970       if (NeedExt)
971         // ANY_EXTEND is correct here since the store will only look at the
972         // lower-order bits anyway.
973         ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
974       Ops.push_back(ExtVal);
975     }
976
977     // Then any remaining arguments
978     for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
979       Ops.push_back(N->getOperand(i));
980     }
981
982     MemSDNode *MemSD = cast<MemSDNode>(N);
983
984     SDValue NewSt = DAG.getMemIntrinsicNode(
985         Opcode, DL, DAG.getVTList(MVT::Other), &Ops[0], Ops.size(),
986         MemSD->getMemoryVT(), MemSD->getMemOperand());
987
988     //return DCI.CombineTo(N, NewSt, true);
989     return NewSt;
990   }
991
992   return SDValue();
993 }
994
995 // st i1 v, addr
996 //    =>
997 // v1 = zxt v to i8
998 // st i8, addr
999 SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
1000   SDNode *Node = Op.getNode();
1001   SDLoc dl(Node);
1002   StoreSDNode *ST = cast<StoreSDNode>(Node);
1003   SDValue Tmp1 = ST->getChain();
1004   SDValue Tmp2 = ST->getBasePtr();
1005   SDValue Tmp3 = ST->getValue();
1006   assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
1007   unsigned Alignment = ST->getAlignment();
1008   bool isVolatile = ST->isVolatile();
1009   bool isNonTemporal = ST->isNonTemporal();
1010   Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, Tmp3);
1011   SDValue Result = DAG.getStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(),
1012                                 isVolatile, isNonTemporal, Alignment);
1013   return Result;
1014 }
1015
1016 SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname,
1017                                         int idx, EVT v) const {
1018   std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
1019   std::stringstream suffix;
1020   suffix << idx;
1021   *name += suffix.str();
1022   return DAG.getTargetExternalSymbol(name->c_str(), v);
1023 }
1024
1025 SDValue
1026 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1027   return getExtSymb(DAG, ".PARAM", idx, v);
1028 }
1029
1030 SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1031   return getExtSymb(DAG, ".HLPPARAM", idx);
1032 }
1033
1034 // Check to see if the kernel argument is image*_t or sampler_t
1035
1036 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1037   static const char *const specialTypes[] = { "struct._image2d_t",
1038                                               "struct._image3d_t",
1039                                               "struct._sampler_t" };
1040
1041   const Type *Ty = arg->getType();
1042   const PointerType *PTy = dyn_cast<PointerType>(Ty);
1043
1044   if (!PTy)
1045     return false;
1046
1047   if (!context)
1048     return false;
1049
1050   const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1051   const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1052
1053   for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1054     if (TypeName == specialTypes[i])
1055       return true;
1056
1057   return false;
1058 }
1059
1060 SDValue NVPTXTargetLowering::LowerFormalArguments(
1061     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1062     const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG,
1063     SmallVectorImpl<SDValue> &InVals) const {
1064   MachineFunction &MF = DAG.getMachineFunction();
1065   const DataLayout *TD = getDataLayout();
1066
1067   const Function *F = MF.getFunction();
1068   const AttributeSet &PAL = F->getAttributes();
1069   const TargetLowering *TLI = nvTM->getTargetLowering();
1070
1071   SDValue Root = DAG.getRoot();
1072   std::vector<SDValue> OutChains;
1073
1074   bool isKernel = llvm::isKernelFunction(*F);
1075   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1076   assert(isABI && "Non-ABI compilation is not supported");
1077   if (!isABI)
1078     return Chain;
1079
1080   std::vector<Type *> argTypes;
1081   std::vector<const Argument *> theArgs;
1082   for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1083        I != E; ++I) {
1084     theArgs.push_back(I);
1085     argTypes.push_back(I->getType());
1086   }
1087   // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
1088   // Ins.size() will be larger
1089   //   * if there is an aggregate argument with multiple fields (each field
1090   //     showing up separately in Ins)
1091   //   * if there is a vector argument with more than typical vector-length
1092   //     elements (generally if more than 4) where each vector element is
1093   //     individually present in Ins.
1094   // So a different index should be used for indexing into Ins.
1095   // See similar issue in LowerCall.
1096   unsigned InsIdx = 0;
1097
1098   int idx = 0;
1099   for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) {
1100     Type *Ty = argTypes[i];
1101
1102     // If the kernel argument is image*_t or sampler_t, convert it to
1103     // a i32 constant holding the parameter position. This can later
1104     // matched in the AsmPrinter to output the correct mangled name.
1105     if (isImageOrSamplerVal(
1106             theArgs[i],
1107             (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent()
1108                                      : 0))) {
1109       assert(isKernel && "Only kernels can have image/sampler params");
1110       InVals.push_back(DAG.getConstant(i + 1, MVT::i32));
1111       continue;
1112     }
1113
1114     if (theArgs[i]->use_empty()) {
1115       // argument is dead
1116       if (Ty->isAggregateType()) {
1117         SmallVector<EVT, 16> vtparts;
1118
1119         ComputeValueVTs(*this, Ty, vtparts);
1120         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1121         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1122              ++parti) {
1123           EVT partVT = vtparts[parti];
1124           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, partVT));
1125           ++InsIdx;
1126         }
1127         if (vtparts.size() > 0)
1128           --InsIdx;
1129         continue;
1130       }
1131       if (Ty->isVectorTy()) {
1132         EVT ObjectVT = getValueType(Ty);
1133         unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
1134         for (unsigned parti = 0; parti < NumRegs; ++parti) {
1135           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1136           ++InsIdx;
1137         }
1138         if (NumRegs > 0)
1139           --InsIdx;
1140         continue;
1141       }
1142       InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1143       continue;
1144     }
1145
1146     // In the following cases, assign a node order of "idx+1"
1147     // to newly created nodes. The SDNodes for params have to
1148     // appear in the same order as their order of appearance
1149     // in the original function. "idx+1" holds that order.
1150     if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
1151       if (Ty->isAggregateType()) {
1152         SmallVector<EVT, 16> vtparts;
1153         SmallVector<uint64_t, 16> offsets;
1154
1155         ComputeValueVTs(*this, Ty, vtparts, &offsets, 0);
1156         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1157         bool aggregateIsPacked = false;
1158         if (StructType *STy = llvm::dyn_cast<StructType>(Ty))
1159           aggregateIsPacked = STy->isPacked();
1160
1161         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1162         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1163              ++parti) {
1164           EVT partVT = vtparts[parti];
1165           Value *srcValue = Constant::getNullValue(
1166               PointerType::get(partVT.getTypeForEVT(F->getContext()),
1167                                llvm::ADDRESS_SPACE_PARAM));
1168           SDValue srcAddr =
1169               DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1170                           DAG.getConstant(offsets[parti], getPointerTy()));
1171           unsigned partAlign =
1172               aggregateIsPacked ? 1
1173                                 : TD->getABITypeAlignment(
1174                                       partVT.getTypeForEVT(F->getContext()));
1175           SDValue p = DAG.getLoad(partVT, dl, Root, srcAddr,
1176                                   MachinePointerInfo(srcValue), false, false,
1177                                   true, partAlign);
1178           if (p.getNode())
1179             p.getNode()->setIROrder(idx + 1);
1180           InVals.push_back(p);
1181           ++InsIdx;
1182         }
1183         if (vtparts.size() > 0)
1184           --InsIdx;
1185         continue;
1186       }
1187       if (Ty->isVectorTy()) {
1188         EVT ObjectVT = getValueType(Ty);
1189         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1190         unsigned NumElts = ObjectVT.getVectorNumElements();
1191         assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1192                "Vector was not scalarized");
1193         unsigned Ofst = 0;
1194         EVT EltVT = ObjectVT.getVectorElementType();
1195
1196         // V1 load
1197         // f32 = load ...
1198         if (NumElts == 1) {
1199           // We only have one element, so just directly load it
1200           Value *SrcValue = Constant::getNullValue(PointerType::get(
1201               EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1202           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1203                                         DAG.getConstant(Ofst, getPointerTy()));
1204           SDValue P = DAG.getLoad(
1205               EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1206               false, true,
1207               TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
1208           if (P.getNode())
1209             P.getNode()->setIROrder(idx + 1);
1210
1211           InVals.push_back(P);
1212           Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
1213           ++InsIdx;
1214         } else if (NumElts == 2) {
1215           // V2 load
1216           // f32,f32 = load ...
1217           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
1218           Value *SrcValue = Constant::getNullValue(PointerType::get(
1219               VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1220           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1221                                         DAG.getConstant(Ofst, getPointerTy()));
1222           SDValue P = DAG.getLoad(
1223               VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1224               false, true,
1225               TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1226           if (P.getNode())
1227             P.getNode()->setIROrder(idx + 1);
1228
1229           SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1230                                      DAG.getIntPtrConstant(0));
1231           SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1232                                      DAG.getIntPtrConstant(1));
1233           InVals.push_back(Elt0);
1234           InVals.push_back(Elt1);
1235           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1236           InsIdx += 2;
1237         } else {
1238           // V4 loads
1239           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
1240           // the
1241           // vector will be expanded to a power of 2 elements, so we know we can
1242           // always round up to the next multiple of 4 when creating the vector
1243           // loads.
1244           // e.g.  4 elem => 1 ld.v4
1245           //       6 elem => 2 ld.v4
1246           //       8 elem => 2 ld.v4
1247           //      11 elem => 3 ld.v4
1248           unsigned VecSize = 4;
1249           if (EltVT.getSizeInBits() == 64) {
1250             VecSize = 2;
1251           }
1252           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1253           for (unsigned i = 0; i < NumElts; i += VecSize) {
1254             Value *SrcValue = Constant::getNullValue(
1255                 PointerType::get(VecVT.getTypeForEVT(F->getContext()),
1256                                  llvm::ADDRESS_SPACE_PARAM));
1257             SDValue SrcAddr =
1258                 DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1259                             DAG.getConstant(Ofst, getPointerTy()));
1260             SDValue P = DAG.getLoad(
1261                 VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1262                 false, true,
1263                 TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1264             if (P.getNode())
1265               P.getNode()->setIROrder(idx + 1);
1266
1267             for (unsigned j = 0; j < VecSize; ++j) {
1268               if (i + j >= NumElts)
1269                 break;
1270               SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1271                                         DAG.getIntPtrConstant(j));
1272               InVals.push_back(Elt);
1273             }
1274             Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1275             InsIdx += VecSize;
1276           }
1277         }
1278
1279         if (NumElts > 0)
1280           --InsIdx;
1281         continue;
1282       }
1283       // A plain scalar.
1284       EVT ObjectVT = getValueType(Ty);
1285       assert(ObjectVT == Ins[InsIdx].VT &&
1286              "Ins type did not match function type");
1287       // If ABI, load from the param symbol
1288       SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1289       Value *srcValue = Constant::getNullValue(PointerType::get(
1290           ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1291       SDValue p = DAG.getLoad(
1292           ObjectVT, dl, Root, Arg, MachinePointerInfo(srcValue), false, false,
1293           true,
1294           TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1295       if (p.getNode())
1296         p.getNode()->setIROrder(idx + 1);
1297       InVals.push_back(p);
1298       continue;
1299     }
1300
1301     // Param has ByVal attribute
1302     // Return MoveParam(param symbol).
1303     // Ideally, the param symbol can be returned directly,
1304     // but when SDNode builder decides to use it in a CopyToReg(),
1305     // machine instruction fails because TargetExternalSymbol
1306     // (not lowered) is target dependent, and CopyToReg assumes
1307     // the source is lowered.
1308     EVT ObjectVT = getValueType(Ty);
1309     assert(ObjectVT == Ins[InsIdx].VT &&
1310            "Ins type did not match function type");
1311     SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1312     SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1313     if (p.getNode())
1314       p.getNode()->setIROrder(idx + 1);
1315     if (isKernel)
1316       InVals.push_back(p);
1317     else {
1318       SDValue p2 = DAG.getNode(
1319           ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1320           DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
1321       InVals.push_back(p2);
1322     }
1323   }
1324
1325   // Clang will check explicit VarArg and issue error if any. However, Clang
1326   // will let code with
1327   // implicit var arg like f() pass. See bug 617733.
1328   // We treat this case as if the arg list is empty.
1329   // if (F.isVarArg()) {
1330   // assert(0 && "VarArg not supported yet!");
1331   //}
1332
1333   if (!OutChains.empty())
1334     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &OutChains[0],
1335                             OutChains.size()));
1336
1337   return Chain;
1338 }
1339
1340
1341 SDValue
1342 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1343                                  bool isVarArg,
1344                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
1345                                  const SmallVectorImpl<SDValue> &OutVals,
1346                                  SDLoc dl, SelectionDAG &DAG) const {
1347   MachineFunction &MF = DAG.getMachineFunction();
1348   const Function *F = MF.getFunction();
1349   const Type *RetTy = F->getReturnType();
1350   const DataLayout *TD = getDataLayout();
1351
1352   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1353   assert(isABI && "Non-ABI compilation is not supported");
1354   if (!isABI)
1355     return Chain;
1356
1357   if (const VectorType *VTy = dyn_cast<const VectorType>(RetTy)) {
1358     // If we have a vector type, the OutVals array will be the scalarized
1359     // components and we have combine them into 1 or more vector stores.
1360     unsigned NumElts = VTy->getNumElements();
1361     assert(NumElts == Outs.size() && "Bad scalarization of return value");
1362
1363     // V1 store
1364     if (NumElts == 1) {
1365       SDValue StoreVal = OutVals[0];
1366       // We only have one element, so just directly store it
1367       if (StoreVal.getValueType().getSizeInBits() < 8)
1368         StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
1369       Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain,
1370                           DAG.getConstant(0, MVT::i32), StoreVal);
1371     } else if (NumElts == 2) {
1372       // V2 store
1373       SDValue StoreVal0 = OutVals[0];
1374       SDValue StoreVal1 = OutVals[1];
1375
1376       if (StoreVal0.getValueType().getSizeInBits() < 8) {
1377         StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal0);
1378         StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal1);
1379       }
1380
1381       Chain = DAG.getNode(NVPTXISD::StoreRetvalV2, dl, MVT::Other, Chain,
1382                           DAG.getConstant(0, MVT::i32), StoreVal0, StoreVal1);
1383     } else {
1384       // V4 stores
1385       // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
1386       // vector will be expanded to a power of 2 elements, so we know we can
1387       // always round up to the next multiple of 4 when creating the vector
1388       // stores.
1389       // e.g.  4 elem => 1 st.v4
1390       //       6 elem => 2 st.v4
1391       //       8 elem => 2 st.v4
1392       //      11 elem => 3 st.v4
1393
1394       unsigned VecSize = 4;
1395       if (OutVals[0].getValueType().getSizeInBits() == 64)
1396         VecSize = 2;
1397
1398       unsigned Offset = 0;
1399
1400       EVT VecVT =
1401           EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
1402       unsigned PerStoreOffset =
1403           TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1404
1405       bool Extend = false;
1406       if (OutVals[0].getValueType().getSizeInBits() < 8)
1407         Extend = true;
1408
1409       for (unsigned i = 0; i < NumElts; i += VecSize) {
1410         // Get values
1411         SDValue StoreVal;
1412         SmallVector<SDValue, 8> Ops;
1413         Ops.push_back(Chain);
1414         Ops.push_back(DAG.getConstant(Offset, MVT::i32));
1415         unsigned Opc = NVPTXISD::StoreRetvalV2;
1416         EVT ExtendedVT = (Extend) ? MVT::i8 : OutVals[0].getValueType();
1417
1418         StoreVal = OutVals[i];
1419         if (Extend)
1420           StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
1421         Ops.push_back(StoreVal);
1422
1423         if (i + 1 < NumElts) {
1424           StoreVal = OutVals[i + 1];
1425           if (Extend)
1426             StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
1427         } else {
1428           StoreVal = DAG.getUNDEF(ExtendedVT);
1429         }
1430         Ops.push_back(StoreVal);
1431
1432         if (VecSize == 4) {
1433           Opc = NVPTXISD::StoreRetvalV4;
1434           if (i + 2 < NumElts) {
1435             StoreVal = OutVals[i + 2];
1436             if (Extend)
1437               StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
1438           } else {
1439             StoreVal = DAG.getUNDEF(ExtendedVT);
1440           }
1441           Ops.push_back(StoreVal);
1442
1443           if (i + 3 < NumElts) {
1444             StoreVal = OutVals[i + 3];
1445             if (Extend)
1446               StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
1447           } else {
1448             StoreVal = DAG.getUNDEF(ExtendedVT);
1449           }
1450           Ops.push_back(StoreVal);
1451         }
1452
1453         Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
1454         Offset += PerStoreOffset;
1455       }
1456     }
1457   } else {
1458     unsigned sizesofar = 0;
1459     for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
1460       SDValue theVal = OutVals[i];
1461       EVT theValType = theVal.getValueType();
1462       unsigned numElems = 1;
1463       if (theValType.isVector())
1464         numElems = theValType.getVectorNumElements();
1465       for (unsigned j = 0, je = numElems; j != je; ++j) {
1466         SDValue tmpval = theVal;
1467         if (theValType.isVector())
1468           tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1469                                theValType.getVectorElementType(), tmpval,
1470                                DAG.getIntPtrConstant(j));
1471         EVT theStoreType = tmpval.getValueType();
1472         if (theStoreType.getSizeInBits() < 8)
1473           tmpval = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, tmpval);
1474         Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain,
1475                             DAG.getConstant(sizesofar, MVT::i32), tmpval);
1476         if (theValType.isVector())
1477           sizesofar +=
1478               theValType.getVectorElementType().getStoreSizeInBits() / 8;
1479         else
1480           sizesofar += theValType.getStoreSizeInBits() / 8;
1481       }
1482     }
1483   }
1484
1485   return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1486 }
1487
1488 void NVPTXTargetLowering::LowerAsmOperandForConstraint(
1489     SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
1490     SelectionDAG &DAG) const {
1491   if (Constraint.length() > 1)
1492     return;
1493   else
1494     TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1495 }
1496
1497 // NVPTX suuport vector of legal types of any length in Intrinsics because the
1498 // NVPTX specific type legalizer
1499 // will legalize them to the PTX supported length.
1500 bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1501   if (isTypeLegal(VT))
1502     return true;
1503   if (VT.isVector()) {
1504     MVT eVT = VT.getVectorElementType();
1505     if (isTypeLegal(eVT))
1506       return true;
1507   }
1508   return false;
1509 }
1510
1511 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
1512 // TgtMemIntrinsic
1513 // because we need the information that is only available in the "Value" type
1514 // of destination
1515 // pointer. In particular, the address space information.
1516 bool NVPTXTargetLowering::getTgtMemIntrinsic(
1517     IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const {
1518   switch (Intrinsic) {
1519   default:
1520     return false;
1521
1522   case Intrinsic::nvvm_atomic_load_add_f32:
1523     Info.opc = ISD::INTRINSIC_W_CHAIN;
1524     Info.memVT = MVT::f32;
1525     Info.ptrVal = I.getArgOperand(0);
1526     Info.offset = 0;
1527     Info.vol = 0;
1528     Info.readMem = true;
1529     Info.writeMem = true;
1530     Info.align = 0;
1531     return true;
1532
1533   case Intrinsic::nvvm_atomic_load_inc_32:
1534   case Intrinsic::nvvm_atomic_load_dec_32:
1535     Info.opc = ISD::INTRINSIC_W_CHAIN;
1536     Info.memVT = MVT::i32;
1537     Info.ptrVal = I.getArgOperand(0);
1538     Info.offset = 0;
1539     Info.vol = 0;
1540     Info.readMem = true;
1541     Info.writeMem = true;
1542     Info.align = 0;
1543     return true;
1544
1545   case Intrinsic::nvvm_ldu_global_i:
1546   case Intrinsic::nvvm_ldu_global_f:
1547   case Intrinsic::nvvm_ldu_global_p:
1548
1549     Info.opc = ISD::INTRINSIC_W_CHAIN;
1550     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
1551       Info.memVT = MVT::i32;
1552     else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
1553       Info.memVT = getPointerTy();
1554     else
1555       Info.memVT = MVT::f32;
1556     Info.ptrVal = I.getArgOperand(0);
1557     Info.offset = 0;
1558     Info.vol = 0;
1559     Info.readMem = true;
1560     Info.writeMem = false;
1561     Info.align = 0;
1562     return true;
1563
1564   }
1565   return false;
1566 }
1567
1568 /// isLegalAddressingMode - Return true if the addressing mode represented
1569 /// by AM is legal for this target, for a load/store of the specified type.
1570 /// Used to guide target specific optimizations, like loop strength reduction
1571 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
1572 /// (CodeGenPrepare.cpp)
1573 bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
1574                                                 Type *Ty) const {
1575
1576   // AddrMode - This represents an addressing mode of:
1577   //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
1578   //
1579   // The legal address modes are
1580   // - [avar]
1581   // - [areg]
1582   // - [areg+immoff]
1583   // - [immAddr]
1584
1585   if (AM.BaseGV) {
1586     if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
1587       return false;
1588     return true;
1589   }
1590
1591   switch (AM.Scale) {
1592   case 0: // "r", "r+i" or "i" is allowed
1593     break;
1594   case 1:
1595     if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
1596       return false;
1597     // Otherwise we have r+i.
1598     break;
1599   default:
1600     // No scale > 1 is allowed
1601     return false;
1602   }
1603   return true;
1604 }
1605
1606 //===----------------------------------------------------------------------===//
1607 //                         NVPTX Inline Assembly Support
1608 //===----------------------------------------------------------------------===//
1609
1610 /// getConstraintType - Given a constraint letter, return the type of
1611 /// constraint it is for this target.
1612 NVPTXTargetLowering::ConstraintType
1613 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
1614   if (Constraint.size() == 1) {
1615     switch (Constraint[0]) {
1616     default:
1617       break;
1618     case 'r':
1619     case 'h':
1620     case 'c':
1621     case 'l':
1622     case 'f':
1623     case 'd':
1624     case '0':
1625     case 'N':
1626       return C_RegisterClass;
1627     }
1628   }
1629   return TargetLowering::getConstraintType(Constraint);
1630 }
1631
1632 std::pair<unsigned, const TargetRegisterClass *>
1633 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
1634                                                   MVT VT) const {
1635   if (Constraint.size() == 1) {
1636     switch (Constraint[0]) {
1637     case 'c':
1638       return std::make_pair(0U, &NVPTX::Int8RegsRegClass);
1639     case 'h':
1640       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
1641     case 'r':
1642       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
1643     case 'l':
1644     case 'N':
1645       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
1646     case 'f':
1647       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
1648     case 'd':
1649       return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
1650     }
1651   }
1652   return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
1653 }
1654
1655 /// getFunctionAlignment - Return the Log2 alignment of this function.
1656 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
1657   return 4;
1658 }
1659
1660 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
1661 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
1662                               SmallVectorImpl<SDValue> &Results) {
1663   EVT ResVT = N->getValueType(0);
1664   SDLoc DL(N);
1665
1666   assert(ResVT.isVector() && "Vector load must have vector type");
1667
1668   // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1669   // legal.  We can (and should) split that into 2 loads of <2 x double> here
1670   // but I'm leaving that as a TODO for now.
1671   assert(ResVT.isSimple() && "Can only handle simple types");
1672   switch (ResVT.getSimpleVT().SimpleTy) {
1673   default:
1674     return;
1675   case MVT::v2i8:
1676   case MVT::v2i16:
1677   case MVT::v2i32:
1678   case MVT::v2i64:
1679   case MVT::v2f32:
1680   case MVT::v2f64:
1681   case MVT::v4i8:
1682   case MVT::v4i16:
1683   case MVT::v4i32:
1684   case MVT::v4f32:
1685     // This is a "native" vector type
1686     break;
1687   }
1688
1689   EVT EltVT = ResVT.getVectorElementType();
1690   unsigned NumElts = ResVT.getVectorNumElements();
1691
1692   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
1693   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1694   // loaded type to i16 and propogate the "real" type as the memory type.
1695   bool NeedTrunc = false;
1696   if (EltVT.getSizeInBits() < 16) {
1697     EltVT = MVT::i16;
1698     NeedTrunc = true;
1699   }
1700
1701   unsigned Opcode = 0;
1702   SDVTList LdResVTs;
1703
1704   switch (NumElts) {
1705   default:
1706     return;
1707   case 2:
1708     Opcode = NVPTXISD::LoadV2;
1709     LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
1710     break;
1711   case 4: {
1712     Opcode = NVPTXISD::LoadV4;
1713     EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
1714     LdResVTs = DAG.getVTList(ListVTs, 5);
1715     break;
1716   }
1717   }
1718
1719   SmallVector<SDValue, 8> OtherOps;
1720
1721   // Copy regular operands
1722   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1723     OtherOps.push_back(N->getOperand(i));
1724
1725   LoadSDNode *LD = cast<LoadSDNode>(N);
1726
1727   // The select routine does not have access to the LoadSDNode instance, so
1728   // pass along the extension information
1729   OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
1730
1731   SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
1732                                           OtherOps.size(), LD->getMemoryVT(),
1733                                           LD->getMemOperand());
1734
1735   SmallVector<SDValue, 4> ScalarRes;
1736
1737   for (unsigned i = 0; i < NumElts; ++i) {
1738     SDValue Res = NewLD.getValue(i);
1739     if (NeedTrunc)
1740       Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
1741     ScalarRes.push_back(Res);
1742   }
1743
1744   SDValue LoadChain = NewLD.getValue(NumElts);
1745
1746   SDValue BuildVec =
1747       DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
1748
1749   Results.push_back(BuildVec);
1750   Results.push_back(LoadChain);
1751 }
1752
1753 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
1754                                      SmallVectorImpl<SDValue> &Results) {
1755   SDValue Chain = N->getOperand(0);
1756   SDValue Intrin = N->getOperand(1);
1757   SDLoc DL(N);
1758
1759   // Get the intrinsic ID
1760   unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
1761   switch (IntrinNo) {
1762   default:
1763     return;
1764   case Intrinsic::nvvm_ldg_global_i:
1765   case Intrinsic::nvvm_ldg_global_f:
1766   case Intrinsic::nvvm_ldg_global_p:
1767   case Intrinsic::nvvm_ldu_global_i:
1768   case Intrinsic::nvvm_ldu_global_f:
1769   case Intrinsic::nvvm_ldu_global_p: {
1770     EVT ResVT = N->getValueType(0);
1771
1772     if (ResVT.isVector()) {
1773       // Vector LDG/LDU
1774
1775       unsigned NumElts = ResVT.getVectorNumElements();
1776       EVT EltVT = ResVT.getVectorElementType();
1777
1778       // Since LDU/LDG are target nodes, we cannot rely on DAG type legalization.
1779       // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1780       // loaded type to i16 and propogate the "real" type as the memory type.
1781       bool NeedTrunc = false;
1782       if (EltVT.getSizeInBits() < 16) {
1783         EltVT = MVT::i16;
1784         NeedTrunc = true;
1785       }
1786
1787       unsigned Opcode = 0;
1788       SDVTList LdResVTs;
1789
1790       switch (NumElts) {
1791       default:
1792         return;
1793       case 2:
1794         switch (IntrinNo) {
1795         default:
1796           return;
1797         case Intrinsic::nvvm_ldg_global_i:
1798         case Intrinsic::nvvm_ldg_global_f:
1799         case Intrinsic::nvvm_ldg_global_p:
1800           Opcode = NVPTXISD::LDGV2;
1801           break;
1802         case Intrinsic::nvvm_ldu_global_i:
1803         case Intrinsic::nvvm_ldu_global_f:
1804         case Intrinsic::nvvm_ldu_global_p:
1805           Opcode = NVPTXISD::LDUV2;
1806           break;
1807         }
1808         LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
1809         break;
1810       case 4: {
1811         switch (IntrinNo) {
1812         default:
1813           return;
1814         case Intrinsic::nvvm_ldg_global_i:
1815         case Intrinsic::nvvm_ldg_global_f:
1816         case Intrinsic::nvvm_ldg_global_p:
1817           Opcode = NVPTXISD::LDGV4;
1818           break;
1819         case Intrinsic::nvvm_ldu_global_i:
1820         case Intrinsic::nvvm_ldu_global_f:
1821         case Intrinsic::nvvm_ldu_global_p:
1822           Opcode = NVPTXISD::LDUV4;
1823           break;
1824         }
1825         EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
1826         LdResVTs = DAG.getVTList(ListVTs, 5);
1827         break;
1828       }
1829       }
1830
1831       SmallVector<SDValue, 8> OtherOps;
1832
1833       // Copy regular operands
1834
1835       OtherOps.push_back(Chain); // Chain
1836                                  // Skip operand 1 (intrinsic ID)
1837                                  // Others
1838       for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
1839         OtherOps.push_back(N->getOperand(i));
1840
1841       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
1842
1843       SDValue NewLD = DAG.getMemIntrinsicNode(
1844           Opcode, DL, LdResVTs, &OtherOps[0], OtherOps.size(),
1845           MemSD->getMemoryVT(), MemSD->getMemOperand());
1846
1847       SmallVector<SDValue, 4> ScalarRes;
1848
1849       for (unsigned i = 0; i < NumElts; ++i) {
1850         SDValue Res = NewLD.getValue(i);
1851         if (NeedTrunc)
1852           Res =
1853               DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
1854         ScalarRes.push_back(Res);
1855       }
1856
1857       SDValue LoadChain = NewLD.getValue(NumElts);
1858
1859       SDValue BuildVec =
1860           DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
1861
1862       Results.push_back(BuildVec);
1863       Results.push_back(LoadChain);
1864     } else {
1865       // i8 LDG/LDU
1866       assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
1867              "Custom handling of non-i8 ldu/ldg?");
1868
1869       // Just copy all operands as-is
1870       SmallVector<SDValue, 4> Ops;
1871       for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1872         Ops.push_back(N->getOperand(i));
1873
1874       // Force output to i16
1875       SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
1876
1877       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
1878
1879       // We make sure the memory type is i8, which will be used during isel
1880       // to select the proper instruction.
1881       SDValue NewLD =
1882           DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0],
1883                                   Ops.size(), MVT::i8, MemSD->getMemOperand());
1884
1885       Results.push_back(NewLD.getValue(0));
1886       Results.push_back(NewLD.getValue(1));
1887     }
1888   }
1889   }
1890 }
1891
1892 void NVPTXTargetLowering::ReplaceNodeResults(
1893     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
1894   switch (N->getOpcode()) {
1895   default:
1896     report_fatal_error("Unhandled custom legalization");
1897   case ISD::LOAD:
1898     ReplaceLoadVector(N, DAG, Results);
1899     return;
1900   case ISD::INTRINSIC_W_CHAIN:
1901     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
1902     return;
1903   }
1904 }