Add DAG combine for shl + add of constants.
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index 4c414005080d020cc3705a668cee613c56ad2afe..8a09838873a598c3b59d0159f6eb4ddf46de176e 100644 (file)
@@ -1517,28 +1517,6 @@ SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
   return SDValue(N, 0);   // Return N so it doesn't get rechecked!
 }
 
-static
-SDValue combineShlAddConstant(SDLoc DL, SDValue N0, SDValue N1,
-                              SelectionDAG &DAG) {
-  EVT VT = N0.getValueType();
-  SDValue N00 = N0.getOperand(0);
-  SDValue N01 = N0.getOperand(1);
-  ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N01);
-
-  if (N01C && N00.getOpcode() == ISD::ADD && N00.getNode()->hasOneUse() &&
-      isa<ConstantSDNode>(N00.getOperand(1))) {
-    // fold (add (shl (add x, c1), c2), ) -> (add (add (shl x, c2), c1<<c2), )
-    N0 = DAG.getNode(ISD::ADD, SDLoc(N0), VT,
-                     DAG.getNode(ISD::SHL, SDLoc(N00), VT,
-                                 N00.getOperand(0), N01),
-                     DAG.getNode(ISD::SHL, SDLoc(N01), VT,
-                                 N00.getOperand(1), N01));
-    return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
-  }
-
-  return SDValue();
-}
-
 SDValue DAGCombiner::visitADD(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
@@ -1655,16 +1633,6 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
     }
   }
 
-  // fold (add (shl (add x, c1), c2), ) -> (add (add (shl x, c2), c1<<c2), )
-  if (N0.getOpcode() == ISD::SHL && N0.getNode()->hasOneUse()) {
-    SDValue Result = combineShlAddConstant(SDLoc(N), N0, N1, DAG);
-    if (Result.getNode()) return Result;
-  }
-  if (N1.getOpcode() == ISD::SHL && N1.getNode()->hasOneUse()) {
-    SDValue Result = combineShlAddConstant(SDLoc(N), N1, N0, DAG);
-    if (Result.getNode()) return Result;
-  }
-
   // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
   if (N1.getOpcode() == ISD::SHL &&
       N1.getOperand(0).getOpcode() == ISD::SUB)
@@ -2664,9 +2632,17 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
 
     // fold (and x, 0) -> 0, vector edition
     if (ISD::isBuildVectorAllZeros(N0.getNode()))
-      return N0;
+      // do not return N0, because undef node may exist in N0
+      return DAG.getConstant(
+          APInt::getNullValue(
+              N0.getValueType().getScalarType().getSizeInBits()),
+          N0.getValueType());
     if (ISD::isBuildVectorAllZeros(N1.getNode()))
-      return N1;
+      // do not return N1, because undef node may exist in N1
+      return DAG.getConstant(
+          APInt::getNullValue(
+              N1.getValueType().getScalarType().getSizeInBits()),
+          N1.getValueType());
 
     // fold (and x, -1) -> x, vector edition
     if (ISD::isBuildVectorAllOnes(N0.getNode()))
@@ -3312,9 +3288,17 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
 
     // fold (or x, -1) -> -1, vector edition
     if (ISD::isBuildVectorAllOnes(N0.getNode()))
-      return N0;
+      // do not return N0, because undef node may exist in N0
+      return DAG.getConstant(
+          APInt::getAllOnesValue(
+              N0.getValueType().getScalarType().getSizeInBits()),
+          N0.getValueType());
     if (ISD::isBuildVectorAllOnes(N1.getNode()))
-      return N1;
+      // do not return N1, because undef node may exist in N1
+      return DAG.getConstant(
+          APInt::getAllOnesValue(
+              N1.getValueType().getScalarType().getSizeInBits()),
+          N1.getValueType());
 
     // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask1)
     // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf B, A, Mask2)
@@ -4169,6 +4153,18 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
                        HiBitsMask);
   }
 
