From: Matt Arsenault Date: Thu, 24 Jul 2014 06:59:20 +0000 (+0000) Subject: R600: Fix LowerSDIV24 X-Git-Url: http://plrg.eecs.uci.edu/git/?a=commitdiff_plain;h=f303d037f248109e01eda9fd95d4e92d2efb8116;p=oota-llvm.git R600: Fix LowerSDIV24 Use ComputeNumSignBits instead of checking for i8 / i16 which only worked when AMDIL was lying about having legal i8 / i16. If an integer is known to fit in 24-bits, we can do division faster with float ops. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@213843 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Target/R600/AMDGPUISelLowering.cpp b/lib/Target/R600/AMDGPUISelLowering.cpp index ffd6357a9a0..1eccaafe8b9 100644 --- a/lib/Target/R600/AMDGPUISelLowering.cpp +++ b/lib/Target/R600/AMDGPUISelLowering.cpp @@ -250,7 +250,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(TargetMachine &TM) : const MVT ScalarIntVTs[] = { MVT::i32, MVT::i64 }; for (MVT VT : ScalarIntVTs) { setOperationAction(ISD::SREM, VT, Expand); - setOperationAction(ISD::SDIV, VT, Expand); + setOperationAction(ISD::SDIV, VT, Custom); // GPU does not have divrem function for signed or unsigned. setOperationAction(ISD::SDIVREM, VT, Custom); @@ -1272,85 +1272,83 @@ SDValue AMDGPUTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const { return SDValue(); } +// This is a shortcut for integer division because we have fast i32<->f32 +// conversions, and fast f32 reciprocal instructions. The fractional part of a +// float is enough to accurately represent up to a 24-bit integer. SDValue AMDGPUTargetLowering::LowerSDIV24(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); - EVT OVT = Op.getValueType(); + EVT VT = Op.getValueType(); SDValue LHS = Op.getOperand(0); SDValue RHS = Op.getOperand(1); - MVT INTTY; - MVT FLTTY; - if (!OVT.isVector()) { - INTTY = MVT::i32; - FLTTY = MVT::f32; - } else if (OVT.getVectorNumElements() == 2) { - INTTY = MVT::v2i32; - FLTTY = MVT::v2f32; - } else if (OVT.getVectorNumElements() == 4) { - INTTY = MVT::v4i32; - FLTTY = MVT::v4f32; + MVT IntVT = MVT::i32; + MVT FltVT = MVT::f32; + + if (VT.isVector()) { + unsigned NElts = VT.getVectorNumElements(); + IntVT = MVT::getVectorVT(MVT::i32, NElts); + FltVT = MVT::getVectorVT(MVT::f32, NElts); } - unsigned bitsize = OVT.getScalarType().getSizeInBits(); + + unsigned BitSize = VT.getScalarType().getSizeInBits(); + // char|short jq = ia ^ ib; - SDValue jq = DAG.getNode(ISD::XOR, DL, OVT, LHS, RHS); + SDValue jq = DAG.getNode(ISD::XOR, DL, VT, LHS, RHS); // jq = jq >> (bitsize - 2) - jq = DAG.getNode(ISD::SRA, DL, OVT, jq, DAG.getConstant(bitsize - 2, OVT)); + jq = DAG.getNode(ISD::SRA, DL, VT, jq, DAG.getConstant(BitSize - 2, VT)); // jq = jq | 0x1 - jq = DAG.getNode(ISD::OR, DL, OVT, jq, DAG.getConstant(1, OVT)); + jq = DAG.getNode(ISD::OR, DL, VT, jq, DAG.getConstant(1, VT)); // jq = (int)jq - jq = DAG.getSExtOrTrunc(jq, DL, INTTY); + jq = DAG.getSExtOrTrunc(jq, DL, IntVT); // int ia = (int)LHS; - SDValue ia = DAG.getSExtOrTrunc(LHS, DL, INTTY); + SDValue ia = DAG.getSExtOrTrunc(LHS, DL, IntVT); // int ib, (int)RHS; - SDValue ib = DAG.getSExtOrTrunc(RHS, DL, INTTY); + SDValue ib = DAG.getSExtOrTrunc(RHS, DL, IntVT); // float fa = (float)ia; - SDValue fa = DAG.getNode(ISD::SINT_TO_FP, DL, FLTTY, ia); + SDValue fa = DAG.getNode(ISD::SINT_TO_FP, DL, FltVT, ia); // float fb = (float)ib; - SDValue fb = DAG.getNode(ISD::SINT_TO_FP, DL, FLTTY, ib); + SDValue fb = DAG.getNode(ISD::SINT_TO_FP, DL, FltVT, ib); // float fq = native_divide(fa, fb); - SDValue fq = DAG.getNode(ISD::FMUL, DL, FLTTY, - fa, DAG.getNode(AMDGPUISD::RCP, DL, FLTTY, fb)); + SDValue fq = DAG.getNode(ISD::FMUL, DL, FltVT, + fa, DAG.getNode(AMDGPUISD::RCP, DL, FltVT, fb)); // fq = trunc(fq); - fq = DAG.getNode(ISD::FTRUNC, DL, FLTTY, fq); + fq = DAG.getNode(ISD::FTRUNC, DL, FltVT, fq); // float fqneg = -fq; - SDValue fqneg = DAG.getNode(ISD::FNEG, DL, FLTTY, fq); + SDValue fqneg = DAG.getNode(ISD::FNEG, DL, FltVT, fq); // float fr = mad(fqneg, fb, fa); - SDValue fr = DAG.getNode(ISD::FADD, DL, FLTTY, - DAG.getNode(ISD::MUL, DL, FLTTY, fqneg, fb), fa); + SDValue fr = DAG.getNode(ISD::FADD, DL, FltVT, + DAG.getNode(ISD::FMUL, DL, FltVT, fqneg, fb), fa); // int iq = (int)fq; - SDValue iq = DAG.getNode(ISD::FP_TO_SINT, DL, INTTY, fq); + SDValue iq = DAG.getNode(ISD::FP_TO_SINT, DL, IntVT, fq); // fr = fabs(fr); - fr = DAG.getNode(ISD::FABS, DL, FLTTY, fr); + fr = DAG.getNode(ISD::FABS, DL, FltVT, fr); // fb = fabs(fb); - fb = DAG.getNode(ISD::FABS, DL, FLTTY, fb); + fb = DAG.getNode(ISD::FABS, DL, FltVT, fb); + + EVT SetCCVT = getSetCCResultType(*DAG.getContext(), VT); // int cv = fr >= fb; - SDValue cv; - if (INTTY == MVT::i32) { - cv = DAG.getSetCC(DL, INTTY, fr, fb, ISD::SETOGE); - } else { - cv = DAG.getSetCC(DL, INTTY, fr, fb, ISD::SETOGE); - } + SDValue cv = DAG.getSetCC(DL, SetCCVT, fr, fb, ISD::SETOGE); + // jq = (cv ? jq : 0); - jq = DAG.getNode(ISD::SELECT, DL, OVT, cv, jq, - DAG.getConstant(0, OVT)); + jq = DAG.getNode(ISD::SELECT, DL, VT, cv, jq, DAG.getConstant(0, VT)); + // dst = iq + jq; - iq = DAG.getSExtOrTrunc(iq, DL, OVT); - iq = DAG.getNode(ISD::ADD, DL, OVT, iq, jq); - return iq; + iq = DAG.getSExtOrTrunc(iq, DL, VT); + return DAG.getNode(ISD::ADD, DL, VT, iq, jq); } SDValue AMDGPUTargetLowering::LowerSDIV32(SDValue Op, SelectionDAG &DAG) const { @@ -1425,19 +1423,20 @@ SDValue AMDGPUTargetLowering::LowerSDIV64(SDValue Op, SelectionDAG &DAG) const { SDValue AMDGPUTargetLowering::LowerSDIV(SDValue Op, SelectionDAG &DAG) const { EVT OVT = Op.getValueType().getScalarType(); - if (OVT == MVT::i64) - return LowerSDIV64(Op, DAG); + if (OVT == MVT::i32) { + if (DAG.ComputeNumSignBits(Op.getOperand(0)) > 8 && + DAG.ComputeNumSignBits(Op.getOperand(1)) > 8) { + // TODO: We technically could do this for i64, but shouldn't that just be + // handled by something generally reducing 64-bit division on 32-bit + // values to 32-bit? + return LowerSDIV24(Op, DAG); + } - if (OVT.getScalarType() == MVT::i32) return LowerSDIV32(Op, DAG); - - if (OVT == MVT::i16 || OVT == MVT::i8) { - // FIXME: We should be checking for the masked bits. This isn't reached - // because i8 and i16 are not legal types. - return LowerSDIV24(Op, DAG); } - return SDValue(Op.getNode(), 0); + assert(OVT == MVT::i64); + return LowerSDIV64(Op, DAG); } SDValue AMDGPUTargetLowering::LowerSREM32(SDValue Op, SelectionDAG &DAG) const { diff --git a/test/CodeGen/R600/sdiv24.ll b/test/CodeGen/R600/sdiv24.ll new file mode 100644 index 00000000000..84c9ecbbfda --- /dev/null +++ b/test/CodeGen/R600/sdiv24.ll @@ -0,0 +1,120 @@ +; RUN: llc -march=r600 -mcpu=SI < %s | FileCheck -check-prefix=SI -check-prefix=FUNC %s +; RUN: llc -march=r600 -mcpu=redwood < %s | FileCheck -check-prefix=EG -check-prefix=FUNC %s + +; FUNC-LABEL: @sdiv24_i8 +; SI: V_CVT_F32_I32 +; SI: V_CVT_F32_I32 +; SI: V_RCP_F32 +; SI: V_CVT_I32_F32 + +; EG: INT_TO_FLT +; EG-DAG: INT_TO_FLT +; EG-DAG: RECIP_IEEE +; EG: FLT_TO_INT +define void @sdiv24_i8(i8 addrspace(1)* %out, i8 addrspace(1)* %in) { + %den_ptr = getelementptr i8 addrspace(1)* %in, i8 1 + %num = load i8 addrspace(1) * %in + %den = load i8 addrspace(1) * %den_ptr + %result = sdiv i8 %num, %den + store i8 %result, i8 addrspace(1)* %out + ret void +} + +; FUNC-LABEL: @sdiv24_i16 +; SI: V_CVT_F32_I32 +; SI: V_CVT_F32_I32 +; SI: V_RCP_F32 +; SI: V_CVT_I32_F32 + +; EG: INT_TO_FLT +; EG-DAG: INT_TO_FLT +; EG-DAG: RECIP_IEEE +; EG: FLT_TO_INT +define void @sdiv24_i16(i16 addrspace(1)* %out, i16 addrspace(1)* %in) { + %den_ptr = getelementptr i16 addrspace(1)* %in, i16 1 + %num = load i16 addrspace(1) * %in, align 2 + %den = load i16 addrspace(1) * %den_ptr, align 2 + %result = sdiv i16 %num, %den + store i16 %result, i16 addrspace(1)* %out, align 2 + ret void +} + +; FUNC-LABEL: @sdiv24_i32 +; SI: V_CVT_F32_I32 +; SI: V_CVT_F32_I32 +; SI: V_RCP_F32 +; SI: V_CVT_I32_F32 + +; EG: INT_TO_FLT +; EG-DAG: INT_TO_FLT +; EG-DAG: RECIP_IEEE +; EG: FLT_TO_INT +define void @sdiv24_i32(i32 addrspace(1)* %out, i32 addrspace(1)* %in) { + %den_ptr = getelementptr i32 addrspace(1)* %in, i32 1 + %num = load i32 addrspace(1) * %in, align 4 + %den = load i32 addrspace(1) * %den_ptr, align 4 + %num.i24.0 = shl i32 %num, 8 + %den.i24.0 = shl i32 %den, 8 + %num.i24 = ashr i32 %num.i24.0, 8 + %den.i24 = ashr i32 %den.i24.0, 8 + %result = sdiv i32 %num.i24, %den.i24 + store i32 %result, i32 addrspace(1)* %out, align 4 + ret void +} + +; FUNC-LABEL: @sdiv25_i32 +; SI-NOT: V_CVT_F32_I32 +; SI-NOT: V_RCP_F32 + +; EG-NOT: INT_TO_FLT +; EG-NOT: RECIP_IEEE +define void @sdiv25_i32(i32 addrspace(1)* %out, i32 addrspace(1)* %in) { + %den_ptr = getelementptr i32 addrspace(1)* %in, i32 1 + %num = load i32 addrspace(1) * %in, align 4 + %den = load i32 addrspace(1) * %den_ptr, align 4 + %num.i24.0 = shl i32 %num, 7 + %den.i24.0 = shl i32 %den, 7 + %num.i24 = ashr i32 %num.i24.0, 7 + %den.i24 = ashr i32 %den.i24.0, 7 + %result = sdiv i32 %num.i24, %den.i24 + store i32 %result, i32 addrspace(1)* %out, align 4 + ret void +} + +; FUNC-LABEL: @test_no_sdiv24_i32_1 +; SI-NOT: V_CVT_F32_I32 +; SI-NOT: V_RCP_F32 + +; EG-NOT: INT_TO_FLT +; EG-NOT: RECIP_IEEE +define void @test_no_sdiv24_i32_1(i32 addrspace(1)* %out, i32 addrspace(1)* %in) { + %den_ptr = getelementptr i32 addrspace(1)* %in, i32 1 + %num = load i32 addrspace(1) * %in, align 4 + %den = load i32 addrspace(1) * %den_ptr, align 4 + %num.i24.0 = shl i32 %num, 8 + %den.i24.0 = shl i32 %den, 7 + %num.i24 = ashr i32 %num.i24.0, 8 + %den.i24 = ashr i32 %den.i24.0, 7 + %result = sdiv i32 %num.i24, %den.i24 + store i32 %result, i32 addrspace(1)* %out, align 4 + ret void +} + +; FUNC-LABEL: @test_no_sdiv24_i32_2 +; SI-NOT: V_CVT_F32_I32 +; SI-NOT: V_RCP_F32 + +; EG-NOT: INT_TO_FLT +; EG-NOT: RECIP_IEEE +define void @test_no_sdiv24_i32_2(i32 addrspace(1)* %out, i32 addrspace(1)* %in) { + %den_ptr = getelementptr i32 addrspace(1)* %in, i32 1 + %num = load i32 addrspace(1) * %in, align 4 + %den = load i32 addrspace(1) * %den_ptr, align 4 + %num.i24.0 = shl i32 %num, 7 + %den.i24.0 = shl i32 %den, 8 + %num.i24 = ashr i32 %num.i24.0, 7 + %den.i24 = ashr i32 %den.i24.0, 8 + %result = sdiv i32 %num.i24, %den.i24 + store i32 %result, i32 addrspace(1)* %out, align 4 + ret void +}