[X86][SSE] Shuffle blends with zero
[oota-llvm.git] / lib / Target / X86 / X86ISelLowering.cpp
index 08db5b89e96e97c6af8843f0fb65cd6e57308909..494121e540c22171b04f0359b8e486bf03a9a171 100644 (file)
@@ -6859,22 +6859,62 @@ static SDValue lowerVectorShuffleAsBitBlend(SDLoc DL, MVT VT, SDValue V1,
 /// This doesn't do any checks for the availability of instructions for blending
 /// these values. It relies on the availability of the X86ISD::BLENDI pattern to
 /// be matched in the backend with the type given. What it does check for is
-/// that the shuffle mask is in fact a blend.
+/// that the shuffle mask is a blend, or convertible into a blend with zero.
 static SDValue lowerVectorShuffleAsBlend(SDLoc DL, MVT VT, SDValue V1,
-                                         SDValue V2, ArrayRef<int> Mask,
+                                         SDValue V2, ArrayRef<int> Original,
                                          const X86Subtarget *Subtarget,
                                          SelectionDAG &DAG) {
+  bool V1IsZero = ISD::isBuildVectorAllZeros(V1.getNode());
+  bool V2IsZero = ISD::isBuildVectorAllZeros(V2.getNode());
+  SmallVector<int, 8> Mask(Original.begin(), Original.end());
+  SmallBitVector Zeroable = computeZeroableShuffleElements(Mask, V1, V2);
+  bool ForceV1Zero = false, ForceV2Zero = false;
+
+  // Attempt to generate the binary blend mask. If an input is zero then
+  // we can use any lane.
+  // TODO: generalize the zero matching to any scalar like isShuffleEquivalent.
   unsigned BlendMask = 0;
   for (int i = 0, Size = Mask.size(); i < Size; ++i) {
-    if (Mask[i] >= Size) {
-      if (Mask[i] != i + Size)
-        return SDValue(); // Shuffled V2 input!
+    int M = Mask[i];
+    if (M < 0)
+      continue;
+    if (M == i)
+      continue;
+    if (M == i + Size) {
       BlendMask |= 1u << i;
       continue;
     }
-    if (Mask[i] >= 0 && Mask[i] != i)
-      return SDValue(); // Shuffled V1 input!
+    if (Zeroable[i]) {
+      if (V1IsZero) {
+        ForceV1Zero = true;
+        Mask[i] = i;
+        continue;
+      }
+      if (V2IsZero) {
+        ForceV2Zero = true;
+        BlendMask |= 1u << i;
+        Mask[i] = i + Size;
+        continue;
+      }
+    }
+    return SDValue(); // Shuffled input!
   }
+
+  // Create a REAL zero vector - ISD::isBuildVectorAllZeros allows UNDEFs.
+  if (ForceV1Zero)
+    V1 = getZeroVector(VT, Subtarget, DAG, DL);
+  if (ForceV2Zero)
+    V2 = getZeroVector(VT, Subtarget, DAG, DL);
+
+  auto ScaleBlendMask = [](unsigned BlendMask, int Size, int Scale) {
+    unsigned ScaledMask = 0;
+    for (int i = 0; i != Size; ++i)
+      if (BlendMask & (1u << i))
+        for (int j = 0; j != Scale; ++j)
+          ScaledMask |= 1u << (i * Scale + j);
+    return ScaledMask;
+  };
+
   switch (VT.SimpleTy) {
   case MVT::v2f64:
   case MVT::v4f32:
@@ -6894,12 +6934,7 @@ static SDValue lowerVectorShuffleAsBlend(SDLoc DL, MVT VT, SDValue V1,
     if (Subtarget->hasAVX2()) {
       // Scale the blend by the number of 32-bit dwords per element.
       int Scale =  VT.getScalarSizeInBits() / 32;
-      BlendMask = 0;
-      for (int i = 0, Size = Mask.size(); i < Size; ++i)
-        if (Mask[i] >= Size)
-          for (int j = 0; j < Scale; ++j)
-            BlendMask |= 1u << (i * Scale + j);
-
+      BlendMask = ScaleBlendMask(BlendMask, Mask.size(), Scale);
       MVT BlendVT = VT.getSizeInBits() > 128 ? MVT::v8i32 : MVT::v4i32;
       V1 = DAG.getBitcast(BlendVT, V1);
       V2 = DAG.getBitcast(BlendVT, V2);
@@ -6912,12 +6947,7 @@ static SDValue lowerVectorShuffleAsBlend(SDLoc DL, MVT VT, SDValue V1,
     // For integer shuffles we need to expand the mask and cast the inputs to
     // v8i16s prior to blending.
     int Scale = 8 / VT.getVectorNumElements();
-    BlendMask = 0;
-    for (int i = 0, Size = Mask.size(); i < Size; ++i)
-      if (Mask[i] >= Size)
-        for (int j = 0; j < Scale; ++j)
-          BlendMask |= 1u << (i * Scale + j);
-
+    BlendMask = ScaleBlendMask(BlendMask, Mask.size(), Scale);
     V1 = DAG.getBitcast(MVT::v8i16, V1);
     V2 = DAG.getBitcast(MVT::v8i16, V2);
     return DAG.getBitcast(VT,