AVX-512: Fixed masked load / store instruction selection for KNL.
authorElena Demikhovsky <elena.demikhovsky@intel.com>
Mon, 7 Dec 2015 13:39:24 +0000 (13:39 +0000)
committerElena Demikhovsky <elena.demikhovsky@intel.com>
Mon, 7 Dec 2015 13:39:24 +0000 (13:39 +0000)
Patterns were missing for KNL target for <8 x i32>, <8 x float> masked load/store.

This intrinsic comes with all legal types:
<8 x float> @llvm.masked.load.v8f32(<8 x float>* %addr, i32 align, <8 x i1> %mask, <8 x float> %passThru),
but still requires lowering, because VMASKMOVPS, VMASKMOVDQU32 work with 512-bit vectors only.

All data operands should be widened to 512-bit vector.
The mask operand should be widened to v16i1 with zeroes.

Differential Revision: http://reviews.llvm.org/D15265

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@254909 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
lib/Target/X86/X86ISelLowering.cpp
lib/Target/X86/X86InstrAVX512.td
test/CodeGen/X86/masked_memop.ll

index 1fb7b160a67177902e006b29cc96d9aa7b72b62a..8295b2a19dd2c48128112ced97b2974f9eb7e154 100644 (file)
@@ -244,7 +244,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
         Changed = true;
         return LegalizeOp(ExpandStore(Op));
       }
-  } else if (Op.getOpcode() == ISD::MSCATTER)
+  } else if (Op.getOpcode() == ISD::MSCATTER || Op.getOpcode() == ISD::MSTORE)
     HasVectorValue = true;
 
   for (SDNode::value_iterator J = Node->value_begin(), E = Node->value_end();
@@ -344,6 +344,9 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::MSCATTER:
     QueryType = cast<MaskedScatterSDNode>(Node)->getValue().getValueType();
     break;
+  case ISD::MSTORE:
+    QueryType = cast<MaskedStoreSDNode>(Node)->getValue().getValueType();
+    break;
   }
 
   switch (TLI.getOperationAction(Node->getOpcode(), QueryType)) {
index f38ca2956ff325979dc8e1f154a213f07e9f2e97..fa6f5c8be88cb004c2baf553deebedb12be82954 100644 (file)
@@ -1384,6 +1384,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
       setTruncStoreAction(MVT::v2i64, MVT::v2i32, Legal);
       setTruncStoreAction(MVT::v4i32, MVT::v4i8,  Legal);
       setTruncStoreAction(MVT::v4i32, MVT::v4i16, Legal);
+    } else {
+      setOperationAction(ISD::MLOAD,    MVT::v8i32, Custom);
+      setOperationAction(ISD::MLOAD,    MVT::v8f32, Custom);
+      setOperationAction(ISD::MSTORE,   MVT::v8i32, Custom);
+      setOperationAction(ISD::MSTORE,   MVT::v8f32, Custom);
     }
     setOperationAction(ISD::TRUNCATE,           MVT::i1, Custom);
     setOperationAction(ISD::TRUNCATE,           MVT::v16i8, Custom);
@@ -1459,6 +1464,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
 
     setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v8i1,  Custom);
     setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v16i1, Custom);
+    setOperationAction(ISD::INSERT_SUBVECTOR,   MVT::v16i1, Custom);
     setOperationAction(ISD::INSERT_VECTOR_ELT,  MVT::v16i1, Custom);
     setOperationAction(ISD::INSERT_VECTOR_ELT,  MVT::v8i1, Custom);
     setOperationAction(ISD::BUILD_VECTOR,       MVT::v8i1, Custom);
@@ -19685,6 +19691,47 @@ static SDValue LowerFSINCOS(SDValue Op, const X86Subtarget *Subtarget,
   return DAG.getNode(ISD::MERGE_VALUES, dl, Tys, SinVal, CosVal);
 }
 
