1 # This test generates all variants of wmma intrinsics and verifies that LLVM
2 # generates correct instructions for them. This is the test generator only. The
3 # test scripts themselves are in wmma-ptx*-sm*.py files.
7 from __future__
import print_function
10 from itertools
import product
11 from string
import Template
14 def __init__(self
, ptx_type
):
15 self
.ptx_type
= ptx_type
31 self
.ptx_reg_pattern
= {
35 }.get(ptx_type
, "%r[0-9]+")
38 return "%s/%s" % (self
.ptx_type
, self
.llvm_type
)
42 def __init__(self
, geom
, frag
, ptx_elt_type
):
45 self
.mma_type
= MMAType(ptx_elt_type
)
47 # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
84 # u4/s4 -> s32 @ m8n8k32 (u4/s4)
103 # b1 -> s32 @ m8n8k128(b1)
110 "m16n8k128:c:s32": 4,
111 "m16n8k128:d:s32": 4,
114 "m16n8k256:c:s32": 4,
115 "m16n8k256:d:s32": 4,
116 # bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
117 "m16n16k16:a:bf16": 4,
118 "m16n16k16:b:bf16": 4,
119 "m8n32k16:a:bf16": 2,
120 "m8n32k16:b:bf16": 8,
121 "m32n8k16:a:bf16": 8,
122 "m32n8k16:b:bf16": 2,
123 "m16n8k16:a:bf16": 4,
124 "m16n8k16:b:bf16": 2,
135 # tf32 -> s32 @ m16n16k8
136 "m16n16k8:a:tf32": 4,
137 "m16n16k8:b:tf32": 4,
165 "%s:%s:%s" % (geom
, frag
, ptx_elt_type
),
167 # All other FP shape/fragment/type combinations have the same size
174 }.get("%s:%s" % (frag
, ptx_elt_type
), None),
179 return "%s:%s:%s%s" % (
183 "" if self
.nregs
== 1 else ("*%d" % self
.nregs
),
188 def __init__(self
, a
, b
, c
, d
):
195 return "{A:%s, B:%s, C:%s, D:%s}" % (self
.a
, self
.b
, self
.c
, self
.d
)
198 def make_mma_ops(geoms
, types_a
, types_b
, types_c
, types_d
):
200 for geom
, type_a
, type_c
in product(geoms
, types_a
, types_c
):
201 for type_b
, type_d
in product(
202 types_b
if types_b
else [type_a
], types_d
if types_d
else [type_c
]
206 MMAFrag(geom
, "a", type_a
),
207 MMAFrag(geom
, "b", type_b
),
208 MMAFrag(geom
, "c", type_c
),
209 MMAFrag(geom
, "d", type_d
),
215 def make_ldst_ops(geoms
, frags
, types
):
217 MMAFrag(geom
, frag
, ptx_type
)
218 for (geom
, frag
, ptx_type
) in product(geoms
, frags
, types
)
222 def make_ldmatrix_ops(geoms
, frags
, types
):
224 MMAFrag(geom
, frag
, ptx_type
)
225 for (geom
, frag
, ptx_type
) in product(geoms
, frags
, types
)
231 make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
232 + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["bf16"], [], ["f32"], [])
233 + make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
235 ["m16n16k16", "m32n8k16", "m8n32k16"],
242 ["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], []
244 + make_mma_ops(["m8n8k32"], ["s4", "u4"], [], ["s32"], [])
245 + make_mma_ops(["m8n8k128"], ["b1"], [], ["s32"], [])
251 make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
252 + make_mma_ops(["m16n8k4", "m16n8k8"], ["tf32"], [], ["f32"], [])
253 + make_mma_ops(["m16n8k16", "m16n8k8"], ["bf16"], [], ["f32"], [])
255 ["m8n8k4", "m16n8k8", "m16n8k16"],
262 ["m8n8k16", "m16n8k16", "m16n8k32"], ["s8", "u8"], ["s8", "u8"], ["s32"], []
265 ["m8n8k32", "m16n8k32", "m16n8k64"], ["s4", "u4"], ["s4", "u4"], ["s32"], []
267 + make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], ["b1"], [], ["s32"], [])
271 def get_ldst_ops(kind
):
274 ["m16n16k16", "m32n8k16", "m8n32k16"],
276 ["f16", "u8", "s8", "bf16"],
279 ["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"]
281 + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4", "u4"])
282 + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"])
283 + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])
284 + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"])
285 + make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"])
286 + make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"])
288 return [x
for x
in ldst_ops
if (x
.frag
== "d") == (kind
== "store")]
291 def get_ldmatrix_ops():
292 return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
295 def is_wmma_geom_supported(geom
):
296 # geometries for FP and ints.
297 if geom
in ["m8n32k16", "m32n8k16"]:
298 return ptx_version
>= 61
299 # geometries for sub-ints.
300 if geom
in ["m8n8k32", "m8n8k128"]:
301 return ptx_version
>= 63 and gpu_arch
>= 75
302 if geom
== "m16n16k16":
303 return ptx_version
>= 60
304 if geom
== "m16n8k8":
305 return ptx_version
>= 65
306 if geom
in ["m16n16k8", "m8n8k4"]:
307 return ptx_version
>= 70
308 assert False # Unexpected geometry.
311 def is_mma_geom_supported(geom
):
312 # geometries for FP and ints.
314 return ptx_version
>= 64
315 if geom
in ["m16n8k8", "m8n8k16", "m8n8k32"]:
316 return ptx_version
>= 65
326 return ptx_version
>= 70
327 assert False # Unexpected geometry.
330 def is_ldmatrix_geom_supported(geom
):
332 return ptx_version
>= 65 and gpu_arch
>= 75
333 assert False # Unexpected geometry.
336 def is_type_supported(ptx_type
):
337 if ptx_type
in ["s8", "u8", "s32"]:
338 return ptx_version
>= 63 and gpu_arch
>= 72
339 if ptx_type
in ["s4", "u4", "b1"]:
340 return ptx_version
>= 63 and gpu_arch
>= 75
341 if ptx_type
== "b16":
342 return ptx_version
>= 65 and gpu_arch
>= 75
343 if ptx_type
in ["bf16", "tf32", "f64"]:
344 return ptx_version
>= 70
345 return ptx_version
>= 60 and gpu_arch
>= 70
348 def is_wmma_variant_supported(op
, layout_a
, layout_b
, rnd
, satf
):
350 is_type_supported(op
.a
.mma_type
.ptx_type
) and is_wmma_geom_supported(op
.a
.geom
)
354 # rnd is only supported for FP64 WMMA
355 if rnd
and op
.a
.mma_type
.ptx_type
!= "f64":
359 # satfinite for floating points was removed in PTX 6.5
360 if op
.a
.mma_type
.ptx_type
== "f16" and ptx_version
>= 65:
362 if not op
.a
.mma_type
.ptx_type
in ["f16", "s8", "u8", "s4", "u4"]:
365 # sub-integer require row/col layout.
366 if op
.a
.mma_type
.ptx_type
in ["s4", "u4", "b1"]:
367 return layout_a
== "row" and layout_b
== "col"
371 def is_mma_variant_supported(op
, layout_a
, layout_b
, satf
):
373 is_type_supported(op
.a
.mma_type
.ptx_type
) and is_mma_geom_supported(op
.a
.geom
)
377 if satf
and not op
.a
.mma_type
.ptx_type
in ["s8", "u8", "s4", "u4"]:
380 # If the type of C is f32 then so must the type of D
382 op
.a
.geom
== "m8n8k4"
383 and op
.c
.mma_type
.ptx_type
== "f32"
384 and op
.d
.mma_type
.ptx_type
!= "f32"
388 # A and B type must be the same. C and D type must be the same
389 if op
.a
.geom
== "m16n8k8" and (
390 op
.a
.mma_type
.ptx_type
!= op
.b
.mma_type
.ptx_type
391 or op
.c
.mma_type
.ptx_type
!= op
.d
.mma_type
.ptx_type
395 # C and D type must be the same
396 if op
.a
.geom
== "m16n8k16" and op
.c
.mma_type
.ptx_type
!= op
.d
.mma_type
.ptx_type
:
399 # Require row/col layout for all MMA except m8n8k4 on FP16
400 if not (op
.a
.geom
== "m8n8k4" and op
.a
.mma_type
.ptx_type
== "f16"):
401 return layout_a
== "row" and layout_b
== "col"
405 def is_ldst_variant_supported(frag
, layout
):
407 is_type_supported(frag
.mma_type
.ptx_type
) and is_wmma_geom_supported(frag
.geom
)
410 if frag
.mma_type
.ptx_type
in ["s4", "u4", "b1"]:
411 # sub-integer require sm_75 and ptx63, row/col layout for a/b.
413 (frag
.frag
== "a" and layout
== "row")
414 or (frag
.frag
== "b" and layout
== "col")
415 or frag
.frag
in ["c", "d"]
420 def is_ldmatrix_variant_supported(frag
):
422 is_type_supported(frag
.mma_type
.ptx_type
)
423 and is_ldmatrix_geom_supported(frag
.geom
)
426 return frag
.frag
in ["x1", "x2", "x4"]
429 def make_wmma_slice_ty(frag
):
430 return [frag
.mma_type
.llvm_type
] * frag
.nregs
433 def make_wmma_ld_ret_ty(frag
):
434 results
= make_wmma_slice_ty(frag
)
435 if len(results
) == 1:
436 return "%s" % results
[0]
437 return "{%s}" % ", ".join(results
)
440 # returns address space
441 def get_aspace(space
):
451 return space_map
[space
]
454 def get_pspace(space
):
455 return "p%di8" % get_aspace(space
)
458 def check_pattern(frag
):
459 return "{{%s}}" % ", *".join([frag
.mma_type
.ptx_reg_pattern
] * frag
.nregs
)
462 def gen_wmma_load_tests():
464 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
466 ; CHECK-LABEL: .func {{.*}}test_${function}(
467 define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
468 ; CHECK: ${instruction}
469 ; CHECK: {${check_result}}
470 ; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
471 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
475 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
476 define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
477 ; CHECK: ${instruction}
478 ; CHECK: {${check_result}}
479 ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
480 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
481 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
485 intrinsic_template
= (
486 "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
488 instruction_template
= (
489 "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
494 for frag
, layout
, space
, stride
in product(
495 get_ldst_ops("load"),
497 ["", ".shared", ".global"],
500 if not is_ldst_variant_supported(frag
, layout
):
505 "aligned": ".aligned" if ptx_version
>= 63 else "",
509 "itype": frag
.mma_type
.ptx_type
,
510 "pspace": get_pspace(space
),
511 "as": "addrspace(%d)" % get_aspace(space
),
516 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
517 test_params
["function"] = test_params
["intrinsic"].replace(".", "_")
518 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
519 test_params
["ret_ty"] = make_wmma_ld_ret_ty(frag
)
520 test_params
["check_result"] = check_pattern(frag
)
523 test_params
["extra_args"] = ", i32 %stride"
524 test_params
["stride_pattern"] = ", %r{{[0-9]+}}"
526 test_params
["extra_args"] = ""
527 test_params
["stride_pattern"] = ""
529 print(Template(load_template
).substitute(test_params
))
531 generated_items
.append((test_params
["intrinsic"], test_params
["instruction"]))
533 return generated_items
536 def make_wmma_slice_args(frag
):
539 "%s %%%s%d" % (t
, frag
.frag
, i
)
540 for i
, t
in enumerate(make_wmma_slice_ty(frag
))
545 def gen_wmma_store_tests():
547 declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
549 ; CHECK-LABEL: .func {{.*}}test_${function}(
550 define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
551 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
552 ; CHECK: {${check_args}}
553 ; CHECK: ${stride_pattern}
554 call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
558 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
559 define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
560 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
561 ; CHECK: ${check_args}
562 ; CHECK: ${stride_pattern}
563 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
564 call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
568 intrinsic_template
= (
569 "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
571 instruction_template
= (
572 "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
577 for frag
, layout
, space
, stride
in product(
578 get_ldst_ops("store"),
580 ["", ".shared", ".global"],
584 if not is_ldst_variant_supported(frag
, layout
):
589 "aligned": ".aligned" if ptx_version
>= 63 else "",
593 "itype": frag
.mma_type
.ptx_type
,
594 "pspace": get_pspace(space
),
595 "as": "addrspace(%d)" % get_aspace(space
),
600 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
601 test_params
["function"] = test_params
["intrinsic"].replace(".", "_")
602 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
603 test_params
["ret_ty"] = make_wmma_ld_ret_ty(frag
)
604 test_params
["check_args"] = check_pattern(frag
)
606 test_params
["extra_args"] = ", i32 %stride"
607 test_params
["stride_pattern"] = ", %r{{[0-9]+}};"
609 test_params
["extra_args"] = ""
610 test_params
["stride_pattern"] = ";"
611 test_params
["args"] = make_wmma_slice_args(frag
)
613 print(Template(store_template
).substitute(test_params
))
614 generated_items
.append((test_params
["intrinsic"], test_params
["instruction"]))
616 return generated_items
619 def gen_ldmatrix_tests():
620 ldmatrix_template
= """
621 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src);
623 ; CHECK-LABEL: .func {{.*}}test_${function}(
624 define ${ret_ty} @test_${function}(i8 ${as}* %src) {
625 ; CHECK: ${instruction}
626 ; CHECK: {${check_result}}
627 ; CHECK: [%rd{{[0-9]+}}]
628 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src);
632 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
633 define ${ret_ty} @test_${function}_o(i8 ${as}* %src) {
634 ; CHECK: ${instruction}
635 ; CHECK: {${check_result}}
636 ; CHECK: [%rd{{[0-9]+}}+128]
637 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
638 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1);
642 intrinsic_template
= (
643 "llvm.nvvm.ldmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
645 instruction_template
= (
646 "ldmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
651 for frag
, space
, trans
in product(
656 if not is_ldmatrix_variant_supported(frag
):
663 "itype": frag
.mma_type
.ptx_type
,
664 "pspace": get_pspace(space
),
665 "as": "addrspace(%d)" % get_aspace(space
),
670 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
671 test_params
["function"] = test_params
["intrinsic"].replace(".", "_")
672 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
673 test_params
["ret_ty"] = make_wmma_ld_ret_ty(frag
)
674 test_params
["check_result"] = check_pattern(frag
)
676 print(Template(ldmatrix_template
).substitute(test_params
))
678 generated_items
.append((test_params
["intrinsic"], test_params
["instruction"]))
680 return generated_items
683 def mma_signature(op
):
684 if op
.a
.mma_type
.ptx_type
== "f16":
685 # FP16 ops identified by accumulator & result type.
686 return "%s.%s" % (op
.d
.mma_type
.ptx_type
, op
.c
.mma_type
.ptx_type
)
687 elif op
.a
.mma_type
.ptx_type
!= op
.b
.mma_type
.ptx_type
:
688 # other ops are identified by input types.
689 return "%s.%s" % (op
.a
.mma_type
.ptx_type
, op
.b
.mma_type
.ptx_type
)
691 # if input types are the same, it only appears once.
692 return op
.a
.mma_type
.ptx_type
695 def mma_ptx_signature(op
):
696 # Encode all four types as D.A.B.C
697 return ".".join(x
.mma_type
.ptx_type
for x
in (op
.d
, op
.a
, op
.b
, op
.c
))
700 def wmma_signature(op
):
701 if op
.a
.mma_type
.ptx_type
== "f16":
702 # FP16 ops identified by accumulator & result type.
703 return "%s.%s" % (op
.d
.mma_type
.ptx_type
, op
.c
.mma_type
.ptx_type
)
705 # other ops are identified by input type.
706 return op
.a
.mma_type
.ptx_type
709 def wmma_ptx_signature(op
):
710 if op
.a
.mma_type
.ptx_type
== "f16":
711 # FP16 instructions use D.C
712 return "%s.%s" % (op
.d
.mma_type
.ptx_type
, op
.c
.mma_type
.ptx_type
)
714 # other instructions encode all four types as D.A.B.C
715 return ".".join(x
.mma_type
.ptx_type
for x
in (op
.d
, op
.a
, op
.b
, op
.c
))
718 def common_mma_test_gen(params
, op
, intrinsic_template
, instruction_template
):
720 declare ${ret_ty} @${intrinsic}(
723 ; CHECK-LABEL: .func {{.*}}test_${function}(
724 define ${ret_ty} @test_${function}(
726 ; CHECK: ${instruction}
727 ; CHECK-NEXT: ${check_d}
728 ; CHECK-NEXT: ${check_a}
729 ; CHECK-NEXT: ${check_b}
730 ; CHECK-NEXT: ${check_c}
731 %r = call ${ret_ty} @${intrinsic}(
738 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
739 test_params
["function"] = test_params
["intrinsic"].replace(".", "_")
740 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
741 test_params
["ret_ty"] = make_wmma_ld_ret_ty(op
.d
)
742 test_params
["check_a"] = check_pattern(op
.a
)
743 test_params
["check_b"] = check_pattern(op
.b
)
744 test_params
["check_c"] = check_pattern(op
.c
)
745 test_params
["check_d"] = check_pattern(op
.d
)
746 args
= ",\n ".join(make_wmma_slice_args(frag
) for frag
in (op
.a
, op
.b
, op
.c
))
747 test_params
["args"] = args
748 print(Template(mma_template
).substitute(test_params
))
749 return (test_params
["intrinsic"], test_params
["instruction"])
752 def get_b1_ops(ptx_type
):
755 if ptx_version
>= 71:
756 return [".xor.popc", ".and.popc"]
760 def gen_wmma_mma_tests():
761 wmma_intrinsic_template
= "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
762 wmma_instruction_template
= "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
766 for op
, alayout
, blayout
, rnd
, satf
in product(
770 [".rn", ".rz", ".rm", ".rp", ""],
774 if not is_wmma_variant_supported(op
, alayout
, blayout
, rnd
, satf
):
777 for b1op
in get_b1_ops(op
.a
.mma_type
.ptx_type
):
779 "aligned": ".aligned" if ptx_version
>= 63 else "",
782 "intrinsic_signature": wmma_signature(op
),
783 "ptx_signature": wmma_ptx_signature(op
),
790 intrinsic_template
= wmma_intrinsic_template
791 instruction_template
= wmma_instruction_template
793 generated_items
.append(
795 params
, op
, intrinsic_template
, instruction_template
799 return generated_items
803 mma_intrinsic_template
= "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
804 mma_instruction_template
= "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
808 for op
, alayout
, blayout
, satf
in product(
809 get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]
812 if not is_mma_variant_supported(op
, alayout
, blayout
, satf
):
815 for b1op
in get_b1_ops(op
.a
.mma_type
.ptx_type
):
817 "aligned": ".aligned" if ptx_version
>= 63 else "",
820 "intrinsic_signature": mma_signature(op
),
821 "ptx_signature": mma_ptx_signature(op
),
827 intrinsic_template
= mma_intrinsic_template
828 instruction_template
= mma_instruction_template
830 generated_items
.append(
832 params
, op
, intrinsic_template
, instruction_template
836 return generated_items
839 # Append complete list of intrinsics and instructions we've generated tests for.
840 # Generate set of checks to verify that that we did generate sensible set of
841 # tests for the given combination of PTX and SM variants.
843 def gen_check_unsupported_ops(items
):
845 "; Complete list of intrinsics supported by PTX%d on sm_%d"
846 % (ptx_version
, gpu_arch
)
848 print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
852 ; NOEXTGEOM-NOT: {{m8n32|m32n8}}
853 ; NOINT-NOT: .{{s32|s8}}
854 ; NOSUBINT-NOT: {{s4|u4|b1}}
855 ; NOMMA-NOT: .m8n8k4.
856 ; NOALTFLOAT-NOT: .{{bf16|tf32}}
858 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned
860 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
861 ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
862 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f32
863 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f16
864 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f16
865 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f32
867 ; PTX60 adds support for m32n8k16/m8n32k16 geometries.
868 ; EXTGEOM-DAG: m32n8k16.load.{{[ab].*}}.f16.p
869 ; EXTGEOM-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
870 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f32
871 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f16
872 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f16
873 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f32
875 ; EXTGEOM-DAG: m8n32k16.load.{{[ab].*}}.f16.p
876 ; EXTGEOM-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
877 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f32
878 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f16
879 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f16
880 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f32
882 ; INT-DAG: m16n16k16.load.{{[ab].*}}.s8.p
883 ; INT-DAG: m8n32k16.load.{{[ab].*}}.s8.p
884 ; INT-DAG: m32n8k16.load.{{[ab].*}}.s8.p
885 ; INT-DAG: m16n16k16.load.{{[ab].*}}.u8.p
886 ; INT-DAG: m8n32k16.load.{{[ab].*}}.u8.p
887 ; INT-DAG: m32n8k16.load.{{[ab].*}}.u8.p
888 ; INT-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
889 ; INT-DAG: m16n16k16.mma.{{.*}}.u8
890 ; INT-DAG: m16n16k16.mma.{{.*}}.s8
891 ; INT-DAG: m8n32k16.mma.{{.*}}.u8
892 ; INT-DAG: m8n32k16.mma.{{.*}}.s8
893 ; INT-DAG: m32n8k16.mma.{{.*}}.u8
894 ; INT-DAG: m32n8k16.mma.{{.*}}.s8
896 ; SUBINT-DAG: m8n8k128.load.{{[ab].*}}.b1.p
897 ; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.s4.p
898 ; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.u4.p
899 ; SUBINT-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
900 ; SUBINT-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
901 ; SUBINT-DAG: m8n8k32.mma.{{.*}}.u4
902 ; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4
903 ; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1
905 ; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p
906 ; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p
907 ; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p
908 ; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p
909 ; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16
910 ; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16
911 ; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16
912 ; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32
914 ; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p
915 ; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p
916 ; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64
918 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32
919 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16
920 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16
921 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32
923 ; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16
924 ; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32
925 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8
926 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8
927 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8
928 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8
929 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4
930 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4
931 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
932 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
934 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.b16
935 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.b16
936 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.b16
937 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.b16
938 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.b16
939 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.b16
940 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.shared.b16
941 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.shared.b16
942 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.shared.b16
943 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16
944 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
945 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
947 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
948 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
949 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
950 ; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16
951 ; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16
952 ; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16
953 ; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32
954 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
955 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
956 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
957 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
958 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
959 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
960 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
961 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
962 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
963 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
964 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
965 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
966 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
967 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
968 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
969 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
970 ; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1
971 ; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1
972 ; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1
973 ; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1
974 ; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1
975 ; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1
981 print("; INTRINSICS_LIST_BEGIN")
982 for intrinsic
, instruction
in sorted(items
):
983 print("; ", intrinsic
, " -> ", instruction
, "")
984 print("; INTRINSICS_LIST_END")
985 print("; INTRINSICS: ; INTRINSICS_LIST_END")
989 items
= gen_wmma_load_tests()
990 items
+= gen_wmma_store_tests()
991 items
+= gen_ldmatrix_tests()
992 items
+= gen_wmma_mma_tests()
993 items
+= gen_mma_tests()
994 gen_check_unsupported_ops(items
)
1000 parser
= argparse
.ArgumentParser()
1001 parser
.add_argument("--ptx", type=int, default
=60)
1002 parser
.add_argument("--gpu-arch", type=int, default
=70)
1003 args
= parser
.parse_args()
1005 ptx_version
= args
.ptx
1006 gpu_arch
= args
.gpu_arch
1011 if __name__
== "__main__":