[NVPTX] Add support for vectorized function return values
authorJustin Holewinski <jholewinski@nvidia.com>
Fri, 28 Jun 2013 17:57:55 +0000 (17:57 +0000)
committerJustin Holewinski <jholewinski@nvidia.com>
Fri, 28 Jun 2013 17:57:55 +0000 (17:57 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@185173 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/NVPTX/NVPTXISelLowering.cpp
test/CodeGen/NVPTX/vector-args.ll

index 42bfab148c92dfd8b43530108c578a48f8ff9bd6..9679b05ab7b2061d31ddbc7a116755bb05dbbf6e 100644 (file)
@@ -1338,37 +1338,147 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
 }
 
 
-SDValue NVPTXTargetLowering::LowerReturn(
-    SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
-    const SmallVectorImpl<ISD::OutputArg> &Outs,
-    const SmallVectorImpl<SDValue> &OutVals, SDLoc dl,
-    SelectionDAG &DAG) const {
+SDValue
+NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
+                                 bool isVarArg,
+                                 const SmallVectorImpl<ISD::OutputArg> &Outs,
+                                 const SmallVectorImpl<SDValue> &OutVals,
+                                 SDLoc dl, SelectionDAG &DAG) const {
+  MachineFunction &MF = DAG.getMachineFunction();
+  const Function *F = MF.getFunction();
+  const Type *RetTy = F->getReturnType();
+  const DataLayout *TD = getDataLayout();
 
   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
+  assert(isABI && "Non-ABI compilation is not supported");
+  if (!isABI)
+    return Chain;
 
-  unsigned sizesofar = 0;
-  unsigned idx = 0;
-  for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
-    SDValue theVal = OutVals[i];
-    EVT theValType = theVal.getValueType();
-    unsigned numElems = 1;
-    if (theValType.isVector())
-      numElems = theValType.getVectorNumElements();
-    for (unsigned j = 0, je = numElems; j != je; ++j) {
-      SDValue tmpval = theVal;
-      if (theValType.isVector())
-        tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
-                             theValType.getVectorElementType(), tmpval,
-                             DAG.getIntPtrConstant(j));
-      Chain = DAG.getNode(
-          isABI ? NVPTXISD::StoreRetval : NVPTXISD::MoveToRetval, dl,
-          MVT::Other, Chain, DAG.getConstant(isABI ? sizesofar : idx, MVT::i32),
-          tmpval);
+  if (const VectorType *VTy = dyn_cast<const VectorType>(RetTy)) {
+    // If we have a vector type, the OutVals array will be the scalarized
+    // components and we have combine them into 1 or more vector stores.
+    unsigned NumElts = VTy->getNumElements();
+    assert(NumElts == Outs.size() && "Bad scalarization of return value");
+
+    // V1 store
+    if (NumElts == 1) {
+      SDValue StoreVal = OutVals[0];
+      // We only have one element, so just directly store it
+      if (StoreVal.getValueType().getSizeInBits() < 8)
+        StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
+      Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain,
+                          DAG.getConstant(0, MVT::i32), StoreVal);
+    } else if (NumElts == 2) {
+      // V2 store
+      SDValue StoreVal0 = OutVals[0];
+      SDValue StoreVal1 = OutVals[1];
+
+      if (StoreVal0.getValueType().getSizeInBits() < 8) {
+        StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal0);
+        StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal1);
+      }
+
+      Chain = DAG.getNode(NVPTXISD::StoreRetvalV2, dl, MVT::Other, Chain,
+                          DAG.getConstant(0, MVT::i32), StoreVal0, StoreVal1);
+    } else {
+      // V4 stores
+      // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
+      // vector will be expanded to a power of 2 elements, so we know we can
+      // always round up to the next multiple of 4 when creating the vector
+      // stores.
+      // e.g.  4 elem => 1 st.v4
+      //       6 elem => 2 st.v4
+      //       8 elem => 2 st.v4
+      //      11 elem => 3 st.v4
+
+      unsigned VecSize = 4;
+      if (OutVals[0].getValueType().getSizeInBits() == 64)
+        VecSize = 2;
+
+      unsigned Offset = 0;
+
+      EVT VecVT =
+          EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
+      unsigned PerStoreOffset =
+          TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
+
+      bool Extend = false;
+      if (OutVals[0].getValueType().getSizeInBits() < 8)
+        Extend = true;
+
+      for (unsigned i = 0; i < NumElts; i += VecSize) {
+        // Get values
+        SDValue StoreVal;
+        SmallVector<SDValue, 8> Ops;
+        Ops.push_back(Chain);
+        Ops.push_back(DAG.getConstant(Offset, MVT::i32));
+        unsigned Opc = NVPTXISD::StoreRetvalV2;
+        EVT ExtendedVT = (Extend) ? MVT::i8 : OutVals[0].getValueType();
+
+        StoreVal = OutVals[i];
+        if (Extend)
+          StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
+        Ops.push_back(StoreVal);
+
+        if (i + 1 < NumElts) {
+          StoreVal = OutVals[i + 1];
+          if (Extend)
+            StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
+        } else {
+          StoreVal = DAG.getUNDEF(ExtendedVT);
+        }
+        Ops.push_back(StoreVal);
+
+        if (VecSize == 4) {
+          Opc = NVPTXISD::StoreRetvalV4;
+          if (i + 2 < NumElts) {
+            StoreVal = OutVals[i + 2];
+            if (Extend)
+              StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
+          } else {
+            StoreVal = DAG.getUNDEF(ExtendedVT);
+          }
+          Ops.push_back(StoreVal);
+
+          if (i + 3 < NumElts) {
+            StoreVal = OutVals[i + 3];
+            if (Extend)
+              StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
+          } else {
+            StoreVal = DAG.getUNDEF(ExtendedVT);
+          }
+          Ops.push_back(StoreVal);
+        }
+
+        Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
+        Offset += PerStoreOffset;
+      }
+    }
+  } else {
+    unsigned sizesofar = 0;
+    for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
+      SDValue theVal = OutVals[i];
+      EVT theValType = theVal.getValueType();
+      unsigned numElems = 1;
       if (theValType.isVector())
-        sizesofar += theValType.getVectorElementType().getStoreSizeInBits() / 8;
-      else
-        sizesofar += theValType.getStoreSizeInBits() / 8;
-      ++idx;
+        numElems = theValType.getVectorNumElements();
+      for (unsigned j = 0, je = numElems; j != je; ++j) {
+        SDValue tmpval = theVal;
+        if (theValType.isVector())
+          tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
+                               theValType.getVectorElementType(), tmpval,
+                               DAG.getIntPtrConstant(j));
+        EVT theStoreType = tmpval.getValueType();
+        if (theStoreType.getSizeInBits() < 8)
+          tmpval = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, tmpval);
+        Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain,
+                            DAG.getConstant(sizesofar, MVT::i32), tmpval);
+        if (theValType.isVector())
+          sizesofar +=
+              theValType.getVectorElementType().getStoreSizeInBits() / 8;
+        else
+          sizesofar += theValType.getStoreSizeInBits() / 8;
+      }
     }
   }
 
index e480b086075d5546facc081979bd7b4da59ca98f..c6c8e73bf83ece0e3edf0893b823ed2ed193d479 100644 (file)
@@ -23,3 +23,13 @@ define float @bar(<4 x float> %a) {
   %t4 = fadd float %t2, %t3
   ret float %t4
 }
+
+
+define <4 x float> @baz(<4 x float> %a) {
+; CHECK: .func  (.param .align 16 .b8 func_retval0[16]) baz
+; CHECK: .param .align 16 .b8 baz_param_0[16]
+; CHECK: ld.param.v4.f32 {%f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}}
+; CHECK: st.param.v4.f32 [func_retval0+0], {%f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}}
+  %t1 = fmul <4 x float> %a, %a
+  ret <4 x float> %t1
+}