ARM64: Combine shifts and uses from different basic block to bit-extract instruction
[oota-llvm.git] / lib / Target / ARM64 / ARM64ISelLowering.cpp
index d4898ae5239e882471f940fd6b89a369986f51f5..ee498e65c71c5991903e9bd10c0ccedb7a2c8747 100644 (file)
@@ -435,7 +435,11 @@ ARM64TargetLowering::ARM64TargetLowering(ARM64TargetMachine &TM)
 
   setMinFunctionAlignment(2);
 
+  setDivIsWellDefined(true);
+
   RequireStrictAlign = StrictAlign;
+
+  setHasExtractBitsInsn(true);
 }
 
 void ARM64TargetLowering::addTypeForNEON(EVT VT, EVT PromotedBitwiseVT) {
@@ -1313,10 +1317,16 @@ static SDValue LowerVectorFP_TO_INT(SDValue Op, SelectionDAG &DAG) {
   if (VT.getSizeInBits() == InVT.getSizeInBits())
     return Op;
 
-  if (InVT == MVT::v2f64) {
+  if (InVT == MVT::v2f64 || InVT == MVT::v4f32) {
     SDLoc dl(Op);
-    SDValue Cv = DAG.getNode(Op.getOpcode(), dl, MVT::v2i64, Op.getOperand(0));
+    SDValue Cv =
+        DAG.getNode(Op.getOpcode(), dl, InVT.changeVectorElementTypeToInteger(),
+                    Op.getOperand(0));
     return DAG.getNode(ISD::TRUNCATE, dl, VT, Cv);
+  } else if (InVT == MVT::v2f32) {
+    SDLoc dl(Op);
+    SDValue Ext = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v2f64, Op.getOperand(0));
+    return DAG.getNode(Op.getOpcode(), dl, VT, Ext);
   }
 
   // Type changing conversions are illegal.
@@ -3829,9 +3839,11 @@ SDValue ARM64TargetLowering::ReconstructShuffle(SDValue Op,
       VEXTOffsets[i] = 0;
       continue;
     } else if (SourceVecs[i].getValueType().getVectorNumElements() < NumElts) {
-      // It probably isn't worth padding out a smaller vector just to
-      // break it down again in a shuffle.
-      return SDValue();
+      // We can pad out the smaller vector for free, so if it's part of a
+      // shuffle...
+      ShuffleSrcs[i] = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, SourceVecs[i],
+                                   DAG.getUNDEF(SourceVecs[i].getValueType()));
+      continue;
     }
 
     // Don't attempt to extract subvectors from BUILD_VECTOR sources
@@ -3943,11 +3955,13 @@ static bool isEXTMask(ArrayRef<int> M, EVT VT, bool &ReverseEXT,
   unsigned NumElts = VT.getVectorNumElements();
   ReverseEXT = false;
 
-  // Assume that the first shuffle index is not UNDEF.  Fail if it is.
-  if (M[0] < 0)
-    return false;
-
-  Imm = M[0];
+  // Look for the first non-undef choice and count backwards from
+  // that. E.g. <-1, -1, 3, ...> means that an EXT must start at 3 - 2 = 1. This
+  // guarantees that at least one index is correct.
+  const int *FirstRealElt =
+      std::find_if(M.begin(), M.end(), [](int Elt) { return Elt >= 0; });
+  assert(FirstRealElt != M.end() && "Completely UNDEF shuffle? Why bother?");
+  Imm = *FirstRealElt - (FirstRealElt - M.begin());
 
   // If this is a VEXT shuffle, the immediate value is the index of the first
   // element.  The other shuffle indices must be the successive elements after
@@ -4093,6 +4107,94 @@ static bool isTRN_v_undef_Mask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
   return true;
 }
 
+static bool isINSMask(ArrayRef<int> M, int NumInputElements,
+                      bool &DstIsLeft, int &Anomaly) {
+  if (M.size() != static_cast<size_t>(NumInputElements))
+    return false;
+
+  int NumLHSMatch = 0, NumRHSMatch = 0;
+  int LastLHSMismatch = -1, LastRHSMismatch = -1;
+
+  for (int i = 0; i < NumInputElements; ++i) {
+    if (M[i] == -1) {
+      ++NumLHSMatch;
+      ++NumRHSMatch;
+      continue;
+    }
+
+    if (M[i] == i)
+      ++NumLHSMatch;
+    else
+      LastLHSMismatch = i;
+
+    if (M[i] == i + NumInputElements)
+      ++NumRHSMatch;
+    else
+      LastRHSMismatch = i;
+  }
+
+  if (NumLHSMatch == NumInputElements - 1) {
+    DstIsLeft = true;
+    Anomaly = LastLHSMismatch;
+    return true;
+  } else if (NumRHSMatch == NumInputElements - 1) {
+    DstIsLeft = false;
+    Anomaly = LastRHSMismatch;
+    return true;
+  }
+
+  return false;
+}
+
+static bool isConcatMask(ArrayRef<int> Mask, EVT VT, bool SplitLHS) {
+  if (VT.getSizeInBits() != 128)
+    return false;
+
+  unsigned NumElts = VT.getVectorNumElements();
+
+  for (int I = 0, E = NumElts / 2; I != E; I++) {
+    if (Mask[I] != I)
+      return false;
+  }
+
+  int Offset = NumElts / 2;
+  for (int I = NumElts / 2, E = NumElts; I != E; I++) {
+    if (Mask[I] != I + SplitLHS * Offset)
+      return false;
+  }
+
+  return true;
+}
+
+static SDValue tryFormConcatFromShuffle(SDValue Op, SelectionDAG &DAG) {
+  SDLoc DL(Op);
+  EVT VT = Op.getValueType();
+  SDValue V0 = Op.getOperand(0);
+  SDValue V1 = Op.getOperand(1);
+  ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(Op)->getMask();
+
+  if (VT.getVectorElementType() != V0.getValueType().getVectorElementType() ||
+      VT.getVectorElementType() != V1.getValueType().getVectorElementType())
+    return SDValue();
+
+  bool SplitV0 = V0.getValueType().getSizeInBits() == 128;
+
+  if (!isConcatMask(Mask, VT, SplitV0))
+    return SDValue();
+
+  EVT CastVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
+                                VT.getVectorNumElements() / 2);
+  if (SplitV0) {
+    V0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, CastVT, V0,
+                     DAG.getConstant(0, MVT::i64));
+  }
+  if (V1.getValueType().getSizeInBits() == 128) {
+    V1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, CastVT, V1,
+                     DAG.getConstant(0, MVT::i64));
+  }
+  return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, V0, V1);
+}
+
 /// GeneratePerfectShuffle - Given an entry in the perfect-shuffle table, emit
 /// the specified operations to build the shuffle.
 static SDValue GeneratePerfectShuffle(unsigned PFEntry, SDValue LHS,
@@ -4362,6 +4464,35 @@ SDValue ARM64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
     return DAG.getNode(Opc, dl, V1.getValueType(), V1, V1);
   }
 
