DAGCombiner: Turn divs of vector splats into vectorized multiplications.
authorBenjamin Kramer <benny.kra@googlemail.com>
Sat, 26 Apr 2014 12:06:28 +0000 (12:06 +0000)
committerBenjamin Kramer <benny.kra@googlemail.com>
Sat, 26 Apr 2014 12:06:28 +0000 (12:06 +0000)
Otherwise the legalizer would just scalarize everything. Support for
mulhi in the targets isn't that great yet so on most targets we get
exactly the same scalarized output. Add a test for x86 vector udiv.

I had to disable the mulhi nodes on ARM because there aren't any patterns
for it. As far as I know ARM has instructions for getting the high part of
a multiply so this should be fixed.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@207315 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Target/TargetLowering.h
lib/CodeGen/SelectionDAG/DAGCombiner.cpp
lib/CodeGen/SelectionDAG/TargetLowering.cpp
lib/Target/ARM/ARMISelLowering.cpp
lib/Target/ARM64/ARM64ISelLowering.cpp
test/CodeGen/X86/vector-idiv.ll [new file with mode: 0644]

index 8dba94fc272c7a53b1e77970559c634720b681a9..aadfca964aafab849d541b5bc43ca36d2941861b 100644 (file)
@@ -2417,10 +2417,12 @@ public:
   //
   SDValue BuildExactSDIV(SDValue Op1, SDValue Op2, SDLoc dl,
                          SelectionDAG &DAG) const;
-  SDValue BuildSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
-                      std::vector<SDNode*> *Created) const;
-  SDValue BuildUDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
-                      std::vector<SDNode*> *Created) const;
+  SDValue BuildSDIV(SDNode *N, const APInt &Divisor, SelectionDAG &DAG,
+                    bool IsAfterLegalization,
+                    std::vector<SDNode *> *Created) const;
+  SDValue BuildUDIV(SDNode *N, const APInt &Divisor, SelectionDAG &DAG,
+                    bool IsAfterLegalization,
+                    std::vector<SDNode *> *Created) const;
 
   //===--------------------------------------------------------------------===//
   // Legalization utility functions
index a52dacf5216b293e217d2865e35f4ead35186953..7850bc25892034240a74ae40e33c19dc32956f34 100644 (file)
@@ -2024,7 +2024,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
 
   // if integer divide is expensive and we satisfy the requirements, emit an
   // alternate sequence.
-  if (N1C && !N1C->isNullValue() && !TLI.isIntDivCheap()) {
+  if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) {
     SDValue Op = BuildSDIV(N);
     if (Op.getNode()) return Op;
   }
@@ -2076,7 +2076,7 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
     }
   }
   // fold (udiv x, c) -> alternate
-  if (N1C && !N1C->isNullValue() && !TLI.isIntDivCheap()) {
+  if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) {
     SDValue Op = BuildUDIV(N);
     if (Op.getNode()) return Op;
   }
@@ -11191,8 +11191,24 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0,
 /// multiplying by a magic number.  See:
 /// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
+  const APInt *Divisor;
+  if (N->getValueType(0).isVector()) {
+    // Handle splat vectors.
+    BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1));
+    if (ConstantSDNode *C = BV->getConstantSplatValue())
+      Divisor = &C->getAPIntValue();
+    else
+      return SDValue();
+  } else {
+    Divisor = &cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
+  }
+
+  // Avoid division by zero.
+  if (!*Divisor)
+    return SDValue();
+
   std::vector<SDNode*> Built;
-  SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, &Built);
+  SDValue S = TLI.BuildSDIV(N, *Divisor, DAG, LegalOperations, &Built);
 
   for (std::vector<SDNode*>::iterator ii = Built.begin(), ee = Built.end();
        ii != ee; ++ii)
@@ -11200,13 +11216,29 @@ SDValue DAGCombiner::BuildSDIV(SDNode *N) {
   return S;
 }
 
-/// BuildUDIVSequence - Given an ISD::UDIV node expressing a divide by constant,
+/// BuildUDIV - Given an ISD::UDIV node expressing a divide by constant,
 /// return a DAG expression to select that will generate the same value by
 /// multiplying by a magic number.  See:
 /// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
+  const APInt *Divisor;
+  if (N->getValueType(0).isVector()) {
+    // Handle splat vectors.
+    BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1));
+    if (ConstantSDNode *C = BV->getConstantSplatValue())
+      Divisor = &C->getAPIntValue();
+    else
+      return SDValue();
+  } else {
+    Divisor = &cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
+  }
+
+  // Avoid division by zero.
+  if (!*Divisor)
+    return SDValue();
+
   std::vector<SDNode*> Built;
