[x86] Restore the bitcasts I removed when refactoring this to avoid
authorChandler Carruth <chandlerc@gmail.com>
Sat, 30 May 2015 04:05:11 +0000 (04:05 +0000)
committerChandler Carruth <chandlerc@gmail.com>
Sat, 30 May 2015 04:05:11 +0000 (04:05 +0000)
shifting vectors of bytes as x86 doesn't have direct support for that.

This removes a bunch of redundant masking in the generated code for SSE2
and SSE3.

In order to avoid the really significant code size growth this would
have triggered, I also factored the completely repeatative logic for
shifting and masking into two lambdas which in turn makes all of this
much easier to read IMO.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@238637 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
test/CodeGen/X86/vector-popcnt-128.ll

index 49be23a22043fe4b8b0bfe86737b0fa6b1a668af..6e37b7355fc74e7a417a67db85df4c69c5ee9c5d 100644 (file)
@@ -17479,7 +17479,6 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,
          "Only 128-bit vector bitmath lowering supported.");
 
   int VecSize = VT.getSizeInBits();
          "Only 128-bit vector bitmath lowering supported.");
 
   int VecSize = VT.getSizeInBits();
-  int NumElts = VT.getVectorNumElements();
   MVT EltVT = VT.getVectorElementType();
   int Len = EltVT.getSizeInBits();
 
   MVT EltVT = VT.getVectorElementType();
   int Len = EltVT.getSizeInBits();
 
@@ -17490,48 +17489,52 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,
   // this when we don't have SSSE3 which allows a LUT-based lowering that is
   // much faster, even faster than using native popcnt instructions.
 
   // this when we don't have SSSE3 which allows a LUT-based lowering that is
   // much faster, even faster than using native popcnt instructions.
 
-  SDValue Cst55 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), DL,
-                                  EltVT);
-  SDValue Cst33 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), DL,
-                                  EltVT);
-  SDValue Cst0F = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), DL,
-                                  EltVT);
+  auto GetShift = [&](unsigned OpCode, SDValue V, int Shifter) {
+    MVT VT = V.getSimpleValueType();
+    SmallVector<SDValue, 32> Shifters(
+        VT.getVectorNumElements(),
+        DAG.getConstant(Shifter, DL, VT.getVectorElementType()));
+    return DAG.getNode(OpCode, DL, VT, V,
+                       DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Shifters));
+  };
+  auto GetMask = [&](SDValue V, APInt Mask) {
+    MVT VT = V.getSimpleValueType();
+    SmallVector<SDValue, 32> Masks(
+        VT.getVectorNumElements(),
+        DAG.getConstant(Mask, DL, VT.getVectorElementType()));
+    return DAG.getNode(ISD::AND, DL, VT, V,
+                       DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Masks));
+  };
+
+  // We don't want to incur the implicit masks required to SRL vNi8 vectors on
+  // x86, so set the SRL type to have elements at least i16 wide. This is
+  // correct because all of our SRLs are followed immediately by a mask anyways
+  // that handles any bits that sneak into the high bits of the byte elements.
+  MVT SrlVT = Len > 8 ? VT : MVT::getVectorVT(MVT::i16, VecSize / 16);
 
   SDValue V = Op;
 
   // v = v - ((v >> 1) & 0x55555555...)
 
   SDValue V = Op;
 
   // v = v - ((v >> 1) & 0x55555555...)
-  SmallVector<SDValue, 8> Ones(NumElts, DAG.getConstant(1, DL, EltVT));
-  SDValue OnesV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Ones);
-  SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, V, OnesV);
-
-  SmallVector<SDValue, 8> Mask55(NumElts, Cst55);
-  SDValue M55 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask55);
-  SDValue And = DAG.getNode(ISD::AND, DL, Srl.getValueType(), Srl, M55);
-
+  SDValue Srl = DAG.getNode(
+      ISD::BITCAST, DL, VT,
+      GetShift(ISD::SRL, DAG.getNode(ISD::BITCAST, DL, SrlVT, V), 1));
+  SDValue And = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x55)));
   V = DAG.getNode(ISD::SUB, DL, VT, V, And);
 
   // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
   V = DAG.getNode(ISD::SUB, DL, VT, V, And);
 
   // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
-  SmallVector<SDValue, 8> Mask33(NumElts, Cst33);
-  SDValue M33 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask33);
-  SDValue AndLHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), V, M33);
-
-  SmallVector<SDValue, 8> Twos(NumElts, DAG.getConstant(2, DL, EltVT));
-  SDValue TwosV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Twos);
-  Srl = DAG.getNode(ISD::SRL, DL, VT, V, TwosV);
-  SDValue AndRHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), Srl, M33);
-
+  SDValue AndLHS = GetMask(V, APInt::getSplat(Len, APInt(8, 0x33)));
+  Srl = DAG.getNode(
+      ISD::BITCAST, DL, VT,
+      GetShift(ISD::SRL, DAG.getNode(ISD::BITCAST, DL, SrlVT, V), 2));
+  SDValue AndRHS = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x33)));
   V = DAG.getNode(ISD::ADD, DL, VT, AndLHS, AndRHS);
 
   // v = (v + (v >> 4)) & 0x0F0F0F0F...
   V = DAG.getNode(ISD::ADD, DL, VT, AndLHS, AndRHS);
 
   // v = (v + (v >> 4)) & 0x0F0F0F0F...
