From: Andrea Di Biagio Date: Sat, 26 Apr 2014 01:03:22 +0000 (+0000) Subject: [InstCombine][X86] Teach how to fold calls to SSE2/AVX2 packed logical shift X-Git-Url: http://plrg.eecs.uci.edu/git/?a=commitdiff_plain;h=96db9b8ed87e502801e3dda7d13896acd17d8128;p=oota-llvm.git [InstCombine][X86] Teach how to fold calls to SSE2/AVX2 packed logical shift right intrinsics. A packed logical shift right with a shift count bigger than or equal to the element size always produces a zero vector. In all other cases, it can be safely replaced by a 'lshr' instruction. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@207299 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 17ada47d2be..df217f19acd 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -570,8 +570,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_psll_w: case Intrinsic::x86_avx2_pslli_d: case Intrinsic::x86_avx2_pslli_q: - case Intrinsic::x86_avx2_pslli_w: { - // Simplify if count is constant. To 0 if > BitWidth, otherwise to shl. + case Intrinsic::x86_avx2_pslli_w: + case Intrinsic::x86_sse2_psrl_d: + case Intrinsic::x86_sse2_psrl_q: + case Intrinsic::x86_sse2_psrl_w: + case Intrinsic::x86_sse2_psrli_d: + case Intrinsic::x86_sse2_psrli_q: + case Intrinsic::x86_sse2_psrli_w: + case Intrinsic::x86_avx2_psrl_d: + case Intrinsic::x86_avx2_psrl_q: + case Intrinsic::x86_avx2_psrl_w: + case Intrinsic::x86_avx2_psrli_d: + case Intrinsic::x86_avx2_psrli_q: + case Intrinsic::x86_avx2_psrli_w: { + // Simplify if count is constant. To 0 if >= BitWidth, + // otherwise to shl/lshr. auto CDV = dyn_cast(II->getArgOperand(1)); auto CInt = dyn_cast(II->getArgOperand(1)); if (!CDV && !CInt) @@ -588,14 +601,33 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { VT->getElementType()->getPrimitiveSizeInBits() - 1) return ReplaceInstUsesWith( CI, ConstantAggregateZero::get(Vec->getType())); - else { - unsigned VWidth = VT->getNumElements(); - // Get a constant vector of the same type as the first operand. - auto VTCI = ConstantInt::get(VT->getElementType(), Count->getZExtValue()); - return BinaryOperator::CreateShl( - Vec, Builder->CreateVectorSplat(VWidth, VTCI)); + + bool isPackedShiftLeft = true; + switch (II->getIntrinsicID()) { + default : break; + case Intrinsic::x86_sse2_psrl_d: + case Intrinsic::x86_sse2_psrl_q: + case Intrinsic::x86_sse2_psrl_w: + case Intrinsic::x86_sse2_psrli_d: + case Intrinsic::x86_sse2_psrli_q: + case Intrinsic::x86_sse2_psrli_w: + case Intrinsic::x86_avx2_psrl_d: + case Intrinsic::x86_avx2_psrl_q: + case Intrinsic::x86_avx2_psrl_w: + case Intrinsic::x86_avx2_psrli_d: + case Intrinsic::x86_avx2_psrli_q: + case Intrinsic::x86_avx2_psrli_w: isPackedShiftLeft = false; break; } - break; + + unsigned VWidth = VT->getNumElements(); + // Get a constant vector of the same type as the first operand. + auto VTCI = ConstantInt::get(VT->getElementType(), Count->getZExtValue()); + if (isPackedShiftLeft) + return BinaryOperator::CreateShl(Vec, + Builder->CreateVectorSplat(VWidth, VTCI)); + + return BinaryOperator::CreateLShr(Vec, + Builder->CreateVectorSplat(VWidth, VTCI)); } case Intrinsic::x86_sse41_pmovsxbw: diff --git a/test/Transforms/InstCombine/vec_demanded_elts.ll b/test/Transforms/InstCombine/vec_demanded_elts.ll index 35ba09313d6..a3e978141ad 100644 --- a/test/Transforms/InstCombine/vec_demanded_elts.ll +++ b/test/Transforms/InstCombine/vec_demanded_elts.ll @@ -358,7 +358,6 @@ define <2 x i64> @test_sse2_1() nounwind readnone uwtable { %15 = bitcast <4 x i32> %14 to <2 x i64> %16 = tail call <2 x i64> @llvm.x86.sse2.pslli.q(<2 x i64> %15, i32 %S) ret <2 x i64> %16 - ; CHECK: test_sse2_1 ; CHECK: ret <2 x i64> } @@ -405,7 +404,6 @@ define <2 x i64> @test_sse2_0() nounwind readnone uwtable { %15 = bitcast <4 x i32> %14 to <2 x i64> %16 = tail call <2 x i64> @llvm.x86.sse2.pslli.q(<2 x i64> %15, i32 %S) ret <2 x i64> %16 - ; CHECK: test_sse2_0 ; CHECK: ret <2 x i64> zeroinitializer } @@ -432,6 +430,97 @@ define <4 x i64> @test_avx2_0() nounwind readnone uwtable { ; CHECK: test_avx2_0 ; CHECK: ret <4 x i64> zeroinitializer } +define <2 x i64> @test_sse2_psrl_1() nounwind readnone uwtable { + %S = bitcast i32 1 to i32 + %1 = zext i32 %S to i64 + %2 = insertelement <2 x i64> undef, i64 %1, i32 0 + %3 = insertelement <2 x i64> %2, i64 0, i32 1 + %4 = bitcast <2 x i64> %3 to <8 x i16> + %5 = tail call <8 x i16> @llvm.x86.sse2.psrl.w(<8 x i16> , <8 x i16> %4) + %6 = bitcast <8 x i16> %5 to <4 x i32> + %7 = bitcast <2 x i64> %3 to <4 x i32> + %8 = tail call <4 x i32> @llvm.x86.sse2.psrl.d(<4 x i32> %6, <4 x i32> %7) + %9 = bitcast <4 x i32> %8 to <2 x i64> + %10 = tail call <2 x i64> @llvm.x86.sse2.psrl.q(<2 x i64> %9, <2 x i64> %3) + %11 = bitcast <2 x i64> %10 to <8 x i16> + %12 = tail call <8 x i16> @llvm.x86.sse2.psrli.w(<8 x i16> %11, i32 %S) + %13 = bitcast <8 x i16> %12 to <4 x i32> + %14 = tail call <4 x i32> @llvm.x86.sse2.psrli.d(<4 x i32> %13, i32 %S) + %15 = bitcast <4 x i32> %14 to <2 x i64> + %16 = tail call <2 x i64> @llvm.x86.sse2.psrli.q(<2 x i64> %15, i32 %S) + ret <2 x i64> %16 +; CHECK: test_sse2_psrl_1 +; CHECK: ret <2 x i64> +} + +define <4 x i64> @test_avx2_psrl_1() nounwind readnone uwtable { + %S = bitcast i32 1 to i32 + %1 = zext i32 %S to i64 + %2 = insertelement <2 x i64> undef, i64 %1, i32 0 + %3 = insertelement <2 x i64> %2, i64 0, i32 1 + %4 = bitcast <2 x i64> %3 to <8 x i16> + %5 = tail call <16 x i16> @llvm.x86.avx2.psrl.w(<16 x i16> , <8 x i16> %4) + %6 = bitcast <16 x i16> %5 to <8 x i32> + %7 = bitcast <2 x i64> %3 to <4 x i32> + %8 = tail call <8 x i32> @llvm.x86.avx2.psrl.d(<8 x i32> %6, <4 x i32> %7) + %9 = bitcast <8 x i32> %8 to <4 x i64> + %10 = tail call <4 x i64> @llvm.x86.avx2.psrl.q(<4 x i64> %9, <2 x i64> %3) + %11 = bitcast <4 x i64> %10 to <16 x i16> + %12 = tail call <16 x i16> @llvm.x86.avx2.psrli.w(<16 x i16> %11, i32 %S) + %13 = bitcast <16 x i16> %12 to <8 x i32> + %14 = tail call <8 x i32> @llvm.x86.avx2.psrli.d(<8 x i32> %13, i32 %S) + %15 = bitcast <8 x i32> %14 to <4 x i64> + %16 = tail call <4 x i64> @llvm.x86.avx2.psrli.q(<4 x i64> %15, i32 %S) + ret <4 x i64> %16 +; CHECK: test_avx2_psrl_1 +; CHECK: ret <4 x i64> +} + +define <2 x i64> @test_sse2_psrl_0() nounwind readnone uwtable { + %S = bitcast i32 128 to i32 + %1 = zext i32 %S to i64 + %2 = insertelement <2 x i64> undef, i64 %1, i32 0 + %3 = insertelement <2 x i64> %2, i64 0, i32 1 + %4 = bitcast <2 x i64> %3 to <8 x i16> + %5 = tail call <8 x i16> @llvm.x86.sse2.psrl.w(<8 x i16> , <8 x i16> %4) + %6 = bitcast <8 x i16> %5 to <4 x i32> + %7 = bitcast <2 x i64> %3 to <4 x i32> + %8 = tail call <4 x i32> @llvm.x86.sse2.psrl.d(<4 x i32> %6, <4 x i32> %7) + %9 = bitcast <4 x i32> %8 to <2 x i64> + %10 = tail call <2 x i64> @llvm.x86.sse2.psrl.q(<2 x i64> %9, <2 x i64> %3) + %11 = bitcast <2 x i64> %10 to <8 x i16> + %12 = tail call <8 x i16> @llvm.x86.sse2.psrli.w(<8 x i16> %11, i32 %S) + %13 = bitcast <8 x i16> %12 to <4 x i32> + %14 = tail call <4 x i32> @llvm.x86.sse2.psrli.d(<4 x i32> %13, i32 %S) + %15 = bitcast <4 x i32> %14 to <2 x i64> + %16 = tail call <2 x i64> @llvm.x86.sse2.psrli.q(<2 x i64> %15, i32 %S) + ret <2 x i64> %16 +; CHECK: test_sse2_psrl_0 +; CHECK: ret <2 x i64> zeroinitializer +} + +define <4 x i64> @test_avx2_psrl_0() nounwind readnone uwtable { + %S = bitcast i32 128 to i32 + %1 = zext i32 %S to i64 + %2 = insertelement <2 x i64> undef, i64 %1, i32 0 + %3 = insertelement <2 x i64> %2, i64 0, i32 1 + %4 = bitcast <2 x i64> %3 to <8 x i16> + %5 = tail call <16 x i16> @llvm.x86.avx2.psrl.w(<16 x i16> , <8 x i16> %4) + %6 = bitcast <16 x i16> %5 to <8 x i32> + %7 = bitcast <2 x i64> %3 to <4 x i32> + %8 = tail call <8 x i32> @llvm.x86.avx2.psrl.d(<8 x i32> %6, <4 x i32> %7) + %9 = bitcast <8 x i32> %8 to <4 x i64> + %10 = tail call <4 x i64> @llvm.x86.avx2.psrl.q(<4 x i64> %9, <2 x i64> %3) + %11 = bitcast <4 x i64> %10 to <16 x i16> + %12 = tail call <16 x i16> @llvm.x86.avx2.psrli.w(<16 x i16> %11, i32 %S) + %13 = bitcast <16 x i16> %12 to <8 x i32> + %14 = tail call <8 x i32> @llvm.x86.avx2.psrli.d(<8 x i32> %13, i32 %S) + %15 = bitcast <8 x i32> %14 to <4 x i64> + %16 = tail call <4 x i64> @llvm.x86.avx2.psrli.q(<4 x i64> %15, i32 %S) + ret <4 x i64> %16 +; CHECK: test_avx2_psrl_0 +; CHECK: ret <4 x i64> zeroinitializer +} declare <4 x i64> @llvm.x86.avx2.pslli.q(<4 x i64>, i32) #1 declare <8 x i32> @llvm.x86.avx2.pslli.d(<8 x i32>, i32) #1 @@ -445,5 +534,17 @@ declare <8 x i16> @llvm.x86.sse2.pslli.w(<8 x i16>, i32) #1 declare <2 x i64> @llvm.x86.sse2.psll.q(<2 x i64>, <2 x i64>) #1 declare <4 x i32> @llvm.x86.sse2.psll.d(<4 x i32>, <4 x i32>) #1 declare <8 x i16> @llvm.x86.sse2.psll.w(<8 x i16>, <8 x i16>) #1 +declare <4 x i64> @llvm.x86.avx2.psrli.q(<4 x i64>, i32) #1 +declare <8 x i32> @llvm.x86.avx2.psrli.d(<8 x i32>, i32) #1 +declare <16 x i16> @llvm.x86.avx2.psrli.w(<16 x i16>, i32) #1 +declare <4 x i64> @llvm.x86.avx2.psrl.q(<4 x i64>, <2 x i64>) #1 +declare <8 x i32> @llvm.x86.avx2.psrl.d(<8 x i32>, <4 x i32>) #1 +declare <16 x i16> @llvm.x86.avx2.psrl.w(<16 x i16>, <8 x i16>) #1 +declare <2 x i64> @llvm.x86.sse2.psrli.q(<2 x i64>, i32) #1 +declare <4 x i32> @llvm.x86.sse2.psrli.d(<4 x i32>, i32) #1 +declare <8 x i16> @llvm.x86.sse2.psrli.w(<8 x i16>, i32) #1 +declare <2 x i64> @llvm.x86.sse2.psrl.q(<2 x i64>, <2 x i64>) #1 +declare <4 x i32> @llvm.x86.sse2.psrl.d(<4 x i32>, <4 x i32>) #1 +declare <8 x i16> @llvm.x86.sse2.psrl.w(<8 x i16>, <8 x i16>) #1 attributes #1 = { nounwind readnone }