[NVPTX] aligned byte-buffers for vector return types
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
index 645a9bb5c5e3ea4477d0c027be8fc508ea21185b..866017e49db5c6cebdf48c96d4f26f68cfd0b897 100644 (file)
@@ -106,7 +106,7 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, Type *Ty,
 }
 
 // NVPTXTargetLowering Constructor.
-NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
+NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM)
     : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
       nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
 
@@ -203,8 +203,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
 
   // Turn FP extload into load/fextend
+  setLoadExtAction(ISD::EXTLOAD, MVT::f16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
   // Turn FP truncstore into trunc + store.
+  setTruncStoreAction(MVT::f32, MVT::f16, Expand);
+  setTruncStoreAction(MVT::f64, MVT::f16, Expand);
   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
 
   // PTX does not support load / store predicate registers
@@ -1352,7 +1355,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     //  .param .align 16 .b8 retval0[<size-in-bytes>], or
     //  .param .b<size-in-bits> retval0
     unsigned resultsz = TD->getTypeAllocSizeInBits(retTy);
-    if (retTy->isSingleValueType()) {
+    // Emit ".param .b<size-in-bits> retval0" instead of byte arrays only for
+    // these three types to match the logic in
+    // NVPTXAsmPrinter::printReturnValStr and NVPTXTargetLowering::getPrototype.
+    // Plus, this behavior is consistent with nvcc's.
+    if (retTy->isFloatingPointTy() || retTy->isIntegerTy() ||
+        retTy->isPointerTy()) {
       // Scalar needs to be at least 32bit wide
       if (resultsz < 32)
         resultsz = 32;
@@ -1448,8 +1456,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       EVT ObjectVT = getValueType(retTy);
       unsigned NumElts = ObjectVT.getVectorNumElements();
       EVT EltVT = ObjectVT.getVectorElementType();
-      assert(nvTM->getTargetLowering()->getNumRegisters(F->getContext(),
-                                                        ObjectVT) == NumElts &&
+      assert(nvTM->getSubtargetImpl()->getTargetLowering()->getNumRegisters(
+                 F->getContext(), ObjectVT) == NumElts &&
              "Vector was not scalarized");
       unsigned sz = EltVT.getSizeInBits();
       bool needTruncate = sz < 8 ? true : false;
@@ -2025,7 +2033,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
 
   const Function *F = MF.getFunction();
   const AttributeSet &PAL = F->getAttributes();
-  const TargetLowering *TLI = DAG.getTarget().getTargetLowering();
+  const TargetLowering *TLI = DAG.getSubtarget().getTargetLowering();
 
   SDValue Root = DAG.getRoot();
   std::vector<SDValue> OutChains;
@@ -2139,7 +2147,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
                                      ISD::SEXTLOAD : ISD::ZEXTLOAD;
             p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, srcAddr,
                                MachinePointerInfo(srcValue), partVT, false,
-                               false, partAlign);
+                               false, false, partAlign);
           } else {
             p = DAG.getLoad(partVT, dl, Root, srcAddr,
                             MachinePointerInfo(srcValue), false, false, false,
@@ -2160,7 +2168,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         unsigned NumElts = ObjectVT.getVectorNumElements();
         assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
                "Vector was not scalarized");
-        unsigned Ofst = 0;
         EVT EltVT = ObjectVT.getVectorElementType();
 
         // V1 load
@@ -2169,10 +2176,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
           // We only have one element, so just directly load it
           Value *SrcValue = Constant::getNullValue(PointerType::get(
               EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
-          SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
-                                        DAG.getConstant(Ofst, getPointerTy()));
           SDValue P = DAG.getLoad(
-              EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
+              EltVT, dl, Root, Arg, MachinePointerInfo(SrcValue), false,
               false, true,
               TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
           if (P.getNode())
@@ -2181,7 +2186,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
             P = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, P);
           InVals.push_back(P);
-          Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
           ++InsIdx;
         } else if (NumElts == 2) {
           // V2 load
@@ -2189,10 +2193,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
           Value *SrcValue = Constant::getNullValue(PointerType::get(
               VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
-          SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
-                                        DAG.getConstant(Ofst, getPointerTy()));
           SDValue P = DAG.getLoad(
-              VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
+              VecVT, dl, Root, Arg, MachinePointerInfo(SrcValue), false,
               false, true,
               TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
           if (P.getNode())
@@ -2210,7 +2212,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
 
           InVals.push_back(Elt0);
           InVals.push_back(Elt1);
-          Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
           InsIdx += 2;
         } else {
           // V4 loads
@@ -2228,6 +2229,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
             VecSize = 2;
           }
           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
+          unsigned Ofst = 0;
           for (unsigned i = 0; i < NumElts; i += VecSize) {
             Value *SrcValue = Constant::getNullValue(
                 PointerType::get(VecVT.getTypeForEVT(F->getContext()),
@@ -2272,6 +2274,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
                                        ISD::SEXTLOAD : ISD::ZEXTLOAD;
         p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, Arg,
                            MachinePointerInfo(srcValue), ObjectVT, false, false,
+                           false,
         TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
       } else {
         p = DAG.getLoad(Ins[InsIdx].VT, dl, Root, Arg,
@@ -3266,16 +3269,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     Info.vol = 0;
     Info.readMem = true;
     Info.writeMem = false;
-
-    // alignment is available as metadata.
-    // Grab it and set the alignment.
-    assert(I.hasMetadataOtherThanDebugLoc() && "Must have alignment metadata");
-    MDNode *AlignMD = I.getMetadata("align");
-    assert(AlignMD && "Must have a non-null MDNode");
-    assert(AlignMD->getNumOperands() == 1 && "Must have a single operand");
-    Value *Align = AlignMD->getOperand(0);
-    int64_t Alignment = cast<ConstantInt>(Align)->getZExtValue();
-    Info.align = Alignment;
+    Info.align = cast<ConstantInt>(I.getArgOperand(1))->getZExtValue();
 
     return true;
   }
@@ -3295,16 +3289,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     Info.vol = 0;
     Info.readMem = true;
     Info.writeMem = false;
-
-    // alignment is available as metadata.
-    // Grab it and set the alignment.
-    assert(I.hasMetadataOtherThanDebugLoc() && "Must have alignment metadata");
-    MDNode *AlignMD = I.getMetadata("align");
-    assert(AlignMD && "Must have a non-null MDNode");
-    assert(AlignMD->getNumOperands() == 1 && "Must have a single operand");
-    Value *Align = AlignMD->getOperand(0);
-    int64_t Alignment = cast<ConstantInt>(Align)->getZExtValue();
-    Info.align = Alignment;
+    Info.align = cast<ConstantInt>(I.getArgOperand(1))->getZExtValue();
 
     return true;
   }
@@ -3863,8 +3848,8 @@ static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
   }
   else if (N0.getOpcode() == ISD::FMUL) {
     if (VT == MVT::f32 || VT == MVT::f64) {
-      NVPTXTargetLowering *TLI =
-        (NVPTXTargetLowering *)&DAG.getTargetLoweringInfo();
+      const auto *TLI = static_cast<const NVPTXTargetLowering *>(
+          &DAG.getTargetLoweringInfo());
       if (!TLI->allowFMA(DAG.getMachineFunction(), OptLevel))
         return SDValue();
 
@@ -4050,13 +4035,13 @@ static bool IsMulWideOperandDemotable(SDValue Op,
   if (Op.getOpcode() == ISD::SIGN_EXTEND ||
       Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
     EVT OrigVT = Op.getOperand(0).getValueType();
-    if (OrigVT.getSizeInBits() == OptSize) {
+    if (OrigVT.getSizeInBits() <= OptSize) {
       S = Signed;
       return true;
     }
   } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
     EVT OrigVT = Op.getOperand(0).getValueType();
-    if (OrigVT.getSizeInBits() == OptSize) {
+    if (OrigVT.getSizeInBits() <= OptSize) {
       S = Unsigned;
       return true;
     }
@@ -4210,8 +4195,7 @@ static SDValue PerformSHLCombine(SDNode *N,
 
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
-  // FIXME: Get this from the DAG somehow
-  CodeGenOpt::Level OptLevel = CodeGenOpt::Aggressive;
+  CodeGenOpt::Level OptLevel = getTargetMachine().getOptLevel();
   switch (N->getOpcode()) {
     default: break;
     case ISD::ADD: