[NVPTX] Fix logic error in loading vector parameters of more than 4 components
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
1 //
2 //                     The LLVM Compiler Infrastructure
3 //
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "NVPTXISelLowering.h"
15 #include "NVPTX.h"
16 #include "NVPTXTargetMachine.h"
17 #include "NVPTXTargetObjectFile.h"
18 #include "NVPTXUtilities.h"
19 #include "llvm/CodeGen/Analysis.h"
20 #include "llvm/CodeGen/MachineFrameInfo.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/MC/MCSectionELF.h"
32 #include "llvm/Support/CallSite.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <sstream>
38
39 #undef DEBUG_TYPE
40 #define DEBUG_TYPE "nvptx-lower"
41
42 using namespace llvm;
43
44 static unsigned int uniqueCallSite = 0;
45
46 static cl::opt<bool> sched4reg(
47     "nvptx-sched4reg",
48     cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
49
50 static bool IsPTXVectorType(MVT VT) {
51   switch (VT.SimpleTy) {
52   default:
53     return false;
54   case MVT::v2i1:
55   case MVT::v4i1:
56   case MVT::v2i8:
57   case MVT::v4i8:
58   case MVT::v2i16:
59   case MVT::v4i16:
60   case MVT::v2i32:
61   case MVT::v4i32:
62   case MVT::v2i64:
63   case MVT::v2f32:
64   case MVT::v4f32:
65   case MVT::v2f64:
66     return true;
67   }
68 }
69
70 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
71 /// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
72 /// into their primitive components.
73 /// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
74 /// same number of types as the Ins/Outs arrays in LowerFormalArguments,
75 /// LowerCall, and LowerReturn.
76 static void ComputePTXValueVTs(const TargetLowering &TLI, Type *Ty,
77                                SmallVectorImpl<EVT> &ValueVTs,
78                                SmallVectorImpl<uint64_t> *Offsets = 0,
79                                uint64_t StartingOffset = 0) {
80   SmallVector<EVT, 16> TempVTs;
81   SmallVector<uint64_t, 16> TempOffsets;
82
83   ComputeValueVTs(TLI, Ty, TempVTs, &TempOffsets, StartingOffset);
84   for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
85     EVT VT = TempVTs[i];
86     uint64_t Off = TempOffsets[i];
87     if (VT.isVector())
88       for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) {
89         ValueVTs.push_back(VT.getVectorElementType());
90         if (Offsets)
91           Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize());
92       }
93     else {
94       ValueVTs.push_back(VT);
95       if (Offsets)
96         Offsets->push_back(Off);
97     }
98   }
99 }
100
101 // NVPTXTargetLowering Constructor.
102 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
103     : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
104       nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
105
106   // always lower memset, memcpy, and memmove intrinsics to load/store
107   // instructions, rather
108   // then generating calls to memset, mempcy or memmove.
109   MaxStoresPerMemset = (unsigned) 0xFFFFFFFF;
110   MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF;
111   MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF;
112
113   setBooleanContents(ZeroOrNegativeOneBooleanContent);
114
115   // Jump is Expensive. Don't create extra control flow for 'and', 'or'
116   // condition branches.
117   setJumpIsExpensive(true);
118
119   // By default, use the Source scheduling
120   if (sched4reg)
121     setSchedulingPreference(Sched::RegPressure);
122   else
123     setSchedulingPreference(Sched::Source);
124
125   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
126   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
127   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
128   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
129   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
130   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
131
132   // Operations not directly supported by NVPTX.
133   setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
134   setOperationAction(ISD::BR_CC, MVT::f32, Expand);
135   setOperationAction(ISD::BR_CC, MVT::f64, Expand);
136   setOperationAction(ISD::BR_CC, MVT::i1, Expand);
137   setOperationAction(ISD::BR_CC, MVT::i8, Expand);
138   setOperationAction(ISD::BR_CC, MVT::i16, Expand);
139   setOperationAction(ISD::BR_CC, MVT::i32, Expand);
140   setOperationAction(ISD::BR_CC, MVT::i64, Expand);
141   // Some SIGN_EXTEND_INREG can be done using cvt instruction.
142   // For others we will expand to a SHL/SRA pair.
143   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal);
144   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Legal);
145   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Legal);
146   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Legal);
147   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
148
149   if (nvptxSubtarget.hasROT64()) {
150     setOperationAction(ISD::ROTL, MVT::i64, Legal);
151     setOperationAction(ISD::ROTR, MVT::i64, Legal);
152   } else {
153     setOperationAction(ISD::ROTL, MVT::i64, Expand);
154     setOperationAction(ISD::ROTR, MVT::i64, Expand);
155   }
156   if (nvptxSubtarget.hasROT32()) {
157     setOperationAction(ISD::ROTL, MVT::i32, Legal);
158     setOperationAction(ISD::ROTR, MVT::i32, Legal);
159   } else {
160     setOperationAction(ISD::ROTL, MVT::i32, Expand);
161     setOperationAction(ISD::ROTR, MVT::i32, Expand);
162   }
163
164   setOperationAction(ISD::ROTL, MVT::i16, Expand);
165   setOperationAction(ISD::ROTR, MVT::i16, Expand);
166   setOperationAction(ISD::ROTL, MVT::i8, Expand);
167   setOperationAction(ISD::ROTR, MVT::i8, Expand);
168   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
169   setOperationAction(ISD::BSWAP, MVT::i32, Expand);
170   setOperationAction(ISD::BSWAP, MVT::i64, Expand);
171
172   // Indirect branch is not supported.
173   // This also disables Jump Table creation.
174   setOperationAction(ISD::BR_JT, MVT::Other, Expand);
175   setOperationAction(ISD::BRIND, MVT::Other, Expand);
176
177   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
178   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
179
180   // We want to legalize constant related memmove and memcopy
181   // intrinsics.
182   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
183
184   // Turn FP extload into load/fextend
185   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
186   // Turn FP truncstore into trunc + store.
187   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
188
189   // PTX does not support load / store predicate registers
190   setOperationAction(ISD::LOAD, MVT::i1, Custom);
191   setOperationAction(ISD::STORE, MVT::i1, Custom);
192
193   setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
194   setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
195   setTruncStoreAction(MVT::i64, MVT::i1, Expand);
196   setTruncStoreAction(MVT::i32, MVT::i1, Expand);
197   setTruncStoreAction(MVT::i16, MVT::i1, Expand);
198   setTruncStoreAction(MVT::i8, MVT::i1, Expand);
199
200   // This is legal in NVPTX
201   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
202   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
203
204   // TRAP can be lowered to PTX trap
205   setOperationAction(ISD::TRAP, MVT::Other, Legal);
206
207   setOperationAction(ISD::ADDC, MVT::i64, Expand);
208   setOperationAction(ISD::ADDE, MVT::i64, Expand);
209
210   // Register custom handling for vector loads/stores
211   for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
212        ++i) {
213     MVT VT = (MVT::SimpleValueType) i;
214     if (IsPTXVectorType(VT)) {
215       setOperationAction(ISD::LOAD, VT, Custom);
216       setOperationAction(ISD::STORE, VT, Custom);
217       setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
218     }
219   }
220
221   // Custom handling for i8 intrinsics
222   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
223
224   setOperationAction(ISD::CTLZ, MVT::i16, Legal);
225   setOperationAction(ISD::CTLZ, MVT::i32, Legal);
226   setOperationAction(ISD::CTLZ, MVT::i64, Legal);
227   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i16, Legal);
228   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Legal);
229   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i64, Legal);
230   setOperationAction(ISD::CTTZ, MVT::i16, Expand);
231   setOperationAction(ISD::CTTZ, MVT::i32, Expand);
232   setOperationAction(ISD::CTTZ, MVT::i64, Expand);
233   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i16, Expand);
234   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i32, Expand);
235   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i64, Expand);
236   setOperationAction(ISD::CTPOP, MVT::i16, Legal);
237   setOperationAction(ISD::CTPOP, MVT::i32, Legal);
238   setOperationAction(ISD::CTPOP, MVT::i64, Legal);
239
240   // Now deduce the information based on the above mentioned
241   // actions
242   computeRegisterProperties();
243 }
244
245 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
246   switch (Opcode) {
247   default:
248     return 0;
249   case NVPTXISD::CALL:
250     return "NVPTXISD::CALL";
251   case NVPTXISD::RET_FLAG:
252     return "NVPTXISD::RET_FLAG";
253   case NVPTXISD::Wrapper:
254     return "NVPTXISD::Wrapper";
255   case NVPTXISD::DeclareParam:
256     return "NVPTXISD::DeclareParam";
257   case NVPTXISD::DeclareScalarParam:
258     return "NVPTXISD::DeclareScalarParam";
259   case NVPTXISD::DeclareRet:
260     return "NVPTXISD::DeclareRet";
261   case NVPTXISD::DeclareRetParam:
262     return "NVPTXISD::DeclareRetParam";
263   case NVPTXISD::PrintCall:
264     return "NVPTXISD::PrintCall";
265   case NVPTXISD::LoadParam:
266     return "NVPTXISD::LoadParam";
267   case NVPTXISD::LoadParamV2:
268     return "NVPTXISD::LoadParamV2";
269   case NVPTXISD::LoadParamV4:
270     return "NVPTXISD::LoadParamV4";
271   case NVPTXISD::StoreParam:
272     return "NVPTXISD::StoreParam";
273   case NVPTXISD::StoreParamV2:
274     return "NVPTXISD::StoreParamV2";
275   case NVPTXISD::StoreParamV4:
276     return "NVPTXISD::StoreParamV4";
277   case NVPTXISD::StoreParamS32:
278     return "NVPTXISD::StoreParamS32";
279   case NVPTXISD::StoreParamU32:
280     return "NVPTXISD::StoreParamU32";
281   case NVPTXISD::CallArgBegin:
282     return "NVPTXISD::CallArgBegin";
283   case NVPTXISD::CallArg:
284     return "NVPTXISD::CallArg";
285   case NVPTXISD::LastCallArg:
286     return "NVPTXISD::LastCallArg";
287   case NVPTXISD::CallArgEnd:
288     return "NVPTXISD::CallArgEnd";
289   case NVPTXISD::CallVoid:
290     return "NVPTXISD::CallVoid";
291   case NVPTXISD::CallVal:
292     return "NVPTXISD::CallVal";
293   case NVPTXISD::CallSymbol:
294     return "NVPTXISD::CallSymbol";
295   case NVPTXISD::Prototype:
296     return "NVPTXISD::Prototype";
297   case NVPTXISD::MoveParam:
298     return "NVPTXISD::MoveParam";
299   case NVPTXISD::StoreRetval:
300     return "NVPTXISD::StoreRetval";
301   case NVPTXISD::StoreRetvalV2:
302     return "NVPTXISD::StoreRetvalV2";
303   case NVPTXISD::StoreRetvalV4:
304     return "NVPTXISD::StoreRetvalV4";
305   case NVPTXISD::PseudoUseParam:
306     return "NVPTXISD::PseudoUseParam";
307   case NVPTXISD::RETURN:
308     return "NVPTXISD::RETURN";
309   case NVPTXISD::CallSeqBegin:
310     return "NVPTXISD::CallSeqBegin";
311   case NVPTXISD::CallSeqEnd:
312     return "NVPTXISD::CallSeqEnd";
313   case NVPTXISD::LoadV2:
314     return "NVPTXISD::LoadV2";
315   case NVPTXISD::LoadV4:
316     return "NVPTXISD::LoadV4";
317   case NVPTXISD::LDGV2:
318     return "NVPTXISD::LDGV2";
319   case NVPTXISD::LDGV4:
320     return "NVPTXISD::LDGV4";
321   case NVPTXISD::LDUV2:
322     return "NVPTXISD::LDUV2";
323   case NVPTXISD::LDUV4:
324     return "NVPTXISD::LDUV4";
325   case NVPTXISD::StoreV2:
326     return "NVPTXISD::StoreV2";
327   case NVPTXISD::StoreV4:
328     return "NVPTXISD::StoreV4";
329   }
330 }
331
332 bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
333   return VT == MVT::i1;
334 }
335
336 SDValue
337 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
338   SDLoc dl(Op);
339   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
340   Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
341   return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
342 }
343
344 std::string
345 NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
346                                   const SmallVectorImpl<ISD::OutputArg> &Outs,
347                                   unsigned retAlignment,
348                                   const ImmutableCallSite *CS) const {
349
350   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
351   assert(isABI && "Non-ABI compilation is not supported");
352   if (!isABI)
353     return "";
354
355   std::stringstream O;
356   O << "prototype_" << uniqueCallSite << " : .callprototype ";
357
358   if (retTy->getTypeID() == Type::VoidTyID) {
359     O << "()";
360   } else {
361     O << "(";
362     if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
363       unsigned size = 0;
364       if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
365         size = ITy->getBitWidth();
366         if (size < 32)
367           size = 32;
368       } else {
369         assert(retTy->isFloatingPointTy() &&
370                "Floating point type expected here");
371         size = retTy->getPrimitiveSizeInBits();
372       }
373
374       O << ".param .b" << size << " _";
375     } else if (isa<PointerType>(retTy)) {
376       O << ".param .b" << getPointerTy().getSizeInBits() << " _";
377     } else {
378       if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) {
379         SmallVector<EVT, 16> vtparts;
380         ComputeValueVTs(*this, retTy, vtparts);
381         unsigned totalsz = 0;
382         for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
383           unsigned elems = 1;
384           EVT elemtype = vtparts[i];
385           if (vtparts[i].isVector()) {
386             elems = vtparts[i].getVectorNumElements();
387             elemtype = vtparts[i].getVectorElementType();
388           }
389           // TODO: no need to loop
390           for (unsigned j = 0, je = elems; j != je; ++j) {
391             unsigned sz = elemtype.getSizeInBits();
392             if (elemtype.isInteger() && (sz < 8))
393               sz = 8;
394             totalsz += sz / 8;
395           }
396         }
397         O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
398       } else {
399         assert(false && "Unknown return type");
400       }
401     }
402     O << ") ";
403   }
404   O << "_ (";
405
406   bool first = true;
407   MVT thePointerTy = getPointerTy();
408
409   unsigned OIdx = 0;
410   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
411     Type *Ty = Args[i].Ty;
412     if (!first) {
413       O << ", ";
414     }
415     first = false;
416
417     if (Outs[OIdx].Flags.isByVal() == false) {
418       if (Ty->isAggregateType() || Ty->isVectorTy()) {
419         unsigned align = 0;
420         const CallInst *CallI = cast<CallInst>(CS->getInstruction());
421         const DataLayout *TD = getDataLayout();
422         // +1 because index 0 is reserved for return type alignment
423         if (!llvm::getAlign(*CallI, i + 1, align))
424           align = TD->getABITypeAlignment(Ty);
425         unsigned sz = TD->getTypeAllocSize(Ty);
426         O << ".param .align " << align << " .b8 ";
427         O << "_";
428         O << "[" << sz << "]";
429         // update the index for Outs
430         SmallVector<EVT, 16> vtparts;
431         ComputeValueVTs(*this, Ty, vtparts);
432         if (unsigned len = vtparts.size())
433           OIdx += len - 1;
434         continue;
435       }
436        // i8 types in IR will be i16 types in SDAG
437       assert((getValueType(Ty) == Outs[OIdx].VT ||
438              (getValueType(Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
439              "type mismatch between callee prototype and arguments");
440       // scalar type
441       unsigned sz = 0;
442       if (isa<IntegerType>(Ty)) {
443         sz = cast<IntegerType>(Ty)->getBitWidth();
444         if (sz < 32)
445           sz = 32;
446       } else if (isa<PointerType>(Ty))
447         sz = thePointerTy.getSizeInBits();
448       else
449         sz = Ty->getPrimitiveSizeInBits();
450       O << ".param .b" << sz << " ";
451       O << "_";
452       continue;
453     }
454     const PointerType *PTy = dyn_cast<PointerType>(Ty);
455     assert(PTy && "Param with byval attribute should be a pointer type");
456     Type *ETy = PTy->getElementType();
457
458     unsigned align = Outs[OIdx].Flags.getByValAlign();
459     unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
460     O << ".param .align " << align << " .b8 ";
461     O << "_";
462     O << "[" << sz << "]";
463   }
464   O << ");";
465   return O.str();
466 }
467
468 unsigned
469 NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
470                                           const ImmutableCallSite *CS,
471                                           Type *Ty,
472                                           unsigned Idx) const {
473   const DataLayout *TD = getDataLayout();
474   unsigned align = 0;
475   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
476
477   if (Func) { // direct call
478     assert(CS->getCalledFunction() &&
479            "direct call cannot find callee");
480     if (!llvm::getAlign(*(CS->getCalledFunction()), Idx, align))
481       align = TD->getABITypeAlignment(Ty);
482   }
483   else { // indirect call
484     const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
485     if (!llvm::getAlign(*CallI, Idx, align))
486       align = TD->getABITypeAlignment(Ty);
487   }
488
489   return align;
490 }
491
492 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
493                                        SmallVectorImpl<SDValue> &InVals) const {
494   SelectionDAG &DAG = CLI.DAG;
495   SDLoc dl = CLI.DL;
496   SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
497   SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
498   SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
499   SDValue Chain = CLI.Chain;
500   SDValue Callee = CLI.Callee;
501   bool &isTailCall = CLI.IsTailCall;
502   ArgListTy &Args = CLI.Args;
503   Type *retTy = CLI.RetTy;
504   ImmutableCallSite *CS = CLI.CS;
505
506   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
507   assert(isABI && "Non-ABI compilation is not supported");
508   if (!isABI)
509     return Chain;
510   const DataLayout *TD = getDataLayout();
511   MachineFunction &MF = DAG.getMachineFunction();
512   const Function *F = MF.getFunction();
513
514   SDValue tempChain = Chain;
515   Chain =
516       DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
517                            dl);
518   SDValue InFlag = Chain.getValue(1);
519
520   unsigned paramCount = 0;
521   // Args.size() and Outs.size() need not match.
522   // Outs.size() will be larger
523   //   * if there is an aggregate argument with multiple fields (each field
524   //     showing up separately in Outs)
525   //   * if there is a vector argument with more than typical vector-length
526   //     elements (generally if more than 4) where each vector element is
527   //     individually present in Outs.
528   // So a different index should be used for indexing into Outs/OutVals.
529   // See similar issue in LowerFormalArguments.
530   unsigned OIdx = 0;
531   // Declare the .params or .reg need to pass values
532   // to the function
533   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
534     EVT VT = Outs[OIdx].VT;
535     Type *Ty = Args[i].Ty;
536
537     if (Outs[OIdx].Flags.isByVal() == false) {
538       if (Ty->isAggregateType()) {
539         // aggregate
540         SmallVector<EVT, 16> vtparts;
541         ComputeValueVTs(*this, Ty, vtparts);
542
543         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
544         // declare .param .align <align> .b8 .param<n>[<size>];
545         unsigned sz = TD->getTypeAllocSize(Ty);
546         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
547         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
548                                       DAG.getConstant(paramCount, MVT::i32),
549                                       DAG.getConstant(sz, MVT::i32), InFlag };
550         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
551                             DeclareParamOps, 5);
552         InFlag = Chain.getValue(1);
553         unsigned curOffset = 0;
554         for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
555           unsigned elems = 1;
556           EVT elemtype = vtparts[j];
557           if (vtparts[j].isVector()) {
558             elems = vtparts[j].getVectorNumElements();
559             elemtype = vtparts[j].getVectorElementType();
560           }
561           for (unsigned k = 0, ke = elems; k != ke; ++k) {
562             unsigned sz = elemtype.getSizeInBits();
563             if (elemtype.isInteger() && (sz < 8))
564               sz = 8;
565             SDValue StVal = OutVals[OIdx];
566             if (elemtype.getSizeInBits() < 16) {
567               StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
568             }
569             SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
570             SDValue CopyParamOps[] = { Chain,
571                                        DAG.getConstant(paramCount, MVT::i32),
572                                        DAG.getConstant(curOffset, MVT::i32),
573                                        StVal, InFlag };
574             Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
575                                             CopyParamVTs, &CopyParamOps[0], 5,
576                                             elemtype, MachinePointerInfo());
577             InFlag = Chain.getValue(1);
578             curOffset += sz / 8;
579             ++OIdx;
580           }
581         }
582         if (vtparts.size() > 0)
583           --OIdx;
584         ++paramCount;
585         continue;
586       }
587       if (Ty->isVectorTy()) {
588         EVT ObjectVT = getValueType(Ty);
589         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
590         // declare .param .align <align> .b8 .param<n>[<size>];
591         unsigned sz = TD->getTypeAllocSize(Ty);
592         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
593         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
594                                       DAG.getConstant(paramCount, MVT::i32),
595                                       DAG.getConstant(sz, MVT::i32), InFlag };
596         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
597                             DeclareParamOps, 5);
598         InFlag = Chain.getValue(1);
599         unsigned NumElts = ObjectVT.getVectorNumElements();
600         EVT EltVT = ObjectVT.getVectorElementType();
601         EVT MemVT = EltVT;
602         bool NeedExtend = false;
603         if (EltVT.getSizeInBits() < 16) {
604           NeedExtend = true;
605           EltVT = MVT::i16;
606         }
607
608         // V1 store
609         if (NumElts == 1) {
610           SDValue Elt = OutVals[OIdx++];
611           if (NeedExtend)
612             Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt);
613
614           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
615           SDValue CopyParamOps[] = { Chain,
616                                      DAG.getConstant(paramCount, MVT::i32),
617                                      DAG.getConstant(0, MVT::i32), Elt,
618                                      InFlag };
619           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
620                                           CopyParamVTs, &CopyParamOps[0], 5,
621                                           MemVT, MachinePointerInfo());
622           InFlag = Chain.getValue(1);
623         } else if (NumElts == 2) {
624           SDValue Elt0 = OutVals[OIdx++];
625           SDValue Elt1 = OutVals[OIdx++];
626           if (NeedExtend) {
627             Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0);
628             Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1);
629           }
630
631           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
632           SDValue CopyParamOps[] = { Chain,
633                                      DAG.getConstant(paramCount, MVT::i32),
634                                      DAG.getConstant(0, MVT::i32), Elt0, Elt1,
635                                      InFlag };
636           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl,
637                                           CopyParamVTs, &CopyParamOps[0], 6,
638                                           MemVT, MachinePointerInfo());
639           InFlag = Chain.getValue(1);
640         } else {
641           unsigned curOffset = 0;
642           // V4 stores
643           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
644           // the
645           // vector will be expanded to a power of 2 elements, so we know we can
646           // always round up to the next multiple of 4 when creating the vector
647           // stores.
648           // e.g.  4 elem => 1 st.v4
649           //       6 elem => 2 st.v4
650           //       8 elem => 2 st.v4
651           //      11 elem => 3 st.v4
652           unsigned VecSize = 4;
653           if (EltVT.getSizeInBits() == 64)
654             VecSize = 2;
655
656           // This is potentially only part of a vector, so assume all elements
657           // are packed together.
658           unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize;
659
660           for (unsigned i = 0; i < NumElts; i += VecSize) {
661             // Get values
662             SDValue StoreVal;
663             SmallVector<SDValue, 8> Ops;
664             Ops.push_back(Chain);
665             Ops.push_back(DAG.getConstant(paramCount, MVT::i32));
666             Ops.push_back(DAG.getConstant(curOffset, MVT::i32));
667
668             unsigned Opc = NVPTXISD::StoreParamV2;
669
670             StoreVal = OutVals[OIdx++];
671             if (NeedExtend)
672               StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
673             Ops.push_back(StoreVal);
674
675             if (i + 1 < NumElts) {
676               StoreVal = OutVals[OIdx++];
677               if (NeedExtend)
678                 StoreVal =
679                     DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
680             } else {
681               StoreVal = DAG.getUNDEF(EltVT);
682             }
683             Ops.push_back(StoreVal);
684
685             if (VecSize == 4) {
686               Opc = NVPTXISD::StoreParamV4;
687               if (i + 2 < NumElts) {
688                 StoreVal = OutVals[OIdx++];
689                 if (NeedExtend)
690                   StoreVal =
691                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
692               } else {
693                 StoreVal = DAG.getUNDEF(EltVT);
694               }
695               Ops.push_back(StoreVal);
696
697               if (i + 3 < NumElts) {
698                 StoreVal = OutVals[OIdx++];
699                 if (NeedExtend)
700                   StoreVal =
701                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
702               } else {
703                 StoreVal = DAG.getUNDEF(EltVT);
704               }
705               Ops.push_back(StoreVal);
706             }
707
708             Ops.push_back(InFlag);
709
710             SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
711             Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, &Ops[0],
712                                             Ops.size(), MemVT,
713                                             MachinePointerInfo());
714             InFlag = Chain.getValue(1);
715             curOffset += PerStoreOffset;
716           }
717         }
718         ++paramCount;
719         --OIdx;
720         continue;
721       }
722       // Plain scalar
723       // for ABI,    declare .param .b<size> .param<n>;
724       unsigned sz = VT.getSizeInBits();
725       bool needExtend = false;
726       if (VT.isInteger()) {
727         if (sz < 16)
728           needExtend = true;
729         if (sz < 32)
730           sz = 32;
731       }
732       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
733       SDValue DeclareParamOps[] = { Chain,
734                                     DAG.getConstant(paramCount, MVT::i32),
735                                     DAG.getConstant(sz, MVT::i32),
736                                     DAG.getConstant(0, MVT::i32), InFlag };
737       Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
738                           DeclareParamOps, 5);
739       InFlag = Chain.getValue(1);
740       SDValue OutV = OutVals[OIdx];
741       if (needExtend) {
742         // zext/sext i1 to i16
743         unsigned opc = ISD::ZERO_EXTEND;
744         if (Outs[OIdx].Flags.isSExt())
745           opc = ISD::SIGN_EXTEND;
746         OutV = DAG.getNode(opc, dl, MVT::i16, OutV);
747       }
748       SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
749       SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
750                                  DAG.getConstant(0, MVT::i32), OutV, InFlag };
751
752       unsigned opcode = NVPTXISD::StoreParam;
753       if (Outs[OIdx].Flags.isZExt())
754         opcode = NVPTXISD::StoreParamU32;
755       else if (Outs[OIdx].Flags.isSExt())
756         opcode = NVPTXISD::StoreParamS32;
757       Chain = DAG.getMemIntrinsicNode(opcode, dl, CopyParamVTs, CopyParamOps, 5,
758                                       VT, MachinePointerInfo());
759
760       InFlag = Chain.getValue(1);
761       ++paramCount;
762       continue;
763     }
764     // struct or vector
765     SmallVector<EVT, 16> vtparts;
766     const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
767     assert(PTy && "Type of a byval parameter should be pointer");
768     ComputeValueVTs(*this, PTy->getElementType(), vtparts);
769
770     // declare .param .align <align> .b8 .param<n>[<size>];
771     unsigned sz = Outs[OIdx].Flags.getByValSize();
772     SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
773     // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
774     // so we don't need to worry about natural alignment or not.
775     // See TargetLowering::LowerCallTo().
776     SDValue DeclareParamOps[] = {
777       Chain, DAG.getConstant(Outs[OIdx].Flags.getByValAlign(), MVT::i32),
778       DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32),
779       InFlag
780     };
781     Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
782                         DeclareParamOps, 5);
783     InFlag = Chain.getValue(1);
784     unsigned curOffset = 0;
785     for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
786       unsigned elems = 1;
787       EVT elemtype = vtparts[j];
788       if (vtparts[j].isVector()) {
789         elems = vtparts[j].getVectorNumElements();
790         elemtype = vtparts[j].getVectorElementType();
791       }
792       for (unsigned k = 0, ke = elems; k != ke; ++k) {
793         unsigned sz = elemtype.getSizeInBits();
794         if (elemtype.isInteger() && (sz < 8))
795           sz = 8;
796         SDValue srcAddr =
797             DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
798                         DAG.getConstant(curOffset, getPointerTy()));
799         SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
800                                      MachinePointerInfo(), false, false, false,
801                                      0);
802         if (elemtype.getSizeInBits() < 16) {
803           theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
804         }
805         SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
806         SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
807                                    DAG.getConstant(curOffset, MVT::i32), theVal,
808                                    InFlag };
809         Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
810                                         CopyParamOps, 5, elemtype,
811                                         MachinePointerInfo());
812
813         InFlag = Chain.getValue(1);
814         curOffset += sz / 8;
815       }
816     }
817     ++paramCount;
818   }
819
820   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
821   unsigned retAlignment = 0;
822
823   // Handle Result
824   if (Ins.size() > 0) {
825     SmallVector<EVT, 16> resvtparts;
826     ComputeValueVTs(*this, retTy, resvtparts);
827
828     // Declare
829     //  .param .align 16 .b8 retval0[<size-in-bytes>], or
830     //  .param .b<size-in-bits> retval0
831     unsigned resultsz = TD->getTypeAllocSizeInBits(retTy);
832     if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
833         retTy->isPointerTy()) {
834       // Scalar needs to be at least 32bit wide
835       if (resultsz < 32)
836         resultsz = 32;
837       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
838       SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
839                                   DAG.getConstant(resultsz, MVT::i32),
840                                   DAG.getConstant(0, MVT::i32), InFlag };
841       Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
842                           DeclareRetOps, 5);
843       InFlag = Chain.getValue(1);
844     } else {
845       retAlignment = getArgumentAlignment(Callee, CS, retTy, 0);
846       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
847       SDValue DeclareRetOps[] = { Chain,
848                                   DAG.getConstant(retAlignment, MVT::i32),
849                                   DAG.getConstant(resultsz / 8, MVT::i32),
850                                   DAG.getConstant(0, MVT::i32), InFlag };
851       Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
852                           DeclareRetOps, 5);
853       InFlag = Chain.getValue(1);
854     }
855   }
856
857   if (!Func) {
858     // This is indirect function call case : PTX requires a prototype of the
859     // form
860     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
861     // to be emitted, and the label has to used as the last arg of call
862     // instruction.
863     // The prototype is embedded in a string and put as the operand for an
864     // INLINEASM SDNode.
865     SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
866     std::string proto_string =
867         getPrototype(retTy, Args, Outs, retAlignment, CS);
868     const char *asmstr = nvTM->getManagedStrPool()
869         ->getManagedString(proto_string.c_str())->c_str();
870     SDValue InlineAsmOps[] = {
871       Chain, DAG.getTargetExternalSymbol(asmstr, getPointerTy()),
872       DAG.getMDNode(0), DAG.getTargetConstant(0, MVT::i32), InFlag
873     };
874     Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
875     InFlag = Chain.getValue(1);
876   }
877   // Op to just print "call"
878   SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
879   SDValue PrintCallOps[] = {
880     Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, MVT::i32), InFlag
881   };
882   Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
883                       dl, PrintCallVTs, PrintCallOps, 3);
884   InFlag = Chain.getValue(1);
885
886   // Ops to print out the function name
887   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
888   SDValue CallVoidOps[] = { Chain, Callee, InFlag };
889   Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
890   InFlag = Chain.getValue(1);
891
892   // Ops to print out the param list
893   SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
894   SDValue CallArgBeginOps[] = { Chain, InFlag };
895   Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
896                       CallArgBeginOps, 2);
897   InFlag = Chain.getValue(1);
898
899   for (unsigned i = 0, e = paramCount; i != e; ++i) {
900     unsigned opcode;
901     if (i == (e - 1))
902       opcode = NVPTXISD::LastCallArg;
903     else
904       opcode = NVPTXISD::CallArg;
905     SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
906     SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
907                              DAG.getConstant(i, MVT::i32), InFlag };
908     Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
909     InFlag = Chain.getValue(1);
910   }
911   SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
912   SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32),
913                               InFlag };
914   Chain =
915       DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps, 3);
916   InFlag = Chain.getValue(1);
917
918   if (!Func) {
919     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
920     SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32),
921                                InFlag };
922     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
923     InFlag = Chain.getValue(1);
924   }
925
926   // Generate loads from param memory/moves from registers for result
927   if (Ins.size() > 0) {
928     unsigned resoffset = 0;
929     if (retTy && retTy->isVectorTy()) {
930       EVT ObjectVT = getValueType(retTy);
931       unsigned NumElts = ObjectVT.getVectorNumElements();
932       EVT EltVT = ObjectVT.getVectorElementType();
933       assert(nvTM->getTargetLowering()->getNumRegisters(F->getContext(),
934                                                         ObjectVT) == NumElts &&
935              "Vector was not scalarized");
936       unsigned sz = EltVT.getSizeInBits();
937       bool needTruncate = sz < 16 ? true : false;
938
939       if (NumElts == 1) {
940         // Just a simple load
941         std::vector<EVT> LoadRetVTs;
942         if (needTruncate) {
943           // If loading i1 result, generate
944           //   load i16
945           //   trunc i16 to i1
946           LoadRetVTs.push_back(MVT::i16);
947         } else
948           LoadRetVTs.push_back(EltVT);
949         LoadRetVTs.push_back(MVT::Other);
950         LoadRetVTs.push_back(MVT::Glue);
951         std::vector<SDValue> LoadRetOps;
952         LoadRetOps.push_back(Chain);
953         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
954         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
955         LoadRetOps.push_back(InFlag);
956         SDValue retval = DAG.getMemIntrinsicNode(
957             NVPTXISD::LoadParam, dl,
958             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
959             LoadRetOps.size(), EltVT, MachinePointerInfo());
960         Chain = retval.getValue(1);
961         InFlag = retval.getValue(2);
962         SDValue Ret0 = retval;
963         if (needTruncate)
964           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Ret0);
965         InVals.push_back(Ret0);
966       } else if (NumElts == 2) {
967         // LoadV2
968         std::vector<EVT> LoadRetVTs;
969         if (needTruncate) {
970           // If loading i1 result, generate
971           //   load i16
972           //   trunc i16 to i1
973           LoadRetVTs.push_back(MVT::i16);
974           LoadRetVTs.push_back(MVT::i16);
975         } else {
976           LoadRetVTs.push_back(EltVT);
977           LoadRetVTs.push_back(EltVT);
978         }
979         LoadRetVTs.push_back(MVT::Other);
980         LoadRetVTs.push_back(MVT::Glue);
981         std::vector<SDValue> LoadRetOps;
982         LoadRetOps.push_back(Chain);
983         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
984         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
985         LoadRetOps.push_back(InFlag);
986         SDValue retval = DAG.getMemIntrinsicNode(
987             NVPTXISD::LoadParamV2, dl,
988             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
989             LoadRetOps.size(), EltVT, MachinePointerInfo());
990         Chain = retval.getValue(2);
991         InFlag = retval.getValue(3);
992         SDValue Ret0 = retval.getValue(0);
993         SDValue Ret1 = retval.getValue(1);
994         if (needTruncate) {
995           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret0);
996           InVals.push_back(Ret0);
997           Ret1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret1);
998           InVals.push_back(Ret1);
999         } else {
1000           InVals.push_back(Ret0);
1001           InVals.push_back(Ret1);
1002         }
1003       } else {
1004         // Split into N LoadV4
1005         unsigned Ofst = 0;
1006         unsigned VecSize = 4;
1007         unsigned Opc = NVPTXISD::LoadParamV4;
1008         if (EltVT.getSizeInBits() == 64) {
1009           VecSize = 2;
1010           Opc = NVPTXISD::LoadParamV2;
1011         }
1012         EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1013         for (unsigned i = 0; i < NumElts; i += VecSize) {
1014           SmallVector<EVT, 8> LoadRetVTs;
1015           if (needTruncate) {
1016             // If loading i1 result, generate
1017             //   load i16
1018             //   trunc i16 to i1
1019             for (unsigned j = 0; j < VecSize; ++j)
1020               LoadRetVTs.push_back(MVT::i16);
1021           } else {
1022             for (unsigned j = 0; j < VecSize; ++j)
1023               LoadRetVTs.push_back(EltVT);
1024           }
1025           LoadRetVTs.push_back(MVT::Other);
1026           LoadRetVTs.push_back(MVT::Glue);
1027           SmallVector<SDValue, 4> LoadRetOps;
1028           LoadRetOps.push_back(Chain);
1029           LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1030           LoadRetOps.push_back(DAG.getConstant(Ofst, MVT::i32));
1031           LoadRetOps.push_back(InFlag);
1032           SDValue retval = DAG.getMemIntrinsicNode(
1033               Opc, dl, DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()),
1034               &LoadRetOps[0], LoadRetOps.size(), EltVT, MachinePointerInfo());
1035           if (VecSize == 2) {
1036             Chain = retval.getValue(2);
1037             InFlag = retval.getValue(3);
1038           } else {
1039             Chain = retval.getValue(4);
1040             InFlag = retval.getValue(5);
1041           }
1042
1043           for (unsigned j = 0; j < VecSize; ++j) {
1044             if (i + j >= NumElts)
1045               break;
1046             SDValue Elt = retval.getValue(j);
1047             if (needTruncate)
1048               Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
1049             InVals.push_back(Elt);
1050           }
1051           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1052         }
1053       }
1054     } else {
1055       SmallVector<EVT, 16> VTs;
1056       ComputePTXValueVTs(*this, retTy, VTs);
1057       assert(VTs.size() == Ins.size() && "Bad value decomposition");
1058       for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
1059         unsigned sz = VTs[i].getSizeInBits();
1060         bool needTruncate = sz < 8 ? true : false;
1061         if (VTs[i].isInteger() && (sz < 8))
1062           sz = 8;
1063
1064         SmallVector<EVT, 4> LoadRetVTs;
1065         EVT TheLoadType = VTs[i];
1066         if (retTy->isIntegerTy() &&
1067             TD->getTypeAllocSizeInBits(retTy) < 32) {
1068           // This is for integer types only, and specifically not for
1069           // aggregates.
1070           LoadRetVTs.push_back(MVT::i32);
1071           TheLoadType = MVT::i32;
1072         } else if (sz < 16) {
1073           // If loading i1/i8 result, generate
1074           //   load i8 (-> i16)
1075           //   trunc i16 to i1/i8
1076           LoadRetVTs.push_back(MVT::i16);
1077         } else
1078           LoadRetVTs.push_back(Ins[i].VT);
1079         LoadRetVTs.push_back(MVT::Other);
1080         LoadRetVTs.push_back(MVT::Glue);
1081
1082         SmallVector<SDValue, 4> LoadRetOps;
1083         LoadRetOps.push_back(Chain);
1084         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1085         LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
1086         LoadRetOps.push_back(InFlag);
1087         SDValue retval = DAG.getMemIntrinsicNode(
1088             NVPTXISD::LoadParam, dl,
1089             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1090             LoadRetOps.size(), TheLoadType, MachinePointerInfo());
1091         Chain = retval.getValue(1);
1092         InFlag = retval.getValue(2);
1093         SDValue Ret0 = retval.getValue(0);
1094         if (needTruncate)
1095           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
1096         InVals.push_back(Ret0);
1097         resoffset += sz / 8;
1098       }
1099     }
1100   }
1101
1102   Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
1103                              DAG.getIntPtrConstant(uniqueCallSite + 1, true),
1104                              InFlag, dl);
1105   uniqueCallSite++;
1106
1107   // set isTailCall to false for now, until we figure out how to express
1108   // tail call optimization in PTX
1109   isTailCall = false;
1110   return Chain;
1111 }
1112
1113 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1114 // (see LegalizeDAG.cpp). This is slow and uses local memory.
1115 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1116 SDValue
1117 NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1118   SDNode *Node = Op.getNode();
1119   SDLoc dl(Node);
1120   SmallVector<SDValue, 8> Ops;
1121   unsigned NumOperands = Node->getNumOperands();
1122   for (unsigned i = 0; i < NumOperands; ++i) {
1123     SDValue SubOp = Node->getOperand(i);
1124     EVT VVT = SubOp.getNode()->getValueType(0);
1125     EVT EltVT = VVT.getVectorElementType();
1126     unsigned NumSubElem = VVT.getVectorNumElements();
1127     for (unsigned j = 0; j < NumSubElem; ++j) {
1128       Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
1129                                 DAG.getIntPtrConstant(j)));
1130     }
1131   }
1132   return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), &Ops[0],
1133                      Ops.size());
1134 }
1135
1136 SDValue
1137 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1138   switch (Op.getOpcode()) {
1139   case ISD::RETURNADDR:
1140     return SDValue();
1141   case ISD::FRAMEADDR:
1142     return SDValue();
1143   case ISD::GlobalAddress:
1144     return LowerGlobalAddress(Op, DAG);
1145   case ISD::INTRINSIC_W_CHAIN:
1146     return Op;
1147   case ISD::BUILD_VECTOR:
1148   case ISD::EXTRACT_SUBVECTOR:
1149     return Op;
1150   case ISD::CONCAT_VECTORS:
1151     return LowerCONCAT_VECTORS(Op, DAG);
1152   case ISD::STORE:
1153     return LowerSTORE(Op, DAG);
1154   case ISD::LOAD:
1155     return LowerLOAD(Op, DAG);
1156   default:
1157     llvm_unreachable("Custom lowering not defined for operation");
1158   }
1159 }
1160
1161 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
1162   if (Op.getValueType() == MVT::i1)
1163     return LowerLOADi1(Op, DAG);
1164   else
1165     return SDValue();
1166 }
1167
1168 // v = ld i1* addr
1169 //   =>
1170 // v1 = ld i8* addr (-> i16)
1171 // v = trunc i16 to i1
1172 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
1173   SDNode *Node = Op.getNode();
1174   LoadSDNode *LD = cast<LoadSDNode>(Node);
1175   SDLoc dl(Node);
1176   assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
1177   assert(Node->getValueType(0) == MVT::i1 &&
1178          "Custom lowering for i1 load only");
1179   SDValue newLD =
1180       DAG.getLoad(MVT::i16, dl, LD->getChain(), LD->getBasePtr(),
1181                   LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(),
1182                   LD->isInvariant(), LD->getAlignment());
1183   SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
1184   // The legalizer (the caller) is expecting two values from the legalized
1185   // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
1186   // in LegalizeDAG.cpp which also uses MergeValues.
1187   SDValue Ops[] = { result, LD->getChain() };
1188   return DAG.getMergeValues(Ops, 2, dl);
1189 }
1190
1191 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
1192   EVT ValVT = Op.getOperand(1).getValueType();
1193   if (ValVT == MVT::i1)
1194     return LowerSTOREi1(Op, DAG);
1195   else if (ValVT.isVector())
1196     return LowerSTOREVector(Op, DAG);
1197   else
1198     return SDValue();
1199 }
1200
1201 SDValue
1202 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
1203   SDNode *N = Op.getNode();
1204   SDValue Val = N->getOperand(1);
1205   SDLoc DL(N);
1206   EVT ValVT = Val.getValueType();
1207
1208   if (ValVT.isVector()) {
1209     // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1210     // legal.  We can (and should) split that into 2 stores of <2 x double> here
1211     // but I'm leaving that as a TODO for now.
1212     if (!ValVT.isSimple())
1213       return SDValue();
1214     switch (ValVT.getSimpleVT().SimpleTy) {
1215     default:
1216       return SDValue();
1217     case MVT::v2i8:
1218     case MVT::v2i16:
1219     case MVT::v2i32:
1220     case MVT::v2i64:
1221     case MVT::v2f32:
1222     case MVT::v2f64:
1223     case MVT::v4i8:
1224     case MVT::v4i16:
1225     case MVT::v4i32:
1226     case MVT::v4f32:
1227       // This is a "native" vector type
1228       break;
1229     }
1230
1231     unsigned Opcode = 0;
1232     EVT EltVT = ValVT.getVectorElementType();
1233     unsigned NumElts = ValVT.getVectorNumElements();
1234
1235     // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
1236     // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1237     // stored type to i16 and propogate the "real" type as the memory type.
1238     bool NeedExt = false;
1239     if (EltVT.getSizeInBits() < 16)
1240       NeedExt = true;
1241
1242     switch (NumElts) {
1243     default:
1244       return SDValue();
1245     case 2:
1246       Opcode = NVPTXISD::StoreV2;
1247       break;
1248     case 4: {
1249       Opcode = NVPTXISD::StoreV4;
1250       break;
1251     }
1252     }
1253
1254     SmallVector<SDValue, 8> Ops;
1255
1256     // First is the chain
1257     Ops.push_back(N->getOperand(0));
1258
1259     // Then the split values
1260     for (unsigned i = 0; i < NumElts; ++i) {
1261       SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
1262                                    DAG.getIntPtrConstant(i));
1263       if (NeedExt)
1264         ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
1265       Ops.push_back(ExtVal);
1266     }
1267
1268     // Then any remaining arguments
1269     for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
1270       Ops.push_back(N->getOperand(i));
1271     }
1272
1273     MemSDNode *MemSD = cast<MemSDNode>(N);
1274
1275     SDValue NewSt = DAG.getMemIntrinsicNode(
1276         Opcode, DL, DAG.getVTList(MVT::Other), &Ops[0], Ops.size(),
1277         MemSD->getMemoryVT(), MemSD->getMemOperand());
1278
1279     //return DCI.CombineTo(N, NewSt, true);
1280     return NewSt;
1281   }
1282
1283   return SDValue();
1284 }
1285
1286 // st i1 v, addr
1287 //    =>
1288 // v1 = zxt v to i16
1289 // st.u8 i16, addr
1290 SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
1291   SDNode *Node = Op.getNode();
1292   SDLoc dl(Node);
1293   StoreSDNode *ST = cast<StoreSDNode>(Node);
1294   SDValue Tmp1 = ST->getChain();
1295   SDValue Tmp2 = ST->getBasePtr();
1296   SDValue Tmp3 = ST->getValue();
1297   assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
1298   unsigned Alignment = ST->getAlignment();
1299   bool isVolatile = ST->isVolatile();
1300   bool isNonTemporal = ST->isNonTemporal();
1301   Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
1302   SDValue Result = DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2,
1303                                      ST->getPointerInfo(), MVT::i8, isNonTemporal,
1304                                      isVolatile, Alignment);
1305   return Result;
1306 }
1307
1308 SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname,
1309                                         int idx, EVT v) const {
1310   std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
1311   std::stringstream suffix;
1312   suffix << idx;
1313   *name += suffix.str();
1314   return DAG.getTargetExternalSymbol(name->c_str(), v);
1315 }
1316
1317 SDValue
1318 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1319   std::string ParamSym;
1320   raw_string_ostream ParamStr(ParamSym);
1321
1322   ParamStr << DAG.getMachineFunction().getName() << "_param_" << idx;
1323   ParamStr.flush();
1324
1325   std::string *SavedStr =
1326     nvTM->getManagedStrPool()->getManagedString(ParamSym.c_str());
1327   return DAG.getTargetExternalSymbol(SavedStr->c_str(), v);
1328 }
1329
1330 SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1331   return getExtSymb(DAG, ".HLPPARAM", idx);
1332 }
1333
1334 // Check to see if the kernel argument is image*_t or sampler_t
1335
1336 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1337   static const char *const specialTypes[] = { "struct._image2d_t",
1338                                               "struct._image3d_t",
1339                                               "struct._sampler_t" };
1340
1341   const Type *Ty = arg->getType();
1342   const PointerType *PTy = dyn_cast<PointerType>(Ty);
1343
1344   if (!PTy)
1345     return false;
1346
1347   if (!context)
1348     return false;
1349
1350   const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1351   const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1352
1353   for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1354     if (TypeName == specialTypes[i])
1355       return true;
1356
1357   return false;
1358 }
1359
1360 SDValue NVPTXTargetLowering::LowerFormalArguments(
1361     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1362     const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG,
1363     SmallVectorImpl<SDValue> &InVals) const {
1364   MachineFunction &MF = DAG.getMachineFunction();
1365   const DataLayout *TD = getDataLayout();
1366
1367   const Function *F = MF.getFunction();
1368   const AttributeSet &PAL = F->getAttributes();
1369   const TargetLowering *TLI = nvTM->getTargetLowering();
1370
1371   SDValue Root = DAG.getRoot();
1372   std::vector<SDValue> OutChains;
1373
1374   bool isKernel = llvm::isKernelFunction(*F);
1375   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1376   assert(isABI && "Non-ABI compilation is not supported");
1377   if (!isABI)
1378     return Chain;
1379
1380   std::vector<Type *> argTypes;
1381   std::vector<const Argument *> theArgs;
1382   for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1383        I != E; ++I) {
1384     theArgs.push_back(I);
1385     argTypes.push_back(I->getType());
1386   }
1387   // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
1388   // Ins.size() will be larger
1389   //   * if there is an aggregate argument with multiple fields (each field
1390   //     showing up separately in Ins)
1391   //   * if there is a vector argument with more than typical vector-length
1392   //     elements (generally if more than 4) where each vector element is
1393   //     individually present in Ins.
1394   // So a different index should be used for indexing into Ins.
1395   // See similar issue in LowerCall.
1396   unsigned InsIdx = 0;
1397
1398   int idx = 0;
1399   for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) {
1400     Type *Ty = argTypes[i];
1401
1402     // If the kernel argument is image*_t or sampler_t, convert it to
1403     // a i32 constant holding the parameter position. This can later
1404     // matched in the AsmPrinter to output the correct mangled name.
1405     if (isImageOrSamplerVal(
1406             theArgs[i],
1407             (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent()
1408                                      : 0))) {
1409       assert(isKernel && "Only kernels can have image/sampler params");
1410       InVals.push_back(DAG.getConstant(i + 1, MVT::i32));
1411       continue;
1412     }
1413
1414     if (theArgs[i]->use_empty()) {
1415       // argument is dead
1416       if (Ty->isAggregateType()) {
1417         SmallVector<EVT, 16> vtparts;
1418
1419         ComputePTXValueVTs(*this, Ty, vtparts);
1420         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1421         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1422              ++parti) {
1423           EVT partVT = vtparts[parti];
1424           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, partVT));
1425           ++InsIdx;
1426         }
1427         if (vtparts.size() > 0)
1428           --InsIdx;
1429         continue;
1430       }
1431       if (Ty->isVectorTy()) {
1432         EVT ObjectVT = getValueType(Ty);
1433         unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
1434         for (unsigned parti = 0; parti < NumRegs; ++parti) {
1435           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1436           ++InsIdx;
1437         }
1438         if (NumRegs > 0)
1439           --InsIdx;
1440         continue;
1441       }
1442       InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1443       continue;
1444     }
1445
1446     // In the following cases, assign a node order of "idx+1"
1447     // to newly created nodes. The SDNodes for params have to
1448     // appear in the same order as their order of appearance
1449     // in the original function. "idx+1" holds that order.
1450     if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
1451       if (Ty->isAggregateType()) {
1452         SmallVector<EVT, 16> vtparts;
1453         SmallVector<uint64_t, 16> offsets;
1454
1455         // NOTE: Here, we lose the ability to issue vector loads for vectors
1456         // that are a part of a struct.  This should be investigated in the
1457         // future.
1458         ComputePTXValueVTs(*this, Ty, vtparts, &offsets, 0);
1459         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1460         bool aggregateIsPacked = false;
1461         if (StructType *STy = llvm::dyn_cast<StructType>(Ty))
1462           aggregateIsPacked = STy->isPacked();
1463
1464         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1465         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1466              ++parti) {
1467           EVT partVT = vtparts[parti];
1468           Value *srcValue = Constant::getNullValue(
1469               PointerType::get(partVT.getTypeForEVT(F->getContext()),
1470                                llvm::ADDRESS_SPACE_PARAM));
1471           SDValue srcAddr =
1472               DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1473                           DAG.getConstant(offsets[parti], getPointerTy()));
1474           unsigned partAlign =
1475               aggregateIsPacked ? 1
1476                                 : TD->getABITypeAlignment(
1477                                       partVT.getTypeForEVT(F->getContext()));
1478           SDValue p;
1479           if (Ins[InsIdx].VT.getSizeInBits() > partVT.getSizeInBits()) {
1480             ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ? 
1481                                      ISD::SEXTLOAD : ISD::ZEXTLOAD;
1482             p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, srcAddr,
1483                                MachinePointerInfo(srcValue), partVT, false,
1484                                false, partAlign);
1485           } else {
1486             p = DAG.getLoad(partVT, dl, Root, srcAddr,
1487                             MachinePointerInfo(srcValue), false, false, false,
1488                             partAlign);
1489           }
1490           if (p.getNode())
1491             p.getNode()->setIROrder(idx + 1);
1492           InVals.push_back(p);
1493           ++InsIdx;
1494         }
1495         if (vtparts.size() > 0)
1496           --InsIdx;
1497         continue;
1498       }
1499       if (Ty->isVectorTy()) {
1500         EVT ObjectVT = getValueType(Ty);
1501         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1502         unsigned NumElts = ObjectVT.getVectorNumElements();
1503         assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1504                "Vector was not scalarized");
1505         unsigned Ofst = 0;
1506         EVT EltVT = ObjectVT.getVectorElementType();
1507
1508         // V1 load
1509         // f32 = load ...
1510         if (NumElts == 1) {
1511           // We only have one element, so just directly load it
1512           Value *SrcValue = Constant::getNullValue(PointerType::get(
1513               EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1514           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1515                                         DAG.getConstant(Ofst, getPointerTy()));
1516           SDValue P = DAG.getLoad(
1517               EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1518               false, true,
1519               TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
1520           if (P.getNode())
1521             P.getNode()->setIROrder(idx + 1);
1522
1523           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1524             P = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, P);
1525           InVals.push_back(P);
1526           Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
1527           ++InsIdx;
1528         } else if (NumElts == 2) {
1529           // V2 load
1530           // f32,f32 = load ...
1531           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
1532           Value *SrcValue = Constant::getNullValue(PointerType::get(
1533               VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1534           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1535                                         DAG.getConstant(Ofst, getPointerTy()));
1536           SDValue P = DAG.getLoad(
1537               VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1538               false, true,
1539               TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1540           if (P.getNode())
1541             P.getNode()->setIROrder(idx + 1);
1542
1543           SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1544                                      DAG.getIntPtrConstant(0));
1545           SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1546                                      DAG.getIntPtrConstant(1));
1547
1548           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) {
1549             Elt0 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt0);
1550             Elt1 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt1);
1551           }
1552
1553           InVals.push_back(Elt0);
1554           InVals.push_back(Elt1);
1555           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1556           InsIdx += 2;
1557         } else {
1558           // V4 loads
1559           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
1560           // the
1561           // vector will be expanded to a power of 2 elements, so we know we can
1562           // always round up to the next multiple of 4 when creating the vector
1563           // loads.
1564           // e.g.  4 elem => 1 ld.v4
1565           //       6 elem => 2 ld.v4
1566           //       8 elem => 2 ld.v4
1567           //      11 elem => 3 ld.v4
1568           unsigned VecSize = 4;
1569           if (EltVT.getSizeInBits() == 64) {
1570             VecSize = 2;
1571           }
1572           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1573           for (unsigned i = 0; i < NumElts; i += VecSize) {
1574             Value *SrcValue = Constant::getNullValue(
1575                 PointerType::get(VecVT.getTypeForEVT(F->getContext()),
1576                                  llvm::ADDRESS_SPACE_PARAM));
1577             SDValue SrcAddr =
1578                 DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1579                             DAG.getConstant(Ofst, getPointerTy()));
1580             SDValue P = DAG.getLoad(
1581                 VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1582                 false, true,
1583                 TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1584             if (P.getNode())
1585               P.getNode()->setIROrder(idx + 1);
1586
1587             for (unsigned j = 0; j < VecSize; ++j) {
1588               if (i + j >= NumElts)
1589                 break;
1590               SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1591                                         DAG.getIntPtrConstant(j));
1592               if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1593                 Elt = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt);
1594               InVals.push_back(Elt);
1595             }
1596             Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1597           }
1598           InsIdx += NumElts;
1599         }
1600
1601         if (NumElts > 0)
1602           --InsIdx;
1603         continue;
1604       }
1605       // A plain scalar.
1606       EVT ObjectVT = getValueType(Ty);
1607       // If ABI, load from the param symbol
1608       SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1609       Value *srcValue = Constant::getNullValue(PointerType::get(
1610           ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1611       SDValue p;
1612        if (ObjectVT.getSizeInBits() < Ins[InsIdx].VT.getSizeInBits()) {
1613         ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ? 
1614                                        ISD::SEXTLOAD : ISD::ZEXTLOAD;
1615         p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, Arg,
1616                            MachinePointerInfo(srcValue), ObjectVT, false, false,
1617         TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1618       } else {
1619         p = DAG.getLoad(Ins[InsIdx].VT, dl, Root, Arg,
1620                         MachinePointerInfo(srcValue), false, false, false,
1621         TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1622       }
1623       if (p.getNode())
1624         p.getNode()->setIROrder(idx + 1);
1625       InVals.push_back(p);
1626       continue;
1627     }
1628
1629     // Param has ByVal attribute
1630     // Return MoveParam(param symbol).
1631     // Ideally, the param symbol can be returned directly,
1632     // but when SDNode builder decides to use it in a CopyToReg(),
1633     // machine instruction fails because TargetExternalSymbol
1634     // (not lowered) is target dependent, and CopyToReg assumes
1635     // the source is lowered.
1636     EVT ObjectVT = getValueType(Ty);
1637     assert(ObjectVT == Ins[InsIdx].VT &&
1638            "Ins type did not match function type");
1639     SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1640     SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1641     if (p.getNode())
1642       p.getNode()->setIROrder(idx + 1);
1643     if (isKernel)
1644       InVals.push_back(p);
1645     else {
1646       SDValue p2 = DAG.getNode(
1647           ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1648           DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
1649       InVals.push_back(p2);
1650     }
1651   }
1652
1653   // Clang will check explicit VarArg and issue error if any. However, Clang
1654   // will let code with
1655   // implicit var arg like f() pass. See bug 617733.
1656   // We treat this case as if the arg list is empty.
1657   // if (F.isVarArg()) {
1658   // assert(0 && "VarArg not supported yet!");
1659   //}
1660
1661   if (!OutChains.empty())
1662     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &OutChains[0],
1663                             OutChains.size()));
1664
1665   return Chain;
1666 }
1667
1668
1669 SDValue
1670 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1671                                  bool isVarArg,
1672                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
1673                                  const SmallVectorImpl<SDValue> &OutVals,
1674                                  SDLoc dl, SelectionDAG &DAG) const {
1675   MachineFunction &MF = DAG.getMachineFunction();
1676   const Function *F = MF.getFunction();
1677   Type *RetTy = F->getReturnType();
1678   const DataLayout *TD = getDataLayout();
1679
1680   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1681   assert(isABI && "Non-ABI compilation is not supported");
1682   if (!isABI)
1683     return Chain;
1684
1685   if (VectorType *VTy = dyn_cast<VectorType>(RetTy)) {
1686     // If we have a vector type, the OutVals array will be the scalarized
1687     // components and we have combine them into 1 or more vector stores.
1688     unsigned NumElts = VTy->getNumElements();
1689     assert(NumElts == Outs.size() && "Bad scalarization of return value");
1690
1691     // const_cast can be removed in later LLVM versions
1692     EVT EltVT = getValueType(RetTy).getVectorElementType();
1693     bool NeedExtend = false;
1694     if (EltVT.getSizeInBits() < 16)
1695       NeedExtend = true;
1696
1697     // V1 store
1698     if (NumElts == 1) {
1699       SDValue StoreVal = OutVals[0];
1700       // We only have one element, so just directly store it
1701       if (NeedExtend)
1702         StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
1703       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal };
1704       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1705                                       DAG.getVTList(MVT::Other), &Ops[0], 3,
1706                                       EltVT, MachinePointerInfo());
1707
1708     } else if (NumElts == 2) {
1709       // V2 store
1710       SDValue StoreVal0 = OutVals[0];
1711       SDValue StoreVal1 = OutVals[1];
1712
1713       if (NeedExtend) {
1714         StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal0);
1715         StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal1);
1716       }
1717
1718       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal0,
1719                         StoreVal1 };
1720       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetvalV2, dl,
1721                                       DAG.getVTList(MVT::Other), &Ops[0], 4,
1722                                       EltVT, MachinePointerInfo());
1723     } else {
1724       // V4 stores
1725       // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
1726       // vector will be expanded to a power of 2 elements, so we know we can
1727       // always round up to the next multiple of 4 when creating the vector
1728       // stores.
1729       // e.g.  4 elem => 1 st.v4
1730       //       6 elem => 2 st.v4
1731       //       8 elem => 2 st.v4
1732       //      11 elem => 3 st.v4
1733
1734       unsigned VecSize = 4;
1735       if (OutVals[0].getValueType().getSizeInBits() == 64)
1736         VecSize = 2;
1737
1738       unsigned Offset = 0;
1739
1740       EVT VecVT =
1741           EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
1742       unsigned PerStoreOffset =
1743           TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1744
1745       for (unsigned i = 0; i < NumElts; i += VecSize) {
1746         // Get values
1747         SDValue StoreVal;
1748         SmallVector<SDValue, 8> Ops;
1749         Ops.push_back(Chain);
1750         Ops.push_back(DAG.getConstant(Offset, MVT::i32));
1751         unsigned Opc = NVPTXISD::StoreRetvalV2;
1752         EVT ExtendedVT = (NeedExtend) ? MVT::i16 : OutVals[0].getValueType();
1753
1754         StoreVal = OutVals[i];
1755         if (NeedExtend)
1756           StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1757         Ops.push_back(StoreVal);
1758
1759         if (i + 1 < NumElts) {
1760           StoreVal = OutVals[i + 1];
1761           if (NeedExtend)
1762             StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1763         } else {
1764           StoreVal = DAG.getUNDEF(ExtendedVT);
1765         }
1766         Ops.push_back(StoreVal);
1767
1768         if (VecSize == 4) {
1769           Opc = NVPTXISD::StoreRetvalV4;
1770           if (i + 2 < NumElts) {
1771             StoreVal = OutVals[i + 2];
1772             if (NeedExtend)
1773               StoreVal =
1774                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1775           } else {
1776             StoreVal = DAG.getUNDEF(ExtendedVT);
1777           }
1778           Ops.push_back(StoreVal);
1779
1780           if (i + 3 < NumElts) {
1781             StoreVal = OutVals[i + 3];
1782             if (NeedExtend)
1783               StoreVal =
1784                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1785           } else {
1786             StoreVal = DAG.getUNDEF(ExtendedVT);
1787           }
1788           Ops.push_back(StoreVal);
1789         }
1790
1791         // Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
1792         Chain =
1793             DAG.getMemIntrinsicNode(Opc, dl, DAG.getVTList(MVT::Other), &Ops[0],
1794                                     Ops.size(), EltVT, MachinePointerInfo());
1795         Offset += PerStoreOffset;
1796       }
1797     }
1798   } else {
1799     SmallVector<EVT, 16> ValVTs;
1800     // const_cast is necessary since we are still using an LLVM version from
1801     // before the type system re-write.
1802     ComputePTXValueVTs(*this, RetTy, ValVTs);
1803     assert(ValVTs.size() == OutVals.size() && "Bad return value decomposition");
1804
1805     unsigned SizeSoFar = 0;
1806     for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
1807       SDValue theVal = OutVals[i];
1808       EVT TheValType = theVal.getValueType();
1809       unsigned numElems = 1;
1810       if (TheValType.isVector())
1811         numElems = TheValType.getVectorNumElements();
1812       for (unsigned j = 0, je = numElems; j != je; ++j) {
1813         SDValue TmpVal = theVal;
1814         if (TheValType.isVector())
1815           TmpVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1816                                TheValType.getVectorElementType(), TmpVal,
1817                                DAG.getIntPtrConstant(j));
1818         EVT TheStoreType = ValVTs[i];
1819         if (RetTy->isIntegerTy() &&
1820             TD->getTypeAllocSizeInBits(RetTy) < 32) {
1821           // The following zero-extension is for integer types only, and
1822           // specifically not for aggregates.
1823           TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal);
1824           TheStoreType = MVT::i32;
1825         }
1826         else if (TmpVal.getValueType().getSizeInBits() < 16)
1827           TmpVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, TmpVal);
1828
1829         SDValue Ops[] = { Chain, DAG.getConstant(SizeSoFar, MVT::i32), TmpVal };
1830         Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1831                                         DAG.getVTList(MVT::Other), &Ops[0],
1832                                         3, TheStoreType,
1833                                         MachinePointerInfo());
1834         if(TheValType.isVector())
1835           SizeSoFar += 
1836             TheStoreType.getVectorElementType().getStoreSizeInBits() / 8;
1837         else
1838           SizeSoFar += TheStoreType.getStoreSizeInBits()/8;
1839       }
1840     }
1841   }
1842
1843   return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1844 }
1845
1846
1847 void NVPTXTargetLowering::LowerAsmOperandForConstraint(
1848     SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
1849     SelectionDAG &DAG) const {
1850   if (Constraint.length() > 1)
1851     return;
1852   else
1853     TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1854 }
1855
1856 // NVPTX suuport vector of legal types of any length in Intrinsics because the
1857 // NVPTX specific type legalizer
1858 // will legalize them to the PTX supported length.
1859 bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1860   if (isTypeLegal(VT))
1861     return true;
1862   if (VT.isVector()) {
1863     MVT eVT = VT.getVectorElementType();
1864     if (isTypeLegal(eVT))
1865       return true;
1866   }
1867   return false;
1868 }
1869
1870 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
1871 // TgtMemIntrinsic
1872 // because we need the information that is only available in the "Value" type
1873 // of destination
1874 // pointer. In particular, the address space information.
1875 bool NVPTXTargetLowering::getTgtMemIntrinsic(
1876     IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const {
1877   switch (Intrinsic) {
1878   default:
1879     return false;
1880
1881   case Intrinsic::nvvm_atomic_load_add_f32:
1882     Info.opc = ISD::INTRINSIC_W_CHAIN;
1883     Info.memVT = MVT::f32;
1884     Info.ptrVal = I.getArgOperand(0);
1885     Info.offset = 0;
1886     Info.vol = 0;
1887     Info.readMem = true;
1888     Info.writeMem = true;
1889     Info.align = 0;
1890     return true;
1891
1892   case Intrinsic::nvvm_atomic_load_inc_32:
1893   case Intrinsic::nvvm_atomic_load_dec_32:
1894     Info.opc = ISD::INTRINSIC_W_CHAIN;
1895     Info.memVT = MVT::i32;
1896     Info.ptrVal = I.getArgOperand(0);
1897     Info.offset = 0;
1898     Info.vol = 0;
1899     Info.readMem = true;
1900     Info.writeMem = true;
1901     Info.align = 0;
1902     return true;
1903
1904   case Intrinsic::nvvm_ldu_global_i:
1905   case Intrinsic::nvvm_ldu_global_f:
1906   case Intrinsic::nvvm_ldu_global_p:
1907
1908     Info.opc = ISD::INTRINSIC_W_CHAIN;
1909     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
1910       Info.memVT = getValueType(I.getType());
1911     else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
1912       Info.memVT = getValueType(I.getType());
1913     else
1914       Info.memVT = MVT::f32;
1915     Info.ptrVal = I.getArgOperand(0);
1916     Info.offset = 0;
1917     Info.vol = 0;
1918     Info.readMem = true;
1919     Info.writeMem = false;
1920     Info.align = 0;
1921     return true;
1922
1923   }
1924   return false;
1925 }
1926
1927 /// isLegalAddressingMode - Return true if the addressing mode represented
1928 /// by AM is legal for this target, for a load/store of the specified type.
1929 /// Used to guide target specific optimizations, like loop strength reduction
1930 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
1931 /// (CodeGenPrepare.cpp)
1932 bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
1933                                                 Type *Ty) const {
1934
1935   // AddrMode - This represents an addressing mode of:
1936   //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
1937   //
1938   // The legal address modes are
1939   // - [avar]
1940   // - [areg]
1941   // - [areg+immoff]
1942   // - [immAddr]
1943
1944   if (AM.BaseGV) {
1945     if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
1946       return false;
1947     return true;
1948   }
1949
1950   switch (AM.Scale) {
1951   case 0: // "r", "r+i" or "i" is allowed
1952     break;
1953   case 1:
1954     if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
1955       return false;
1956     // Otherwise we have r+i.
1957     break;
1958   default:
1959     // No scale > 1 is allowed
1960     return false;
1961   }
1962   return true;
1963 }
1964
1965 //===----------------------------------------------------------------------===//
1966 //                         NVPTX Inline Assembly Support
1967 //===----------------------------------------------------------------------===//
1968
1969 /// getConstraintType - Given a constraint letter, return the type of
1970 /// constraint it is for this target.
1971 NVPTXTargetLowering::ConstraintType
1972 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
1973   if (Constraint.size() == 1) {
1974     switch (Constraint[0]) {
1975     default:
1976       break;
1977     case 'r':
1978     case 'h':
1979     case 'c':
1980     case 'l':
1981     case 'f':
1982     case 'd':
1983     case '0':
1984     case 'N':
1985       return C_RegisterClass;
1986     }
1987   }
1988   return TargetLowering::getConstraintType(Constraint);
1989 }
1990
1991 std::pair<unsigned, const TargetRegisterClass *>
1992 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
1993                                                   MVT VT) const {
1994   if (Constraint.size() == 1) {
1995     switch (Constraint[0]) {
1996     case 'c':
1997       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
1998     case 'h':
1999       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2000     case 'r':
2001       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
2002     case 'l':
2003     case 'N':
2004       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
2005     case 'f':
2006       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
2007     case 'd':
2008       return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
2009     }
2010   }
2011   return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
2012 }
2013
2014 /// getFunctionAlignment - Return the Log2 alignment of this function.
2015 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
2016   return 4;
2017 }
2018
2019 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
2020 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
2021                               SmallVectorImpl<SDValue> &Results) {
2022   EVT ResVT = N->getValueType(0);
2023   SDLoc DL(N);
2024
2025   assert(ResVT.isVector() && "Vector load must have vector type");
2026
2027   // We only handle "native" vector sizes for now, e.g. <4 x double> is not
2028   // legal.  We can (and should) split that into 2 loads of <2 x double> here
2029   // but I'm leaving that as a TODO for now.
2030   assert(ResVT.isSimple() && "Can only handle simple types");
2031   switch (ResVT.getSimpleVT().SimpleTy) {
2032   default:
2033     return;
2034   case MVT::v2i8:
2035   case MVT::v2i16:
2036   case MVT::v2i32:
2037   case MVT::v2i64:
2038   case MVT::v2f32:
2039   case MVT::v2f64:
2040   case MVT::v4i8:
2041   case MVT::v4i16:
2042   case MVT::v4i32:
2043   case MVT::v4f32:
2044     // This is a "native" vector type
2045     break;
2046   }
2047
2048   EVT EltVT = ResVT.getVectorElementType();
2049   unsigned NumElts = ResVT.getVectorNumElements();
2050
2051   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
2052   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2053   // loaded type to i16 and propogate the "real" type as the memory type.
2054   bool NeedTrunc = false;
2055   if (EltVT.getSizeInBits() < 16) {
2056     EltVT = MVT::i16;
2057     NeedTrunc = true;
2058   }
2059
2060   unsigned Opcode = 0;
2061   SDVTList LdResVTs;
2062
2063   switch (NumElts) {
2064   default:
2065     return;
2066   case 2:
2067     Opcode = NVPTXISD::LoadV2;
2068     LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2069     break;
2070   case 4: {
2071     Opcode = NVPTXISD::LoadV4;
2072     EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2073     LdResVTs = DAG.getVTList(ListVTs, 5);
2074     break;
2075   }
2076   }
2077
2078   SmallVector<SDValue, 8> OtherOps;
2079
2080   // Copy regular operands
2081   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2082     OtherOps.push_back(N->getOperand(i));
2083
2084   LoadSDNode *LD = cast<LoadSDNode>(N);
2085
2086   // The select routine does not have access to the LoadSDNode instance, so
2087   // pass along the extension information
2088   OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
2089
2090   SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
2091                                           OtherOps.size(), LD->getMemoryVT(),
2092                                           LD->getMemOperand());
2093
2094   SmallVector<SDValue, 4> ScalarRes;
2095
2096   for (unsigned i = 0; i < NumElts; ++i) {
2097     SDValue Res = NewLD.getValue(i);
2098     if (NeedTrunc)
2099       Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2100     ScalarRes.push_back(Res);
2101   }
2102
2103   SDValue LoadChain = NewLD.getValue(NumElts);
2104
2105   SDValue BuildVec =
2106       DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2107
2108   Results.push_back(BuildVec);
2109   Results.push_back(LoadChain);
2110 }
2111
2112 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
2113                                      SmallVectorImpl<SDValue> &Results) {
2114   SDValue Chain = N->getOperand(0);
2115   SDValue Intrin = N->getOperand(1);
2116   SDLoc DL(N);
2117
2118   // Get the intrinsic ID
2119   unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
2120   switch (IntrinNo) {
2121   default:
2122     return;
2123   case Intrinsic::nvvm_ldg_global_i:
2124   case Intrinsic::nvvm_ldg_global_f:
2125   case Intrinsic::nvvm_ldg_global_p:
2126   case Intrinsic::nvvm_ldu_global_i:
2127   case Intrinsic::nvvm_ldu_global_f:
2128   case Intrinsic::nvvm_ldu_global_p: {
2129     EVT ResVT = N->getValueType(0);
2130
2131     if (ResVT.isVector()) {
2132       // Vector LDG/LDU
2133
2134       unsigned NumElts = ResVT.getVectorNumElements();
2135       EVT EltVT = ResVT.getVectorElementType();
2136
2137       // Since LDU/LDG are target nodes, we cannot rely on DAG type
2138       // legalization.
2139       // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2140       // loaded type to i16 and propogate the "real" type as the memory type.
2141       bool NeedTrunc = false;
2142       if (EltVT.getSizeInBits() < 16) {
2143         EltVT = MVT::i16;
2144         NeedTrunc = true;
2145       }
2146
2147       unsigned Opcode = 0;
2148       SDVTList LdResVTs;
2149
2150       switch (NumElts) {
2151       default:
2152         return;
2153       case 2:
2154         switch (IntrinNo) {
2155         default:
2156           return;
2157         case Intrinsic::nvvm_ldg_global_i:
2158         case Intrinsic::nvvm_ldg_global_f:
2159         case Intrinsic::nvvm_ldg_global_p:
2160           Opcode = NVPTXISD::LDGV2;
2161           break;
2162         case Intrinsic::nvvm_ldu_global_i:
2163         case Intrinsic::nvvm_ldu_global_f:
2164         case Intrinsic::nvvm_ldu_global_p:
2165           Opcode = NVPTXISD::LDUV2;
2166           break;
2167         }
2168         LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2169         break;
2170       case 4: {
2171         switch (IntrinNo) {
2172         default:
2173           return;
2174         case Intrinsic::nvvm_ldg_global_i:
2175         case Intrinsic::nvvm_ldg_global_f:
2176         case Intrinsic::nvvm_ldg_global_p:
2177           Opcode = NVPTXISD::LDGV4;
2178           break;
2179         case Intrinsic::nvvm_ldu_global_i:
2180         case Intrinsic::nvvm_ldu_global_f:
2181         case Intrinsic::nvvm_ldu_global_p:
2182           Opcode = NVPTXISD::LDUV4;
2183           break;
2184         }
2185         EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2186         LdResVTs = DAG.getVTList(ListVTs, 5);
2187         break;
2188       }
2189       }
2190
2191       SmallVector<SDValue, 8> OtherOps;
2192
2193       // Copy regular operands
2194
2195       OtherOps.push_back(Chain); // Chain
2196                                  // Skip operand 1 (intrinsic ID)
2197       // Others
2198       for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
2199         OtherOps.push_back(N->getOperand(i));
2200
2201       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2202
2203       SDValue NewLD = DAG.getMemIntrinsicNode(
2204           Opcode, DL, LdResVTs, &OtherOps[0], OtherOps.size(),
2205           MemSD->getMemoryVT(), MemSD->getMemOperand());
2206
2207       SmallVector<SDValue, 4> ScalarRes;
2208
2209       for (unsigned i = 0; i < NumElts; ++i) {
2210         SDValue Res = NewLD.getValue(i);
2211         if (NeedTrunc)
2212           Res =
2213               DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2214         ScalarRes.push_back(Res);
2215       }
2216
2217       SDValue LoadChain = NewLD.getValue(NumElts);
2218
2219       SDValue BuildVec =
2220           DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2221
2222       Results.push_back(BuildVec);
2223       Results.push_back(LoadChain);
2224     } else {
2225       // i8 LDG/LDU
2226       assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
2227              "Custom handling of non-i8 ldu/ldg?");
2228
2229       // Just copy all operands as-is
2230       SmallVector<SDValue, 4> Ops;
2231       for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2232         Ops.push_back(N->getOperand(i));
2233
2234       // Force output to i16
2235       SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
2236
2237       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2238
2239       // We make sure the memory type is i8, which will be used during isel
2240       // to select the proper instruction.
2241       SDValue NewLD =
2242           DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0],
2243                                   Ops.size(), MVT::i8, MemSD->getMemOperand());
2244
2245       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
2246                                     NewLD.getValue(0)));
2247       Results.push_back(NewLD.getValue(1));
2248     }
2249   }
2250   }
2251 }
2252
2253 void NVPTXTargetLowering::ReplaceNodeResults(
2254     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
2255   switch (N->getOpcode()) {
2256   default:
2257     report_fatal_error("Unhandled custom legalization");
2258   case ISD::LOAD:
2259     ReplaceLoadVector(N, DAG, Results);
2260     return;
2261   case ISD::INTRINSIC_W_CHAIN:
2262     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
2263     return;
2264   }
2265 }