+  // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
+  // Variant of version done on multiply, except mul by a power of 2 is turned
+  // into a shift.
+  APInt Val;
+  if (N1C && N0.getOpcode() == ISD::ADD && N0.getNode()->hasOneUse() &&
+      (isa<ConstantSDNode>(N0.getOperand(1)) ||
+       isConstantSplatVector(N0.getOperand(1).getNode(), Val))) {
+    SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
+    SDValue Shl1 = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
+    return DAG.getNode(ISD::ADD, SDLoc(N), VT, Shl0, Shl1);
+  }
+
   if (N1C) {
     SDValue NewSHL = visitShiftByConstant(N, N1C);
     if (NewSHL.getNode())
@@ -6602,22 +6598,11 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
     // We can fold chains of FADD's of the same value into multiplications.
     // This transform is not safe in general because we are reducing the number
     // of rounding steps.
-
-    // FIXME: There are 4 redundant folds below where fmul canonicalization
-    // should have moved the constant out of Op0?
     if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
       if (N0.getOpcode() == ISD::FMUL) {
         ConstantFPSDNode *CFP00 = dyn_cast<ConstantFPSDNode>(N0.getOperand(0));
         ConstantFPSDNode *CFP01 = dyn_cast<ConstantFPSDNode>(N0.getOperand(1));
         
-        // (fadd (fmul c, x), x) -> (fmul x, c+1)
-        if (CFP00 && !CFP01 && N0.getOperand(1) == N1) {
-          SDValue NewCFP = DAG.getNode(ISD::FADD, SDLoc(N), VT,
-                                       SDValue(CFP00, 0),
-                                       DAG.getConstantFP(1.0, VT));
-          return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N1, NewCFP);
-        }
-        
         // (fadd (fmul x, c), x) -> (fmul x, c+1)
         if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
           SDValue NewCFP = DAG.getNode(ISD::FADD, SDLoc(N), VT,
@@ -6625,17 +6610,6 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
                                        DAG.getConstantFP(1.0, VT));
           return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N1, NewCFP);
         }
-
-        // (fadd (fmul c, x), (fadd x, x)) -> (fmul x, c+2)
-        if (CFP00 && !CFP01 && N1.getOpcode() == ISD::FADD &&
-            N1.getOperand(0) == N1.getOperand(1) &&
-            N0.getOperand(1) == N1.getOperand(0)) {
-          SDValue NewCFP = DAG.getNode(ISD::FADD, SDLoc(N), VT,
-                                       SDValue(CFP00, 0),
-                                       DAG.getConstantFP(2.0, VT));
-          return DAG.getNode(ISD::FMUL, SDLoc(N), VT,
-                             N0.getOperand(1), NewCFP);
-        }
         
         // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
         if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
@@ -6653,14 +6627,6 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
         ConstantFPSDNode *CFP10 = dyn_cast<ConstantFPSDNode>(N1.getOperand(0));
         ConstantFPSDNode *CFP11 = dyn_cast<ConstantFPSDNode>(N1.getOperand(1));
         
-        // (fadd x, (fmul c, x)) -> (fmul x, c+1)
-        if (CFP10 && !CFP11 && N1.getOperand(1) == N0) {
-          SDValue NewCFP = DAG.getNode(ISD::FADD, SDLoc(N), VT,
-                                       SDValue(CFP10, 0),
-                                       DAG.getConstantFP(1.0, VT));
-          return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0, NewCFP);
-        }
-        
         // (fadd x, (fmul x, c)) -> (fmul x, c+1)
         if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
           SDValue NewCFP = DAG.getNode(ISD::FADD, SDLoc(N), VT,
@@ -6668,18 +6634,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
                                        DAG.getConstantFP(1.0, VT));
           return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0, NewCFP);
         }
-        
-        
-        // (fadd (fadd x, x), (fmul c, x)) -> (fmul x, c+2)
-        if (CFP10 && !CFP11 && N0.getOpcode() == ISD::FADD &&
-            N0.getOperand(0) == N0.getOperand(1) &&
-            N1.getOperand(1) == N0.getOperand(0)) {
-          SDValue NewCFP = DAG.getNode(ISD::FADD, SDLoc(N), VT,
-                                       SDValue(CFP10, 0),
-                                       DAG.getConstantFP(2.0, VT));
-          return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N1.getOperand(1), NewCFP);
-        }
-        
+
         // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
         if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
             N0.getOperand(0) == N0.getOperand(1) &&
