[X86][SSE] Add v16i8/v32i8 multiplication support
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 27 Apr 2015 07:55:46 +0000 (07:55 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 27 Apr 2015 07:55:46 +0000 (07:55 +0000)
Patch to allow int8 vectors to be multiplied on the SSE unit instead of being scalarized.

The patch sign extends the i8 lanes to i16, uses the SSE2 pmullw multiplication instruction, then packs the lower byte from each result.

Differential Revision: http://reviews.llvm.org/D9115

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@235837 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
test/CodeGen/X86/avx2-arith.ll
test/CodeGen/X86/pmul.ll

index b2f06609c6bf0ad272ae6d4fbd341be63a31709b..c4d4a0f6ce73bd3778636cc43008c494a48ba3a4 100644 (file)
@@ -802,6 +802,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     setOperationAction(ISD::ADD,                MVT::v8i16, Legal);
     setOperationAction(ISD::ADD,                MVT::v4i32, Legal);
     setOperationAction(ISD::ADD,                MVT::v2i64, Legal);
+    setOperationAction(ISD::MUL,                MVT::v16i8, Custom);
     setOperationAction(ISD::MUL,                MVT::v4i32, Custom);
     setOperationAction(ISD::MUL,                MVT::v2i64, Custom);
     setOperationAction(ISD::UMUL_LOHI,          MVT::v4i32, Custom);
@@ -1122,7 +1123,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
       setOperationAction(ISD::MUL,             MVT::v4i64, Custom);
       setOperationAction(ISD::MUL,             MVT::v8i32, Legal);
       setOperationAction(ISD::MUL,             MVT::v16i16, Legal);
-      // Don't lower v32i8 because there is no 128-bit byte mul
+      setOperationAction(ISD::MUL,             MVT::v32i8, Custom);
 
       setOperationAction(ISD::UMUL_LOHI,       MVT::v8i32, Custom);
       setOperationAction(ISD::SMUL_LOHI,       MVT::v8i32, Custom);
@@ -1171,7 +1172,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
       setOperationAction(ISD::MUL,             MVT::v4i64, Custom);
       setOperationAction(ISD::MUL,             MVT::v8i32, Custom);
       setOperationAction(ISD::MUL,             MVT::v16i16, Custom);
-      // Don't lower v32i8 because there is no 128-bit byte mul
+      setOperationAction(ISD::MUL,             MVT::v32i8, Custom);
     }
 
     // In the customized shift lowering, the legal cases in AVX2 will be
@@ -9894,7 +9895,7 @@ static SDValue lower256BitVectorShuffle(SDValue Op, SDValue V1, SDValue V2,
   int NumV2Elements = std::count_if(Mask.begin(), Mask.end(), [NumElts](int M) {
     return M >= NumElts;
   });
-  
+
   if (NumV2Elements == 1 && Mask[0] >= NumElts)
     if (SDValue Insertion = lowerVectorShuffleAsElementInsertion(
                               DL, VT, V1, V2, Mask, Subtarget, DAG))
@@ -10646,7 +10647,7 @@ SDValue X86TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
         return DAG.getNode(X86ISD::BLENDI, dl, VT, N0, N1Vec, N2);
       }
     }
-    
+
     // Get the desired 128-bit vector chunk.
     SDValue V = Extract128BitVector(N0, IdxVal, DAG, dl);
 
@@ -15908,6 +15909,79 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget *Subtarget,
   SDValue A = Op.getOperand(0);
   SDValue B = Op.getOperand(1);
 
