X86 vector element shift-by-immediate instructions take i8 immediates. Make
authorLang Hames <lhames@gmail.com>
Mon, 21 Oct 2013 17:51:24 +0000 (17:51 +0000)
committerLang Hames <lhames@gmail.com>
Mon, 21 Oct 2013 17:51:24 +0000 (17:51 +0000)
the instruction defenitions and ISEL reflect this.

Prior to this patch these instructions took an i32i8imm, and the high bits were
dropped during encoding. This led to incorrect behavior for shifts by
immediates higher than 255. This patch fixes that issue by detecting large
immediate shifts and returning constant zero (for logical shifts) or capping
the shift amount at an encodable value (for arithmetic shifts).

Fixes <rdar://problem/14968098>

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@193096 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
lib/Target/X86/X86InstrAVX512.td
lib/Target/X86/X86InstrSSE.td
test/CodeGen/X86/avx2-vector-shifts.ll
test/CodeGen/X86/sse2-vector-shifts.ll

index dea4a4616f38ef6ee98363d3e17c435062f475fe..e8918f4c34d2ce05ab6c9d6113cdc6113036e060 100644 (file)
@@ -10952,6 +10952,26 @@ static SDValue LowerVACOPY(SDValue Op, const X86Subtarget *Subtarget,
                        MachinePointerInfo(DstSV), MachinePointerInfo(SrcSV));
 }
 
+// getTargetVShiftByConstNode - Handle vector element shifts where the shift
+// amount is a constant. Takes immediate version of shift as input.
+static SDValue getTargetVShiftByConstNode(unsigned Opc, SDLoc dl, EVT VT,
+                                          SDValue SrcOp, uint64_t ShiftAmt,
+                                          SelectionDAG &DAG) {
+
+  // Check for ShiftAmt >= element width
+  if (ShiftAmt >= VT.getVectorElementType().getSizeInBits()) {
+    if (Opc == X86ISD::VSRAI)
+      ShiftAmt = VT.getVectorElementType().getSizeInBits() - 1;
+    else
+      return DAG.getConstant(0, VT);
+  }
+
+  assert((Opc == X86ISD::VSHLI || Opc == X86ISD::VSRLI || Opc == X86ISD::VSRAI)
+         && "Unknown target vector shift-by-constant node");
+
+  return DAG.getNode(Opc, dl, VT, SrcOp, DAG.getConstant(ShiftAmt, MVT::i8));
+}
+
 // getTargetVShiftNode - Handle vector element shifts where the shift amount
 // may or may not be a constant. Takes immediate version of shift as input.
 static SDValue getTargetVShiftNode(unsigned Opc, SDLoc dl, EVT VT,
@@ -10959,18 +10979,10 @@ static SDValue getTargetVShiftNode(unsigned Opc, SDLoc dl, EVT VT,
                                    SelectionDAG &DAG) {
   assert(ShAmt.getValueType() == MVT::i32 && "ShAmt is not i32");
 
-  if (isa<ConstantSDNode>(ShAmt)) {
-    // Constant may be a TargetConstant. Use a regular constant.
-    uint32_t ShiftAmt = cast<ConstantSDNode>(ShAmt)->getZExtValue();
-    switch (Opc) {
-      default: llvm_unreachable("Unknown target vector shift node");
-      case X86ISD::VSHLI:
-      case X86ISD::VSRLI:
-      case X86ISD::VSRAI:
-        return DAG.getNode(Opc, dl, VT, SrcOp,
-                           DAG.getConstant(ShiftAmt, MVT::i32));
-    }
-  }
+  // Catch shift-by-constant.
+  if (ConstantSDNode *CShAmt = dyn_cast<ConstantSDNode>(ShAmt))
+    return getTargetVShiftByConstNode(Opc, dl, VT, SrcOp,
+                                      CShAmt->getZExtValue(), DAG);
 
   // Change opcode to non-immediate version
   switch (Opc) {
@@ -12416,10 +12428,8 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget *Subtarget,
   //  AhiBlo = psllqi(AhiBlo, 32);
   //  return AloBlo + AloBhi + AhiBlo;
 
-  SDValue ShAmt = DAG.getConstant(32, MVT::i32);
-
-  SDValue Ahi = DAG.getNode(X86ISD::VSRLI, dl, VT, A, ShAmt);
-  SDValue Bhi = DAG.getNode(X86ISD::VSRLI, dl, VT, B, ShAmt);
+  SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG);
+  SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG);
 
   // Bit cast to 32-bit vectors for MULUDQ
   EVT MulVT = (VT == MVT::v2i64) ? MVT::v4i32 :
@@ -12433,8 +12443,8 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget *Subtarget,
   SDValue AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, Bhi);
   SDValue AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, B);
 