-  SmallVector<SDValue, 8> Fours(NumElts, DAG.getConstant(4, DL, EltVT));
-  SDValue FoursV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Fours);
-  Srl = DAG.getNode(ISD::SRL, DL, VT, V, FoursV);
+  Srl = DAG.getNode(
+      ISD::BITCAST, DL, VT,
+      GetShift(ISD::SRL, DAG.getNode(ISD::BITCAST, DL, SrlVT, V), 4));
   SDValue Add = DAG.getNode(ISD::ADD, DL, VT, V, Srl);
   SDValue Add = DAG.getNode(ISD::ADD, DL, VT, V, Srl);
-
-  SmallVector<SDValue, 8> Mask0F(NumElts, Cst0F);
-  SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask0F);
-
-  V = DAG.getNode(ISD::AND, DL, M0F.getValueType(), Add, M0F);
+  V = GetMask(Add, APInt::getSplat(Len, APInt(8, 0x0F)));
 
   // At this point, V contains the byte-wise population count, and we are
   // merely doing a horizontal sum if necessary to get the wider element
 
   // At this point, V contains the byte-wise population count, and we are
   // merely doing a horizontal sum if necessary to get the wider element
@@ -17543,26 +17546,21 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,
   MVT ByteVT = MVT::getVectorVT(MVT::i8, VecSize / 8);
   MVT ShiftVT = MVT::getVectorVT(MVT::i64, VecSize / 64);
   V = DAG.getNode(ISD::BITCAST, DL, ByteVT, V);
   MVT ByteVT = MVT::getVectorVT(MVT::i8, VecSize / 8);
   MVT ShiftVT = MVT::getVectorVT(MVT::i64, VecSize / 64);
   V = DAG.getNode(ISD::BITCAST, DL, ByteVT, V);
-  SmallVector<SDValue, 8> Csts;
   assert(Len <= 64 && "We don't support element sizes of more than 64 bits!");
   assert(isPowerOf2_32(Len) && "Only power of two element sizes supported!");
   for (int i = Len; i > 8; i /= 2) {
   assert(Len <= 64 && "We don't support element sizes of more than 64 bits!");
   assert(isPowerOf2_32(Len) && "Only power of two element sizes supported!");
   for (int i = Len; i > 8; i /= 2) {
-    Csts.assign(VecSize / 64, DAG.getConstant(i / 2, DL, MVT::i64));
     SDValue Shl = DAG.getNode(
     SDValue Shl = DAG.getNode(
-        ISD::SHL, DL, ShiftVT, DAG.getNode(ISD::BITCAST, DL, ShiftVT, V),
-        DAG.getNode(ISD::BUILD_VECTOR, DL, ShiftVT, Csts));
-    V = DAG.getNode(ISD::ADD, DL, ByteVT, V,
-                    DAG.getNode(ISD::BITCAST, DL, ByteVT, Shl));
+        ISD::BITCAST, DL, ByteVT,
+        GetShift(ISD::SHL, DAG.getNode(ISD::BITCAST, DL, ShiftVT, V), i / 2));
+    V = DAG.getNode(ISD::ADD, DL, ByteVT, V, Shl);
   }
 
   // The high byte now contains the sum of the element bytes. Shift it right
   // (if needed) to make it the low byte.
   V = DAG.getNode(ISD::BITCAST, DL, VT, V);
   }
 
   // The high byte now contains the sum of the element bytes. Shift it right
   // (if needed) to make it the low byte.
   V = DAG.getNode(ISD::BITCAST, DL, VT, V);
-  if (Len > 8) {
-    Csts.assign(NumElts, DAG.getConstant(Len - 8, DL, EltVT));
-    V = DAG.getNode(ISD::SRL, DL, VT, V,
-                    DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Csts));
-  }
+  if (Len > 8)
+    V = GetShift(ISD::SRL, V, Len - 8);
+
   return V;
 }
 
   return V;
 }
 
