[shuffle] Make the seed an optional component and add support for
[oota-llvm.git] / utils / shuffle_fuzz.py
1 #!/usr/bin/env python
2
3 """A shuffle vector fuzz tester.
4
5 This is a python program to fuzz test the LLVM shufflevector instruction. It
6 generates a function with a random sequnece of shufflevectors, maintaining the
7 element mapping accumulated across the function. It then generates a main
8 function which calls it with a different value in each element and checks that
9 the result matches the expected mapping.
10
11 Take the output IR printed to stdout, compile it to an executable using whatever
12 set of transforms you want to test, and run the program. If it crashes, it found
13 a bug.
14 """
15
16 import argparse
17 import itertools
18 import random
19 import sys
20 import uuid
21
22 def main():
23   parser = argparse.ArgumentParser(description=__doc__)
24   parser.add_argument('-v', '--verbose', action='store_true',
25                       help='Show verbose output')
26   parser.add_argument('--seed', default=str(uuid.uuid4()),
27                       help='A string used to seed the RNG')
28   parser.add_argument('--max-shuffle-height', type=int, default=16,
29                       help='Specify a fixed height of shuffle tree to test')
30   parser.add_argument('--no-blends', dest='blends', action='store_false',
31                       help='Include blends of two input vectors')
32   parser.add_argument('--fixed-bit-width', type=int, choices=[128, 256],
33                       help='Specify a fixed bit width of vector to test')
34   parser.add_argument('--triple',
35                       help='Specify a triple string to include in the IR')
36   args = parser.parse_args()
37
38   random.seed(args.seed)
39
40   if args.fixed_bit_width is not None:
41     if args.fixed_bit_width == 128:
42       (width, element_type) = random.choice(
43           [(2, 'i64'), (4, 'i32'), (8, 'i16'), (16, 'i8'),
44            (2, 'f64'), (4, 'f32')])
45     elif args.fixed_bit_width == 256:
46       (width, element_type) = random.choice([
47           (4, 'i64'), (8, 'i32'), (16, 'i16'), (32, 'i8'),
48           (4, 'f64'), (8, 'f32')])
49     else:
50       sys.exit(1) # Checked above by argument parsing.
51   else:
52     width = random.choice([2, 4, 8, 16, 32, 64])
53     element_type = random.choice(['i8', 'i16', 'i32', 'i64', 'f32', 'f64'])
54
55   element_modulus = {
56       'i8': 1 << 8, 'i16': 1 << 16, 'i32': 1 << 32, 'i64': 1 << 64,
57       'f32': 1 << 32, 'f64': 1 << 64}[element_type]
58
59   shuffle_range = (2 * width) if args.blends else width
60   shuffle_indices = [-1] + range(shuffle_range)
61
62   shuffle_tree = [[[random.choice(shuffle_indices)
63                     for _ in itertools.repeat(None, width)]
64                    for _ in itertools.repeat(None, args.max_shuffle_height - i)]
65                   for i in xrange(args.max_shuffle_height)]
66
67   if args.verbose:
68     # Print out the shuffle sequence in a compact form.
69     print >>sys.stderr, ('Testing shuffle sequence "%s" (v%d%s):' %
70                          (args.seed, width, element_type))
71     for i, shuffles in enumerate(shuffle_tree):
72       print >>sys.stderr, '  tree level %d:' % (i,)
73       for j, s in enumerate(shuffles):
74         print >>sys.stderr, '    shuffle %d: %s' % (j, s)
75     print >>sys.stderr, ''
76
77   # Symbolically evaluate the shuffle tree.
78   inputs = [[int(j % element_modulus)
79              for j in xrange(i * width + 1, (i + 1) * width + 1)]
80             for i in xrange(args.max_shuffle_height + 1)]
81   results = inputs
82   for shuffles in shuffle_tree:
83     results = [[((results[i] if j < width else results[i + 1])[j % width]
84                  if j != -1 else -1)
85                 for j in s]
86                for i, s in enumerate(shuffles)]
87   if len(results) != 1:
88     print >>sys.stderr, 'ERROR: Bad results: %s' % (results,)
89     sys.exit(1)
90   result = results[0]
91
92   if args.verbose:
93     print >>sys.stderr, 'Which transforms:'
94     print >>sys.stderr, '  from: %s' % (inputs,)
95     print >>sys.stderr, '  into: %s' % (result,)
96     print >>sys.stderr, ''
97
98   # The IR uses silly names for floating point types. We also need a same-size
99   # integer type.
100   integral_element_type = element_type
101   if element_type == 'f32':
102     integral_element_type = 'i32'
103     element_type = 'float'
104   elif element_type == 'f64':
105     integral_element_type = 'i64'
106     element_type = 'double'
107
108   # Now we need to generate IR for the shuffle function.
109   subst = {'N': width, 'T': element_type, 'IT': integral_element_type}
110   print """
111 define internal fastcc <%(N)d x %(T)s> @test(%(arguments)s) noinline nounwind {
112 entry:""" % dict(subst,
113                  arguments=', '.join(
114                      ['<%(N)d x %(T)s> %%s.0.%(i)d' % dict(subst, i=i)
115                       for i in xrange(args.max_shuffle_height + 1)]))
116
117   for i, shuffles in enumerate(shuffle_tree):
118    for j, s in enumerate(shuffles):
119     print """
120   %%s.%(next_i)d.%(j)d = shufflevector <%(N)d x %(T)s> %%s.%(i)d.%(j)d, <%(N)d x %(T)s> %%s.%(i)d.%(next_j)d, <%(N)d x i32> <%(S)s>
121 """.strip('\n') % dict(subst, i=i, next_i=i + 1, j=j, next_j=j + 1,
122                        S=', '.join(['i32 ' + (str(si) if si != -1 else 'undef')
123                                     for si in s]))
124
125   print """
126   ret <%(N)d x %(T)s> %%s.%(i)d.0
127 }
128 """ % dict(subst, i=len(shuffle_tree))
129
130   # Generate some string constants that we can use to report errors.
131   for i, r in enumerate(result):
132     if r != -1:
133       s = ('FAIL(%(seed)s): lane %(lane)d, expected %(result)d, found %%d\\0A' %
134            {'seed': args.seed, 'lane': i, 'result': r})
135       s += ''.join(['\\00' for _ in itertools.repeat(None, 128 - len(s) + 2)])
136       print """
137 @error.%(i)d = private unnamed_addr global [128 x i8] c"%(s)s"
138 """.strip() % {'i': i, 's': s}
139
140   # Define a wrapper function which is marked 'optnone' to prevent
141   # interprocedural optimizations from deleting the test.
142   print """
143 define internal fastcc <%(N)d x %(T)s> @test_wrapper(%(arguments)s) optnone noinline {
144   %%result = call fastcc <%(N)d x %(T)s> @test(%(arguments)s)
145   ret <%(N)d x %(T)s> %%result
146 }
147 """ % dict(subst,
148            arguments=', '.join(['<%(N)d x %(T)s> %%s.%(i)d' % dict(subst, i=i)
149                                 for i in xrange(args.max_shuffle_height + 1)]))
150
151   # Finally, generate a main function which will trap if any lanes are mapped
152   # incorrectly (in an observable way).
153   print """
154 define i32 @main() {
155 entry:
156   ; Create a scratch space to print error messages.
157   %%str = alloca [128 x i8]
158   %%str.ptr = getelementptr inbounds [128 x i8]* %%str, i32 0, i32 0
159
160   ; Build the input vector and call the test function.
161   %%v = call fastcc <%(N)d x %(T)s> @test_wrapper(%(inputs)s)
162   ; We need to cast this back to an integer type vector to easily check the
163   ; result.
164   %%v.cast = bitcast <%(N)d x %(T)s> %%v to <%(N)d x %(IT)s>
165   br label %%test.0
166 """ % dict(subst,
167            inputs=', '.join(
168                [('<%(N)d x %(T)s> bitcast '
169                  '(<%(N)d x %(IT)s> <%(input)s> to <%(N)d x %(T)s>)' %
170                  dict(subst, input=', '.join(['%(IT)s %(i)d' % dict(subst, i=i)
171                                               for i in input])))
172                 for input in inputs]))
173
174   # Test that each non-undef result lane contains the expected value.
175   for i, r in enumerate(result):
176     if r == -1:
177       print """
178 test.%(i)d:
179   ; Skip this lane, its value is undef.
180   br label %%test.%(next_i)d
181 """ % dict(subst, i=i, next_i=i + 1)
182     else:
183       print """
184 test.%(i)d:
185   %%v.%(i)d = extractelement <%(N)d x %(IT)s> %%v.cast, i32 %(i)d
186   %%cmp.%(i)d = icmp ne %(IT)s %%v.%(i)d, %(r)d
187   br i1 %%cmp.%(i)d, label %%die.%(i)d, label %%test.%(next_i)d
188
189 die.%(i)d:
190   ; Capture the actual value and print an error message.
191   %%tmp.%(i)d = zext %(IT)s %%v.%(i)d to i2048
192   %%bad.%(i)d = trunc i2048 %%tmp.%(i)d to i32
193   call i32 (i8*, i8*, ...)* @sprintf(i8* %%str.ptr, i8* getelementptr inbounds ([128 x i8]* @error.%(i)d, i32 0, i32 0), i32 %%bad.%(i)d)
194   %%length.%(i)d = call i32 @strlen(i8* %%str.ptr)
195   %%size.%(i)d = add i32 %%length.%(i)d, 1
196   call i32 @write(i32 2, i8* %%str.ptr, i32 %%size.%(i)d)
197   call void @llvm.trap()
198   unreachable
199 """ % dict(subst, i=i, next_i=i + 1, r=r)
200
201   print """
202 test.%d:
203   ret i32 0
204 }
205
206 declare i32 @strlen(i8*)
207 declare i32 @write(i32, i8*, i32)
208 declare i32 @sprintf(i8*, i8*, ...)
209 declare void @llvm.trap() noreturn nounwind
210 """ % (len(result),)
211
212 if __name__ == '__main__':
213   main()