[x86] Restructure the parallel bitmath lowering of popcount into
[oota-llvm.git] / lib / Target / X86 / X86ISelLowering.cpp
index fdaeed71330e17d98493167509e5231ef5d3b204..c834be3c1a7c3c469b084ede01a630f7d55a2406 100644 (file)
@@ -846,8 +846,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     // know to perform better than using the popcnt instructions on each vector
     // element. If popcnt isn't supported, always provide the custom version.
     if (!Subtarget->hasPOPCNT()) {
-      setOperationAction(ISD::CTPOP,            MVT::v4i32, Custom);
       setOperationAction(ISD::CTPOP,            MVT::v2i64, Custom);
+      setOperationAction(ISD::CTPOP,            MVT::v4i32, Custom);
+      setOperationAction(ISD::CTPOP,            MVT::v8i16, Custom);
+      setOperationAction(ISD::CTPOP,            MVT::v16i8, Custom);
     }
 
     // Custom lower build_vector, vector_shuffle, and extract_vector_elt.
@@ -17327,141 +17329,131 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget *Subtarget,
   return SDValue();
 }
 
-static SDValue LowerCTPOP(SDValue Op, const X86Subtarget *Subtarget,
-                          SelectionDAG &DAG) {
-  SDNode *Node = Op.getNode();
-  SDLoc dl(Node);
-
-  Op = Op.getOperand(0);
-  EVT VT = Op.getValueType();
+static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL,
+                                       const X86Subtarget *Subtarget,
+                                       SelectionDAG &DAG) {
+  MVT VT = Op.getSimpleValueType();
   assert((VT.is128BitVector() || VT.is256BitVector()) &&
          "CTPOP lowering only implemented for 128/256-bit wide vector types");
 
-  unsigned NumElts = VT.getVectorNumElements();
-  EVT EltVT = VT.getVectorElementType();
-  unsigned Len = EltVT.getSizeInBits();
+  int VecSize = VT.getSizeInBits();
+  int NumElts = VT.getVectorNumElements();
+  MVT EltVT = VT.getVectorElementType();
+  int Len = EltVT.getSizeInBits();
 
   // This is the vectorized version of the "best" algorithm from
   // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
   // with a minor tweak to use a series of adds + shifts instead of vector
-  // multiplications. Implemented for the v2i64, v4i64, v4i32, v8i32 types:
-  //
-  //  v2i64, v4i64, v4i32 => Only profitable w/ popcnt disabled
-  //  v8i32 => Always profitable
+  // multiplications. Implemented for all integer vector types.
   //
-  // FIXME: There a couple of possible improvements:
-  //
-  // 1) Support for i8 and i16 vectors (needs measurements if popcnt enabled).
-  // 2) Use strategies from http://wm.ite.pl/articles/sse-popcount.html
-  //
-  assert(EltVT.isInteger() && (Len == 32 || Len == 64) && Len % 8 == 0 &&
-         "CTPOP not implemented for this vector element type.");
-
-  // X86 canonicalize ANDs to vXi64, generate the appropriate bitcasts to avoid
-  // extra legalization.
-  bool NeedsBitcast = EltVT == MVT::i32;
-  MVT BitcastVT = VT.is256BitVector() ? MVT::v4i64 : MVT::v2i64;
+  // FIXME: Use strategies from http://wm.ite.pl/articles/sse-popcount.html
 
-  SDValue Cst55 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), dl,
+  SDValue Cst55 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), DL,
                                   EltVT);
-  SDValue Cst33 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), dl,
+  SDValue Cst33 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), DL,
                                   EltVT);
-  SDValue Cst0F = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), dl,
+  SDValue Cst0F = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), DL,
                                   EltVT);
 
+  SDValue V = Op;
+
   // v = v - ((v >> 1) & 0x55555555...)
-  SmallVector<SDValue, 8> Ones(NumElts, DAG.getConstant(1, dl, EltVT));
-  SDValue OnesV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Ones);
-  SDValue Srl = DAG.getNode(ISD::SRL, dl, VT, Op, OnesV);
-  if (NeedsBitcast)
-    Srl = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Srl);
+  SmallVector<SDValue, 8> Ones(NumElts, DAG.getConstant(1, DL, EltVT));
+  SDValue OnesV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Ones);
+  SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, V, OnesV);
 
   SmallVector<SDValue, 8> Mask55(NumElts, Cst55);
