+/// This function detects the AVG pattern between vectors of unsigned i8/i16,
+/// which is c = (a + b + 1) / 2, and replace this operation with the efficient
+/// X86ISD::AVG instruction.
+static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
+ const X86Subtarget *Subtarget, SDLoc DL) {
+ if (!VT.isVector() || !VT.isSimple())
+ return SDValue();
+ EVT InVT = In.getValueType();
+ unsigned NumElems = VT.getVectorNumElements();
+
+ EVT ScalarVT = VT.getVectorElementType();
+ if (!((ScalarVT == MVT::i8 || ScalarVT == MVT::i16) &&
+ isPowerOf2_32(NumElems)))
+ return SDValue();
+
+ // InScalarVT is the intermediate type in AVG pattern and it should be greater
+ // than the original input type (i8/i16).
+ EVT InScalarVT = InVT.getVectorElementType();
+ if (InScalarVT.getSizeInBits() <= ScalarVT.getSizeInBits())
+ return SDValue();
+
+ if (Subtarget->hasAVX512()) {
+ if (VT.getSizeInBits() > 512)
+ return SDValue();
+ } else if (Subtarget->hasAVX2()) {
+ if (VT.getSizeInBits() > 256)
+ return SDValue();
+ } else {
+ if (VT.getSizeInBits() > 128)
+ return SDValue();
+ }
+
+ // Detect the following pattern:
+ //
+ // %1 = zext <N x i8> %a to <N x i32>
+ // %2 = zext <N x i8> %b to <N x i32>
+ // %3 = add nuw nsw <N x i32> %1, <i32 1 x N>
+ // %4 = add nuw nsw <N x i32> %3, %2
+ // %5 = lshr <N x i32> %N, <i32 1 x N>
+ // %6 = trunc <N x i32> %5 to <N x i8>
+ //
+ // In AVX512, the last instruction can also be a trunc store.
+
+ if (In.getOpcode() != ISD::SRL)
+ return SDValue();
+
+ // A lambda checking the given SDValue is a constant vector and each element
+ // is in the range [Min, Max].
+ auto IsConstVectorInRange = [](SDValue V, unsigned Min, unsigned Max) {
+ BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(V);
+ if (!BV || !BV->isConstant())
+ return false;
+ for (unsigned i = 0, e = V.getNumOperands(); i < e; i++) {
+ ConstantSDNode *C = dyn_cast<ConstantSDNode>(V.getOperand(i));
+ if (!C)
+ return false;
+ uint64_t Val = C->getZExtValue();
+ if (Val < Min || Val > Max)
+ return false;
+ }
+ return true;
+ };
+
+ // Check if each element of the vector is left-shifted by one.
+ auto LHS = In.getOperand(0);
+ auto RHS = In.getOperand(1);
+ if (!IsConstVectorInRange(RHS, 1, 1))
+ return SDValue();
+ if (LHS.getOpcode() != ISD::ADD)
+ return SDValue();
+
+ // Detect a pattern of a + b + 1 where the order doesn't matter.
+ SDValue Operands[3];
+ Operands[0] = LHS.getOperand(0);
+ Operands[1] = LHS.getOperand(1);
+
+ // Take care of the case when one of the operands is a constant vector whose
+ // element is in the range [1, 256].
+ if (IsConstVectorInRange(Operands[1], 1, ScalarVT == MVT::i8 ? 256 : 65536) &&
+ Operands[0].getOpcode() == ISD::ZERO_EXTEND &&
+ Operands[0].getOperand(0).getValueType() == VT) {
+ // The pattern is detected. Subtract one from the constant vector, then
+ // demote it and emit X86ISD::AVG instruction.
+ SDValue One = DAG.getConstant(1, DL, InScalarVT);
+ SDValue Ones = DAG.getNode(ISD::BUILD_VECTOR, DL, InVT,
+ SmallVector<SDValue, 8>(NumElems, One));
+ Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], Ones);
+ Operands[1] = DAG.getNode(ISD::TRUNCATE, DL, VT, Operands[1]);
+ return DAG.getNode(X86ISD::AVG, DL, VT, Operands[0].getOperand(0),
+ Operands[1]);
+ }
+
+ if (Operands[0].getOpcode() == ISD::ADD)
+ std::swap(Operands[0], Operands[1]);
+ else if (Operands[1].getOpcode() != ISD::ADD)
+ return SDValue();
+ Operands[2] = Operands[1].getOperand(0);
+ Operands[1] = Operands[1].getOperand(1);
+
+ // Now we have three operands of two additions. Check that one of them is a
+ // constant vector with ones, and the other two are promoted from i8/i16.
+ for (int i = 0; i < 3; ++i) {
+ if (!IsConstVectorInRange(Operands[i], 1, 1))
+ continue;
+ std::swap(Operands[i], Operands[2]);
+
+ // Check if Operands[0] and Operands[1] are results of type promotion.
+ for (int j = 0; j < 2; ++j)
+ if (Operands[j].getOpcode() != ISD::ZERO_EXTEND ||
+ Operands[j].getOperand(0).getValueType() != VT)
+ return SDValue();
+
+ // The pattern is detected, emit X86ISD::AVG instruction.
+ return DAG.getNode(X86ISD::AVG, DL, VT, Operands[0].getOperand(0),
+ Operands[1].getOperand(0));
+ }
+
+ return SDValue();
+}
+
+static SDValue PerformTRUNCATECombine(SDNode *N, SelectionDAG &DAG,
+ const X86Subtarget *Subtarget) {
+ return detectAVGPattern(N->getOperand(0), N->getValueType(0), DAG, Subtarget,
+ SDLoc(N));
+}
+