[x86] Split out the horizontal byte sum lowering component of the LUT
authorChandler Carruth <chandlerc@gmail.com>
Sat, 30 May 2015 09:46:16 +0000 (09:46 +0000)
committerChandler Carruth <chandlerc@gmail.com>
Sat, 30 May 2015 09:46:16 +0000 (09:46 +0000)
lowering into a helper function.

NFC.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@238650 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp

index 6c34efc..ef409d3 100644 (file)
@@ -17290,74 +17290,34 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget *Subtarget,
   return SDValue();
 }
 
-static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL,
-                                  const X86Subtarget *Subtarget,
-                                  SelectionDAG &DAG) {
-  EVT VT = Op.getValueType();
-  MVT EltVT = VT.getVectorElementType().getSimpleVT();
+/// Compute the horizontal sum of bytes in V for the elements of VT.
+///
+/// Requires V to be a byte vector and VT to be an integer vector type with
+/// wider elements than V's type. The width of the elements of VT determines
+/// how many bytes of V are summed horizontally to produce each element of the
+/// result.
+static SDValue LowerHorizontalByteSum(SDValue V, MVT VT,
+                                      const X86Subtarget *Subtarget,
+                                      SelectionDAG &DAG) {
+  SDLoc DL(V);
+  MVT ByteVecVT = V.getSimpleValueType();
+  MVT EltVT = VT.getVectorElementType();
+  int NumElts = VT.getVectorNumElements();
+  assert(ByteVecVT.getVectorElementType() == MVT::i8 &&
+         "Expected value to have byte element type.");
+  assert(EltVT != MVT::i8 &&
+         "Horizontal byte sum only makes sense for wider elements!");
   unsigned VecSize = VT.getSizeInBits();
-
-  // Implement a lookup table in register by using an algorithm based on:
-  // http://wm.ite.pl/articles/sse-popcount.html
-  //
-  // The general idea is that every lower byte nibble in the input vector is an
-  // index into a in-register pre-computed pop count table. We then split up the
-  // input vector in two new ones: (1) a vector with only the shifted-right
-  // higher nibbles for each byte and (2) a vector with the lower nibbles (and
-  // masked out higher ones) for each byte. PSHUB is used separately with both
-  // to index the in-register table. Next, both are added and the result is a
-  // i8 vector where each element contains the pop count for input byte.
-  //
-  // To obtain the pop count for elements != i8, we follow up with the same
-  // approach and use additional tricks as described below.
-  //
-  const int LUT[16] = {/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
-                       /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
-                       /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
-                       /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4};
-
-  int NumByteElts = VecSize / 8;
-  MVT ByteVecVT = MVT::getVectorVT(MVT::i8, NumByteElts);
-  SDValue In = DAG.getBitcast(ByteVecVT, Op);
-  SmallVector<SDValue, 16> LUTVec;
-  for (int i = 0; i < NumByteElts; ++i)
-    LUTVec.push_back(DAG.getConstant(LUT[i % 16], DL, MVT::i8));
-  SDValue InRegLUT = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, LUTVec);
-  SmallVector<SDValue, 16> Mask0F(NumByteElts,
-                                  DAG.getConstant(0x0F, DL, MVT::i8));
-  SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, Mask0F);
-
-  // High nibbles
-  SmallVector<SDValue, 16> Four(NumByteElts, DAG.getConstant(4, DL, MVT::i8));
-  SDValue FourV = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, Four);
-  SDValue HighNibbles = DAG.getNode(ISD::SRL, DL, ByteVecVT, In, FourV);
-
-  // Low nibbles
-  SDValue LowNibbles = DAG.getNode(ISD::AND, DL, ByteVecVT, In, M0F);
-
-  // The input vector is used as the shuffle mask that index elements into the
-  // LUT. After counting low and high nibbles, add the vector to obtain the
-  // final pop count per i8 element.
-  SDValue HighPopCnt =
-      DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, HighNibbles);
-  SDValue LowPopCnt =
-      DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, LowNibbles);
-  SDValue PopCnt = DAG.getNode(ISD::ADD, DL, ByteVecVT, HighPopCnt, LowPopCnt);
-
-  if (EltVT == MVT::i8)
-    return PopCnt;
+  assert(ByteVecVT.getSizeInBits() == VecSize && "Cannot change vector size!");
 
   // PSADBW instruction horizontally add all bytes and leave the result in i64
   // chunks, thus directly computes the pop count for v2i64 and v4i64.
   if (EltVT == MVT::i64) {
     SDValue Zeros = getZeroVector(ByteVecVT, Subtarget, DAG, DL);
-    PopCnt = DAG.getNode(X86ISD::PSADBW, DL, ByteVecVT, PopCnt, Zeros);
-    return DAG.getBitcast(VT, PopCnt);
+    V = DAG.getNode(X86ISD::PSADBW, DL, ByteVecVT, V, Zeros);
+    return DAG.getBitcast(VT, V);
   }
 
