[X86][SSE] Vectorize CTTZ + CTTZ_ZERO_UNDEF
[oota-llvm.git] / lib / Target / X86 / X86ISelLowering.cpp
index cda055438a11c8dcf6b5f63450a9934fef3862d3..9c39c26aba5134121405b562c36c69831a96e7e3 100644 (file)
@@ -847,6 +847,15 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     setOperationAction(ISD::CTPOP,              MVT::v4i32, Custom);
     setOperationAction(ISD::CTPOP,              MVT::v2i64, Custom);
 
+    setOperationAction(ISD::CTTZ,               MVT::v16i8, Custom);
+    setOperationAction(ISD::CTTZ,               MVT::v8i16, Custom);
+    setOperationAction(ISD::CTTZ,               MVT::v4i32, Custom);
+    // ISD::CTTZ v2i64 - scalarization is faster.
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF,    MVT::v16i8, Custom);
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF,    MVT::v8i16, Custom);
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF,    MVT::v4i32, Custom);
+    // ISD::CTTZ_ZERO_UNDEF v2i64 - scalarization is faster.
+
     // Custom lower build_vector, vector_shuffle, and extract_vector_elt.
     for (int i = MVT::v16i8; i != MVT::v2i64; ++i) {
       MVT VT = (MVT::SimpleValueType)i;
@@ -1127,6 +1136,15 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     setOperationAction(ISD::CTPOP,             MVT::v8i32, Custom);
     setOperationAction(ISD::CTPOP,             MVT::v4i64, Custom);
 
+    setOperationAction(ISD::CTTZ,              MVT::v32i8, Custom);
+    setOperationAction(ISD::CTTZ,              MVT::v16i16, Custom);
+    setOperationAction(ISD::CTTZ,              MVT::v8i32, Custom);
+    setOperationAction(ISD::CTTZ,              MVT::v4i64, Custom);
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF,   MVT::v32i8, Custom);
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF,   MVT::v16i16, Custom);
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF,   MVT::v8i32, Custom);
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF,   MVT::v4i64, Custom);
+
     if (Subtarget->hasFMA() || Subtarget->hasFMA4() || Subtarget->hasAVX512()) {
       setOperationAction(ISD::FMA,             MVT::v8f32, Legal);
       setOperationAction(ISD::FMA,             MVT::v4f64, Legal);
@@ -1499,6 +1517,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
       setOperationAction(ISD::CTLZ,             MVT::v16i32, Legal);
       setOperationAction(ISD::CTLZ_ZERO_UNDEF,  MVT::v8i64, Legal);
       setOperationAction(ISD::CTLZ_ZERO_UNDEF,  MVT::v16i32, Legal);
+
+      setOperationAction(ISD::CTTZ_ZERO_UNDEF,  MVT::v8i64, Custom);
+      setOperationAction(ISD::CTTZ_ZERO_UNDEF,  MVT::v16i32, Custom);
     }
     if (Subtarget->hasVLX() && Subtarget->hasCDI()) {
       setOperationAction(ISD::CTLZ,             MVT::v4i64, Legal);
@@ -1509,6 +1530,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
       setOperationAction(ISD::CTLZ_ZERO_UNDEF,  MVT::v8i32, Legal);
       setOperationAction(ISD::CTLZ_ZERO_UNDEF,  MVT::v2i64, Legal);
       setOperationAction(ISD::CTLZ_ZERO_UNDEF,  MVT::v4i32, Legal);
+
+      setOperationAction(ISD::CTTZ_ZERO_UNDEF,  MVT::v4i64, Custom);
+      setOperationAction(ISD::CTTZ_ZERO_UNDEF,  MVT::v8i32, Custom);
+      setOperationAction(ISD::CTTZ_ZERO_UNDEF,  MVT::v2i64, Custom);
+      setOperationAction(ISD::CTTZ_ZERO_UNDEF,  MVT::v4i32, Custom);
     }
     if (Subtarget->hasDQI()) {
       setOperationAction(ISD::MUL,             MVT::v2i64, Legal);
@@ -17222,13 +17248,39 @@ static SDValue LowerCTLZ_ZERO_UNDEF(SDValue Op, SelectionDAG &DAG) {
 
 static SDValue LowerCTTZ(SDValue Op, SelectionDAG &DAG) {
   MVT VT = Op.getSimpleValueType();
-  unsigned NumBits = VT.getSizeInBits();
+  unsigned NumBits = VT.getScalarSizeInBits();
   SDLoc dl(Op);
-  Op = Op.getOperand(0);
+
+  if (VT.isVector()) {
+    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+    SDValue N0 = Op.getOperand(0);
+    SDValue Zero = DAG.getConstant(0, dl, VT);
+
+    // lsb(x) = (x & -x)
+    SDValue LSB = DAG.getNode(ISD::AND, dl, VT, N0,
+                              DAG.getNode(ISD::SUB, dl, VT, Zero, N0));
+
+    // cttz_undef(x) = (width - 1) - ctlz(lsb)
+    if (Op.getOpcode() == ISD::CTTZ_ZERO_UNDEF &&
+        TLI.isOperationLegal(ISD::CTLZ, VT)) {
+      SDValue WidthMinusOne = DAG.getConstant(NumBits - 1, dl, VT);
+      return DAG.getNode(ISD::SUB, dl, VT, WidthMinusOne,
+                         DAG.getNode(ISD::CTLZ, dl, VT, LSB));
+    }
+
+    // cttz(x) = ctpop(lsb - 1)
+    SDValue One = DAG.getConstant(1, dl, VT);
+    return DAG.getNode(ISD::CTPOP, dl, VT,
+                       DAG.getNode(ISD::SUB, dl, VT, LSB, One));
+  }
+
+  assert(Op.getOpcode() == ISD::CTTZ &&
+         "Only scalar CTTZ requires custom lowering");
 
   // Issue a bsf (scan bits forward) which also sets EFLAGS.
   SDVTList VTs = DAG.getVTList(VT, MVT::i32);
-  Op = DAG.getNode(X86ISD::BSF, dl, VTs, Op);
+  Op = DAG.getNode(X86ISD::BSF, dl, VTs, Op.getOperand(0));
 
   // If src is zero (i.e. bsf sets ZF), returns NumBits.
   SDValue Ops[] = {
@@ -19168,7 +19220,8 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::FLT_ROUNDS_:        return LowerFLT_ROUNDS_(Op, DAG);
   case ISD::CTLZ:               return LowerCTLZ(Op, DAG);
   case ISD::CTLZ_ZERO_UNDEF:    return LowerCTLZ_ZERO_UNDEF(Op, DAG);
-  case ISD::CTTZ:               return LowerCTTZ(Op, DAG);
+  case ISD::CTTZ:
+  case ISD::CTTZ_ZERO_UNDEF:    return LowerCTTZ(Op, DAG);
   case ISD::MUL:                return LowerMUL(Op, Subtarget, DAG);
   case ISD::UMUL_LOHI:
   case ISD::SMUL_LOHI:          return LowerMUL_LOHI(Op, Subtarget, DAG);