Refactor reciprocal square root estimate into target-independent function; NFC.
authorSanjay Patel <spatel@rotateright.com>
Sun, 21 Sep 2014 15:19:15 +0000 (15:19 +0000)
committerSanjay Patel <spatel@rotateright.com>
Sun, 21 Sep 2014 15:19:15 +0000 (15:19 +0000)
This is purely a plumbing patch. No functional changes intended.

The ultimate goal is to allow targets other than PowerPC (certainly X86 and Aarch64) to turn this:

z = y / sqrt(x)

into:

z = y * rsqrte(x)

using whatever HW magic they can use. See http://llvm.org/bugs/show_bug.cgi?id=20900 .

The first step is to add a target hook for RSQRTE, take the already target-independent code selfishly hoarded by PPC, and put it into DAGCombiner.

Next steps:

    The code in DAGCombiner::BuildRSQRTE() should be refactored further; tests that exercise that logic need to be added.
    Logic in PPCTargetLowering::BuildRSQRTE() should be hoisted into DAGCombiner.
    X86 and AArch64 overrides for TargetLowering.BuildRSQRTE() should be added.

Differential Revision: http://reviews.llvm.org/D5425

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@218219 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Target/TargetLowering.h
lib/CodeGen/SelectionDAG/DAGCombiner.cpp
lib/Target/PowerPC/PPCISelLowering.cpp
lib/Target/PowerPC/PPCISelLowering.h

index e6c4634079cc26548552dc9b84409f0f67220b32..0a7222599aac4e094f9b0ac3dfa7beeee2979b65 100644 (file)
@@ -2602,6 +2602,10 @@ public:
     return SDValue();
   }
 
     return SDValue();
   }
 
+  virtual SDValue BuildRSQRTE(SDValue Op, DAGCombinerInfo &DCI) const {
+    return SDValue();
+  }
+
   //===--------------------------------------------------------------------===//
   // Legalization utility functions
   //
   //===--------------------------------------------------------------------===//
   // Legalization utility functions
   //
index aa2f2d1f2b1a7cdbcb8481700932e11b913afd18..30ac63570ff14362c8f771210b6c8979ad3e8264 100644 (file)
@@ -326,6 +326,7 @@ namespace {
     SDValue BuildSDIV(SDNode *N);
     SDValue BuildSDIVPow2(SDNode *N);
     SDValue BuildUDIV(SDNode *N);
     SDValue BuildSDIV(SDNode *N);
     SDValue BuildSDIVPow2(SDNode *N);
     SDValue BuildUDIV(SDNode *N);
+    SDValue BuildRSQRTE(SDNode *N);
     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
                                bool DemandHighBits = true);
     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
                                bool DemandHighBits = true);
     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
@@ -6987,23 +6988,29 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
   if (N0CFP && N1CFP)
     return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1);
 
   if (N0CFP && N1CFP)
     return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1);
 
-  // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
-  if (N1CFP && Options.UnsafeFPMath) {
-    // Compute the reciprocal 1.0 / c2.
-    APFloat N1APF = N1CFP->getValueAPF();
-    APFloat Recip(N1APF.getSemantics(), 1); // 1.0
-    APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
-    // Only do the transform if the reciprocal is a legal fp immediate that
-    // isn't too nasty (eg NaN, denormal, ...).
-    if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty
-        (!LegalOperations ||
-         // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
-         // backend)... we should handle this gracefully after Legalize.
-         // TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT) ||
-         TLI.isOperationLegal(llvm::ISD::ConstantFP, VT) ||
-         TLI.isFPImmLegal(Recip, VT)))
-      return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0,
-                         DAG.getConstantFP(Recip, VT));
+  if (Options.UnsafeFPMath) {
+    // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
+    if (N1CFP) {
+      // Compute the reciprocal 1.0 / c2.
+      APFloat N1APF = N1CFP->getValueAPF();
+      APFloat Recip(N1APF.getSemantics(), 1); // 1.0
+      APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
+      // Only do the transform if the reciprocal is a legal fp immediate that
+      // isn't too nasty (eg NaN, denormal, ...).
+      if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty
+          (!LegalOperations ||
+           // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
+           // backend)... we should handle this gracefully after Legalize.
+           // TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT) ||
+           TLI.isOperationLegal(llvm::ISD::ConstantFP, VT) ||
+           TLI.isFPImmLegal(Recip, VT)))
+        return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0,
+                           DAG.getConstantFP(Recip, VT));
+    }
+    // If this FDIV is part of a reciprocal square root, it may be folded
+    // into a target-specific square root estimate instruction.
+    if (SDValue SqrtOp = BuildRSQRTE(N))
+      return SqrtOp;
   }
 
   // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
   }
 
   // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
