[HLSL] Implement RWBuffer::operator[] via __builtin_hlsl_resource_getpointer (#117017)
[llvm-project.git] / llvm / test / CodeGen / NVPTX / wmma.py
blobe1e46f0b8cab34b6e415c1346ac3750117421695
1 # This test generates all variants of wmma intrinsics and verifies that LLVM
2 # generates correct instructions for them. This is the test generator only. The
3 # test scripts themselves are in wmma-ptx*-sm*.py files.
5 # RUN: true
7 from __future__ import print_function
9 import argparse
10 from itertools import product
11 from string import Template
13 class MMAType:
14 def __init__(self, ptx_type):
15 self.ptx_type = ptx_type
16 self.llvm_type = {
17 "f16": "<2 x half>",
18 "f32": "float",
19 "f64": "double",
20 "s32": "i32",
21 "b16": "i32",
22 "s8": "i32",
23 "u8": "i32",
24 "s4": "i32",
25 "u4": "i32",
26 "b1": "i32",
27 "bf16": "i32",
28 "tf32": "i32",
29 }[ptx_type]
31 self.ptx_reg_pattern = {
32 "f16": "%r[0-9]+",
33 "f32": "%f[0-9]+",
34 "f64": "%fd[0-9]+",
35 }.get(ptx_type, "%r[0-9]+")
37 def __repr__(self):
38 return "%s/%s" % (self.ptx_type, self.llvm_type)
41 class MMAFrag:
42 def __init__(self, geom, frag, ptx_elt_type):
43 self.geom = geom
44 self.frag = frag
45 self.mma_type = MMAType(ptx_elt_type)
46 self.nregs = {
47 # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
48 "m16n16k16:a:u8": 2,
49 "m16n16k16:a:s8": 2,
50 "m16n16k16:b:u8": 2,
51 "m16n16k16:b:s8": 2,
52 "m16n16k16:c:s32": 8,
53 "m16n16k16:d:s32": 8,
54 "m8n32k16:a:u8": 1,
55 "m8n32k16:a:s8": 1,
56 "m8n32k16:b:u8": 4,
57 "m8n32k16:b:s8": 4,
58 "m8n32k16:c:s32": 8,
59 "m8n32k16:d:s32": 8,
60 "m32n8k16:a:u8": 4,
61 "m32n8k16:a:s8": 4,
62 "m32n8k16:b:u8": 1,
63 "m32n8k16:b:s8": 1,
64 "m32n8k16:c:s32": 8,
65 "m32n8k16:d:s32": 8,
66 "m8n8k16:a:u8": 1,
67 "m8n8k16:a:s8": 1,
68 "m8n8k16:b:u8": 1,
69 "m8n8k16:b:s8": 1,
70 "m8n8k16:c:s32": 2,
71 "m8n8k16:d:s32": 2,
72 "m16n8k16:a:u8": 2,
73 "m16n8k16:a:s8": 2,
74 "m16n8k16:b:u8": 1,
75 "m16n8k16:b:s8": 1,
76 "m16n8k16:c:s32": 4,
77 "m16n8k16:d:s32": 4,
78 "m16n8k32:a:u8": 4,
79 "m16n8k32:a:s8": 4,
80 "m16n8k32:b:u8": 2,
81 "m16n8k32:b:s8": 2,
82 "m16n8k32:c:s32": 4,
83 "m16n8k32:d:s32": 4,
84 # u4/s4 -> s32 @ m8n8k32 (u4/s4)
85 "m8n8k32:a:u4": 1,
86 "m8n8k32:a:s4": 1,
87 "m8n8k32:b:u4": 1,
88 "m8n8k32:b:s4": 1,
89 "m8n8k32:c:s32": 2,
90 "m8n8k32:d:s32": 2,
91 "m16n8k32:a:u4": 2,
92 "m16n8k32:a:s4": 2,
93 "m16n8k32:b:u4": 1,
94 "m16n8k32:b:s4": 1,
95 "m16n8k32:c:s32": 4,
96 "m16n8k32:d:s32": 4,
97 "m16n8k64:a:u4": 4,
98 "m16n8k64:a:s4": 4,
99 "m16n8k64:b:u4": 2,
100 "m16n8k64:b:s4": 2,
101 "m16n8k64:c:s32": 4,
102 "m16n8k64:d:s32": 4,
103 # b1 -> s32 @ m8n8k128(b1)
104 "m8n8k128:a:b1": 1,
105 "m8n8k128:b:b1": 1,
106 "m8n8k128:c:s32": 2,
107 "m8n8k128:d:s32": 2,
108 "m16n8k128:a:b1": 2,
109 "m16n8k128:b:b1": 1,
110 "m16n8k128:c:s32": 4,
111 "m16n8k128:d:s32": 4,
112 "m16n8k256:a:b1": 4,
113 "m16n8k256:b:b1": 2,
114 "m16n8k256:c:s32": 4,
115 "m16n8k256:d:s32": 4,
116 # bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
117 "m16n16k16:a:bf16": 4,
118 "m16n16k16:b:bf16": 4,
119 "m8n32k16:a:bf16": 2,
120 "m8n32k16:b:bf16": 8,
121 "m32n8k16:a:bf16": 8,
122 "m32n8k16:b:bf16": 2,
123 "m16n8k16:a:bf16": 4,
124 "m16n8k16:b:bf16": 2,
125 "m16n8k16:c:f32": 4,
126 "m16n8k16:d:f32": 4,
127 "m16n8k8:a:bf16": 2,
128 "m16n8k8:b:bf16": 1,
129 "m16n8k8:c:f32": 4,
130 "m16n8k8:d:f32": 4,
131 "m8n8k4:a:f64": 1,
132 "m8n8k4:b:f64": 1,
133 "m8n8k4:c:f64": 2,
134 "m8n8k4:d:f64": 2,
135 # tf32 -> s32 @ m16n16k8
136 "m16n16k8:a:tf32": 4,
137 "m16n16k8:b:tf32": 4,
138 "m16n8k4:a:tf32": 2,
139 "m16n8k4:b:tf32": 1,
140 "m16n8k4:c:f32": 4,
141 "m16n8k4:d:f32": 4,
142 "m16n8k8:a:tf32": 4,
143 "m16n8k8:b:tf32": 2,
144 "m16n8k8:c:f32": 4,
145 "m16n8k8:d:f32": 4,
146 "m8n8k4:a:f16": 2,
147 "m8n8k4:b:f16": 2,
148 "m16n8k8:a:f16": 2,
149 "m16n8k8:b:f16": 1,
150 "m16n8k8:c:f16": 2,
151 "m16n8k8:d:f16": 2,
152 "m16n8k8:c:f32": 4,
153 "m16n8k8:d:f32": 4,
154 "m16n8k16:a:f16": 4,
155 "m16n8k16:b:f16": 2,
156 "m16n8k16:c:f16": 2,
157 "m16n8k16:d:f16": 2,
158 "m16n8k16:c:f32": 4,
159 "m16n8k16:d:f32": 4,
160 # ldmatrix
161 "m8n8:x1:b16": 1,
162 "m8n8:x2:b16": 2,
163 "m8n8:x4:b16": 4,
164 }.get(
165 "%s:%s:%s" % (geom, frag, ptx_elt_type),
167 # All other FP shape/fragment/type combinations have the same size
168 "a:f16": 8,
169 "b:f16": 8,
170 "c:f16": 4,
171 "d:f16": 4,
172 "c:f32": 8,
173 "d:f32": 8,
174 }.get("%s:%s" % (frag, ptx_elt_type), None),
176 assert self.nregs
178 def __repr__(self):
179 return "%s:%s:%s%s" % (
180 self.geom,
181 self.frag,
182 self.mma_type,
183 "" if self.nregs == 1 else ("*%d" % self.nregs),
187 class MMAOp:
188 def __init__(self, a, b, c, d):
189 self.a = a
190 self.b = b
191 self.c = c
192 self.d = d
194 def __repr__(self):
195 return "{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d)
198 def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
199 ops = []
200 for geom, type_a, type_c in product(geoms, types_a, types_c):
201 for type_b, type_d in product(
202 types_b if types_b else [type_a], types_d if types_d else [type_c]
204 ops.append(
205 MMAOp(
206 MMAFrag(geom, "a", type_a),
207 MMAFrag(geom, "b", type_b),
208 MMAFrag(geom, "c", type_c),
209 MMAFrag(geom, "d", type_d),
212 return ops
215 def make_ldst_ops(geoms, frags, types):
216 return [
217 MMAFrag(geom, frag, ptx_type)
218 for (geom, frag, ptx_type) in product(geoms, frags, types)
222 def make_ldmatrix_ops(geoms, frags, types):
223 return [
224 MMAFrag(geom, frag, ptx_type)
225 for (geom, frag, ptx_type) in product(geoms, frags, types)
229 def get_wmma_ops():
230 return (
231 make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
232 + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["bf16"], [], ["f32"], [])
233 + make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
234 + make_mma_ops(
235 ["m16n16k16", "m32n8k16", "m8n32k16"],
236 ["f16"],
238 ["f16", "f32"],
239 ["f16", "f32"],
241 + make_mma_ops(
242 ["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], []
244 + make_mma_ops(["m8n8k32"], ["s4", "u4"], [], ["s32"], [])
245 + make_mma_ops(["m8n8k128"], ["b1"], [], ["s32"], [])
249 def get_mma_ops():
250 return (
251 make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
252 + make_mma_ops(["m16n8k4", "m16n8k8"], ["tf32"], [], ["f32"], [])
253 + make_mma_ops(["m16n8k16", "m16n8k8"], ["bf16"], [], ["f32"], [])
254 + make_mma_ops(
255 ["m8n8k4", "m16n8k8", "m16n8k16"],
256 ["f16"],
258 ["f16", "f32"],
259 ["f16", "f32"],
261 + make_mma_ops(
262 ["m8n8k16", "m16n8k16", "m16n8k32"], ["s8", "u8"], ["s8", "u8"], ["s32"], []
264 + make_mma_ops(
265 ["m8n8k32", "m16n8k32", "m16n8k64"], ["s4", "u4"], ["s4", "u4"], ["s32"], []
267 + make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], ["b1"], [], ["s32"], [])
271 def get_ldst_ops(kind):
272 ldst_ops = (
273 make_ldst_ops(
274 ["m16n16k16", "m32n8k16", "m8n32k16"],
275 ["a", "b"],
276 ["f16", "u8", "s8", "bf16"],
278 + make_ldst_ops(
279 ["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"]
281 + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4", "u4"])
282 + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"])
283 + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])
284 + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"])
285 + make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"])
286 + make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"])
288 return [x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
291 def get_ldmatrix_ops():
292 return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
295 def is_wmma_geom_supported(geom):
296 # geometries for FP and ints.
297 if geom in ["m8n32k16", "m32n8k16"]:
298 return ptx_version >= 61
299 # geometries for sub-ints.
300 if geom in ["m8n8k32", "m8n8k128"]:
301 return ptx_version >= 63 and gpu_arch >= 75
302 if geom == "m16n16k16":
303 return ptx_version >= 60
304 if geom == "m16n8k8":
305 return ptx_version >= 65
306 if geom in ["m16n16k8", "m8n8k4"]:
307 return ptx_version >= 70
308 assert False # Unexpected geometry.
311 def is_mma_geom_supported(geom):
312 # geometries for FP and ints.
313 if geom == "m8n8k4":
314 return ptx_version >= 64
315 if geom in ["m16n8k8", "m8n8k16", "m8n8k32"]:
316 return ptx_version >= 65
317 if geom in [
318 "m16n8k16",
319 "m16n8k4",
320 "m16n8k32",
321 "m16n8k64",
322 "m8n8k128",
323 "m16n8k128",
324 "m16n8k256",
326 return ptx_version >= 70
327 assert False # Unexpected geometry.
330 def is_ldmatrix_geom_supported(geom):
331 if geom in ["m8n8"]:
332 return ptx_version >= 65 and gpu_arch >= 75
333 assert False # Unexpected geometry.
336 def is_type_supported(ptx_type):
337 if ptx_type in ["s8", "u8", "s32"]:
338 return ptx_version >= 63 and gpu_arch >= 72
339 if ptx_type in ["s4", "u4", "b1"]:
340 return ptx_version >= 63 and gpu_arch >= 75
341 if ptx_type == "b16":
342 return ptx_version >= 65 and gpu_arch >= 75
343 if ptx_type in ["bf16", "tf32", "f64"]:
344 return ptx_version >= 70
345 return ptx_version >= 60 and gpu_arch >= 70
348 def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
349 if not (
350 is_type_supported(op.a.mma_type.ptx_type) and is_wmma_geom_supported(op.a.geom)
352 return False
354 # rnd is only supported for FP64 WMMA
355 if rnd and op.a.mma_type.ptx_type != "f64":
356 return False
358 if satf:
359 # satfinite for floating points was removed in PTX 6.5
360 if op.a.mma_type.ptx_type == "f16" and ptx_version >= 65:
361 return False
362 if not op.a.mma_type.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
363 return False
365 # sub-integer require row/col layout.
366 if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
367 return layout_a == "row" and layout_b == "col"
368 return True
371 def is_mma_variant_supported(op, layout_a, layout_b, satf):
372 if not (
373 is_type_supported(op.a.mma_type.ptx_type) and is_mma_geom_supported(op.a.geom)
375 return False
377 if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]:
378 return False
380 # If the type of C is f32 then so must the type of D
381 if (
382 op.a.geom == "m8n8k4"
383 and op.c.mma_type.ptx_type == "f32"
384 and op.d.mma_type.ptx_type != "f32"
386 return False
388 # A and B type must be the same. C and D type must be the same
389 if op.a.geom == "m16n8k8" and (
390 op.a.mma_type.ptx_type != op.b.mma_type.ptx_type
391 or op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
393 return False
395 # C and D type must be the same
396 if op.a.geom == "m16n8k16" and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type:
397 return False
399 # Require row/col layout for all MMA except m8n8k4 on FP16
400 if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
401 return layout_a == "row" and layout_b == "col"
402 return True
405 def is_ldst_variant_supported(frag, layout):
406 if not (
407 is_type_supported(frag.mma_type.ptx_type) and is_wmma_geom_supported(frag.geom)
409 return False
410 if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
411 # sub-integer require sm_75 and ptx63, row/col layout for a/b.
412 return (
413 (frag.frag == "a" and layout == "row")
414 or (frag.frag == "b" and layout == "col")
415 or frag.frag in ["c", "d"]
417 return True
420 def is_ldmatrix_variant_supported(frag):
421 if not (
422 is_type_supported(frag.mma_type.ptx_type)
423 and is_ldmatrix_geom_supported(frag.geom)
425 return False
426 return frag.frag in ["x1", "x2", "x4"]
429 def make_wmma_slice_ty(frag):
430 return [frag.mma_type.llvm_type] * frag.nregs
433 def make_wmma_ld_ret_ty(frag):
434 results = make_wmma_slice_ty(frag)
435 if len(results) == 1:
436 return "%s" % results[0]
437 return "{%s}" % ", ".join(results)
440 # returns address space
441 def get_aspace(space):
442 space_map = {
443 ".global": 1,
444 ".shared": 3,
445 ".const": 4,
446 ".local": 5,
447 ".param": 101,
448 "": 0,
449 ".generic": 0,
451 return space_map[space]
454 def get_pspace(space):
455 return "p%di8" % get_aspace(space)
458 def check_pattern(frag):
459 return "{{%s}}" % ", *".join([frag.mma_type.ptx_reg_pattern] * frag.nregs)
462 def gen_wmma_load_tests():
463 load_template = """
464 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
466 ; CHECK-LABEL: .func {{.*}}test_${function}(
467 define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
468 ; CHECK: ${instruction}
469 ; CHECK: {${check_result}}
470 ; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
471 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
472 ret ${ret_ty} %v0;
475 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
476 define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
477 ; CHECK: ${instruction}
478 ; CHECK: {${check_result}}
479 ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
480 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
481 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
482 ret ${ret_ty} %v0;
485 intrinsic_template = (
486 "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
488 instruction_template = (
489 "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
492 generated_items = []
494 for frag, layout, space, stride in product(
495 get_ldst_ops("load"),
496 ["row", "col"],
497 ["", ".shared", ".global"],
498 ["", ".stride"],
500 if not is_ldst_variant_supported(frag, layout):
501 continue
503 params = {
504 "abc": frag.frag,
505 "aligned": ".aligned" if ptx_version >= 63 else "",
506 "layout": layout,
507 "space": space,
508 "stride": stride,
509 "itype": frag.mma_type.ptx_type,
510 "pspace": get_pspace(space),
511 "as": "addrspace(%d)" % get_aspace(space),
512 "geom": frag.geom,
515 test_params = params
516 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
517 test_params["function"] = test_params["intrinsic"].replace(".", "_")
518 test_params["instruction"] = Template(instruction_template).substitute(params)
519 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
520 test_params["check_result"] = check_pattern(frag)
522 if stride:
523 test_params["extra_args"] = ", i32 %stride"
524 test_params["stride_pattern"] = ", %r{{[0-9]+}}"
525 else:
526 test_params["extra_args"] = ""
527 test_params["stride_pattern"] = ""
529 print(Template(load_template).substitute(test_params))
531 generated_items.append((test_params["intrinsic"], test_params["instruction"]))
533 return generated_items
536 def make_wmma_slice_args(frag):
537 return ", ".join(
539 "%s %%%s%d" % (t, frag.frag, i)
540 for i, t in enumerate(make_wmma_slice_ty(frag))
545 def gen_wmma_store_tests():
546 store_template = """
547 declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
549 ; CHECK-LABEL: .func {{.*}}test_${function}(
550 define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
551 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
552 ; CHECK: {${check_args}}
553 ; CHECK: ${stride_pattern}
554 call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
555 ret void
558 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
559 define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
560 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
561 ; CHECK: ${check_args}
562 ; CHECK: ${stride_pattern}
563 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
564 call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
565 ret void
568 intrinsic_template = (
569 "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
571 instruction_template = (
572 "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
575 generated_items = []
577 for frag, layout, space, stride in product(
578 get_ldst_ops("store"),
579 ["row", "col"],
580 ["", ".shared", ".global"],
581 ["", ".stride"],
584 if not is_ldst_variant_supported(frag, layout):
585 continue
587 params = {
588 "abc": frag.frag,
589 "aligned": ".aligned" if ptx_version >= 63 else "",
590 "layout": layout,
591 "space": space,
592 "stride": stride,
593 "itype": frag.mma_type.ptx_type,
594 "pspace": get_pspace(space),
595 "as": "addrspace(%d)" % get_aspace(space),
596 "geom": frag.geom,
599 test_params = params
600 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
601 test_params["function"] = test_params["intrinsic"].replace(".", "_")
602 test_params["instruction"] = Template(instruction_template).substitute(params)
603 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
604 test_params["check_args"] = check_pattern(frag)
605 if stride:
606 test_params["extra_args"] = ", i32 %stride"
607 test_params["stride_pattern"] = ", %r{{[0-9]+}};"
608 else:
609 test_params["extra_args"] = ""
610 test_params["stride_pattern"] = ";"
611 test_params["args"] = make_wmma_slice_args(frag)
613 print(Template(store_template).substitute(test_params))
614 generated_items.append((test_params["intrinsic"], test_params["instruction"]))
616 return generated_items
619 def gen_ldmatrix_tests():
620 ldmatrix_template = """
621 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src);
623 ; CHECK-LABEL: .func {{.*}}test_${function}(
624 define ${ret_ty} @test_${function}(i8 ${as}* %src) {
625 ; CHECK: ${instruction}
626 ; CHECK: {${check_result}}
627 ; CHECK: [%rd{{[0-9]+}}]
628 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src);
629 ret ${ret_ty} %v0;
632 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
633 define ${ret_ty} @test_${function}_o(i8 ${as}* %src) {
634 ; CHECK: ${instruction}
635 ; CHECK: {${check_result}}
636 ; CHECK: [%rd{{[0-9]+}}+128]
637 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
638 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1);
639 ret ${ret_ty} %v0;
642 intrinsic_template = (
643 "llvm.nvvm.ldmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
645 instruction_template = (
646 "ldmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
649 generated_items = []
651 for frag, space, trans in product(
652 get_ldmatrix_ops(),
653 ["", ".shared"],
654 ["", ".trans"],
656 if not is_ldmatrix_variant_supported(frag):
657 continue
659 params = {
660 "frag": frag.frag,
661 "space": space,
662 "trans": trans,
663 "itype": frag.mma_type.ptx_type,
664 "pspace": get_pspace(space),
665 "as": "addrspace(%d)" % get_aspace(space),
666 "geom": frag.geom,
669 test_params = params
670 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
671 test_params["function"] = test_params["intrinsic"].replace(".", "_")
672 test_params["instruction"] = Template(instruction_template).substitute(params)
673 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
674 test_params["check_result"] = check_pattern(frag)
676 print(Template(ldmatrix_template).substitute(test_params))
678 generated_items.append((test_params["intrinsic"], test_params["instruction"]))
680 return generated_items
683 def mma_signature(op):
684 if op.a.mma_type.ptx_type == "f16":
685 # FP16 ops identified by accumulator & result type.
686 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
687 elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type:
688 # other ops are identified by input types.
689 return "%s.%s" % (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
690 else:
691 # if input types are the same, it only appears once.
692 return op.a.mma_type.ptx_type
695 def mma_ptx_signature(op):
696 # Encode all four types as D.A.B.C
697 return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
700 def wmma_signature(op):
701 if op.a.mma_type.ptx_type == "f16":
702 # FP16 ops identified by accumulator & result type.
703 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
704 else:
705 # other ops are identified by input type.
706 return op.a.mma_type.ptx_type
709 def wmma_ptx_signature(op):
710 if op.a.mma_type.ptx_type == "f16":
711 # FP16 instructions use D.C
712 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
713 else:
714 # other instructions encode all four types as D.A.B.C
715 return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
718 def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
719 mma_template = """
720 declare ${ret_ty} @${intrinsic}(
721 ${args});
723 ; CHECK-LABEL: .func {{.*}}test_${function}(
724 define ${ret_ty} @test_${function}(
725 ${args}) {
726 ; CHECK: ${instruction}
727 ; CHECK-NEXT: ${check_d}
728 ; CHECK-NEXT: ${check_a}
729 ; CHECK-NEXT: ${check_b}
730 ; CHECK-NEXT: ${check_c}
731 %r = call ${ret_ty} @${intrinsic}(
732 ${args});
733 ret ${ret_ty} %r;
737 test_params = params
738 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
739 test_params["function"] = test_params["intrinsic"].replace(".", "_")
740 test_params["instruction"] = Template(instruction_template).substitute(params)
741 test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
742 test_params["check_a"] = check_pattern(op.a)
743 test_params["check_b"] = check_pattern(op.b)
744 test_params["check_c"] = check_pattern(op.c)
745 test_params["check_d"] = check_pattern(op.d)
746 args = ",\n ".join(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c))
747 test_params["args"] = args
748 print(Template(mma_template).substitute(test_params))
749 return (test_params["intrinsic"], test_params["instruction"])
752 def get_b1_ops(ptx_type):
753 if ptx_type != "b1":
754 return [""]
755 if ptx_version >= 71:
756 return [".xor.popc", ".and.popc"]
757 return [".xor.popc"]
760 def gen_wmma_mma_tests():
761 wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
762 wmma_instruction_template = "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
764 generated_items = []
766 for op, alayout, blayout, rnd, satf in product(
767 get_wmma_ops(),
768 ["row", "col"],
769 ["row", "col"],
770 [".rn", ".rz", ".rm", ".rp", ""],
771 [".satfinite", ""],
774 if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf):
775 continue
777 for b1op in get_b1_ops(op.a.mma_type.ptx_type):
778 params = {
779 "aligned": ".aligned" if ptx_version >= 63 else "",
780 "alayout": alayout,
781 "blayout": blayout,
782 "intrinsic_signature": wmma_signature(op),
783 "ptx_signature": wmma_ptx_signature(op),
784 "satf": satf,
785 "rnd": rnd,
786 "geom": op.a.geom,
787 "b1op": b1op,
790 intrinsic_template = wmma_intrinsic_template
791 instruction_template = wmma_instruction_template
793 generated_items.append(
794 common_mma_test_gen(
795 params, op, intrinsic_template, instruction_template
799 return generated_items
802 def gen_mma_tests():
803 mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
804 mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
806 generated_items = []
808 for op, alayout, blayout, satf in product(
809 get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]
812 if not is_mma_variant_supported(op, alayout, blayout, satf):
813 continue
815 for b1op in get_b1_ops(op.a.mma_type.ptx_type):
816 params = {
817 "aligned": ".aligned" if ptx_version >= 63 else "",
818 "alayout": alayout,
819 "blayout": blayout,
820 "intrinsic_signature": mma_signature(op),
821 "ptx_signature": mma_ptx_signature(op),
822 "satf": satf,
823 "geom": op.a.geom,
824 "b1op": b1op,
827 intrinsic_template = mma_intrinsic_template
828 instruction_template = mma_instruction_template
830 generated_items.append(
831 common_mma_test_gen(
832 params, op, intrinsic_template, instruction_template
836 return generated_items
839 # Append complete list of intrinsics and instructions we've generated tests for.
840 # Generate set of checks to verify that that we did generate sensible set of
841 # tests for the given combination of PTX and SM variants.
843 def gen_check_unsupported_ops(items):
844 print(
845 "; Complete list of intrinsics supported by PTX%d on sm_%d"
846 % (ptx_version, gpu_arch)
848 print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
849 print(
852 ; NOEXTGEOM-NOT: {{m8n32|m32n8}}
853 ; NOINT-NOT: .{{s32|s8}}
854 ; NOSUBINT-NOT: {{s4|u4|b1}}
855 ; NOMMA-NOT: .m8n8k4.
856 ; NOALTFLOAT-NOT: .{{bf16|tf32}}
857 ; NODOUBLE-NOT: .f64
858 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned
860 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
861 ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
862 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f32
863 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f16
864 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f16
865 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f32
867 ; PTX60 adds support for m32n8k16/m8n32k16 geometries.
868 ; EXTGEOM-DAG: m32n8k16.load.{{[ab].*}}.f16.p
869 ; EXTGEOM-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
870 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f32
871 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f16
872 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f16
873 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f32
875 ; EXTGEOM-DAG: m8n32k16.load.{{[ab].*}}.f16.p
876 ; EXTGEOM-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
877 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f32
878 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f16
879 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f16
880 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f32
882 ; INT-DAG: m16n16k16.load.{{[ab].*}}.s8.p
883 ; INT-DAG: m8n32k16.load.{{[ab].*}}.s8.p
884 ; INT-DAG: m32n8k16.load.{{[ab].*}}.s8.p
885 ; INT-DAG: m16n16k16.load.{{[ab].*}}.u8.p
886 ; INT-DAG: m8n32k16.load.{{[ab].*}}.u8.p
887 ; INT-DAG: m32n8k16.load.{{[ab].*}}.u8.p
888 ; INT-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
889 ; INT-DAG: m16n16k16.mma.{{.*}}.u8
890 ; INT-DAG: m16n16k16.mma.{{.*}}.s8
891 ; INT-DAG: m8n32k16.mma.{{.*}}.u8
892 ; INT-DAG: m8n32k16.mma.{{.*}}.s8
893 ; INT-DAG: m32n8k16.mma.{{.*}}.u8
894 ; INT-DAG: m32n8k16.mma.{{.*}}.s8
896 ; SUBINT-DAG: m8n8k128.load.{{[ab].*}}.b1.p
897 ; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.s4.p
898 ; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.u4.p
899 ; SUBINT-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
900 ; SUBINT-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
901 ; SUBINT-DAG: m8n8k32.mma.{{.*}}.u4
902 ; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4
903 ; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1
905 ; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p
906 ; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p
907 ; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p
908 ; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p
909 ; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16
910 ; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16
911 ; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16
912 ; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32
914 ; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p
915 ; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p
916 ; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64
918 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32
919 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16
920 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16
921 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32
923 ; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16
924 ; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32
925 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8
926 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8
927 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8
928 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8
929 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4
930 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4
931 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
932 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
934 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.b16
935 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.b16
936 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.b16
937 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.b16
938 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.b16
939 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.b16
940 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.shared.b16
941 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.shared.b16
942 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.shared.b16
943 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16
944 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
945 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
947 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
948 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
949 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
950 ; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16
951 ; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16
952 ; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16
953 ; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32
954 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
955 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
956 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
957 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
958 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
959 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
960 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
961 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
962 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
963 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
964 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
965 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
966 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
967 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
968 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
969 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
970 ; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1
971 ; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1
972 ; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1
973 ; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1
974 ; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1
975 ; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1
981 print("; INTRINSICS_LIST_BEGIN")
982 for intrinsic, instruction in sorted(items):
983 print("; ", intrinsic, " -> ", instruction, "")
984 print("; INTRINSICS_LIST_END")
985 print("; INTRINSICS: ; INTRINSICS_LIST_END")
988 def gen_tests():
989 items = gen_wmma_load_tests()
990 items += gen_wmma_store_tests()
991 items += gen_ldmatrix_tests()
992 items += gen_wmma_mma_tests()
993 items += gen_mma_tests()
994 gen_check_unsupported_ops(items)
997 def main():
998 global ptx_version
999 global gpu_arch
1000 parser = argparse.ArgumentParser()
1001 parser.add_argument("--ptx", type=int, default=60)
1002 parser.add_argument("--gpu-arch", type=int, default=70)
1003 args = parser.parse_args()
1005 ptx_version = args.ptx
1006 gpu_arch = args.gpu_arch
1008 gen_tests()
1011 if __name__ == "__main__":
1012 main()