DAGCombiner: Simplify code a bit, make more transforms work with vectors.
authorBenjamin Kramer <benny.kra@googlemail.com>
Sat, 26 Apr 2014 23:09:49 +0000 (23:09 +0000)
committerBenjamin Kramer <benny.kra@googlemail.com>
Sat, 26 Apr 2014 23:09:49 +0000 (23:09 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@207338 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/SelectionDAG/DAGCombiner.cpp
test/CodeGen/X86/vector-idiv.ll

index 0156fe1c0ec8a24f0031a89d90dbc723fdb93637..2ca3f3e452c7440d4ba80f893df54634065c3b00 100644 (file)
@@ -644,8 +644,13 @@ static ConstantSDNode *isConstOrConstSplat(SDValue N) {
   if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
     return CN;
 
-  if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N))
-    return BV->getConstantSplatValue();
+  if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
+    ConstantSDNode *CN = BV->getConstantSplatValue();
+
+    // BuildVectors can truncate their operands. Ignore that case here.
+    if (CN && CN->getValueType(0) == N.getValueType().getScalarType())
+      return CN;
+  }
 
   return nullptr;
 }
@@ -1957,8 +1962,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
 SDValue DAGCombiner::visitSDIV(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
-  ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0.getNode());
-  ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
+  ConstantSDNode *N0C = isConstOrConstSplat(N0);
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
   EVT VT = N->getValueType(0);
 
   // fold vector ops
@@ -1985,25 +1990,15 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
                          N0, N1);
   }
 
-  const APInt *Divisor = nullptr;
-  if (N1C) {
-    Divisor = &N1C->getAPIntValue();
-  } else if (N1.getValueType().isVector() &&
-             N1->getOpcode() == ISD::BUILD_VECTOR) {
-    BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1));
-    if (ConstantSDNode *C = BV->getConstantSplatValue())
-      Divisor = &C->getAPIntValue();
-  }
-
   // fold (sdiv X, pow2) -> simple ops after legalize
-  if (Divisor && !!*Divisor &&
-      (Divisor->isPowerOf2() || (-*Divisor).isPowerOf2())) {
+  if (N1C && !N1C->isNullValue() && (N1C->getAPIntValue().isPowerOf2() ||
+                                     (-N1C->getAPIntValue()).isPowerOf2())) {
     // If dividing by powers of two is cheap, then don't perform the following
     // fold.
     if (TLI.isPow2DivCheap())
       return SDValue();
 
-    unsigned lg2 = Divisor->countTrailingZeros();
+    unsigned lg2 = N1C->getAPIntValue().countTrailingZeros();
 
     // Splat the sign bit into the register
     SDValue SGN =
@@ -2025,7 +2020,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
 
     // If we're dividing by a positive value, we're done.  Otherwise, we must
     // negate the result.
-    if (Divisor->isNonNegative())
+    if (N1C->getAPIntValue().isNonNegative())
       return SRA;
 
     AddToWorkList(SRA.getNode());
@@ -2034,7 +2029,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
 
   // if integer divide is expensive and we satisfy the requirements, emit an
   // alternate sequence.
-  if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) {
+  if (N1C && !TLI.isIntDivCheap()) {
     SDValue Op = BuildSDIV(N);
     if (Op.getNode()) return Op;
   }
@@ -2052,8 +2047,8 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
 SDValue DAGCombiner::visitUDIV(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
-  ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0.getNode());
-  ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
+  ConstantSDNode *N0C = isConstOrConstSplat(N0);
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
   EVT VT = N->getValueType(0);
 
   // fold vector ops
@@ -2086,7 +2081,7 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
     }
   }
   // fold (udiv x, c) -> alternate