+/// Widen a vector input to a vector of NVT.  The
+/// input vector must have the same element type as NVT.
+static SDValue ExtendToType(SDValue InOp, MVT NVT, SelectionDAG &DAG,
+                            bool FillWithZeroes = false) {
+  // Check if InOp already has the right width.
+  MVT InVT = InOp.getSimpleValueType();
+  if (InVT == NVT)
+    return InOp;
+
+  if (InOp.isUndef())
+    return DAG.getUNDEF(NVT);
+
+  assert(InVT.getVectorElementType() == NVT.getVectorElementType() &&
+         "input and widen element type must match");
+
+  unsigned InNumElts = InVT.getVectorNumElements();
+  unsigned WidenNumElts = NVT.getVectorNumElements();
+  assert(WidenNumElts > InNumElts && WidenNumElts % InNumElts == 0 &&
+         "Unexpected request for vector widening");
+
+  EVT EltVT = NVT.getVectorElementType();
+
+  SDLoc dl(InOp);
+  if (ISD::isBuildVectorOfConstantSDNodes(InOp.getNode()) ||
+      ISD::isBuildVectorOfConstantFPSDNodes(InOp.getNode())) {
+    SmallVector<SDValue, 16> Ops;
+    for (unsigned i = 0; i < InNumElts; ++i)
+      Ops.push_back(InOp.getOperand(i));
+
+    SDValue FillVal = FillWithZeroes ? DAG.getConstant(0, dl, EltVT) :
+      DAG.getUNDEF(EltVT);
+    for (unsigned i = 0; i < WidenNumElts - InNumElts; ++i)
+      Ops.push_back(FillVal);
+    return DAG.getNode(ISD::BUILD_VECTOR, dl, NVT, Ops);
+  }
+  SDValue FillVal = FillWithZeroes ? DAG.getConstant(0, dl, NVT) :
+    DAG.getUNDEF(NVT);
+  return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, NVT, FillVal,
+                     InOp, DAG.getIntPtrConstant(0, dl));
+}
+
 static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget *Subtarget,
                              SelectionDAG &DAG) {
   assert(Subtarget->hasAVX512() &&
@@ -19714,6 +19761,62 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget *Subtarget,
   return Op;
 }
 
+static SDValue LowerMLOAD(SDValue Op, const X86Subtarget *Subtarget,
+                          SelectionDAG &DAG) {
+
+  MaskedLoadSDNode *N = cast<MaskedLoadSDNode>(Op.getNode());
+  MVT VT = Op.getSimpleValueType();
+  SDValue Mask = N->getMask();
+  SDLoc dl(Op);
+
+  if (Subtarget->hasAVX512() && !Subtarget->hasVLX() &&
+      !VT.is512BitVector() && Mask.getValueType() == MVT::v8i1) {
+    // This operation is legal for targets with VLX, but without
+    // VLX the vector should be widened to 512 bit
+    unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits();
+    MVT WideDataVT = MVT::getVectorVT(VT.getScalarType(), NumEltsInWideVec);
+    MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
+    SDValue Src0 = N->getSrc0();
+    Src0 = ExtendToType(Src0, WideDataVT, DAG);
+    Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
+    SDValue NewLoad = DAG.getMaskedLoad(WideDataVT, dl, N->getChain(),
+                                        N->getBasePtr(), Mask, Src0,
+                                        N->getMemoryVT(), N->getMemOperand(),
+                                        N->getExtensionType());
+
+    SDValue Exract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT,
+                                 NewLoad.getValue(0),
+                                 DAG.getIntPtrConstant(0, dl));
+    SDValue RetOps[] = {Exract, NewLoad.getValue(1)};
+    return DAG.getMergeValues(RetOps, dl);
+  }
+  return Op;
+}
+
+static SDValue LowerMSTORE(SDValue Op, const X86Subtarget *Subtarget,
+                           SelectionDAG &DAG) {
+  MaskedStoreSDNode *N = cast<MaskedStoreSDNode>(Op.getNode());
+  SDValue DataToStore = N->getValue();
+  MVT VT = DataToStore.getSimpleValueType();
+  SDValue Mask = N->getMask();
+  SDLoc dl(Op);
+
+  if (Subtarget->hasAVX512() && !Subtarget->hasVLX() &&
+      !VT.is512BitVector() && Mask.getValueType() == MVT::v8i1) {
+    // This operation is legal for targets with VLX, but without
+    // VLX the vector should be widened to 512 bit
+    unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits();
+    MVT WideDataVT = MVT::getVectorVT(VT.getScalarType(), NumEltsInWideVec);
+    MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
+    DataToStore = ExtendToType(DataToStore, WideDataVT, DAG);
+    Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
+    return DAG.getMaskedStore(N->getChain(), dl, DataToStore, N->getBasePtr(),
+                              Mask, N->getMemoryVT(), N->getMemOperand(),
+                              N->isTruncatingStore());
+  }
+  return Op;
+}
+
 static SDValue LowerMGATHER(SDValue Op, const X86Subtarget *Subtarget,
                             SelectionDAG &DAG) {
   assert(Subtarget->hasAVX512() &&
@@ -19873,6 +19976,8 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::UMAX:
   case ISD::UMIN:               return LowerMINMAX(Op, DAG);
   case ISD::FSINCOS:            return LowerFSINCOS(Op, Subtarget, DAG);
+  case ISD::MLOAD:              return LowerMLOAD(Op, Subtarget, DAG);
+  case ISD::MSTORE:             return LowerMSTORE(Op, Subtarget, DAG);
   case ISD::MGATHER:            return LowerMGATHER(Op, Subtarget, DAG);
   case ISD::MSCATTER:           return LowerMSCATTER(Op, Subtarget, DAG);
   case ISD::GC_TRANSITION_START:
index 60238f6ab23d2b3fb8b5fdc0e9cb3516f0675d98..58206c6acaa6c8375d88d42ba56d45849d1571b9 100644 (file)
@@ -2766,22 +2766,6 @@ def: Pat<(int_x86_avx512_mask_store_pd_512 addr:$ptr, (v8f64 VR512:$src),
          (VMOVAPDZmrk addr:$ptr, (v8i1 (COPY_TO_REGCLASS GR8:$mask, VK8WM)),
             VR512:$src)>;
 
-let Predicates = [HasAVX512, NoVLX] in {
-def: Pat<(X86mstore addr:$ptr, VK8WM:$mask, (v8f32 VR256:$src)),
-         (VMOVUPSZmrk addr:$ptr,
-         (v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)),
-         (INSERT_SUBREG (v16f32 (IMPLICIT_DEF)), VR256:$src, sub_ymm))>;
-
-def: Pat<(v8f32 (masked_load addr:$ptr, VK8WM:$mask, undef)),
-         (v8f32 (EXTRACT_SUBREG (v16f32 (VMOVUPSZrmkz
-          (v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)), addr:$ptr)), sub_ymm))>;
-
-def: Pat<(v8f32 (masked_load addr:$ptr, VK8WM:$mask, (v8f32 VR256:$src0))),
-         (v8f32 (EXTRACT_SUBREG (v16f32 (VMOVUPSZrmk
-         (INSERT_SUBREG (v16f32 (IMPLICIT_DEF)), VR256:$src0, sub_ymm),
-          (v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)), addr:$ptr)), sub_ymm))>;
-}
-
 defm VMOVDQA32 : avx512_alignedload_vl<0x6F, "vmovdqa32", avx512vl_i32_info,
                                        HasAVX512>,
                  avx512_alignedstore_vl<0x7F, "vmovdqa32", avx512vl_i32_info,
