[shuffle] Stand back! I'm about to (try to) do math!
[oota-llvm.git] / utils / shuffle_fuzz.py
1 #!/usr/bin/env python
2
3 """A shuffle vector fuzz tester.
4
5 This is a python program to fuzz test the LLVM shufflevector instruction. It
6 generates a function with a random sequnece of shufflevectors, maintaining the
7 element mapping accumulated across the function. It then generates a main
8 function which calls it with a different value in each element and checks that
9 the result matches the expected mapping.
10
11 Take the output IR printed to stdout, compile it to an executable using whatever
12 set of transforms you want to test, and run the program. If it crashes, it found
13 a bug.
14 """
15
16 import argparse
17 import itertools
18 import random
19 import sys
20 import uuid
21
22 def main():
23   parser = argparse.ArgumentParser(description=__doc__)
24   parser.add_argument('-v', '--verbose', action='store_true',
25                       help='Show verbose output')
26   parser.add_argument('--seed', default=str(uuid.uuid4()),
27                       help='A string used to seed the RNG')
28   parser.add_argument('--max-shuffle-height', type=int, default=16,
29                       help='Specify a fixed height of shuffle tree to test')
30   parser.add_argument('--no-blends', dest='blends', action='store_false',
31                       help='Include blends of two input vectors')
32   parser.add_argument('--fixed-bit-width', type=int, choices=[128, 256],
33                       help='Specify a fixed bit width of vector to test')
34   parser.add_argument('--triple',
35                       help='Specify a triple string to include in the IR')
36   args = parser.parse_args()
37
38   random.seed(args.seed)
39
40   if args.fixed_bit_width is not None:
41     if args.fixed_bit_width == 128:
42       (width, element_type) = random.choice(
43           [(2, 'i64'), (4, 'i32'), (8, 'i16'), (16, 'i8'),
44            (2, 'f64'), (4, 'f32')])
45     elif args.fixed_bit_width == 256:
46       (width, element_type) = random.choice([
47           (4, 'i64'), (8, 'i32'), (16, 'i16'), (32, 'i8'),
48           (4, 'f64'), (8, 'f32')])
49     else:
50       sys.exit(1) # Checked above by argument parsing.
51   else:
52     width = random.choice([2, 4, 8, 16, 32, 64])
53     element_type = random.choice(['i8', 'i16', 'i32', 'i64', 'f32', 'f64'])
54
55   element_modulus = {
56       'i8': 1 << 8, 'i16': 1 << 16, 'i32': 1 << 32, 'i64': 1 << 64,
57       'f32': 1 << 32, 'f64': 1 << 64}[element_type]
58
59   shuffle_range = (2 * width) if args.blends else width
60
61   # Because undef (-1) saturates and is indistinguishable when testing the
62   # correctness of a shuffle, we want to bias our fuzz toward having a decent
63   # mixture of non-undef lanes in the end. With a deep shuffle tree, the
64   # probabilies aren't good so we need to bias things. The math here is that if
65   # we uniformly select between -1 and the other inputs, each element of the
66   # result will have the following probability of being undef:
67   #
68   #   1 - (shuffle_range/(shuffle_range+1))^max_shuffle_height
69   #
70   # More generally, for any probability P of selecting a defined element in
71   # a single shuffle, the end result is:
72   #
73   #   1 - P^max_shuffle_height
74   #
75   # The power of the shuffle height is the real problem, as we want:
76   #
77   #   1 - shuffle_range/(shuffle_range+1)
78   #
79   # So we bias the selection of undef at any given node based on the tree
80   # height. Below, let 'A' be 'len(shuffle_range)', 'C' be 'max_shuffle_height',
81   # and 'B' be the bias we use to compensate for
82   # C '((A+1)*A^(1/C))/(A*(A+1)^(1/C))':
83   #
84   #   1 - (B * A)/(A + 1)^C = 1 - A/(A + 1)
85   #
86   # So at each node we use:
87   #
88   #   1 - (B * A)/(A + 1)
89   # = 1 - ((A + 1) * A * A^(1/C))/(A * (A + 1) * (A + 1)^(1/C))
90   # = 1 - ((A + 1) * A^((C + 1)/C))/(A * (A + 1)^((C + 1)/C))
91   #
92   # This is the formula we use to select undef lanes in the shuffle.
93   A = float(shuffle_range)
94   C = float(args.max_shuffle_height)
95   undef_prob = 1.0 - (((A + 1.0) * pow(A, (C + 1.0)/C)) /
96                       (A * pow(A + 1.0, (C + 1.0)/C)))
97
98   shuffle_tree = [[[-1 if random.random() <= undef_prob
99                        else random.choice(range(shuffle_range))
100                     for _ in itertools.repeat(None, width)]
101                    for _ in itertools.repeat(None, args.max_shuffle_height - i)]
102                   for i in xrange(args.max_shuffle_height)]
103
104   if args.verbose:
105     # Print out the shuffle sequence in a compact form.
106     print >>sys.stderr, ('Testing shuffle sequence "%s" (v%d%s):' %
107                          (args.seed, width, element_type))
108     for i, shuffles in enumerate(shuffle_tree):
109       print >>sys.stderr, '  tree level %d:' % (i,)
110       for j, s in enumerate(shuffles):
111         print >>sys.stderr, '    shuffle %d: %s' % (j, s)
112     print >>sys.stderr, ''
113
114   # Symbolically evaluate the shuffle tree.
115   inputs = [[int(j % element_modulus)
116              for j in xrange(i * width + 1, (i + 1) * width + 1)]
117             for i in xrange(args.max_shuffle_height + 1)]
118   results = inputs
119   for shuffles in shuffle_tree:
120     results = [[((results[i] if j < width else results[i + 1])[j % width]
121                  if j != -1 else -1)
122                 for j in s]
123                for i, s in enumerate(shuffles)]
124   if len(results) != 1:
125     print >>sys.stderr, 'ERROR: Bad results: %s' % (results,)
126     sys.exit(1)
127   result = results[0]
128
129   if args.verbose:
130     print >>sys.stderr, 'Which transforms:'
131     print >>sys.stderr, '  from: %s' % (inputs,)
132     print >>sys.stderr, '  into: %s' % (result,)
133     print >>sys.stderr, ''
134
135   # The IR uses silly names for floating point types. We also need a same-size
136   # integer type.
137   integral_element_type = element_type
138   if element_type == 'f32':
139     integral_element_type = 'i32'
140     element_type = 'float'
141   elif element_type == 'f64':
142     integral_element_type = 'i64'
143     element_type = 'double'
144
145   # Now we need to generate IR for the shuffle function.
146   subst = {'N': width, 'T': element_type, 'IT': integral_element_type}
147   print """
148 define internal fastcc <%(N)d x %(T)s> @test(%(arguments)s) noinline nounwind {
149 entry:""" % dict(subst,
150                  arguments=', '.join(
151                      ['<%(N)d x %(T)s> %%s.0.%(i)d' % dict(subst, i=i)
152                       for i in xrange(args.max_shuffle_height + 1)]))
153
154   for i, shuffles in enumerate(shuffle_tree):
155    for j, s in enumerate(shuffles):
156     print """
157   %%s.%(next_i)d.%(j)d = shufflevector <%(N)d x %(T)s> %%s.%(i)d.%(j)d, <%(N)d x %(T)s> %%s.%(i)d.%(next_j)d, <%(N)d x i32> <%(S)s>
158 """.strip('\n') % dict(subst, i=i, next_i=i + 1, j=j, next_j=j + 1,
159                        S=', '.join(['i32 ' + (str(si) if si != -1 else 'undef')
160                                     for si in s]))
161
162   print """
163   ret <%(N)d x %(T)s> %%s.%(i)d.0
164 }
165 """ % dict(subst, i=len(shuffle_tree))
166
167   # Generate some string constants that we can use to report errors.
168   for i, r in enumerate(result):
169     if r != -1:
170       s = ('FAIL(%(seed)s): lane %(lane)d, expected %(result)d, found %%d\\0A' %
171            {'seed': args.seed, 'lane': i, 'result': r})
172       s += ''.join(['\\00' for _ in itertools.repeat(None, 128 - len(s) + 2)])
173       print """
174 @error.%(i)d = private unnamed_addr global [128 x i8] c"%(s)s"
175 """.strip() % {'i': i, 's': s}
176
177   # Define a wrapper function which is marked 'optnone' to prevent
178   # interprocedural optimizations from deleting the test.
179   print """
180 define internal fastcc <%(N)d x %(T)s> @test_wrapper(%(arguments)s) optnone noinline {
181   %%result = call fastcc <%(N)d x %(T)s> @test(%(arguments)s)
182   ret <%(N)d x %(T)s> %%result
183 }
184 """ % dict(subst,
185            arguments=', '.join(['<%(N)d x %(T)s> %%s.%(i)d' % dict(subst, i=i)
186                                 for i in xrange(args.max_shuffle_height + 1)]))
187
188   # Finally, generate a main function which will trap if any lanes are mapped
189   # incorrectly (in an observable way).
190   print """
191 define i32 @main() {
192 entry:
193   ; Create a scratch space to print error messages.
194   %%str = alloca [128 x i8]
195   %%str.ptr = getelementptr inbounds [128 x i8]* %%str, i32 0, i32 0
196
197   ; Build the input vector and call the test function.
198   %%v = call fastcc <%(N)d x %(T)s> @test_wrapper(%(inputs)s)
199   ; We need to cast this back to an integer type vector to easily check the
200   ; result.
201   %%v.cast = bitcast <%(N)d x %(T)s> %%v to <%(N)d x %(IT)s>
202   br label %%test.0
203 """ % dict(subst,
204            inputs=', '.join(
205                [('<%(N)d x %(T)s> bitcast '
206                  '(<%(N)d x %(IT)s> <%(input)s> to <%(N)d x %(T)s>)' %
207                  dict(subst, input=', '.join(['%(IT)s %(i)d' % dict(subst, i=i)
208                                               for i in input])))
209                 for input in inputs]))
210
211   # Test that each non-undef result lane contains the expected value.
212   for i, r in enumerate(result):
213     if r == -1:
214       print """
215 test.%(i)d:
216   ; Skip this lane, its value is undef.
217   br label %%test.%(next_i)d
218 """ % dict(subst, i=i, next_i=i + 1)
219     else:
220       print """
221 test.%(i)d:
222   %%v.%(i)d = extractelement <%(N)d x %(IT)s> %%v.cast, i32 %(i)d
223   %%cmp.%(i)d = icmp ne %(IT)s %%v.%(i)d, %(r)d
224   br i1 %%cmp.%(i)d, label %%die.%(i)d, label %%test.%(next_i)d
225
226 die.%(i)d:
227   ; Capture the actual value and print an error message.
228   %%tmp.%(i)d = zext %(IT)s %%v.%(i)d to i2048
229   %%bad.%(i)d = trunc i2048 %%tmp.%(i)d to i32
230   call i32 (i8*, i8*, ...)* @sprintf(i8* %%str.ptr, i8* getelementptr inbounds ([128 x i8]* @error.%(i)d, i32 0, i32 0), i32 %%bad.%(i)d)
231   %%length.%(i)d = call i32 @strlen(i8* %%str.ptr)
232   %%size.%(i)d = add i32 %%length.%(i)d, 1
233   call i32 @write(i32 2, i8* %%str.ptr, i32 %%size.%(i)d)
234   call void @llvm.trap()
235   unreachable
236 """ % dict(subst, i=i, next_i=i + 1, r=r)
237
238   print """
239 test.%d:
240   ret i32 0
241 }
242
243 declare i32 @strlen(i8*)
244 declare i32 @write(i32, i8*, i32)
245 declare i32 @sprintf(i8*, i8*, ...)
246 declare void @llvm.trap() noreturn nounwind
247 """ % (len(result),)
248
249 if __name__ == '__main__':
250   main()