Make TargetLowering::getPointerTy() taking DataLayout as an argument
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
index e3d25f7936c61101cb1a4e60194a9174478d93fd..26f16e74f9c9c9a7c0c42c829b61ccda7139a38f 100644 (file)
@@ -885,8 +885,9 @@ SDValue
 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
   SDLoc dl(Op);
   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
-  Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
-  return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
+  auto PtrVT = getPointerTy(DAG.getDataLayout());
+  Op = DAG.getTargetGlobalAddress(GV, dl, PtrVT);
+  return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
 }
 
 std::string
@@ -894,7 +895,7 @@ NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
                                   const SmallVectorImpl<ISD::OutputArg> &Outs,
                                   unsigned retAlignment,
                                   const ImmutableCallSite *CS) const {
-
+  auto PtrVT = getPointerTy(*getDataLayout());
   bool isABI = (STI.getSmVersion() >= 20);
   assert(isABI && "Non-ABI compilation is not supported");
   if (!isABI)
@@ -921,7 +922,7 @@ NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
 
       O << ".param .b" << size << " _";
     } else if (isa<PointerType>(retTy)) {
-      O << ".param .b" << getPointerTy().getSizeInBits() << " _";
+      O << ".param .b" << PtrVT.getSizeInBits() << " _";
     } else if ((retTy->getTypeID() == Type::StructTyID) ||
                isa<VectorType>(retTy)) {
       O << ".param .align "
@@ -936,7 +937,6 @@ NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
   O << "_ (";
 
   bool first = true;
-  MVT thePointerTy = getPointerTy();
 
   unsigned OIdx = 0;
   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
@@ -947,10 +947,10 @@ NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
     first = false;
 
     if (!Outs[OIdx].Flags.isByVal()) {
+      const DataLayout *TD = getDataLayout();
       if (Ty->isAggregateType() || Ty->isVectorTy()) {
         unsigned align = 0;
         const CallInst *CallI = cast<CallInst>(CS->getInstruction());
-        const DataLayout *TD = getDataLayout();
         // +1 because index 0 is reserved for return type alignment
         if (!llvm::getAlign(*CallI, i + 1, align))
           align = TD->getABITypeAlignment(Ty);
@@ -966,9 +966,10 @@ NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
         continue;
       }
        // i8 types in IR will be i16 types in SDAG
-      assert((getValueType(Ty) == Outs[OIdx].VT ||
-             (getValueType(Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
-             "type mismatch between callee prototype and arguments");
+      assert(
+          (getValueType(*TD, Ty) == Outs[OIdx].VT ||
+           (getValueType(*TD, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
+          "type mismatch between callee prototype and arguments");
       // scalar type
       unsigned sz = 0;
       if (isa<IntegerType>(Ty)) {
@@ -976,7 +977,7 @@ NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
         if (sz < 32)
           sz = 32;
       } else if (isa<PointerType>(Ty))
-        sz = thePointerTy.getSizeInBits();
+        sz = PtrVT.getSizeInBits();
       else
         sz = Ty->getPrimitiveSizeInBits();
       O << ".param .b" << sz << " ";
@@ -1137,7 +1138,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         continue;
       }
       if (Ty->isVectorTy()) {
-        EVT ObjectVT = getValueType(Ty);
+        EVT ObjectVT = getValueType(DL, Ty);
         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
         // declare .param .align <align> .b8 .param<n>[<size>];
         unsigned sz = DL.getTypeAllocSize(Ty);
@@ -1342,9 +1343,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       EVT elemtype = vtparts[j];
       int curOffset = Offsets[j];
       unsigned PartAlign = GreatestCommonDivisor64(ArgAlign, curOffset);
-      SDValue srcAddr =
-          DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
-                      DAG.getConstant(curOffset, dl, getPointerTy()));
+      auto PtrVT = getPointerTy(DAG.getDataLayout());
+      SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, OutVals[OIdx],
+                                    DAG.getConstant(curOffset, dl, PtrVT));
       SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
                                    MachinePointerInfo(), false, false, false,
                                    PartAlign);
@@ -1477,7 +1478,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   // Generate loads from param memory/moves from registers for result
   if (Ins.size() > 0) {
     if (retTy && retTy->isVectorTy()) {
-      EVT ObjectVT = getValueType(retTy);
+      EVT ObjectVT = getValueType(DL, retTy);
       unsigned NumElts = ObjectVT.getVectorNumElements();
       EVT EltVT = ObjectVT.getVectorElementType();
       assert(STI.getTargetLowering()->getNumRegisters(F->getContext(),
@@ -2064,6 +2065,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
     SmallVectorImpl<SDValue> &InVals) const {
   MachineFunction &MF = DAG.getMachineFunction();
   const DataLayout &DL = MF.getDataLayout();
+  auto PtrVT = getPointerTy(DL);
 
   const Function *F = MF.getFunction();
   const AttributeSet &PAL = F->getAttributes();
@@ -2129,7 +2131,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         continue;
       }
       if (Ty->isVectorTy()) {
-        EVT ObjectVT = getValueType(Ty);
+        EVT ObjectVT = getValueType(DL, Ty);
         unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
         for (unsigned parti = 0; parti < NumRegs; ++parti) {
           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
@@ -2161,7 +2163,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         if (StructType *STy = llvm::dyn_cast<StructType>(Ty))
           aggregateIsPacked = STy->isPacked();
 
-        SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
+        SDValue Arg = getParamSymbol(DAG, idx, PtrVT);
         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
              ++parti) {
           EVT partVT = vtparts[parti];
@@ -2169,8 +2171,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
               PointerType::get(partVT.getTypeForEVT(F->getContext()),
                                llvm::ADDRESS_SPACE_PARAM));
           SDValue srcAddr =
-              DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
-                          DAG.getConstant(offsets[parti], dl, getPointerTy()));
+              DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
+                          DAG.getConstant(offsets[parti], dl, PtrVT));
           unsigned partAlign = aggregateIsPacked
                                    ? 1
                                    : DL.getABITypeAlignment(
@@ -2197,8 +2199,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         continue;
       }
       if (Ty->isVectorTy()) {
-        EVT ObjectVT = getValueType(Ty);
-        SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
+        EVT ObjectVT = getValueType(DL, Ty);
+        SDValue Arg = getParamSymbol(DAG, idx, PtrVT);
         unsigned NumElts = ObjectVT.getVectorNumElements();
         assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
                "Vector was not scalarized");
@@ -2268,9 +2270,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
             Value *SrcValue = Constant::getNullValue(
                 PointerType::get(VecVT.getTypeForEVT(F->getContext()),
                                  llvm::ADDRESS_SPACE_PARAM));
-            SDValue SrcAddr =
-                DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
-                            DAG.getConstant(Ofst, dl, getPointerTy()));
+            SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
+                                          DAG.getConstant(Ofst, dl, PtrVT));
             SDValue P = DAG.getLoad(
                 VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
                 false, true,
@@ -2297,9 +2298,9 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         continue;
       }
       // A plain scalar.
-      EVT ObjectVT = getValueType(Ty);
+      EVT ObjectVT = getValueType(DL, Ty);
       // If ABI, load from the param symbol
-      SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
+      SDValue Arg = getParamSymbol(DAG, idx, PtrVT);
       Value *srcValue = Constant::getNullValue(PointerType::get(
           ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
       SDValue p;
@@ -2329,10 +2330,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
     // machine instruction fails because TargetExternalSymbol
     // (not lowered) is target dependent, and CopyToReg assumes
     // the source is lowered.
-    EVT ObjectVT = getValueType(Ty);
+    EVT ObjectVT = getValueType(DL, Ty);
     assert(ObjectVT == Ins[InsIdx].VT &&
            "Ins type did not match function type");
-    SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
+    SDValue Arg = getParamSymbol(DAG, idx, PtrVT);
     SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
     if (p.getNode())
       p.getNode()->setIROrder(idx + 1);
@@ -2370,7 +2371,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   MachineFunction &MF = DAG.getMachineFunction();
   const Function *F = MF.getFunction();
   Type *RetTy = F->getReturnType();
-  const DataLayout *TD = getDataLayout();
+  const DataLayout &TD = DAG.getDataLayout();
 
   bool isABI = (STI.getSmVersion() >= 20);
   assert(isABI && "Non-ABI compilation is not supported");
@@ -2384,7 +2385,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
     assert(NumElts == Outs.size() && "Bad scalarization of return value");
 
     // const_cast can be removed in later LLVM versions
-    EVT EltVT = getValueType(RetTy).getVectorElementType();
+    EVT EltVT = getValueType(TD, RetTy).getVectorElementType();
     bool NeedExtend = false;
     if (EltVT.getSizeInBits() < 16)
       NeedExtend = true;
@@ -2435,7 +2436,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
       EVT VecVT =
           EVT::getVectorVT(F->getContext(), EltVT, VecSize);
       unsigned PerStoreOffset =
-          TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
+          TD.getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
 
       for (unsigned i = 0; i < NumElts; i += VecSize) {
         // Get values
@@ -2509,8 +2510,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
                                TheValType.getVectorElementType(), TmpVal,
                                DAG.getIntPtrConstant(j, dl));
         EVT TheStoreType = ValVTs[i];
-        if (RetTy->isIntegerTy() &&
-            TD->getTypeAllocSizeInBits(RetTy) < 32) {
+        if (RetTy->isIntegerTy() && TD.getTypeAllocSizeInBits(RetTy) < 32) {
           // The following zero-extension is for integer types only, and
           // specifically not for aggregates.
           TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal);
@@ -3291,14 +3291,14 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_ldu_global_i:
   case Intrinsic::nvvm_ldu_global_f:
   case Intrinsic::nvvm_ldu_global_p: {
-
+    auto &DL = I.getModule()->getDataLayout();
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
-      Info.memVT = getValueType(I.getType());
+      Info.memVT = getValueType(DL, I.getType());
     else if(Intrinsic == Intrinsic::nvvm_ldu_global_p)
-      Info.memVT = getPointerTy();
+      Info.memVT = getPointerTy(DL);
     else
-      Info.memVT = getValueType(I.getType());
+      Info.memVT = getValueType(DL, I.getType());
     Info.ptrVal = I.getArgOperand(0);
     Info.offset = 0;
     Info.vol = 0;
@@ -3311,14 +3311,15 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_ldg_global_i:
   case Intrinsic::nvvm_ldg_global_f:
   case Intrinsic::nvvm_ldg_global_p: {
+    auto &DL = I.getModule()->getDataLayout();
 
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     if (Intrinsic == Intrinsic::nvvm_ldg_global_i)
-      Info.memVT = getValueType(I.getType());
+      Info.memVT = getValueType(DL, I.getType());
     else if(Intrinsic == Intrinsic::nvvm_ldg_global_p)
-      Info.memVT = getPointerTy();
+      Info.memVT = getPointerTy(DL);
     else
-      Info.memVT = getValueType(I.getType());
+      Info.memVT = getValueType(DL, I.getType());
     Info.ptrVal = I.getArgOperand(0);
     Info.offset = 0;
     Info.vol = 0;