[X86] Refactor the logic to select horizontal adds/subs to a helper function.
authorAndrea Di Biagio <Andrea_DiBiagio@sn.scee.net>
Wed, 11 Jun 2014 07:57:50 +0000 (07:57 +0000)
committerAndrea Di Biagio <Andrea_DiBiagio@sn.scee.net>
Wed, 11 Jun 2014 07:57:50 +0000 (07:57 +0000)
This patch moves part of the logic implemented by the target specific
combine rules added at r210477 to a separate helper function.
This should make easier to add more rules for matching AVX/AVX2 horizontal
adds/subs.

This patch also fixes a problem caused by a wrong check performed on indices
of extract_vector_elt dag nodes in input to the scalar adds/subs.

New tests have been added to verify that we correctly check indices of
extract_vector_elt dag nodes when selecting a horizontal operation.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@210644 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
test/CodeGen/X86/haddsub-2.ll

index c7c8cb53fb7d9ed1e297e1f0eafc65737c391739..8cf3b53b8679c4a157a32c05a64c8a03173ca272 100644 (file)
@@ -6057,102 +6057,130 @@ X86TargetLowering::LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getNode(ISD::BITCAST, dl, VT, Select);
 }
 
-static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
-                                          const X86Subtarget *Subtarget) {
-  EVT VT = N->getValueType(0);
+/// \brief Return true if \p N implements a horizontal binop and return the
+/// operands for the horizontal binop into V0 and V1.
+/// 
+/// This is a helper function of PerformBUILD_VECTORCombine.
+/// This function checks that the build_vector \p N in input implements a
+/// horizontal operation. Parameter \p Opcode defines the kind of horizontal
+/// operation to match.
+/// For example, if \p Opcode is equal to ISD::ADD, then this function
+/// checks if \p N implements a horizontal arithmetic add; if instead \p Opcode
+/// is equal to ISD::SUB, then this function checks if this is a horizontal
+/// arithmetic sub.
+///
+/// This function only analyzes elements of \p N whose indices are
+/// in range [BaseIdx, LastIdx).
+static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode,
+                              unsigned BaseIdx, unsigned LastIdx,
+                              SDValue &V0, SDValue &V1) {
+  assert(BaseIdx * 2 <= LastIdx && "Invalid Indices in input!");
+  assert(N->getValueType(0).isVector() &&
+         N->getValueType(0).getVectorNumElements() >= LastIdx &&
+         "Invalid Vector in input!");
+  
+  bool IsCommutable = (Opcode == ISD::ADD || Opcode == ISD::FADD);
+  bool CanFold = true;
+  unsigned ExpectedVExtractIdx = BaseIdx;
+  unsigned NumElts = LastIdx - BaseIdx;
 
-  // Try to match a horizontal ADD or SUB.
-  if (((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) ||
-      ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) ||
-      ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 ||
-        VT == MVT::v16i16) && Subtarget->hasAVX())) {
-    unsigned NumOperands = N->getNumOperands();
-    unsigned Opcode = N->getOperand(0)->getOpcode();
-    bool isCommutable = false;
-    bool CanFold = false;
-    switch (Opcode) {
-    default : break;
-    case ISD::ADD :
-    case ISD::FADD :
-      isCommutable = true;
-      // FALL-THROUGH
-    case ISD::SUB :
-    case ISD::FSUB :
-      CanFold = true;
-    }
-
-    // Verify that operands have the same opcode; also, the opcode can only
-    // be either of: ADD, FADD, SUB, FSUB.
-    SDValue InVec0, InVec1;
-    for (unsigned i = 0, e = NumOperands; i != e && CanFold; ++i) {
-      SDValue Op = N->getOperand(i);
-      CanFold = Op->getOpcode() == Opcode && Op->hasOneUse();
-
-      if (!CanFold)
-        break;
+  // Check if N implements a horizontal binop.
+  for (unsigned i = 0, e = NumElts; i != e && CanFold; ++i) {
+    SDValue Op = N->getOperand(i + BaseIdx);
+    CanFold = Op->getOpcode() == Opcode && Op->hasOneUse();
 
-      SDValue Op0 = Op.getOperand(0);
-      SDValue Op1 = Op.getOperand(1);
-
-      // Try to match the following pattern:
-      // (BINOP (extract_vector_elt A, I), (extract_vector_elt A, I+1))
-      CanFold = (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
-          Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
-          Op0.getOperand(0) == Op1.getOperand(0) &&
-          isa<ConstantSDNode>(Op0.getOperand(1)) &&
-          isa<ConstantSDNode>(Op1.getOperand(1)));
-      if (!CanFold)
-        break;
+    if (!CanFold)
+      break;
 
-      unsigned I0 = cast<ConstantSDNode>(Op0.getOperand(1))->getZExtValue();
-      unsigned I1 = cast<ConstantSDNode>(Op1.getOperand(1))->getZExtValue();
-      unsigned ExpectedIndex = (i * 2) % NumOperands;
-      if (i == 0)
-        InVec0 = Op0.getOperand(0);
-      else if (i * 2 == NumOperands)
-        InVec1 = Op0.getOperand(0);
-
-      SDValue Expected = (i * 2 < NumOperands) ? InVec0 : InVec1;
-      if (I0 == ExpectedIndex)
-        CanFold = I1 == I0 + 1 && Op0.getOperand(0) == Expected;
-      else if (isCommutable && I1 == ExpectedIndex) {
-        // Try to see if we can match the following dag sequence:
-        // (BINOP (extract_vector_elt A, I+1), (extract_vector_elt A, I))
-        CanFold = I0 == I1 + 1 && Op1.getOperand(0) == Expected;
-      }
-    }
+    SDValue Op0 = Op.getOperand(0);
+    SDValue Op1 = Op.getOperand(1);
+
+    // Try to match the following pattern:
+    // (BINOP (extract_vector_elt A, I), (extract_vector_elt A, I+1))
+    CanFold = (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+        Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+        Op0.getOperand(0) == Op1.getOperand(0) &&
+        isa<ConstantSDNode>(Op0.getOperand(1)) &&
+        isa<ConstantSDNode>(Op1.getOperand(1)));
+    if (!CanFold)
+      break;
 
-    if (CanFold) {
-      unsigned NewOpcode;
-      switch (Opcode) {
-      default : llvm_unreachable("Unexpected opcode found!");
-      case ISD::ADD : NewOpcode = X86ISD::HADD; break;
-      case ISD::FADD : NewOpcode = X86ISD::FHADD; break;
-      case ISD::SUB : NewOpcode = X86ISD::HSUB; break;
-      case ISD::FSUB : NewOpcode = X86ISD::FHSUB; break;
-      }
+    unsigned I0 = cast<ConstantSDNode>(Op0.getOperand(1))->getZExtValue();
+    unsigned I1 = cast<ConstantSDNode>(Op1.getOperand(1))->getZExtValue();
  
-      if (VT.is256BitVector()) {
-        SDLoc dl(N);
-
-        // Convert this sequence into two horizontal add/sub followed
-        // by a concat vector.
-        SDValue InVec0_LO = Extract128BitVector(InVec0, 0, DAG, dl);
-        SDValue InVec0_HI =
-          Extract128BitVector(InVec0, NumOperands/2, DAG, dl);
-        SDValue InVec1_LO = Extract128BitVector(InVec1, 0, DAG, dl);
-        SDValue InVec1_HI =
-          Extract128BitVector(InVec1, NumOperands/2, DAG, dl);
-        EVT NewVT = InVec0_LO.getValueType();
-
-        SDValue LO = DAG.getNode(NewOpcode, dl, NewVT, InVec0_LO, InVec0_HI);
-        SDValue HI = DAG.getNode(NewOpcode, dl, NewVT, InVec1_LO, InVec1_HI);
-        return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, LO, HI);
-      }
+    if (i == 0)
+      V0 = Op0.getOperand(0);
+    else if (i * 2 == NumElts) {
+      V1 = Op0.getOperand(0);
+      ExpectedVExtractIdx = BaseIdx;
+    }
+
+    SDValue Expected = (i * 2 < NumElts) ? V0 : V1;
+    if (I0 == ExpectedVExtractIdx)
+      CanFold = I1 == I0 + 1 && Op0.getOperand(0) == Expected;
+    else if (IsCommutable && I1 == ExpectedVExtractIdx) {
+      // Try to match the following dag sequence:
+      // (BINOP (extract_vector_elt A, I+1), (extract_vector_elt A, I))
+      CanFold = I0 == I1 + 1 && Op1.getOperand(0) == Expected;
+    } else
+      CanFold = false;
 
-      return DAG.getNode(NewOpcode, SDLoc(N), VT, InVec0, InVec1);
-    }
+    ExpectedVExtractIdx += 2;
+  }
+
+  return CanFold;
+}
+
+static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
+                                          const X86Subtarget *Subtarget) {
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  unsigned NumElts = VT.getVectorNumElements();
+  BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N);
+  SDValue InVec0, InVec1;
+
+  // Try to match horizontal ADD/SUB.
+  if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) {
+    // Try to match an SSE3 float HADD/HSUB.
+    if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1))
+      return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1);
+    
+    if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1))
+      return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1);
+  } else if ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) {
+    // Try to match an SSSE3 integer HADD/HSUB.
+    if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1))
+      return DAG.getNode(X86ISD::HADD, DL, VT, InVec0, InVec1);
+    
+    if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1))
+      return DAG.getNode(X86ISD::HSUB, DL, VT, InVec0, InVec1);
+  }
+
+  if ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 ||
+       VT == MVT::v16i16) && Subtarget->hasAVX()) {
+    unsigned X86Opcode;
+    if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1))
+      X86Opcode = X86ISD::HADD;
+    else if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1))
+      X86Opcode = X86ISD::HSUB;
+    else if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1))
+      X86Opcode = X86ISD::FHADD;
+    else if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1))
+      X86Opcode = X86ISD::FHSUB;
+    else
+      return SDValue();
+
+    // Convert this build_vector into two horizontal add/sub followed by
+    // a concat vector.
+    SDValue InVec0_LO = Extract128BitVector(InVec0, 0, DAG, DL);
+    SDValue InVec0_HI = Extract128BitVector(InVec0, NumElts/2, DAG, DL);
+    SDValue InVec1_LO = Extract128BitVector(InVec1, 0, DAG, DL);
+    SDValue InVec1_HI = Extract128BitVector(InVec1, NumElts/2, DAG, DL);
+    EVT NewVT = InVec0_LO.getValueType();
+
+    SDValue LO = DAG.getNode(X86Opcode, DL, NewVT, InVec0_LO, InVec0_HI);
+    SDValue HI = DAG.getNode(X86Opcode, DL, NewVT, InVec1_LO, InVec1_HI);
+    return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LO, HI);
   }
 
   return SDValue();