-  if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) {
+  if (N1C && !TLI.isIntDivCheap()) {
     SDValue Op = BuildUDIV(N);
     if (Op.getNode()) return Op;
   }
@@ -2104,8 +2099,8 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
 SDValue DAGCombiner::visitSREM(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
-  ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
-  ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
+  ConstantSDNode *N0C = isConstOrConstSplat(N0);
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
   EVT VT = N->getValueType(0);
 
   // fold (srem c1, c2) -> c1%c2
@@ -2146,8 +2141,8 @@ SDValue DAGCombiner::visitSREM(SDNode *N) {
 SDValue DAGCombiner::visitUREM(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
-  ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
-  ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
+  ConstantSDNode *N0C = isConstOrConstSplat(N0);
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
   EVT VT = N->getValueType(0);
 
   // fold (urem c1, c2) -> c1%c2
@@ -11187,28 +11182,20 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0,
 /// multiplying by a magic number.  See:
 /// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
-  const APInt *Divisor;
-  if (N->getValueType(0).isVector()) {
-    // Handle splat vectors.
-    BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1));
-    if (ConstantSDNode *C = BV->getConstantSplatValue())
-      Divisor = &C->getAPIntValue();
-    else
-      return SDValue();
-  } else {
-    Divisor = &cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
-  }
+  ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
+  if (!C)
+    return SDValue();
 
   // Avoid division by zero.
-  if (!*Divisor)
+  if (!C->getAPIntValue())
     return SDValue();
 
   std::vector<SDNode*> Built;
-  SDValue S = TLI.BuildSDIV(N, *Divisor, DAG, LegalOperations, &Built);
+  SDValue S =
+      TLI.BuildSDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built);
 
-  for (std::vector<SDNode*>::iterator ii = Built.begin(), ee = Built.end();
-       ii != ee; ++ii)
-    AddToWorkList(*ii);
+  for (SDNode *N : Built)
+    AddToWorkList(N);
   return S;
 }
 
@@ -11217,28 +11204,20 @@ SDValue DAGCombiner::BuildSDIV(SDNode *N) {
 /// multiplying by a magic number.  See:
 /// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html>
 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
-  const APInt *Divisor;
-  if (N->getValueType(0).isVector()) {
-    // Handle splat vectors.
-    BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1));
-    if (ConstantSDNode *C = BV->getConstantSplatValue())
-      Divisor = &C->getAPIntValue();
-    else
-      return SDValue();
-  } else {
-    Divisor = &cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
-  }
+  ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
+  if (!C)
+    return SDValue();
 
   // Avoid division by zero.
-  if (!*Divisor)
+  if (!C->getAPIntValue())
     return SDValue();
 
   std::vector<SDNode*> Built;
-  SDValue S = TLI.BuildUDIV(N, *Divisor, DAG, LegalOperations, &Built);
+  SDValue S =
+      TLI.BuildUDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built);
 
-  for (std::vector<SDNode*>::iterator ii = Built.begin(), ee = Built.end();
-       ii != ee; ++ii)
-    AddToWorkList(*ii);
+  for (SDNode *N : Built)
+    AddToWorkList(N);
   return S;
 }
 
index 06af3434b1a4eafa33d9a1cb1f192f03a012d17d..3b300f74061784684e52559cfb8c9b3f4e4cd01e 100644 (file)
@@ -151,3 +151,38 @@ define <8 x i32> @test9(<8 x i32> %a) {
 ; AVX: vpsrad $2
 ; AVX: vpadd
 }
+
+define <8 x i32> @test10(<8 x i32> %a) {
+  %rem = urem <8 x i32> %a, <i32 7, i32 7, i32 7, i32 7,i32 7, i32 7, i32 7, i32 7>
+  ret <8 x i32> %rem
+
+; AVX-LABEL: test10:
+; AVX: vpermd
+; AVX: vpmuludq
+; AVX: vshufps $-35
+; AVX: vpmuludq
+; AVX: vshufps $-35
+; AVX: vpsubd
+; AVX: vpsrld $1
+; AVX: vpadd
+; AVX: vpsrld $2
+; AVX: vpmulld
+}
+
+define <8 x i32> @test11(<8 x i32> %a) {
+  %rem = srem <8 x i32> %a, <i32 7, i32 7, i32 7, i32 7,i32 7, i32 7, i32 7, i32 7>
+  ret <8 x i32> %rem
+
+; AVX-LABEL: test11:
+; AVX: vpermd
+; AVX: vpmuldq
+; AVX: vshufps $-35
+; AVX: vpmuldq
+; AVX: vshufps $-35
+; AVX: vpshufd $-40
+; AVX: vpadd
+; AVX: vpsrld $31
+; AVX: vpsrad $2
+; AVX: vpadd
+; AVX: vpmulld
+}