From: Chandler Carruth Date: Sat, 30 May 2015 09:46:16 +0000 (+0000) Subject: [x86] Split out the horizontal byte sum lowering component of the LUT X-Git-Url: http://plrg.eecs.uci.edu/git/?p=oota-llvm.git;a=commitdiff_plain;h=da8bb20158469544bab61b24a5123639d8ee3e09 [x86] Split out the horizontal byte sum lowering component of the LUT lowering into a helper function. NFC. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@238650 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 6c34efc49f7..ef409d3d453 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -17290,74 +17290,34 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget *Subtarget, return SDValue(); } -static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL, - const X86Subtarget *Subtarget, - SelectionDAG &DAG) { - EVT VT = Op.getValueType(); - MVT EltVT = VT.getVectorElementType().getSimpleVT(); +/// Compute the horizontal sum of bytes in V for the elements of VT. +/// +/// Requires V to be a byte vector and VT to be an integer vector type with +/// wider elements than V's type. The width of the elements of VT determines +/// how many bytes of V are summed horizontally to produce each element of the +/// result. +static SDValue LowerHorizontalByteSum(SDValue V, MVT VT, + const X86Subtarget *Subtarget, + SelectionDAG &DAG) { + SDLoc DL(V); + MVT ByteVecVT = V.getSimpleValueType(); + MVT EltVT = VT.getVectorElementType(); + int NumElts = VT.getVectorNumElements(); + assert(ByteVecVT.getVectorElementType() == MVT::i8 && + "Expected value to have byte element type."); + assert(EltVT != MVT::i8 && + "Horizontal byte sum only makes sense for wider elements!"); unsigned VecSize = VT.getSizeInBits(); - - // Implement a lookup table in register by using an algorithm based on: - // http://wm.ite.pl/articles/sse-popcount.html - // - // The general idea is that every lower byte nibble in the input vector is an - // index into a in-register pre-computed pop count table. We then split up the - // input vector in two new ones: (1) a vector with only the shifted-right - // higher nibbles for each byte and (2) a vector with the lower nibbles (and - // masked out higher ones) for each byte. PSHUB is used separately with both - // to index the in-register table. Next, both are added and the result is a - // i8 vector where each element contains the pop count for input byte. - // - // To obtain the pop count for elements != i8, we follow up with the same - // approach and use additional tricks as described below. - // - const int LUT[16] = {/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2, - /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3, - /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3, - /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4}; - - int NumByteElts = VecSize / 8; - MVT ByteVecVT = MVT::getVectorVT(MVT::i8, NumByteElts); - SDValue In = DAG.getBitcast(ByteVecVT, Op); - SmallVector LUTVec; - for (int i = 0; i < NumByteElts; ++i) - LUTVec.push_back(DAG.getConstant(LUT[i % 16], DL, MVT::i8)); - SDValue InRegLUT = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, LUTVec); - SmallVector Mask0F(NumByteElts, - DAG.getConstant(0x0F, DL, MVT::i8)); - SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, Mask0F); - - // High nibbles - SmallVector Four(NumByteElts, DAG.getConstant(4, DL, MVT::i8)); - SDValue FourV = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, Four); - SDValue HighNibbles = DAG.getNode(ISD::SRL, DL, ByteVecVT, In, FourV); - - // Low nibbles - SDValue LowNibbles = DAG.getNode(ISD::AND, DL, ByteVecVT, In, M0F); - - // The input vector is used as the shuffle mask that index elements into the - // LUT. After counting low and high nibbles, add the vector to obtain the - // final pop count per i8 element. - SDValue HighPopCnt = - DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, HighNibbles); - SDValue LowPopCnt = - DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, LowNibbles); - SDValue PopCnt = DAG.getNode(ISD::ADD, DL, ByteVecVT, HighPopCnt, LowPopCnt); - - if (EltVT == MVT::i8) - return PopCnt; + assert(ByteVecVT.getSizeInBits() == VecSize && "Cannot change vector size!"); // PSADBW instruction horizontally add all bytes and leave the result in i64 // chunks, thus directly computes the pop count for v2i64 and v4i64. if (EltVT == MVT::i64) { SDValue Zeros = getZeroVector(ByteVecVT, Subtarget, DAG, DL); - PopCnt = DAG.getNode(X86ISD::PSADBW, DL, ByteVecVT, PopCnt, Zeros); - return DAG.getBitcast(VT, PopCnt); + V = DAG.getNode(X86ISD::PSADBW, DL, ByteVecVT, V, Zeros); + return DAG.getBitcast(VT, V); } - int NumI64Elts = VecSize / 64; - MVT VecI64VT = MVT::getVectorVT(MVT::i64, NumI64Elts); - if (EltVT == MVT::i32) { // We unpack the low half and high half into i32s interleaved with zeros so // that we can use PSADBW to horizontally sum them. The most useful part of @@ -17365,8 +17325,8 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL, // two v2i64 vectors which concatenated are the 4 population counts. We can // then use PACKUSWB to shrink and concatenate them into a v4i32 again. SDValue Zeros = getZeroVector(VT, Subtarget, DAG, DL); - SDValue Low = DAG.getNode(X86ISD::UNPCKL, DL, VT, PopCnt, Zeros); - SDValue High = DAG.getNode(X86ISD::UNPCKH, DL, VT, PopCnt, Zeros); + SDValue Low = DAG.getNode(X86ISD::UNPCKL, DL, VT, V, Zeros); + SDValue High = DAG.getNode(X86ISD::UNPCKH, DL, VT, V, Zeros); // Do the horizontal sums into two v2i64s. Zeros = getZeroVector(ByteVecVT, Subtarget, DAG, DL); @@ -17377,11 +17337,11 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL, // Merge them together. MVT ShortVecVT = MVT::getVectorVT(MVT::i16, VecSize / 16); - PopCnt = DAG.getNode(X86ISD::PACKUS, DL, ByteVecVT, - DAG.getBitcast(ShortVecVT, Low), - DAG.getBitcast(ShortVecVT, High)); + V = DAG.getNode(X86ISD::PACKUS, DL, ByteVecVT, + DAG.getBitcast(ShortVecVT, Low), + DAG.getBitcast(ShortVecVT, High)); - return DAG.getBitcast(VT, PopCnt); + return DAG.getBitcast(VT, V); } // To obtain pop count for each i16 element, shuffle the byte pop count to get @@ -17403,8 +17363,8 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL, // We can't use PSHUFB across lanes, so do the shuffle and sum inside each // 128-bit lane, and then collapse the result. - int NumLanes = NumByteElts / 16; - assert(NumByteElts % 16 == 0 && "Must have 16-byte multiple vectors!"); + int NumLanes = VecSize / 128; + assert(VecSize % 128 == 0 && "Must have 16-byte multiple vectors!"); for (int i = 0; i < NumLanes; ++i) { for (int j = 0; j < 8; ++j) { MaskA.push_back(i * 16 + j * 2); @@ -17414,33 +17374,95 @@ static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL, MaskB.append((size_t)8, -1); } - SDValue ShuffA = DAG.getVectorShuffle(ByteVecVT, DL, PopCnt, Undef, MaskA); - SDValue ShuffB = DAG.getVectorShuffle(ByteVecVT, DL, PopCnt, Undef, MaskB); - PopCnt = DAG.getNode(ISD::ADD, DL, ByteVecVT, ShuffA, ShuffB); + SDValue ShuffA = DAG.getVectorShuffle(ByteVecVT, DL, V, Undef, MaskA); + SDValue ShuffB = DAG.getVectorShuffle(ByteVecVT, DL, V, Undef, MaskB); + V = DAG.getNode(ISD::ADD, DL, ByteVecVT, ShuffA, ShuffB); SmallVector Mask; for (int i = 0; i < NumLanes; ++i) Mask.push_back(2 * i); Mask.append((size_t)NumLanes, -1); - PopCnt = DAG.getBitcast(VecI64VT, PopCnt); - PopCnt = - DAG.getVectorShuffle(VecI64VT, DL, PopCnt, DAG.getUNDEF(VecI64VT), Mask); - PopCnt = DAG.getBitcast(ByteVecVT, PopCnt); + int NumI64Elts = VecSize / 64; + MVT VecI64VT = MVT::getVectorVT(MVT::i64, NumI64Elts); + + V = DAG.getBitcast(VecI64VT, V); + V = DAG.getVectorShuffle(VecI64VT, DL, V, DAG.getUNDEF(VecI64VT), Mask); + V = DAG.getBitcast(ByteVecVT, V); // Zero extend i8s into i16 elts SmallVector ZExtInRegMask; - for (int i = 0; i < NumByteElts / 2; ++i) { + for (int i = 0; i < NumElts; ++i) { ZExtInRegMask.push_back(i); - ZExtInRegMask.push_back(NumByteElts); + ZExtInRegMask.push_back(2 * NumElts); } return DAG.getBitcast( - VT, DAG.getVectorShuffle(ByteVecVT, DL, PopCnt, + VT, DAG.getVectorShuffle(ByteVecVT, DL, V, getZeroVector(ByteVecVT, Subtarget, DAG, DL), ZExtInRegMask)); } +static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, SDLoc DL, + const X86Subtarget *Subtarget, + SelectionDAG &DAG) { + MVT VT = Op.getSimpleValueType(); + MVT EltVT = VT.getVectorElementType(); + unsigned VecSize = VT.getSizeInBits(); + + // Implement a lookup table in register by using an algorithm based on: + // http://wm.ite.pl/articles/sse-popcount.html + // + // The general idea is that every lower byte nibble in the input vector is an + // index into a in-register pre-computed pop count table. We then split up the + // input vector in two new ones: (1) a vector with only the shifted-right + // higher nibbles for each byte and (2) a vector with the lower nibbles (and + // masked out higher ones) for each byte. PSHUB is used separately with both + // to index the in-register table. Next, both are added and the result is a + // i8 vector where each element contains the pop count for input byte. + // + // To obtain the pop count for elements != i8, we follow up with the same + // approach and use additional tricks as described below. + // + const int LUT[16] = {/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2, + /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3, + /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3, + /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4}; + + int NumByteElts = VecSize / 8; + MVT ByteVecVT = MVT::getVectorVT(MVT::i8, NumByteElts); + SDValue In = DAG.getBitcast(ByteVecVT, Op); + SmallVector LUTVec; + for (int i = 0; i < NumByteElts; ++i) + LUTVec.push_back(DAG.getConstant(LUT[i % 16], DL, MVT::i8)); + SDValue InRegLUT = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, LUTVec); + SmallVector Mask0F(NumByteElts, + DAG.getConstant(0x0F, DL, MVT::i8)); + SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, Mask0F); + + // High nibbles + SmallVector Four(NumByteElts, DAG.getConstant(4, DL, MVT::i8)); + SDValue FourV = DAG.getNode(ISD::BUILD_VECTOR, DL, ByteVecVT, Four); + SDValue HighNibbles = DAG.getNode(ISD::SRL, DL, ByteVecVT, In, FourV); + + // Low nibbles + SDValue LowNibbles = DAG.getNode(ISD::AND, DL, ByteVecVT, In, M0F); + + // The input vector is used as the shuffle mask that index elements into the + // LUT. After counting low and high nibbles, add the vector to obtain the + // final pop count per i8 element. + SDValue HighPopCnt = + DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, HighNibbles); + SDValue LowPopCnt = + DAG.getNode(X86ISD::PSHUFB, DL, ByteVecVT, InRegLUT, LowNibbles); + SDValue PopCnt = DAG.getNode(ISD::ADD, DL, ByteVecVT, HighPopCnt, LowPopCnt); + + if (EltVT == MVT::i8) + return PopCnt; + + return LowerHorizontalByteSum(PopCnt, VT, Subtarget, DAG); +} + static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL, const X86Subtarget *Subtarget, SelectionDAG &DAG) {