+  // Lower v16i8/v32i8 mul as promotion to v8i16/v16i16 vector
+  // pairs, multiply and truncate.
+  if (VT == MVT::v16i8 || VT == MVT::v32i8) {
+    if (Subtarget->hasInt256()) {
+      if (VT == MVT::v32i8) {
+        MVT SubVT = MVT::getVectorVT(MVT::i8, VT.getVectorNumElements() / 2);
+        SDValue Lo = DAG.getIntPtrConstant(0);
+        SDValue Hi = DAG.getIntPtrConstant(VT.getVectorNumElements() / 2);
+        SDValue ALo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, A, Lo);
+        SDValue BLo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, B, Lo);
+        SDValue AHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, A, Hi);
+        SDValue BHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, B, Hi);
+        return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
+                           DAG.getNode(ISD::MUL, dl, SubVT, ALo, BLo),
+                           DAG.getNode(ISD::MUL, dl, SubVT, AHi, BHi));
+      }
+
+      MVT ExVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements());
+      return DAG.getNode(
+          ISD::TRUNCATE, dl, VT,
+          DAG.getNode(ISD::MUL, dl, ExVT,
+                      DAG.getNode(ISD::SIGN_EXTEND, dl, ExVT, A),
+                      DAG.getNode(ISD::SIGN_EXTEND, dl, ExVT, B)));
+    }
+
+    assert(VT == MVT::v16i8 &&
+           "Pre-AVX2 support only supports v16i8 multiplication");
+    MVT ExVT = MVT::v8i16;
+
+    // Extract the lo parts and sign extend to i16
+    SDValue ALo, BLo;
+    if (Subtarget->hasSSE41()) {
+      ALo = DAG.getNode(X86ISD::VSEXT, dl, ExVT, A);
+      BLo = DAG.getNode(X86ISD::VSEXT, dl, ExVT, B);
+    } else {
+      const int ShufMask[] = {0, -1, 1, -1, 2, -1, 3, -1,
+                              4, -1, 5, -1, 6, -1, 7, -1};
+      ALo = DAG.getVectorShuffle(VT, dl, A, A, ShufMask);
+      BLo = DAG.getVectorShuffle(VT, dl, B, B, ShufMask);
+      ALo = DAG.getNode(ISD::BITCAST, dl, ExVT, ALo);
+      BLo = DAG.getNode(ISD::BITCAST, dl, ExVT, BLo);
+      ALo = DAG.getNode(ISD::SRA, dl, ExVT, ALo, DAG.getConstant(8, ExVT));
+      BLo = DAG.getNode(ISD::SRA, dl, ExVT, BLo, DAG.getConstant(8, ExVT));
+    }
+
+    // Extract the hi parts and sign extend to i16
+    SDValue AHi, BHi;
+    if (Subtarget->hasSSE41()) {
+      const int ShufMask[] = {8,  9,  10, 11, 12, 13, 14, 15,
+                              -1, -1, -1, -1, -1, -1, -1, -1};
+      AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask);
+      BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask);
+      AHi = DAG.getNode(X86ISD::VSEXT, dl, ExVT, AHi);
+      BHi = DAG.getNode(X86ISD::VSEXT, dl, ExVT, BHi);
+    } else {
+      const int ShufMask[] = {8,  -1, 9,  -1, 10, -1, 11, -1,
+                              12, -1, 13, -1, 14, -1, 15, -1};
+      AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask);
+      BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask);
+      AHi = DAG.getNode(ISD::BITCAST, dl, ExVT, AHi);
+      BHi = DAG.getNode(ISD::BITCAST, dl, ExVT, BHi);
+      AHi = DAG.getNode(ISD::SRA, dl, ExVT, AHi, DAG.getConstant(8, ExVT));
+      BHi = DAG.getNode(ISD::SRA, dl, ExVT, BHi, DAG.getConstant(8, ExVT));
+    }
+
+    // Multiply, mask the lower 8bits of the lo/hi results and pack
+    SDValue RLo = DAG.getNode(ISD::MUL, dl, ExVT, ALo, BLo);
+    SDValue RHi = DAG.getNode(ISD::MUL, dl, ExVT, AHi, BHi);
+    RLo = DAG.getNode(ISD::AND, dl, ExVT, RLo, DAG.getConstant(255, ExVT));
+    RHi = DAG.getNode(ISD::AND, dl, ExVT, RHi, DAG.getConstant(255, ExVT));
+    return DAG.getNode(X86ISD::PACKUS, dl, VT, RLo, RHi);
+  }
+
   // Lower v4i32 mul as 2x shuffle, 2x pmuludq, 2x shuffle.
   if (VT == MVT::v4i32) {
     assert(Subtarget->hasSSE2() && !Subtarget->hasSSE41() &&
index 72bdd9d04729a9fc67eefb1f134c70244ee4e9a3..a205be1c0cd6f4b21756376a362ecf3aa5fad171 100644 (file)
@@ -60,6 +60,49 @@ define <16 x i16> @test_vpmullw(<16 x i16> %i, <16 x i16> %j) nounwind readnone
   ret <16 x i16> %x
 }
 
+; CHECK: mul-v16i8
+; CHECK:       # BB#0:
+; CHECK-NEXT:  vpmovsxbw %xmm1, %ymm1
+; CHECK-NEXT:  vpmovsxbw %xmm0, %ymm0
+; CHECK-NEXT:  vpmullw %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:  vextracti128 $1, %ymm0, %xmm1
+; CHECK-NEXT:  vmovdqa {{.*#+}} xmm2 = <0,2,4,6,8,10,12,14,u,u,u,u,u,u,u,u>
+; CHECK-NEXT:  vpshufb %xmm2, %xmm1, %xmm1
+; CHECK-NEXT:  vpshufb %xmm2, %xmm0, %xmm0
+; CHECK-NEXT:  vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; CHECK-NEXT:  vzeroupper
+; CHECK-NEXT:  retq
+define <16 x i8> @mul-v16i8(<16 x i8> %i, <16 x i8> %j) nounwind readnone {
+  %x = mul <16 x i8> %i, %j
+  ret <16 x i8> %x
+}
+
+; CHECK: mul-v32i8
+; CHECK:       # BB#0:
+; CHECK-NEXT:  vextracti128 $1, %ymm1, %xmm2
+; CHECK-NEXT:  vpmovsxbw %xmm2, %ymm2
+; CHECK-NEXT:  vextracti128 $1, %ymm0, %xmm3
+; CHECK-NEXT:  vpmovsxbw %xmm3, %ymm3
+; CHECK-NEXT:  vpmullw %ymm2, %ymm3, %ymm2
+; CHECK-NEXT:  vextracti128 $1, %ymm2, %xmm3
+; CHECK-NEXT:  vmovdqa {{.*#+}} xmm4 = <0,2,4,6,8,10,12,14,u,u,u,u,u,u,u,u>
+; CHECK-NEXT:  vpshufb %xmm4, %xmm3, %xmm3
+; CHECK-NEXT:  vpshufb %xmm4, %xmm2, %xmm2
+; CHECK-NEXT:  vpunpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm3[0]
+; CHECK-NEXT:  vpmovsxbw %xmm1, %ymm1
+; CHECK-NEXT:  vpmovsxbw %xmm0, %ymm0
+; CHECK-NEXT:  vpmullw %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:  vextracti128 $1, %ymm0, %xmm1
+; CHECK-NEXT:  vpshufb %xmm4, %xmm1, %xmm1
+; CHECK-NEXT:  vpshufb %xmm4, %xmm0, %xmm0
+; CHECK-NEXT:  vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; CHECK-NEXT:  vinserti128 $1, %xmm2, %ymm0, %ymm0
+; CHECK-NEXT:  retq
+define <32 x i8> @mul-v32i8(<32 x i8> %i, <32 x i8> %j) nounwind readnone {
+  %x = mul <32 x i8> %i, %j
+  ret <32 x i8> %x
+}
+
 ; CHECK: mul-v4i64
 ; CHECK: vpmuludq %ymm
 ; CHECK-NEXT: vpsrlq $32, %ymm
index 6bfa656b6094dad840bb7fa80545847983dd07c1..83668754d4bf667e8b509dad0f1379bf6f2e6f58 100644 (file)
@@ -1,6 +1,53 @@
 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown | FileCheck %s --check-prefix=ALL --check-prefix=SSE2
 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=sse4.1 | FileCheck %s --check-prefix=ALL --check-prefix=SSE41
 
+define <16 x i8> @mul8c(<16 x i8> %i) nounwind  {
+; SSE2-LABEL: mul8c:
+; SSE2:       # BB#0: # %entry
+; SSE2-NEXT:    movdqa {{.*#+}} xmm1 = [117,117,117,117,117,117,117,117,117,117,117,117,117,117,117,117]
+; SSE2-NEXT:    psraw $8, %xmm1
+; SSE2-NEXT:    movdqa %xmm0, %xmm2
+; SSE2-NEXT:    punpckhbw {{.*#+}} xmm2 = xmm2[8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15]
+; SSE2-NEXT:    psraw $8, %xmm2
+; SSE2-NEXT:    pmullw %xmm1, %xmm2
+; SSE2-NEXT:    movdqa {{.*#+}} xmm3 = [255,255,255,255,255,255,255,255]
+; SSE2-NEXT:    pand %xmm3, %xmm2
+; SSE2-NEXT:    punpcklbw {{.*#+}} xmm0 = xmm0[0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7]
+; SSE2-NEXT:    psraw $8, %xmm0
+; SSE2-NEXT:    pmullw %xmm1, %xmm0
+; SSE2-NEXT:    pand %xmm3, %xmm0
+; SSE2-NEXT:    packuswb %xmm2, %xmm0
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: mul8c:
+; SSE41:       # BB#0: # %entry
+; SSE41-NEXT:    pmovsxbw %xmm0, %xmm1
+; SSE41-NEXT:    pmovsxbw {{.*}}(%rip), %xmm2
+; SSE41-NEXT:    pmullw %xmm2, %xmm1
+; SSE41-NEXT:    movdqa {{.*#+}} xmm3 = [255,255,255,255,255,255,255,255]
+; SSE41-NEXT:    pand %xmm3, %xmm1
+; SSE41-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[2,3,0,1]
+; SSE41-NEXT:    pmovsxbw %xmm0, %xmm0
+; SSE41-NEXT:    pmullw %xmm2, %xmm0
+; SSE41-NEXT:    pand %xmm3, %xmm0
+; SSE41-NEXT:    packuswb %xmm0, %xmm1
+; SSE41-NEXT:    movdqa %xmm1, %xmm0
+; SSE41-NEXT:    retq
+entry:
+  %A = mul <16 x i8> %i, < i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117 >
+  ret <16 x i8> %A
+}
+
+define <8 x i16> @mul16c(<8 x i16> %i) nounwind  {
+; ALL-LABEL: mul16c:
+; ALL:       # BB#0: # %entry
+; ALL-NEXT:    pmullw {{.*}}(%rip), %xmm0
+; ALL-NEXT:    retq
+entry:
+  %A = mul <8 x i16> %i, < i16 117, i16 117, i16 117, i16 117, i16 117, i16 117, i16 117, i16 117 >
+  ret <8 x i16> %A
+}
+
 define <4 x i32> @a(<4 x i32> %i) nounwind  {
 ; SSE2-LABEL: a:
 ; SSE2:       # BB#0: # %entry
@@ -42,6 +89,59 @@ entry:
   ret <2 x i64> %A
 }
 
+define <16 x i8> @mul8(<16 x i8> %i, <16 x i8> %j) nounwind  {
+; SSE2-LABEL: mul8:
+; SSE2:       # BB#0: # %entry
+; SSE2-NEXT:    movdqa %xmm1, %xmm3
+; SSE2-NEXT:    punpcklbw {{.*#+}} xmm3 = xmm3[0],xmm0[0],xmm3[1],xmm0[1],xmm3[2],xmm0[2],xmm3[3],xmm0[3],xmm3[4],xmm0[4],xmm3[5],xmm0[5],xmm3[6],xmm0[6],xmm3[7],xmm0[7]
+; SSE2-NEXT:    psraw $8, %xmm3
+; SSE2-NEXT:    movdqa %xmm0, %xmm2
+; SSE2-NEXT:    punpcklbw {{.*#+}} xmm2 = xmm2[0],xmm0[0],xmm2[1],xmm0[1],xmm2[2],xmm0[2],xmm2[3],xmm0[3],xmm2[4],xmm0[4],xmm2[5],xmm0[5],xmm2[6],xmm0[6],xmm2[7],xmm0[7]
+; SSE2-NEXT:    psraw $8, %xmm2
+; SSE2-NEXT:    pmullw %xmm3, %xmm2
+; SSE2-NEXT:    movdqa {{.*#+}} xmm3 = [255,255,255,255,255,255,255,255]
+; SSE2-NEXT:    pand %xmm3, %xmm2
+; SSE2-NEXT:    punpckhbw {{.*#+}} xmm1 = xmm1[8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15]
+; SSE2-NEXT:    psraw $8, %xmm1
+; SSE2-NEXT:    punpckhbw {{.*#+}} xmm0 = xmm0[8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15]
+; SSE2-NEXT:    psraw $8, %xmm0
+; SSE2-NEXT:    pmullw %xmm1, %xmm0
+; SSE2-NEXT:    pand %xmm3, %xmm0
+; SSE2-NEXT:    packuswb %xmm0, %xmm2
+; SSE2-NEXT:    movdqa %xmm2, %xmm0
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: mul8:
+; SSE41:       # BB#0: # %entry
+; SSE41-NEXT:    pmovsxbw %xmm1, %xmm3
+; SSE41-NEXT:    pmovsxbw %xmm0, %xmm2
+; SSE41-NEXT:    pmullw %xmm3, %xmm2
+; SSE41-NEXT:    movdqa {{.*#+}} xmm3 = [255,255,255,255,255,255,255,255]
+; SSE41-NEXT:    pand %xmm3, %xmm2
+; SSE41-NEXT:    pshufd {{.*#+}} xmm1 = xmm1[2,3,0,1]
+; SSE41-NEXT:    pmovsxbw %xmm1, %xmm1
+; SSE41-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[2,3,0,1]
+; SSE41-NEXT:    pmovsxbw %xmm0, %xmm0
+; SSE41-NEXT:    pmullw %xmm1, %xmm0
+; SSE41-NEXT:    pand %xmm3, %xmm0
+; SSE41-NEXT:    packuswb %xmm0, %xmm2
+; SSE41-NEXT:    movdqa %xmm2, %xmm0
+; SSE41-NEXT:    retq
+entry:
+  %A = mul <16 x i8> %i, %j
+  ret <16 x i8> %A
+}
+
+define <8 x i16> @mul16(<8 x i16> %i, <8 x i16> %j) nounwind  {
+; ALL-LABEL: mul16:
+; ALL:       # BB#0: # %entry
+; ALL-NEXT:    pmullw %xmm1, %xmm0
+; ALL-NEXT:    retq
+entry:
+  %A = mul <8 x i16> %i, %j
+  ret <8 x i16> %A
+}
+
 define <4 x i32> @c(<4 x i32> %i, <4 x i32> %j) nounwind  {
 ; SSE2-LABEL: c:
 ; SSE2:       # BB#0: # %entry