[x86] Rewrite the byte shift detection to not use boolean variables to
[oota-llvm.git] / lib / Target / X86 / X86ISelLowering.cpp
index cd34319203834aad55702f1a81be29084b84d3b7..6255cce0d2bc0a03483020d99f6dc2b36e5e8d88 100644 (file)
@@ -7888,40 +7888,39 @@ static SDValue lowerVectorShuffleAsByteShift(SDLoc DL, MVT VT, SDValue V1,
   // [  5, 6,  7, zz, zz, zz, zz, zz]
   // [ -1, 5,  6,  7, zz, zz, zz, zz]
   // [  1, 2, -1, -1, -1, -1, zz, zz]
-  auto MatchByteShift = [&](int Shift) -> SDValue {
-    bool MatchLeft = true, MatchRight = true;
-    for (int l = 0; l < NumElts; l += NumLaneElts) {
+
+  auto CheckZeros = [&](int Shift, bool LeftShift) {
+    for (int l = 0; l < NumElts; l += NumLaneElts)
       for (int i = 0; i < Shift; ++i)
-        MatchLeft &= Zeroable[l + i];
-      for (int i = NumLaneElts - Shift; i < NumLaneElts; ++i)
-        MatchRight &= Zeroable[l + i];
-    }
-    if (!(MatchLeft || MatchRight))
-      return SDValue();
+        if (!Zeroable[l + i + (LeftShift ? 0 : (NumLaneElts - Shift))])
+          return false;
 
-    bool MatchV1 = true, MatchV2 = true;
+    return true;
+  };
+
+  auto MatchByteShift = [&](int Shift, bool LeftShift, SDValue V) {
     for (int l = 0; l < NumElts; l += NumLaneElts) {
-      unsigned Pos = MatchLeft ? Shift + l : l;
-      unsigned Low = MatchLeft ? l : Shift + l;
+      unsigned Pos = LeftShift ? Shift + l : l;
+      unsigned Low = LeftShift ? l : Shift + l;
       unsigned Len = NumLaneElts - Shift;
-      MatchV1 &= isSequentialOrUndefInRange(Mask, Pos, Len, Low);
-      MatchV2 &= isSequentialOrUndefInRange(Mask, Pos, Len, Low + NumElts);
+      if (!isSequentialOrUndefInRange(Mask, Pos, Len,
+                                      Low + (V == V1 ? 0 : NumElts)))
+        return SDValue();
     }
-    if (!(MatchV1 || MatchV2))
-      return SDValue();
 
     int ByteShift = Shift * Scale;
-    unsigned Op = MatchRight ? X86ISD::VSRLDQ : X86ISD::VSHLDQ;
-    SDValue V = MatchV1 ? V1 : V2;
+    unsigned Op = LeftShift ? X86ISD::VSHLDQ : X86ISD::VSRLDQ;
     V = DAG.getNode(ISD::BITCAST, DL, ShiftVT, V);
-    V = DAG.getNode(Op, DL, ShiftVT, V,
-                    DAG.getConstant(ByteShift, MVT::i8));
+    V = DAG.getNode(Op, DL, ShiftVT, V, DAG.getConstant(ByteShift, MVT::i8));
     return DAG.getNode(ISD::BITCAST, DL, VT, V);
   };
 
   for (int Shift = 1; Shift < NumLaneElts; ++Shift)
-    if (SDValue S = MatchByteShift(Shift))
-      return S;
+    for (bool LeftShift : {true, false})
+      if (CheckZeros(Shift, LeftShift))
+        for (SDValue V : {V1, V2})
+          if (SDValue S = MatchByteShift(Shift, LeftShift, V))
+            return S;
 
   // no match
   return SDValue();