-  int NumI64Elts = VecSize / 64;
-  MVT VecI64VT = MVT::getVectorVT(MVT::i64, NumI64Elts);
-
   if (EltVT == MVT::i32) {
     // We unpack the low half and high half into i32s interleaved with zeros so
     // that we can use PSADBW to horizontally sum them. The most useful part of
@@ -17365,8 +17325,8 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL,
     // two v2i64 vectors which concatenated are the 4 population counts. We can
     // then use PACKUSWB to shrink and concatenate them into a v4i32 again.
     SDValue Zeros = getZeroVector(VT, Subtarget, DAG, DL);
-    SDValue Low = DAG.getNode(X86ISD::UNPCKL, DL, VT, PopCnt, Zeros);
-    SDValue High = DAG.getNode(X86ISD::UNPCKH, DL, VT, PopCnt, Zeros);
+    SDValue Low = DAG.getNode(X86ISD::UNPCKL, DL, VT, V, Zeros);
+    SDValue High = DAG.getNode(X86ISD::UNPCKH, DL, VT, V, Zeros);
 
     // Do the horizontal sums into two v2i64s.
     Zeros = getZeroVector(ByteVecVT, Subtarget, DAG, DL);
@@ -17377,11 +17337,11 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL,
 
     // Merge them together.
     MVT ShortVecVT = MVT::getVectorVT(MVT::i16, VecSize / 16);
-    PopCnt = DAG.getNode(X86ISD::PACKUS, DL, ByteVecVT,
-                         DAG.getBitcast(ShortVecVT, Low),
-                         DAG.getBitcast(ShortVecVT, High));
+    V = DAG.getNode(X86ISD::PACKUS, DL, ByteVecVT,
+                    DAG.getBitcast(ShortVecVT, Low),
+                    DAG.getBitcast(ShortVecVT, High));
 
-    return DAG.getBitcast(VT, PopCnt);
+    return DAG.getBitcast(VT, V);
   }
 
   // To obtain pop count for each i16 element, shuffle the byte pop count to get
@@ -17403,8 +17363,8 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL,
 
   // We can't use PSHUFB across lanes, so do the shuffle and sum inside each
   // 128-bit lane, and then collapse the result.
