[DAGCombiner] Improve FMA support for interpolation patterns
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index 5c4110276cd9c150e9b3df159525180ccee8ca0c..c9914fa0f17ad944f35317467551ce6df7b06aad 100644 (file)
@@ -321,6 +321,7 @@ namespace {
 
     SDValue visitFADDForFMACombine(SDNode *N);
     SDValue visitFSUBForFMACombine(SDNode *N);
+    SDValue visitFMULForFMACombine(SDNode *N);
 
     SDValue XformToShuffleWithZero(SDNode *N);
     SDValue ReassociateOps(unsigned Opc, SDLoc DL, SDValue LHS, SDValue RHS);
@@ -619,7 +620,7 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG,
   assert(Depth <= 6 && "GetNegatedExpression doesn't match isNegatibleForFree");
 
   const SDNodeFlags *Flags = Op.getNode()->getFlags();
-  
+
   switch (Op.getOpcode()) {
   default: llvm_unreachable("Unknown code");
   case ISD::ConstantFP: {
@@ -7481,25 +7482,23 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
   SDLoc SL(N);
 
   const TargetOptions &Options = DAG.getTarget().Options;
-  bool UnsafeFPMath = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
-                       Options.UnsafeFPMath);
+  bool AllowFusion =
+      (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath);
 
   // Floating-point multiply-add with intermediate rounding.
-  bool HasFMAD = (LegalOperations &&
-                  TLI.isOperationLegal(ISD::FMAD, VT));
+  bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
 
   // Floating-point multiply-add without intermediate rounding.
-  bool HasFMA = ((!LegalOperations ||
-                  TLI.isOperationLegalOrCustom(ISD::FMA, VT)) &&
-                 TLI.isFMAFasterThanFMulAndFAdd(VT) &&
-                 UnsafeFPMath);
+  bool HasFMA =
+      AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+      (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
 
   // No valid opcode, do not combine.
   if (!HasFMAD && !HasFMA)
     return SDValue();
 
   // Always prefer FMAD to FMA for precision.
-  unsigned int PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
+  unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
   bool LookThroughFPExt = TLI.isFPExtFree(VT);
 
@@ -7527,7 +7526,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
   }
 
   // Look through FP_EXTEND nodes to do more combining.
-  if (UnsafeFPMath && LookThroughFPExt) {
+  if (AllowFusion && LookThroughFPExt) {
     // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
     if (N0.getOpcode() == ISD::FP_EXTEND) {
       SDValue N00 = N0.getOperand(0);
@@ -7553,7 +7552,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
   }
 
   // More folding opportunities when target permits.
-  if ((UnsafeFPMath || HasFMAD)  && Aggressive) {
+  if ((AllowFusion || HasFMAD)  && Aggressive) {
     // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z))
     if (N0.getOpcode() == PreferredFusedOpcode &&
         N0.getOperand(2).getOpcode() == ISD::FMUL) {
@@ -7576,7 +7575,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
                                      N0));
     }
 
-    if (UnsafeFPMath && LookThroughFPExt) {
+    if (AllowFusion && LookThroughFPExt) {
       // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
       //   -> (fma x, y, (fma (fpext u), (fpext v), z))
       auto FoldFAddFMAFPExtFMul = [&] (
@@ -7666,25 +7665,23 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
   SDLoc SL(N);
 
   const TargetOptions &Options = DAG.getTarget().Options;
-  bool UnsafeFPMath = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
-                       Options.UnsafeFPMath);
+  bool AllowFusion =
+      (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath);
 
   // Floating-point multiply-add with intermediate rounding.
-  bool HasFMAD = (LegalOperations &&
-                  TLI.isOperationLegal(ISD::FMAD, VT));
+  bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
 
   // Floating-point multiply-add without intermediate rounding.
-  bool HasFMA = ((!LegalOperations ||
-                  TLI.isOperationLegalOrCustom(ISD::FMA, VT)) &&
-                 TLI.isFMAFasterThanFMulAndFAdd(VT) &&
-                 UnsafeFPMath);
+  bool HasFMA =
+      AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+      (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
 
   // No valid opcode, do not combine.
   if (!HasFMAD && !HasFMA)
     return SDValue();
 
   // Always prefer FMAD to FMA for precision.
-  unsigned int PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
+  unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
   bool LookThroughFPExt = TLI.isFPExtFree(VT);
 
@@ -7717,7 +7714,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
   }
 
   // Look through FP_EXTEND nodes to do more combining.
-  if (UnsafeFPMath && LookThroughFPExt) {
+  if (AllowFusion && LookThroughFPExt) {
     // fold (fsub (fpext (fmul x, y)), z)
     //   -> (fma (fpext x), (fpext y), (fneg z))
     if (N0.getOpcode() == ISD::FP_EXTEND) {
@@ -7793,7 +7790,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
   }
 
   // More folding opportunities when target permits.
-  if ((UnsafeFPMath || HasFMAD) && Aggressive) {
+  if ((AllowFusion || HasFMAD) && Aggressive) {
     // fold (fsub (fma x, y, (fmul u, v)), z)
     //   -> (fma x, y (fma u, v, (fneg z)))
     if (N0.getOpcode() == PreferredFusedOpcode &&
@@ -7823,7 +7820,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
                                      N21, N0));
     }
 
-    if (UnsafeFPMath && LookThroughFPExt) {
+    if (AllowFusion && LookThroughFPExt) {
       // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
       //   -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
       if (N0.getOpcode() == PreferredFusedOpcode) {
@@ -7924,6 +7921,88 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
   return SDValue();
 }
 
+/// Try to perform FMA combining on a given FMUL node.
+SDValue DAGCombiner::visitFMULForFMACombine(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+  SDLoc SL(N);
+
+  assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
+
+  const TargetOptions &Options = DAG.getTarget().Options;
+  bool AllowFusion =
+      (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath);
+
+  // Floating-point multiply-add with intermediate rounding.
+  bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
+
+  // Floating-point multiply-add without intermediate rounding.
+  bool HasFMA =
+      AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+      (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
+
+  // No valid opcode, do not combine.
+  if (!HasFMAD && !HasFMA)
+    return SDValue();
+
+  // Always prefer FMAD to FMA for precision.
+  unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
+  bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
+
+  // fold (fmul (fadd x, +1.0), y) -> (fma x, y, y)
+  // fold (fmul (fadd x, -1.0), y) -> (fma x, y, (fneg y))
+  auto FuseFADD = [&](SDValue X, SDValue Y) {
+    if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
+      auto XC1 = isConstOrConstSplatFP(X.getOperand(1));
+      if (XC1 && XC1->isExactlyValue(+1.0))
+        return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y);
+      if (XC1 && XC1->isExactlyValue(-1.0))
+        return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
+                           DAG.getNode(ISD::FNEG, SL, VT, Y));
+    }
+    return SDValue();
+  };
+
+  if (SDValue FMA = FuseFADD(N0, N1))
+    return FMA;
+  if (SDValue FMA = FuseFADD(N1, N0))
+    return FMA;
+
+  // fold (fmul (fsub +1.0, x), y) -> (fma (fneg x), y, y)
+  // fold (fmul (fsub -1.0, x), y) -> (fma (fneg x), y, (fneg y))
+  // fold (fmul (fsub x, +1.0), y) -> (fma x, y, (fneg y))
+  // fold (fmul (fsub x, -1.0), y) -> (fma x, y, y)
+  auto FuseFSUB = [&](SDValue X, SDValue Y) {
+    if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
+      auto XC0 = isConstOrConstSplatFP(X.getOperand(0));
+      if (XC0 && XC0->isExactlyValue(+1.0))
+        return DAG.getNode(PreferredFusedOpcode, SL, VT,
+                           DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
+                           Y);
+      if (XC0 && XC0->isExactlyValue(-1.0))
+        return DAG.getNode(PreferredFusedOpcode, SL, VT,
+                           DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
+                           DAG.getNode(ISD::FNEG, SL, VT, Y));
+
+      auto XC1 = isConstOrConstSplatFP(X.getOperand(1));
+      if (XC1 && XC1->isExactlyValue(+1.0))
+        return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
+                           DAG.getNode(ISD::FNEG, SL, VT, Y));
+      if (XC1 && XC1->isExactlyValue(-1.0))
+        return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y);
+    }
+    return SDValue();
+  };
+
+  if (SDValue FMA = FuseFSUB(N0, N1))
+    return FMA;
+  if (SDValue FMA = FuseFSUB(N1, N0))
+    return FMA;
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitFADD(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
@@ -8231,6 +8310,12 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
     }
   }
 
+  // FMUL -> FMA combines:
+  if (SDValue Fused = visitFMULForFMACombine(N)) {
+    AddToWorklist(Fused.getNode());
+    return Fused;
+  }
+
   return SDValue();
 }
 
@@ -11395,20 +11480,7 @@ SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
       Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
                             bitcastToAPInt().getZExtValue(), SDLoc(CFP),
                             MVT::i32);
-      SDValue NewSt = DAG.getStore(Chain, DL, Tmp,
-                                   Ptr, ST->getMemOperand());
-
-      dbgs() << "Replacing FP constant: ";
-      Value->dump(&DAG);
-
-      if (cast<StoreSDNode>(NewSt)->getMemoryVT() != MVT::i32) {
-        dbgs() << "Different memoryvt\n";
-      } else {
-        dbgs() << "same memoryvt\n";
-      }
-
-
-      return NewSt;
+      return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
     }
 
     return SDValue();
@@ -11485,16 +11557,6 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
   if (Value.getOpcode() == ISD::UNDEF && ST->isUnindexed())
     return Chain;
 
-  // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
-  //
-  // Make sure to do this only after attempting to merge stores in order to
-  //  avoid changing the types of some subset of stores due to visit order,
-  //  preventing their merging.
-  if (isa<ConstantFPSDNode>(Value)) {
-    if (SDValue NewSt = replaceStoreOfFPConstant(ST))
-      return NewSt;
-  }
-
   // Try to infer better alignment information than the store already has.
   if (OptLevel != CodeGenOpt::None && ST->isUnindexed()) {
     if (unsigned Align = DAG.InferPtrAlignment(Ptr)) {
@@ -11618,6 +11680,16 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
       return SDValue(N, 0);
   }
 
+  // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
+  //
+  // Make sure to do this only after attempting to merge stores in order to
+  //  avoid changing the types of some subset of stores due to visit order,
+  //  preventing their merging.
+  if (isa<ConstantFPSDNode>(Value)) {
+    if (SDValue NewSt = replaceStoreOfFPConstant(ST))
+      return NewSt;
+  }
+
   return ReduceLoadOpStoreWidth(N);
 }