-  SDValue M55 = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask55);
-  if (NeedsBitcast)
-    M55 = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M55);
+  SDValue M55 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask55);
+  SDValue And = DAG.getNode(ISD::AND, DL, Srl.getValueType(), Srl, M55);
 
-  SDValue And = DAG.getNode(ISD::AND, dl, Srl.getValueType(), Srl, M55);
-  if (VT != And.getValueType())
-    And = DAG.getNode(ISD::BITCAST, dl, VT, And);
-  SDValue Sub = DAG.getNode(ISD::SUB, dl, VT, Op, And);
+  V = DAG.getNode(ISD::SUB, DL, VT, V, And);
 
   // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
   SmallVector<SDValue, 8> Mask33(NumElts, Cst33);
-  SDValue M33 = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask33);
-  SmallVector<SDValue, 8> Twos(NumElts, DAG.getConstant(2, dl, EltVT));
-  SDValue TwosV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Twos);
+  SDValue M33 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask33);
+  SDValue AndLHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), V, M33);
 
-  Srl = DAG.getNode(ISD::SRL, dl, VT, Sub, TwosV);
-  if (NeedsBitcast) {
-    Srl = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Srl);
-    M33 = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M33);
-    Sub = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Sub);
-  }
+  SmallVector<SDValue, 8> Twos(NumElts, DAG.getConstant(2, DL, EltVT));
+  SDValue TwosV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Twos);
+  Srl = DAG.getNode(ISD::SRL, DL, VT, V, TwosV);
+  SDValue AndRHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), Srl, M33);
 
-  SDValue AndRHS = DAG.getNode(ISD::AND, dl, M33.getValueType(), Srl, M33);
-  SDValue AndLHS = DAG.getNode(ISD::AND, dl, M33.getValueType(), Sub, M33);
-  if (VT != AndRHS.getValueType()) {
-    AndRHS = DAG.getNode(ISD::BITCAST, dl, VT, AndRHS);
-    AndLHS = DAG.getNode(ISD::BITCAST, dl, VT, AndLHS);
-  }
-  SDValue Add = DAG.getNode(ISD::ADD, dl, VT, AndLHS, AndRHS);
+  V = DAG.getNode(ISD::ADD, DL, VT, AndLHS, AndRHS);
 
   // v = (v + (v >> 4)) & 0x0F0F0F0F...
-  SmallVector<SDValue, 8> Fours(NumElts, DAG.getConstant(4, dl, EltVT));
-  SDValue FoursV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Fours);
-  Srl = DAG.getNode(ISD::SRL, dl, VT, Add, FoursV);
-  Add = DAG.getNode(ISD::ADD, dl, VT, Add, Srl);
+  SmallVector<SDValue, 8> Fours(NumElts, DAG.getConstant(4, DL, EltVT));
+  SDValue FoursV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Fours);
+  Srl = DAG.getNode(ISD::SRL, DL, VT, V, FoursV);
+  SDValue Add = DAG.getNode(ISD::ADD, DL, VT, V, Srl);
 
   SmallVector<SDValue, 8> Mask0F(NumElts, Cst0F);
-  SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask0F);
-  if (NeedsBitcast) {
-    Add = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Add);
-    M0F = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M0F);
-  }
-  And = DAG.getNode(ISD::AND, dl, M0F.getValueType(), Add, M0F);
-  if (VT != And.getValueType())
-    And = DAG.getNode(ISD::BITCAST, dl, VT, And);
+  SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask0F);
 
-  // The algorithm mentioned above uses:
-  //    v = (v * 0x01010101...) >> (Len - 8)
-  //
-  // Change it to use vector adds + vector shifts which yield faster results on
-  // Haswell than using vector integer multiplication.
-  //
-  // For i32 elements:
-  //    v = v + (v >> 8)
-  //    v = v + (v >> 16)
-  //
-  // For i64 elements:
-  //    v = v + (v >> 8)
-  //    v = v + (v >> 16)
-  //    v = v + (v >> 32)
+  V = DAG.getNode(ISD::AND, DL, M0F.getValueType(), Add, M0F);
+
+  // At this point, V contains the byte-wise population count, and we are
+  // merely doing a horizontal sum if necessary to get the wider element
+  // counts.
   //
