[x86] Add vector @llvm.ctpop intrinsic custom lowering
[oota-llvm.git] / lib / Target / X86 / X86ISelLowering.cpp
index 80902a64994f47bedb94482d9f3174e9fc339678..c6e72c9aa6b0cde2fd3501cbc4cdb2e3795fc4c0 100644 (file)
@@ -987,6 +987,14 @@ void X86TargetLowering::resetOperationActions() {
     setOperationAction(ISD::INSERT_VECTOR_ELT,  MVT::v4i32, Custom);
     setOperationAction(ISD::INSERT_VECTOR_ELT,  MVT::v4f32, Custom);
 
+    // Only provide customized ctpop vector bit twiddling for vector types we
+    // know to perform better than using the popcnt instructions on each vector
+    // element. If popcnt isn't supported, always provide the custom version.
+    if (!Subtarget->hasPOPCNT()) {
+      setOperationAction(ISD::CTPOP,            MVT::v4i32, Custom);
+      setOperationAction(ISD::CTPOP,            MVT::v2i64, Custom);
+    }
+
     // Custom lower build_vector, vector_shuffle, and extract_vector_elt.
     for (int i = MVT::v16i8; i != MVT::v2i64; ++i) {
       MVT VT = (MVT::SimpleValueType)i;
@@ -1282,6 +1290,16 @@ void X86TargetLowering::resetOperationActions() {
       // The custom lowering for UINT_TO_FP for v8i32 becomes interesting
       // when we have a 256bit-wide blend with immediate.
       setOperationAction(ISD::UINT_TO_FP, MVT::v8i32, Custom);
+
+      // Only provide customized ctpop vector bit twiddling for vector types we
+      // know to perform better than using the popcnt instructions on each
+      // vector element. If popcnt isn't supported, always provide the custom
+      // version.
+      if (!Subtarget->hasPOPCNT())
+        setOperationAction(ISD::CTPOP,           MVT::v4i64, Custom);
+
+      // Custom CTPOP always performs better on natively supported v8i32
+      setOperationAction(ISD::CTPOP,             MVT::v8i32, Custom);
     } else {
       setOperationAction(ISD::ADD,             MVT::v4i64, Custom);
       setOperationAction(ISD::ADD,             MVT::v8i32, Custom);
@@ -19183,6 +19201,139 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget *Subtarget,
   return SDValue();
 }
 
+static SDValue LowerCTPOP(SDValue Op, const X86Subtarget *Subtarget,
+                          SelectionDAG &DAG) {
+  SDNode *Node = Op.getNode();
+  SDLoc dl(Node);
+
+  Op = Op.getOperand(0);
+  EVT VT = Op.getValueType();
+  assert((VT.is128BitVector() || VT.is256BitVector()) &&
+         "CTPOP lowering only implemented for 128/256-bit wide vector types");
+
+  unsigned NumElts = VT.getVectorNumElements();
+  EVT EltVT = VT.getVectorElementType();
+  unsigned Len = EltVT.getSizeInBits();
+
+  // This is the vectorized version of the "best" algorithm from
+  // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
+  // with a minor tweak to use a series of adds + shifts instead of vector
+  // multiplications. Implemented for the v2i64, v4i64, v4i32, v8i32 types:
+  //
+  //  v2i64, v4i64, v4i32 => Only profitable w/ popcnt disabled
+  //  v8i32 => Always profitable
+  //
+  // FIXME: There a couple of possible improvements:
+  //
+  // 1) Support for i8 and i16 vectors (needs measurements if popcnt enabled).
+  // 2) Use strategies from http://wm.ite.pl/articles/sse-popcount.html
+  //
+  assert(EltVT.isInteger() && (Len == 32 || Len == 64) && Len % 8 == 0 &&
+         "CTPOP not implemented for this vector element type.");
+
+  // X86 canonicalize ANDs to vXi64, generate the appropriate bitcasts to avoid
+  // extra legalization.
+  bool NeedsBitcast = EltVT == MVT::i32;
+  MVT BitcastVT = VT.is256BitVector() ? MVT::v4i64 : MVT::v2i64;
+
+  SDValue Cst55 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), EltVT);
+  SDValue Cst33 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), EltVT);
+  SDValue Cst0F = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), EltVT);
+
+  // v = v - ((v >> 1) & 0x55555555...)
+  SmallVector<SDValue, 8> Ones(NumElts, DAG.getConstant(1, EltVT));
+  SDValue OnesV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Ones);
+  SDValue Srl = DAG.getNode(ISD::SRL, dl, VT, Op, OnesV);
+  if (NeedsBitcast)
+    Srl = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Srl);
+
+  SmallVector<SDValue, 8> Mask55(NumElts, Cst55);
+  SDValue M55 = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask55);
+  if (NeedsBitcast)
+    M55 = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M55);
+
+  SDValue And = DAG.getNode(ISD::AND, dl, Srl.getValueType(), Srl, M55);
+  if (VT != And.getValueType())
+    And = DAG.getNode(ISD::BITCAST, dl, VT, And);
+  SDValue Sub = DAG.getNode(ISD::SUB, dl, VT, Op, And);
+
+  // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
+  SmallVector<SDValue, 8> Mask33(NumElts, Cst33);
+  SDValue M33 = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask33);
+  SmallVector<SDValue, 8> Twos(NumElts, DAG.getConstant(2, EltVT));
+  SDValue TwosV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Twos);
+
+  Srl = DAG.getNode(ISD::SRL, dl, VT, Sub, TwosV);
+  if (NeedsBitcast) {
+    Srl = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Srl);
+    M33 = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M33);
+    Sub = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Sub);
+  }
+
+  SDValue AndRHS = DAG.getNode(ISD::AND, dl, M33.getValueType(), Srl, M33);
+  SDValue AndLHS = DAG.getNode(ISD::AND, dl, M33.getValueType(), Sub, M33);
+  if (VT != AndRHS.getValueType()) {
+    AndRHS = DAG.getNode(ISD::BITCAST, dl, VT, AndRHS);
+    AndLHS = DAG.getNode(ISD::BITCAST, dl, VT, AndLHS);
+  }
+  SDValue Add = DAG.getNode(ISD::ADD, dl, VT, AndLHS, AndRHS);
+
+  // v = (v + (v >> 4)) & 0x0F0F0F0F...
+  SmallVector<SDValue, 8> Fours(NumElts, DAG.getConstant(4, EltVT));
+  SDValue FoursV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Fours);
+  Srl = DAG.getNode(ISD::SRL, dl, VT, Add, FoursV);
+  Add = DAG.getNode(ISD::ADD, dl, VT, Add, Srl);
+
+  SmallVector<SDValue, 8> Mask0F(NumElts, Cst0F);
+  SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask0F);
+  if (NeedsBitcast) {
+    Add = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Add);
+    M0F = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M0F);
+  }
+  And = DAG.getNode(ISD::AND, dl, M0F.getValueType(), Add, M0F);
+  if (VT != And.getValueType())
+    And = DAG.getNode(ISD::BITCAST, dl, VT, And);
+
+  // The algorithm mentioned above uses:
+  //    v = (v * 0x01010101...) >> (Len - 8)
+  //
+  // Change it to use vector adds + vector shifts which yield faster results on
+  // Haswell than using vector integer multiplication.
+  //
+  // For i32 elements:
+  //    v = v + (v >> 8)
+  //    v = v + (v >> 16)
+  //
+  // For i64 elements:
+  //    v = v + (v >> 8)
+  //    v = v + (v >> 16)
+  //    v = v + (v >> 32)
+  //
+  Add = And;
+  SmallVector<SDValue, 8> Csts;
+  for (unsigned i = 8; i <= Len/2; i *= 2) {
+    Csts.assign(NumElts, DAG.getConstant(i, EltVT));
+    SDValue CstsV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Csts);
+    Srl = DAG.getNode(ISD::SRL, dl, VT, Add, CstsV);
+    Add = DAG.getNode(ISD::ADD, dl, VT, Add, Srl);
+    Csts.clear();
+  }
+
+  // The result is on the least significant 6-bits on i32 and 7-bits on i64.
+  SDValue Cst3F = DAG.getConstant(APInt(Len, Len == 32 ? 0x3F : 0x7F), EltVT);
+  SmallVector<SDValue, 8> Cst3FV(NumElts, Cst3F);
+  SDValue M3F = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Cst3FV);
+  if (NeedsBitcast) {
+    Add = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Add);
+    M3F = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M3F);
+  }
+  And = DAG.getNode(ISD::AND, dl, M3F.getValueType(), Add, M3F);
+  if (VT != And.getValueType())
+    And = DAG.getNode(ISD::BITCAST, dl, VT, And);
+
+  return And;
+}
+
 static SDValue LowerLOAD_SUB(SDValue Op, SelectionDAG &DAG) {
   SDNode *Node = Op.getNode();
   SDLoc dl(Node);
@@ -19310,6 +19461,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::ATOMIC_FENCE:       return LowerATOMIC_FENCE(Op, Subtarget, DAG);
   case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS:
     return LowerCMP_SWAP(Op, Subtarget, DAG);
+  case ISD::CTPOP:              return LowerCTPOP(Op, Subtarget, DAG);
   case ISD::ATOMIC_LOAD_SUB:    return LowerLOAD_SUB(Op,DAG);
   case ISD::ATOMIC_STORE:       return LowerATOMIC_STORE(Op,DAG);
   case ISD::BUILD_VECTOR:       return LowerBUILD_VECTOR(Op, DAG);