AVX512BW: Enable packed word shift for 512bit vector. Enable lowering scalar immidiat...
[oota-llvm.git] / lib / Target / X86 / X86InstrAVX512.td
index 62f28b79ecdb1ec6e07069243eb4b20ef7cec548..f7e5d9c7b526df18cb119f36fb7b525dd6441641 100644 (file)
@@ -2176,17 +2176,19 @@ let Predicates = [HasAVX512] in {
             (EXTRACT_SUBREG
              (AND32ri (KMOVWrk (COPY_TO_REGCLASS VK1:$src, VK16)), (i32 1)),
               sub_16bit)>;
-  def : Pat<(v16i1 (scalar_to_vector VK1:$src)),
-            (COPY_TO_REGCLASS VK1:$src, VK16)>;
-  def : Pat<(v8i1 (scalar_to_vector VK1:$src)),
-            (COPY_TO_REGCLASS VK1:$src, VK8)>;
-}
-let Predicates = [HasBWI] in {
-  def : Pat<(v32i1 (scalar_to_vector VK1:$src)),
-            (COPY_TO_REGCLASS VK1:$src, VK32)>;
-  def : Pat<(v64i1 (scalar_to_vector VK1:$src)),
-            (COPY_TO_REGCLASS VK1:$src, VK64)>;
 }
+def : Pat<(v16i1 (scalar_to_vector VK1:$src)),
+          (COPY_TO_REGCLASS VK1:$src, VK16)>;
+def : Pat<(v8i1 (scalar_to_vector VK1:$src)),
+          (COPY_TO_REGCLASS VK1:$src, VK8)>;
+def : Pat<(v4i1 (scalar_to_vector VK1:$src)),
+          (COPY_TO_REGCLASS VK1:$src, VK4)>;
+def : Pat<(v2i1 (scalar_to_vector VK1:$src)),
+          (COPY_TO_REGCLASS VK1:$src, VK2)>;
+def : Pat<(v32i1 (scalar_to_vector VK1:$src)),
+          (COPY_TO_REGCLASS VK1:$src, VK32)>;
+def : Pat<(v64i1 (scalar_to_vector VK1:$src)),
+          (COPY_TO_REGCLASS VK1:$src, VK64)>;
 
 
 // With AVX-512 only, 8-bit mask is promoted to 16-bit mask.
@@ -2489,6 +2491,9 @@ def : Pat<(v8i1 (extract_subvector (v16i1 VK16:$src), (iPTR 8))),
 def : Pat<(v16i1 (extract_subvector (v32i1 VK32:$src), (iPTR 0))),
           (v16i1 (COPY_TO_REGCLASS VK32:$src, VK16))>;
 
+def : Pat<(v16i1 (extract_subvector (v32i1 VK32:$src), (iPTR 16))),
+          (v16i1 (COPY_TO_REGCLASS (KSHIFTRDri VK32:$src, (i8 16)), VK16))>;
+
 def : Pat<(v32i1 (extract_subvector (v64i1 VK64:$src), (iPTR 0))),
           (v32i1 (COPY_TO_REGCLASS VK64:$src, VK32))>;
 
@@ -2497,6 +2502,7 @@ def : Pat<(v32i1 (extract_subvector (v64i1 VK64:$src), (iPTR 32))),
 
 def : Pat<(v4i1 (extract_subvector (v8i1 VK8:$src), (iPTR 0))),
           (v4i1 (COPY_TO_REGCLASS VK8:$src, VK4))>;
+
 def : Pat<(v2i1 (extract_subvector (v8i1 VK8:$src), (iPTR 0))),
           (v2i1 (COPY_TO_REGCLASS VK8:$src, VK2))>;
 
@@ -4146,6 +4152,27 @@ multiclass avx512_var_shift_types<bits<8> opc, string OpcodeStr,
                                  avx512vl_i64_info>, VEX_W;
 }
 