-  Add = And;
+  // FIXME: There is a different lowering strategy above for the horizontal sum
+  // of byte-wise population counts. This one and that one should be merged,
+  // using the fastest of the two for each size.
+  MVT ByteVT = MVT::getVectorVT(MVT::i8, VecSize / 8);
+  MVT ShiftVT = MVT::getVectorVT(MVT::i64, VecSize / 64);
+  V = DAG.getNode(ISD::BITCAST, DL, ByteVT, V);
   SmallVector<SDValue, 8> Csts;
-  for (unsigned i = 8; i <= Len/2; i *= 2) {
-    Csts.assign(NumElts, DAG.getConstant(i, dl, EltVT));
-    SDValue CstsV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Csts);
-    Srl = DAG.getNode(ISD::SRL, dl, VT, Add, CstsV);
-    Add = DAG.getNode(ISD::ADD, dl, VT, Add, Srl);
-    Csts.clear();
+  assert(Len <= 64 && "We don't support element sizes of more than 64 bits!");
+  assert(isPowerOf2_32(Len) && "Only power of two element sizes supported!");
+  for (int i = Len; i > 8; i /= 2) {
+    Csts.assign(VecSize / 64, DAG.getConstant(i / 2, DL, MVT::i64));
+    SDValue Shl = DAG.getNode(
+        ISD::SHL, DL, ShiftVT, DAG.getNode(ISD::BITCAST, DL, ShiftVT, V),
+        DAG.getNode(ISD::BUILD_VECTOR, DL, ShiftVT, Csts));
+    V = DAG.getNode(ISD::ADD, DL, ByteVT, V,
+                    DAG.getNode(ISD::BITCAST, DL, ByteVT, Shl));
+  }
+
+  // The high byte now contains the sum of the element bytes. Shift it right
+  // (if needed) to make it the low byte.
+  V = DAG.getNode(ISD::BITCAST, DL, VT, V);
+  if (Len > 8) {
+    Csts.assign(NumElts, DAG.getConstant(Len - 8, DL, EltVT));
+    V = DAG.getNode(ISD::SRL, DL, VT, V,
+                    DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Csts));
   }
+  return V;
+}
 
-  // The result is on the least significant 6-bits on i32 and 7-bits on i64.
-  SDValue Cst3F = DAG.getConstant(APInt(Len, Len == 32 ? 0x3F : 0x7F), dl,
-                                  EltVT);
-  SmallVector<SDValue, 8> Cst3FV(NumElts, Cst3F);
-  SDValue M3F = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Cst3FV);
-  if (NeedsBitcast) {
-    Add = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Add);
-    M3F = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M3F);
+
+static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget *Subtarget,
+                                SelectionDAG &DAG) {
+  MVT VT = Op.getSimpleValueType();
+  // FIXME: Need to add AVX-512 support here!
+  assert((VT.is256BitVector() || VT.is128BitVector()) &&
+         "Unknown CTPOP type to handle");
+  SDLoc DL(Op.getNode());
+  SDValue Op0 = Op.getOperand(0);
+
+  if (VT.is256BitVector() && !Subtarget->hasInt256()) {
+    unsigned NumElems = VT.getVectorNumElements();
+
+    // Extract each 128-bit vector, compute pop count and concat the result.
+    SDValue LHS = Extract128BitVector(Op0, 0, DAG, DL);
+    SDValue RHS = Extract128BitVector(Op0, NumElems/2, DAG, DL);
+
+    return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT,
+                       LowerVectorCTPOPBitmath(LHS, DL, Subtarget, DAG),
+                       LowerVectorCTPOPBitmath(RHS, DL, Subtarget, DAG));
   }
-  And = DAG.getNode(ISD::AND, dl, M3F.getValueType(), Add, M3F);
-  if (VT != And.getValueType())
-    And = DAG.getNode(ISD::BITCAST, dl, VT, And);
 
-  return And;
+  return LowerVectorCTPOPBitmath(Op0, DL, Subtarget, DAG);
+}
+
+static SDValue LowerCTPOP(SDValue Op, const X86Subtarget *Subtarget,
+                          SelectionDAG &DAG) {
+  assert(Op.getValueType().isVector() &&
+         "We only do custom lowering for vector population count.");
+  return LowerVectorCTPOP(Op, Subtarget, DAG);
 }
 
 static SDValue LowerLOAD_SUB(SDValue Op, SelectionDAG &DAG) {