From 7be5dfd1a164707fbfc9bb49de23d68b6e15df44 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Sat, 12 Nov 2011 09:58:49 +0000 Subject: [PATCH] Add more AVX2 shift lowering support. Move AVX2 variable shift to use patterns instead of custom lowering code. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144457 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/X86/X86ISelLowering.cpp | 166 ++++++++++++++++++----------- lib/Target/X86/X86InstrSSE.td | 49 +++++++++ test/CodeGen/X86/avx-shift.ll | 39 +++++++ 3 files changed, 192 insertions(+), 62 deletions(-) diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index e77b1dfcd58..f1c80a2f4f4 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -924,10 +924,6 @@ X86TargetLowering::X86TargetLowering(X86TargetMachine &TM) // FIXME: Do we need to handle scalar-to-vector here? setOperationAction(ISD::MUL, MVT::v4i32, Legal); - // Can turn SHL into an integer multiply. - setOperationAction(ISD::SHL, MVT::v4i32, Custom); - setOperationAction(ISD::SHL, MVT::v16i8, Custom); - setOperationAction(ISD::VSELECT, MVT::v2f64, Legal); setOperationAction(ISD::VSELECT, MVT::v2i64, Legal); setOperationAction(ISD::VSELECT, MVT::v16i8, Legal); @@ -955,18 +951,32 @@ X86TargetLowering::X86TargetLowering(X86TargetMachine &TM) } if (Subtarget->hasXMMInt()) { - setOperationAction(ISD::SRL, MVT::v2i64, Custom); - setOperationAction(ISD::SRL, MVT::v4i32, Custom); - setOperationAction(ISD::SRL, MVT::v16i8, Custom); setOperationAction(ISD::SRL, MVT::v8i16, Custom); + setOperationAction(ISD::SRL, MVT::v16i8, Custom); - setOperationAction(ISD::SHL, MVT::v2i64, Custom); - setOperationAction(ISD::SHL, MVT::v4i32, Custom); setOperationAction(ISD::SHL, MVT::v8i16, Custom); + setOperationAction(ISD::SHL, MVT::v16i8, Custom); - setOperationAction(ISD::SRA, MVT::v4i32, Custom); setOperationAction(ISD::SRA, MVT::v8i16, Custom); setOperationAction(ISD::SRA, MVT::v16i8, Custom); + + if (Subtarget->hasAVX2()) { + setOperationAction(ISD::SRL, MVT::v2i64, Legal); + setOperationAction(ISD::SRL, MVT::v4i32, Legal); + + setOperationAction(ISD::SHL, MVT::v2i64, Legal); + setOperationAction(ISD::SHL, MVT::v4i32, Legal); + + setOperationAction(ISD::SRA, MVT::v4i32, Legal); + } else { + setOperationAction(ISD::SRL, MVT::v2i64, Custom); + setOperationAction(ISD::SRL, MVT::v4i32, Custom); + + setOperationAction(ISD::SHL, MVT::v2i64, Custom); + setOperationAction(ISD::SHL, MVT::v4i32, Custom); + + setOperationAction(ISD::SRA, MVT::v4i32, Custom); + } } if (Subtarget->hasSSE42() || Subtarget->hasAVX()) @@ -1009,18 +1019,14 @@ X86TargetLowering::X86TargetLowering(X86TargetMachine &TM) setOperationAction(ISD::CONCAT_VECTORS, MVT::v32i8, Custom); setOperationAction(ISD::CONCAT_VECTORS, MVT::v16i16, Custom); - setOperationAction(ISD::SRL, MVT::v4i64, Custom); - setOperationAction(ISD::SRL, MVT::v8i32, Custom); setOperationAction(ISD::SRL, MVT::v16i16, Custom); setOperationAction(ISD::SRL, MVT::v32i8, Custom); - setOperationAction(ISD::SHL, MVT::v4i64, Custom); - setOperationAction(ISD::SHL, MVT::v8i32, Custom); setOperationAction(ISD::SHL, MVT::v16i16, Custom); setOperationAction(ISD::SHL, MVT::v32i8, Custom); - setOperationAction(ISD::SRA, MVT::v8i32, Custom); setOperationAction(ISD::SRA, MVT::v16i16, Custom); + setOperationAction(ISD::SRA, MVT::v32i8, Custom); setOperationAction(ISD::SETCC, MVT::v32i8, Custom); setOperationAction(ISD::SETCC, MVT::v16i16, Custom); @@ -1053,6 +1059,14 @@ X86TargetLowering::X86TargetLowering(X86TargetMachine &TM) // Don't lower v32i8 because there is no 128-bit byte mul setOperationAction(ISD::VSELECT, MVT::v32i8, Legal); + + setOperationAction(ISD::SRL, MVT::v4i64, Legal); + setOperationAction(ISD::SRL, MVT::v8i32, Legal); + + setOperationAction(ISD::SHL, MVT::v4i64, Legal); + setOperationAction(ISD::SHL, MVT::v8i32, Legal); + + setOperationAction(ISD::SRA, MVT::v8i32, Legal); } else { setOperationAction(ISD::ADD, MVT::v4i64, Custom); setOperationAction(ISD::ADD, MVT::v8i32, Custom); @@ -1068,6 +1082,14 @@ X86TargetLowering::X86TargetLowering(X86TargetMachine &TM) setOperationAction(ISD::MUL, MVT::v8i32, Custom); setOperationAction(ISD::MUL, MVT::v16i16, Custom); // Don't lower v32i8 because there is no 128-bit byte mul + + setOperationAction(ISD::SRL, MVT::v4i64, Custom); + setOperationAction(ISD::SRL, MVT::v8i32, Custom); + + setOperationAction(ISD::SHL, MVT::v4i64, Custom); + setOperationAction(ISD::SHL, MVT::v8i32, Custom); + + setOperationAction(ISD::SRA, MVT::v8i32, Custom); } // Custom lower several nodes for 256-bit types. @@ -9510,6 +9532,14 @@ X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const // Fix vector shift instructions where the last operand is a non-immediate // i32 value. + case Intrinsic::x86_avx2_pslli_w: + case Intrinsic::x86_avx2_pslli_d: + case Intrinsic::x86_avx2_pslli_q: + case Intrinsic::x86_avx2_psrli_w: + case Intrinsic::x86_avx2_psrli_d: + case Intrinsic::x86_avx2_psrli_q: + case Intrinsic::x86_avx2_psrai_w: + case Intrinsic::x86_avx2_psrai_d: case Intrinsic::x86_sse2_pslli_w: case Intrinsic::x86_sse2_pslli_d: case Intrinsic::x86_sse2_pslli_q: @@ -9557,6 +9587,30 @@ X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const case Intrinsic::x86_sse2_psrai_d: NewIntNo = Intrinsic::x86_sse2_psra_d; break; + case Intrinsic::x86_avx2_pslli_w: + NewIntNo = Intrinsic::x86_avx2_psll_w; + break; + case Intrinsic::x86_avx2_pslli_d: + NewIntNo = Intrinsic::x86_avx2_psll_d; + break; + case Intrinsic::x86_avx2_pslli_q: + NewIntNo = Intrinsic::x86_avx2_psll_q; + break; + case Intrinsic::x86_avx2_psrli_w: + NewIntNo = Intrinsic::x86_avx2_psrl_w; + break; + case Intrinsic::x86_avx2_psrli_d: + NewIntNo = Intrinsic::x86_avx2_psrl_d; + break; + case Intrinsic::x86_avx2_psrli_q: + NewIntNo = Intrinsic::x86_avx2_psrl_q; + break; + case Intrinsic::x86_avx2_psrai_w: + NewIntNo = Intrinsic::x86_avx2_psra_w; + break; + case Intrinsic::x86_avx2_psrai_d: + NewIntNo = Intrinsic::x86_avx2_psra_d; + break; default: { ShAmtVT = MVT::v2i32; switch (IntNo) { @@ -10251,52 +10305,6 @@ SDValue X86TargetLowering::LowerShift(SDValue Op, SelectionDAG &DAG) const { } } - // AVX2 variable shifts - if (Subtarget->hasAVX2()) { - if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL) - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, - DAG.getConstant(Intrinsic::x86_avx2_psllv_d, MVT::i32), - R, Amt); - if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SHL) - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, - DAG.getConstant(Intrinsic::x86_avx2_psllv_d_256, MVT::i32), - R, Amt); - if (VT == MVT::v2i64 && Op->getOpcode() == ISD::SHL) - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, - DAG.getConstant(Intrinsic::x86_avx2_psllv_q, MVT::i32), - R, Amt); - if (VT == MVT::v4i64 && Op->getOpcode() == ISD::SHL) - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, - DAG.getConstant(Intrinsic::x86_avx2_psllv_q_256, MVT::i32), - R, Amt); - - if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SRL) - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, - DAG.getConstant(Intrinsic::x86_avx2_psrlv_d, MVT::i32), - R, Amt); - if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SRL) - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, - DAG.getConstant(Intrinsic::x86_avx2_psrlv_d_256, MVT::i32), - R, Amt); - if (VT == MVT::v2i64 && Op->getOpcode() == ISD::SRL) - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, - DAG.getConstant(Intrinsic::x86_avx2_psrlv_q, MVT::i32), - R, Amt); - if (VT == MVT::v4i64 && Op->getOpcode() == ISD::SRL) - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, - DAG.getConstant(Intrinsic::x86_avx2_psrlv_q_256, MVT::i32), - R, Amt); - - if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SRA) - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, - DAG.getConstant(Intrinsic::x86_avx2_psrav_d, MVT::i32), - R, Amt); - if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SRA) - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, - DAG.getConstant(Intrinsic::x86_avx2_psrav_d_256, MVT::i32), - R, Amt); - } - // Lower SHL with variable shift amount. if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL) { Op = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, @@ -13464,7 +13472,9 @@ static SDValue PerformShiftCombine(SDNode* N, SelectionDAG &DAG, if (!Subtarget->hasXMMInt()) return SDValue(); - if (VT != MVT::v2i64 && VT != MVT::v4i32 && VT != MVT::v8i16) + if (VT != MVT::v2i64 && VT != MVT::v4i32 && VT != MVT::v8i16 && + (!Subtarget->hasAVX2() || + (VT != MVT::v4i64 && VT != MVT::v8i32 && VT != MVT::v16i16))) return SDValue(); SDValue ShAmtOp = N->getOperand(1); @@ -13537,6 +13547,18 @@ static SDValue PerformShiftCombine(SDNode* N, SelectionDAG &DAG, return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, DAG.getConstant(Intrinsic::x86_sse2_pslli_w, MVT::i32), ValOp, BaseShAmt); + if (VT == MVT::v4i64) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(Intrinsic::x86_avx2_pslli_q, MVT::i32), + ValOp, BaseShAmt); + if (VT == MVT::v8i32) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(Intrinsic::x86_avx2_pslli_d, MVT::i32), + ValOp, BaseShAmt); + if (VT == MVT::v16i16) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(Intrinsic::x86_avx2_pslli_w, MVT::i32), + ValOp, BaseShAmt); break; case ISD::SRA: if (VT == MVT::v4i32) @@ -13547,6 +13569,14 @@ static SDValue PerformShiftCombine(SDNode* N, SelectionDAG &DAG, return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, DAG.getConstant(Intrinsic::x86_sse2_psrai_w, MVT::i32), ValOp, BaseShAmt); + if (VT == MVT::v8i32) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(Intrinsic::x86_avx2_psrai_d, MVT::i32), + ValOp, BaseShAmt); + if (VT == MVT::v16i16) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(Intrinsic::x86_avx2_psrai_w, MVT::i32), + ValOp, BaseShAmt); break; case ISD::SRL: if (VT == MVT::v2i64) @@ -13561,6 +13591,18 @@ static SDValue PerformShiftCombine(SDNode* N, SelectionDAG &DAG, return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, DAG.getConstant(Intrinsic::x86_sse2_psrli_w, MVT::i32), ValOp, BaseShAmt); + if (VT == MVT::v4i64) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(Intrinsic::x86_avx2_psrli_q, MVT::i32), + ValOp, BaseShAmt); + if (VT == MVT::v8i32) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(Intrinsic::x86_avx2_psrli_d, MVT::i32), + ValOp, BaseShAmt); + if (VT == MVT::v16i16) + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(Intrinsic::x86_avx2_psrli_w, MVT::i32), + ValOp, BaseShAmt); break; } return SDValue(); diff --git a/lib/Target/X86/X86InstrSSE.td b/lib/Target/X86/X86InstrSSE.td index 10f527cdeda..735a30f0dde 100644 --- a/lib/Target/X86/X86InstrSSE.td +++ b/lib/Target/X86/X86InstrSSE.td @@ -7713,3 +7713,52 @@ defm VPSRLVQ : avx2_var_shift_i64<0x45, "vpsrlvq", int_x86_avx2_psrlv_q, defm VPSRAVD : avx2_var_shift<0x46, "vpsravd", int_x86_avx2_psrav_d, int_x86_avx2_psrav_d_256>; +let Predicates = [HasAVX2] in { + def : Pat<(v4i32 (shl (v4i32 VR128:$src1), (v4i32 VR128:$src2))), + (VPSLLVDrr VR128:$src1, VR128:$src2)>; + def : Pat<(v2i64 (shl (v2i64 VR128:$src1), (v2i64 VR128:$src2))), + (VPSLLVQrr VR128:$src1, VR128:$src2)>; + def : Pat<(v4i32 (srl (v4i32 VR128:$src1), (v4i32 VR128:$src2))), + (VPSRLVDrr VR128:$src1, VR128:$src2)>; + def : Pat<(v2i64 (srl (v2i64 VR128:$src1), (v2i64 VR128:$src2))), + (VPSRLVQrr VR128:$src1, VR128:$src2)>; + def : Pat<(v4i32 (sra (v4i32 VR128:$src1), (v4i32 VR128:$src2))), + (VPSRAVDrr VR128:$src1, VR128:$src2)>; + def : Pat<(v8i32 (shl (v8i32 VR256:$src1), (v8i32 VR256:$src2))), + (VPSLLVDYrr VR256:$src1, VR256:$src2)>; + def : Pat<(v4i64 (shl (v4i64 VR256:$src1), (v4i64 VR256:$src2))), + (VPSLLVQYrr VR256:$src1, VR256:$src2)>; + def : Pat<(v8i32 (srl (v8i32 VR256:$src1), (v8i32 VR256:$src2))), + (VPSRLVDYrr VR256:$src1, VR256:$src2)>; + def : Pat<(v4i64 (srl (v4i64 VR256:$src1), (v4i64 VR256:$src2))), + (VPSRLVQYrr VR256:$src1, VR256:$src2)>; + def : Pat<(v8i32 (sra (v8i32 VR256:$src1), (v8i32 VR256:$src2))), + (VPSRAVDYrr VR256:$src1, VR256:$src2)>; + + def : Pat<(v4i32 (shl (v4i32 VR128:$src1), + (v4i32 (bitconvert (memopv2i64 addr:$src2))))), + (VPSLLVDrm VR128:$src1, addr:$src2)>; + def : Pat<(v2i64 (shl (v2i64 VR128:$src1), (memopv2i64 addr:$src2))), + (VPSLLVQrm VR128:$src1, addr:$src2)>; + def : Pat<(v4i32 (srl (v4i32 VR128:$src1), + (v4i32 (bitconvert (memopv2i64 addr:$src2))))), + (VPSRLVDrm VR128:$src1, addr:$src2)>; + def : Pat<(v2i64 (srl (v2i64 VR128:$src1), (memopv2i64 addr:$src2))), + (VPSRLVQrm VR128:$src1, addr:$src2)>; + def : Pat<(v4i32 (sra (v4i32 VR128:$src1), + (v4i32 (bitconvert (memopv2i64 addr:$src2))))), + (VPSRAVDrm VR128:$src1, addr:$src2)>; + def : Pat<(v8i32 (shl (v8i32 VR256:$src1), + (v8i32 (bitconvert (memopv4i64 addr:$src2))))), + (VPSLLVDYrm VR256:$src1, addr:$src2)>; + def : Pat<(v4i64 (shl (v4i64 VR256:$src1), (memopv4i64 addr:$src2))), + (VPSLLVQYrm VR256:$src1, addr:$src2)>; + def : Pat<(v8i32 (srl (v8i32 VR256:$src1), + (v8i32 (bitconvert (memopv4i64 addr:$src2))))), + (VPSRLVDYrm VR256:$src1, addr:$src2)>; + def : Pat<(v4i64 (srl (v4i64 VR256:$src1), (memopv4i64 addr:$src2))), + (VPSRLVQYrm VR256:$src1, addr:$src2)>; + def : Pat<(v8i32 (sra (v8i32 VR256:$src1), + (v8i32 (bitconvert (memopv4i64 addr:$src2))))), + (VPSRAVDYrm VR256:$src1, addr:$src2)>; +} diff --git a/test/CodeGen/X86/avx-shift.ll b/test/CodeGen/X86/avx-shift.ll index 3ea39a2358e..a33423d7c55 100644 --- a/test/CodeGen/X86/avx-shift.ll +++ b/test/CodeGen/X86/avx-shift.ll @@ -62,6 +62,45 @@ define <16 x i16> @vshift07(<16 x i16> %a) nounwind readnone { ret <16 x i16> %s } +; CHECK: vpsrlw +; CHECK: pand +; CHECK: pxor +; CHECK: psubb +; CHECK: vpsrlw +; CHECK: pand +; CHECK: pxor +; CHECK: psubb +define <32 x i8> @vshift09(<32 x i8> %a) nounwind readnone { + %s = ashr <32 x i8> %a, + ret <32 x i8> %s +} + +; CHECK: pxor +; CHECK: pcmpgtb +; CHECK: pcmpgtb +define <32 x i8> @vshift10(<32 x i8> %a) nounwind readnone { + %s = ashr <32 x i8> %a, + ret <32 x i8> %s +} + +; CHECK: vpsrlw +; CHECK: pand +; CHECK: vpsrlw +; CHECK: pand +define <32 x i8> @vshift11(<32 x i8> %a) nounwind readnone { + %s = lshr <32 x i8> %a, + ret <32 x i8> %s +} + +; CHECK: vpsllw +; CHECK: pand +; CHECK: vpsllw +; CHECK: pand +define <32 x i8> @vshift12(<32 x i8> %a) nounwind readnone { + %s = shl <32 x i8> %a, + ret <32 x i8> %s +} + ;;; Support variable shifts ; CHECK: _vshift08 ; CHECK: vextractf128 $1 -- 2.34.1