+// Use 512bit version to implement 128/256 bit in case NoVLX.  
+multiclass avx512_var_shift_w_lowering<AVX512VLVectorVTInfo _, SDNode OpNode> {
+  let Predicates = [HasBWI, NoVLX] in {
+  def : Pat<(_.info256.VT (OpNode (_.info256.VT _.info256.RC:$src1), 
+                                  (_.info256.VT _.info256.RC:$src2))),
+            (EXTRACT_SUBREG                
+                (!cast<Instruction>(NAME#"WZrr")
+                    (INSERT_SUBREG (_.info512.VT (IMPLICIT_DEF)), VR256X:$src1, sub_ymm),
+                    (INSERT_SUBREG (_.info512.VT (IMPLICIT_DEF)), VR256X:$src2, sub_ymm)),
+             sub_ymm)>;
+
+  def : Pat<(_.info128.VT (OpNode (_.info128.VT _.info128.RC:$src1), 
+                                  (_.info128.VT _.info128.RC:$src2))),
+            (EXTRACT_SUBREG                
+                (!cast<Instruction>(NAME#"WZrr")
+                    (INSERT_SUBREG (_.info512.VT (IMPLICIT_DEF)), VR128X:$src1, sub_xmm),
+                    (INSERT_SUBREG (_.info512.VT (IMPLICIT_DEF)), VR128X:$src2, sub_xmm)),
+             sub_xmm)>;
+  }
+}
+
 multiclass avx512_var_shift_w<bits<8> opc, string OpcodeStr,
                                  SDNode OpNode> {
   let Predicates = [HasBWI] in
@@ -4161,11 +4188,14 @@ multiclass avx512_var_shift_w<bits<8> opc, string OpcodeStr,
 }
 
 defm VPSLLV : avx512_var_shift_types<0x47, "vpsllv", shl>,
-              avx512_var_shift_w<0x12, "vpsllvw", shl>;
+              avx512_var_shift_w<0x12, "vpsllvw", shl>,
+              avx512_var_shift_w_lowering<avx512vl_i16_info, shl>;
 defm VPSRAV : avx512_var_shift_types<0x46, "vpsrav", sra>,
-              avx512_var_shift_w<0x11, "vpsravw", sra>;
+              avx512_var_shift_w<0x11, "vpsravw", sra>,
+              avx512_var_shift_w_lowering<avx512vl_i16_info, sra>;
 defm VPSRLV : avx512_var_shift_types<0x45, "vpsrlv", srl>,
-              avx512_var_shift_w<0x10, "vpsrlvw", srl>;
+              avx512_var_shift_w<0x10, "vpsrlvw", srl>,
+              avx512_var_shift_w_lowering<avx512vl_i16_info, srl>;
 defm VPRORV : avx512_var_shift_types<0x14, "vprorv", rotr>;
 defm VPROLV : avx512_var_shift_types<0x15, "vprolv", rotl>;
 
@@ -5706,20 +5736,6 @@ multiclass avx512_fp14_p_vl_all<bits<8> opc, string OpcodeStr, SDNode OpNode> {
 defm VRSQRT14 : avx512_fp14_p_vl_all<0x4E, "vrsqrt14", X86frsqrt>;
 defm VRCP14 : avx512_fp14_p_vl_all<0x4C, "vrcp14", X86frcp>;
 
-def : Pat <(v16f32 (int_x86_avx512_rsqrt14_ps_512 (v16f32 VR512:$src),
-              (bc_v16f32 (v16i32 immAllZerosV)), (i16 -1))),
-           (VRSQRT14PSZr VR512:$src)>;
-def : Pat <(v8f64 (int_x86_avx512_rsqrt14_pd_512 (v8f64 VR512:$src),
-              (bc_v8f64 (v16i32 immAllZerosV)), (i8 -1))),
-           (VRSQRT14PDZr VR512:$src)>;
-
-def : Pat <(v16f32 (int_x86_avx512_rcp14_ps_512 (v16f32 VR512:$src),
-              (bc_v16f32 (v16i32 immAllZerosV)), (i16 -1))),
-           (VRCP14PSZr VR512:$src)>;
-def : Pat <(v8f64 (int_x86_avx512_rcp14_pd_512 (v8f64 VR512:$src),
-              (bc_v8f64 (v16i32 immAllZerosV)), (i8 -1))),
-           (VRCP14PDZr VR512:$src)>;
-
 /// avx512_fp28_s rcp28ss, rcp28sd, rsqrt28ss, rsqrt28sd
 multiclass avx512_fp28_s<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
                          SDNode OpNode> {