[NVPTX] Add support for efficient rotate instructions on SM 3.2+
authorJustin Holewinski <jholewinski@nvidia.com>
Fri, 27 Jun 2014 18:35:33 +0000 (18:35 +0000)
committerJustin Holewinski <jholewinski@nvidia.com>
Fri, 27 Jun 2014 18:35:33 +0000 (18:35 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@211934 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/IR/IntrinsicsNVVM.td
lib/Target/NVPTX/NVPTXInstrInfo.td
lib/Target/NVPTX/NVPTXIntrinsics.td
test/CodeGen/NVPTX/rotate.ll [new file with mode: 0644]

index d6f933cfbc73fdf4a24358b3eda4d693b5ee6d90..52df102232f1c8de5c72c0897d8b5c2ac875db25 100644 (file)
@@ -1948,6 +1948,25 @@ def int_nvvm_sust_p_3d_v4i32_trap
               "llvm.nvvm.sust.p.3d.v4i32.trap">,
     GCCBuiltin<"__nvvm_sust_p_3d_v4i32_trap">;
 
+def int_nvvm_rotate_b32
+  : Intrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty],
+              [IntrNoMem], "llvm.nvvm.rotate.b32">,
+              GCCBuiltin<"__nvvm_rotate_b32">;
+
+def int_nvvm_rotate_b64
+  :Intrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty],
+             [IntrNoMem], "llvm.nvvm.rotate.b64">,
+             GCCBuiltin<"__nvvm_rotate_b64">;
+
+def int_nvvm_rotate_right_b64
+  : Intrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty],
+              [IntrNoMem], "llvm.nvvm.rotate.right.b64">,
+              GCCBuiltin<"__nvvm_rotate_right_b64">;
+
+def int_nvvm_swap_lo_hi_b64
+  : Intrinsic<[llvm_i64_ty], [llvm_i64_ty],
+              [IntrNoMem], "llvm.nvvm.swap.lo.hi.b64">,
+              GCCBuiltin<"__nvvm_swap_lo_hi_b64">;
 
 
 // Old PTX back-end intrinsics retained here for backwards-compatibility
index e94250b38da7e33fbfd4e694bbc85c77bd57f875..725d6fc91c3ea8654f28181ea9b6eab24a585afd 100644 (file)
@@ -158,6 +158,7 @@ def do_SQRTF32_APPROX : Predicate<"!usePrecSqrtF32()">;
 def do_SQRTF32_RN : Predicate<"usePrecSqrtF32()">;
 
 def hasHWROT32 : Predicate<"Subtarget.hasHWROT32()">;
+def noHWROT32 : Predicate<"!Subtarget.hasHWROT32()">;
 
 def true : Predicate<"1">;
 
@@ -1085,6 +1086,43 @@ multiclass RSHIFT_FORMAT<string OpcStr, SDNode OpNode> {
 defm SRA : RSHIFT_FORMAT<"shr.s", sra>;
 defm SRL : RSHIFT_FORMAT<"shr.u", srl>;
 
+//
+// Rotate: use ptx shf instruction if available.
+//
+
+// 32 bit r2 = rotl r1, n
+//    =>
+//        r2 = shf.l r1, r1, n
+def ROTL32imm_hw : NVPTXInst<(outs Int32Regs:$dst),
+                             (ins Int32Regs:$src, i32imm:$amt),
+              "shf.l.wrap.b32 \t$dst, $src, $src, $amt;",
+    [(set Int32Regs:$dst, (rotl Int32Regs:$src, (i32 imm:$amt)))]>,
+    Requires<[hasHWROT32]> ;
+
+def ROTL32reg_hw : NVPTXInst<(outs Int32Regs:$dst),
+                             (ins Int32Regs:$src, Int32Regs:$amt),
+              "shf.l.wrap.b32 \t$dst, $src, $src, $amt;",
+    [(set Int32Regs:$dst, (rotl Int32Regs:$src, Int32Regs:$amt))]>,
+    Requires<[hasHWROT32]>;
+
+// 32 bit r2 = rotr r1, n
+//    =>
+//        r2 = shf.r r1, r1, n
+def ROTR32imm_hw : NVPTXInst<(outs Int32Regs:$dst),
+                             (ins Int32Regs:$src, i32imm:$amt),
+              "shf.r.wrap.b32 \t$dst, $src, $src, $amt;",
+    [(set Int32Regs:$dst, (rotr Int32Regs:$src, (i32 imm:$amt)))]>,
+    Requires<[hasHWROT32]>;
+
+def ROTR32reg_hw : NVPTXInst<(outs Int32Regs:$dst),
+                             (ins Int32Regs:$src, Int32Regs:$amt),
+              "shf.r.wrap.b32 \t$dst, $src, $src, $amt;",
+    [(set Int32Regs:$dst, (rotr Int32Regs:$src, Int32Regs:$amt))]>,
+    Requires<[hasHWROT32]>;
+
+//
+// Rotate: if ptx shf instruction is not available, then use shift+add
+//
 // 32bit
 def ROT32imm_sw : NVPTXInst<(outs Int32Regs:$dst),
   (ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2),
@@ -1102,9 +1140,11 @@ def SUB_FRM_32 : SDNodeXForm<imm, [{
 }]>;
 
 def : Pat<(rotl Int32Regs:$src, (i32 imm:$amt)),
-          (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>;
+          (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>,
+      Requires<[noHWROT32]>;
 def : Pat<(rotr Int32Regs:$src, (i32 imm:$amt)),
-          (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>;
+          (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>,
+      Requires<[noHWROT32]>;
 
 def ROTL32reg_sw : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src,
     Int32Regs:$amt),
@@ -1117,7 +1157,8 @@ def ROTL32reg_sw : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src,
     !strconcat("shr.b32 \t%rhs, $src, %amt2;\n\t",
     !strconcat("add.u32 \t$dst, %lhs, %rhs;\n\t",
     !strconcat("}}", ""))))))))),
