[AVX512] Bring back vector-shuffle lowering support through broadcasts
[oota-llvm.git] / lib / Target / X86 / X86ISelLowering.cpp
index 1bfacfe65a2f4492ee5b376134d6d810069fc1c7..656c1dea1f7ebbd98fe9aeab15fc9decda273b96 100644 (file)
@@ -811,6 +811,8 @@ void X86TargetLowering::resetOperationActions() {
   setOperationAction(ISD::FLOG10, MVT::f80, Expand);
   setOperationAction(ISD::FEXP, MVT::f80, Expand);
   setOperationAction(ISD::FEXP2, MVT::f80, Expand);
+  setOperationAction(ISD::FMINNUM, MVT::f80, Expand);
+  setOperationAction(ISD::FMAXNUM, MVT::f80, Expand);
 
   // First set operation action for all vector types to either promote
   // (for widening) or expand (for scalarization). Then we will selectively
@@ -1595,9 +1597,6 @@ void X86TargetLowering::resetOperationActions() {
     setOperationAction(ISD::UMULO, VT, Custom);
   }
 
-  // There are no 8-bit 3-address imul/mul instructions
-  setOperationAction(ISD::SMULO, MVT::i8, Expand);
-  setOperationAction(ISD::UMULO, MVT::i8, Expand);
 
   if (!Subtarget->is64Bit()) {
     // These libcalls are not available in 32-bit.
@@ -7449,7 +7448,7 @@ static SDValue lowerVectorShuffleAsDecomposedShuffleBlend(SDLoc DL, MVT VT,
 /// \brief Try to lower a vector shuffle as a byte rotation.
 ///
 /// We have a generic PALIGNR instruction in x86 that will do an arbitrary
-/// byte-rotation of a the concatentation of two vectors. This routine will
+/// byte-rotation of the concatenation of two vectors. This routine will
 /// try to generically lower a vector shuffle through such an instruction. It
 /// does not check for the availability of PALIGNR-based lowerings, only the
 /// applicability of this strategy to the given mask. This matches shuffle
@@ -7895,10 +7894,42 @@ static SDValue lowerVectorShuffleAsBroadcast(MVT VT, SDLoc DL, SDValue V,
                                             "a sorted mask where the broadcast "
                                             "comes from V1.");
 
-  // Check if this is a broadcast of a scalar. We special case lowering for
-  // scalars so that we can more effectively fold with loads.
+  // Go up the chain of (vector) values to try and find a scalar load that
+  // we can combine with the broadcast.
+  for (;;) {
+    switch (V.getOpcode()) {
+    case ISD::CONCAT_VECTORS: {
+      int OperandSize = Mask.size() / V.getNumOperands();
+      V = V.getOperand(BroadcastIdx / OperandSize);
+      BroadcastIdx %= OperandSize;
+      continue;
+    }
+
+    case ISD::INSERT_SUBVECTOR: {
+      SDValue VOuter = V.getOperand(0), VInner = V.getOperand(1);
+      auto ConstantIdx = dyn_cast<ConstantSDNode>(V.getOperand(2));
+      if (!ConstantIdx)
+        break;
+
+      int BeginIdx = (int)ConstantIdx->getZExtValue();
+      int EndIdx =
+          BeginIdx + (int)VInner.getValueType().getVectorNumElements();
+      if (BroadcastIdx >= BeginIdx && BroadcastIdx < EndIdx) {
+        BroadcastIdx -= BeginIdx;
+        V = VInner;
+      } else {
+        V = VOuter;
+      }
+      continue;
+    }
+    }
+    break;
+  }
+
+  // Check if this is a broadcast of a scalar. We special case lowering
+  // for scalars so that we can more effectively fold with loads.
   if (V.getOpcode() == ISD::BUILD_VECTOR ||
-        (V.getOpcode() == ISD::SCALAR_TO_VECTOR && BroadcastIdx == 0)) {
+      (V.getOpcode() == ISD::SCALAR_TO_VECTOR && BroadcastIdx == 0)) {
     V = V.getOperand(BroadcastIdx);
 
     // If the scalar isn't a load we can't broadcast from it in AVX1, only with
@@ -9017,7 +9048,8 @@ static SDValue lowerV8I16VectorShuffle(SDValue Op, SDValue V1, SDValue V2,
 
   // Try to use rotation instructions if available.
   if (Subtarget->hasSSSE3())
-    if (SDValue Rotate = lowerVectorShuffleAsByteRotate(DL, MVT::v8i16, V1, V2, Mask, DAG))
+    if (SDValue Rotate = lowerVectorShuffleAsByteRotate(
+            DL, MVT::v8i16, V1, V2, Mask, DAG))
       return Rotate;
 
   if (NumV1Inputs + NumV2Inputs <= 4)
@@ -9151,8 +9183,8 @@ static SDValue lowerV16I8VectorShuffle(SDValue Op, SDValue V1, SDValue V2,
 
   // Try to use rotation instructions if available.
   if (Subtarget->hasSSSE3())
-    if (SDValue Rotate = lowerVectorShuffleAsByteRotate(DL, MVT::v16i8, V1, V2,
-                                                        OrigMask, DAG))
+    if (SDValue Rotate = lowerVectorShuffleAsByteRotate(
+            DL, MVT::v16i8, V1, V2, OrigMask, DAG))
       return Rotate;
 
   // Try to use a zext lowering.
@@ -9278,21 +9310,29 @@ static SDValue lowerV16I8VectorShuffle(SDValue Op, SDValue V1, SDValue V2,
   //
   // FIXME: We need to handle other interleaving widths (i16, i32, ...).
   if (shouldLowerAsInterleaving(Mask)) {
-    // FIXME: Figure out whether we should pack these into the low or high
-    // halves.
-
-    int EMask[16], OMask[16];
+    int NumLoHalf = std::count_if(Mask.begin(), Mask.end(), [](int M) {
+      return (M >= 0 && M < 8) || (M >= 16 && M < 24);
+    });
+    int NumHiHalf = std::count_if(Mask.begin(), Mask.end(), [](int M) {
+      return (M >= 8 && M < 16) || M >= 24;
+    });
+    int EMask[16] = {-1, -1, -1, -1, -1, -1, -1, -1,
+                     -1, -1, -1, -1, -1, -1, -1, -1};
+    int OMask[16] = {-1, -1, -1, -1, -1, -1, -1, -1,
+                     -1, -1, -1, -1, -1, -1, -1, -1};
+    bool UnpackLo = NumLoHalf >= NumHiHalf;
+    MutableArrayRef<int> TargetEMask(UnpackLo ? EMask : EMask + 8, 8);
+    MutableArrayRef<int> TargetOMask(UnpackLo ? OMask : OMask + 8, 8);
     for (int i = 0; i < 8; ++i) {
-      EMask[i] = Mask[2*i];
-      OMask[i] = Mask[2*i + 1];
-      EMask[i + 8] = -1;
-      OMask[i + 8] = -1;
+      TargetEMask[i] = Mask[2 * i];
+      TargetOMask[i] = Mask[2 * i + 1];
     }
 
     SDValue Evens = DAG.getVectorShuffle(MVT::v16i8, DL, V1, V2, EMask);
     SDValue Odds = DAG.getVectorShuffle(MVT::v16i8, DL, V1, V2, OMask);
 
-    return DAG.getNode(X86ISD::UNPCKL, DL, MVT::v16i8, Evens, Odds);
+    return DAG.getNode(UnpackLo ? X86ISD::UNPCKL : X86ISD::UNPCKH, DL,
+                       MVT::v16i8, Evens, Odds);
   }
 
   // Check for SSSE3 which lets us lower all v16i8 shuffles much more directly
@@ -10191,7 +10231,6 @@ static SDValue lowerV8I64VectorShuffle(SDValue Op, SDValue V1, SDValue V2,
   ShuffleVectorSDNode *SVOp = cast<ShuffleVectorSDNode>(Op);
   ArrayRef<int> Mask = SVOp->getMask();
   assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!");
-  assert(Subtarget->hasDQI() && "We can only lower v8i64 with AVX-512-DQI");
 
   // FIXME: Implement direct support for this type!
   return splitAndLowerVectorShuffle(DL, MVT::v8i64, V1, V2, Mask, DAG);
@@ -10207,7 +10246,6 @@ static SDValue lowerV16I32VectorShuffle(SDValue Op, SDValue V1, SDValue V2,
   ShuffleVectorSDNode *SVOp = cast<ShuffleVectorSDNode>(Op);
   ArrayRef<int> Mask = SVOp->getMask();
   assert(Mask.size() == 16 && "Unexpected mask size for v16 shuffle!");
-  assert(Subtarget->hasDQI() && "We can only lower v16i32 with AVX-512-DQI!");
 
   // FIXME: Implement direct support for this type!
   return splitAndLowerVectorShuffle(DL, MVT::v16i32, V1, V2, Mask, DAG);
@@ -10259,6 +10297,11 @@ static SDValue lower512BitVectorShuffle(SDValue Op, SDValue V1, SDValue V2,
   assert(Subtarget->hasAVX512() &&
          "Cannot lower 512-bit vectors w/ basic ISA!");
 
+  // Check for being able to broadcast a single element.
+  if (SDValue Broadcast = lowerVectorShuffleAsBroadcast(VT.SimpleTy, DL, V1,
+                                                        Mask, Subtarget, DAG))
+    return Broadcast;
+
   // Dispatch to each element type for lowering. If we don't have supprot for
   // specific element type shuffles at 512 bits, immediately split them and
   // lower them. Each lowering routine of a given type is allowed to assume that
@@ -10269,13 +10312,9 @@ static SDValue lower512BitVectorShuffle(SDValue Op, SDValue V1, SDValue V2,
   case MVT::v16f32:
     return lowerV16F32VectorShuffle(Op, V1, V2, Subtarget, DAG);
   case MVT::v8i64:
-    if (Subtarget->hasDQI())
-      return lowerV8I64VectorShuffle(Op, V1, V2, Subtarget, DAG);
-    break;
+    return lowerV8I64VectorShuffle(Op, V1, V2, Subtarget, DAG);
   case MVT::v16i32:
-    if (Subtarget->hasDQI())
-      return lowerV16I32VectorShuffle(Op, V1, V2, Subtarget, DAG);
-    break;
+    return lowerV16I32VectorShuffle(Op, V1, V2, Subtarget, DAG);
   case MVT::v32i16:
     if (Subtarget->hasBWI())
       return lowerV32I16VectorShuffle(Op, V1, V2, Subtarget, DAG);
@@ -14327,6 +14366,37 @@ SDValue X86TargetLowering::ConvertCmpIfNecessary(SDValue Cmp,
   return DAG.getNode(X86ISD::SAHF, dl, MVT::i32, TruncSrl);
 }
 
+/// The minimum architected relative accuracy is 2^-12. We need one
+/// Newton-Raphson step to have a good float result (24 bits of precision).
+SDValue X86TargetLowering::getRsqrtEstimate(SDValue Op,
+                                            DAGCombinerInfo &DCI,
+                                            unsigned &RefinementSteps,
+                                            bool &UseOneConstNR) const {
+  // FIXME: We should use instruction latency models to calculate the cost of
+  // each potential sequence, but this is very hard to do reliably because
+  // at least Intel's Core* chips have variable timing based on the number of
+  // significant digits in the divisor and/or sqrt operand.
+  if (!Subtarget->useSqrtEst())
+    return SDValue();
+
+  EVT VT = Op.getValueType();
+  
+  // SSE1 has rsqrtss and rsqrtps.
+  // TODO: Add support for AVX512 (v16f32).
+  // It is likely not profitable to do this for f64 because a double-precision
+  // rsqrt estimate with refinement on x86 prior to FMA requires at least 16
+  // instructions: convert to single, rsqrtss, convert back to double, refine
+  // (3 steps = at least 13 insts). If an 'rsqrtsd' variant was added to the ISA
+  // along with FMA, this could be a throughput win.
+  if ((Subtarget->hasSSE1() && (VT == MVT::f32 || VT == MVT::v4f32)) ||
+      (Subtarget->hasAVX() && VT == MVT::v8f32)) {
+    RefinementSteps = 1;
+    UseOneConstNR = false;
+    return DCI.DAG.getNode(X86ISD::FRSQRT, SDLoc(Op), VT, Op);
+  }
+  return SDValue();
+}
+
 static bool isAllOnes(SDValue V) {
   ConstantSDNode *C = dyn_cast<ConstantSDNode>(V);
   return C && C->isAllOnesValue();
@@ -15111,13 +15181,32 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getNode(X86ISD::CMOV, DL, VTs, Ops);
 }
 
-static SDValue LowerSIGN_EXTEND_AVX512(SDValue Op, SelectionDAG &DAG) {
+static SDValue LowerSIGN_EXTEND_AVX512(SDValue Op, const X86Subtarget *Subtarget,
+                                       SelectionDAG &DAG) {
   MVT VT = Op->getSimpleValueType(0);
   SDValue In = Op->getOperand(0);
   MVT InVT = In.getSimpleValueType();
+  MVT VTElt = VT.getVectorElementType();
+  MVT InVTElt = InVT.getVectorElementType();
   SDLoc dl(Op);
 
+  // SKX processor
+  if ((InVTElt == MVT::i1) &&
+      (((Subtarget->hasBWI() && Subtarget->hasVLX() &&
+        VT.getSizeInBits() <= 256 && VTElt.getSizeInBits() <= 16)) ||
+
+       ((Subtarget->hasBWI() && VT.is512BitVector() &&
+        VTElt.getSizeInBits() <= 16)) ||
+
+       ((Subtarget->hasDQI() && Subtarget->hasVLX() &&
+        VT.getSizeInBits() <= 256 && VTElt.getSizeInBits() >= 32)) ||
+    
+       ((Subtarget->hasDQI() && VT.is512BitVector() &&
+        VTElt.getSizeInBits() >= 32))))
+    return DAG.getNode(X86ISD::VSEXT, dl, VT, In);
+    
   unsigned int NumElts = VT.getVectorNumElements();
+
   if (NumElts != 8 && NumElts != 16)
     return SDValue();
 
@@ -15150,7 +15239,7 @@ static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget *Subtarget,
   SDLoc dl(Op);
 
   if (VT.is512BitVector() || InVT.getVectorElementType() == MVT::i1)
-    return LowerSIGN_EXTEND_AVX512(Op, DAG);
+    return LowerSIGN_EXTEND_AVX512(Op, Subtarget, DAG);
 
   if ((VT != MVT::v4i64 || InVT != MVT::v4i32) &&
       (VT != MVT::v8i32 || InVT != MVT::v8i16) &&
@@ -16167,7 +16256,8 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) {
     case INTR_TYPE_3OP:
       return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1),
         Op.getOperand(2), Op.getOperand(3));
-    case CMP_MASK: {
+    case CMP_MASK:
+    case CMP_MASK_CC: {
       // Comparison intrinsics with masks.
       // Example of transformation:
       // (i8 (int_x86_avx512_mask_pcmpeq_q_128
@@ -16180,12 +16270,19 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) {
       EVT VT = Op.getOperand(1).getValueType();
       EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
                                     VT.getVectorNumElements());
-      SDValue Mask = Op.getOperand(3);
+      SDValue Mask = Op.getOperand((IntrData->Type == CMP_MASK_CC) ? 4 : 3);
       EVT BitcastVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
                                        Mask.getValueType().getSizeInBits());
-      SDValue Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT,
-                                Op.getOperand(1), Op.getOperand(2));
-      SDValue CmpMask = getVectorMaskingNode(Cmp, Op.getOperand(3),
+      SDValue Cmp;
+      if (IntrData->Type == CMP_MASK_CC) {
+        Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT, Op.getOperand(1),
+                    Op.getOperand(2), Op.getOperand(3));
+      } else {
+        assert(IntrData->Type == CMP_MASK && "Unexpected intrinsic type!");
+        Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT, Op.getOperand(1),
+                    Op.getOperand(2));
+      }
+      SDValue CmpMask = getVectorMaskingNode(Cmp, Mask,
                                         DAG.getTargetConstant(0, MaskVT), DAG);
       SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, BitcastVT,
                                 DAG.getUNDEF(BitcastVT), CmpMask,
@@ -18120,10 +18217,15 @@ static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) {
     Cond = X86::COND_B;
     break;
   case ISD::SMULO:
-    BaseOp = X86ISD::SMUL;
+    BaseOp = N->getValueType(0) == MVT::i8 ? X86ISD::SMUL8 : X86ISD::SMUL;
     Cond = X86::COND_O;
     break;
   case ISD::UMULO: { // i64, i8 = umulo lhs, rhs --> i64, i64, i32 umul lhs,rhs
+    if (N->getValueType(0) == MVT::i8) {
+      BaseOp = X86ISD::UMUL8;
+      Cond = X86::COND_O;
+      break;
+    }
     SDVTList VTs = DAG.getVTList(N->getValueType(0), N->getValueType(0),
                                  MVT::i32);
     SDValue Sum = DAG.getNode(X86ISD::UMUL, DL, VTs, LHS, RHS);
@@ -21673,7 +21775,7 @@ static SDValue PerformTruncateCombine(SDNode *N, SelectionDAG &DAG,
 /// XFormVExtractWithShuffleIntoLoad - Check if a vector extract from a target
 /// specific shuffle of a load can be folded into a single element load.
 /// Similar handling for VECTOR_SHUFFLE is performed by DAGCombiner, but
-/// shuffles have been customed lowered so we need to handle those here.
+/// shuffles have been custom lowered so we need to handle those here.
 static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG,
                                          TargetLowering::DAGCombinerInfo &DCI) {
   if (DCI.isBeforeLegalizeOps())
@@ -21685,18 +21787,20 @@ static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG,
   if (!isa<ConstantSDNode>(EltNo))
     return SDValue();
 
-  EVT VT = InVec.getValueType();
+  EVT OriginalVT = InVec.getValueType();
 
   if (InVec.getOpcode() == ISD::BITCAST) {
     // Don't duplicate a load with other uses.
     if (!InVec.hasOneUse())
       return SDValue();
     EVT BCVT = InVec.getOperand(0).getValueType();
-    if (BCVT.getVectorNumElements() != VT.getVectorNumElements())
+    if (BCVT.getVectorNumElements() != OriginalVT.getVectorNumElements())
       return SDValue();
     InVec = InVec.getOperand(0);
   }
 
+  EVT CurrentVT = InVec.getValueType();
+
   if (!isTargetShuffle(InVec.getOpcode()))
     return SDValue();
 
@@ -21706,12 +21810,12 @@ static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG,
 
   SmallVector<int, 16> ShuffleMask;
   bool UnaryShuffle;
-  if (!getTargetShuffleMask(InVec.getNode(), VT.getSimpleVT(), ShuffleMask,
-                            UnaryShuffle))
+  if (!getTargetShuffleMask(InVec.getNode(), CurrentVT.getSimpleVT(),
+                            ShuffleMask, UnaryShuffle))
     return SDValue();
 
   // Select the input vector, guarding against out of range extract vector.
-  unsigned NumElems = VT.getVectorNumElements();
+  unsigned NumElems = CurrentVT.getVectorNumElements();
   int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue();
   int Idx = (Elt > (int)NumElems) ? -1 : ShuffleMask[Elt];
   SDValue LdNode = (Idx < (int)NumElems) ? InVec.getOperand(0)
@@ -21753,11 +21857,12 @@ static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG,
   SDLoc dl(N);
 
   // Create shuffle node taking into account the case that its a unary shuffle
-  SDValue Shuffle = (UnaryShuffle) ? DAG.getUNDEF(VT) : InVec.getOperand(1);
-  Shuffle = DAG.getVectorShuffle(InVec.getValueType(), dl,
+  SDValue Shuffle = (UnaryShuffle) ? DAG.getUNDEF(CurrentVT)
+                                   : InVec.getOperand(1);
+  Shuffle = DAG.getVectorShuffle(CurrentVT, dl,
                                  InVec.getOperand(0), Shuffle,
                                  &ShuffleMask[0]);
-  Shuffle = DAG.getNode(ISD::BITCAST, dl, VT, Shuffle);
+  Shuffle = DAG.getNode(ISD::BITCAST, dl, OriginalVT, Shuffle);
   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, N->getValueType(0), Shuffle,
                      EltNo);
 }
@@ -22530,7 +22635,12 @@ static SDValue PerformSELECTCombine(SDNode *N, SelectionDAG &DAG,
     TargetLowering::TargetLoweringOpt TLO(DAG, DCI.isBeforeLegalize(),
                                           DCI.isBeforeLegalizeOps());
     if (TLO.ShrinkDemandedConstant(Cond, DemandedMask) ||
-        TLI.SimplifyDemandedBits(Cond, DemandedMask, KnownZero, KnownOne, TLO))
+        (TLI.SimplifyDemandedBits(Cond, DemandedMask, KnownZero, KnownOne,
+                                  TLO) &&
+         // Don't optimize vector of constants. Those are handled by
+         // the generic code and all the bits must be properly set for
+         // the generic optimizer.
+         !ISD::isBuildVectorOfConstantSDNodes(TLO.New.getNode())))
       DCI.CommitTargetLoweringOpt(TLO);
   }