AArch64: implement efficient f16 bitcasts
[oota-llvm.git] / lib / Target / AArch64 / AArch64ISelLowering.cpp
index 7b77c59ed11fe49e2f2b5908b952cf550be8966f..4921826034e9490045ff14c8fb757d4f541b59d4 100644 (file)
@@ -305,6 +305,7 @@ AArch64TargetLowering::AArch64TargetLowering(TargetMachine &TM)
 
   // AArch64 does not have floating-point extending loads, i1 sign-extending
   // load, floating-point truncating stores, or v2i32->v2i16 truncating store.
+  setLoadExtAction(ISD::EXTLOAD, MVT::f16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::f64, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::f80, Expand);
@@ -316,6 +317,10 @@ AArch64TargetLowering::AArch64TargetLowering(TargetMachine &TM)
   setTruncStoreAction(MVT::f128, MVT::f64, Expand);
   setTruncStoreAction(MVT::f128, MVT::f32, Expand);
   setTruncStoreAction(MVT::f128, MVT::f16, Expand);
+
+  setOperationAction(ISD::BITCAST, MVT::i16, Custom);
+  setOperationAction(ISD::BITCAST, MVT::f16, Custom);
+
   // Indexed loads and stores are supported.
   for (unsigned im = (unsigned)ISD::PRE_INC;
        im != (unsigned)ISD::LAST_INDEXED_MODE; ++im) {
@@ -1509,12 +1514,30 @@ SDValue AArch64TargetLowering::LowerFSINCOS(SDValue Op,
   return CallResult.first;
 }
 
+static SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) {
+  if (Op.getValueType() != MVT::f16)
+    return SDValue();
+
+  assert(Op.getOperand(0).getValueType() == MVT::i16);
+  SDLoc DL(Op);
+
+  Op = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op.getOperand(0));
+  Op = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Op);
+  return SDValue(
+      DAG.getMachineNode(TargetOpcode::EXTRACT_SUBREG, DL, MVT::f16, Op,
+                         DAG.getTargetConstant(AArch64::hsub, MVT::i32)),
+      0);
+}
+
+
 SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
                                               SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
   default:
     llvm_unreachable("unimplemented operand");
     return SDValue();
+  case ISD::BITCAST:
+    return LowerBITCAST(Op, DAG);
   case ISD::GlobalAddress:
     return LowerGlobalAddress(Op, DAG);
   case ISD::GlobalTLSAddress:
@@ -6417,10 +6440,61 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+static SDValue performVectorCompareAndMaskUnaryOpCombine(SDNode *N,
+                                                         SelectionDAG &DAG) {
+  // Take advantage of vector comparisons producing 0 or -1 in each lane to
+  // optimize away operation when it's from a constant.
+  //
+  // The general transformation is:
+  //    UNARYOP(AND(VECTOR_CMP(x,y), constant)) -->
+  //       AND(VECTOR_CMP(x,y), constant2)
+  //    constant2 = UNARYOP(constant)
+
+  // Early exit if this isn't a vector operation or if the operand of the
+  // unary operation isn't a bitwise AND.
+  EVT VT = N->getValueType(0);
+  if (!VT.isVector() || N->getOperand(0)->getOpcode() != ISD::AND ||
+      N->getOperand(0)->getOperand(0)->getOpcode() != ISD::SETCC)
+    return SDValue();
+
+  // Now check that the other operand of the AND is a constant splat. We could
+  // make the transformation for non-constant splats as well, but it's unclear
+  // that would be a benefit as it would not eliminate any operations, just
+  // perform one more step in scalar code before moving to the vector unit.
+  if (BuildVectorSDNode *BV =
+          dyn_cast<BuildVectorSDNode>(N->getOperand(0)->getOperand(1))) {
+    // Bail out if the vector isn't a constant splat.
+    if (!BV->getConstantSplatNode())
+      return SDValue();
+
+    // Everything checks out. Build up the new and improved node.
+    SDLoc DL(N);
+    EVT IntVT = BV->getValueType(0);
+    // Create a new constant of the appropriate type for the transformed
+    // DAG.
+    SDValue SourceConst = DAG.getNode(N->getOpcode(), DL, VT, SDValue(BV, 0));
+    // The AND node needs bitcasts to/from an integer vector type around it.
+    SDValue MaskConst = DAG.getNode(ISD::BITCAST, DL, IntVT, SourceConst);
+    SDValue NewAnd = DAG.getNode(ISD::AND, DL, IntVT,
+                                 N->getOperand(0)->getOperand(0), MaskConst);
+    SDValue Res = DAG.getNode(ISD::BITCAST, DL, VT, NewAnd);
+    return Res;
+  }
+
+  return SDValue();
+}
+
 static SDValue performIntToFpCombine(SDNode *N, SelectionDAG &DAG) {
+  // First try to optimize away the conversion when it's conditionally from
+  // a constant. Vectors only.
+  SDValue Res = performVectorCompareAndMaskUnaryOpCombine(N, DAG);
+  if (Res != SDValue())
+    return Res;
+
   EVT VT = N->getValueType(0);
   if (VT != MVT::f32 && VT != MVT::f64)
     return SDValue();
+
   // Only optimize when the source and destination types have the same width.
   if (VT.getSizeInBits() != N->getOperand(0).getValueType().getSizeInBits())
     return SDValue();
@@ -7890,11 +7964,32 @@ bool AArch64TargetLowering::getPostIndexedAddressParts(
   return true;
 }
 
+static void ReplaceBITCASTResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
+                                  SelectionDAG &DAG) {
+  if (N->getValueType(0) != MVT::i16)
+    return;
+
+  SDLoc DL(N);
+  SDValue Op = N->getOperand(0);
+  assert(Op.getValueType() == MVT::f16 &&
+         "Inconsistent bitcast? Only 16-bit types should be i16 or f16");
+  Op = SDValue(
+      DAG.getMachineNode(TargetOpcode::INSERT_SUBREG, DL, MVT::f32,
+                         DAG.getUNDEF(MVT::i32), Op,
+                         DAG.getTargetConstant(AArch64::hsub, MVT::i32)),
+      0);
+  Op = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Op);
+  Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Op));
+}
+
 void AArch64TargetLowering::ReplaceNodeResults(
     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
   switch (N->getOpcode()) {
   default:
     llvm_unreachable("Don't know how to custom expand this");
+  case ISD::BITCAST:
+    ReplaceBITCASTResults(N, Results, DAG);
+    return;
   case ISD::FP_TO_UINT:
   case ISD::FP_TO_SINT:
     assert(N->getValueType(0) == MVT::i128 && "unexpected illegal conversion");