[DAGCombiner] Improve FMA support for interpolation patterns
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index b8691bd9e100cf97c16de9fa53a6894211afdf66..c9914fa0f17ad944f35317467551ce6df7b06aad 100644 (file)
@@ -301,6 +301,7 @@ namespace {
     SDValue visitLOAD(SDNode *N);
 
     SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
+    SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
 
     SDValue visitSTORE(SDNode *N);
     SDValue visitINSERT_VECTOR_ELT(SDNode *N);
@@ -320,6 +321,7 @@ namespace {
 
     SDValue visitFADDForFMACombine(SDNode *N);
     SDValue visitFSUBForFMACombine(SDNode *N);
+    SDValue visitFMULForFMACombine(SDNode *N);
 
     SDValue XformToShuffleWithZero(SDNode *N);
     SDValue ReassociateOps(unsigned Opc, SDLoc DL, SDValue LHS, SDValue RHS);
@@ -618,7 +620,7 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG,
   assert(Depth <= 6 && "GetNegatedExpression doesn't match isNegatibleForFree");
 
   const SDNodeFlags *Flags = Op.getNode()->getFlags();
-  
+
   switch (Op.getOpcode()) {
   default: llvm_unreachable("Unknown code");
   case ISD::ConstantFP: {
@@ -7480,25 +7482,23 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
   SDLoc SL(N);
 
   const TargetOptions &Options = DAG.getTarget().Options;
-  bool UnsafeFPMath = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
-                       Options.UnsafeFPMath);
+  bool AllowFusion =
+      (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath);
 
   // Floating-point multiply-add with intermediate rounding.
-  bool HasFMAD = (LegalOperations &&
-                  TLI.isOperationLegal(ISD::FMAD, VT));
+  bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
 
   // Floating-point multiply-add without intermediate rounding.
-  bool HasFMA = ((!LegalOperations ||
-                  TLI.isOperationLegalOrCustom(ISD::FMA, VT)) &&
-                 TLI.isFMAFasterThanFMulAndFAdd(VT) &&
-                 UnsafeFPMath);
+  bool HasFMA =
+      AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+      (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
 
   // No valid opcode, do not combine.
   if (!HasFMAD && !HasFMA)
     return SDValue();
 
   // Always prefer FMAD to FMA for precision.
-  unsigned int PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
+  unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
   bool LookThroughFPExt = TLI.isFPExtFree(VT);
 
@@ -7526,7 +7526,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
   }
 
   // Look through FP_EXTEND nodes to do more combining.
-  if (UnsafeFPMath && LookThroughFPExt) {
+  if (AllowFusion && LookThroughFPExt) {
     // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
     if (N0.getOpcode() == ISD::FP_EXTEND) {
       SDValue N00 = N0.getOperand(0);
@@ -7552,7 +7552,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
   }
 
   // More folding opportunities when target permits.
-  if ((UnsafeFPMath || HasFMAD)  && Aggressive) {
+  if ((AllowFusion || HasFMAD)  && Aggressive) {
     // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z))
     if (N0.getOpcode() == PreferredFusedOpcode &&
         N0.getOperand(2).getOpcode() == ISD::FMUL) {
@@ -7575,7 +7575,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
                                      N0));
     }
 
-    if (UnsafeFPMath && LookThroughFPExt) {
+    if (AllowFusion && LookThroughFPExt) {
       // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
       //   -> (fma x, y, (fma (fpext u), (fpext v), z))
       auto FoldFAddFMAFPExtFMul = [&] (
@@ -7665,25 +7665,23 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
   SDLoc SL(N);
 
   const TargetOptions &Options = DAG.getTarget().Options;
-  bool UnsafeFPMath = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
-                       Options.UnsafeFPMath);
+  bool AllowFusion =
+      (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath);
 
   // Floating-point multiply-add with intermediate rounding.
-  bool HasFMAD = (LegalOperations &&
-                  TLI.isOperationLegal(ISD::FMAD, VT));
+  bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
 
   // Floating-point multiply-add without intermediate rounding.
-  bool HasFMA = ((!LegalOperations ||
-                  TLI.isOperationLegalOrCustom(ISD::FMA, VT)) &&
-                 TLI.isFMAFasterThanFMulAndFAdd(VT) &&
-                 UnsafeFPMath);
+  bool HasFMA =
+      AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+      (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
 
   // No valid opcode, do not combine.
   if (!HasFMAD && !HasFMA)
     return SDValue();
 
   // Always prefer FMAD to FMA for precision.
-  unsigned int PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
+  unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
   bool LookThroughFPExt = TLI.isFPExtFree(VT);
 
@@ -7716,7 +7714,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
   }
 
   // Look through FP_EXTEND nodes to do more combining.
-  if (UnsafeFPMath && LookThroughFPExt) {
+  if (AllowFusion && LookThroughFPExt) {
     // fold (fsub (fpext (fmul x, y)), z)
     //   -> (fma (fpext x), (fpext y), (fneg z))
     if (N0.getOpcode() == ISD::FP_EXTEND) {
@@ -7792,7 +7790,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
   }
 
   // More folding opportunities when target permits.
-  if ((UnsafeFPMath || HasFMAD) && Aggressive) {
+  if ((AllowFusion || HasFMAD) && Aggressive) {
     // fold (fsub (fma x, y, (fmul u, v)), z)
     //   -> (fma x, y (fma u, v, (fneg z)))
     if (N0.getOpcode() == PreferredFusedOpcode &&
@@ -7822,7 +7820,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
                                      N21, N0));
     }
 
-    if (UnsafeFPMath && LookThroughFPExt) {
+    if (AllowFusion && LookThroughFPExt) {
       // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
       //   -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
       if (N0.getOpcode() == PreferredFusedOpcode) {
@@ -7923,6 +7921,88 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
   return SDValue();
 }
 
+/// Try to perform FMA combining on a given FMUL node.
+SDValue DAGCombiner::visitFMULForFMACombine(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+  SDLoc SL(N);
+
+  assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
+
+  const TargetOptions &Options = DAG.getTarget().Options;
+  bool AllowFusion =
+      (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath);
+
+  // Floating-point multiply-add with intermediate rounding.
+  bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
+
+  // Floating-point multiply-add without intermediate rounding.
+  bool HasFMA =
+      AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+      (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
+
+  // No valid opcode, do not combine.
+  if (!HasFMAD && !HasFMA)
+    return SDValue();
+
+  // Always prefer FMAD to FMA for precision.
+  unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
+  bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
+
+  // fold (fmul (fadd x, +1.0), y) -> (fma x, y, y)
+  // fold (fmul (fadd x, -1.0), y) -> (fma x, y, (fneg y))
+  auto FuseFADD = [&](SDValue X, SDValue Y) {
+    if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
+      auto XC1 = isConstOrConstSplatFP(X.getOperand(1));
+      if (XC1 && XC1->isExactlyValue(+1.0))
+        return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y);
+      if (XC1 && XC1->isExactlyValue(-1.0))
+        return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
+                           DAG.getNode(ISD::FNEG, SL, VT, Y));
+    }
+    return SDValue();
+  };
+
+  if (SDValue FMA = FuseFADD(N0, N1))
+    return FMA;
+  if (SDValue FMA = FuseFADD(N1, N0))
+    return FMA;
+
+  // fold (fmul (fsub +1.0, x), y) -> (fma (fneg x), y, y)
+  // fold (fmul (fsub -1.0, x), y) -> (fma (fneg x), y, (fneg y))
+  // fold (fmul (fsub x, +1.0), y) -> (fma x, y, (fneg y))
+  // fold (fmul (fsub x, -1.0), y) -> (fma x, y, y)
+  auto FuseFSUB = [&](SDValue X, SDValue Y) {
+    if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
+      auto XC0 = isConstOrConstSplatFP(X.getOperand(0));
+      if (XC0 && XC0->isExactlyValue(+1.0))
+        return DAG.getNode(PreferredFusedOpcode, SL, VT,
+                           DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
+                           Y);
+      if (XC0 && XC0->isExactlyValue(-1.0))
+        return DAG.getNode(PreferredFusedOpcode, SL, VT,
+                           DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
+                           DAG.getNode(ISD::FNEG, SL, VT, Y));
+
+      auto XC1 = isConstOrConstSplatFP(X.getOperand(1));
+      if (XC1 && XC1->isExactlyValue(+1.0))
+        return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
+                           DAG.getNode(ISD::FNEG, SL, VT, Y));
+      if (XC1 && XC1->isExactlyValue(-1.0))
+        return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y);
+    }
+    return SDValue();
+  };
+
+  if (SDValue FMA = FuseFSUB(N0, N1))
+    return FMA;
+  if (SDValue FMA = FuseFSUB(N1, N0))
+    return FMA;
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitFADD(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
@@ -8230,6 +8310,12 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
     }
   }
 
+  // FMUL -> FMA combines:
+  if (SDValue Fused = visitFMULForFMACombine(N)) {
+    AddToWorklist(Fused.getNode());
+    return Fused;
+  }
+
   return SDValue();
 }
 
@@ -11361,6 +11447,89 @@ SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
   return CombineTo(ST, Token, false);
 }
 
+SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
+  SDValue Value = ST->getValue();
+  if (Value.getOpcode() == ISD::TargetConstantFP)
+    return SDValue();
+
+  SDLoc DL(ST);
+
+  SDValue Chain = ST->getChain();
+  SDValue Ptr = ST->getBasePtr();
+
+  const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Value);
+
+  // NOTE: If the original store is volatile, this transform must not increase
+  // the number of stores.  For example, on x86-32 an f64 can be stored in one
+  // processor operation but an i64 (which is not legal) requires two.  So the
+  // transform should not be done in this case.
+
+  SDValue Tmp;
+  switch (CFP->getSimpleValueType(0).SimpleTy) {
+  default:
+    llvm_unreachable("Unknown FP type");
+  case MVT::f16:    // We don't do this for these yet.
+  case MVT::f80:
+  case MVT::f128:
+  case MVT::ppcf128:
+    return SDValue();
+  case MVT::f32:
+    if ((isTypeLegal(MVT::i32) && !LegalOperations && !ST->isVolatile()) ||
+        TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
+      ;
+      Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
+                            bitcastToAPInt().getZExtValue(), SDLoc(CFP),
+                            MVT::i32);
+      return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
+    }
+
+    return SDValue();
+  case MVT::f64:
+    if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
+         !ST->isVolatile()) ||
+        TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
+      ;
+      Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
+                            getZExtValue(), SDLoc(CFP), MVT::i64);
+      return DAG.getStore(Chain, DL, Tmp,
+                          Ptr, ST->getMemOperand());
+    }
+
+    if (!ST->isVolatile() &&
+        TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
+      // Many FP stores are not made apparent until after legalize, e.g. for
+      // argument passing.  Since this is so common, custom legalize the
+      // 64-bit integer store into two 32-bit stores.
+      uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
+      SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
+      SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
+      if (DAG.getDataLayout().isBigEndian())
+        std::swap(Lo, Hi);
+
+      unsigned Alignment = ST->getAlignment();
+      bool isVolatile = ST->isVolatile();
+      bool isNonTemporal = ST->isNonTemporal();
+      AAMDNodes AAInfo = ST->getAAInfo();
+
+      SDValue St0 = DAG.getStore(Chain, DL, Lo,
+                                 Ptr, ST->getPointerInfo(),
+                                 isVolatile, isNonTemporal,
+                                 ST->getAlignment(), AAInfo);
+      Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
+                        DAG.getConstant(4, DL, Ptr.getValueType()));
+      Alignment = MinAlign(Alignment, 4U);
+      SDValue St1 = DAG.getStore(Chain, DL, Hi,
+                                 Ptr, ST->getPointerInfo().getWithOffset(4),
+                                 isVolatile, isNonTemporal,
+                                 Alignment, AAInfo);
+      return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
+                         St0, St1);
+    }
+
+    return SDValue();
+  }
+}
+
 SDValue DAGCombiner::visitSTORE(SDNode *N) {
   StoreSDNode *ST  = cast<StoreSDNode>(N);
   SDValue Chain = ST->getChain();
@@ -11388,81 +11557,6 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
   if (Value.getOpcode() == ISD::UNDEF && ST->isUnindexed())
     return Chain;
 
-  // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
-  if (ConstantFPSDNode *CFP = dyn_cast<ConstantFPSDNode>(Value)) {
-    // NOTE: If the original store is volatile, this transform must not increase
-    // the number of stores.  For example, on x86-32 an f64 can be stored in one
-    // processor operation but an i64 (which is not legal) requires two.  So the
-    // transform should not be done in this case.
-    if (Value.getOpcode() != ISD::TargetConstantFP) {
-      SDValue Tmp;
-      switch (CFP->getSimpleValueType(0).SimpleTy) {
-      default: llvm_unreachable("Unknown FP type");
-      case MVT::f16:    // We don't do this for these yet.
-      case MVT::f80:
-      case MVT::f128:
-      case MVT::ppcf128:
-        break;
-      case MVT::f32:
-        if ((isTypeLegal(MVT::i32) && !LegalOperations && !ST->isVolatile()) ||
-            TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
-          ;
-          Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
-                              bitcastToAPInt().getZExtValue(), SDLoc(CFP),
-                              MVT::i32);
-          return DAG.getStore(Chain, SDLoc(N), Tmp,
-                              Ptr, ST->getMemOperand());
-        }
-        break;
-      case MVT::f64:
-        if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
-             !ST->isVolatile()) ||
-            TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
-          ;
-          Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
-                                getZExtValue(), SDLoc(CFP), MVT::i64);
-          return DAG.getStore(Chain, SDLoc(N), Tmp,
-                              Ptr, ST->getMemOperand());
-        }
-
-        if (!ST->isVolatile() &&
-            TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
-          // Many FP stores are not made apparent until after legalize, e.g. for
-          // argument passing.  Since this is so common, custom legalize the
-          // 64-bit integer store into two 32-bit stores.
-          uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
-          SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
-          SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
-          if (DAG.getDataLayout().isBigEndian())
-            std::swap(Lo, Hi);
-
-          unsigned Alignment = ST->getAlignment();
-          bool isVolatile = ST->isVolatile();
-          bool isNonTemporal = ST->isNonTemporal();
-          AAMDNodes AAInfo = ST->getAAInfo();
-
-          SDLoc DL(N);
-
-          SDValue St0 = DAG.getStore(Chain, SDLoc(ST), Lo,
-                                     Ptr, ST->getPointerInfo(),
-                                     isVolatile, isNonTemporal,
-                                     ST->getAlignment(), AAInfo);
-          Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
-                            DAG.getConstant(4, DL, Ptr.getValueType()));
-          Alignment = MinAlign(Alignment, 4U);
-          SDValue St1 = DAG.getStore(Chain, SDLoc(ST), Hi,
-                                     Ptr, ST->getPointerInfo().getWithOffset(4),
-                                     isVolatile, isNonTemporal,
-                                     Alignment, AAInfo);
-          return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
-                             St0, St1);
-        }
-
-        break;
-      }
-    }
-  }
-
   // Try to infer better alignment information than the store already has.
   if (OptLevel != CodeGenOpt::None && ST->isUnindexed()) {
     if (unsigned Align = DAG.InferPtrAlignment(Ptr)) {
@@ -11586,6 +11680,16 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
       return SDValue(N, 0);
   }
 
+  // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
+  //
+  // Make sure to do this only after attempting to merge stores in order to
+  //  avoid changing the types of some subset of stores due to visit order,
+  //  preventing their merging.
+  if (isa<ConstantFPSDNode>(Value)) {
+    if (SDValue NewSt = replaceStoreOfFPConstant(ST))
+      return NewSt;
+  }
+
   return ReduceLoadOpStoreWidth(N);
 }