Merging r259346 (with adjustments for r258867):
authorHans Wennborg <hans@hanshq.net>
Tue, 2 Feb 2016 17:41:39 +0000 (17:41 +0000)
committerHans Wennborg <hans@hanshq.net>
Tue, 2 Feb 2016 17:41:39 +0000 (17:41 +0000)
------------------------------------------------------------------------
r259346 | ibreger | 2016-02-01 01:57:15 -0800 (Mon, 01 Feb 2016) | 3 lines

AVX512: fix mask handling for gather/scatter/prefetch intrinsics.

Differential Revision: http://reviews.llvm.org/D16755
------------------------------------------------------------------------

git-svn-id: https://llvm.org/svn/llvm-project/llvm/branches/release_38@259536 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
test/CodeGen/X86/avx512-gather-scatter-intrin.ll

index 390c4da..c12a3ed 100644 (file)
@@ -16319,6 +16319,11 @@ static SDValue getMaskNode(SDValue Mask, MVT MaskVT,
                            const X86Subtarget *Subtarget,
                            SelectionDAG &DAG, SDLoc dl) {
 
+  if (isAllOnesConstant(Mask))
+    return DAG.getTargetConstant(1, dl, MaskVT);
+  if (X86::isZeroNode(Mask))
+    return DAG.getTargetConstant(0, dl, MaskVT);
+
   if (MaskVT.bitsGT(Mask.getSimpleValueType())) {
     // Mask should be extended
     Mask = DAG.getNode(ISD::ANY_EXTEND, dl,
@@ -17207,26 +17212,14 @@ static SDValue getGatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
   SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8);
   MVT MaskVT = MVT::getVectorVT(MVT::i1,
                              Index.getSimpleValueType().getVectorNumElements());
-  SDValue MaskInReg;
-  ConstantSDNode *MaskC = dyn_cast<ConstantSDNode>(Mask);
-  if (MaskC)
-    MaskInReg = DAG.getTargetConstant(MaskC->getSExtValue(), dl, MaskVT);
-  else {
-    MVT BitcastVT = MVT::getVectorVT(MVT::i1,
-                                     Mask.getSimpleValueType().getSizeInBits());
 
-    // In case when MaskVT equals v2i1 or v4i1, low 2 or 4 elements
-    // are extracted by EXTRACT_SUBVECTOR.
-    MaskInReg = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MaskVT,
-                            DAG.getBitcast(BitcastVT, Mask),
-                            DAG.getIntPtrConstant(0, dl));
-  }
+  SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
   SDVTList VTs = DAG.getVTList(Op.getValueType(), MaskVT, MVT::Other);
   SDValue Disp = DAG.getTargetConstant(0, dl, MVT::i32);
   SDValue Segment = DAG.getRegister(0, MVT::i32);
   if (Src.getOpcode() == ISD::UNDEF)
     Src = getZeroVector(Op.getSimpleValueType(), Subtarget, DAG, dl);
-  SDValue Ops[] = {Src, MaskInReg, Base, Scale, Index, Disp, Segment, Chain};
+  SDValue Ops[] = {Src, VMask, Base, Scale, Index, Disp, Segment, Chain};
   SDNode *Res = DAG.getMachineNode(Opc, dl, VTs, Ops);
   SDValue RetOps[] = { SDValue(Res, 0), SDValue(Res, 2) };
   return DAG.getMergeValues(RetOps, dl);
@@ -17234,7 +17227,8 @@ static SDValue getGatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
 
 static SDValue getScatterNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
                                SDValue Src, SDValue Mask, SDValue Base,
-                               SDValue Index, SDValue ScaleOp, SDValue Chain) {
+                               SDValue Index, SDValue ScaleOp, SDValue Chain,
+                               const X86Subtarget &Subtarget) {
   SDLoc dl(Op);
   auto *C = cast<ConstantSDNode>(ScaleOp);
   SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8);
@@ -17242,29 +17236,18 @@ static SDValue getScatterNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
   SDValue Segment = DAG.getRegister(0, MVT::i32);
   MVT MaskVT = MVT::getVectorVT(MVT::i1,
                              Index.getSimpleValueType().getVectorNumElements());
-  SDValue MaskInReg;
-  ConstantSDNode *MaskC = dyn_cast<ConstantSDNode>(Mask);
-  if (MaskC)
-    MaskInReg = DAG.getTargetConstant(MaskC->getSExtValue(), dl, MaskVT);
-  else {
-    MVT BitcastVT = MVT::getVectorVT(MVT::i1,
-                                     Mask.getSimpleValueType().getSizeInBits());
 
-    // In case when MaskVT equals v2i1 or v4i1, low 2 or 4 elements
-    // are extracted by EXTRACT_SUBVECTOR.
-    MaskInReg = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MaskVT,
-                            DAG.getBitcast(BitcastVT, Mask),
-                            DAG.getIntPtrConstant(0, dl));
-  }
+  SDValue VMask = getMaskNode(Mask, MaskVT, &Subtarget, DAG, dl);
   SDVTList VTs = DAG.getVTList(MaskVT, MVT::Other);