+  SDValue Concat = tryFormConcatFromShuffle(Op, DAG);
+  if (Concat.getNode())
+    return Concat;
+
+  bool DstIsLeft;
+  int Anomaly;
+  int NumInputElements = V1.getValueType().getVectorNumElements();
+  if (isINSMask(ShuffleMask, NumInputElements, DstIsLeft, Anomaly)) {
+    SDValue DstVec = DstIsLeft ? V1 : V2;
+    SDValue DstLaneV = DAG.getConstant(Anomaly, MVT::i64);
+
+    SDValue SrcVec = V1;
+    int SrcLane = ShuffleMask[Anomaly];
+    if (SrcLane >= NumInputElements) {
+      SrcVec = V2;
+      SrcLane -= VT.getVectorNumElements();
+    }
+    SDValue SrcLaneV = DAG.getConstant(SrcLane, MVT::i64);
+
+    EVT ScalarVT = VT.getVectorElementType();
+    if (ScalarVT.getSizeInBits() < 32)
+      ScalarVT = MVT::i32;
+
+    return DAG.getNode(
+        ISD::INSERT_VECTOR_ELT, dl, VT, DstVec,
+        DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, SrcVec, SrcLaneV),
+        DstLaneV);
+  }
+
   // If the shuffle is not directly supported and it has 4 elements, use
   // the PerfectShuffle-generated table to synthesize it from other shuffles.
   unsigned NumElts = VT.getVectorNumElements();
@@ -5200,18 +5331,21 @@ bool ARM64TargetLowering::isShuffleMaskLegal(const SmallVectorImpl<int> &M,
       return true;
   }
 
-  bool ReverseVEXT;
-  unsigned Imm, WhichResult;
+  bool DummyBool;
+  int DummyInt;
+  unsigned DummyUnsigned;
 
   return (ShuffleVectorSDNode::isSplatMask(&M[0], VT) || isREVMask(M, VT, 64) ||
           isREVMask(M, VT, 32) || isREVMask(M, VT, 16) ||
-          isEXTMask(M, VT, ReverseVEXT, Imm) ||
+          isEXTMask(M, VT, DummyBool, DummyUnsigned) ||
           // isTBLMask(M, VT) || // FIXME: Port TBL support from ARM.
-          isTRNMask(M, VT, WhichResult) || isUZPMask(M, VT, WhichResult) ||
-          isZIPMask(M, VT, WhichResult) ||
-          isTRN_v_undef_Mask(M, VT, WhichResult) ||
-          isUZP_v_undef_Mask(M, VT, WhichResult) ||
-          isZIP_v_undef_Mask(M, VT, WhichResult));
+          isTRNMask(M, VT, DummyUnsigned) || isUZPMask(M, VT, DummyUnsigned) ||
+          isZIPMask(M, VT, DummyUnsigned) ||
+          isTRN_v_undef_Mask(M, VT, DummyUnsigned) ||
+          isUZP_v_undef_Mask(M, VT, DummyUnsigned) ||
+          isZIP_v_undef_Mask(M, VT, DummyUnsigned) ||
+          isINSMask(M, VT.getVectorNumElements(), DummyBool, DummyInt) ||
+          isConcatMask(M, VT, VT.getSizeInBits() == 128));
 }
 
 /// getVShiftImm - Check if this is a valid build_vector for the immediate