[X86][SSE] Merge PerformBLENDICombine into PerformShuffleCombine
[oota-llvm.git] / lib / Target / AArch64 / AArch64ISelLowering.cpp
index 2ede39eb7f22dc252387943ea81559d95612563b..e24a31f29e16b501e3f1d9e9746c5be4aaba70ae 100644 (file)
@@ -8450,25 +8450,6 @@ static SDValue performExtendCombine(SDNode *N,
     }
   }
 
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-
-  // If we see (any_extend (extract_vector_element v, i)), we can potentially
-  // remove the extend and promote the extract. We can do this if the vector
-  // type is legal and if the result is sign extended from the element type.
-  if (DCI.isAfterLegalizeVectorOps() && N->getOpcode() == ISD::ANY_EXTEND &&
-      N->hasOneUse() && N->use_begin()->getOpcode() == ISD::SIGN_EXTEND_INREG) {
-    const SDValue &M = N->getOperand(0);
-    if (M.getNode()->hasOneUse() && M.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
-      EVT DstTy = N->getValueType(0);
-      EVT SrcTy = cast<VTSDNode>(N->use_begin()->getOperand(1))->getVT();
-      EVT VecTy = M.getOperand(0).getValueType();
-      EVT ElmTy = VecTy.getScalarType();
-      if (TLI.isTypeLegal(VecTy) && SrcTy == ElmTy)
-        return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(N), DstTy,
-                           M.getOperand(0), M.getOperand(1));
-    }
-  }
-
   // This is effectively a custom type legalization for AArch64.
   //
   // Type legalization will split an extend of a small, legal, type to a larger
@@ -8499,6 +8480,7 @@ static SDValue performExtendCombine(SDNode *N,
   // We're only interested in cleaning things up for non-legal vector types
   // here. If both the source and destination are legal, things will just
   // work naturally without any fiddling.
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   EVT ResVT = N->getValueType(0);
   if (!ResVT.isVector() || TLI.isTypeLegal(ResVT))
     return SDValue();
@@ -9509,6 +9491,103 @@ static SDValue performBRCONDCombine(SDNode *N,
   return SDValue();
 }
 
+// Optimize some simple tbz/tbnz cases.  Returns the new operand and bit to test
+// as well as whether the test should be inverted.  This code is required to
+// catch these cases (as opposed to standard dag combines) because
+// AArch64ISD::TBZ is matched during legalization.
+static SDValue getTestBitOperand(SDValue Op, unsigned &Bit, bool &Invert,
+                                 SelectionDAG &DAG) {
+
+  if (!Op->hasOneUse())
+    return Op;
+
+  // We don't handle undef/constant-fold cases below, as they should have
+  // already been taken care of (e.g. and of 0, test of undefined shifted bits,
+  // etc.)
+
+  // (tbz (trunc x), b) -> (tbz x, b)
+  // This case is just here to enable more of the below cases to be caught.
+  if (Op->getOpcode() == ISD::TRUNCATE &&
+      Bit < Op->getValueType(0).getSizeInBits()) {
+    return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
+  }
+
+  if (Op->getNumOperands() != 2)
+    return Op;
+
+  auto *C = dyn_cast<ConstantSDNode>(Op->getOperand(1));
+  if (!C)
+    return Op;
+
+  switch (Op->getOpcode()) {
+  default:
+    return Op;
+
+  // (tbz (and x, m), b) -> (tbz x, b)
+  case ISD::AND:
+    if ((C->getZExtValue() >> Bit) & 1)
+      return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
+    return Op;
+
+  // (tbz (shl x, c), b) -> (tbz x, b-c)
+  case ISD::SHL:
+    if (C->getZExtValue() <= Bit &&
+        (Bit - C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) {
+      Bit = Bit - C->getZExtValue();
+      return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
+    }
+    return Op;
+
+  // (tbz (sra x, c), b) -> (tbz x, b+c) or (tbz x, msb) if b+c is > # bits in x
+  case ISD::SRA:
+    Bit = Bit + C->getZExtValue();
+    if (Bit >= Op->getValueType(0).getSizeInBits())
+      Bit = Op->getValueType(0).getSizeInBits() - 1;
+    return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
+
+  // (tbz (srl x, c), b) -> (tbz x, b+c)
+  case ISD::SRL:
+    if ((Bit + C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) {
+      Bit = Bit + C->getZExtValue();
+      return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
+    }
+    return Op;
+
+  // (tbz (xor x, -1), b) -> (tbnz x, b)
+  case ISD::XOR:
+    if ((C->getZExtValue() >> Bit) & 1)
+      Invert = !Invert;
+    return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
+  }
+}
+
+// Optimize test single bit zero/non-zero and branch.
+static SDValue performTBZCombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 SelectionDAG &DAG) {
+  unsigned Bit = cast<ConstantSDNode>(N->getOperand(2))->getZExtValue();
+  bool Invert = false;
+  SDValue TestSrc = N->getOperand(1);
+  SDValue NewTestSrc = getTestBitOperand(TestSrc, Bit, Invert, DAG);
+
+  if (TestSrc == NewTestSrc)
+    return SDValue();
+
+  unsigned NewOpc = N->getOpcode();
+  if (Invert) {
+    if (NewOpc == AArch64ISD::TBZ)
+      NewOpc = AArch64ISD::TBNZ;
+    else {
+      assert(NewOpc == AArch64ISD::TBNZ);
+      NewOpc = AArch64ISD::TBZ;
+    }
+  }
+
+  SDLoc DL(N);
+  return DAG.getNode(NewOpc, DL, MVT::Other, N->getOperand(0), NewTestSrc,
+                     DAG.getConstant(Bit, DL, MVT::i64), N->getOperand(3));
+}
+
 // vselect (v1i1 setcc) ->
 //     vselect (v1iXX setcc)  (XX is the size of the compared operand type)
 // FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as
@@ -9660,6 +9739,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performSTORECombine(N, DCI, DAG, Subtarget);
   case AArch64ISD::BRCOND:
     return performBRCONDCombine(N, DCI, DAG);
+  case AArch64ISD::TBNZ:
+  case AArch64ISD::TBZ:
+    return performTBZCombine(N, DCI, DAG);
   case AArch64ISD::CSEL:
     return performCONDCombine(N, DCI, DAG, 2, 3);
   case AArch64ISD::DUP: