[NVPTX] Clean up argument lowering code and properly handle alignment for structs...
authorJustin Holewinski <jholewinski@nvidia.com>
Fri, 27 Jun 2014 18:35:44 +0000 (18:35 +0000)
committerJustin Holewinski <jholewinski@nvidia.com>
Fri, 27 Jun 2014 18:35:44 +0000 (18:35 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@211938 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/NVPTX/NVPTXISelLowering.cpp
test/CodeGen/NVPTX/arg-lowering.ll [new file with mode: 0644]

index b324cdb7d667d159a094cee6c96709c0d57f1bae..292e8e173b08a5c2ecbd4c44f9d0ea054b193ae6 100644 (file)
@@ -67,6 +67,17 @@ static bool IsPTXVectorType(MVT VT) {
   }
 }
 
+static uint64_t GCD( int a, int b)
+{
+  if (a < b) std::swap(a,b);
+  while (b > 0) {
+    uint64_t c = b;
+    b = a % b;
+    a = c;
+  }
+  return a;
+}
+
 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
 /// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
 /// into their primitive components.
@@ -518,26 +529,12 @@ NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
     } else if (isa<PointerType>(retTy)) {
       O << ".param .b" << getPointerTy().getSizeInBits() << " _";
     } else {
-      if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) {
-        SmallVector<EVT, 16> vtparts;
-        ComputeValueVTs(*this, retTy, vtparts);
-        unsigned totalsz = 0;
-        for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
-          unsigned elems = 1;
-          EVT elemtype = vtparts[i];
-          if (vtparts[i].isVector()) {
-            elems = vtparts[i].getVectorNumElements();
-            elemtype = vtparts[i].getVectorElementType();
-          }
-          // TODO: no need to loop
-          for (unsigned j = 0, je = elems; j != je; ++j) {
-            unsigned sz = elemtype.getSizeInBits();
-            if (elemtype.isInteger() && (sz < 8))
-              sz = 8;
-            totalsz += sz / 8;
-          }
-        }
-        O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
+      if((retTy->getTypeID() == Type::StructTyID) ||
+         isa<VectorType>(retTy)) {
+        O << ".param .align "
+          << retAlignment
+          << " .b8 _["
+          << getDataLayout()->getTypeAllocSize(retTy) << "]";
       } else {
         assert(false && "Unknown return type");
       }
@@ -706,7 +703,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       if (Ty->isAggregateType()) {
         // aggregate
         SmallVector<EVT, 16> vtparts;
-        ComputeValueVTs(*this, Ty, vtparts);
+        SmallVector<uint64_t, 16> Offsets;
+        ComputePTXValueVTs(*this, Ty, vtparts, &Offsets, 0);
 
         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
         // declare .param .align <align> .b8 .param<n>[<size>];
@@ -718,34 +716,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
                             DeclareParamOps);
         InFlag = Chain.getValue(1);
-        unsigned curOffset = 0;
         for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
-          unsigned elems = 1;
           EVT elemtype = vtparts[j];
-          if (vtparts[j].isVector()) {
-            elems = vtparts[j].getVectorNumElements();
-            elemtype = vtparts[j].getVectorElementType();
-          }
-          for (unsigned k = 0, ke = elems; k != ke; ++k) {
-            unsigned sz = elemtype.getSizeInBits();
-            if (elemtype.isInteger() && (sz < 8))
-              sz = 8;
-            SDValue StVal = OutVals[OIdx];
-            if (elemtype.getSizeInBits() < 16) {
-              StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
-            }
-            SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-            SDValue CopyParamOps[] = { Chain,
-                                       DAG.getConstant(paramCount, MVT::i32),
-                                       DAG.getConstant(curOffset, MVT::i32),
-                                       StVal, InFlag };
-            Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
-                                            CopyParamVTs, CopyParamOps,
-                                            elemtype, MachinePointerInfo());
-            InFlag = Chain.getValue(1);
-            curOffset += sz / 8;
-            ++OIdx;
+          unsigned ArgAlign = GCD(align, Offsets[j]);
+          if (elemtype.isInteger() && (sz < 8))
+            sz = 8;
+          SDValue StVal = OutVals[OIdx];
+          if (elemtype.getSizeInBits() < 16) {
+            StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
           }
