Teach the DAG combiner to turn chains of FADDs (x+x+x+x+...) into FMULs by constants...
authorOwen Anderson <resistor@mac.com>
Thu, 30 Aug 2012 23:35:16 +0000 (23:35 +0000)
committerOwen Anderson <resistor@mac.com>
Thu, 30 Aug 2012 23:35:16 +0000 (23:35 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@162956 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/SelectionDAG/DAGCombiner.cpp
test/CodeGen/X86/fp-fast.ll [new file with mode: 0644]

index 4e29879bef19d77ba66d727d67b73abceac7b9be..d338ac81ed5c513890cb6c878280143280111074 100644 (file)
@@ -5681,6 +5681,127 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
                        DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
                                    N0.getOperand(1), N1));
 
+  // In unsafe math mode, we can fold chains of FADD's of the same value
+  // into multiplications.  This transform is not safe in general because
+  // we are reducing the number of rounding steps.
+  if (DAG.getTarget().Options.UnsafeFPMath &&
+      TLI.isOperationLegalOrCustom(ISD::FMUL, VT) &&
+      !N0CFP && !N1CFP) {
+    if (N0.getOpcode() == ISD::FMUL) {
+      ConstantFPSDNode *CFP00 = dyn_cast<ConstantFPSDNode>(N0.getOperand(0));
+      ConstantFPSDNode *CFP01 = dyn_cast<ConstantFPSDNode>(N0.getOperand(1));
+
+      // (fadd (fmul c, x), x) -> (fmul c+1, x)
+      if (CFP00 && !CFP01 && N0.getOperand(1) == N1) {
+        SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
+                                     SDValue(CFP00, 0),
+                                     DAG.getConstantFP(1.0, VT));
+        return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
+                           N1, NewCFP);
+      }
+
+      // (fadd (fmul x, c), x) -> (fmul c+1, x)
+      if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
+        SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
+                                     SDValue(CFP01, 0),
+                                     DAG.getConstantFP(1.0, VT));
+        return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
+                           N1, NewCFP);
+      }
+
+      // (fadd (fadd x, x), x) -> (fmul 3.0, x)
+      if (!CFP00 && !CFP01 && N0.getOperand(0) == N0.getOperand(1) &&
+          N0.getOperand(0) == N1) {
+        return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
+                           N1, DAG.getConstantFP(3.0, VT));
+      }
+
+      // (fadd (fmul c, x), (fadd x, x)) -> (fmul c+2, x)
+      if (CFP00 && !CFP01 && N1.getOpcode() == ISD::FADD &&
+          N1.getOperand(0) == N1.getOperand(1) &&
+          N0.getOperand(1) == N1.getOperand(0)) {
+        SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
+                                     SDValue(CFP00, 0),
+                                     DAG.getConstantFP(2.0, VT));
+        return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
+                           N0.getOperand(1), NewCFP);
+      }
+
+      // (fadd (fmul x, c), (fadd x, x)) -> (fmul c+2, x)
+      if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
+          N1.getOperand(0) == N1.getOperand(1) &&
+          N0.getOperand(0) == N1.getOperand(0)) {
+        SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
+                                     SDValue(CFP01, 0),
+                                     DAG.getConstantFP(2.0, VT));
+        return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
+                           N0.getOperand(0), NewCFP);
+      }
+    }
+
+    if (N1.getOpcode() == ISD::FMUL) {
+      ConstantFPSDNode *CFP10 = dyn_cast<ConstantFPSDNode>(N1.getOperand(0));
+      ConstantFPSDNode *CFP11 = dyn_cast<ConstantFPSDNode>(N1.getOperand(1));
+
+      // (fadd x, (fmul c, x)) -> (fmul c+1, x)
+      if (CFP10 && !CFP11 && N1.getOperand(1) == N0) {
+        SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
+                                     SDValue(CFP10, 0),
+                                     DAG.getConstantFP(1.0, VT));
+        return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
+                           N0, NewCFP);
+      }
+
+      // (fadd x, (fmul x, c)) -> (fmul c+1, x)
+      if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
+        SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
+                                     SDValue(CFP11, 0),
+                                     DAG.getConstantFP(1.0, VT));
+        return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
+                           N0, NewCFP);
+      }
+
+      // (fadd x, (fadd x, x)) -> (fmul 3.0, x)
+      if (!CFP10 && !CFP11 && N1.getOperand(0) == N1.getOperand(1) &&
+          N1.getOperand(0) == N0) {
+        return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
+                           N0, DAG.getConstantFP(3.0, VT));
+      }
+
+      // (fadd (fadd x, x), (fmul c, x)) -> (fmul c+2, x)
+      if (CFP10 && !CFP11 && N1.getOpcode() == ISD::FADD &&
+          N1.getOperand(0) == N1.getOperand(1) &&
+          N0.getOperand(1) == N1.getOperand(0)) {
+        SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
+                                     SDValue(CFP10, 0),
+                                     DAG.getConstantFP(2.0, VT));
+        return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
+                           N0.getOperand(1), NewCFP);
+      }
+
+      // (fadd (fadd x, x), (fmul x, c)) -> (fmul c+2, x)
+      if (CFP11 && !CFP10 && N1.getOpcode() == ISD::FADD &&
+          N1.getOperand(0) == N1.getOperand(1) &&
+          N0.getOperand(0) == N1.getOperand(0)) {
+        SDValue NewCFP = DAG.getNode(ISD::FADD, N->getDebugLoc(), VT,
+                                     SDValue(CFP11, 0),
+                                     DAG.getConstantFP(2.0, VT));
+        return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
+                           N0.getOperand(0), NewCFP);
+      }
+    }
+
+    // (fadd (fadd x, x), (fadd x, x)) -> (fmul 4.0, x)
+    if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
+        N0.getOperand(0) == N0.getOperand(1) &&
+        N1.getOperand(0) == N1.getOperand(1) &&
+        N0.getOperand(0) == N1.getOperand(0)) {
+      return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
+                         N0.getOperand(0),
+                         DAG.getConstantFP(4.0, VT));
+    }
+  }
+
   // FADD -> FMA combines:
   if ((DAG.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast ||
        DAG.getTarget().Options.UnsafeFPMath) &&
@@ -5692,7 +5813,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
       return DAG.getNode(ISD::FMA, N->getDebugLoc(), VT,
                          N0.getOperand(0), N0.getOperand(1), N1);
     }
