ARM: Fix cttz expansion on vector types.
[oota-llvm.git] / lib / Target / ARM / ARMISelLowering.cpp
index e747ab0daca335d06c9a619e74621b8c41c8001d..e335784f6d876cf742e6da081b1bd049cce8da3e 100644 (file)
@@ -543,6 +543,27 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::CTPOP,      MVT::v4i16, Custom);
     setOperationAction(ISD::CTPOP,      MVT::v8i16, Custom);
 
+    // NEON does not have single instruction CTTZ for vectors.
+    setOperationAction(ISD::CTTZ, MVT::v8i8, Custom);
+    setOperationAction(ISD::CTTZ, MVT::v4i16, Custom);
+    setOperationAction(ISD::CTTZ, MVT::v2i32, Custom);
+    setOperationAction(ISD::CTTZ, MVT::v1i64, Custom);
+
+    setOperationAction(ISD::CTTZ, MVT::v16i8, Custom);
+    setOperationAction(ISD::CTTZ, MVT::v8i16, Custom);
+    setOperationAction(ISD::CTTZ, MVT::v4i32, Custom);
+    setOperationAction(ISD::CTTZ, MVT::v2i64, Custom);
+
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::v8i8, Custom);
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::v4i16, Custom);
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::v2i32, Custom);
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::v1i64, Custom);
+
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::v16i8, Custom);
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::v8i16, Custom);
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::v4i32, Custom);
+    setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::v2i64, Custom);
+
     // NEON only has FMA instructions as of VFP4.
     if (!Subtarget->hasVFP4()) {
       setOperationAction(ISD::FMA, MVT::v2f32, Expand);
@@ -4282,8 +4303,82 @@ SDValue ARMTargetLowering::LowerFLT_ROUNDS_(SDValue Op,
 
 static SDValue LowerCTTZ(SDNode *N, SelectionDAG &DAG,
                          const ARMSubtarget *ST) {
-  EVT VT = N->getValueType(0);
   SDLoc dl(N);
+  EVT VT = N->getValueType(0);
+  if (VT.isVector()) {
+    assert(ST->hasNEON());
+
+    // Compute the least significant set bit: LSB = X & -X
+    SDValue X = N->getOperand(0);
+    SDValue NX = DAG.getNode(ISD::SUB, dl, VT, getZeroVector(VT, DAG, dl), X);
+    SDValue LSB = DAG.getNode(ISD::AND, dl, VT, X, NX);
+
+    EVT ElemTy = VT.getVectorElementType();
+
+    if (ElemTy == MVT::i8) {
+      // Compute with: cttz(x) = ctpop(lsb - 1)
+      SDValue One = DAG.getNode(ARMISD::VMOVIMM, dl, VT,
+                                DAG.getTargetConstant(1, dl, ElemTy));
+      SDValue Bits = DAG.getNode(ISD::SUB, dl, VT, LSB, One);
+      return DAG.getNode(ISD::CTPOP, dl, VT, Bits);
+    }
+
+    if ((ElemTy == MVT::i16 || ElemTy == MVT::i32) &&
+        (N->getOpcode() == ISD::CTTZ_ZERO_UNDEF)) {
+      // Compute with: cttz(x) = (width - 1) - ctlz(lsb), if x != 0
+      unsigned NumBits = ElemTy.getSizeInBits();
+      SDValue WidthMinus1 =
+          DAG.getNode(ARMISD::VMOVIMM, dl, VT,
+                      DAG.getTargetConstant(NumBits - 1, dl, ElemTy));
+      SDValue CTLZ = DAG.getNode(ISD::CTLZ, dl, VT, LSB);
+      return DAG.getNode(ISD::SUB, dl, VT, WidthMinus1, CTLZ);
+    }
+
+    // Compute with: cttz(x) = ctpop(lsb - 1)
+
+    // Since we can only compute the number of bits in a byte with vcnt.8, we
+    // have to gather the result with pairwise addition (vpaddl) for i16, i32,
+    // and i64.
+
+    // Compute LSB - 1.
+    SDValue Bits;
+    if (ElemTy == MVT::i64) {
+      // Load constant 0xffff'ffff'ffff'ffff to register.
+      SDValue FF = DAG.getNode(ARMISD::VMOVIMM, dl, VT,
+                               DAG.getTargetConstant(0x1eff, dl, MVT::i32));
+      Bits = DAG.getNode(ISD::ADD, dl, VT, LSB, FF);
+    } else {
+      SDValue One = DAG.getNode(ARMISD::VMOVIMM, dl, VT,
+                                DAG.getTargetConstant(1, dl, ElemTy));
+      Bits = DAG.getNode(ISD::SUB, dl, VT, LSB, One);
+    }
+
+    // Count #bits with vcnt.8.
+    EVT VT8Bit = VT.is64BitVector() ? MVT::v8i8 : MVT::v16i8;
+    SDValue BitsVT8 = DAG.getNode(ISD::BITCAST, dl, VT8Bit, Bits);
+    SDValue Cnt8 = DAG.getNode(ISD::CTPOP, dl, VT8Bit, BitsVT8);
+
+    // Gather the #bits with vpaddl (pairwise add.)
+    EVT VT16Bit = VT.is64BitVector() ? MVT::v4i16 : MVT::v8i16;
+    SDValue Cnt16 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT16Bit,
+        DAG.getTargetConstant(Intrinsic::arm_neon_vpaddlu, dl, MVT::i32),
+        Cnt8);
+    if (ElemTy == MVT::i16)
+      return Cnt16;
+
+    EVT VT32Bit = VT.is64BitVector() ? MVT::v2i32 : MVT::v4i32;
+    SDValue Cnt32 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT32Bit,
+        DAG.getTargetConstant(Intrinsic::arm_neon_vpaddlu, dl, MVT::i32),
+        Cnt16);
+    if (ElemTy == MVT::i32)
+      return Cnt32;
+
+    assert(ElemTy == MVT::i64);
+    SDValue Cnt64 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+        DAG.getTargetConstant(Intrinsic::arm_neon_vpaddlu, dl, MVT::i32),
+        Cnt32);
+    return Cnt64;
+  }
 
   if (!ST->hasV6T2Ops())
     return SDValue();
@@ -6487,7 +6582,8 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::SHL_PARTS:     return LowerShiftLeftParts(Op, DAG);
   case ISD::SRL_PARTS:
   case ISD::SRA_PARTS:     return LowerShiftRightParts(Op, DAG);
-  case ISD::CTTZ:          return LowerCTTZ(Op.getNode(), DAG, Subtarget);
+  case ISD::CTTZ:
+  case ISD::CTTZ_ZERO_UNDEF: return LowerCTTZ(Op.getNode(), DAG, Subtarget);
   case ISD::CTPOP:         return LowerCTPOP(Op.getNode(), DAG, Subtarget);
   case ISD::SETCC:         return LowerVSETCC(Op, DAG);
   case ISD::ConstantFP:    return LowerConstantFP(Op, DAG, Subtarget);