+          SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+          SDValue CopyParamOps[] = { Chain,
+                                     DAG.getConstant(paramCount, MVT::i32),
+                                     DAG.getConstant(Offsets[j], MVT::i32),
+                                     StVal, InFlag };
+          Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
+                                          CopyParamVTs, CopyParamOps,
+                                          elemtype, MachinePointerInfo(),
+                                          ArgAlign);
+          InFlag = Chain.getValue(1);
+          ++OIdx;
         }
         if (vtparts.size() > 0)
           --OIdx;
@@ -930,13 +920,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     }
     // struct or vector
     SmallVector<EVT, 16> vtparts;
+    SmallVector<uint64_t, 16> Offsets;
     const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
     assert(PTy && "Type of a byval parameter should be pointer");
-    ComputeValueVTs(*this, PTy->getElementType(), vtparts);
+    ComputePTXValueVTs(*this, PTy->getElementType(), vtparts, &Offsets, 0);
 
     // declare .param .align <align> .b8 .param<n>[<size>];
     unsigned sz = Outs[OIdx].Flags.getByValSize();
     SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+    unsigned ArgAlign = Outs[OIdx].Flags.getByValAlign();
     // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
     // so we don't need to worry about natural alignment or not.
     // See TargetLowering::LowerCallTo().
@@ -948,38 +940,28 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
                         DeclareParamOps);
     InFlag = Chain.getValue(1);
-    unsigned curOffset = 0;
     for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
-      unsigned elems = 1;
       EVT elemtype = vtparts[j];
-      if (vtparts[j].isVector()) {
-        elems = vtparts[j].getVectorNumElements();
-        elemtype = vtparts[j].getVectorElementType();
+      int curOffset = Offsets[j];
+      unsigned PartAlign = GCD(ArgAlign, curOffset);
+      SDValue srcAddr =
+          DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
+                      DAG.getConstant(curOffset, getPointerTy()));
+      SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
+                                   MachinePointerInfo(), false, false, false,
+                                   PartAlign);
+      if (elemtype.getSizeInBits() < 16) {
+        theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
       }
-      for (unsigned k = 0, ke = elems; k != ke; ++k) {
-        unsigned sz = elemtype.getSizeInBits();
-        if (elemtype.isInteger() && (sz < 8))
-          sz = 8;
-        SDValue srcAddr =
-            DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
-                        DAG.getConstant(curOffset, getPointerTy()));
-        SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
-                                     MachinePointerInfo(), false, false, false,
-                                     0);
-        if (elemtype.getSizeInBits() < 16) {
-          theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
-        }
-        SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-        SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
-                                   DAG.getConstant(curOffset, MVT::i32), theVal,
-                                   InFlag };
-        Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
-                                        CopyParamOps, elemtype,
-                                        MachinePointerInfo());
+      SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+      SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
+                                 DAG.getConstant(curOffset, MVT::i32), theVal,
+                                 InFlag };
+      Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
+                                      CopyParamOps, elemtype,
+                                      MachinePointerInfo());
 
-        InFlag = Chain.getValue(1);
-        curOffset += sz / 8;
-      }
+      InFlag = Chain.getValue(1);
     }
     ++paramCount;
   }
@@ -1088,7 +1070,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
   // Generate loads from param memory/moves from registers for result
   if (Ins.size() > 0) {
-    unsigned resoffset = 0;
     if (retTy && retTy->isVectorTy()) {
       EVT ObjectVT = getValueType(retTy);
       unsigned NumElts = ObjectVT.getVectorNumElements();
@@ -1097,14 +1078,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                                                         ObjectVT) == NumElts &&
              "Vector was not scalarized");
       unsigned sz = EltVT.getSizeInBits();