-  
+
     // fold (fadd x, (fmul y, z)) -> (fma x, y, z)
     // Note: Commutes FADD operands.
     if (N1.getOpcode() == ISD::FMUL && N1->hasOneUse()) {
diff --git a/test/CodeGen/X86/fp-fast.ll b/test/CodeGen/X86/fp-fast.ll
new file mode 100644 (file)
index 0000000..61f59b4
--- /dev/null
@@ -0,0 +1,37 @@
+; RUN: llc -march=x86-64 -mtriple=x86_64-apple-darwin -enable-unsafe-fp-math < %s | FileCheck %s
+
+; CHECK: test1
+define float @test1(float %a) {
+; CHECK-NOT: vaddss
+; CHECK: vmulss
+; CHECK-NOT: vaddss
+; CHECK: ret
+  %t1 = fadd float %a, %a
+  %r = fadd float %t1, %t1
+  ret float %r
+}
+
+; CHECK: test2
+define float @test2(float %a) {
+; CHECK-NOT: vaddss
+; CHECK: vmulss
+; CHECK-NOT: vaddss
+; CHECK: ret
+  %t1 = fmul float 4.0, %a
+  %t2 = fadd float %a, %a
+  %r = fadd float %t1, %t2
+  ret float %r
+}
+
+; CHECK: test3
+define float @test3(float %a) {
+; CHECK-NOT: vaddss
+; CHECK: vxorps
+; CHECK-NOT: vaddss
+; CHECK: ret
+  %t1 = fmul float 2.0, %a
+  %t2 = fadd float %a, %a
+  %r = fsub float %t1, %t2
+  ret float %r
+}
+