+
+ case ISD::FADD: {
+ if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
+ break;
+
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::f32)
+ break;
+
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
+
+ // These should really be instruction patterns, but writing patterns with
+ // source modiifiers is a pain.
+
+ // fadd (fadd (a, a), b) -> mad 2.0, a, b
+ if (LHS.getOpcode() == ISD::FADD) {
+ SDValue A = LHS.getOperand(0);
+ if (A == LHS.getOperand(1)) {
+ const SDValue Two = DAG.getTargetConstantFP(2.0, MVT::f32);
+ return DAG.getNode(AMDGPUISD::MAD, DL, VT, Two, A, RHS);
+ }
+ }
+
+ // fadd (b, fadd (a, a)) -> mad 2.0, a, b
+ if (RHS.getOpcode() == ISD::FADD) {
+ SDValue A = RHS.getOperand(0);
+ if (A == RHS.getOperand(1)) {
+ const SDValue Two = DAG.getTargetConstantFP(2.0, MVT::f32);
+ return DAG.getNode(AMDGPUISD::MAD, DL, VT, Two, A, LHS);
+ }
+ }
+
+ break;
+ }
+ case ISD::FSUB: {
+ if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
+ break;
+
+ EVT VT = N->getValueType(0);
+
+ // Try to get the fneg to fold into the source modifier. This undoes generic
+ // DAG combines and folds them into the mad.
+ if (VT == MVT::f32) {
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
+
+ if (LHS.getOpcode() == ISD::FMUL) {
+ // (fsub (fmul a, b), c) -> mad a, b, (fneg c)
+
+ SDValue A = LHS.getOperand(0);
+ SDValue B = LHS.getOperand(1);
+ SDValue C = DAG.getNode(ISD::FNEG, DL, VT, RHS);
+
+ return DAG.getNode(AMDGPUISD::MAD, DL, VT, A, B, C);
+ }
+
+ if (RHS.getOpcode() == ISD::FMUL) {
+ // (fsub c, (fmul a, b)) -> mad (fneg a), b, c
+
+ SDValue A = DAG.getNode(ISD::FNEG, DL, VT, RHS.getOperand(0));
+ SDValue B = RHS.getOperand(1);
+ SDValue C = LHS;
+
+ return DAG.getNode(AMDGPUISD::MAD, DL, VT, A, B, C);
+ }
+
+ if (LHS.getOpcode() == ISD::FADD) {
+ // (fsub (fadd a, a), c) -> mad 2.0, a, (fneg c)
+
+ SDValue A = LHS.getOperand(0);
+ if (A == LHS.getOperand(1)) {
+ const SDValue Two = DAG.getTargetConstantFP(2.0, MVT::f32);
+ SDValue NegRHS = DAG.getNode(ISD::FNEG, DL, VT, RHS);
+
+ return DAG.getNode(AMDGPUISD::MAD, DL, VT, Two, A, NegRHS);
+ }
+ }
+
+ if (RHS.getOpcode() == ISD::FADD) {
+ // (fsub c, (fadd a, a)) -> mad -2.0, a, c
+
+ SDValue A = RHS.getOperand(0);
+ if (A == RHS.getOperand(1)) {
+ const SDValue NegTwo = DAG.getTargetConstantFP(-2.0, MVT::f32);
+ return DAG.getNode(AMDGPUISD::MAD, DL, VT, NegTwo, A, LHS);
+ }
+ }
+ }
+
+ break;
+ }