[NFC][Py Reformat] Reformat python files in llvm
[llvm-project.git] / llvm / utils / shuffle_select_fuzz_tester.py
blob73bac3c18db141e47c81f38cebf8dac65a3e8781
1 #!/usr/bin/env python
3 """A shuffle-select vector fuzz tester.
5 This is a python program to fuzz test the LLVM shufflevector and select
6 instructions. It generates a function with a random sequnece of shufflevectors
7 while optionally attaching it with a select instruction (regular or zero merge),
8 maintaining the element mapping accumulated across the function. It then
9 generates a main function which calls it with a different value in each element
10 and checks that the result matches the expected mapping.
12 Take the output IR printed to stdout, compile it to an executable using whatever
13 set of transforms you want to test, and run the program. If it crashes, it found
14 a bug (an error message with the expected and actual result is printed).
15 """
16 from __future__ import print_function
18 import random
19 import uuid
20 import argparse
22 # Possibility of one undef index in generated mask for shufflevector instruction
23 SHUF_UNDEF_POS = 0.15
25 # Possibility of one undef index in generated mask for select instruction
26 SEL_UNDEF_POS = 0.15
28 # Possibility of adding a select instruction to the result of a shufflevector
29 ADD_SEL_POS = 0.4
31 # If we are adding a select instruction, this is the possibility of a
32 # merge-select instruction (1 - MERGE_SEL_POS = possibility of zero-merge-select
33 # instruction.
34 MERGE_SEL_POS = 0.5
37 test_template = r"""
38 define internal fastcc {ty} @test({inputs}) noinline nounwind {{
39 entry:
40 {instructions}
41 ret {ty} {last_name}
43 """
45 error_template = r'''@error.{lane} = private unnamed_addr global [64 x i8] c"FAIL: lane {lane}, expected {exp}, found %d\0A{padding}"'''
47 main_template = r"""
48 define i32 @main() {{
49 entry:
50 ; Create a scratch space to print error messages.
51 %str = alloca [64 x i8]
52 %str.ptr = getelementptr inbounds [64 x i8], [64 x i8]* %str, i32 0, i32 0
54 ; Build the input vector and call the test function.
55 %v = call fastcc {ty} @test({inputs})
56 br label %test.0
58 {check_die}
61 declare i32 @strlen(i8*)
62 declare i32 @write(i32, i8*, i32)
63 declare i32 @sprintf(i8*, i8*, ...)
64 declare void @llvm.trap() noreturn nounwind
65 """
67 check_template = r"""
68 test.{lane}:
69 %v.{lane} = extractelement {ty} %v, i32 {lane}
70 %cmp.{lane} = {i_f}cmp {ordered}ne {scalar_ty} %v.{lane}, {exp}
71 br i1 %cmp.{lane}, label %die.{lane}, label %test.{n_lane}
72 """
74 undef_check_template = r"""
75 test.{lane}:
76 ; Skip this lane, its value is undef.
77 br label %test.{n_lane}
78 """
80 die_template = r"""
81 die.{lane}:
82 ; Capture the actual value and print an error message.
83 call i32 (i8*, i8*, ...) @sprintf(i8* %str.ptr, i8* getelementptr inbounds ([64 x i8], [64 x i8]* @error.{lane}, i32 0, i32 0), {scalar_ty} %v.{lane})
84 %length.{lane} = call i32 @strlen(i8* %str.ptr)
85 call i32 @write(i32 2, i8* %str.ptr, i32 %length.{lane})
86 call void @llvm.trap()
87 unreachable
88 """
91 class Type:
92 def __init__(self, is_float, elt_width, elt_num):
93 self.is_float = is_float # Boolean
94 self.elt_width = elt_width # Integer
95 self.elt_num = elt_num # Integer
97 def dump(self):
98 if self.is_float:
99 str_elt = "float" if self.elt_width == 32 else "double"
100 else:
101 str_elt = "i" + str(self.elt_width)
103 if self.elt_num == 1:
104 return str_elt
105 else:
106 return "<" + str(self.elt_num) + " x " + str_elt + ">"
108 def get_scalar_type(self):
109 return Type(self.is_float, self.elt_width, 1)
112 # Class to represent any value (variable) that can be used.
113 class Value:
114 def __init__(self, name, ty, value=None):
115 self.ty = ty # Type
116 self.name = name # String
117 self.value = value # list of integers or floating points
120 # Class to represent an IR instruction (shuffle/select).
121 class Instruction(Value):
122 def __init__(self, name, ty, op0, op1, mask):
123 Value.__init__(self, name, ty)
124 self.op0 = op0 # Value
125 self.op1 = op1 # Value
126 self.mask = mask # list of integers
128 def dump(self):
129 pass
131 def calc_value(self):
132 pass
135 # Class to represent an IR shuffle instruction
136 class ShufInstr(Instruction):
138 shuf_template = (
139 " {name} = shufflevector {ty} {op0}, {ty} {op1}, <{num} x i32> {mask}\n"
142 def __init__(self, name, ty, op0, op1, mask):
143 Instruction.__init__(self, "%shuf" + name, ty, op0, op1, mask)
145 def dump(self):
146 str_mask = [
147 ("i32 " + str(idx)) if idx != -1 else "i32 undef" for idx in self.mask
149 str_mask = "<" + (", ").join(str_mask) + ">"
150 return self.shuf_template.format(
151 name=self.name,
152 ty=self.ty.dump(),
153 op0=self.op0.name,
154 op1=self.op1.name,
155 num=self.ty.elt_num,
156 mask=str_mask,
159 def calc_value(self):
160 if self.value != None:
161 print("Trying to calculate the value of a shuffle instruction twice")
162 exit(1)
164 result = []
165 for i in range(len(self.mask)):
166 index = self.mask[i]
168 if index < self.ty.elt_num and index >= 0:
169 result.append(self.op0.value[index])
170 elif index >= self.ty.elt_num:
171 index = index % self.ty.elt_num
172 result.append(self.op1.value[index])
173 else: # -1 => undef
174 result.append(-1)
176 self.value = result
179 # Class to represent an IR select instruction
180 class SelectInstr(Instruction):
182 sel_template = " {name} = select <{num} x i1> {mask}, {ty} {op0}, {ty} {op1}\n"
184 def __init__(self, name, ty, op0, op1, mask):
185 Instruction.__init__(self, "%sel" + name, ty, op0, op1, mask)
187 def dump(self):
188 str_mask = [
189 ("i1 " + str(idx)) if idx != -1 else "i1 undef" for idx in self.mask
191 str_mask = "<" + (", ").join(str_mask) + ">"
192 return self.sel_template.format(
193 name=self.name,
194 ty=self.ty.dump(),
195 op0=self.op0.name,
196 op1=self.op1.name,
197 num=self.ty.elt_num,
198 mask=str_mask,
201 def calc_value(self):
202 if self.value != None:
203 print("Trying to calculate the value of a select instruction twice")
204 exit(1)
206 result = []
207 for i in range(len(self.mask)):
208 index = self.mask[i]
210 if index == 1:
211 result.append(self.op0.value[i])
212 elif index == 0:
213 result.append(self.op1.value[i])
214 else: # -1 => undef
215 result.append(-1)
217 self.value = result
220 # Returns a list of Values initialized with actual numbers according to the
221 # provided type
222 def gen_inputs(ty, num):
223 inputs = []
224 for i in range(num):
225 inp = []
226 for j in range(ty.elt_num):
227 if ty.is_float:
228 inp.append(float(i * ty.elt_num + j))
229 else:
230 inp.append((i * ty.elt_num + j) % (1 << ty.elt_width))
231 inputs.append(Value("%inp" + str(i), ty, inp))
233 return inputs
236 # Returns a random vector type to be tested
237 # In case one of the dimensions (scalar type/number of elements) is provided,
238 # fill the blank dimension and return appropriate Type object.
239 def get_random_type(ty, num_elts):
240 if ty != None:
241 if ty == "i8":
242 is_float = False
243 width = 8
244 elif ty == "i16":
245 is_float = False
246 width = 16
247 elif ty == "i32":
248 is_float = False
249 width = 32
250 elif ty == "i64":
251 is_float = False
252 width = 64
253 elif ty == "f32":
254 is_float = True
255 width = 32
256 elif ty == "f64":
257 is_float = True
258 width = 64
260 int_elt_widths = [8, 16, 32, 64]
261 float_elt_widths = [32, 64]
263 if num_elts == None:
264 num_elts = random.choice(range(2, 65))
266 if ty == None:
267 # 1 for integer type, 0 for floating-point
268 if random.randint(0, 1):
269 is_float = False
270 width = random.choice(int_elt_widths)
271 else:
272 is_float = True
273 width = random.choice(float_elt_widths)
275 return Type(is_float, width, num_elts)
278 # Generate mask for shufflevector IR instruction, with SHUF_UNDEF_POS possibility
279 # of one undef index.
280 def gen_shuf_mask(ty):
281 mask = []
282 for i in range(ty.elt_num):
283 if SHUF_UNDEF_POS / ty.elt_num > random.random():
284 mask.append(-1)
285 else:
286 mask.append(random.randint(0, ty.elt_num * 2 - 1))
288 return mask
291 # Generate mask for select IR instruction, with SEL_UNDEF_POS possibility
292 # of one undef index.
293 def gen_sel_mask(ty):
294 mask = []
295 for i in range(ty.elt_num):
296 if SEL_UNDEF_POS / ty.elt_num > random.random():
297 mask.append(-1)
298 else:
299 mask.append(random.randint(0, 1))
301 return mask
304 # Generate shuffle instructions with optional select instruction after.
305 def gen_insts(inputs, ty):
306 int_zero_init = Value("zeroinitializer", ty, [0] * ty.elt_num)
307 float_zero_init = Value("zeroinitializer", ty, [0.0] * ty.elt_num)
309 insts = []
310 name_idx = 0
311 while len(inputs) > 1:
312 # Choose 2 available Values - remove them from inputs list.
313 [idx0, idx1] = sorted(random.sample(range(len(inputs)), 2))
314 op0 = inputs[idx0]
315 op1 = inputs[idx1]
317 # Create the shuffle instruction.
318 shuf_mask = gen_shuf_mask(ty)
319 shuf_inst = ShufInstr(str(name_idx), ty, op0, op1, shuf_mask)
320 shuf_inst.calc_value()
322 # Add the new shuffle instruction to the list of instructions.
323 insts.append(shuf_inst)
325 # Optionally, add select instruction with the result of the previous shuffle.
326 if random.random() < ADD_SEL_POS:
327 # Either blending with a random Value or with an all-zero vector.
328 if random.random() < MERGE_SEL_POS:
329 op2 = random.choice(inputs)
330 else:
331 op2 = float_zero_init if ty.is_float else int_zero_init
333 select_mask = gen_sel_mask(ty)
334 select_inst = SelectInstr(str(name_idx), ty, shuf_inst, op2, select_mask)
335 select_inst.calc_value()
337 # Add the select instructions to the list of instructions and to the available Values.
338 insts.append(select_inst)
339 inputs.append(select_inst)
340 else:
341 # If the shuffle instruction is not followed by select, add it to the available Values.
342 inputs.append(shuf_inst)
344 del inputs[idx1]
345 del inputs[idx0]
346 name_idx += 1
348 return insts
351 def main():
352 parser = argparse.ArgumentParser(description=__doc__)
353 parser.add_argument(
354 "--seed", default=str(uuid.uuid4()), help="A string used to seed the RNG"
356 parser.add_argument(
357 "--max-num-inputs",
358 type=int,
359 default=20,
360 help="Specify the maximum number of vector inputs for the test. (default: 20)",
362 parser.add_argument(
363 "--min-num-inputs",
364 type=int,
365 default=10,
366 help="Specify the minimum number of vector inputs for the test. (default: 10)",
368 parser.add_argument(
369 "--type",
370 default=None,
371 help="""
372 Choose specific type to be tested.
373 i8, i16, i32, i64, f32 or f64.
374 (default: random)""",
376 parser.add_argument(
377 "--num-elts",
378 default=None,
379 type=int,
380 help="Choose specific number of vector elements to be tested. (default: random)",
382 args = parser.parse_args()
384 print("; The seed used for this test is " + args.seed)
386 assert (
387 args.min_num_inputs < args.max_num_inputs
388 ), "Minimum value greater than maximum."
389 assert args.type in [None, "i8", "i16", "i32", "i64", "f32", "f64"], "Illegal type."
390 assert (
391 args.num_elts == None or args.num_elts > 0
392 ), "num_elts must be a positive integer."
394 random.seed(args.seed)
395 ty = get_random_type(args.type, args.num_elts)
396 inputs = gen_inputs(ty, random.randint(args.min_num_inputs, args.max_num_inputs))
397 inputs_str = (", ").join([inp.ty.dump() + " " + inp.name for inp in inputs])
398 inputs_values = [inp.value for inp in inputs]
400 insts = gen_insts(inputs, ty)
402 assert len(inputs) == 1, "Only one value should be left after generating phase"
403 res = inputs[0]
405 # print the actual test function by dumping the generated instructions.
406 insts_str = "".join([inst.dump() for inst in insts])
407 print(
408 test_template.format(
409 ty=ty.dump(), inputs=inputs_str, instructions=insts_str, last_name=res.name
413 # Print the error message templates as global strings
414 for i in range(len(res.value)):
415 pad = "".join(["\\00"] * (31 - len(str(i)) - len(str(res.value[i]))))
416 print(error_template.format(lane=str(i), exp=str(res.value[i]), padding=pad))
418 # Prepare the runtime checks and failure handlers.
419 scalar_ty = ty.get_scalar_type()
420 check_die = ""
421 i_f = "f" if ty.is_float else "i"
422 ordered = "o" if ty.is_float else ""
423 for i in range(len(res.value)):
424 if res.value[i] != -1:
425 # Emit runtime check for each non-undef expected value.
426 check_die += check_template.format(
427 lane=str(i),
428 n_lane=str(i + 1),
429 ty=ty.dump(),
430 i_f=i_f,
431 scalar_ty=scalar_ty.dump(),
432 exp=str(res.value[i]),
433 ordered=ordered,
435 # Emit failure handler for each runtime check with proper error message
436 check_die += die_template.format(lane=str(i), scalar_ty=scalar_ty.dump())
437 else:
438 # Ignore lanes with undef result
439 check_die += undef_check_template.format(lane=str(i), n_lane=str(i + 1))
441 check_die += "\ntest." + str(len(res.value)) + ":\n"
442 check_die += " ret i32 0"
444 # Prepare the input values passed to the test function.
445 inputs_values = [
446 ", ".join([scalar_ty.dump() + " " + str(i) for i in inp])
447 for inp in inputs_values
449 inputs = ", ".join([ty.dump() + " <" + inp + ">" for inp in inputs_values])
451 print(main_template.format(ty=ty.dump(), inputs=inputs, check_die=check_die))
454 if __name__ == "__main__":
455 main()