-  int NumLanes = NumByteElts / 16;
-  assert(NumByteElts % 16 == 0 && "Must have 16-byte multiple vectors!");
+  int NumLanes = VecSize / 128;
+  assert(VecSize % 128 == 0 && "Must have 16-byte multiple vectors!");
   for (int i = 0; i < NumLanes; ++i) {
     for (int j = 0; j < 8; ++j) {
       MaskA.push_back(i * 16 + j * 2);
@@ -17414,33 +17374,95 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL,
     MaskB.append((size_t)8, -1);
   }
 
-  SDValue ShuffA = DAG.getVectorShuffle(ByteVecVT, DL, PopCnt, Undef, MaskA);
-  SDValue ShuffB = DAG.getVectorShuffle(ByteVecVT, DL, PopCnt, Undef, MaskB);
-  PopCnt = DAG.getNode(ISD::ADD, DL, ByteVecVT, ShuffA, ShuffB);
+  SDValue ShuffA = DAG.getVectorShuffle(ByteVecVT, DL, V, Undef, MaskA);
+  SDValue ShuffB = DAG.getVectorShuffle(ByteVecVT, DL, V, Undef, MaskB);
+  V = DAG.getNode(ISD::ADD, DL, ByteVecVT, ShuffA, ShuffB);
 
   SmallVector<int, 4> Mask;
   for (int i = 0; i < NumLanes; ++i)
     Mask.push_back(2 * i);
   Mask.append((size_t)NumLanes, -1);
 
-  PopCnt = DAG.getBitcast(VecI64VT, PopCnt);
-  PopCnt =
-      DAG.getVectorShuffle(VecI64VT, DL, PopCnt, DAG.getUNDEF(VecI64VT), Mask);
-  PopCnt = DAG.getBitcast(ByteVecVT, PopCnt);
+  int NumI64Elts = VecSize / 64;
+  MVT VecI64VT = MVT::getVectorVT(MVT::i64, NumI64Elts);
+
+  V = DAG.getBitcast(VecI64VT, V);
+  V = DAG.getVectorShuffle(VecI64VT, DL, V, DAG.getUNDEF(VecI64VT), Mask);
+  V = DAG.getBitcast(ByteVecVT, V);
 
   // Zero extend i8s into i16 elts
   SmallVector<int, 16> ZExtInRegMask;
-  for (int i = 0; i < NumByteElts / 2; ++i) {
+  for (int i = 0; i < NumElts; ++i) {
     ZExtInRegMask.push_back(i);
-    ZExtInRegMask.push_back(NumByteElts);
+    ZExtInRegMask.push_back(2 * NumElts);
   }
 
   return DAG.getBitcast(
-      VT, DAG.getVectorShuffle(ByteVecVT, DL, PopCnt,
+      VT, DAG.getVectorShuffle(ByteVecVT, DL, V,
                                getZeroVector(ByteVecVT, Subtarget, DAG, DL),
                                ZExtInRegMask));
 }
 
+static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL,
+                                        const X86Subtarget *Subtarget,
+                                        SelectionDAG &DAG) {
+  MVT VT = Op.getSimpleValueType();
+  MVT EltVT = VT.getVectorElementType();
+  unsigned VecSize = VT.getSizeInBits();
+
+  // Implement a lookup table in register by using an algorithm based on:
+  // http://wm.ite.pl/articles/sse-popcount.html
+  //
+  // The general idea is that every lower byte nibble in the input vector is an
+  // index into a in-register pre-computed pop count table. We then split up the
+  // input vector in two new ones: (1) a vector with only the shifted-right
+  // higher nibbles for each byte and (2) a vector with the lower nibbles (and
+  // masked out higher ones) for each byte. PSHUB is used separately with both
+  // to index the in-register table. Next, both are added and the result is a
+  // i8 vector where each element contains the pop count for input byte.
+  //
+  // To obtain the pop count for elements != i8, we follow up with the same
+  // approach and use additional tricks as described below.
+  //
+  const int LUT[16] = {/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
+                       /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
+                       /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
+                       /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4};
+
+  int NumByteElts = VecSize / 8;
+  MVT ByteVecVT = MVT::getVectorVT(MVT::i8, NumByteElts);
+  SDValue In = DAG.getBitcast(ByteVecVT, Op);
+  SmallVector<SDValue, 16> LUTVec;
+  for (int i = 0; i < NumByteElts; ++i)
+    LUTVec.push_back(DAG.getConstant(LUT[i % 16], DL, MVT::i8));
+  SDValue InRegLUT = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, LUTVec);
+  SmallVector<SDValue, 16> Mask0F(NumByteElts,
+                                  DAG.getConstant(0x0F, DL, MVT::i8));
+  SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, Mask0F);
+
+  // High nibbles
+  SmallVector<SDValue, 16> Four(NumByteElts, DAG.getConstant(4, DL, MVT::i8));
+  SDValue FourV = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, Four);
+  SDValue HighNibbles = DAG.getNode(ISD::SRL, DL, ByteVecVT, In, FourV);
+
+  // Low nibbles
+  SDValue LowNibbles = DAG.getNode(ISD::AND, DL, ByteVecVT, In, M0F);
+
+  // The input vector is used as the shuffle mask that index elements into the
+  // LUT. After counting low and high nibbles, add the vector to obtain the
+  // final pop count per i8 element.
+  SDValue HighPopCnt =
+      DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, HighNibbles);
+  SDValue LowPopCnt =
+      DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, LowNibbles);
+  SDValue PopCnt = DAG.getNode(ISD::ADD, DL, ByteVecVT, HighPopCnt, LowPopCnt);
+
+  if (EltVT == MVT::i8)
+    return PopCnt;
+
+  return LowerHorizontalByteSum(PopCnt, VT, Subtarget, DAG);
+}
+
 static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,
                                        const X86Subtarget *Subtarget,
                                        SelectionDAG &DAG) {