index dc99fec3d475b649baaf1f5d3e565129681f627f..f55b054deb0694340b7c43c18381b19354938cfe 100644 (file)
@@ -339,21 +339,17 @@ define <16 x i8> @testv16i8(<16 x i8> %in) {
 ; SSE2-NEXT:    movdqa %xmm0, %xmm1
 ; SSE2-NEXT:    psrlw $1, %xmm1
 ; SSE2-NEXT:    pand {{.*}}(%rip), %xmm1
 ; SSE2-NEXT:    movdqa %xmm0, %xmm1
 ; SSE2-NEXT:    psrlw $1, %xmm1
 ; SSE2-NEXT:    pand {{.*}}(%rip), %xmm1
-; SSE2-NEXT:    pand {{.*}}(%rip), %xmm1
 ; SSE2-NEXT:    psubb %xmm1, %xmm0
 ; SSE2-NEXT:    movdqa {{.*#+}} xmm1 = [51,51,51,51,51,51,51,51,51,51,51,51,51,51,51,51]
 ; SSE2-NEXT:    movdqa %xmm0, %xmm2
 ; SSE2-NEXT:    pand %xmm1, %xmm2
 ; SSE2-NEXT:    psrlw $2, %xmm0
 ; SSE2-NEXT:    psubb %xmm1, %xmm0
 ; SSE2-NEXT:    movdqa {{.*#+}} xmm1 = [51,51,51,51,51,51,51,51,51,51,51,51,51,51,51,51]
 ; SSE2-NEXT:    movdqa %xmm0, %xmm2
 ; SSE2-NEXT:    pand %xmm1, %xmm2
 ; SSE2-NEXT:    psrlw $2, %xmm0
-; SSE2-NEXT:    pand {{.*}}(%rip), %xmm0
 ; SSE2-NEXT:    pand %xmm1, %xmm0
 ; SSE2-NEXT:    paddb %xmm2, %xmm0
 ; SSE2-NEXT:    movdqa %xmm0, %xmm1
 ; SSE2-NEXT:    psrlw $4, %xmm1
 ; SSE2-NEXT:    pand %xmm1, %xmm0
 ; SSE2-NEXT:    paddb %xmm2, %xmm0
 ; SSE2-NEXT:    movdqa %xmm0, %xmm1
 ; SSE2-NEXT:    psrlw $4, %xmm1
-; SSE2-NEXT:    movdqa {{.*#+}} xmm2 = [15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15]
-; SSE2-NEXT:    pand %xmm2, %xmm1
 ; SSE2-NEXT:    paddb %xmm0, %xmm1
 ; SSE2-NEXT:    paddb %xmm0, %xmm1
-; SSE2-NEXT:    pand %xmm2, %xmm1
+; SSE2-NEXT:    pand {{.*}}(%rip), %xmm1
 ; SSE2-NEXT:    movdqa %xmm1, %xmm0
 ; SSE2-NEXT:    retq
 ;
 ; SSE2-NEXT:    movdqa %xmm1, %xmm0
 ; SSE2-NEXT:    retq
 ;
@@ -362,21 +358,17 @@ define <16 x i8> @testv16i8(<16 x i8> %in) {
 ; SSE3-NEXT:    movdqa %xmm0, %xmm1
 ; SSE3-NEXT:    psrlw $1, %xmm1
 ; SSE3-NEXT:    pand {{.*}}(%rip), %xmm1
 ; SSE3-NEXT:    movdqa %xmm0, %xmm1
 ; SSE3-NEXT:    psrlw $1, %xmm1
 ; SSE3-NEXT:    pand {{.*}}(%rip), %xmm1
-; SSE3-NEXT:    pand {{.*}}(%rip), %xmm1
 ; SSE3-NEXT:    psubb %xmm1, %xmm0
 ; SSE3-NEXT:    movdqa {{.*#+}} xmm1 = [51,51,51,51,51,51,51,51,51,51,51,51,51,51,51,51]
 ; SSE3-NEXT:    movdqa %xmm0, %xmm2
 ; SSE3-NEXT:    pand %xmm1, %xmm2
 ; SSE3-NEXT:    psrlw $2, %xmm0
 ; SSE3-NEXT:    psubb %xmm1, %xmm0
 ; SSE3-NEXT:    movdqa {{.*#+}} xmm1 = [51,51,51,51,51,51,51,51,51,51,51,51,51,51,51,51]
 ; SSE3-NEXT:    movdqa %xmm0, %xmm2
 ; SSE3-NEXT:    pand %xmm1, %xmm2
 ; SSE3-NEXT:    psrlw $2, %xmm0
-; SSE3-NEXT:    pand {{.*}}(%rip), %xmm0
 ; SSE3-NEXT:    pand %xmm1, %xmm0
 ; SSE3-NEXT:    paddb %xmm2, %xmm0
 ; SSE3-NEXT:    movdqa %xmm0, %xmm1
 ; SSE3-NEXT:    psrlw $4, %xmm1
 ; SSE3-NEXT:    pand %xmm1, %xmm0
 ; SSE3-NEXT:    paddb %xmm2, %xmm0
 ; SSE3-NEXT:    movdqa %xmm0, %xmm1
 ; SSE3-NEXT:    psrlw $4, %xmm1
-; SSE3-NEXT:    movdqa {{.*#+}} xmm2 = [15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15]
-; SSE3-NEXT:    pand %xmm2, %xmm1
 ; SSE3-NEXT:    paddb %xmm0, %xmm1
 ; SSE3-NEXT:    paddb %xmm0, %xmm1
-; SSE3-NEXT:    pand %xmm2, %xmm1
+; SSE3-NEXT:    pand {{.*}}(%rip), %xmm1
 ; SSE3-NEXT:    movdqa %xmm1, %xmm0
 ; SSE3-NEXT:    retq
 ;
 ; SSE3-NEXT:    movdqa %xmm1, %xmm0
 ; SSE3-NEXT:    retq
 ;