From a5f45d5444fc0768e8832b447ac1f23330d790cc Mon Sep 17 00:00:00 2001 From: Matt Arsenault Date: Mon, 29 Sep 2014 14:59:34 +0000 Subject: [PATCH] R600/SI: Fix using mad with multiplies by 2 These turn into fadds, so combine them into the target mad node. fadd (fadd (a, a), b) -> mad 2.0, a, b git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@218608 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/R600/SIISelLowering.cpp | 35 +++++++ test/CodeGen/R600/fmuladd.ll | 158 +++++++++++++++++++++++++++-- 2 files changed, 187 insertions(+), 6 deletions(-) diff --git a/lib/Target/R600/SIISelLowering.cpp b/lib/Target/R600/SIISelLowering.cpp index 417356d800f..10e6c17f6cc 100644 --- a/lib/Target/R600/SIISelLowering.cpp +++ b/lib/Target/R600/SIISelLowering.cpp @@ -226,6 +226,7 @@ SITargetLowering::SITargetLowering(TargetMachine &TM) : setOperationAction(ISD::FDIV, MVT::f32, Custom); + setTargetDAGCombine(ISD::FADD); setTargetDAGCombine(ISD::FSUB); setTargetDAGCombine(ISD::SELECT_CC); setTargetDAGCombine(ISD::SETCC); @@ -1418,6 +1419,40 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N, case ISD::UINT_TO_FP: { return performUCharToFloatCombine(N, DCI); + case ISD::FADD: { + if (DCI.getDAGCombineLevel() < AfterLegalizeDAG) + break; + + EVT VT = N->getValueType(0); + if (VT != MVT::f32) + break; + + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + + // These should really be instruction patterns, but writing patterns with + // source modiifiers is a pain. + + // fadd (fadd (a, a), b) -> mad 2.0, a, b + if (LHS.getOpcode() == ISD::FADD) { + SDValue A = LHS.getOperand(0); + if (A == LHS.getOperand(1)) { + const SDValue Two = DAG.getTargetConstantFP(2.0, MVT::f32); + return DAG.getNode(AMDGPUISD::MAD, DL, VT, Two, A, RHS); + } + } + + // fadd (b, fadd (a, a)) -> mad 2.0, a, b + if (RHS.getOpcode() == ISD::FADD) { + SDValue A = RHS.getOperand(0); + if (A == RHS.getOperand(1)) { + const SDValue Two = DAG.getTargetConstantFP(2.0, MVT::f32); + return DAG.getNode(AMDGPUISD::MAD, DL, VT, Two, A, LHS); + } + } + + break; + } case ISD::FSUB: { if (DCI.getDAGCombineLevel() < AfterLegalizeDAG) break; diff --git a/test/CodeGen/R600/fmuladd.ll b/test/CodeGen/R600/fmuladd.ll index 48944f62948..6f581f21d9b 100644 --- a/test/CodeGen/R600/fmuladd.ll +++ b/test/CodeGen/R600/fmuladd.ll @@ -1,6 +1,11 @@ -; RUN: llc < %s -march=r600 -mcpu=SI -verify-machineinstrs | FileCheck %s +; RUN: llc -march=r600 -mcpu=SI -verify-machineinstrs < %s | FileCheck %s -; CHECK: @fmuladd_f32 +declare float @llvm.fmuladd.f32(float, float, float) +declare double @llvm.fmuladd.f64(double, double, double) +declare i32 @llvm.r600.read.tidig.x() nounwind readnone +declare float @llvm.fabs.f32(float) nounwind readnone + +; CHECK-LABEL: @fmuladd_f32 ; CHECK: V_MAD_F32 {{v[0-9]+, v[0-9]+, v[0-9]+, v[0-9]+}} define void @fmuladd_f32(float addrspace(1)* %out, float addrspace(1)* %in1, @@ -13,9 +18,7 @@ define void @fmuladd_f32(float addrspace(1)* %out, float addrspace(1)* %in1, ret void } -declare float @llvm.fmuladd.f32(float, float, float) - -; CHECK: @fmuladd_f64 +; CHECK-LABEL: @fmuladd_f64 ; CHECK: V_FMA_F64 {{v\[[0-9]+:[0-9]+\], v\[[0-9]+:[0-9]+\], v\[[0-9]+:[0-9]+\], v\[[0-9]+:[0-9]+\]}} define void @fmuladd_f64(double addrspace(1)* %out, double addrspace(1)* %in1, @@ -28,4 +31,147 @@ define void @fmuladd_f64(double addrspace(1)* %out, double addrspace(1)* %in1, ret void } -declare double @llvm.fmuladd.f64(double, double, double) +; CHECK-LABEL: @fmuladd_2.0_a_b_f32 +; CHECK-DAG: BUFFER_LOAD_DWORD [[R1:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64{{$}} +; CHECK-DAG: BUFFER_LOAD_DWORD [[R2:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64 offset:0x4 +; CHECK: V_MAD_F32 [[RESULT:v[0-9]+]], 2.0, [[R1]], [[R2]] +; CHECK: BUFFER_STORE_DWORD [[RESULT]] +define void @fmuladd_2.0_a_b_f32(float addrspace(1)* %out, float addrspace(1)* %in) { + %tid = call i32 @llvm.r600.read.tidig.x() nounwind readnone + %gep.0 = getelementptr float addrspace(1)* %out, i32 %tid + %gep.1 = getelementptr float addrspace(1)* %gep.0, i32 1 + %gep.out = getelementptr float addrspace(1)* %out, i32 %tid + + %r1 = load float addrspace(1)* %gep.0 + %r2 = load float addrspace(1)* %gep.1 + + %r3 = tail call float @llvm.fmuladd.f32(float 2.0, float %r1, float %r2) + store float %r3, float addrspace(1)* %gep.out + ret void +} + +; CHECK-LABEL: @fmuladd_a_2.0_b_f32 +; CHECK-DAG: BUFFER_LOAD_DWORD [[R1:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64{{$}} +; CHECK-DAG: BUFFER_LOAD_DWORD [[R2:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64 offset:0x4 +; CHECK: V_MAD_F32 [[RESULT:v[0-9]+]], 2.0, [[R1]], [[R2]] +; CHECK: BUFFER_STORE_DWORD [[RESULT]] +define void @fmuladd_a_2.0_b_f32(float addrspace(1)* %out, float addrspace(1)* %in) { + %tid = call i32 @llvm.r600.read.tidig.x() nounwind readnone + %gep.0 = getelementptr float addrspace(1)* %out, i32 %tid + %gep.1 = getelementptr float addrspace(1)* %gep.0, i32 1 + %gep.out = getelementptr float addrspace(1)* %out, i32 %tid + + %r1 = load float addrspace(1)* %gep.0 + %r2 = load float addrspace(1)* %gep.1 + + %r3 = tail call float @llvm.fmuladd.f32(float %r1, float 2.0, float %r2) + store float %r3, float addrspace(1)* %gep.out + ret void +} + +; CHECK-LABEL: @fadd_a_a_b_f32 +; CHECK-DAG: BUFFER_LOAD_DWORD [[R1:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64{{$}} +; CHECK-DAG: BUFFER_LOAD_DWORD [[R2:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64 offset:0x4 +; CHECK: V_MAD_F32 [[RESULT:v[0-9]+]], 2.0, [[R1]], [[R2]] +; CHECK: BUFFER_STORE_DWORD [[RESULT]] +define void @fadd_a_a_b_f32(float addrspace(1)* %out, + float addrspace(1)* %in1, + float addrspace(1)* %in2) { + %tid = call i32 @llvm.r600.read.tidig.x() nounwind readnone + %gep.0 = getelementptr float addrspace(1)* %out, i32 %tid + %gep.1 = getelementptr float addrspace(1)* %gep.0, i32 1 + %gep.out = getelementptr float addrspace(1)* %out, i32 %tid + + %r0 = load float addrspace(1)* %gep.0 + %r1 = load float addrspace(1)* %gep.1 + + %add.0 = fadd float %r0, %r0 + %add.1 = fadd float %add.0, %r1 + store float %add.1, float addrspace(1)* %out + ret void +} + +; CHECK-LABEL: @fadd_b_a_a_f32 +; CHECK-DAG: BUFFER_LOAD_DWORD [[R1:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64{{$}} +; CHECK-DAG: BUFFER_LOAD_DWORD [[R2:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64 offset:0x4 +; CHECK: V_MAD_F32 [[RESULT:v[0-9]+]], 2.0, [[R1]], [[R2]] +; CHECK: BUFFER_STORE_DWORD [[RESULT]] +define void @fadd_b_a_a_f32(float addrspace(1)* %out, + float addrspace(1)* %in1, + float addrspace(1)* %in2) { + %tid = call i32 @llvm.r600.read.tidig.x() nounwind readnone + %gep.0 = getelementptr float addrspace(1)* %out, i32 %tid + %gep.1 = getelementptr float addrspace(1)* %gep.0, i32 1 + %gep.out = getelementptr float addrspace(1)* %out, i32 %tid + + %r0 = load float addrspace(1)* %gep.0 + %r1 = load float addrspace(1)* %gep.1 + + %add.0 = fadd float %r0, %r0 + %add.1 = fadd float %r1, %add.0 + store float %add.1, float addrspace(1)* %out + ret void +} + +; CHECK-LABEL: @fmuladd_neg_2.0_a_b_f32 +; CHECK-DAG: BUFFER_LOAD_DWORD [[R1:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64{{$}} +; CHECK-DAG: BUFFER_LOAD_DWORD [[R2:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64 offset:0x4 +; CHECK: V_MAD_F32 [[RESULT:v[0-9]+]], [[R1]], -2.0, [[R2]] +; CHECK: BUFFER_STORE_DWORD [[RESULT]] +define void @fmuladd_neg_2.0_a_b_f32(float addrspace(1)* %out, float addrspace(1)* %in) { + %tid = call i32 @llvm.r600.read.tidig.x() nounwind readnone + %gep.0 = getelementptr float addrspace(1)* %out, i32 %tid + %gep.1 = getelementptr float addrspace(1)* %gep.0, i32 1 + %gep.out = getelementptr float addrspace(1)* %out, i32 %tid + + %r1 = load float addrspace(1)* %gep.0 + %r2 = load float addrspace(1)* %gep.1 + + %r3 = tail call float @llvm.fmuladd.f32(float -2.0, float %r1, float %r2) + store float %r3, float addrspace(1)* %gep.out + ret void +} + + +; CHECK-LABEL: @fmuladd_neg_2.0_neg_a_b_f32 +; CHECK-DAG: BUFFER_LOAD_DWORD [[R1:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64{{$}} +; CHECK-DAG: BUFFER_LOAD_DWORD [[R2:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64 offset:0x4 +; CHECK: V_MAD_F32 [[RESULT:v[0-9]+]], 2.0, [[R1]], [[R2]] +; CHECK: BUFFER_STORE_DWORD [[RESULT]] +define void @fmuladd_neg_2.0_neg_a_b_f32(float addrspace(1)* %out, float addrspace(1)* %in) { + %tid = call i32 @llvm.r600.read.tidig.x() nounwind readnone + %gep.0 = getelementptr float addrspace(1)* %out, i32 %tid + %gep.1 = getelementptr float addrspace(1)* %gep.0, i32 1 + %gep.out = getelementptr float addrspace(1)* %out, i32 %tid + + %r1 = load float addrspace(1)* %gep.0 + %r2 = load float addrspace(1)* %gep.1 + + %r1.fneg = fsub float -0.000000e+00, %r1 + + %r3 = tail call float @llvm.fmuladd.f32(float -2.0, float %r1.fneg, float %r2) + store float %r3, float addrspace(1)* %gep.out + ret void +} + + +; CHECK-LABEL: @fmuladd_2.0_neg_a_b_f32 +; CHECK-DAG: BUFFER_LOAD_DWORD [[R1:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64{{$}} +; CHECK-DAG: BUFFER_LOAD_DWORD [[R2:v[0-9]+]], {{v\[[0-9]+:[0-9]+\]}}, {{s\[[0-9]+:[0-9]+\]}}, 0 addr64 offset:0x4 +; CHECK: V_MAD_F32 [[RESULT:v[0-9]+]], [[R1]], -2.0, [[R2]] +; CHECK: BUFFER_STORE_DWORD [[RESULT]] +define void @fmuladd_2.0_neg_a_b_f32(float addrspace(1)* %out, float addrspace(1)* %in) { + %tid = call i32 @llvm.r600.read.tidig.x() nounwind readnone + %gep.0 = getelementptr float addrspace(1)* %out, i32 %tid + %gep.1 = getelementptr float addrspace(1)* %gep.0, i32 1 + %gep.out = getelementptr float addrspace(1)* %out, i32 %tid + + %r1 = load float addrspace(1)* %gep.0 + %r2 = load float addrspace(1)* %gep.1 + + %r1.fneg = fsub float -0.000000e+00, %r1 + + %r3 = tail call float @llvm.fmuladd.f32(float 2.0, float %r1.fneg, float %r2) + store float %r3, float addrspace(1)* %gep.out + ret void +} -- 2.34.1