-  SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, &Built);
+  SDValue S = TLI.BuildUDIV(N, *Divisor, DAG, LegalOperations, &Built);
 
   for (std::vector<SDNode*>::iterator ii = Built.begin(), ee = Built.end();
        ii != ee; ++ii)
index b5e43587a33bc3ba5c28c1d80d76ff37a44e8335..dc92795d623ac36393b0d3029edbd7630f68c8d7 100644 (file)
@@ -2602,9 +2602,9 @@ SDValue TargetLowering::BuildExactSDIV(SDValue Op1, SDValue Op2, SDLoc dl,
 /// return a DAG expression to select that will generate the same value by
 /// multiplying by a magic number.  See:
 /// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
-SDValue TargetLowering::
-BuildSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
-          std::vector<SDNode*> *Created) const {
+SDValue TargetLowering::BuildSDIV(SDNode *N, const APInt &Divisor,
+                                  SelectionDAG &DAG, bool IsAfterLegalization,
+                                  std::vector<SDNode *> *Created) const {
   EVT VT = N->getValueType(0);
   SDLoc dl(N);
 
@@ -2613,8 +2613,7 @@ BuildSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
   if (!isTypeLegal(VT))
     return SDValue();
 
-  APInt d = cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
-  APInt::ms magics = d.magic();
+  APInt::ms magics = Divisor.magic();
 
   // Multiply the numerator (operand 0) by the magic value
   // FIXME: We should support doing a MUL in a wider type
@@ -2631,13 +2630,13 @@ BuildSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
   else
     return SDValue();       // No mulhs or equvialent
   // If d > 0 and m < 0, add the numerator
-  if (d.isStrictlyPositive() && magics.m.isNegative()) {
+  if (Divisor.isStrictlyPositive() && magics.m.isNegative()) {
     Q = DAG.getNode(ISD::ADD, dl, VT, Q, N->getOperand(0));
     if (Created)
       Created->push_back(Q.getNode());
   }
   // If d < 0 and m > 0, subtract the numerator.
-  if (d.isNegative() && magics.m.isStrictlyPositive()) {
+  if (Divisor.isNegative() && magics.m.isStrictlyPositive()) {
     Q = DAG.getNode(ISD::SUB, dl, VT, Q, N->getOperand(0));
     if (Created)
       Created->push_back(Q.getNode());
@@ -2650,9 +2649,9 @@ BuildSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
       Created->push_back(Q.getNode());
   }
   // Extract the sign bit and add it to the quotient
-  SDValue T =
-    DAG.getNode(ISD::SRL, dl, VT, Q, DAG.getConstant(VT.getSizeInBits()-1,
-                                           getShiftAmountTy(Q.getValueType())));
+  SDValue T = DAG.getNode(ISD::SRL, dl, VT, Q,
+                          DAG.getConstant(VT.getScalarSizeInBits() - 1,
+                                          getShiftAmountTy(Q.getValueType())));
   if (Created)
     Created->push_back(T.getNode());
   return DAG.getNode(ISD::ADD, dl, VT, Q, T);
@@ -2662,9 +2661,9 @@ BuildSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
 /// return a DAG expression to select that will generate the same value by
 /// multiplying by a magic number.  See:
 /// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
-SDValue TargetLowering::
-BuildUDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
-          std::vector<SDNode*> *Created) const {
+SDValue TargetLowering::BuildUDIV(SDNode *N, const APInt &Divisor,
+                                  SelectionDAG &DAG, bool IsAfterLegalization,
+                                  std::vector<SDNode *> *Created) const {
   EVT VT = N->getValueType(0);
   SDLoc dl(N);
 
@@ -2675,22 +2674,21 @@ BuildUDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
 
   // FIXME: We should use a narrower constant when the upper
   // bits are known to be zero.
-  const APInt &N1C = cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
-  APInt::mu magics = N1C.magicu();
+  APInt::mu magics = Divisor.magicu();
 
   SDValue Q = N->getOperand(0);
 
   // If the divisor is even, we can avoid using the expensive fixup by shifting
   // the divided value upfront.
-  if (magics.a != 0 && !N1C[0]) {
-    unsigned Shift = N1C.countTrailingZeros();
+  if (magics.a != 0 && !Divisor[0]) {
+    unsigned Shift = Divisor.countTrailingZeros();
     Q = DAG.getNode(ISD::SRL, dl, VT, Q,
                     DAG.getConstant(Shift, getShiftAmountTy(Q.getValueType())));
     if (Created)
       Created->push_back(Q.getNode());
 
     // Get magic number for the shifted divisor.
-    magics = N1C.lshr(Shift).magicu(Shift);
+    magics = Divisor.lshr(Shift).magicu(Shift);
     assert(magics.a == 0 && "Should use cheap fixup now");
   }
 
@@ -2709,7 +2707,7 @@ BuildUDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
     Created->push_back(Q.getNode());
 
   if (magics.a == 0) {
-    assert(magics.s < N1C.getBitWidth() &&
+    assert(magics.s < Divisor.getBitWidth() &&
            "We shouldn't generate an undefined shift!");
     return DAG.getNode(ISD::SRL, dl, VT, Q,
                  DAG.getConstant(magics.s, getShiftAmountTy(Q.getValueType())));
index 7c302ca2e4d377adc242711517a4f2f8a429c65d..d966397880fd40dfda84309134f173d160292c49 100644 (file)
@@ -446,6 +446,11 @@ ARMTargetLowering::ARMTargetLowering(TargetMachine &TM)
     setLoadExtAction(ISD::SEXTLOAD, (MVT::SimpleValueType)VT, Expand);
     setLoadExtAction(ISD::ZEXTLOAD, (MVT::SimpleValueType)VT, Expand);
     setLoadExtAction(ISD::EXTLOAD, (MVT::SimpleValueType)VT, Expand);
+
+    setOperationAction(ISD::MULHS, (MVT::SimpleValueType)VT, Expand);
+    setOperationAction(ISD::SMUL_LOHI, (MVT::SimpleValueType)VT, Expand);
+    setOperationAction(ISD::MULHU, (MVT::SimpleValueType)VT, Expand);
+    setOperationAction(ISD::UMUL_LOHI, (MVT::SimpleValueType)VT, Expand);
   }
 
   setOperationAction(ISD::ConstantFP, MVT::f32, Custom);
index d295dd1812705b17527c2d335af112045c333650..1881e8833c28fe45a125e5513b65b95333d3c169 100644 (file)
@@ -435,6 +435,11 @@ ARM64TargetLowering::ARM64TargetLowering(ARM64TargetMachine &TM)
       setOperationAction(ISD::SIGN_EXTEND_INREG, (MVT::SimpleValueType)VT,
                          Expand);
 
+      setOperationAction(ISD::MULHS, (MVT::SimpleValueType)VT, Expand);
+      setOperationAction(ISD::SMUL_LOHI, (MVT::SimpleValueType)VT, Expand);
+      setOperationAction(ISD::MULHU, (MVT::SimpleValueType)VT, Expand);
+      setOperationAction(ISD::UMUL_LOHI, (MVT::SimpleValueType)VT, Expand);
+
       for (unsigned InnerVT = (unsigned)MVT::FIRST_VECTOR_VALUETYPE;
            InnerVT <= (unsigned)MVT::LAST_VECTOR_VALUETYPE; ++InnerVT)
         setTruncStoreAction((MVT::SimpleValueType)VT,
diff --git a/test/CodeGen/X86/vector-idiv.ll b/test/CodeGen/X86/vector-idiv.ll
new file mode 100644 (file)
index 0000000..5b8153a
--- /dev/null
@@ -0,0 +1,45 @@
+; RUN: llc -march=x86-64 -mcpu=core2 < %s | FileCheck %s -check-prefix=SSE
+; RUN: llc -march=x86-64 -mcpu=core-avx2 < %s | FileCheck %s -check-prefix=AVX
+
+define <4 x i32> @test1(<4 x i32> %a) {
+  %div = udiv <4 x i32> %a, <i32 7, i32 7, i32 7, i32 7>
+  ret <4 x i32> %div
+
+; SSE-LABEL: test1:
+; SSE: pmuludq
+; SSE: pshufd  $57
+; SSE: pmuludq
+; SSE: shufps  $-35
+; SSE: psubd
+; SSE: psrld $1
+; SSE: padd
+; SSE: psrld $2
+
+; AVX-LABEL: test1:
+; AVX: vpmuludq
+; AVX: vpshufd $57
+; AVX: vpmuludq
+; AVX: vshufps $-35
+; AVX: vpsubd
+; AVX: vpsrld $1
+; AVX: vpadd
+; AVX: vpsrld $2
+}
+
+define <8 x i32> @test2(<8 x i32> %a) {
+  %div = udiv <8 x i32> %a, <i32 7, i32 7, i32 7, i32 7,i32 7, i32 7, i32 7, i32 7>
+  ret <8 x i32> %div
+
+; AVX-LABEL: test2:
+; AVX: vpermd
+; AVX: vpmuludq
+; AVX: vshufps $-35
+; AVX: vpmuludq
+; AVX: vshufps $-35
+; AVX: vpsubd
+; AVX: vpsrld $1
+; AVX: vpadd
+; AVX: vpsrld $2
+}
+
+; TODO: sdiv -> pmuldq