index 7b875c0b5ed03b64acae985976d77ddf456e5127..72217b329fa3b113541f8e498fe8dfdd368db630 100644 (file)
@@ -86,12 +86,12 @@ define <4 x float> @hsub_ps_test2(<4 x float> %A, <4 x float> %B) {
   %vecext3 = extractelement <4 x float> %A, i32 1
   %sub4 = fsub float %vecext2, %vecext3
   %vecinit5 = insertelement <4 x float> %vecinit, float %sub4, i32 0
-  %vecext6 = extractelement <4 x float> %B, i32 3
-  %vecext7 = extractelement <4 x float> %B, i32 2
+  %vecext6 = extractelement <4 x float> %B, i32 2
+  %vecext7 = extractelement <4 x float> %B, i32 3
   %sub8 = fsub float %vecext6, %vecext7
   %vecinit9 = insertelement <4 x float> %vecinit5, float %sub8, i32 3
-  %vecext10 = extractelement <4 x float> %B, i32 1
-  %vecext11 = extractelement <4 x float> %B, i32 0
+  %vecext10 = extractelement <4 x float> %B, i32 0
+  %vecext11 = extractelement <4 x float> %B, i32 1
   %sub12 = fsub float %vecext10, %vecext11
   %vecinit13 = insertelement <4 x float> %vecinit9, float %sub12, i32 2
   ret <4 x float> %vecinit13
@@ -137,12 +137,12 @@ define <4 x i32> @phadd_d_test2(<4 x i32> %A, <4 x i32> %B) {
   %vecext3 = extractelement <4 x i32> %A, i32 1
   %add4 = add i32 %vecext2, %vecext3
   %vecinit5 = insertelement <4 x i32> %vecinit, i32 %add4, i32 0
-  %vecext6 = extractelement <4 x i32> %B, i32 2
-  %vecext7 = extractelement <4 x i32> %B, i32 3
+  %vecext6 = extractelement <4 x i32> %B, i32 3
+  %vecext7 = extractelement <4 x i32> %B, i32 2
   %add8 = add i32 %vecext6, %vecext7
   %vecinit9 = insertelement <4 x i32> %vecinit5, i32 %add8, i32 3
-  %vecext10 = extractelement <4 x i32> %B, i32 0
-  %vecext11 = extractelement <4 x i32> %B, i32 1
+  %vecext10 = extractelement <4 x i32> %B, i32 1
+  %vecext11 = extractelement <4 x i32> %B, i32 0
   %add12 = add i32 %vecext10, %vecext11
   %vecinit13 = insertelement <4 x i32> %vecinit9, i32 %add12, i32 2
   ret <4 x i32> %vecinit13
@@ -191,12 +191,12 @@ define <4 x i32> @phsub_d_test2(<4 x i32> %A, <4 x i32> %B) {
   %vecext3 = extractelement <4 x i32> %A, i32 1
   %sub4 = sub i32 %vecext2, %vecext3
   %vecinit5 = insertelement <4 x i32> %vecinit, i32 %sub4, i32 0
-  %vecext6 = extractelement <4 x i32> %B, i32 3
-  %vecext7 = extractelement <4 x i32> %B, i32 2
+  %vecext6 = extractelement <4 x i32> %B, i32 2
+  %vecext7 = extractelement <4 x i32> %B, i32 3
   %sub8 = sub i32 %vecext6, %vecext7
   %vecinit9 = insertelement <4 x i32> %vecinit5, i32 %sub8, i32 3
-  %vecext10 = extractelement <4 x i32> %B, i32 1
-  %vecext11 = extractelement <4 x i32> %B, i32 0
+  %vecext10 = extractelement <4 x i32> %B, i32 0
+  %vecext11 = extractelement <4 x i32> %B, i32 1
   %sub12 = sub i32 %vecext10, %vecext11
   %vecinit13 = insertelement <4 x i32> %vecinit9, i32 %sub12, i32 2
   ret <4 x i32> %vecinit13
@@ -258,14 +258,14 @@ define <2 x double> @hsub_pd_test1(<2 x double> %A, <2 x double> %B) {
 
 
 define <2 x double> @hsub_pd_test2(<2 x double> %A, <2 x double> %B) {
-  %vecext = extractelement <2 x double> %A, i32 1
-  %vecext1 = extractelement <2 x double> %A, i32 0
+  %vecext = extractelement <2 x double> %B, i32 0
+  %vecext1 = extractelement <2 x double> %B, i32 1
   %sub = fsub double %vecext, %vecext1
-  %vecinit = insertelement <2 x double> undef, double %sub, i32 0
-  %vecext2 = extractelement <2 x double> %B, i32 1
-  %vecext3 = extractelement <2 x double> %B, i32 0
+  %vecinit = insertelement <2 x double> undef, double %sub, i32 1
+  %vecext2 = extractelement <2 x double> %A, i32 0
+  %vecext3 = extractelement <2 x double> %A, i32 1
   %sub2 = fsub double %vecext2, %vecext3
-  %vecinit2 = insertelement <2 x double> %vecinit, double %sub2, i32 1
+  %vecinit2 = insertelement <2 x double> %vecinit, double %sub2, i32 0
   ret <2 x double> %vecinit2
 }
 ; CHECK-LABEL: hsub_pd_test2
@@ -458,3 +458,68 @@ define <16 x i16> @avx2_vphadd_w_test(<16 x i16> %a, <16 x i16> %b) {
 ; CHECK: ret
 
 
+; Verify that we don't select horizontal subs in the following functions.
+
+define <4 x i32> @not_a_hsub_1(<4 x i32> %A, <4 x i32> %B) {
+  %vecext = extractelement <4 x i32> %A, i32 0
+  %vecext1 = extractelement <4 x i32> %A, i32 1
+  %sub = sub i32 %vecext, %vecext1
+  %vecinit = insertelement <4 x i32> undef, i32 %sub, i32 0
+  %vecext2 = extractelement <4 x i32> %A, i32 2
+  %vecext3 = extractelement <4 x i32> %A, i32 3
+  %sub4 = sub i32 %vecext2, %vecext3
+  %vecinit5 = insertelement <4 x i32> %vecinit, i32 %sub4, i32 1
+  %vecext6 = extractelement <4 x i32> %B, i32 1
+  %vecext7 = extractelement <4 x i32> %B, i32 0
+  %sub8 = sub i32 %vecext6, %vecext7
+  %vecinit9 = insertelement <4 x i32> %vecinit5, i32 %sub8, i32 2
+  %vecext10 = extractelement <4 x i32> %B, i32 3
+  %vecext11 = extractelement <4 x i32> %B, i32 2
+  %sub12 = sub i32 %vecext10, %vecext11
+  %vecinit13 = insertelement <4 x i32> %vecinit9, i32 %sub12, i32 3
+  ret <4 x i32> %vecinit13
+}
+; CHECK-LABEL: not_a_hsub_1
+; CHECK-NOT: phsubd
+; CHECK: ret
+
+
+define <4 x float> @not_a_hsub_2(<4 x float> %A, <4 x float> %B) {
+  %vecext = extractelement <4 x float> %A, i32 2
+  %vecext1 = extractelement <4 x float> %A, i32 3
+  %sub = fsub float %vecext, %vecext1
+  %vecinit = insertelement <4 x float> undef, float %sub, i32 1
+  %vecext2 = extractelement <4 x float> %A, i32 0
+  %vecext3 = extractelement <4 x float> %A, i32 1
+  %sub4 = fsub float %vecext2, %vecext3
+  %vecinit5 = insertelement <4 x float> %vecinit, float %sub4, i32 0
+  %vecext6 = extractelement <4 x float> %B, i32 3
+  %vecext7 = extractelement <4 x float> %B, i32 2
+  %sub8 = fsub float %vecext6, %vecext7
+  %vecinit9 = insertelement <4 x float> %vecinit5, float %sub8, i32 3
+  %vecext10 = extractelement <4 x float> %B, i32 0
+  %vecext11 = extractelement <4 x float> %B, i32 1
+  %sub12 = fsub float %vecext10, %vecext11
+  %vecinit13 = insertelement <4 x float> %vecinit9, float %sub12, i32 2
+  ret <4 x float> %vecinit13
+}
+; CHECK-LABEL: not_a_hsub_2
+; CHECK-NOT: hsubps
+; CHECK: ret
+
+
+define <2 x double> @not_a_hsub_3(<2 x double> %A, <2 x double> %B) {
+  %vecext = extractelement <2 x double> %B, i32 0
+  %vecext1 = extractelement <2 x double> %B, i32 1
+  %sub = fsub double %vecext, %vecext1
+  %vecinit = insertelement <2 x double> undef, double %sub, i32 1
+  %vecext2 = extractelement <2 x double> %A, i32 1
+  %vecext3 = extractelement <2 x double> %A, i32 0
+  %sub2 = fsub double %vecext2, %vecext3
+  %vecinit2 = insertelement <2 x double> %vecinit, double %sub2, i32 0
+  ret <2 x double> %vecinit2
+}
+; CHECK-LABEL: not_a_hsub_3
+; CHECK-NOT: hsubpd
+; CHECK: ret
+