Move all of the header files which are involved in modelling the LLVM IR
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
1 //
2 //                     The LLVM Compiler Infrastructure
3 //
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13
14
15 #include "NVPTXISelLowering.h"
16 #include "NVPTX.h"
17 #include "NVPTXTargetMachine.h"
18 #include "NVPTXTargetObjectFile.h"
19 #include "NVPTXUtilities.h"
20 #include "llvm/CodeGen/Analysis.h"
21 #include "llvm/CodeGen/MachineFrameInfo.h"
22 #include "llvm/CodeGen/MachineFunction.h"
23 #include "llvm/CodeGen/MachineInstrBuilder.h"
24 #include "llvm/CodeGen/MachineRegisterInfo.h"
25 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
26 #include "llvm/IR/DerivedTypes.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/GlobalValue.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/Intrinsics.h"
31 #include "llvm/IR/Module.h"
32 #include "llvm/MC/MCSectionELF.h"
33 #include "llvm/Support/CallSite.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/ErrorHandling.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <sstream>
39
40 #undef DEBUG_TYPE
41 #define DEBUG_TYPE "nvptx-lower"
42
43 using namespace llvm;
44
45 static unsigned int uniqueCallSite = 0;
46
47 static cl::opt<bool>
48 RetainVectorOperands("nvptx-codegen-vectors",
49      cl::desc("NVPTX Specific: Retain LLVM's vectors and generate PTX vectors"),
50                      cl::init(true));
51
52 static cl::opt<bool>
53 sched4reg("nvptx-sched4reg",
54           cl::desc("NVPTX Specific: schedule for register pressue"),
55           cl::init(false));
56
57 // NVPTXTargetLowering Constructor.
58 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
59 : TargetLowering(TM, new NVPTXTargetObjectFile()),
60   nvTM(&TM),
61   nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
62
63   // always lower memset, memcpy, and memmove intrinsics to load/store
64   // instructions, rather
65   // then generating calls to memset, mempcy or memmove.
66   maxStoresPerMemset = (unsigned)0xFFFFFFFF;
67   maxStoresPerMemcpy = (unsigned)0xFFFFFFFF;
68   maxStoresPerMemmove = (unsigned)0xFFFFFFFF;
69
70   setBooleanContents(ZeroOrNegativeOneBooleanContent);
71
72   // Jump is Expensive. Don't create extra control flow for 'and', 'or'
73   // condition branches.
74   setJumpIsExpensive(true);
75
76   // By default, use the Source scheduling
77   if (sched4reg)
78     setSchedulingPreference(Sched::RegPressure);
79   else
80     setSchedulingPreference(Sched::Source);
81
82   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
83   addRegisterClass(MVT::i8, &NVPTX::Int8RegsRegClass);
84   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
85   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
86   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
87   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
88   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
89
90   if (RetainVectorOperands) {
91     addRegisterClass(MVT::v2f32, &NVPTX::V2F32RegsRegClass);
92     addRegisterClass(MVT::v4f32, &NVPTX::V4F32RegsRegClass);
93     addRegisterClass(MVT::v2i32, &NVPTX::V2I32RegsRegClass);
94     addRegisterClass(MVT::v4i32, &NVPTX::V4I32RegsRegClass);
95     addRegisterClass(MVT::v2f64, &NVPTX::V2F64RegsRegClass);
96     addRegisterClass(MVT::v2i64, &NVPTX::V2I64RegsRegClass);
97     addRegisterClass(MVT::v2i16, &NVPTX::V2I16RegsRegClass);
98     addRegisterClass(MVT::v4i16, &NVPTX::V4I16RegsRegClass);
99     addRegisterClass(MVT::v2i8, &NVPTX::V2I8RegsRegClass);
100     addRegisterClass(MVT::v4i8, &NVPTX::V4I8RegsRegClass);
101
102     setOperationAction(ISD::BUILD_VECTOR, MVT::v4i32  , Custom);
103     setOperationAction(ISD::BUILD_VECTOR, MVT::v4f32  , Custom);
104     setOperationAction(ISD::BUILD_VECTOR, MVT::v4i16  , Custom);
105     setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8   , Custom);
106     setOperationAction(ISD::BUILD_VECTOR, MVT::v2i64  , Custom);
107     setOperationAction(ISD::BUILD_VECTOR, MVT::v2f64  , Custom);
108     setOperationAction(ISD::BUILD_VECTOR, MVT::v2i32  , Custom);
109     setOperationAction(ISD::BUILD_VECTOR, MVT::v2f32  , Custom);
110     setOperationAction(ISD::BUILD_VECTOR, MVT::v2i16  , Custom);
111     setOperationAction(ISD::BUILD_VECTOR, MVT::v2i8   , Custom);
112
113     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v4i32  , Custom);
114     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v4f32  , Custom);
115     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v4i16  , Custom);
116     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v4i8   , Custom);
117     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2i64  , Custom);
118     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2f64  , Custom);
119     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2i32  , Custom);
120     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2f32  , Custom);
121     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2i16  , Custom);
122     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2i8   , Custom);
123   }
124
125   // Operations not directly supported by NVPTX.
126   setOperationAction(ISD::SELECT_CC,         MVT::Other, Expand);
127   setOperationAction(ISD::BR_CC,             MVT::Other, Expand);
128   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Expand);
129   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Expand);
130   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand);
131   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Expand);
132   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1 , Expand);
133
134   if (nvptxSubtarget.hasROT64()) {
135     setOperationAction(ISD::ROTL , MVT::i64, Legal);
136     setOperationAction(ISD::ROTR , MVT::i64, Legal);
137   }
138   else {
139     setOperationAction(ISD::ROTL , MVT::i64, Expand);
140     setOperationAction(ISD::ROTR , MVT::i64, Expand);
141   }
142   if (nvptxSubtarget.hasROT32()) {
143     setOperationAction(ISD::ROTL , MVT::i32, Legal);
144     setOperationAction(ISD::ROTR , MVT::i32, Legal);
145   }
146   else {
147     setOperationAction(ISD::ROTL , MVT::i32, Expand);
148     setOperationAction(ISD::ROTR , MVT::i32, Expand);
149   }
150
151   setOperationAction(ISD::ROTL , MVT::i16, Expand);
152   setOperationAction(ISD::ROTR , MVT::i16, Expand);
153   setOperationAction(ISD::ROTL , MVT::i8, Expand);
154   setOperationAction(ISD::ROTR , MVT::i8, Expand);
155   setOperationAction(ISD::BSWAP , MVT::i16, Expand);
156   setOperationAction(ISD::BSWAP , MVT::i32, Expand);
157   setOperationAction(ISD::BSWAP , MVT::i64, Expand);
158
159   // Indirect branch is not supported.
160   // This also disables Jump Table creation.
161   setOperationAction(ISD::BR_JT,             MVT::Other, Expand);
162   setOperationAction(ISD::BRIND,             MVT::Other, Expand);
163
164   setOperationAction(ISD::GlobalAddress   , MVT::i32  , Custom);
165   setOperationAction(ISD::GlobalAddress   , MVT::i64  , Custom);
166
167   // We want to legalize constant related memmove and memcopy
168   // intrinsics.
169   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
170
171   // Turn FP extload into load/fextend
172   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
173   // Turn FP truncstore into trunc + store.
174   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
175
176   // PTX does not support load / store predicate registers
177   setOperationAction(ISD::LOAD, MVT::i1, Custom);
178   setOperationAction(ISD::STORE, MVT::i1, Custom);
179
180   setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
181   setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
182   setTruncStoreAction(MVT::i64, MVT::i1, Expand);
183   setTruncStoreAction(MVT::i32, MVT::i1, Expand);
184   setTruncStoreAction(MVT::i16, MVT::i1, Expand);
185   setTruncStoreAction(MVT::i8, MVT::i1, Expand);
186
187   // This is legal in NVPTX
188   setOperationAction(ISD::ConstantFP,         MVT::f64, Legal);
189   setOperationAction(ISD::ConstantFP,         MVT::f32, Legal);
190
191   // TRAP can be lowered to PTX trap
192   setOperationAction(ISD::TRAP,               MVT::Other, Legal);
193
194   // By default, CONCAT_VECTORS is implemented via store/load
195   // through stack. It is slow and uses local memory. We need
196   // to custom-lowering them.
197   setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i32  , Custom);
198   setOperationAction(ISD::CONCAT_VECTORS, MVT::v4f32  , Custom);
199   setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i16  , Custom);
200   setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i8   , Custom);
201   setOperationAction(ISD::CONCAT_VECTORS, MVT::v2i64  , Custom);
202   setOperationAction(ISD::CONCAT_VECTORS, MVT::v2f64  , Custom);
203   setOperationAction(ISD::CONCAT_VECTORS, MVT::v2i32  , Custom);
204   setOperationAction(ISD::CONCAT_VECTORS, MVT::v2f32  , Custom);
205   setOperationAction(ISD::CONCAT_VECTORS, MVT::v2i16  , Custom);
206   setOperationAction(ISD::CONCAT_VECTORS, MVT::v2i8   , Custom);
207
208   // Expand vector int to float and float to int conversions
209   // - For SINT_TO_FP and UINT_TO_FP, the src type
210   //   (Node->getOperand(0).getValueType())
211   //   is used to determine the action, while for FP_TO_UINT and FP_TO_SINT,
212   //   the dest type (Node->getValueType(0)) is used.
213   //
214   //   See VectorLegalizer::LegalizeOp() (LegalizeVectorOps.cpp) for the vector
215   //   case, and
216   //   SelectionDAGLegalize::LegalizeOp() (LegalizeDAG.cpp) for the scalar case.
217   //
218   //   That is why v4i32 or v2i32 are used here.
219   //
220   //   The expansion for vectors happens in VectorLegalizer::LegalizeOp()
221   //   (LegalizeVectorOps.cpp).
222   setOperationAction(ISD::SINT_TO_FP, MVT::v4i32, Expand);
223   setOperationAction(ISD::SINT_TO_FP, MVT::v2i32, Expand);
224   setOperationAction(ISD::UINT_TO_FP, MVT::v4i32, Expand);
225   setOperationAction(ISD::UINT_TO_FP, MVT::v2i32, Expand);
226   setOperationAction(ISD::FP_TO_SINT, MVT::v2i32, Expand);
227   setOperationAction(ISD::FP_TO_SINT, MVT::v4i32, Expand);
228   setOperationAction(ISD::FP_TO_UINT, MVT::v2i32, Expand);
229   setOperationAction(ISD::FP_TO_UINT, MVT::v4i32, Expand);
230
231   // Now deduce the information based on the above mentioned
232   // actions
233   computeRegisterProperties();
234 }
235
236
237 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
238   switch (Opcode) {
239   default: return 0;
240   case NVPTXISD::CALL:            return "NVPTXISD::CALL";
241   case NVPTXISD::RET_FLAG:        return "NVPTXISD::RET_FLAG";
242   case NVPTXISD::Wrapper:         return "NVPTXISD::Wrapper";
243   case NVPTXISD::NVBuiltin:       return "NVPTXISD::NVBuiltin";
244   case NVPTXISD::DeclareParam:    return "NVPTXISD::DeclareParam";
245   case NVPTXISD::DeclareScalarParam:
246     return "NVPTXISD::DeclareScalarParam";
247   case NVPTXISD::DeclareRet:      return "NVPTXISD::DeclareRet";
248   case NVPTXISD::DeclareRetParam: return "NVPTXISD::DeclareRetParam";
249   case NVPTXISD::PrintCall:       return "NVPTXISD::PrintCall";
250   case NVPTXISD::LoadParam:       return "NVPTXISD::LoadParam";
251   case NVPTXISD::StoreParam:      return "NVPTXISD::StoreParam";
252   case NVPTXISD::StoreParamS32:   return "NVPTXISD::StoreParamS32";
253   case NVPTXISD::StoreParamU32:   return "NVPTXISD::StoreParamU32";
254   case NVPTXISD::MoveToParam:     return "NVPTXISD::MoveToParam";
255   case NVPTXISD::CallArgBegin:    return "NVPTXISD::CallArgBegin";
256   case NVPTXISD::CallArg:         return "NVPTXISD::CallArg";
257   case NVPTXISD::LastCallArg:     return "NVPTXISD::LastCallArg";
258   case NVPTXISD::CallArgEnd:      return "NVPTXISD::CallArgEnd";
259   case NVPTXISD::CallVoid:        return "NVPTXISD::CallVoid";
260   case NVPTXISD::CallVal:         return "NVPTXISD::CallVal";
261   case NVPTXISD::CallSymbol:      return "NVPTXISD::CallSymbol";
262   case NVPTXISD::Prototype:       return "NVPTXISD::Prototype";
263   case NVPTXISD::MoveParam:       return "NVPTXISD::MoveParam";
264   case NVPTXISD::MoveRetval:      return "NVPTXISD::MoveRetval";
265   case NVPTXISD::MoveToRetval:    return "NVPTXISD::MoveToRetval";
266   case NVPTXISD::StoreRetval:     return "NVPTXISD::StoreRetval";
267   case NVPTXISD::PseudoUseParam:  return "NVPTXISD::PseudoUseParam";
268   case NVPTXISD::RETURN:          return "NVPTXISD::RETURN";
269   case NVPTXISD::CallSeqBegin:    return "NVPTXISD::CallSeqBegin";
270   case NVPTXISD::CallSeqEnd:      return "NVPTXISD::CallSeqEnd";
271   }
272 }
273
274 bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
275   return VT == MVT::i1;
276 }
277
278 SDValue
279 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
280   DebugLoc dl = Op.getDebugLoc();
281   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
282   Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
283   return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
284 }
285
286 std::string NVPTXTargetLowering::getPrototype(Type *retTy,
287                                               const ArgListTy &Args,
288                                     const SmallVectorImpl<ISD::OutputArg> &Outs,
289                                               unsigned retAlignment) const {
290
291   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
292
293   std::stringstream O;
294   O << "prototype_" << uniqueCallSite << " : .callprototype ";
295
296   if (retTy->getTypeID() == Type::VoidTyID)
297     O << "()";
298   else {
299     O << "(";
300     if (isABI) {
301       if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
302         unsigned size = 0;
303         if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
304           size = ITy->getBitWidth();
305           if (size < 32) size = 32;
306         }
307         else {
308           assert(retTy->isFloatingPointTy() &&
309                  "Floating point type expected here");
310           size = retTy->getPrimitiveSizeInBits();
311         }
312
313         O << ".param .b" << size << " _";
314       }
315       else if (isa<PointerType>(retTy))
316         O << ".param .b" << getPointerTy().getSizeInBits()
317         << " _";
318       else {
319         if ((retTy->getTypeID() == Type::StructTyID) ||
320             isa<VectorType>(retTy)) {
321           SmallVector<EVT, 16> vtparts;
322           ComputeValueVTs(*this, retTy, vtparts);
323           unsigned totalsz = 0;
324           for (unsigned i=0,e=vtparts.size(); i!=e; ++i) {
325             unsigned elems = 1;
326             EVT elemtype = vtparts[i];
327             if (vtparts[i].isVector()) {
328               elems = vtparts[i].getVectorNumElements();
329               elemtype = vtparts[i].getVectorElementType();
330             }
331             for (unsigned j=0, je=elems; j!=je; ++j) {
332               unsigned sz = elemtype.getSizeInBits();
333               if (elemtype.isInteger() && (sz < 8)) sz = 8;
334               totalsz += sz/8;
335             }
336           }
337           O << ".param .align "
338               << retAlignment
339               << " .b8 _["
340               << totalsz << "]";
341         }
342         else {
343           assert(false &&
344                  "Unknown return type");
345         }
346       }
347     }
348     else {
349       SmallVector<EVT, 16> vtparts;
350       ComputeValueVTs(*this, retTy, vtparts);
351       unsigned idx = 0;
352       for (unsigned i=0,e=vtparts.size(); i!=e; ++i) {
353         unsigned elems = 1;
354         EVT elemtype = vtparts[i];
355         if (vtparts[i].isVector()) {
356           elems = vtparts[i].getVectorNumElements();
357           elemtype = vtparts[i].getVectorElementType();
358         }
359
360         for (unsigned j=0, je=elems; j!=je; ++j) {
361           unsigned sz = elemtype.getSizeInBits();
362           if (elemtype.isInteger() && (sz < 32)) sz = 32;
363           O << ".reg .b" << sz << " _";
364           if (j<je-1) O << ", ";
365           ++idx;
366         }
367         if (i < e-1)
368           O << ", ";
369       }
370     }
371     O << ") ";
372   }
373   O << "_ (";
374
375   bool first = true;
376   MVT thePointerTy = getPointerTy();
377
378   for (unsigned i=0,e=Args.size(); i!=e; ++i) {
379     const Type *Ty = Args[i].Ty;
380     if (!first) {
381       O << ", ";
382     }
383     first = false;
384
385     if (Outs[i].Flags.isByVal() == false) {
386       unsigned sz = 0;
387       if (isa<IntegerType>(Ty)) {
388         sz = cast<IntegerType>(Ty)->getBitWidth();
389         if (sz < 32) sz = 32;
390       }
391       else if (isa<PointerType>(Ty))
392         sz = thePointerTy.getSizeInBits();
393       else
394         sz = Ty->getPrimitiveSizeInBits();
395       if (isABI)
396         O << ".param .b" << sz << " ";
397       else
398         O << ".reg .b" << sz << " ";
399       O << "_";
400       continue;
401     }
402     const PointerType *PTy = dyn_cast<PointerType>(Ty);
403     assert(PTy &&
404            "Param with byval attribute should be a pointer type");
405     Type *ETy = PTy->getElementType();
406
407     if (isABI) {
408       unsigned align = Outs[i].Flags.getByValAlign();
409       unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
410       O << ".param .align " << align
411           << " .b8 ";
412       O << "_";
413       O << "[" << sz << "]";
414       continue;
415     }
416     else {
417       SmallVector<EVT, 16> vtparts;
418       ComputeValueVTs(*this, ETy, vtparts);
419       for (unsigned i=0,e=vtparts.size(); i!=e; ++i) {
420         unsigned elems = 1;
421         EVT elemtype = vtparts[i];
422         if (vtparts[i].isVector()) {
423           elems = vtparts[i].getVectorNumElements();
424           elemtype = vtparts[i].getVectorElementType();
425         }
426
427         for (unsigned j=0,je=elems; j!=je; ++j) {
428           unsigned sz = elemtype.getSizeInBits();
429           if (elemtype.isInteger() && (sz < 32)) sz = 32;
430           O << ".reg .b" << sz << " ";
431           O << "_";
432           if (j<je-1) O << ", ";
433         }
434         if (i<e-1)
435           O << ", ";
436       }
437       continue;
438     }
439   }
440   O << ");";
441   return O.str();
442 }
443
444
445 SDValue
446 NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
447                                SmallVectorImpl<SDValue> &InVals) const {
448   SelectionDAG &DAG                     = CLI.DAG;
449   DebugLoc &dl                          = CLI.DL;
450   SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
451   SmallVector<SDValue, 32> &OutVals     = CLI.OutVals;
452   SmallVector<ISD::InputArg, 32> &Ins   = CLI.Ins;
453   SDValue Chain                         = CLI.Chain;
454   SDValue Callee                        = CLI.Callee;
455   bool &isTailCall                      = CLI.IsTailCall;
456   ArgListTy &Args                       = CLI.Args;
457   Type *retTy                           = CLI.RetTy;
458   ImmutableCallSite *CS                 = CLI.CS;
459
460   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
461
462   SDValue tempChain = Chain;
463   Chain = DAG.getCALLSEQ_START(Chain,
464                                DAG.getIntPtrConstant(uniqueCallSite, true));
465   SDValue InFlag = Chain.getValue(1);
466
467   assert((Outs.size() == Args.size()) &&
468          "Unexpected number of arguments to function call");
469   unsigned paramCount = 0;
470   // Declare the .params or .reg need to pass values
471   // to the function
472   for (unsigned i=0, e=Outs.size(); i!=e; ++i) {
473     EVT VT = Outs[i].VT;
474
475     if (Outs[i].Flags.isByVal() == false) {
476       // Plain scalar
477       // for ABI,    declare .param .b<size> .param<n>;
478       // for nonABI, declare .reg .b<size> .param<n>;
479       unsigned isReg = 1;
480       if (isABI)
481         isReg = 0;
482       unsigned sz = VT.getSizeInBits();
483       if (VT.isInteger() && (sz < 32)) sz = 32;
484       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
485       SDValue DeclareParamOps[] = { Chain,
486                                     DAG.getConstant(paramCount, MVT::i32),
487                                     DAG.getConstant(sz, MVT::i32),
488                                     DAG.getConstant(isReg, MVT::i32),
489                                     InFlag };
490       Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
491                           DeclareParamOps, 5);
492       InFlag = Chain.getValue(1);
493       SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
494       SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
495                              DAG.getConstant(0, MVT::i32), OutVals[i], InFlag };
496
497       unsigned opcode = NVPTXISD::StoreParam;
498       if (isReg)
499         opcode = NVPTXISD::MoveToParam;
500       else {
501         if (Outs[i].Flags.isZExt())
502           opcode = NVPTXISD::StoreParamU32;
503         else if (Outs[i].Flags.isSExt())
504           opcode = NVPTXISD::StoreParamS32;
505       }
506       Chain = DAG.getNode(opcode, dl, CopyParamVTs, CopyParamOps, 5);
507
508       InFlag = Chain.getValue(1);
509       ++paramCount;
510       continue;
511     }
512     // struct or vector
513     SmallVector<EVT, 16> vtparts;
514     const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
515     assert(PTy &&
516            "Type of a byval parameter should be pointer");
517     ComputeValueVTs(*this, PTy->getElementType(), vtparts);
518
519     if (isABI) {
520       // declare .param .align 16 .b8 .param<n>[<size>];
521       unsigned sz = Outs[i].Flags.getByValSize();
522       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
523       // The ByValAlign in the Outs[i].Flags is alway set at this point, so we
524       // don't need to
525       // worry about natural alignment or not. See TargetLowering::LowerCallTo()
526       SDValue DeclareParamOps[] = { Chain,
527                        DAG.getConstant(Outs[i].Flags.getByValAlign(), MVT::i32),
528                                     DAG.getConstant(paramCount, MVT::i32),
529                                     DAG.getConstant(sz, MVT::i32),
530                                     InFlag };
531       Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
532                           DeclareParamOps, 5);
533       InFlag = Chain.getValue(1);
534       unsigned curOffset = 0;
535       for (unsigned j=0,je=vtparts.size(); j!=je; ++j) {
536         unsigned elems = 1;
537         EVT elemtype = vtparts[j];
538         if (vtparts[j].isVector()) {
539           elems = vtparts[j].getVectorNumElements();
540           elemtype = vtparts[j].getVectorElementType();
541         }
542         for (unsigned k=0,ke=elems; k!=ke; ++k) {
543           unsigned sz = elemtype.getSizeInBits();
544           if (elemtype.isInteger() && (sz < 8)) sz = 8;
545           SDValue srcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(),
546                                         OutVals[i],
547                                         DAG.getConstant(curOffset,
548                                                         getPointerTy()));
549           SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
550                                 MachinePointerInfo(), false, false, false, 0);
551           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
552           SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount,
553                                                             MVT::i32),
554                                            DAG.getConstant(curOffset, MVT::i32),
555                                                             theVal, InFlag };
556           Chain = DAG.getNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
557                               CopyParamOps, 5);
558           InFlag = Chain.getValue(1);
559           curOffset += sz/8;
560         }
561       }
562       ++paramCount;
563       continue;
564     }
565     // Non-abi, struct or vector
566     // Declare a bunch or .reg .b<size> .param<n>
567     unsigned curOffset = 0;
568     for (unsigned j=0,je=vtparts.size(); j!=je; ++j) {
569       unsigned elems = 1;
570       EVT elemtype = vtparts[j];
571       if (vtparts[j].isVector()) {
572         elems = vtparts[j].getVectorNumElements();
573         elemtype = vtparts[j].getVectorElementType();
574       }
575       for (unsigned k=0,ke=elems; k!=ke; ++k) {
576         unsigned sz = elemtype.getSizeInBits();
577         if (elemtype.isInteger() && (sz < 32)) sz = 32;
578         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
579         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(paramCount,
580                                                              MVT::i32),
581                                                   DAG.getConstant(sz, MVT::i32),
582                                                    DAG.getConstant(1, MVT::i32),
583                                                              InFlag };
584         Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
585                             DeclareParamOps, 5);
586         InFlag = Chain.getValue(1);
587         SDValue srcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i],
588                                       DAG.getConstant(curOffset,
589                                                       getPointerTy()));
590         SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
591                                   MachinePointerInfo(), false, false, false, 0);
592         SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
593         SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
594                                    DAG.getConstant(0, MVT::i32), theVal,
595                                    InFlag };
596         Chain = DAG.getNode(NVPTXISD::MoveToParam, dl, CopyParamVTs,
597                             CopyParamOps, 5);
598         InFlag = Chain.getValue(1);
599         ++paramCount;
600       }
601     }
602   }
603
604   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
605   unsigned retAlignment = 0;
606
607   // Handle Result
608   unsigned retCount = 0;
609   if (Ins.size() > 0) {
610     SmallVector<EVT, 16> resvtparts;
611     ComputeValueVTs(*this, retTy, resvtparts);
612
613     // Declare one .param .align 16 .b8 func_retval0[<size>] for ABI or
614     // individual .reg .b<size> func_retval<0..> for non ABI
615     unsigned resultsz = 0;
616     for (unsigned i=0,e=resvtparts.size(); i!=e; ++i) {
617       unsigned elems = 1;
618       EVT elemtype = resvtparts[i];
619       if (resvtparts[i].isVector()) {
620         elems = resvtparts[i].getVectorNumElements();
621         elemtype = resvtparts[i].getVectorElementType();
622       }
623       for (unsigned j=0,je=elems; j!=je; ++j) {
624         unsigned sz = elemtype.getSizeInBits();
625         if (isABI == false) {
626           if (elemtype.isInteger() && (sz < 32)) sz = 32;
627         }
628         else {
629           if (elemtype.isInteger() && (sz < 8)) sz = 8;
630         }
631         if (isABI == false) {
632           SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
633           SDValue DeclareRetOps[] = { Chain, DAG.getConstant(2, MVT::i32),
634                                       DAG.getConstant(sz, MVT::i32),
635                                       DAG.getConstant(retCount, MVT::i32),
636                                       InFlag };
637           Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
638                               DeclareRetOps, 5);
639           InFlag = Chain.getValue(1);
640           ++retCount;
641         }
642         resultsz += sz;
643       }
644     }
645     if (isABI) {
646       if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
647           retTy->isPointerTy() ) {
648         // Scalar needs to be at least 32bit wide
649         if (resultsz < 32)
650           resultsz = 32;
651         SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
652         SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
653                                     DAG.getConstant(resultsz, MVT::i32),
654                                     DAG.getConstant(0, MVT::i32), InFlag };
655         Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
656                             DeclareRetOps, 5);
657         InFlag = Chain.getValue(1);
658       }
659       else {
660         if (Func) { // direct call
661           if (!llvm::getAlign(*(CS->getCalledFunction()), 0, retAlignment))
662             retAlignment = getDataLayout()->getABITypeAlignment(retTy);
663         } else { // indirect call
664           const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
665           if (!llvm::getAlign(*CallI, 0, retAlignment))
666             retAlignment = getDataLayout()->getABITypeAlignment(retTy);
667         }
668         SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
669         SDValue DeclareRetOps[] = { Chain, DAG.getConstant(retAlignment,
670                                                            MVT::i32),
671                                           DAG.getConstant(resultsz/8, MVT::i32),
672                                          DAG.getConstant(0, MVT::i32), InFlag };
673         Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
674                             DeclareRetOps, 5);
675         InFlag = Chain.getValue(1);
676       }
677     }
678   }
679
680   if (!Func) {
681     // This is indirect function call case : PTX requires a prototype of the
682     // form
683     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
684     // to be emitted, and the label has to used as the last arg of call
685     // instruction.
686     // The prototype is embedded in a string and put as the operand for an
687     // INLINEASM SDNode.
688     SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
689     std::string proto_string = getPrototype(retTy, Args, Outs, retAlignment);
690     const char *asmstr = nvTM->getManagedStrPool()->
691         getManagedString(proto_string.c_str())->c_str();
692     SDValue InlineAsmOps[] = { Chain,
693                                DAG.getTargetExternalSymbol(asmstr,
694                                                            getPointerTy()),
695                                                            DAG.getMDNode(0),
696                                    DAG.getTargetConstant(0, MVT::i32), InFlag };
697     Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
698     InFlag = Chain.getValue(1);
699   }
700   // Op to just print "call"
701   SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
702   SDValue PrintCallOps[] = { Chain,
703                              DAG.getConstant(isABI ? ((Ins.size()==0) ? 0 : 1)
704                                  : retCount, MVT::i32),
705                                    InFlag };
706   Chain = DAG.getNode(Func?(NVPTXISD::PrintCallUni):(NVPTXISD::PrintCall), dl,
707       PrintCallVTs, PrintCallOps, 3);
708   InFlag = Chain.getValue(1);
709
710   // Ops to print out the function name
711   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
712   SDValue CallVoidOps[] = { Chain, Callee, InFlag };
713   Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
714   InFlag = Chain.getValue(1);
715
716   // Ops to print out the param list
717   SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
718   SDValue CallArgBeginOps[] = { Chain, InFlag };
719   Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
720                       CallArgBeginOps, 2);
721   InFlag = Chain.getValue(1);
722
723   for (unsigned i=0, e=paramCount; i!=e; ++i) {
724     unsigned opcode;
725     if (i==(e-1))
726       opcode = NVPTXISD::LastCallArg;
727     else
728       opcode = NVPTXISD::CallArg;
729     SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
730     SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
731                              DAG.getConstant(i, MVT::i32),
732                              InFlag };
733     Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
734     InFlag = Chain.getValue(1);
735   }
736   SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
737   SDValue CallArgEndOps[] = { Chain,
738                               DAG.getConstant(Func ? 1 : 0, MVT::i32),
739                               InFlag };
740   Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps,
741                       3);
742   InFlag = Chain.getValue(1);
743
744   if (!Func) {
745     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
746     SDValue PrototypeOps[] = { Chain,
747                                DAG.getConstant(uniqueCallSite, MVT::i32),
748                                InFlag };
749     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
750     InFlag = Chain.getValue(1);
751   }
752
753   // Generate loads from param memory/moves from registers for result
754   if (Ins.size() > 0) {
755     if (isABI) {
756       unsigned resoffset = 0;
757       for (unsigned i=0,e=Ins.size(); i!=e; ++i) {
758         unsigned sz = Ins[i].VT.getSizeInBits();
759         if (Ins[i].VT.isInteger() && (sz < 8)) sz = 8;
760         std::vector<EVT> LoadRetVTs;
761         LoadRetVTs.push_back(Ins[i].VT);
762         LoadRetVTs.push_back(MVT::Other); LoadRetVTs.push_back(MVT::Glue);
763         std::vector<SDValue> LoadRetOps;
764         LoadRetOps.push_back(Chain);
765         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
766         LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
767         LoadRetOps.push_back(InFlag);
768         SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, LoadRetVTs,
769                                      &LoadRetOps[0], LoadRetOps.size());
770         Chain = retval.getValue(1);
771         InFlag = retval.getValue(2);
772         InVals.push_back(retval);
773         resoffset += sz/8;
774       }
775     }
776     else {
777       SmallVector<EVT, 16> resvtparts;
778       ComputeValueVTs(*this, retTy, resvtparts);
779
780       assert(Ins.size() == resvtparts.size() &&
781              "Unexpected number of return values in non-ABI case");
782       unsigned paramNum = 0;
783       for (unsigned i=0,e=Ins.size(); i!=e; ++i) {
784         assert(EVT(Ins[i].VT) == resvtparts[i] &&
785                "Unexpected EVT type in non-ABI case");
786         unsigned numelems = 1;
787         EVT elemtype = Ins[i].VT;
788         if (Ins[i].VT.isVector()) {
789           numelems = Ins[i].VT.getVectorNumElements();
790           elemtype = Ins[i].VT.getVectorElementType();
791         }
792         std::vector<SDValue> tempRetVals;
793         for (unsigned j=0; j<numelems; ++j) {
794           std::vector<EVT> MoveRetVTs;
795           MoveRetVTs.push_back(elemtype);
796           MoveRetVTs.push_back(MVT::Other); MoveRetVTs.push_back(MVT::Glue);
797           std::vector<SDValue> MoveRetOps;
798           MoveRetOps.push_back(Chain);
799           MoveRetOps.push_back(DAG.getConstant(0, MVT::i32));
800           MoveRetOps.push_back(DAG.getConstant(paramNum, MVT::i32));
801           MoveRetOps.push_back(InFlag);
802           SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, MoveRetVTs,
803                                        &MoveRetOps[0], MoveRetOps.size());
804           Chain = retval.getValue(1);
805           InFlag = retval.getValue(2);
806           tempRetVals.push_back(retval);
807           ++paramNum;
808         }
809         if (Ins[i].VT.isVector())
810           InVals.push_back(DAG.getNode(ISD::BUILD_VECTOR, dl, Ins[i].VT,
811                                        &tempRetVals[0], tempRetVals.size()));
812         else
813           InVals.push_back(tempRetVals[0]);
814       }
815     }
816   }
817   Chain = DAG.getCALLSEQ_END(Chain,
818                              DAG.getIntPtrConstant(uniqueCallSite, true),
819                              DAG.getIntPtrConstant(uniqueCallSite+1, true),
820                              InFlag);
821   uniqueCallSite++;
822
823   // set isTailCall to false for now, until we figure out how to express
824   // tail call optimization in PTX
825   isTailCall = false;
826   return Chain;
827 }
828
829 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
830 // (see LegalizeDAG.cpp). This is slow and uses local memory.
831 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
832 SDValue NVPTXTargetLowering::
833 LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
834   SDNode *Node = Op.getNode();
835   DebugLoc dl = Node->getDebugLoc();
836   SmallVector<SDValue, 8> Ops;
837   unsigned NumOperands = Node->getNumOperands();
838   for (unsigned i=0; i < NumOperands; ++i) {
839     SDValue SubOp = Node->getOperand(i);
840     EVT VVT = SubOp.getNode()->getValueType(0);
841     EVT EltVT = VVT.getVectorElementType();
842     unsigned NumSubElem = VVT.getVectorNumElements();
843     for (unsigned j=0; j < NumSubElem; ++j) {
844       Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
845                                 DAG.getIntPtrConstant(j)));
846     }
847   }
848   return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0),
849                      &Ops[0], Ops.size());
850 }
851
852 SDValue NVPTXTargetLowering::
853 LowerOperation(SDValue Op, SelectionDAG &DAG) const {
854   switch (Op.getOpcode()) {
855   case ISD::RETURNADDR: return SDValue();
856   case ISD::FRAMEADDR:  return SDValue();
857   case ISD::GlobalAddress:      return LowerGlobalAddress(Op, DAG);
858   case ISD::INTRINSIC_W_CHAIN: return Op;
859   case ISD::BUILD_VECTOR:
860   case ISD::EXTRACT_SUBVECTOR:
861     return Op;
862   case ISD::CONCAT_VECTORS: return LowerCONCAT_VECTORS(Op, DAG);
863   case ISD::STORE: return LowerSTORE(Op, DAG);
864   case ISD::LOAD: return LowerLOAD(Op, DAG);
865   default:
866     llvm_unreachable("Custom lowering not defined for operation");
867   }
868 }
869
870
871 // v = ld i1* addr
872 //   =>
873 // v1 = ld i8* addr
874 // v = trunc v1 to i1
875 SDValue NVPTXTargetLowering::
876 LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
877   SDNode *Node = Op.getNode();
878   LoadSDNode *LD = cast<LoadSDNode>(Node);
879   DebugLoc dl = Node->getDebugLoc();
880   assert(LD->getExtensionType() == ISD::NON_EXTLOAD) ;
881   assert(Node->getValueType(0) == MVT::i1 &&
882          "Custom lowering for i1 load only");
883   SDValue newLD = DAG.getLoad(MVT::i8, dl, LD->getChain(), LD->getBasePtr(),
884                               LD->getPointerInfo(),
885                               LD->isVolatile(), LD->isNonTemporal(),
886                               LD->isInvariant(),
887                               LD->getAlignment());
888   SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
889   // The legalizer (the caller) is expecting two values from the legalized
890   // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
891   // in LegalizeDAG.cpp which also uses MergeValues.
892   SDValue Ops[] = {result, LD->getChain()};
893   return DAG.getMergeValues(Ops, 2, dl);
894 }
895
896 // st i1 v, addr
897 //    =>
898 // v1 = zxt v to i8
899 // st i8, addr
900 SDValue NVPTXTargetLowering::
901 LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
902   SDNode *Node = Op.getNode();
903   DebugLoc dl = Node->getDebugLoc();
904   StoreSDNode *ST = cast<StoreSDNode>(Node);
905   SDValue Tmp1 = ST->getChain();
906   SDValue Tmp2 = ST->getBasePtr();
907   SDValue Tmp3 = ST->getValue();
908   assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
909   unsigned Alignment = ST->getAlignment();
910   bool isVolatile = ST->isVolatile();
911   bool isNonTemporal = ST->isNonTemporal();
912   Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl,
913                      MVT::i8, Tmp3);
914   SDValue Result = DAG.getStore(Tmp1, dl, Tmp3, Tmp2,
915                                 ST->getPointerInfo(), isVolatile,
916                                 isNonTemporal, Alignment);
917   return Result;
918 }
919
920
921 SDValue
922 NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname, int idx,
923                                 EVT v) const {
924   std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
925   std::stringstream suffix;
926   suffix << idx;
927   *name += suffix.str();
928   return DAG.getTargetExternalSymbol(name->c_str(), v);
929 }
930
931 SDValue
932 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
933   return getExtSymb(DAG, ".PARAM", idx, v);
934 }
935
936 SDValue
937 NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
938   return getExtSymb(DAG, ".HLPPARAM", idx);
939 }
940
941 // Check to see if the kernel argument is image*_t or sampler_t
942
943 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
944   static const char *const specialTypes[] = {
945                                              "struct._image2d_t",
946                                              "struct._image3d_t",
947                                              "struct._sampler_t"
948   };
949
950   const Type *Ty = arg->getType();
951   const PointerType *PTy = dyn_cast<PointerType>(Ty);
952
953   if (!PTy)
954     return false;
955
956   if (!context)
957     return false;
958
959   const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
960   const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
961
962   for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
963     if (TypeName == specialTypes[i])
964       return true;
965
966   return false;
967 }
968
969 SDValue
970 NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
971                                         CallingConv::ID CallConv, bool isVarArg,
972                                       const SmallVectorImpl<ISD::InputArg> &Ins,
973                                           DebugLoc dl, SelectionDAG &DAG,
974                                        SmallVectorImpl<SDValue> &InVals) const {
975   MachineFunction &MF = DAG.getMachineFunction();
976   const DataLayout *TD = getDataLayout();
977
978   const Function *F = MF.getFunction();
979   const AttributeSet &PAL = F->getAttributes();
980
981   SDValue Root = DAG.getRoot();
982   std::vector<SDValue> OutChains;
983
984   bool isKernel = llvm::isKernelFunction(*F);
985   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
986
987   std::vector<Type *> argTypes;
988   std::vector<const Argument *> theArgs;
989   for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
990       I != E; ++I) {
991     theArgs.push_back(I);
992     argTypes.push_back(I->getType());
993   }
994   assert(argTypes.size() == Ins.size() &&
995          "Ins types and function types did not match");
996
997   int idx = 0;
998   for (unsigned i=0, e=Ins.size(); i!=e; ++i, ++idx) {
999     Type *Ty = argTypes[i];
1000     EVT ObjectVT = getValueType(Ty);
1001     assert(ObjectVT == Ins[i].VT &&
1002            "Ins type did not match function type");
1003
1004     // If the kernel argument is image*_t or sampler_t, convert it to
1005     // a i32 constant holding the parameter position. This can later
1006     // matched in the AsmPrinter to output the correct mangled name.
1007     if (isImageOrSamplerVal(theArgs[i],
1008                            (theArgs[i]->getParent() ?
1009                                theArgs[i]->getParent()->getParent() : 0))) {
1010       assert(isKernel && "Only kernels can have image/sampler params");
1011       InVals.push_back(DAG.getConstant(i+1, MVT::i32));
1012       continue;
1013     }
1014
1015     if (theArgs[i]->use_empty()) {
1016       // argument is dead
1017       InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
1018       continue;
1019     }
1020
1021     // In the following cases, assign a node order of "idx+1"
1022     // to newly created nodes. The SDNOdes for params have to
1023     // appear in the same order as their order of appearance
1024     // in the original function. "idx+1" holds that order.
1025     if (PAL.hasAttribute(i+1, Attribute::ByVal) == false) {
1026       // A plain scalar.
1027       if (isABI || isKernel) {
1028         // If ABI, load from the param symbol
1029         SDValue Arg = getParamSymbol(DAG, idx);
1030         Value *srcValue = new Argument(PointerType::get(ObjectVT.getTypeForEVT(
1031             F->getContext()),
1032             llvm::ADDRESS_SPACE_PARAM));
1033         SDValue p = DAG.getLoad(ObjectVT, dl, Root, Arg,
1034                                 MachinePointerInfo(srcValue), false, false,
1035                                 false,
1036                                 TD->getABITypeAlignment(ObjectVT.getTypeForEVT(
1037                                   F->getContext())));
1038         if (p.getNode())
1039           DAG.AssignOrdering(p.getNode(), idx+1);
1040         InVals.push_back(p);
1041       }
1042       else {
1043         // If no ABI, just move the param symbol
1044         SDValue Arg = getParamSymbol(DAG, idx, ObjectVT);
1045         SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1046         if (p.getNode())
1047           DAG.AssignOrdering(p.getNode(), idx+1);
1048         InVals.push_back(p);
1049       }
1050       continue;
1051     }
1052
1053     // Param has ByVal attribute
1054     if (isABI || isKernel) {
1055       // Return MoveParam(param symbol).
1056       // Ideally, the param symbol can be returned directly,
1057       // but when SDNode builder decides to use it in a CopyToReg(),
1058       // machine instruction fails because TargetExternalSymbol
1059       // (not lowered) is target dependent, and CopyToReg assumes
1060       // the source is lowered.
1061       SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1062       SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1063       if (p.getNode())
1064         DAG.AssignOrdering(p.getNode(), idx+1);
1065       if (isKernel)
1066         InVals.push_back(p);
1067       else {
1068         SDValue p2 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1069                     DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32),
1070                                  p);
1071         InVals.push_back(p2);
1072       }
1073     } else {
1074       // Have to move a set of param symbols to registers and
1075       // store them locally and return the local pointer in InVals
1076       const PointerType *elemPtrType = dyn_cast<PointerType>(argTypes[i]);
1077       assert(elemPtrType &&
1078              "Byval parameter should be a pointer type");
1079       Type *elemType = elemPtrType->getElementType();
1080       // Compute the constituent parts
1081       SmallVector<EVT, 16> vtparts;
1082       SmallVector<uint64_t, 16> offsets;
1083       ComputeValueVTs(*this, elemType, vtparts, &offsets, 0);
1084       unsigned totalsize = 0;
1085       for (unsigned j=0, je=vtparts.size(); j!=je; ++j)
1086         totalsize += vtparts[j].getStoreSizeInBits();
1087       SDValue localcopy =  DAG.getFrameIndex(MF.getFrameInfo()->
1088                                       CreateStackObject(totalsize/8, 16, false),
1089                                              getPointerTy());
1090       unsigned sizesofar = 0;
1091       std::vector<SDValue> theChains;
1092       for (unsigned j=0, je=vtparts.size(); j!=je; ++j) {
1093         unsigned numElems = 1;
1094         if (vtparts[j].isVector()) numElems = vtparts[j].getVectorNumElements();
1095         for (unsigned k=0, ke=numElems; k!=ke; ++k) {
1096           EVT tmpvt = vtparts[j];
1097           if (tmpvt.isVector()) tmpvt = tmpvt.getVectorElementType();
1098           SDValue arg = DAG.getNode(NVPTXISD::MoveParam, dl, tmpvt,
1099                                     getParamSymbol(DAG, idx, tmpvt));
1100           SDValue addr = DAG.getNode(ISD::ADD, dl, getPointerTy(), localcopy,
1101                                     DAG.getConstant(sizesofar, getPointerTy()));
1102           theChains.push_back(DAG.getStore(Chain, dl, arg, addr,
1103                                         MachinePointerInfo(), false, false, 0));
1104           sizesofar += tmpvt.getStoreSizeInBits()/8;
1105           ++idx;
1106         }
1107       }
1108       --idx;
1109       Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &theChains[0],
1110                           theChains.size());
1111       InVals.push_back(localcopy);
1112     }
1113   }
1114
1115   // Clang will check explicit VarArg and issue error if any. However, Clang
1116   // will let code with
1117   // implicit var arg like f() pass.
1118   // We treat this case as if the arg list is empty.
1119   //if (F.isVarArg()) {
1120   // assert(0 && "VarArg not supported yet!");
1121   //}
1122
1123   if (!OutChains.empty())
1124     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
1125                             &OutChains[0], OutChains.size()));
1126
1127   return Chain;
1128 }
1129
1130 SDValue
1131 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1132                                  bool isVarArg,
1133                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
1134                                  const SmallVectorImpl<SDValue> &OutVals,
1135                                  DebugLoc dl, SelectionDAG &DAG) const {
1136
1137   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1138
1139   unsigned sizesofar = 0;
1140   unsigned idx = 0;
1141   for (unsigned i=0, e=Outs.size(); i!=e; ++i) {
1142     SDValue theVal = OutVals[i];
1143     EVT theValType = theVal.getValueType();
1144     unsigned numElems = 1;
1145     if (theValType.isVector()) numElems = theValType.getVectorNumElements();
1146     for (unsigned j=0,je=numElems; j!=je; ++j) {
1147       SDValue tmpval = theVal;
1148       if (theValType.isVector())
1149         tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1150                              theValType.getVectorElementType(),
1151                              tmpval, DAG.getIntPtrConstant(j));
1152       Chain = DAG.getNode(isABI ? NVPTXISD::StoreRetval :NVPTXISD::MoveToRetval,
1153           dl, MVT::Other,
1154           Chain,
1155           DAG.getConstant(isABI ? sizesofar : idx, MVT::i32),
1156           tmpval);
1157       if (theValType.isVector())
1158         sizesofar += theValType.getVectorElementType().getStoreSizeInBits()/8;
1159       else
1160         sizesofar += theValType.getStoreSizeInBits()/8;
1161       ++idx;
1162     }
1163   }
1164
1165   return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1166 }
1167
1168 void
1169 NVPTXTargetLowering::LowerAsmOperandForConstraint(SDValue Op,
1170                                                   std::string &Constraint,
1171                                                   std::vector<SDValue> &Ops,
1172                                                   SelectionDAG &DAG) const
1173 {
1174   if (Constraint.length() > 1)
1175     return;
1176   else
1177     TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1178 }
1179
1180 // NVPTX suuport vector of legal types of any length in Intrinsics because the
1181 // NVPTX specific type legalizer
1182 // will legalize them to the PTX supported length.
1183 bool
1184 NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1185   if (isTypeLegal(VT))
1186     return true;
1187   if (VT.isVector()) {
1188     MVT eVT = VT.getVectorElementType();
1189     if (isTypeLegal(eVT))
1190       return true;
1191   }
1192   return false;
1193 }
1194
1195
1196 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
1197 // TgtMemIntrinsic
1198 // because we need the information that is only available in the "Value" type
1199 // of destination
1200 // pointer. In particular, the address space information.
1201 bool
1202 NVPTXTargetLowering::getTgtMemIntrinsic(IntrinsicInfo& Info, const CallInst &I,
1203                                         unsigned Intrinsic) const {
1204   switch (Intrinsic) {
1205   default:
1206     return false;
1207
1208   case Intrinsic::nvvm_atomic_load_add_f32:
1209     Info.opc = ISD::INTRINSIC_W_CHAIN;
1210     Info.memVT = MVT::f32;
1211     Info.ptrVal = I.getArgOperand(0);
1212     Info.offset = 0;
1213     Info.vol = 0;
1214     Info.readMem = true;
1215     Info.writeMem = true;
1216     Info.align = 0;
1217     return true;
1218
1219   case Intrinsic::nvvm_atomic_load_inc_32:
1220   case Intrinsic::nvvm_atomic_load_dec_32:
1221     Info.opc = ISD::INTRINSIC_W_CHAIN;
1222     Info.memVT = MVT::i32;
1223     Info.ptrVal = I.getArgOperand(0);
1224     Info.offset = 0;
1225     Info.vol = 0;
1226     Info.readMem = true;
1227     Info.writeMem = true;
1228     Info.align = 0;
1229     return true;
1230
1231   case Intrinsic::nvvm_ldu_global_i:
1232   case Intrinsic::nvvm_ldu_global_f:
1233   case Intrinsic::nvvm_ldu_global_p:
1234
1235     Info.opc = ISD::INTRINSIC_W_CHAIN;
1236     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
1237       Info.memVT = MVT::i32;
1238     else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
1239       Info.memVT = getPointerTy();
1240     else
1241       Info.memVT = MVT::f32;
1242     Info.ptrVal = I.getArgOperand(0);
1243     Info.offset = 0;
1244     Info.vol = 0;
1245     Info.readMem = true;
1246     Info.writeMem = false;
1247     Info.align = 0;
1248     return true;
1249
1250   }
1251   return false;
1252 }
1253
1254 /// isLegalAddressingMode - Return true if the addressing mode represented
1255 /// by AM is legal for this target, for a load/store of the specified type.
1256 /// Used to guide target specific optimizations, like loop strength reduction
1257 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
1258 /// (CodeGenPrepare.cpp)
1259 bool
1260 NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
1261                                            Type *Ty) const {
1262
1263   // AddrMode - This represents an addressing mode of:
1264   //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
1265   //
1266   // The legal address modes are
1267   // - [avar]
1268   // - [areg]
1269   // - [areg+immoff]
1270   // - [immAddr]
1271
1272   if (AM.BaseGV) {
1273     if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
1274       return false;
1275     return true;
1276   }
1277
1278   switch (AM.Scale) {
1279   case 0:  // "r", "r+i" or "i" is allowed
1280     break;
1281   case 1:
1282     if (AM.HasBaseReg)  // "r+r+i" or "r+r" is not allowed.
1283       return false;
1284     // Otherwise we have r+i.
1285     break;
1286   default:
1287     // No scale > 1 is allowed
1288     return false;
1289   }
1290   return true;
1291 }
1292
1293 //===----------------------------------------------------------------------===//
1294 //                         NVPTX Inline Assembly Support
1295 //===----------------------------------------------------------------------===//
1296
1297 /// getConstraintType - Given a constraint letter, return the type of
1298 /// constraint it is for this target.
1299 NVPTXTargetLowering::ConstraintType
1300 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
1301   if (Constraint.size() == 1) {
1302     switch (Constraint[0]) {
1303     default:
1304       break;
1305     case 'r':
1306     case 'h':
1307     case 'c':
1308     case 'l':
1309     case 'f':
1310     case 'd':
1311     case '0':
1312     case 'N':
1313       return C_RegisterClass;
1314     }
1315   }
1316   return TargetLowering::getConstraintType(Constraint);
1317 }
1318
1319
1320 std::pair<unsigned, const TargetRegisterClass*>
1321 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
1322                                                   EVT VT) const {
1323   if (Constraint.size() == 1) {
1324     switch (Constraint[0]) {
1325     case 'c':
1326       return std::make_pair(0U, &NVPTX::Int8RegsRegClass);
1327     case 'h':
1328       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
1329     case 'r':
1330       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
1331     case 'l':
1332     case 'N':
1333       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
1334     case 'f':
1335       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
1336     case 'd':
1337       return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
1338     }
1339   }
1340   return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
1341 }
1342
1343
1344
1345 /// getFunctionAlignment - Return the Log2 alignment of this function.
1346 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
1347   return 4;
1348 }