-  AloBhi = DAG.getNode(X86ISD::VSHLI, dl, VT, AloBhi, ShAmt);
-  AhiBlo = DAG.getNode(X86ISD::VSHLI, dl, VT, AhiBlo, ShAmt);
+  AloBhi = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AloBhi, 32, DAG);
+  AhiBlo = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AhiBlo, 32, DAG);
 
   SDValue Res = DAG.getNode(ISD::ADD, dl, VT, AloBlo, AloBhi);
   return DAG.getNode(ISD::ADD, dl, VT, Res, AhiBlo);
@@ -12462,7 +12472,7 @@ static SDValue LowerSDIV(SDValue Op, SelectionDAG &DAG) {
 
   if ((SplatValue != 0) &&
       (SplatValue.isPowerOf2() || (-SplatValue).isPowerOf2())) {
-    unsigned lg2 = SplatValue.countTrailingZeros();
+    unsigned Lg2 = SplatValue.countTrailingZeros();
     // Splat the sign bit.
     SmallVector<SDValue, 16> Sz(NumElts,
                                 DAG.getConstant(EltTy.getSizeInBits() - 1,
@@ -12472,13 +12482,13 @@ static SDValue LowerSDIV(SDValue Op, SelectionDAG &DAG) {
                                           NumElts));
     // Add (N0 < 0) ? abs2 - 1 : 0;
     SmallVector<SDValue, 16> Amt(NumElts,
-                                 DAG.getConstant(EltTy.getSizeInBits() - lg2,
+                                 DAG.getConstant(EltTy.getSizeInBits() - Lg2,
                                                  EltTy));
     SDValue SRL = DAG.getNode(ISD::SRL, dl, VT, SGN,
                               DAG.getNode(ISD::BUILD_VECTOR, dl, VT, &Amt[0],
                                           NumElts));
     SDValue ADD = DAG.getNode(ISD::ADD, dl, VT, N0, SRL);
-    SmallVector<SDValue, 16> Lg2Amt(NumElts, DAG.getConstant(lg2, EltTy));
+    SmallVector<SDValue, 16> Lg2Amt(NumElts, DAG.getConstant(Lg2, EltTy));
     SDValue SRA = DAG.getNode(ISD::SRA, dl, VT, ADD,
                               DAG.getNode(ISD::BUILD_VECTOR, dl, VT, &Lg2Amt[0],
                                           NumElts));
@@ -12514,21 +12524,22 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
           (Subtarget->hasAVX512() &&
            (VT == MVT::v8i64 || VT == MVT::v16i32))) {
         if (Op.getOpcode() == ISD::SHL)
-          return DAG.getNode(X86ISD::VSHLI, dl, VT, R,
-                             DAG.getConstant(ShiftAmt, MVT::i32));
+          return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
+                                            DAG);
         if (Op.getOpcode() == ISD::SRL)
-          return DAG.getNode(X86ISD::VSRLI, dl, VT, R,
-                             DAG.getConstant(ShiftAmt, MVT::i32));
+          return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
+                                            DAG);
         if (Op.getOpcode() == ISD::SRA && VT != MVT::v2i64 && VT != MVT::v4i64)
-          return DAG.getNode(X86ISD::VSRAI, dl, VT, R,
-                             DAG.getConstant(ShiftAmt, MVT::i32));
+          return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
+                                            DAG);
       }
 
       if (VT == MVT::v16i8) {
         if (Op.getOpcode() == ISD::SHL) {
           // Make a large shift.
-          SDValue SHL = DAG.getNode(X86ISD::VSHLI, dl, MVT::v8i16, R,
-                                    DAG.getConstant(ShiftAmt, MVT::i32));
+          SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl,
+                                                   MVT::v8i16, R, ShiftAmt,
+                                                   DAG); 
           SHL = DAG.getNode(ISD::BITCAST, dl, VT, SHL);
           // Zero out the rightmost bits.
           SmallVector<SDValue, 16> V(16,
@@ -12539,8 +12550,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
         }
         if (Op.getOpcode() == ISD::SRL) {
           // Make a large shift.
-          SDValue SRL = DAG.getNode(X86ISD::VSRLI, dl, MVT::v8i16, R,
-                                    DAG.getConstant(ShiftAmt, MVT::i32));
+          SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl,
+                                                   MVT::v8i16, R, ShiftAmt,
+                                                   DAG);
           SRL = DAG.getNode(ISD::BITCAST, dl, VT, SRL);
           // Zero out the leftmost bits.
           SmallVector<SDValue, 16> V(16,
@@ -12571,8 +12583,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
       if (Subtarget->hasInt256() && VT == MVT::v32i8) {
         if (Op.getOpcode() == ISD::SHL) {
           // Make a large shift.
-          SDValue SHL = DAG.getNode(X86ISD::VSHLI, dl, MVT::v16i16, R,
-                                    DAG.getConstant(ShiftAmt, MVT::i32));
+          SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl,
+                                                   MVT::v16i16, R, ShiftAmt,
+                                                   DAG);
           SHL = DAG.getNode(ISD::BITCAST, dl, VT, SHL);
           // Zero out the rightmost bits.
           SmallVector<SDValue, 32> V(32,
@@ -12583,8 +12596,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
         }
         if (Op.getOpcode() == ISD::SRL) {
           // Make a large shift.
-          SDValue SRL = DAG.getNode(X86ISD::VSRLI, dl, MVT::v16i16, R,
-                                    DAG.getConstant(ShiftAmt, MVT::i32));
+          SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl,
+                                                   MVT::v16i16, R, ShiftAmt,
+                                                   DAG);
           SRL = DAG.getNode(ISD::BITCAST, dl, VT, SRL);
           // Zero out the leftmost bits.
           SmallVector<SDValue, 32> V(32,
@@ -12649,14 +12663,14 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
     default:
       llvm_unreachable("Unknown shift opcode!");
     case ISD::SHL:
-      return DAG.getNode(X86ISD::VSHLI, dl, VT, R,
-                         DAG.getConstant(ShiftAmt, MVT::i32));
+      return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
+                                        DAG);
     case ISD::SRL:
-      return DAG.getNode(X86ISD::VSRLI, dl, VT, R,
-                         DAG.getConstant(ShiftAmt, MVT::i32));
+      return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
+                                        DAG);
     case ISD::SRA:
-      return DAG.getNode(X86ISD::VSRAI, dl, VT, R,
-                         DAG.getConstant(ShiftAmt, MVT::i32));
+      return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
+                                        DAG);
     }
   }
 
