Combine fmul vector FP constants when unsafe math is allowed.
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index 156d0a369305108cd7072f257be4dbccfe4c87a4..c29200a549ecdc37523cd06d5dd866a8d49e2ca5 100644 (file)
@@ -6820,8 +6820,16 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
 
   // fold vector ops
   if (VT.isVector()) {
+    // This just handles C1 * C2 for vectors. Other vector folds are below.
     SDValue FoldedVOp = SimplifyVBinOp(N);
-    if (FoldedVOp.getNode()) return FoldedVOp;
+    if (FoldedVOp.getNode())
+      return FoldedVOp;
+    // Canonicalize vector constant to RHS.
+    if (N0.getOpcode() == ISD::BUILD_VECTOR &&
+        N1.getOpcode() != ISD::BUILD_VECTOR)
+      if (auto *BV0 = dyn_cast<BuildVectorSDNode>(N0))
+        if (BV0->isConstant())
+          return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
   }
 
   // fold (fmul c1, c2) -> c1*c2
@@ -6842,11 +6850,19 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
       return N1;
 
     // fold (fmul (fmul x, c1), c2) -> (fmul x, (fmul c1, c2))
-    if (N1CFP && N0.getOpcode() == ISD::FMUL &&
-        N0.getNode()->hasOneUse() && isConstOrConstSplatFP(N0.getOperand(1))) {
-      SDLoc SL(N);
-      SDValue MulConsts = DAG.getNode(ISD::FMUL, SL, VT, N0.getOperand(1), N1);
-      return DAG.getNode(ISD::FMUL, SL, VT, N0.getOperand(0), MulConsts);
+    if (N0.getOpcode() == ISD::FMUL) {
+      // Fold scalars or any vector constants (not just splats).
+      // This fold is done in general by InstCombine, but extra fmul insts
+      // may have been generated during lowering.
+      SDValue N01 = N0.getOperand(1);
+      auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
+      auto *BV01 = dyn_cast<BuildVectorSDNode>(N01);
+      if ((N1CFP && isConstOrConstSplatFP(N01)) ||
+          (BV1 && BV01 && BV1->isConstant() && BV01->isConstant())) {
+        SDLoc SL(N);
+        SDValue MulConsts = DAG.getNode(ISD::FMUL, SL, VT, N01, N1);
+        return DAG.getNode(ISD::FMUL, SL, VT, N0.getOperand(0), MulConsts);
+      }
     }
 
     // fold (fmul (fadd x, x), c) -> (fmul x, (fmul 2.0, c))