@@ -11695,6 +11702,44 @@ SDValue DAGCombiner::BuildUDIV(SDNode *N) {
   return S;
 }
 
   return S;
 }
 
+/// Given an ISD::FDIV node with either a direct or indirect ISD::FSQRT operand,
+/// generate a DAG expression using a reciprocal square root estimate op.
+SDValue DAGCombiner::BuildRSQRTE(SDNode *N) {
+  // Expose the DAG combiner to the target combiner implementations.
+  TargetLowering::DAGCombinerInfo DCI(DAG, Level, false, this);
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  SDValue N1 = N->getOperand(1);
+
+  if (N1.getOpcode() == ISD::FSQRT) {
+    SDValue RV = TLI.BuildRSQRTE(N1.getOperand(0), DCI);
+    if (RV.getNode()) {
+      DCI.AddToWorklist(RV.getNode());
+      return DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV);
+    }
+  } else if (N1.getOpcode() == ISD::FP_EXTEND &&
+             N1.getOperand(0).getOpcode() == ISD::FSQRT) {
+    SDValue RV = TLI.BuildRSQRTE(N1.getOperand(0).getOperand(0), DCI);
+    if (RV.getNode()) {
+      DCI.AddToWorklist(RV.getNode());
+      RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
+      DCI.AddToWorklist(RV.getNode());
+      return DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV);
+    }
+  } else if (N1.getOpcode() == ISD::FP_ROUND &&
+             N1.getOperand(0).getOpcode() == ISD::FSQRT) {
+    SDValue RV = TLI.BuildRSQRTE(N1.getOperand(0).getOperand(0), DCI);
+    if (RV.getNode()) {
+      DCI.AddToWorklist(RV.getNode());
+      RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
+      DCI.AddToWorklist(RV.getNode());
+      return DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV);
+    }
+  }
+
+  return SDValue();
+}
+
 /// Return true if base is a frame index, which is known not to alias with
 /// anything but itself.  Provides base object and offset as results.
 static bool FindBaseOffset(SDValue Ptr, SDValue &Base, int64_t &Offset,
 /// Return true if base is a frame index, which is known not to alias with
 /// anything but itself.  Provides base object and offset as results.
 static bool FindBaseOffset(SDValue Ptr, SDValue &Base, int64_t &Offset,
index fd188fe37e2e6214c721b040d2d708a2372bdd9b..d96cdab604faf7616c4cb816c663b25644343253 100644 (file)
@@ -7489,8 +7489,7 @@ SDValue PPCTargetLowering::DAGCombineFastRecip(SDValue Op,
   return SDValue();
 }
 
   return SDValue();
 }
 
-SDValue PPCTargetLowering::DAGCombineFastRecipFSQRT(SDValue Op,
-                                             DAGCombinerInfo &DCI) const {
+SDValue PPCTargetLowering::BuildRSQRTE(SDValue Op, DAGCombinerInfo &DCI) const {
   if (DCI.isAfterLegalizeVectorOps())
     return SDValue();
 
   if (DCI.isAfterLegalizeVectorOps())
     return SDValue();
 
@@ -8289,43 +8288,6 @@ SDValue PPCTargetLowering::PerformDAGCombine(SDNode *N,
     assert(TM.Options.UnsafeFPMath &&
            "Reciprocal estimates require UnsafeFPMath");
 
     assert(TM.Options.UnsafeFPMath &&
            "Reciprocal estimates require UnsafeFPMath");
 
-    if (N->getOperand(1).getOpcode() == ISD::FSQRT) {
-      SDValue RV =
-        DAGCombineFastRecipFSQRT(N->getOperand(1).getOperand(0), DCI);
-      if (RV.getNode()) {
-        DCI.AddToWorklist(RV.getNode());
-        return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
-                           N->getOperand(0), RV);
-      }
-    } else if (N->getOperand(1).getOpcode() == ISD::FP_EXTEND &&
-               N->getOperand(1).getOperand(0).getOpcode() == ISD::FSQRT) {
-      SDValue RV =
-        DAGCombineFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0),
-                                 DCI);
-      if (RV.getNode()) {
-        DCI.AddToWorklist(RV.getNode());
-        RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N->getOperand(1)),
-                         N->getValueType(0), RV);
-        DCI.AddToWorklist(RV.getNode());
-        return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
-                           N->getOperand(0), RV);
-      }
-    } else if (N->getOperand(1).getOpcode() == ISD::FP_ROUND &&
-               N->getOperand(1).getOperand(0).getOpcode() == ISD::FSQRT) {
-      SDValue RV =
-        DAGCombineFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0),
-                                 DCI);
-      if (RV.getNode()) {
-        DCI.AddToWorklist(RV.getNode());
-        RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N->getOperand(1)),
-                         N->getValueType(0), RV,
-                         N->getOperand(1).getOperand(1));
-        DCI.AddToWorklist(RV.getNode());
-        return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
-                           N->getOperand(0), RV);
-      }
-    }
-
     SDValue RV = DAGCombineFastRecip(N->getOperand(1), DCI);
     if (RV.getNode()) {
       DCI.AddToWorklist(RV.getNode());
     SDValue RV = DAGCombineFastRecip(N->getOperand(1), DCI);
     if (RV.getNode()) {
       DCI.AddToWorklist(RV.getNode());
@@ -8341,7 +8303,7 @@ SDValue PPCTargetLowering::PerformDAGCombine(SDNode *N,
 
     // Compute this as 1/(1/sqrt(X)), which is the reciprocal of the
     // reciprocal sqrt.
 
     // Compute this as 1/(1/sqrt(X)), which is the reciprocal of the
     // reciprocal sqrt.
-    SDValue RV = DAGCombineFastRecipFSQRT(N->getOperand(0), DCI);
+    SDValue RV = BuildRSQRTE(N->getOperand(0), DCI);
     if (RV.getNode()) {
       DCI.AddToWorklist(RV.getNode());
       RV = DAGCombineFastRecip(RV, DCI);
     if (RV.getNode()) {
       DCI.AddToWorklist(RV.getNode());
       RV = DAGCombineFastRecip(RV, DCI);
index c53dc83fa8a9f1dd5e8c9d79ee28479775cca63f..5628bc79342de218b18ebc735495b094511b0377 100644 (file)
@@ -696,7 +696,7 @@ namespace llvm {
     SDValue DAGCombineExtBoolTrunc(SDNode *N, DAGCombinerInfo &DCI) const;
     SDValue DAGCombineTruncBoolExt(SDNode *N, DAGCombinerInfo &DCI) const;
     SDValue DAGCombineFastRecip(SDValue Op, DAGCombinerInfo &DCI) const;
     SDValue DAGCombineExtBoolTrunc(SDNode *N, DAGCombinerInfo &DCI) const;
     SDValue DAGCombineTruncBoolExt(SDNode *N, DAGCombinerInfo &DCI) const;
     SDValue DAGCombineFastRecip(SDValue Op, DAGCombinerInfo &DCI) const;
-    SDValue DAGCombineFastRecipFSQRT(SDValue Op, DAGCombinerInfo &DCI) const;
+    SDValue BuildRSQRTE(SDValue Op, DAGCombinerInfo &DCI) const;
 
     CCAssignFn *useFastISelCCs(unsigned Flag) const;
   };
 
     CCAssignFn *useFastISelCCs(unsigned Flag) const;
   };