-    [(set Int32Regs:$dst, (rotl Int32Regs:$src, Int32Regs:$amt))]>;
+    [(set Int32Regs:$dst, (rotl Int32Regs:$src, Int32Regs:$amt))]>,
+    Requires<[noHWROT32]>;
 
 def ROTR32reg_sw : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src,
     Int32Regs:$amt),
@@ -1130,7 +1171,8 @@ def ROTR32reg_sw : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src,
     !strconcat("shl.b32 \t%rhs, $src, %amt2;\n\t",
     !strconcat("add.u32 \t$dst, %lhs, %rhs;\n\t",
     !strconcat("}}", ""))))))))),
-    [(set Int32Regs:$dst, (rotr Int32Regs:$src, Int32Regs:$amt))]>;
+    [(set Int32Regs:$dst, (rotr Int32Regs:$src, Int32Regs:$amt))]>,
+    Requires<[noHWROT32]>;
 
 // 64bit
 def ROT64imm_sw : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src,
index 00c315c94e23fbef7c9896b91701f9ce2d619221..0617e7d4e17cbca08da8be166f19246cbcf1acff 100644 (file)
@@ -1864,6 +1864,130 @@ def : Pat<(int_nvvm_read_ptx_sreg_envreg30), (MOV_SPECIAL ENVREG30)>;
 def : Pat<(int_nvvm_read_ptx_sreg_envreg31), (MOV_SPECIAL ENVREG31)>;
 
 
