// [ 5, 6, 7, zz, zz, zz, zz, zz]
// [ -1, 5, 6, 7, zz, zz, zz, zz]
// [ 1, 2, -1, -1, -1, -1, zz, zz]
- auto MatchByteShift = [&](int Shift) -> SDValue {
- bool MatchLeft = true, MatchRight = true;
- for (int l = 0; l < NumElts; l += NumLaneElts) {
+
+ auto CheckZeros = [&](int Shift, bool LeftShift) {
+ for (int l = 0; l < NumElts; l += NumLaneElts)
for (int i = 0; i < Shift; ++i)
- MatchLeft &= Zeroable[l + i];
- for (int i = NumLaneElts - Shift; i < NumLaneElts; ++i)
- MatchRight &= Zeroable[l + i];
- }
- if (!(MatchLeft || MatchRight))
- return SDValue();
+ if (!Zeroable[l + i + (LeftShift ? 0 : (NumLaneElts - Shift))])
+ return false;
- bool MatchV1 = true, MatchV2 = true;
+ return true;
+ };
+
+ auto MatchByteShift = [&](int Shift, bool LeftShift, SDValue V) {
for (int l = 0; l < NumElts; l += NumLaneElts) {
- unsigned Pos = MatchLeft ? Shift + l : l;
- unsigned Low = MatchLeft ? l : Shift + l;
+ unsigned Pos = LeftShift ? Shift + l : l;
+ unsigned Low = LeftShift ? l : Shift + l;
unsigned Len = NumLaneElts - Shift;
- MatchV1 &= isSequentialOrUndefInRange(Mask, Pos, Len, Low);
- MatchV2 &= isSequentialOrUndefInRange(Mask, Pos, Len, Low + NumElts);
+ if (!isSequentialOrUndefInRange(Mask, Pos, Len,
+ Low + (V == V1 ? 0 : NumElts)))
+ return SDValue();
}
- if (!(MatchV1 || MatchV2))
- return SDValue();
int ByteShift = Shift * Scale;
- unsigned Op = MatchRight ? X86ISD::VSRLDQ : X86ISD::VSHLDQ;
- SDValue V = MatchV1 ? V1 : V2;
+ unsigned Op = LeftShift ? X86ISD::VSHLDQ : X86ISD::VSRLDQ;
V = DAG.getNode(ISD::BITCAST, DL, ShiftVT, V);
- V = DAG.getNode(Op, DL, ShiftVT, V,
- DAG.getConstant(ByteShift, MVT::i8));
+ V = DAG.getNode(Op, DL, ShiftVT, V, DAG.getConstant(ByteShift, MVT::i8));
return DAG.getNode(ISD::BITCAST, DL, VT, V);
};
for (int Shift = 1; Shift < NumLaneElts; ++Shift)
- if (SDValue S = MatchByteShift(Shift))
- return S;
+ for (bool LeftShift : {true, false})
+ if (CheckZeros(Shift, LeftShift))
+ for (SDValue V : {V1, V2})
+ if (SDValue S = MatchByteShift(Shift, LeftShift, V))
+ return S;
// no match
return SDValue();