-      bool needTruncate = sz < 16 ? true : false;
+      bool needTruncate = sz < 8 ? true : false;
 
       if (NumElts == 1) {
         // Just a simple load
         SmallVector<EVT, 4> LoadRetVTs;
-        if (needTruncate) {
-          // If loading i1 result, generate
-          //   load i16
+        if (EltVT == MVT::i1 || EltVT == MVT::i8) {
+          // If loading i1/i8 result, generate
+          //   load.b8 i16
+          //   if i1
           //   trunc i16 to i1
           LoadRetVTs.push_back(MVT::i16);
         } else
@@ -1128,9 +1110,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       } else if (NumElts == 2) {
         // LoadV2
         SmallVector<EVT, 4> LoadRetVTs;
-        if (needTruncate) {
-          // If loading i1 result, generate
-          //   load i16
+        if (EltVT == MVT::i1 || EltVT == MVT::i8) {
+          // If loading i1/i8 result, generate
+          //   load.b8 i16
+          //   if i1
           //   trunc i16 to i1
           LoadRetVTs.push_back(MVT::i16);
           LoadRetVTs.push_back(MVT::i16);
@@ -1173,9 +1156,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
         for (unsigned i = 0; i < NumElts; i += VecSize) {
           SmallVector<EVT, 8> LoadRetVTs;
-          if (needTruncate) {
-            // If loading i1 result, generate
-            //   load i16
+          if (EltVT == MVT::i1 || EltVT == MVT::i8) {
+            // If loading i1/i8 result, generate
+            //   load.b8 i16
+            //   if i1
             //   trunc i16 to i1
             for (unsigned j = 0; j < VecSize; ++j)
               LoadRetVTs.push_back(MVT::i16);
@@ -1214,10 +1198,13 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       }
     } else {
       SmallVector<EVT, 16> VTs;
-      ComputePTXValueVTs(*this, retTy, VTs);
+      SmallVector<uint64_t, 16> Offsets;
+      ComputePTXValueVTs(*this, retTy, VTs, &Offsets, 0);
       assert(VTs.size() == Ins.size() && "Bad value decomposition");
+      unsigned RetAlign = getArgumentAlignment(Callee, CS, retTy, 0);
       for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
         unsigned sz = VTs[i].getSizeInBits();
+        unsigned AlignI = GCD(RetAlign, Offsets[i]);
         bool needTruncate = sz < 8 ? true : false;
         if (VTs[i].isInteger() && (sz < 8))
           sz = 8;
@@ -1243,19 +1230,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         SmallVector<SDValue, 4> LoadRetOps;
         LoadRetOps.push_back(Chain);
         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
-        LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
+        LoadRetOps.push_back(DAG.getConstant(Offsets[i], MVT::i32));
         LoadRetOps.push_back(InFlag);
         SDValue retval = DAG.getMemIntrinsicNode(
             NVPTXISD::LoadParam, dl,
             DAG.getVTList(LoadRetVTs), LoadRetOps,
-            TheLoadType, MachinePointerInfo());
+            TheLoadType, MachinePointerInfo(), AlignI);
         Chain = retval.getValue(1);
         InFlag = retval.getValue(2);
         SDValue Ret0 = retval.getValue(0);
         if (needTruncate)
           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
         InVals.push_back(Ret0);
-        resoffset += sz / 8;
       }
     }
   }
diff --git a/test/CodeGen/NVPTX/arg-lowering.ll b/test/CodeGen/NVPTX/arg-lowering.ll
new file mode 100644 (file)
index 0000000..f7b8a14
--- /dev/null
@@ -0,0 +1,13 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
+
+; CHECK: .visible .func  (.param .align 16 .b8 func_retval0[16]) foo0(
+; CHECK:          .param .align 4 .b8 foo0_param_0[8]
+define <4 x float> @foo0({float, float} %arg0) {
+  ret <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>
+}
+
+; CHECK: .visible .func  (.param .align 8 .b8 func_retval0[8]) foo1(
+; CHECK:          .param .align 8 .b8 foo1_param_0[16]
+define <2 x float> @foo1({float, float, i64} %arg0) {
+  ret <2 x float> <float 1.0, float 1.0>
+}