The getRegForInlineAsmConstraint function should only accept MVT value types.
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
1 //
2 //                     The LLVM Compiler Infrastructure
3 //
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "NVPTXISelLowering.h"
15 #include "NVPTX.h"
16 #include "NVPTXTargetMachine.h"
17 #include "NVPTXTargetObjectFile.h"
18 #include "NVPTXUtilities.h"
19 #include "llvm/CodeGen/Analysis.h"
20 #include "llvm/CodeGen/MachineFrameInfo.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/MC/MCSectionELF.h"
32 #include "llvm/Support/CallSite.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <sstream>
38
39 #undef DEBUG_TYPE
40 #define DEBUG_TYPE "nvptx-lower"
41
42 using namespace llvm;
43
44 static unsigned int uniqueCallSite = 0;
45
46 static cl::opt<bool> sched4reg(
47     "nvptx-sched4reg",
48     cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
49
50 static bool IsPTXVectorType(MVT VT) {
51   switch (VT.SimpleTy) {
52   default:
53     return false;
54   case MVT::v2i8:
55   case MVT::v4i8:
56   case MVT::v2i16:
57   case MVT::v4i16:
58   case MVT::v2i32:
59   case MVT::v4i32:
60   case MVT::v2i64:
61   case MVT::v2f32:
62   case MVT::v4f32:
63   case MVT::v2f64:
64     return true;
65   }
66 }
67
68 // NVPTXTargetLowering Constructor.
69 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
70     : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
71       nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
72
73   // always lower memset, memcpy, and memmove intrinsics to load/store
74   // instructions, rather
75   // then generating calls to memset, mempcy or memmove.
76   MaxStoresPerMemset = (unsigned) 0xFFFFFFFF;
77   MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF;
78   MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF;
79
80   setBooleanContents(ZeroOrNegativeOneBooleanContent);
81
82   // Jump is Expensive. Don't create extra control flow for 'and', 'or'
83   // condition branches.
84   setJumpIsExpensive(true);
85
86   // By default, use the Source scheduling
87   if (sched4reg)
88     setSchedulingPreference(Sched::RegPressure);
89   else
90     setSchedulingPreference(Sched::Source);
91
92   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
93   addRegisterClass(MVT::i8, &NVPTX::Int8RegsRegClass);
94   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
95   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
96   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
97   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
98   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
99
100   // Operations not directly supported by NVPTX.
101   setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
102   setOperationAction(ISD::BR_CC, MVT::f32, Expand);
103   setOperationAction(ISD::BR_CC, MVT::f64, Expand);
104   setOperationAction(ISD::BR_CC, MVT::i1, Expand);
105   setOperationAction(ISD::BR_CC, MVT::i8, Expand);
106   setOperationAction(ISD::BR_CC, MVT::i16, Expand);
107   setOperationAction(ISD::BR_CC, MVT::i32, Expand);
108   setOperationAction(ISD::BR_CC, MVT::i64, Expand);
109   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Expand);
110   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Expand);
111   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand);
112   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8, Expand);
113   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
114
115   if (nvptxSubtarget.hasROT64()) {
116     setOperationAction(ISD::ROTL, MVT::i64, Legal);
117     setOperationAction(ISD::ROTR, MVT::i64, Legal);
118   } else {
119     setOperationAction(ISD::ROTL, MVT::i64, Expand);
120     setOperationAction(ISD::ROTR, MVT::i64, Expand);
121   }
122   if (nvptxSubtarget.hasROT32()) {
123     setOperationAction(ISD::ROTL, MVT::i32, Legal);
124     setOperationAction(ISD::ROTR, MVT::i32, Legal);
125   } else {
126     setOperationAction(ISD::ROTL, MVT::i32, Expand);
127     setOperationAction(ISD::ROTR, MVT::i32, Expand);
128   }
129
130   setOperationAction(ISD::ROTL, MVT::i16, Expand);
131   setOperationAction(ISD::ROTR, MVT::i16, Expand);
132   setOperationAction(ISD::ROTL, MVT::i8, Expand);
133   setOperationAction(ISD::ROTR, MVT::i8, Expand);
134   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
135   setOperationAction(ISD::BSWAP, MVT::i32, Expand);
136   setOperationAction(ISD::BSWAP, MVT::i64, Expand);
137
138   // Indirect branch is not supported.
139   // This also disables Jump Table creation.
140   setOperationAction(ISD::BR_JT, MVT::Other, Expand);
141   setOperationAction(ISD::BRIND, MVT::Other, Expand);
142
143   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
144   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
145
146   // We want to legalize constant related memmove and memcopy
147   // intrinsics.
148   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
149
150   // Turn FP extload into load/fextend
151   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
152   // Turn FP truncstore into trunc + store.
153   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
154
155   // PTX does not support load / store predicate registers
156   setOperationAction(ISD::LOAD, MVT::i1, Custom);
157   setOperationAction(ISD::STORE, MVT::i1, Custom);
158
159   setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
160   setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
161   setTruncStoreAction(MVT::i64, MVT::i1, Expand);
162   setTruncStoreAction(MVT::i32, MVT::i1, Expand);
163   setTruncStoreAction(MVT::i16, MVT::i1, Expand);
164   setTruncStoreAction(MVT::i8, MVT::i1, Expand);
165
166   // This is legal in NVPTX
167   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
168   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
169
170   // TRAP can be lowered to PTX trap
171   setOperationAction(ISD::TRAP, MVT::Other, Legal);
172
173   // Register custom handling for vector loads/stores
174   for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
175        ++i) {
176     MVT VT = (MVT::SimpleValueType) i;
177     if (IsPTXVectorType(VT)) {
178       setOperationAction(ISD::LOAD, VT, Custom);
179       setOperationAction(ISD::STORE, VT, Custom);
180       setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
181     }
182   }
183
184   // Now deduce the information based on the above mentioned
185   // actions
186   computeRegisterProperties();
187 }
188
189 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
190   switch (Opcode) {
191   default:
192     return 0;
193   case NVPTXISD::CALL:
194     return "NVPTXISD::CALL";
195   case NVPTXISD::RET_FLAG:
196     return "NVPTXISD::RET_FLAG";
197   case NVPTXISD::Wrapper:
198     return "NVPTXISD::Wrapper";
199   case NVPTXISD::NVBuiltin:
200     return "NVPTXISD::NVBuiltin";
201   case NVPTXISD::DeclareParam:
202     return "NVPTXISD::DeclareParam";
203   case NVPTXISD::DeclareScalarParam:
204     return "NVPTXISD::DeclareScalarParam";
205   case NVPTXISD::DeclareRet:
206     return "NVPTXISD::DeclareRet";
207   case NVPTXISD::DeclareRetParam:
208     return "NVPTXISD::DeclareRetParam";
209   case NVPTXISD::PrintCall:
210     return "NVPTXISD::PrintCall";
211   case NVPTXISD::LoadParam:
212     return "NVPTXISD::LoadParam";
213   case NVPTXISD::StoreParam:
214     return "NVPTXISD::StoreParam";
215   case NVPTXISD::StoreParamS32:
216     return "NVPTXISD::StoreParamS32";
217   case NVPTXISD::StoreParamU32:
218     return "NVPTXISD::StoreParamU32";
219   case NVPTXISD::MoveToParam:
220     return "NVPTXISD::MoveToParam";
221   case NVPTXISD::CallArgBegin:
222     return "NVPTXISD::CallArgBegin";
223   case NVPTXISD::CallArg:
224     return "NVPTXISD::CallArg";
225   case NVPTXISD::LastCallArg:
226     return "NVPTXISD::LastCallArg";
227   case NVPTXISD::CallArgEnd:
228     return "NVPTXISD::CallArgEnd";
229   case NVPTXISD::CallVoid:
230     return "NVPTXISD::CallVoid";
231   case NVPTXISD::CallVal:
232     return "NVPTXISD::CallVal";
233   case NVPTXISD::CallSymbol:
234     return "NVPTXISD::CallSymbol";
235   case NVPTXISD::Prototype:
236     return "NVPTXISD::Prototype";
237   case NVPTXISD::MoveParam:
238     return "NVPTXISD::MoveParam";
239   case NVPTXISD::MoveRetval:
240     return "NVPTXISD::MoveRetval";
241   case NVPTXISD::MoveToRetval:
242     return "NVPTXISD::MoveToRetval";
243   case NVPTXISD::StoreRetval:
244     return "NVPTXISD::StoreRetval";
245   case NVPTXISD::PseudoUseParam:
246     return "NVPTXISD::PseudoUseParam";
247   case NVPTXISD::RETURN:
248     return "NVPTXISD::RETURN";
249   case NVPTXISD::CallSeqBegin:
250     return "NVPTXISD::CallSeqBegin";
251   case NVPTXISD::CallSeqEnd:
252     return "NVPTXISD::CallSeqEnd";
253   case NVPTXISD::LoadV2:
254     return "NVPTXISD::LoadV2";
255   case NVPTXISD::LoadV4:
256     return "NVPTXISD::LoadV4";
257   case NVPTXISD::LDGV2:
258     return "NVPTXISD::LDGV2";
259   case NVPTXISD::LDGV4:
260     return "NVPTXISD::LDGV4";
261   case NVPTXISD::LDUV2:
262     return "NVPTXISD::LDUV2";
263   case NVPTXISD::LDUV4:
264     return "NVPTXISD::LDUV4";
265   case NVPTXISD::StoreV2:
266     return "NVPTXISD::StoreV2";
267   case NVPTXISD::StoreV4:
268     return "NVPTXISD::StoreV4";
269   }
270 }
271
272 bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
273   return VT == MVT::i1;
274 }
275
276 SDValue
277 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
278   SDLoc dl(Op);
279   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
280   Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
281   return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
282 }
283
284 std::string NVPTXTargetLowering::getPrototype(
285     Type *retTy, const ArgListTy &Args,
286     const SmallVectorImpl<ISD::OutputArg> &Outs, unsigned retAlignment) const {
287
288   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
289
290   std::stringstream O;
291   O << "prototype_" << uniqueCallSite << " : .callprototype ";
292
293   if (retTy->getTypeID() == Type::VoidTyID)
294     O << "()";
295   else {
296     O << "(";
297     if (isABI) {
298       if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
299         unsigned size = 0;
300         if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
301           size = ITy->getBitWidth();
302           if (size < 32)
303             size = 32;
304         } else {
305           assert(retTy->isFloatingPointTy() &&
306                  "Floating point type expected here");
307           size = retTy->getPrimitiveSizeInBits();
308         }
309
310         O << ".param .b" << size << " _";
311       } else if (isa<PointerType>(retTy))
312         O << ".param .b" << getPointerTy().getSizeInBits() << " _";
313       else {
314         if ((retTy->getTypeID() == Type::StructTyID) ||
315             isa<VectorType>(retTy)) {
316           SmallVector<EVT, 16> vtparts;
317           ComputeValueVTs(*this, retTy, vtparts);
318           unsigned totalsz = 0;
319           for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
320             unsigned elems = 1;
321             EVT elemtype = vtparts[i];
322             if (vtparts[i].isVector()) {
323               elems = vtparts[i].getVectorNumElements();
324               elemtype = vtparts[i].getVectorElementType();
325             }
326             for (unsigned j = 0, je = elems; j != je; ++j) {
327               unsigned sz = elemtype.getSizeInBits();
328               if (elemtype.isInteger() && (sz < 8))
329                 sz = 8;
330               totalsz += sz / 8;
331             }
332           }
333           O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
334         } else {
335           assert(false && "Unknown return type");
336         }
337       }
338     } else {
339       SmallVector<EVT, 16> vtparts;
340       ComputeValueVTs(*this, retTy, vtparts);
341       unsigned idx = 0;
342       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
343         unsigned elems = 1;
344         EVT elemtype = vtparts[i];
345         if (vtparts[i].isVector()) {
346           elems = vtparts[i].getVectorNumElements();
347           elemtype = vtparts[i].getVectorElementType();
348         }
349
350         for (unsigned j = 0, je = elems; j != je; ++j) {
351           unsigned sz = elemtype.getSizeInBits();
352           if (elemtype.isInteger() && (sz < 32))
353             sz = 32;
354           O << ".reg .b" << sz << " _";
355           if (j < je - 1)
356             O << ", ";
357           ++idx;
358         }
359         if (i < e - 1)
360           O << ", ";
361       }
362     }
363     O << ") ";
364   }
365   O << "_ (";
366
367   bool first = true;
368   MVT thePointerTy = getPointerTy();
369
370   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
371     const Type *Ty = Args[i].Ty;
372     if (!first) {
373       O << ", ";
374     }
375     first = false;
376
377     if (Outs[i].Flags.isByVal() == false) {
378       unsigned sz = 0;
379       if (isa<IntegerType>(Ty)) {
380         sz = cast<IntegerType>(Ty)->getBitWidth();
381         if (sz < 32)
382           sz = 32;
383       } else if (isa<PointerType>(Ty))
384         sz = thePointerTy.getSizeInBits();
385       else
386         sz = Ty->getPrimitiveSizeInBits();
387       if (isABI)
388         O << ".param .b" << sz << " ";
389       else
390         O << ".reg .b" << sz << " ";
391       O << "_";
392       continue;
393     }
394     const PointerType *PTy = dyn_cast<PointerType>(Ty);
395     assert(PTy && "Param with byval attribute should be a pointer type");
396     Type *ETy = PTy->getElementType();
397
398     if (isABI) {
399       unsigned align = Outs[i].Flags.getByValAlign();
400       unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
401       O << ".param .align " << align << " .b8 ";
402       O << "_";
403       O << "[" << sz << "]";
404       continue;
405     } else {
406       SmallVector<EVT, 16> vtparts;
407       ComputeValueVTs(*this, ETy, vtparts);
408       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
409         unsigned elems = 1;
410         EVT elemtype = vtparts[i];
411         if (vtparts[i].isVector()) {
412           elems = vtparts[i].getVectorNumElements();
413           elemtype = vtparts[i].getVectorElementType();
414         }
415
416         for (unsigned j = 0, je = elems; j != je; ++j) {
417           unsigned sz = elemtype.getSizeInBits();
418           if (elemtype.isInteger() && (sz < 32))
419             sz = 32;
420           O << ".reg .b" << sz << " ";
421           O << "_";
422           if (j < je - 1)
423             O << ", ";
424         }
425         if (i < e - 1)
426           O << ", ";
427       }
428       continue;
429     }
430   }
431   O << ");";
432   return O.str();
433 }
434
435 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
436                                        SmallVectorImpl<SDValue> &InVals) const {
437   SelectionDAG &DAG = CLI.DAG;
438   SDLoc dl = CLI.DL;
439   SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
440   SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
441   SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
442   SDValue Chain = CLI.Chain;
443   SDValue Callee = CLI.Callee;
444   bool &isTailCall = CLI.IsTailCall;
445   ArgListTy &Args = CLI.Args;
446   Type *retTy = CLI.RetTy;
447   ImmutableCallSite *CS = CLI.CS;
448
449   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
450
451   SDValue tempChain = Chain;
452   Chain = DAG.getCALLSEQ_START(Chain,
453                                DAG.getIntPtrConstant(uniqueCallSite, true),
454                                dl);
455   SDValue InFlag = Chain.getValue(1);
456
457   assert((Outs.size() == Args.size()) &&
458          "Unexpected number of arguments to function call");
459   unsigned paramCount = 0;
460   // Declare the .params or .reg need to pass values
461   // to the function
462   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
463     EVT VT = Outs[i].VT;
464
465     if (Outs[i].Flags.isByVal() == false) {
466       // Plain scalar
467       // for ABI,    declare .param .b<size> .param<n>;
468       // for nonABI, declare .reg .b<size> .param<n>;
469       unsigned isReg = 1;
470       if (isABI)
471         isReg = 0;
472       unsigned sz = VT.getSizeInBits();
473       if (VT.isInteger() && (sz < 32))
474         sz = 32;
475       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
476       SDValue DeclareParamOps[] = { Chain,
477                                     DAG.getConstant(paramCount, MVT::i32),
478                                     DAG.getConstant(sz, MVT::i32),
479                                     DAG.getConstant(isReg, MVT::i32), InFlag };
480       Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
481                           DeclareParamOps, 5);
482       InFlag = Chain.getValue(1);
483       SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
484       SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
485                                  DAG.getConstant(0, MVT::i32), OutVals[i],
486                                  InFlag };
487
488       unsigned opcode = NVPTXISD::StoreParam;
489       if (isReg)
490         opcode = NVPTXISD::MoveToParam;
491       else {
492         if (Outs[i].Flags.isZExt())
493           opcode = NVPTXISD::StoreParamU32;
494         else if (Outs[i].Flags.isSExt())
495           opcode = NVPTXISD::StoreParamS32;
496       }
497       Chain = DAG.getNode(opcode, dl, CopyParamVTs, CopyParamOps, 5);
498
499       InFlag = Chain.getValue(1);
500       ++paramCount;
501       continue;
502     }
503     // struct or vector
504     SmallVector<EVT, 16> vtparts;
505     const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
506     assert(PTy && "Type of a byval parameter should be pointer");
507     ComputeValueVTs(*this, PTy->getElementType(), vtparts);
508
509     if (isABI) {
510       // declare .param .align 16 .b8 .param<n>[<size>];
511       unsigned sz = Outs[i].Flags.getByValSize();
512       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
513       // The ByValAlign in the Outs[i].Flags is alway set at this point, so we
514       // don't need to
515       // worry about natural alignment or not. See TargetLowering::LowerCallTo()
516       SDValue DeclareParamOps[] = {
517         Chain, DAG.getConstant(Outs[i].Flags.getByValAlign(), MVT::i32),
518         DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32),
519         InFlag
520       };
521       Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
522                           DeclareParamOps, 5);
523       InFlag = Chain.getValue(1);
524       unsigned curOffset = 0;
525       for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
526         unsigned elems = 1;
527         EVT elemtype = vtparts[j];
528         if (vtparts[j].isVector()) {
529           elems = vtparts[j].getVectorNumElements();
530           elemtype = vtparts[j].getVectorElementType();
531         }
532         for (unsigned k = 0, ke = elems; k != ke; ++k) {
533           unsigned sz = elemtype.getSizeInBits();
534           if (elemtype.isInteger() && (sz < 8))
535             sz = 8;
536           SDValue srcAddr =
537               DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i],
538                           DAG.getConstant(curOffset, getPointerTy()));
539           SDValue theVal =
540               DAG.getLoad(elemtype, dl, tempChain, srcAddr,
541                           MachinePointerInfo(), false, false, false, 0);
542           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
543           SDValue CopyParamOps[] = { Chain,
544                                      DAG.getConstant(paramCount, MVT::i32),
545                                      DAG.getConstant(curOffset, MVT::i32),
546                                      theVal, InFlag };
547           Chain = DAG.getNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
548                               CopyParamOps, 5);
549           InFlag = Chain.getValue(1);
550           curOffset += sz / 8;
551         }
552       }
553       ++paramCount;
554       continue;
555     }
556     // Non-abi, struct or vector
557     // Declare a bunch or .reg .b<size> .param<n>
558     unsigned curOffset = 0;
559     for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
560       unsigned elems = 1;
561       EVT elemtype = vtparts[j];
562       if (vtparts[j].isVector()) {
563         elems = vtparts[j].getVectorNumElements();
564         elemtype = vtparts[j].getVectorElementType();
565       }
566       for (unsigned k = 0, ke = elems; k != ke; ++k) {
567         unsigned sz = elemtype.getSizeInBits();
568         if (elemtype.isInteger() && (sz < 32))
569           sz = 32;
570         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
571         SDValue DeclareParamOps[] = { Chain,
572                                       DAG.getConstant(paramCount, MVT::i32),
573                                       DAG.getConstant(sz, MVT::i32),
574                                       DAG.getConstant(1, MVT::i32), InFlag };
575         Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
576                             DeclareParamOps, 5);
577         InFlag = Chain.getValue(1);
578         SDValue srcAddr =
579             DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i],
580                         DAG.getConstant(curOffset, getPointerTy()));
581         SDValue theVal =
582             DAG.getLoad(elemtype, dl, tempChain, srcAddr, MachinePointerInfo(),
583                         false, false, false, 0);
584         SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
585         SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
586                                    DAG.getConstant(0, MVT::i32), theVal,
587                                    InFlag };
588         Chain = DAG.getNode(NVPTXISD::MoveToParam, dl, CopyParamVTs,
589                             CopyParamOps, 5);
590         InFlag = Chain.getValue(1);
591         ++paramCount;
592       }
593     }
594   }
595
596   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
597   unsigned retAlignment = 0;
598
599   // Handle Result
600   unsigned retCount = 0;
601   if (Ins.size() > 0) {
602     SmallVector<EVT, 16> resvtparts;
603     ComputeValueVTs(*this, retTy, resvtparts);
604
605     // Declare one .param .align 16 .b8 func_retval0[<size>] for ABI or
606     // individual .reg .b<size> func_retval<0..> for non ABI
607     unsigned resultsz = 0;
608     for (unsigned i = 0, e = resvtparts.size(); i != e; ++i) {
609       unsigned elems = 1;
610       EVT elemtype = resvtparts[i];
611       if (resvtparts[i].isVector()) {
612         elems = resvtparts[i].getVectorNumElements();
613         elemtype = resvtparts[i].getVectorElementType();
614       }
615       for (unsigned j = 0, je = elems; j != je; ++j) {
616         unsigned sz = elemtype.getSizeInBits();
617         if (isABI == false) {
618           if (elemtype.isInteger() && (sz < 32))
619             sz = 32;
620         } else {
621           if (elemtype.isInteger() && (sz < 8))
622             sz = 8;
623         }
624         if (isABI == false) {
625           SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
626           SDValue DeclareRetOps[] = { Chain, DAG.getConstant(2, MVT::i32),
627                                       DAG.getConstant(sz, MVT::i32),
628                                       DAG.getConstant(retCount, MVT::i32),
629                                       InFlag };
630           Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
631                               DeclareRetOps, 5);
632           InFlag = Chain.getValue(1);
633           ++retCount;
634         }
635         resultsz += sz;
636       }
637     }
638     if (isABI) {
639       if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
640           retTy->isPointerTy()) {
641         // Scalar needs to be at least 32bit wide
642         if (resultsz < 32)
643           resultsz = 32;
644         SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
645         SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
646                                     DAG.getConstant(resultsz, MVT::i32),
647                                     DAG.getConstant(0, MVT::i32), InFlag };
648         Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
649                             DeclareRetOps, 5);
650         InFlag = Chain.getValue(1);
651       } else {
652         if (Func) { // direct call
653           if (!llvm::getAlign(*(CS->getCalledFunction()), 0, retAlignment))
654             retAlignment = getDataLayout()->getABITypeAlignment(retTy);
655         } else { // indirect call
656           const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
657           if (!llvm::getAlign(*CallI, 0, retAlignment))
658             retAlignment = getDataLayout()->getABITypeAlignment(retTy);
659         }
660         SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
661         SDValue DeclareRetOps[] = { Chain,
662                                     DAG.getConstant(retAlignment, MVT::i32),
663                                     DAG.getConstant(resultsz / 8, MVT::i32),
664                                     DAG.getConstant(0, MVT::i32), InFlag };
665         Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
666                             DeclareRetOps, 5);
667         InFlag = Chain.getValue(1);
668       }
669     }
670   }
671
672   if (!Func) {
673     // This is indirect function call case : PTX requires a prototype of the
674     // form
675     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
676     // to be emitted, and the label has to used as the last arg of call
677     // instruction.
678     // The prototype is embedded in a string and put as the operand for an
679     // INLINEASM SDNode.
680     SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
681     std::string proto_string = getPrototype(retTy, Args, Outs, retAlignment);
682     const char *asmstr = nvTM->getManagedStrPool()
683         ->getManagedString(proto_string.c_str())->c_str();
684     SDValue InlineAsmOps[] = {
685       Chain, DAG.getTargetExternalSymbol(asmstr, getPointerTy()),
686       DAG.getMDNode(0), DAG.getTargetConstant(0, MVT::i32), InFlag
687     };
688     Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
689     InFlag = Chain.getValue(1);
690   }
691   // Op to just print "call"
692   SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
693   SDValue PrintCallOps[] = {
694     Chain,
695     DAG.getConstant(isABI ? ((Ins.size() == 0) ? 0 : 1) : retCount, MVT::i32),
696     InFlag
697   };
698   Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
699                       dl, PrintCallVTs, PrintCallOps, 3);
700   InFlag = Chain.getValue(1);
701
702   // Ops to print out the function name
703   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
704   SDValue CallVoidOps[] = { Chain, Callee, InFlag };
705   Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
706   InFlag = Chain.getValue(1);
707
708   // Ops to print out the param list
709   SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
710   SDValue CallArgBeginOps[] = { Chain, InFlag };
711   Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
712                       CallArgBeginOps, 2);
713   InFlag = Chain.getValue(1);
714
715   for (unsigned i = 0, e = paramCount; i != e; ++i) {
716     unsigned opcode;
717     if (i == (e - 1))
718       opcode = NVPTXISD::LastCallArg;
719     else
720       opcode = NVPTXISD::CallArg;
721     SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
722     SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
723                              DAG.getConstant(i, MVT::i32), InFlag };
724     Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
725     InFlag = Chain.getValue(1);
726   }
727   SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
728   SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32),
729                               InFlag };
730   Chain =
731       DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps, 3);
732   InFlag = Chain.getValue(1);
733
734   if (!Func) {
735     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
736     SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32),
737                                InFlag };
738     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
739     InFlag = Chain.getValue(1);
740   }
741
742   // Generate loads from param memory/moves from registers for result
743   if (Ins.size() > 0) {
744     if (isABI) {
745       unsigned resoffset = 0;
746       for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
747         unsigned sz = Ins[i].VT.getSizeInBits();
748         if (Ins[i].VT.isInteger() && (sz < 8))
749           sz = 8;
750         EVT LoadRetVTs[] = { Ins[i].VT, MVT::Other, MVT::Glue };
751         SDValue LoadRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
752                                  DAG.getConstant(resoffset, MVT::i32), InFlag };
753         SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, LoadRetVTs,
754                                      LoadRetOps, array_lengthof(LoadRetOps));
755         Chain = retval.getValue(1);
756         InFlag = retval.getValue(2);
757         InVals.push_back(retval);
758         resoffset += sz / 8;
759       }
760     } else {
761       SmallVector<EVT, 16> resvtparts;
762       ComputeValueVTs(*this, retTy, resvtparts);
763
764       assert(Ins.size() == resvtparts.size() &&
765              "Unexpected number of return values in non-ABI case");
766       unsigned paramNum = 0;
767       for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
768         assert(EVT(Ins[i].VT) == resvtparts[i] &&
769                "Unexpected EVT type in non-ABI case");
770         unsigned numelems = 1;
771         EVT elemtype = Ins[i].VT;
772         if (Ins[i].VT.isVector()) {
773           numelems = Ins[i].VT.getVectorNumElements();
774           elemtype = Ins[i].VT.getVectorElementType();
775         }
776         std::vector<SDValue> tempRetVals;
777         for (unsigned j = 0; j < numelems; ++j) {
778           EVT MoveRetVTs[] = { elemtype, MVT::Other, MVT::Glue };
779           SDValue MoveRetOps[] = { Chain, DAG.getConstant(0, MVT::i32),
780                                    DAG.getConstant(paramNum, MVT::i32),
781                                    InFlag };
782           SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, MoveRetVTs,
783                                        MoveRetOps, array_lengthof(MoveRetOps));
784           Chain = retval.getValue(1);
785           InFlag = retval.getValue(2);
786           tempRetVals.push_back(retval);
787           ++paramNum;
788         }
789         if (Ins[i].VT.isVector())
790           InVals.push_back(DAG.getNode(ISD::BUILD_VECTOR, dl, Ins[i].VT,
791                                        &tempRetVals[0], tempRetVals.size()));
792         else
793           InVals.push_back(tempRetVals[0]);
794       }
795     }
796   }
797   Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
798                              DAG.getIntPtrConstant(uniqueCallSite + 1, true),
799                              InFlag, dl);
800   uniqueCallSite++;
801
802   // set isTailCall to false for now, until we figure out how to express
803   // tail call optimization in PTX
804   isTailCall = false;
805   return Chain;
806 }
807
808 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
809 // (see LegalizeDAG.cpp). This is slow and uses local memory.
810 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
811 SDValue
812 NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
813   SDNode *Node = Op.getNode();
814   SDLoc dl(Node);
815   SmallVector<SDValue, 8> Ops;
816   unsigned NumOperands = Node->getNumOperands();
817   for (unsigned i = 0; i < NumOperands; ++i) {
818     SDValue SubOp = Node->getOperand(i);
819     EVT VVT = SubOp.getNode()->getValueType(0);
820     EVT EltVT = VVT.getVectorElementType();
821     unsigned NumSubElem = VVT.getVectorNumElements();
822     for (unsigned j = 0; j < NumSubElem; ++j) {
823       Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
824                                 DAG.getIntPtrConstant(j)));
825     }
826   }
827   return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), &Ops[0],
828                      Ops.size());
829 }
830
831 SDValue
832 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
833   switch (Op.getOpcode()) {
834   case ISD::RETURNADDR:
835     return SDValue();
836   case ISD::FRAMEADDR:
837     return SDValue();
838   case ISD::GlobalAddress:
839     return LowerGlobalAddress(Op, DAG);
840   case ISD::INTRINSIC_W_CHAIN:
841     return Op;
842   case ISD::BUILD_VECTOR:
843   case ISD::EXTRACT_SUBVECTOR:
844     return Op;
845   case ISD::CONCAT_VECTORS:
846     return LowerCONCAT_VECTORS(Op, DAG);
847   case ISD::STORE:
848     return LowerSTORE(Op, DAG);
849   case ISD::LOAD:
850     return LowerLOAD(Op, DAG);
851   default:
852     llvm_unreachable("Custom lowering not defined for operation");
853   }
854 }
855
856 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
857   if (Op.getValueType() == MVT::i1)
858     return LowerLOADi1(Op, DAG);
859   else
860     return SDValue();
861 }
862
863 // v = ld i1* addr
864 //   =>
865 // v1 = ld i8* addr
866 // v = trunc v1 to i1
867 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
868   SDNode *Node = Op.getNode();
869   LoadSDNode *LD = cast<LoadSDNode>(Node);
870   SDLoc dl(Node);
871   assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
872   assert(Node->getValueType(0) == MVT::i1 &&
873          "Custom lowering for i1 load only");
874   SDValue newLD =
875       DAG.getLoad(MVT::i8, dl, LD->getChain(), LD->getBasePtr(),
876                   LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(),
877                   LD->isInvariant(), LD->getAlignment());
878   SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
879   // The legalizer (the caller) is expecting two values from the legalized
880   // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
881   // in LegalizeDAG.cpp which also uses MergeValues.
882   SDValue Ops[] = { result, LD->getChain() };
883   return DAG.getMergeValues(Ops, 2, dl);
884 }
885
886 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
887   EVT ValVT = Op.getOperand(1).getValueType();
888   if (ValVT == MVT::i1)
889     return LowerSTOREi1(Op, DAG);
890   else if (ValVT.isVector())
891     return LowerSTOREVector(Op, DAG);
892   else
893     return SDValue();
894 }
895
896 SDValue
897 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
898   SDNode *N = Op.getNode();
899   SDValue Val = N->getOperand(1);
900   SDLoc DL(N);
901   EVT ValVT = Val.getValueType();
902
903   if (ValVT.isVector()) {
904     // We only handle "native" vector sizes for now, e.g. <4 x double> is not
905     // legal.  We can (and should) split that into 2 stores of <2 x double> here
906     // but I'm leaving that as a TODO for now.
907     if (!ValVT.isSimple())
908       return SDValue();
909     switch (ValVT.getSimpleVT().SimpleTy) {
910     default:
911       return SDValue();
912     case MVT::v2i8:
913     case MVT::v2i16:
914     case MVT::v2i32:
915     case MVT::v2i64:
916     case MVT::v2f32:
917     case MVT::v2f64:
918     case MVT::v4i8:
919     case MVT::v4i16:
920     case MVT::v4i32:
921     case MVT::v4f32:
922       // This is a "native" vector type
923       break;
924     }
925
926     unsigned Opcode = 0;
927     EVT EltVT = ValVT.getVectorElementType();
928     unsigned NumElts = ValVT.getVectorNumElements();
929
930     // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
931     // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
932     // stored type to i16 and propogate the "real" type as the memory type.
933     bool NeedExt = false;
934     if (EltVT.getSizeInBits() < 16)
935       NeedExt = true;
936
937     switch (NumElts) {
938     default:
939       return SDValue();
940     case 2:
941       Opcode = NVPTXISD::StoreV2;
942       break;
943     case 4: {
944       Opcode = NVPTXISD::StoreV4;
945       break;
946     }
947     }
948
949     SmallVector<SDValue, 8> Ops;
950
951     // First is the chain
952     Ops.push_back(N->getOperand(0));
953
954     // Then the split values
955     for (unsigned i = 0; i < NumElts; ++i) {
956       SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
957                                    DAG.getIntPtrConstant(i));
958       if (NeedExt)
959         // ANY_EXTEND is correct here since the store will only look at the
960         // lower-order bits anyway.
961         ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
962       Ops.push_back(ExtVal);
963     }
964
965     // Then any remaining arguments
966     for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
967       Ops.push_back(N->getOperand(i));
968     }
969
970     MemSDNode *MemSD = cast<MemSDNode>(N);
971
972     SDValue NewSt = DAG.getMemIntrinsicNode(
973         Opcode, DL, DAG.getVTList(MVT::Other), &Ops[0], Ops.size(),
974         MemSD->getMemoryVT(), MemSD->getMemOperand());
975
976     //return DCI.CombineTo(N, NewSt, true);
977     return NewSt;
978   }
979
980   return SDValue();
981 }
982
983 // st i1 v, addr
984 //    =>
985 // v1 = zxt v to i8
986 // st i8, addr
987 SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
988   SDNode *Node = Op.getNode();
989   SDLoc dl(Node);
990   StoreSDNode *ST = cast<StoreSDNode>(Node);
991   SDValue Tmp1 = ST->getChain();
992   SDValue Tmp2 = ST->getBasePtr();
993   SDValue Tmp3 = ST->getValue();
994   assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
995   unsigned Alignment = ST->getAlignment();
996   bool isVolatile = ST->isVolatile();
997   bool isNonTemporal = ST->isNonTemporal();
998   Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, Tmp3);
999   SDValue Result = DAG.getStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(),
1000                                 isVolatile, isNonTemporal, Alignment);
1001   return Result;
1002 }
1003
1004 SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname,
1005                                         int idx, EVT v) const {
1006   std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
1007   std::stringstream suffix;
1008   suffix << idx;
1009   *name += suffix.str();
1010   return DAG.getTargetExternalSymbol(name->c_str(), v);
1011 }
1012
1013 SDValue
1014 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1015   return getExtSymb(DAG, ".PARAM", idx, v);
1016 }
1017
1018 SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1019   return getExtSymb(DAG, ".HLPPARAM", idx);
1020 }
1021
1022 // Check to see if the kernel argument is image*_t or sampler_t
1023
1024 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1025   static const char *const specialTypes[] = { "struct._image2d_t",
1026                                               "struct._image3d_t",
1027                                               "struct._sampler_t" };
1028
1029   const Type *Ty = arg->getType();
1030   const PointerType *PTy = dyn_cast<PointerType>(Ty);
1031
1032   if (!PTy)
1033     return false;
1034
1035   if (!context)
1036     return false;
1037
1038   const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1039   const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1040
1041   for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1042     if (TypeName == specialTypes[i])
1043       return true;
1044
1045   return false;
1046 }
1047
1048 SDValue NVPTXTargetLowering::LowerFormalArguments(
1049     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1050     const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG,
1051     SmallVectorImpl<SDValue> &InVals) const {
1052   MachineFunction &MF = DAG.getMachineFunction();
1053   const DataLayout *TD = getDataLayout();
1054
1055   const Function *F = MF.getFunction();
1056   const AttributeSet &PAL = F->getAttributes();
1057
1058   SDValue Root = DAG.getRoot();
1059   std::vector<SDValue> OutChains;
1060
1061   bool isKernel = llvm::isKernelFunction(*F);
1062   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1063
1064   std::vector<Type *> argTypes;
1065   std::vector<const Argument *> theArgs;
1066   for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1067        I != E; ++I) {
1068     theArgs.push_back(I);
1069     argTypes.push_back(I->getType());
1070   }
1071   //assert(argTypes.size() == Ins.size() &&
1072   //       "Ins types and function types did not match");
1073
1074   int idx = 0;
1075   for (unsigned i = 0, e = argTypes.size(); i != e; ++i, ++idx) {
1076     Type *Ty = argTypes[i];
1077     EVT ObjectVT = getValueType(Ty);
1078     //assert(ObjectVT == Ins[i].VT &&
1079     //       "Ins type did not match function type");
1080
1081     // If the kernel argument is image*_t or sampler_t, convert it to
1082     // a i32 constant holding the parameter position. This can later
1083     // matched in the AsmPrinter to output the correct mangled name.
1084     if (isImageOrSamplerVal(
1085             theArgs[i],
1086             (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent()
1087                                      : 0))) {
1088       assert(isKernel && "Only kernels can have image/sampler params");
1089       InVals.push_back(DAG.getConstant(i + 1, MVT::i32));
1090       continue;
1091     }
1092
1093     if (theArgs[i]->use_empty()) {
1094       // argument is dead
1095       if (ObjectVT.isVector()) {
1096         EVT EltVT = ObjectVT.getVectorElementType();
1097         unsigned NumElts = ObjectVT.getVectorNumElements();
1098         for (unsigned vi = 0; vi < NumElts; ++vi) {
1099           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, EltVT));
1100         }
1101       } else {
1102         InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
1103       }
1104       continue;
1105     }
1106
1107     // In the following cases, assign a node order of "idx+1"
1108     // to newly created nodes. The SDNOdes for params have to
1109     // appear in the same order as their order of appearance
1110     // in the original function. "idx+1" holds that order.
1111     if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
1112       if (ObjectVT.isVector()) {
1113         unsigned NumElts = ObjectVT.getVectorNumElements();
1114         EVT EltVT = ObjectVT.getVectorElementType();
1115         unsigned Offset = 0;
1116         for (unsigned vi = 0; vi < NumElts; ++vi) {
1117           SDValue A = getParamSymbol(DAG, idx, getPointerTy());
1118           SDValue B = DAG.getIntPtrConstant(Offset);
1119           SDValue Addr = DAG.getNode(ISD::ADD, dl, getPointerTy(),
1120                                      //getParamSymbol(DAG, idx, EltVT),
1121                                      //DAG.getConstant(Offset, getPointerTy()));
1122                                      A, B);
1123           Value *SrcValue = Constant::getNullValue(PointerType::get(
1124               EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1125           SDValue Ld = DAG.getLoad(
1126               EltVT, dl, Root, Addr, MachinePointerInfo(SrcValue), false, false,
1127               false,
1128               TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
1129           Offset += EltVT.getStoreSizeInBits() / 8;
1130           InVals.push_back(Ld);
1131         }
1132         continue;
1133       }
1134
1135       // A plain scalar.
1136       if (isABI || isKernel) {
1137         // If ABI, load from the param symbol
1138         SDValue Arg = getParamSymbol(DAG, idx);
1139         // Conjure up a value that we can get the address space from.
1140         // FIXME: Using a constant here is a hack.
1141         Value *srcValue = Constant::getNullValue(
1142             PointerType::get(ObjectVT.getTypeForEVT(F->getContext()),
1143                              llvm::ADDRESS_SPACE_PARAM));
1144         SDValue p = DAG.getLoad(
1145             ObjectVT, dl, Root, Arg, MachinePointerInfo(srcValue), false, false,
1146             false,
1147             TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1148         if (p.getNode())
1149           p.getNode()->setIROrder(idx + 1);
1150         InVals.push_back(p);
1151       } else {
1152         // If no ABI, just move the param symbol
1153         SDValue Arg = getParamSymbol(DAG, idx, ObjectVT);
1154         SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1155         if (p.getNode())
1156           p.getNode()->setIROrder(idx + 1);
1157         InVals.push_back(p);
1158       }
1159       continue;
1160     }
1161
1162     // Param has ByVal attribute
1163     if (isABI || isKernel) {
1164       // Return MoveParam(param symbol).
1165       // Ideally, the param symbol can be returned directly,
1166       // but when SDNode builder decides to use it in a CopyToReg(),
1167       // machine instruction fails because TargetExternalSymbol
1168       // (not lowered) is target dependent, and CopyToReg assumes
1169       // the source is lowered.
1170       SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1171       SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1172       if (p.getNode())
1173         p.getNode()->setIROrder(idx + 1);
1174       if (isKernel)
1175         InVals.push_back(p);
1176       else {
1177         SDValue p2 = DAG.getNode(
1178             ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1179             DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
1180         InVals.push_back(p2);
1181       }
1182     } else {
1183       // Have to move a set of param symbols to registers and
1184       // store them locally and return the local pointer in InVals
1185       const PointerType *elemPtrType = dyn_cast<PointerType>(argTypes[i]);
1186       assert(elemPtrType && "Byval parameter should be a pointer type");
1187       Type *elemType = elemPtrType->getElementType();
1188       // Compute the constituent parts
1189       SmallVector<EVT, 16> vtparts;
1190       SmallVector<uint64_t, 16> offsets;
1191       ComputeValueVTs(*this, elemType, vtparts, &offsets, 0);
1192       unsigned totalsize = 0;
1193       for (unsigned j = 0, je = vtparts.size(); j != je; ++j)
1194         totalsize += vtparts[j].getStoreSizeInBits();
1195       SDValue localcopy = DAG.getFrameIndex(
1196           MF.getFrameInfo()->CreateStackObject(totalsize / 8, 16, false),
1197           getPointerTy());
1198       unsigned sizesofar = 0;
1199       std::vector<SDValue> theChains;
1200       for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
1201         unsigned numElems = 1;
1202         if (vtparts[j].isVector())
1203           numElems = vtparts[j].getVectorNumElements();
1204         for (unsigned k = 0, ke = numElems; k != ke; ++k) {
1205           EVT tmpvt = vtparts[j];
1206           if (tmpvt.isVector())
1207             tmpvt = tmpvt.getVectorElementType();
1208           SDValue arg = DAG.getNode(NVPTXISD::MoveParam, dl, tmpvt,
1209                                     getParamSymbol(DAG, idx, tmpvt));
1210           SDValue addr =
1211               DAG.getNode(ISD::ADD, dl, getPointerTy(), localcopy,
1212                           DAG.getConstant(sizesofar, getPointerTy()));
1213           theChains.push_back(DAG.getStore(
1214               Chain, dl, arg, addr, MachinePointerInfo(), false, false, 0));
1215           sizesofar += tmpvt.getStoreSizeInBits() / 8;
1216           ++idx;
1217         }
1218       }
1219       --idx;
1220       Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &theChains[0],
1221                           theChains.size());
1222       InVals.push_back(localcopy);
1223     }
1224   }
1225
1226   // Clang will check explicit VarArg and issue error if any. However, Clang
1227   // will let code with
1228   // implicit var arg like f() pass.
1229   // We treat this case as if the arg list is empty.
1230   //if (F.isVarArg()) {
1231   // assert(0 && "VarArg not supported yet!");
1232   //}
1233
1234   if (!OutChains.empty())
1235     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &OutChains[0],
1236                             OutChains.size()));
1237
1238   return Chain;
1239 }
1240
1241 SDValue NVPTXTargetLowering::LowerReturn(
1242     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1243     const SmallVectorImpl<ISD::OutputArg> &Outs,
1244     const SmallVectorImpl<SDValue> &OutVals, SDLoc dl,
1245     SelectionDAG &DAG) const {
1246
1247   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1248
1249   unsigned sizesofar = 0;
1250   unsigned idx = 0;
1251   for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
1252     SDValue theVal = OutVals[i];
1253     EVT theValType = theVal.getValueType();
1254     unsigned numElems = 1;
1255     if (theValType.isVector())
1256       numElems = theValType.getVectorNumElements();
1257     for (unsigned j = 0, je = numElems; j != je; ++j) {
1258       SDValue tmpval = theVal;
1259       if (theValType.isVector())
1260         tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1261                              theValType.getVectorElementType(), tmpval,
1262                              DAG.getIntPtrConstant(j));
1263       Chain = DAG.getNode(
1264           isABI ? NVPTXISD::StoreRetval : NVPTXISD::MoveToRetval, dl,
1265           MVT::Other, Chain, DAG.getConstant(isABI ? sizesofar : idx, MVT::i32),
1266           tmpval);
1267       if (theValType.isVector())
1268         sizesofar += theValType.getVectorElementType().getStoreSizeInBits() / 8;
1269       else
1270         sizesofar += theValType.getStoreSizeInBits() / 8;
1271       ++idx;
1272     }
1273   }
1274
1275   return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1276 }
1277
1278 void NVPTXTargetLowering::LowerAsmOperandForConstraint(
1279     SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
1280     SelectionDAG &DAG) const {
1281   if (Constraint.length() > 1)
1282     return;
1283   else
1284     TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1285 }
1286
1287 // NVPTX suuport vector of legal types of any length in Intrinsics because the
1288 // NVPTX specific type legalizer
1289 // will legalize them to the PTX supported length.
1290 bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1291   if (isTypeLegal(VT))
1292     return true;
1293   if (VT.isVector()) {
1294     MVT eVT = VT.getVectorElementType();
1295     if (isTypeLegal(eVT))
1296       return true;
1297   }
1298   return false;
1299 }
1300
1301 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
1302 // TgtMemIntrinsic
1303 // because we need the information that is only available in the "Value" type
1304 // of destination
1305 // pointer. In particular, the address space information.
1306 bool NVPTXTargetLowering::getTgtMemIntrinsic(
1307     IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const {
1308   switch (Intrinsic) {
1309   default:
1310     return false;
1311
1312   case Intrinsic::nvvm_atomic_load_add_f32:
1313     Info.opc = ISD::INTRINSIC_W_CHAIN;
1314     Info.memVT = MVT::f32;
1315     Info.ptrVal = I.getArgOperand(0);
1316     Info.offset = 0;
1317     Info.vol = 0;
1318     Info.readMem = true;
1319     Info.writeMem = true;
1320     Info.align = 0;
1321     return true;
1322
1323   case Intrinsic::nvvm_atomic_load_inc_32:
1324   case Intrinsic::nvvm_atomic_load_dec_32:
1325     Info.opc = ISD::INTRINSIC_W_CHAIN;
1326     Info.memVT = MVT::i32;
1327     Info.ptrVal = I.getArgOperand(0);
1328     Info.offset = 0;
1329     Info.vol = 0;
1330     Info.readMem = true;
1331     Info.writeMem = true;
1332     Info.align = 0;
1333     return true;
1334
1335   case Intrinsic::nvvm_ldu_global_i:
1336   case Intrinsic::nvvm_ldu_global_f:
1337   case Intrinsic::nvvm_ldu_global_p:
1338
1339     Info.opc = ISD::INTRINSIC_W_CHAIN;
1340     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
1341       Info.memVT = MVT::i32;
1342     else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
1343       Info.memVT = getPointerTy();
1344     else
1345       Info.memVT = MVT::f32;
1346     Info.ptrVal = I.getArgOperand(0);
1347     Info.offset = 0;
1348     Info.vol = 0;
1349     Info.readMem = true;
1350     Info.writeMem = false;
1351     Info.align = 0;
1352     return true;
1353
1354   }
1355   return false;
1356 }
1357
1358 /// isLegalAddressingMode - Return true if the addressing mode represented
1359 /// by AM is legal for this target, for a load/store of the specified type.
1360 /// Used to guide target specific optimizations, like loop strength reduction
1361 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
1362 /// (CodeGenPrepare.cpp)
1363 bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
1364                                                 Type *Ty) const {
1365
1366   // AddrMode - This represents an addressing mode of:
1367   //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
1368   //
1369   // The legal address modes are
1370   // - [avar]
1371   // - [areg]
1372   // - [areg+immoff]
1373   // - [immAddr]
1374
1375   if (AM.BaseGV) {
1376     if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
1377       return false;
1378     return true;
1379   }
1380
1381   switch (AM.Scale) {
1382   case 0: // "r", "r+i" or "i" is allowed
1383     break;
1384   case 1:
1385     if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
1386       return false;
1387     // Otherwise we have r+i.
1388     break;
1389   default:
1390     // No scale > 1 is allowed
1391     return false;
1392   }
1393   return true;
1394 }
1395
1396 //===----------------------------------------------------------------------===//
1397 //                         NVPTX Inline Assembly Support
1398 //===----------------------------------------------------------------------===//
1399
1400 /// getConstraintType - Given a constraint letter, return the type of
1401 /// constraint it is for this target.
1402 NVPTXTargetLowering::ConstraintType
1403 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
1404   if (Constraint.size() == 1) {
1405     switch (Constraint[0]) {
1406     default:
1407       break;
1408     case 'r':
1409     case 'h':
1410     case 'c':
1411     case 'l':
1412     case 'f':
1413     case 'd':
1414     case '0':
1415     case 'N':
1416       return C_RegisterClass;
1417     }
1418   }
1419   return TargetLowering::getConstraintType(Constraint);
1420 }
1421
1422 std::pair<unsigned, const TargetRegisterClass *>
1423 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
1424                                                   MVT VT) const {
1425   if (Constraint.size() == 1) {
1426     switch (Constraint[0]) {
1427     case 'c':
1428       return std::make_pair(0U, &NVPTX::Int8RegsRegClass);
1429     case 'h':
1430       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
1431     case 'r':
1432       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
1433     case 'l':
1434     case 'N':
1435       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
1436     case 'f':
1437       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
1438     case 'd':
1439       return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
1440     }
1441   }
1442   return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
1443 }
1444
1445 /// getFunctionAlignment - Return the Log2 alignment of this function.
1446 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
1447   return 4;
1448 }
1449
1450 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
1451 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
1452                               SmallVectorImpl<SDValue> &Results) {
1453   EVT ResVT = N->getValueType(0);
1454   SDLoc DL(N);
1455
1456   assert(ResVT.isVector() && "Vector load must have vector type");
1457
1458   // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1459   // legal.  We can (and should) split that into 2 loads of <2 x double> here
1460   // but I'm leaving that as a TODO for now.
1461   assert(ResVT.isSimple() && "Can only handle simple types");
1462   switch (ResVT.getSimpleVT().SimpleTy) {
1463   default:
1464     return;
1465   case MVT::v2i8:
1466   case MVT::v2i16:
1467   case MVT::v2i32:
1468   case MVT::v2i64:
1469   case MVT::v2f32:
1470   case MVT::v2f64:
1471   case MVT::v4i8:
1472   case MVT::v4i16:
1473   case MVT::v4i32:
1474   case MVT::v4f32:
1475     // This is a "native" vector type
1476     break;
1477   }
1478
1479   EVT EltVT = ResVT.getVectorElementType();
1480   unsigned NumElts = ResVT.getVectorNumElements();
1481
1482   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
1483   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1484   // loaded type to i16 and propogate the "real" type as the memory type.
1485   bool NeedTrunc = false;
1486   if (EltVT.getSizeInBits() < 16) {
1487     EltVT = MVT::i16;
1488     NeedTrunc = true;
1489   }
1490
1491   unsigned Opcode = 0;
1492   SDVTList LdResVTs;
1493
1494   switch (NumElts) {
1495   default:
1496     return;
1497   case 2:
1498     Opcode = NVPTXISD::LoadV2;
1499     LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
1500     break;
1501   case 4: {
1502     Opcode = NVPTXISD::LoadV4;
1503     EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
1504     LdResVTs = DAG.getVTList(ListVTs, 5);
1505     break;
1506   }
1507   }
1508
1509   SmallVector<SDValue, 8> OtherOps;
1510
1511   // Copy regular operands
1512   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1513     OtherOps.push_back(N->getOperand(i));
1514
1515   LoadSDNode *LD = cast<LoadSDNode>(N);
1516
1517   // The select routine does not have access to the LoadSDNode instance, so
1518   // pass along the extension information
1519   OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
1520
1521   SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
1522                                           OtherOps.size(), LD->getMemoryVT(),
1523                                           LD->getMemOperand());
1524
1525   SmallVector<SDValue, 4> ScalarRes;
1526
1527   for (unsigned i = 0; i < NumElts; ++i) {
1528     SDValue Res = NewLD.getValue(i);
1529     if (NeedTrunc)
1530       Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
1531     ScalarRes.push_back(Res);
1532   }
1533
1534   SDValue LoadChain = NewLD.getValue(NumElts);
1535
1536   SDValue BuildVec =
1537       DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
1538
1539   Results.push_back(BuildVec);
1540   Results.push_back(LoadChain);
1541 }
1542
1543 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
1544                                      SmallVectorImpl<SDValue> &Results) {
1545   SDValue Chain = N->getOperand(0);
1546   SDValue Intrin = N->getOperand(1);
1547   SDLoc DL(N);
1548
1549   // Get the intrinsic ID
1550   unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
1551   switch (IntrinNo) {
1552   default:
1553     return;
1554   case Intrinsic::nvvm_ldg_global_i:
1555   case Intrinsic::nvvm_ldg_global_f:
1556   case Intrinsic::nvvm_ldg_global_p:
1557   case Intrinsic::nvvm_ldu_global_i:
1558   case Intrinsic::nvvm_ldu_global_f:
1559   case Intrinsic::nvvm_ldu_global_p: {
1560     EVT ResVT = N->getValueType(0);
1561
1562     if (ResVT.isVector()) {
1563       // Vector LDG/LDU
1564
1565       unsigned NumElts = ResVT.getVectorNumElements();
1566       EVT EltVT = ResVT.getVectorElementType();
1567
1568       // Since LDU/LDG are target nodes, we cannot rely on DAG type legalization.
1569       // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1570       // loaded type to i16 and propogate the "real" type as the memory type.
1571       bool NeedTrunc = false;
1572       if (EltVT.getSizeInBits() < 16) {
1573         EltVT = MVT::i16;
1574         NeedTrunc = true;
1575       }
1576
1577       unsigned Opcode = 0;
1578       SDVTList LdResVTs;
1579
1580       switch (NumElts) {
1581       default:
1582         return;
1583       case 2:
1584         switch (IntrinNo) {
1585         default:
1586           return;
1587         case Intrinsic::nvvm_ldg_global_i:
1588         case Intrinsic::nvvm_ldg_global_f:
1589         case Intrinsic::nvvm_ldg_global_p:
1590           Opcode = NVPTXISD::LDGV2;
1591           break;
1592         case Intrinsic::nvvm_ldu_global_i:
1593         case Intrinsic::nvvm_ldu_global_f:
1594         case Intrinsic::nvvm_ldu_global_p:
1595           Opcode = NVPTXISD::LDUV2;
1596           break;
1597         }
1598         LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
1599         break;
1600       case 4: {
1601         switch (IntrinNo) {
1602         default:
1603           return;
1604         case Intrinsic::nvvm_ldg_global_i:
1605         case Intrinsic::nvvm_ldg_global_f:
1606         case Intrinsic::nvvm_ldg_global_p:
1607           Opcode = NVPTXISD::LDGV4;
1608           break;
1609         case Intrinsic::nvvm_ldu_global_i:
1610         case Intrinsic::nvvm_ldu_global_f:
1611         case Intrinsic::nvvm_ldu_global_p:
1612           Opcode = NVPTXISD::LDUV4;
1613           break;
1614         }
1615         EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
1616         LdResVTs = DAG.getVTList(ListVTs, 5);
1617         break;
1618       }
1619       }
1620
1621       SmallVector<SDValue, 8> OtherOps;
1622
1623       // Copy regular operands
1624
1625       OtherOps.push_back(Chain); // Chain
1626                                  // Skip operand 1 (intrinsic ID)
1627                                  // Others
1628       for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
1629         OtherOps.push_back(N->getOperand(i));
1630
1631       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
1632
1633       SDValue NewLD = DAG.getMemIntrinsicNode(
1634           Opcode, DL, LdResVTs, &OtherOps[0], OtherOps.size(),
1635           MemSD->getMemoryVT(), MemSD->getMemOperand());
1636
1637       SmallVector<SDValue, 4> ScalarRes;
1638
1639       for (unsigned i = 0; i < NumElts; ++i) {
1640         SDValue Res = NewLD.getValue(i);
1641         if (NeedTrunc)
1642           Res =
1643               DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
1644         ScalarRes.push_back(Res);
1645       }
1646
1647       SDValue LoadChain = NewLD.getValue(NumElts);
1648
1649       SDValue BuildVec =
1650           DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
1651
1652       Results.push_back(BuildVec);
1653       Results.push_back(LoadChain);
1654     } else {
1655       // i8 LDG/LDU
1656       assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
1657              "Custom handling of non-i8 ldu/ldg?");
1658
1659       // Just copy all operands as-is
1660       SmallVector<SDValue, 4> Ops;
1661       for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1662         Ops.push_back(N->getOperand(i));
1663
1664       // Force output to i16
1665       SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
1666
1667       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
1668
1669       // We make sure the memory type is i8, which will be used during isel
1670       // to select the proper instruction.
1671       SDValue NewLD =
1672           DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0],
1673                                   Ops.size(), MVT::i8, MemSD->getMemOperand());
1674
1675       Results.push_back(NewLD.getValue(0));
1676       Results.push_back(NewLD.getValue(1));
1677     }
1678   }
1679   }
1680 }
1681
1682 void NVPTXTargetLowering::ReplaceNodeResults(
1683     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
1684   switch (N->getOpcode()) {
1685   default:
1686     report_fatal_error("Unhandled custom legalization");
1687   case ISD::LOAD:
1688     ReplaceLoadVector(N, DAG, Results);
1689     return;
1690   case ISD::INTRINSIC_W_CHAIN:
1691     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
1692     return;
1693   }
1694 }