-  SDValue Ops[] = {Base, Scale, Index, Disp, Segment, MaskInReg, Src, Chain};
+  SDValue Ops[] = {Base, Scale, Index, Disp, Segment, VMask, Src, Chain};
   SDNode *Res = DAG.getMachineNode(Opc, dl, VTs, Ops);
   return SDValue(Res, 1);
 }
 
 static SDValue getPrefetchNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
                                SDValue Mask, SDValue Base, SDValue Index,
-                               SDValue ScaleOp, SDValue Chain) {
+                               SDValue ScaleOp, SDValue Chain,
+                               const X86Subtarget &Subtarget) {
   SDLoc dl(Op);
   auto *C = cast<ConstantSDNode>(ScaleOp);
   SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8);
@@ -17272,14 +17255,9 @@ static SDValue getPrefetchNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
   SDValue Segment = DAG.getRegister(0, MVT::i32);
   MVT MaskVT =
     MVT::getVectorVT(MVT::i1, Index.getSimpleValueType().getVectorNumElements());
-  SDValue MaskInReg;
-  ConstantSDNode *MaskC = dyn_cast<ConstantSDNode>(Mask);
-  if (MaskC)
-    MaskInReg = DAG.getTargetConstant(MaskC->getSExtValue(), dl, MaskVT);
-  else
-    MaskInReg = DAG.getBitcast(MaskVT, Mask);
+  SDValue VMask = getMaskNode(Mask, MaskVT, &Subtarget, DAG, dl);
   //SDVTList VTs = DAG.getVTList(MVT::Other);
-  SDValue Ops[] = {MaskInReg, Base, Scale, Index, Disp, Segment, Chain};
+  SDValue Ops[] = {VMask, Base, Scale, Index, Disp, Segment, Chain};
   SDNode *Res = DAG.getMachineNode(Opc, dl, MVT::Other, Ops);
   return SDValue(Res, 0);
 }