@@ -12869,8 +12883,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
 
     // r = VSELECT(r, psllw(r & (char16)15, 4), a);
     SDValue M = DAG.getNode(ISD::AND, dl, VT, R, CM1);
-    M = getTargetVShiftNode(X86ISD::VSHLI, dl, MVT::v8i16, M,
-                            DAG.getConstant(4, MVT::i32), DAG);
+    M = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, MVT::v8i16, M, 4, DAG);
     M = DAG.getNode(ISD::BITCAST, dl, VT, M);
     R = DAG.getNode(ISD::VSELECT, dl, VT, OpVSel, M, R);
 
@@ -12881,8 +12894,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
 
     // r = VSELECT(r, psllw(r & (char16)63, 2), a);
     M = DAG.getNode(ISD::AND, dl, VT, R, CM2);
-    M = getTargetVShiftNode(X86ISD::VSHLI, dl, MVT::v8i16, M,
-                            DAG.getConstant(2, MVT::i32), DAG);
+    M = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, MVT::v8i16, M, 2, DAG);
     M = DAG.getNode(ISD::BITCAST, dl, VT, M);
     R = DAG.getNode(ISD::VSELECT, dl, VT, OpVSel, M, R);
 
@@ -13025,7 +13037,6 @@ SDValue X86TargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op,
 
   unsigned BitsDiff = VT.getScalarType().getSizeInBits() -
                       ExtraVT.getScalarType().getSizeInBits();