+// rotate builtin support
+
+def ROTATE_B32_HW_IMM
+  : NVPTXInst<(outs Int32Regs:$dst),
+              (ins  Int32Regs:$src, i32imm:$amt),
+              "shf.l.wrap.b32 \t$dst, $src, $src, $amt;",
+              [(set Int32Regs:$dst,
+                 (int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)))]>,
+              Requires<[hasHWROT32]> ;
+
+def ROTATE_B32_HW_REG
+  : NVPTXInst<(outs Int32Regs:$dst),
+              (ins  Int32Regs:$src, Int32Regs:$amt),
+              "shf.l.wrap.b32 \t$dst, $src, $src, $amt;",
+              [(set Int32Regs:$dst,
+                 (int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt))]>,
+              Requires<[hasHWROT32]> ;
+
+def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)),
+          (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>,
+      Requires<[noHWROT32]> ;
+
+def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt),
+          (ROTL32reg_sw Int32Regs:$src, Int32Regs:$amt)>,
+      Requires<[noHWROT32]> ;
+
+def GET_LO_INT64
+  : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src),
+              !strconcat("{{\n\t",
+              !strconcat(".reg .b32 %dummy;\n\t",
+              !strconcat("mov.b64 \t{$dst,%dummy}, $src;\n\t",
+        !strconcat("}}", "")))),
+        []> ;
+
+def GET_HI_INT64
+  : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src),
+              !strconcat("{{\n\t",
+              !strconcat(".reg .b32 %dummy;\n\t",
+              !strconcat("mov.b64 \t{%dummy,$dst}, $src;\n\t",
+        !strconcat("}}", "")))),
+        []> ;
+
+def PACK_TWO_INT32
+  : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$lo, Int32Regs:$hi),
+              "mov.b64 \t$dst, {{$lo, $hi}};", []> ;
+
+def : Pat<(int_nvvm_swap_lo_hi_b64 Int64Regs:$src),
+          (PACK_TWO_INT32 (GET_HI_INT64 Int64Regs:$src),
+                          (GET_LO_INT64 Int64Regs:$src))> ;
+
+// funnel shift, requires >= sm_32
+def SHF_L_WRAP_B32_IMM
+  : NVPTXInst<(outs Int32Regs:$dst),
+              (ins  Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt),
+              "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
+    Requires<[hasHWROT32]>;
+
+def SHF_L_WRAP_B32_REG
+  : NVPTXInst<(outs Int32Regs:$dst),
+              (ins  Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
+              "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
+    Requires<[hasHWROT32]>;
+
+def SHF_R_WRAP_B32_IMM
+  : NVPTXInst<(outs Int32Regs:$dst),
+              (ins  Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt),
+              "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
+    Requires<[hasHWROT32]>;
+
+def SHF_R_WRAP_B32_REG
+  : NVPTXInst<(outs Int32Regs:$dst),
+              (ins  Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
+              "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
+    Requires<[hasHWROT32]>;
+
+// HW version of rotate 64
+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)),
+          (PACK_TWO_INT32
+            (SHF_L_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src),
+                                (GET_LO_INT64 Int64Regs:$src), imm:$amt),
+            (SHF_L_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src),
+                                (GET_HI_INT64 Int64Regs:$src), imm:$amt))>,
+      Requires<[hasHWROT32]>;
+
+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt),
+          (PACK_TWO_INT32
+            (SHF_L_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src),
+                                (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt),
+            (SHF_L_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src),
+                               (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt))>,
+      Requires<[hasHWROT32]>;
+
+
+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)),
+          (PACK_TWO_INT32
+            (SHF_R_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src),
+                                (GET_HI_INT64 Int64Regs:$src), imm:$amt),
+            (SHF_R_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src),
+                                (GET_LO_INT64 Int64Regs:$src), imm:$amt))>,
+      Requires<[hasHWROT32]>;
+
+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt),
+          (PACK_TWO_INT32
+            (SHF_R_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src),
+                                (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt),
+            (SHF_R_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src),
+                               (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt))>,
+      Requires<[hasHWROT32]>;
+
+// SW version of rotate 64
+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)),
+          (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>,
+      Requires<[noHWROT32]>;
+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt),
+          (ROTL64reg_sw Int64Regs:$src, Int32Regs:$amt)>,
+      Requires<[noHWROT32]>;
+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)),
+          (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>,
+      Requires<[noHWROT32]>;
+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt),
+          (ROTR64reg_sw Int64Regs:$src, Int32Regs:$amt)>,
+      Requires<[noHWROT32]>;
+
+
 //-----------------------------------
 // Texture Intrinsics
 //-----------------------------------
diff --git a/test/CodeGen/NVPTX/rotate.ll b/test/CodeGen/NVPTX/rotate.ll
new file mode 100644 (file)
index 0000000..dfc8b4f
--- /dev/null
@@ -0,0 +1,58 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck --check-prefix=SM20 %s
+; RUN: llc < %s -march=nvptx -mcpu=sm_35 | FileCheck --check-prefix=SM35 %s
+
+
+declare i32 @llvm.nvvm.rotate.b32(i32, i32)
+declare i64 @llvm.nvvm.rotate.b64(i64, i32)
+declare i64 @llvm.nvvm.rotate.right.b64(i64, i32)
+
+; SM20: rotate32
+; SM35: rotate32
+define i32 @rotate32(i32 %a, i32 %b) {
+; SM20: shl.b32
+; SM20: sub.s32
+; SM20: shr.b32
+; SM20: add.u32
+; SM35: shf.l.wrap.b32
+  %val = tail call i32 @llvm.nvvm.rotate.b32(i32 %a, i32 %b)
+  ret i32 %val
+}
+
+; SM20: rotate64
+; SM35: rotate64
+define i64 @rotate64(i64 %a, i32 %b) {
+; SM20: shl.b64
+; SM20: sub.u32
+; SM20: shr.b64
+; SM20: add.u64
+; SM35: shf.l.wrap.b32
+; SM35: shf.l.wrap.b32
+  %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 %b)
+  ret i64 %val
+}
+
+; SM20: rotateright64
+; SM35: rotateright64
+define i64 @rotateright64(i64 %a, i32 %b) {
+; SM20: shr.b64
+; SM20: sub.u32
+; SM20: shl.b64
+; SM20: add.u64
+; SM35: shf.r.wrap.b32
+; SM35: shf.r.wrap.b32
+  %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 %b)
+  ret i64 %val
+}
+
+; SM20: rotl0
+; SM35: rotl0
+define i32 @rotl0(i32 %x) {
+; SM20: shl.b32
+; SM20: shr.b32
+; SM20: add.u32
+; SM35: shf.l.wrap.b32
+  %t0 = shl i32 %x, 8
+  %t1 = lshr i32 %x, 24
+  %t2 = or i32 %t0, %t1
+  ret i32 %t2
+}