@@ -6690,7 +6645,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
           return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N1.getOperand(0), NewCFP);
         }
       }
-      
+
       if (N0.getOpcode() == ISD::FADD && AllowNewConst) {
         ConstantFPSDNode *CFP = dyn_cast<ConstantFPSDNode>(N0.getOperand(0));
         // (fadd (fadd x, x), x) -> (fmul x, 3.0)
@@ -6845,33 +6800,52 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
 
   // fold vector ops
   if (VT.isVector()) {
+    // This just handles C1 * C2 for vectors. Other vector folds are below.
     SDValue FoldedVOp = SimplifyVBinOp(N);
-    if (FoldedVOp.getNode()) return FoldedVOp;
+    if (FoldedVOp.getNode())
+      return FoldedVOp;
+    // Canonicalize vector constant to RHS.
+    if (N0.getOpcode() == ISD::BUILD_VECTOR &&
+        N1.getOpcode() != ISD::BUILD_VECTOR)
+      if (auto *BV0 = dyn_cast<BuildVectorSDNode>(N0))
+        if (BV0->isConstant())
+          return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
   }
 
   // fold (fmul c1, c2) -> c1*c2
   if (N0CFP && N1CFP)
     return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0, N1);
+
   // canonicalize constant to RHS
   if (N0CFP && !N1CFP)
     return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N1, N0);
-  // fold (fmul A, 0) -> 0
-  if (Options.UnsafeFPMath && N1CFP && N1CFP->getValueAPF().isZero())
-    return N1;
+
   // fold (fmul A, 1.0) -> A
   if (N1CFP && N1CFP->isExactlyValue(1.0))
     return N0;
 
-  if (DAG.getTarget().Options.UnsafeFPMath) {
-    // If allowed, fold (fmul (fmul x, c1), c2) -> (fmul x, (fmul c1, c2))
-    if (N1CFP && N0.getOpcode() == ISD::FMUL &&
-        N0.getNode()->hasOneUse() && isConstOrConstSplatFP(N0.getOperand(1))) {
-      SDLoc SL(N);
-      SDValue MulConsts = DAG.getNode(ISD::FMUL, SL, VT, N0.getOperand(1), N1);
-      return DAG.getNode(ISD::FMUL, SL, VT, N0.getOperand(0), MulConsts);
+  if (Options.UnsafeFPMath) {
+    // fold (fmul A, 0) -> 0
+    if (N1CFP && N1CFP->getValueAPF().isZero())
+      return N1;
+
+    // fold (fmul (fmul x, c1), c2) -> (fmul x, (fmul c1, c2))
+    if (N0.getOpcode() == ISD::FMUL) {
+      // Fold scalars or any vector constants (not just splats).
+      // This fold is done in general by InstCombine, but extra fmul insts
+      // may have been generated during lowering.
+      SDValue N01 = N0.getOperand(1);
+      auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
+      auto *BV01 = dyn_cast<BuildVectorSDNode>(N01);
+      if ((N1CFP && isConstOrConstSplatFP(N01)) ||
+          (BV1 && BV01 && BV1->isConstant() && BV01->isConstant())) {
+        SDLoc SL(N);
+        SDValue MulConsts = DAG.getNode(ISD::FMUL, SL, VT, N01, N1);
+        return DAG.getNode(ISD::FMUL, SL, VT, N0.getOperand(0), MulConsts);
+      }
     }
 
-    // If allowed, fold (fmul (fadd x, x), c) -> (fmul x, (fmul 2.0, c))
+    // fold (fmul (fadd x, x), c) -> (fmul x, (fmul 2.0, c))
     // Undo the fmul 2.0, x -> fadd x, x transformation, since if it occurs
     // during an early run of DAGCombiner can prevent folding with fmuls
     // inserted during lowering.
@@ -6886,6 +6860,7 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
   // fold (fmul X, 2.0) -> (fadd X, X)
   if (N1CFP && N1CFP->isExactlyValue(+2.0))
     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N0);
+
   // fold (fmul X, -1.0) -> (fneg X)
   if (N1CFP && N1CFP->isExactlyValue(-1.0))
     if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))