@@ -17513,7 +17491,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget *Subtarget,
     SDValue Src   = Op.getOperand(5);
     SDValue Scale = Op.getOperand(6);
     return getScatterNode(IntrData->Opc0, Op, DAG, Src, Mask, Base, Index,
-                          Scale, Chain);
+                          Scale, Chain, *Subtarget);
   }
   case PREFETCH: {
     SDValue Hint = Op.getOperand(6);
@@ -17525,7 +17503,8 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget *Subtarget,
     SDValue Index = Op.getOperand(3);
     SDValue Base  = Op.getOperand(4);
     SDValue Scale = Op.getOperand(5);
-    return getPrefetchNode(Opcode, Op, DAG, Mask, Base, Index, Scale, Chain);
+    return getPrefetchNode(Opcode, Op, DAG, Mask, Base, Index, Scale, Chain,
+                           *Subtarget);
   }
   // Read Time Stamp Counter (RDTSC) and Processor ID (RDTSCP).
   case RDTSC: {
index 3bc67cc..9ba1819 100644 (file)
@@ -259,18 +259,22 @@ define void @prefetch(<8 x i64> %ind, i8* %base) {
 ; CHECK:       ## BB#0:
 ; CHECK-NEXT:    kxnorw %k0, %k0, %k1
 ; CHECK-NEXT:    vgatherpf0qps (%rdi,%zmm0,4) {%k1}
+; CHECK-NEXT:    kxorw %k0, %k0, %k1
 ; CHECK-NEXT:    vgatherpf1qps (%rdi,%zmm0,4) {%k1}
+; CHECK-NEXT:    movb $1, %al
+; CHECK-NEXT:    kmovb %eax, %k1
 ; CHECK-NEXT:    vscatterpf0qps (%rdi,%zmm0,2) {%k1}
+; CHECK-NEXT:    movb $120, %al
+; CHECK-NEXT:    kmovb %eax, %k1
 ; CHECK-NEXT:    vscatterpf1qps (%rdi,%zmm0,2) {%k1}
 ; CHECK-NEXT:    retq
   call void @llvm.x86.avx512.gatherpf.qps.512(i8 -1, <8 x i64> %ind, i8* %base, i32 4, i32 0)
-  call void @llvm.x86.avx512.gatherpf.qps.512(i8 -1, <8 x i64> %ind, i8* %base, i32 4, i32 1)
-  call void @llvm.x86.avx512.scatterpf.qps.512(i8 -1, <8 x i64> %ind, i8* %base, i32 2, i32 0)
-  call void @llvm.x86.avx512.scatterpf.qps.512(i8 -1, <8 x i64> %ind, i8* %base, i32 2, i32 1)
+  call void @llvm.x86.avx512.gatherpf.qps.512(i8 0, <8 x i64> %ind, i8* %base, i32 4, i32 1)
+  call void @llvm.x86.avx512.scatterpf.qps.512(i8 1, <8 x i64> %ind, i8* %base, i32 2, i32 0)
+  call void @llvm.x86.avx512.scatterpf.qps.512(i8 120, <8 x i64> %ind, i8* %base, i32 2, i32 1)
   ret void
 }
 
-
 declare <2 x double> @llvm.x86.avx512.gather3div2.df(<2 x double>, i8*, <2 x i64>, i8, i32)
 
 define <2 x double>@test_int_x86_avx512_gather3div2_df(<2 x double> %x0, i8* %x1, <2 x i64> %x2, i8 %x3) {
@@ -790,3 +794,54 @@ define void@test_int_x86_avx512_scattersiv8_si(i8* %x0, i8 %x1, <8 x i32> %x2, <
   ret void
 }
 
+define void @scatter_mask_test(i8* %x0, <8 x i32> %x2, <8 x i32> %x3) {
+; CHECK-LABEL: scatter_mask_test:
+; CHECK:       ## BB#0:
+; CHECK-NEXT:    kxnorw %k0, %k0, %k1
+; CHECK-NEXT:    vpscatterdd %ymm1, (%rdi,%ymm0,2) {%k1}
+; CHECK-NEXT:    kxorw %k0, %k0, %k1
+; CHECK-NEXT:    vpscatterdd %ymm1, (%rdi,%ymm0,4) {%k1}
+; CHECK-NEXT:    movb $1, %al
+; CHECK-NEXT:    kmovb %eax, %k1
+; CHECK-NEXT:    vpscatterdd %ymm1, (%rdi,%ymm0,2) {%k1}
+; CHECK-NEXT:    movb $96, %al
+; CHECK-NEXT:    kmovb %eax, %k1
+; CHECK-NEXT:    vpscatterdd %ymm1, (%rdi,%ymm0,4) {%k1}
+; CHECK-NEXT:    retq
+  call void @llvm.x86.avx512.scattersiv8.si(i8* %x0, i8 -1, <8 x i32> %x2, <8 x i32> %x3, i32 2)
+  call void @llvm.x86.avx512.scattersiv8.si(i8* %x0, i8 0, <8 x i32> %x2, <8 x i32> %x3, i32 4)
+  call void @llvm.x86.avx512.scattersiv8.si(i8* %x0, i8 1, <8 x i32> %x2, <8 x i32> %x3, i32 2)
+  call void @llvm.x86.avx512.scattersiv8.si(i8* %x0, i8 96, <8 x i32> %x2, <8 x i32> %x3, i32 4)
+  ret void
+}
+
+define <16 x float> @gather_mask_test(<16 x i32> %ind, <16 x float> %src, i8* %base)  {
+; CHECK-LABEL: gather_mask_test:
+; CHECK:       ## BB#0:
+; CHECK-NEXT:    kxnorw %k0, %k0, %k1
+; CHECK-NEXT:    vmovaps %zmm1, %zmm2
+; CHECK-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm2 {%k1}
+; CHECK-NEXT:    kxorw %k0, %k0, %k1
+; CHECK-NEXT:    vmovaps %zmm1, %zmm3
+; CHECK-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm3 {%k1}
+; CHECK-NEXT:    movw $1, %ax
+; CHECK-NEXT:    kmovw %eax, %k1
+; CHECK-NEXT:    vmovaps %zmm1, %zmm4
+; CHECK-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm4 {%k1}
+; CHECK-NEXT:    movw $220, %ax
+; CHECK-NEXT:    kmovw %eax, %k1
+; CHECK-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; CHECK-NEXT:    vaddps %zmm3, %zmm2, %zmm0
+; CHECK-NEXT:    vaddps %zmm4, %zmm1, %zmm1
+; CHECK-NEXT:    vaddps %zmm0, %zmm1, %zmm0
+; CHECK-NEXT:    retq
+  %res = call <16 x float> @llvm.x86.avx512.gather.dps.512 (<16 x float> %src, i8* %base, <16 x i32>%ind, i16 -1, i32 4)
+  %res1 = call <16 x float> @llvm.x86.avx512.gather.dps.512 (<16 x float> %src, i8* %base, <16 x i32>%ind, i16 0, i32 4)
+  %res2 = call <16 x float> @llvm.x86.avx512.gather.dps.512 (<16 x float> %src, i8* %base, <16 x i32>%ind, i16 1, i32 4)
+  %res3 = call <16 x float> @llvm.x86.avx512.gather.dps.512 (<16 x float> %src, i8* %base, <16 x i32>%ind, i16 220, i32 4)
+
+  %res4 = fadd <16 x float> %res, %res1
+  %res5 = fadd <16 x float> %res3, %res2
+  %res6 = fadd <16 x float> %res5, %res4
+  ret <16 x float> %res6
+}