-  SDValue ShAmt = DAG.getConstant(BitsDiff, MVT::i32);
 
   switch (VT.getSimpleVT().SimpleTy) {
     default: return SDValue();
@@ -13075,8 +13086,10 @@ SDValue X86TargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op,
       }
 
       // If the above didn't work, then just use Shift-Left + Shift-Right.
-      Tmp1 = getTargetVShiftNode(X86ISD::VSHLI, dl, VT, Op0, ShAmt, DAG);
-      return getTargetVShiftNode(X86ISD::VSRAI, dl, VT, Tmp1, ShAmt, DAG);
+      Tmp1 = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, Op0, BitsDiff,
+                                        DAG);
+      return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, Tmp1, BitsDiff,
+                                        DAG);
     }
   }
 }
index 3fd725cd95f9499dab3f031d58fe0d0983fe4128..05e346dec5ac1c6d27b44537d39b839bc8339d2a 100644 (file)
@@ -1845,22 +1845,22 @@ multiclass avx512_shift_rmi<bits<8> opc, Format ImmFormR, Format ImmFormM,
                          ValueType vt, X86MemOperand x86memop, PatFrag mem_frag,
                          RegisterClass KRC> {
   def ri : AVX512BIi8<opc, ImmFormR, (outs RC:$dst),
-       (ins RC:$src1, i32i8imm:$src2),
+       (ins RC:$src1, i8imm:$src2),
            !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
-       [(set RC:$dst, (vt (OpNode RC:$src1, (i32 imm:$src2))))],
+       [(set RC:$dst, (vt (OpNode RC:$src1, (i8 imm:$src2))))],
         SSE_INTSHIFT_ITINS_P.rr>, EVEX_4V;
   def rik : AVX512BIi8<opc, ImmFormR, (outs RC:$dst),
-       (ins KRC:$mask, RC:$src1, i32i8imm:$src2),
+       (ins KRC:$mask, RC:$src1, i8imm:$src2),
            !strconcat(OpcodeStr,
                 "\t{$src2, $src1, $dst {${mask}}|$dst {${mask}}, $src1, $src2}"),
        [], SSE_INTSHIFT_ITINS_P.rr>, EVEX_4V, EVEX_K;
   def mi: AVX512BIi8<opc, ImmFormM, (outs RC:$dst),
-       (ins x86memop:$src1, i32i8imm:$src2),
+       (ins x86memop:$src1, i8imm:$src2),
            !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
        [(set RC:$dst, (OpNode (mem_frag addr:$src1),
-                     (i32 imm:$src2)))], SSE_INTSHIFT_ITINS_P.rm>, EVEX_4V;
+                     (i8 imm:$src2)))], SSE_INTSHIFT_ITINS_P.rm>, EVEX_4V;
   def mik: AVX512BIi8<opc, ImmFormM, (outs RC:$dst),
-       (ins KRC:$mask, x86memop:$src1, i32i8imm:$src2),
+       (ins KRC:$mask, x86memop:$src1, i8imm:$src2),
            !strconcat(OpcodeStr,
                 "\t{$src2, $src1, $dst {${mask}}|$dst {${mask}}, $src1, $src2}"),
        [], SSE_INTSHIFT_ITINS_P.rm>, EVEX_4V, EVEX_K;
index c2f319704d296aa5e39c72aad6b7ac1a5bc427b5..f1bb9f84a720b895edb20d7aebf40d32e0657e77 100644 (file)
@@ -3744,11 +3744,11 @@ multiclass PDI_binop_rmi<bits<8> opc, bits<8> opc2, Format ImmForm,
                        (bc_frag (memopv2i64 addr:$src2)))))], itins.rm>,
       Sched<[WriteVecShiftLd, ReadAfterLd]>;
   def ri : PDIi8<opc2, ImmForm, (outs RC:$dst),