@@ -2843,17 +2827,6 @@ def : Pat<(v16i32 (vselect VK16WM:$mask, (v16i32 immAllZerosV),
                            (v16i32 VR512:$src))),
                   (VMOVDQU32Zrrkz (KNOTWrr VK16WM:$mask), VR512:$src)>;
 }
-// NoVLX patterns
-let Predicates = [HasAVX512, NoVLX] in {
-def: Pat<(X86mstore addr:$ptr, VK8WM:$mask, (v8i32 VR256:$src)),
-         (VMOVDQU32Zmrk addr:$ptr,
-         (v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)),
-         (INSERT_SUBREG (v16i32 (IMPLICIT_DEF)), VR256:$src, sub_ymm))>;
-
-def: Pat<(v8i32 (masked_load addr:$ptr, VK8WM:$mask, undef)),
-         (v8i32 (EXTRACT_SUBREG (v16i32 (VMOVDQU32Zrmkz
-          (v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)), addr:$ptr)), sub_ymm))>;
-}
 
 // Move Int Doubleword to Packed Double Int
 //
index a720054c167ca6d74e9b45c57ab87006b8875ed6..1a9cf008e86900a6d0e60442a381f8c160316de2 100644 (file)
@@ -139,18 +139,55 @@ define <4 x double> @test10(<4 x i32> %trigger, <4 x double>* %addr, <4 x double
   ret <4 x double> %res
 }
 
