[DAGCombiner] Improved FMA combine support for vectors
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index 3280060316c044163c6179406d54bd8fd8f75a39..c1f962f14069d2ca592a6bfb4d399d593d803cca 100644 (file)
@@ -8331,7 +8331,8 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
 
   // Canonicalize (fma c, x, y) -> (fma x, c, y)
-  if (N0CFP && !N1CFP)
+  if (isConstantFPBuildVectorOrConstantFP(N0) &&
+     !isConstantFPBuildVectorOrConstantFP(N1))
     return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
 
   // TODO: FMA nodes should have flags that propagate to the created nodes.
@@ -8339,26 +8340,26 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
   SDNodeFlags Flags;
   Flags.setUnsafeAlgebra(true);
 
-  // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
-  if (Options.UnsafeFPMath && N1CFP &&
-      N2.getOpcode() == ISD::FMUL &&
-      N0 == N2.getOperand(0) &&
-      N2.getOperand(1).getOpcode() == ISD::ConstantFP) {
-    return DAG.getNode(ISD::FMUL, dl, VT, N0,
-                       DAG.getNode(ISD::FADD, dl, VT, N1, N2.getOperand(1),
-                                   &Flags), &Flags);
-  }
-
+  if (Options.UnsafeFPMath) {
+    // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
+    if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) &&
+        isConstantFPBuildVectorOrConstantFP(N1) &&
+        isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
+      return DAG.getNode(ISD::FMUL, dl, VT, N0,
+                         DAG.getNode(ISD::FADD, dl, VT, N1, N2.getOperand(1),
+                                     &Flags), &Flags);
+    }
 
-  // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
-  if (Options.UnsafeFPMath &&
-      N0.getOpcode() == ISD::FMUL && N1CFP &&
-      N0.getOperand(1).getOpcode() == ISD::ConstantFP) {
-    return DAG.getNode(ISD::FMA, dl, VT,
-                       N0.getOperand(0),
-                       DAG.getNode(ISD::FMUL, dl, VT, N1, N0.getOperand(1),
-                                   &Flags),
-                       N2);
+    // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
+    if (N0.getOpcode() == ISD::FMUL &&
+        isConstantFPBuildVectorOrConstantFP(N1) &&
+        isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
+      return DAG.getNode(ISD::FMA, dl, VT,
+                         N0.getOperand(0),
+                         DAG.getNode(ISD::FMUL, dl, VT, N1, N0.getOperand(1),
+                                     &Flags),
+                         N2);
+    }
   }
 
   // (fma x, 1, y) -> (fadd x, y)
@@ -8377,20 +8378,22 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
     }
   }
 
-  // (fma x, c, x) -> (fmul x, (c+1))
-  if (Options.UnsafeFPMath && N1CFP && N0 == N2) {
-    return DAG.getNode(ISD::FMUL, dl, VT, N0,
-                       DAG.getNode(ISD::FADD, dl, VT,
-                                   N1, DAG.getConstantFP(1.0, dl, VT),
-                                   &Flags), &Flags);
-  }
-  // (fma x, c, (fneg x)) -> (fmul x, (c-1))
-  if (Options.UnsafeFPMath && N1CFP &&
-      N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
+  if (Options.UnsafeFPMath) {
+    // (fma x, c, x) -> (fmul x, (c+1))
+    if (N1CFP && N0 == N2) {
     return DAG.getNode(ISD::FMUL, dl, VT, N0,
-                       DAG.getNode(ISD::FADD, dl, VT,
-                                   N1, DAG.getConstantFP(-1.0, dl, VT),
-                                   &Flags), &Flags);
+                         DAG.getNode(ISD::FADD, dl, VT,
+                                     N1, DAG.getConstantFP(1.0, dl, VT),
+                                     &Flags), &Flags);
+    }
+
+    // (fma x, c, (fneg x)) -> (fmul x, (c-1))
+    if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
+      return DAG.getNode(ISD::FMUL, dl, VT, N0,
+                         DAG.getNode(ISD::FADD, dl, VT,
+                                     N1, DAG.getConstantFP(-1.0, dl, VT),
+                                     &Flags), &Flags);
+    }
   }
 
   return SDValue();