-       (ins RC:$src1, i32i8imm:$src2),
+       (ins RC:$src1, i8imm:$src2),
        !if(Is2Addr,
            !strconcat(OpcodeStr, "\t{$src2, $dst|$dst, $src2}"),
            !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}")),
-       [(set RC:$dst, (DstVT (OpNode2 RC:$src1, (i32 imm:$src2))))], itins.ri>,
+       [(set RC:$dst, (DstVT (OpNode2 RC:$src1, (i8 imm:$src2))))], itins.ri>,
        Sched<[WriteVecShift]>;
 }
 
@@ -5064,12 +5064,12 @@ multiclass SS3I_unop_rm_int_y<bits<8> opc, string OpcodeStr,
 // Helper fragments to match sext vXi1 to vXiY.
 def v16i1sextv16i8 : PatLeaf<(v16i8 (X86pcmpgt (bc_v16i8 (v4i32 immAllZerosV)),
                                                VR128:$src))>;
-def v8i1sextv8i16  : PatLeaf<(v8i16 (X86vsrai VR128:$src, (i32 15)))>;
-def v4i1sextv4i32  : PatLeaf<(v4i32 (X86vsrai VR128:$src, (i32 31)))>;
+def v8i1sextv8i16  : PatLeaf<(v8i16 (X86vsrai VR128:$src, (i8 15)))>;
+def v4i1sextv4i32  : PatLeaf<(v4i32 (X86vsrai VR128:$src, (i8 31)))>;
 def v32i1sextv32i8 : PatLeaf<(v32i8 (X86pcmpgt (bc_v32i8 (v8i32 immAllZerosV)),
                                                VR256:$src))>;
-def v16i1sextv16i16: PatLeaf<(v16i16 (X86vsrai VR256:$src, (i32 15)))>;
-def v8i1sextv8i32  : PatLeaf<(v8i32 (X86vsrai VR256:$src, (i32 31)))>;
+def v16i1sextv16i16: PatLeaf<(v16i16 (X86vsrai VR256:$src, (i8 15)))>;
+def v8i1sextv8i32  : PatLeaf<(v8i32 (X86vsrai VR256:$src, (i8 31)))>;
 
 let Predicates = [HasAVX] in {
   defm VPABSB  : SS3I_unop_rm_int<0x1C, "vpabsb",
index a978d93fc558b87e69ad62cea98af048c8dff533..5592e6c8a5f7356c5d225695e32571579fd54502 100644 (file)
@@ -121,7 +121,7 @@ entry:
 }
 
 ; CHECK-LABEL: test_sraw_3:
-; CHECK: vpsraw  $16, %ymm0, %ymm0
+; CHECK: vpsraw  $15, %ymm0, %ymm0
 ; CHECK: ret
 
 define <8 x i32> @test_srad_1(<8 x i32> %InVec) {
@@ -151,7 +151,7 @@ entry:
 }
 
 ; CHECK-LABEL: test_srad_3:
-; CHECK: vpsrad  $32, %ymm0, %ymm0
+; CHECK: vpsrad  $31, %ymm0, %ymm0
 ; CHECK: ret
 
 ; SSE Logical Shift Right
index e2d612567a1cb4ff55f1523f9eaf6a54e87cdd0c..462def980a91eb95e2bd53370d47a6bd1439a6f2 100644 (file)
@@ -121,7 +121,7 @@ entry:
 }
 
 ; CHECK-LABEL: test_sraw_3:
-; CHECK: psraw   $16, %xmm0
+; CHECK: psraw   $15, %xmm0
 ; CHECK-NEXT: ret
 
 define <4 x i32> @test_srad_1(<4 x i32> %InVec) {
@@ -151,7 +151,7 @@ entry:
 }
 
 ; CHECK-LABEL: test_srad_3:
-; CHECK: psrad   $32, %xmm0
+; CHECK: psrad   $31, %xmm0
 ; CHECK-NEXT: ret
 
 ; SSE Logical Shift Right