Make LowerSIGN_EXTEND_INREG split 256-bit vectors when AVX1 is enabled and use AVX2...
authorCraig Topper <craig.topper@gmail.com>
Mon, 21 Nov 2011 01:12:36 +0000 (01:12 +0000)
committerCraig Topper <craig.topper@gmail.com>
Mon, 21 Nov 2011 01:12:36 +0000 (01:12 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@145022 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
test/CodeGen/X86/avx-shift.ll
test/CodeGen/X86/avx2-shift.ll

index 470a115747895f9f27e26d357c8b85717f9589c2..4ba4b9357184231bfb13790432fc5874d50de85a 100644 (file)
@@ -10571,9 +10571,9 @@ SDValue X86TargetLowering::LowerXALUO(SDValue Op, SelectionDAG &DAG) const {
 
 SDValue X86TargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op, SelectionDAG &DAG) const{
   DebugLoc dl = Op.getDebugLoc();
-  SDNode* Node = Op.getNode();
-  EVT ExtraVT = cast<VTSDNode>(Node->getOperand(1))->getVT();
-  EVT VT = Node->getValueType(0);
+  EVT ExtraVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
+  EVT VT = Op.getValueType();
+
   if (Subtarget->hasXMMInt() && VT.isVector()) {
     unsigned BitsDiff = VT.getScalarType().getSizeInBits() -
                         ExtraVT.getScalarType().getSizeInBits();
@@ -10584,21 +10584,55 @@ SDValue X86TargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op, SelectionDAG &DAG)
     switch (VT.getSimpleVT().SimpleTy) {
       default:
         return SDValue();
-      case MVT::v4i32: {
+      case MVT::v4i32:
         SHLIntrinsicsID = Intrinsic::x86_sse2_pslli_d;
         SRAIntrinsicsID = Intrinsic::x86_sse2_psrai_d;
         break;
-      }
-      case MVT::v8i16: {
+      case MVT::v8i16:
         SHLIntrinsicsID = Intrinsic::x86_sse2_pslli_w;
         SRAIntrinsicsID = Intrinsic::x86_sse2_psrai_w;
         break;
-      }
+      case MVT::v8i32:
+      case MVT::v16i16:
+        if (!Subtarget->hasAVX())
+          return SDValue();
+        if (!Subtarget->hasAVX2()) {
+          // needs to be split
+          int NumElems = VT.getVectorNumElements();
+          SDValue Idx0 = DAG.getConstant(0, MVT::i32);
+          SDValue Idx1 = DAG.getConstant(NumElems/2, MVT::i32);
+
+          // Extract the LHS vectors
+          SDValue LHS = Op.getOperand(0);
+          SDValue LHS1 = Extract128BitVector(LHS, Idx0, DAG, dl);
+          SDValue LHS2 = Extract128BitVector(LHS, Idx1, DAG, dl);
+
+          MVT EltVT = VT.getVectorElementType().getSimpleVT();
+          EVT NewVT = MVT::getVectorVT(EltVT, NumElems/2);
+
+          EVT ExtraEltVT = ExtraVT.getVectorElementType();
+          int ExtraNumElems = ExtraVT.getVectorNumElements();
+          ExtraVT = EVT::getVectorVT(*DAG.getContext(), ExtraEltVT,
+                                     ExtraNumElems/2);
+          SDValue Extra = DAG.getValueType(ExtraVT);
+
+          LHS1 = DAG.getNode(Op.getOpcode(), dl, NewVT, LHS1, Extra);
+          LHS2 = DAG.getNode(Op.getOpcode(), dl, NewVT, LHS2, Extra);
+
+          return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, LHS1, LHS2);;
+        }
+        if (VT == MVT::v8i32) {
+          SHLIntrinsicsID = Intrinsic::x86_avx2_pslli_d;
+          SRAIntrinsicsID = Intrinsic::x86_avx2_psrai_d;
+        } else {
+          SHLIntrinsicsID = Intrinsic::x86_avx2_pslli_w;
+          SRAIntrinsicsID = Intrinsic::x86_avx2_psrai_w;
+        }
     }
 
     SDValue Tmp1 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
                          DAG.getConstant(SHLIntrinsicsID, MVT::i32),
-                         Node->getOperand(0), ShAmt);
+                         Op.getOperand(0), ShAmt);
 
     return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
                        DAG.getConstant(SRAIntrinsicsID, MVT::i32),
index a33423d7c553432e41432933ea74ceb90ce233be..681747b844a03491b9e1bbb83b1aa8091522c642 100644 (file)
@@ -112,3 +112,27 @@ define <8 x i32> @vshift08(<8 x i32> %a) nounwind {
   ret <8 x i32> %bitop
 }
 
+;;; Uses shifts for sign extension
+; CHECK: _sext_v16i16
+; CHECK: vpsllw
+; CHECK: vpsraw
+; CHECK: vpsllw
+; CHECK: vpsraw
+; CHECK: vinsertf128
+define <16 x i16> @sext_v16i16(<16 x i16> %a) nounwind {
+  %b = trunc <16 x i16> %a to <16 x i8>
+  %c = sext <16 x i8> %b to <16 x i16>
+  ret <16 x i16> %c
+}
+
+; CHECK: _sext_v8i32
+; CHECK: vpslld
+; CHECK: vpsrad
+; CHECK: vpslld
+; CHECK: vpsrad
+; CHECK: vinsertf128
+define <8 x i32> @sext_v8i32(<8 x i32> %a) nounwind {
+  %b = trunc <8 x i32> %a to <8 x i16>
+  %c = sext <8 x i16> %b to <8 x i32>
+  ret <8 x i32> %c
+}
index b9d1edcb139e0e8fe4b621004bbf29ea2bf6c27d..b6cf54ebe8fada1a416d1d90b0f0ef9c388a3007 100644 (file)
@@ -246,3 +246,23 @@ define <32 x i8> @sra_v32i8(<32 x i8> %A) nounwind {
 ; CHECK: vpsubb
 ; CHECK: ret
 }
+
+; CHECK: _sext_v16i16
+; CHECK: vpsllw
+; CHECK: vpsraw
+; CHECK-NOT: vinsertf128
+define <16 x i16> @sext_v16i16(<16 x i16> %a) nounwind {
+  %b = trunc <16 x i16> %a to <16 x i8>
+  %c = sext <16 x i8> %b to <16 x i16>
+  ret <16 x i16> %c
+}
+
+; CHECK: _sext_v8i32
+; CHECK: vpslld
+; CHECK: vpsrad
+; CHECK-NOT: vinsertf128
+define <8 x i32> @sext_v8i32(<8 x i32> %a) nounwind {
+  %b = trunc <8 x i32> %a to <8 x i16>
+  %c = sext <8 x i16> %b to <8 x i32>
+  ret <8 x i32> %c
+}