-; AVX2-LABEL: test11
+; AVX2-LABEL: test11a
 ; AVX2: vmaskmovps
 ; AVX2: vblendvps
 
-; SKX-LABEL: test11
-; SKX: vmovaps {{.*}}{%k1}
-define <8 x float> @test11(<8 x i32> %trigger, <8 x float>* %addr, <8 x float> %dst) {
+; SKX-LABEL: test11a
+; SKX: vmovaps (%rdi), %ymm1 {%k1}
+; AVX512-LABEL: test11a
+; AVX512: kshiftlw $8
+; AVX512: kshiftrw $8
+; AVX512: vmovups (%rdi), %zmm1 {%k1}
+define <8 x float> @test11a(<8 x i32> %trigger, <8 x float>* %addr, <8 x float> %dst) {
   %mask = icmp eq <8 x i32> %trigger, zeroinitializer
   %res = call <8 x float> @llvm.masked.load.v8f32(<8 x float>* %addr, i32 32, <8 x i1>%mask, <8 x float>%dst)
   ret <8 x float> %res
 }
 
+; SKX-LABEL: test11b
+; SKX: vmovdqu32 (%rdi), %ymm1 {%k1}
+; AVX512-LABEL: test11b
+; AVX512: kshiftlw        $8
+; AVX512: kshiftrw        $8
+; AVX512: vmovdqu32 (%rdi), %zmm1 {%k1}
+define <8 x i32> @test11b(<8 x i1> %mask, <8 x i32>* %addr, <8 x i32> %dst) {
+  %res = call <8 x i32> @llvm.masked.load.v8i32(<8 x i32>* %addr, i32 4, <8 x i1>%mask, <8 x i32>%dst)
+  ret <8 x i32> %res
+}
+
+; SKX-LABEL: test11c
+; SKX: vmovaps (%rdi), %ymm0 {%k1} {z}
+; AVX512-LABEL: test11c
+; AVX512: kshiftlw  $8
+; AVX512: kshiftrw  $8
+; AVX512: vmovups (%rdi), %zmm0 {%k1} {z}
+define <8 x float> @test11c(<8 x i1> %mask, <8 x float>* %addr) {
+  %res = call <8 x float> @llvm.masked.load.v8f32(<8 x float>* %addr, i32 32, <8 x i1> %mask, <8 x float> zeroinitializer)
+  ret <8 x float> %res
+}
+
+; SKX-LABEL: test11d
+; SKX: vmovdqu32 (%rdi), %ymm0 {%k1} {z}
+; AVX512-LABEL: test11d
+; AVX512: kshiftlw  $8
+; AVX512: kshiftrw  $8
+; AVX512: vmovdqu32 (%rdi), %zmm0 {%k1} {z}
+define <8 x i32> @test11d(<8 x i1> %mask, <8 x i32>* %addr) {
+  %res = call <8 x i32> @llvm.masked.load.v8i32(<8 x i32>* %addr, i32 4, <8 x i1> %mask, <8 x i32> zeroinitializer)
+  ret <8 x i32> %res
+}
+
 ; AVX2-LABEL: test12
 ; AVX2: vpmaskmovd %ymm
 
@@ -291,6 +328,7 @@ declare void @llvm.masked.store.v16f32(<16 x float>, <16 x float>*, i32, <16 x i
 declare void @llvm.masked.store.v16f32p(<16 x float>*, <16 x float>**, i32, <16 x i1>)
 declare <16 x float> @llvm.masked.load.v16f32(<16 x float>*, i32, <16 x i1>, <16 x float>)
 declare <8 x float> @llvm.masked.load.v8f32(<8 x float>*, i32, <8 x i1>, <8 x float>)
+declare <8 x i32> @llvm.masked.load.v8i32(<8 x i32>*, i32, <8 x i1>, <8 x i32>)
 declare <4 x float> @llvm.masked.load.v4f32(<4 x float>*, i32, <4 x i1>, <4 x float>)
 declare <2 x float> @llvm.masked.load.v2f32(<2 x float>*, i32, <2 x i1>, <2 x float>)
 declare <8 x double> @llvm.masked.load.v8f64(<8 x double>*, i32, <8 x i1>, <8 x double>)