1 # This test generates all variants of wmma intrinsics and verifies that LLVM
2 # generates correct instructions for them.
4 # Check all variants of instructions supported by PTX60 on SM70
5 # RUN: %python %s --ptx=60 --gpu-arch=70 > %t-ptx60-sm_70.ll
6 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
7 # RUN: --check-prefixes=INTRINSICS,M16N16
8 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
9 # RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
10 # RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
11 # RUN: | FileCheck %t-ptx60-sm_70.ll
13 # RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
14 # RUN: | %ptxas-verify -arch=sm_70 \
17 # Check all variants of instructions supported by PTX61 on SM70
18 # RUN: %python %s --ptx=61 --gpu-arch=70 > %t-ptx61-sm_70.ll
19 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
20 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM
21 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
22 # RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
23 # RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
24 # RUN: | FileCheck %t-ptx61-sm_70.ll
25 # RUN: %if ptxas-9.1 %{ \
26 # RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
27 # RUN: | %ptxas-verify -arch=sm_70 \
30 # Check all variants of instructions supported by PTX63 on SM72
31 # RUN: %python %s --ptx=63 --gpu-arch=72 > %t-ptx63-sm_72.ll
32 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
33 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT
34 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
35 # RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
36 # RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
37 # RUN: | FileCheck %t-ptx63-sm_72.ll
38 # RUN: %if ptxas-10.0 %{ \
39 # RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
40 # RUN: | %ptxas-verify -arch=sm_72 \
43 # Check all variants of instructions supported by PTX63 on SM75
44 # RUN: %python %s --ptx=63 --gpu-arch=75 > %t-ptx63-sm_75.ll
45 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
46 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT
47 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
48 # RUN: --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
49 # RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
50 # RUN: | FileCheck %t-ptx63-sm_75.ll
51 # RUN: %if ptxas-10.0 %{ \
52 # RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
53 # RUN: | %ptxas-verify -arch=sm_75 \
56 # Check all variants of instructions supported by PTX64 on SM70+
57 # RUN: %python %s --ptx=64 --gpu-arch=70 > %t-ptx64-sm_70.ll
58 # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
59 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA
60 # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
61 # RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT,NOLDMATRIX
62 # RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \
63 # RUN: | FileCheck %t-ptx64-sm_70.ll
64 # RUN: %if ptxas-10.1 %{ \
65 # RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \
66 # RUN: | %ptxas-verify -arch=sm_70 \
69 # Check all variants of instructions supported by PTX65 on SM75+
70 # RUN: %python %s --ptx=65 --gpu-arch=75 > %t-ptx65-sm_75.ll
71 # RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
72 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA,PTX65LDMATRIX
73 # RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
74 # RUN: --check-prefixes=INTRINSICS
75 # RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
76 # RUN: | FileCheck %t-ptx65-sm_75.ll
77 # RUN: %if ptxas-10.2 %{ \
78 # RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
79 # RUN: | %ptxas-verify -arch=sm_75 \
82 # Check all variants of instructions supported by PTX71 on SM80+
83 # RUN: %python %s --ptx=71 --gpu-arch=80 > %t-ptx71-sm_80.ll
84 # RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
85 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX65LDMATRIX,PTX71MMA
86 # RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
87 # RUN: --check-prefixes=INTRINSICS
88 # RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \
89 # RUN: | FileCheck %t-ptx71-sm_80.ll
90 # RUN: %if ptxas-11.1 %{ \
91 # RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \
92 # RUN: | %ptxas-verify -arch=sm_80 \
95 from __future__
import print_function
98 from itertools
import product
99 from string
import Template
103 def __init__(self
, ptx_type
):
104 self
.ptx_type
= ptx_type
120 self
.ptx_reg_pattern
= {
124 }.get(ptx_type
, "%r[0-9]+")
127 return "%s/%s" % (self
.ptx_type
, self
.llvm_type
)
131 def __init__(self
, geom
, frag
, ptx_elt_type
):
134 self
.mma_type
= MMAType(ptx_elt_type
)
136 # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
141 "m16n16k16:c:s32": 8,
142 "m16n16k16:d:s32": 8,
173 # u4/s4 -> s32 @ m8n8k32 (u4/s4)
192 # b1 -> s32 @ m8n8k128(b1)
199 "m16n8k128:c:s32": 4,
200 "m16n8k128:d:s32": 4,
203 "m16n8k256:c:s32": 4,
204 "m16n8k256:d:s32": 4,
205 # bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
206 "m16n16k16:a:bf16": 4,
207 "m16n16k16:b:bf16": 4,
208 "m8n32k16:a:bf16": 2,
209 "m8n32k16:b:bf16": 8,
210 "m32n8k16:a:bf16": 8,
211 "m32n8k16:b:bf16": 2,
212 "m16n8k16:a:bf16": 4,
213 "m16n8k16:b:bf16": 2,
224 # tf32 -> s32 @ m16n16k8
225 "m16n16k8:a:tf32": 4,
226 "m16n16k8:b:tf32": 4,
254 "%s:%s:%s" % (geom
, frag
, ptx_elt_type
),
256 # All other FP shape/fragment/type combinations have the same size
263 }.get("%s:%s" % (frag
, ptx_elt_type
), None),
268 return "%s:%s:%s%s" % (
272 "" if self
.nregs
== 1 else ("*%d" % self
.nregs
),
277 def __init__(self
, a
, b
, c
, d
):
284 return "{A:%s, B:%s, C:%s, D:%s}" % (self
.a
, self
.b
, self
.c
, self
.d
)
287 def make_mma_ops(geoms
, types_a
, types_b
, types_c
, types_d
):
289 for geom
, type_a
, type_c
in product(geoms
, types_a
, types_c
):
290 for type_b
, type_d
in product(
291 types_b
if types_b
else [type_a
], types_d
if types_d
else [type_c
]
295 MMAFrag(geom
, "a", type_a
),
296 MMAFrag(geom
, "b", type_b
),
297 MMAFrag(geom
, "c", type_c
),
298 MMAFrag(geom
, "d", type_d
),
304 def make_ldst_ops(geoms
, frags
, types
):
306 MMAFrag(geom
, frag
, ptx_type
)
307 for (geom
, frag
, ptx_type
) in product(geoms
, frags
, types
)
311 def make_ldmatrix_ops(geoms
, frags
, types
):
313 MMAFrag(geom
, frag
, ptx_type
)
314 for (geom
, frag
, ptx_type
) in product(geoms
, frags
, types
)
320 make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
321 + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["bf16"], [], ["f32"], [])
322 + make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
324 ["m16n16k16", "m32n8k16", "m8n32k16"],
331 ["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], []
333 + make_mma_ops(["m8n8k32"], ["s4", "u4"], [], ["s32"], [])
334 + make_mma_ops(["m8n8k128"], ["b1"], [], ["s32"], [])
340 make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
341 + make_mma_ops(["m16n8k4", "m16n8k8"], ["tf32"], [], ["f32"], [])
342 + make_mma_ops(["m16n8k16", "m16n8k8"], ["bf16"], [], ["f32"], [])
344 ["m8n8k4", "m16n8k8", "m16n8k16"],
351 ["m8n8k16", "m16n8k16", "m16n8k32"], ["s8", "u8"], ["s8", "u8"], ["s32"], []
354 ["m8n8k32", "m16n8k32", "m16n8k64"], ["s4", "u4"], ["s4", "u4"], ["s32"], []
356 + make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], ["b1"], [], ["s32"], [])
360 def get_ldst_ops(kind
):
363 ["m16n16k16", "m32n8k16", "m8n32k16"],
365 ["f16", "u8", "s8", "bf16"],
368 ["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"]
370 + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4", "u4"])
371 + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"])
372 + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])
373 + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"])
374 + make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"])
375 + make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"])
377 return [x
for x
in ldst_ops
if (x
.frag
== "d") == (kind
== "store")]
380 def get_ldmatrix_ops():
381 return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
384 def is_wmma_geom_supported(geom
):
385 # geometries for FP and ints.
386 if geom
in ["m8n32k16", "m32n8k16"]:
387 return ptx_version
>= 61
388 # geometries for sub-ints.
389 if geom
in ["m8n8k32", "m8n8k128"]:
390 return ptx_version
>= 63 and gpu_arch
>= 75
391 if geom
== "m16n16k16":
392 return ptx_version
>= 60
393 if geom
== "m16n8k8":
394 return ptx_version
>= 65
395 if geom
in ["m16n16k8", "m8n8k4"]:
396 return ptx_version
>= 70
397 assert False # Unexpected geometry.
400 def is_mma_geom_supported(geom
):
401 # geometries for FP and ints.
403 return ptx_version
>= 64
404 if geom
in ["m16n8k8", "m8n8k16", "m8n8k32"]:
405 return ptx_version
>= 65
415 return ptx_version
>= 70
416 assert False # Unexpected geometry.
419 def is_ldmatrix_geom_supported(geom
):
421 return ptx_version
>= 65 and gpu_arch
>= 75
422 assert False # Unexpected geometry.
425 def is_type_supported(ptx_type
):
426 if ptx_type
in ["s8", "u8", "s32"]:
427 return ptx_version
>= 63 and gpu_arch
>= 72
428 if ptx_type
in ["s4", "u4", "b1"]:
429 return ptx_version
>= 63 and gpu_arch
>= 75
430 if ptx_type
== "b16":
431 return ptx_version
>= 65 and gpu_arch
>= 75
432 if ptx_type
in ["bf16", "tf32", "f64"]:
433 return ptx_version
>= 70
434 return ptx_version
>= 60 and gpu_arch
>= 70
437 def is_wmma_variant_supported(op
, layout_a
, layout_b
, rnd
, satf
):
439 is_type_supported(op
.a
.mma_type
.ptx_type
) and is_wmma_geom_supported(op
.a
.geom
)
443 # rnd is only supported for FP64 WMMA
444 if rnd
and op
.a
.mma_type
.ptx_type
!= "f64":
448 # satfinite for floating points was removed in PTX 6.5
449 if op
.a
.mma_type
.ptx_type
== "f16" and ptx_version
>= 65:
451 if not op
.a
.mma_type
.ptx_type
in ["f16", "s8", "u8", "s4", "u4"]:
454 # sub-integer require row/col layout.
455 if op
.a
.mma_type
.ptx_type
in ["s4", "u4", "b1"]:
456 return layout_a
== "row" and layout_b
== "col"
460 def is_mma_variant_supported(op
, layout_a
, layout_b
, satf
):
462 is_type_supported(op
.a
.mma_type
.ptx_type
) and is_mma_geom_supported(op
.a
.geom
)
466 if satf
and not op
.a
.mma_type
.ptx_type
in ["s8", "u8", "s4", "u4"]:
469 # If the type of C is f32 then so must the type of D
471 op
.a
.geom
== "m8n8k4"
472 and op
.c
.mma_type
.ptx_type
== "f32"
473 and op
.d
.mma_type
.ptx_type
!= "f32"
477 # A and B type must be the same. C and D type must be the same
478 if op
.a
.geom
== "m16n8k8" and (
479 op
.a
.mma_type
.ptx_type
!= op
.b
.mma_type
.ptx_type
480 or op
.c
.mma_type
.ptx_type
!= op
.d
.mma_type
.ptx_type
484 # C and D type must be the same
485 if op
.a
.geom
== "m16n8k16" and op
.c
.mma_type
.ptx_type
!= op
.d
.mma_type
.ptx_type
:
488 # Require row/col layout for all MMA except m8n8k4 on FP16
489 if not (op
.a
.geom
== "m8n8k4" and op
.a
.mma_type
.ptx_type
== "f16"):
490 return layout_a
== "row" and layout_b
== "col"
494 def is_ldst_variant_supported(frag
, layout
):
496 is_type_supported(frag
.mma_type
.ptx_type
) and is_wmma_geom_supported(frag
.geom
)
499 if frag
.mma_type
.ptx_type
in ["s4", "u4", "b1"]:
500 # sub-integer require sm_75 and ptx63, row/col layout for a/b.
502 (frag
.frag
== "a" and layout
== "row")
503 or (frag
.frag
== "b" and layout
== "col")
504 or frag
.frag
in ["c", "d"]
509 def is_ldmatrix_variant_supported(frag
):
511 is_type_supported(frag
.mma_type
.ptx_type
)
512 and is_ldmatrix_geom_supported(frag
.geom
)
515 return frag
.frag
in ["x1", "x2", "x4"]
518 def make_wmma_slice_ty(frag
):
519 return [frag
.mma_type
.llvm_type
] * frag
.nregs
522 def make_wmma_ld_ret_ty(frag
):
523 results
= make_wmma_slice_ty(frag
)
524 if len(results
) == 1:
525 return "%s" % results
[0]
526 return "{%s}" % ", ".join(results
)
529 # returns address space
530 def get_aspace(space
):
540 return space_map
[space
]
543 def get_pspace(space
):
544 return "p%di8" % get_aspace(space
)
547 def check_pattern(frag
):
548 return "{{%s}}" % ", *".join([frag
.mma_type
.ptx_reg_pattern
] * frag
.nregs
)
551 def gen_wmma_load_tests():
553 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
555 ; CHECK-LABEL: .func {{.*}}test_${function}(
556 define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
557 ; CHECK: ${instruction}
558 ; CHECK: {${check_result}}
559 ; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
560 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
564 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
565 define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
566 ; CHECK: ${instruction}
567 ; CHECK: {${check_result}}
568 ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
569 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
570 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
574 intrinsic_template
= (
575 "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
577 instruction_template
= (
578 "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
583 for frag
, layout
, space
, stride
in product(
584 get_ldst_ops("load"),
586 ["", ".shared", ".global"],
589 if not is_ldst_variant_supported(frag
, layout
):
594 "aligned": ".aligned" if ptx_version
>= 63 else "",
598 "itype": frag
.mma_type
.ptx_type
,
599 "pspace": get_pspace(space
),
600 "as": "addrspace(%d)" % get_aspace(space
),
605 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
606 test_params
["function"] = test_params
["intrinsic"].replace(".", "_")
607 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
608 test_params
["ret_ty"] = make_wmma_ld_ret_ty(frag
)
609 test_params
["check_result"] = check_pattern(frag
)
612 test_params
["extra_args"] = ", i32 %stride"
613 test_params
["stride_pattern"] = ", %r{{[0-9]+}}"
615 test_params
["extra_args"] = ""
616 test_params
["stride_pattern"] = ""
618 print(Template(load_template
).substitute(test_params
))
620 generated_items
.append((test_params
["intrinsic"], test_params
["instruction"]))
622 return generated_items
625 def make_wmma_slice_args(frag
):
628 "%s %%%s%d" % (t
, frag
.frag
, i
)
629 for i
, t
in enumerate(make_wmma_slice_ty(frag
))
634 def gen_wmma_store_tests():
636 declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
638 ; CHECK-LABEL: .func {{.*}}test_${function}(
639 define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
640 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
641 ; CHECK: {${check_args}}
642 ; CHECK: ${stride_pattern}
643 call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
647 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
648 define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
649 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
650 ; CHECK: ${check_args}
651 ; CHECK: ${stride_pattern}
652 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
653 call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
657 intrinsic_template
= (
658 "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
660 instruction_template
= (
661 "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
666 for frag
, layout
, space
, stride
in product(
667 get_ldst_ops("store"),
669 ["", ".shared", ".global"],
673 if not is_ldst_variant_supported(frag
, layout
):
678 "aligned": ".aligned" if ptx_version
>= 63 else "",
682 "itype": frag
.mma_type
.ptx_type
,
683 "pspace": get_pspace(space
),
684 "as": "addrspace(%d)" % get_aspace(space
),
689 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
690 test_params
["function"] = test_params
["intrinsic"].replace(".", "_")
691 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
692 test_params
["ret_ty"] = make_wmma_ld_ret_ty(frag
)
693 test_params
["check_args"] = check_pattern(frag
)
695 test_params
["extra_args"] = ", i32 %stride"
696 test_params
["stride_pattern"] = ", %r{{[0-9]+}};"
698 test_params
["extra_args"] = ""
699 test_params
["stride_pattern"] = ";"
700 test_params
["args"] = make_wmma_slice_args(frag
)
702 print(Template(store_template
).substitute(test_params
))
703 generated_items
.append((test_params
["intrinsic"], test_params
["instruction"]))
705 return generated_items
708 def gen_ldmatrix_tests():
709 ldmatrix_template
= """
710 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src);
712 ; CHECK-LABEL: .func {{.*}}test_${function}(
713 define ${ret_ty} @test_${function}(i8 ${as}* %src) {
714 ; CHECK: ${instruction}
715 ; CHECK: {${check_result}}
716 ; CHECK: [%rd{{[0-9]+}}]
717 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src);
721 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
722 define ${ret_ty} @test_${function}_o(i8 ${as}* %src) {
723 ; CHECK: ${instruction}
724 ; CHECK: {${check_result}}
725 ; CHECK: [%rd{{[0-9]+}}+128]
726 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
727 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1);
731 intrinsic_template
= (
732 "llvm.nvvm.ldmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
734 instruction_template
= (
735 "ldmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
740 for frag
, space
, trans
in product(
745 if not is_ldmatrix_variant_supported(frag
):
752 "itype": frag
.mma_type
.ptx_type
,
753 "pspace": get_pspace(space
),
754 "as": "addrspace(%d)" % get_aspace(space
),
759 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
760 test_params
["function"] = test_params
["intrinsic"].replace(".", "_")
761 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
762 test_params
["ret_ty"] = make_wmma_ld_ret_ty(frag
)
763 test_params
["check_result"] = check_pattern(frag
)
765 print(Template(ldmatrix_template
).substitute(test_params
))
767 generated_items
.append((test_params
["intrinsic"], test_params
["instruction"]))
769 return generated_items
772 def mma_signature(op
):
773 if op
.a
.mma_type
.ptx_type
== "f16":
774 # FP16 ops identified by accumulator & result type.
775 return "%s.%s" % (op
.d
.mma_type
.ptx_type
, op
.c
.mma_type
.ptx_type
)
776 elif op
.a
.mma_type
.ptx_type
!= op
.b
.mma_type
.ptx_type
:
777 # other ops are identified by input types.
778 return "%s.%s" % (op
.a
.mma_type
.ptx_type
, op
.b
.mma_type
.ptx_type
)
780 # if input types are the same, it only appears once.
781 return op
.a
.mma_type
.ptx_type
784 def mma_ptx_signature(op
):
785 # Encode all four types as D.A.B.C
786 return ".".join(x
.mma_type
.ptx_type
for x
in (op
.d
, op
.a
, op
.b
, op
.c
))
789 def wmma_signature(op
):
790 if op
.a
.mma_type
.ptx_type
== "f16":
791 # FP16 ops identified by accumulator & result type.
792 return "%s.%s" % (op
.d
.mma_type
.ptx_type
, op
.c
.mma_type
.ptx_type
)
794 # other ops are identified by input type.
795 return op
.a
.mma_type
.ptx_type
798 def wmma_ptx_signature(op
):
799 if op
.a
.mma_type
.ptx_type
== "f16":
800 # FP16 instructions use D.C
801 return "%s.%s" % (op
.d
.mma_type
.ptx_type
, op
.c
.mma_type
.ptx_type
)
803 # other instructions encode all four types as D.A.B.C
804 return ".".join(x
.mma_type
.ptx_type
for x
in (op
.d
, op
.a
, op
.b
, op
.c
))
807 def common_mma_test_gen(params
, op
, intrinsic_template
, instruction_template
):
809 declare ${ret_ty} @${intrinsic}(
812 ; CHECK-LABEL: .func {{.*}}test_${function}(
813 define ${ret_ty} @test_${function}(
815 ; CHECK: ${instruction}
816 ; CHECK-NEXT: ${check_d}
817 ; CHECK-NEXT: ${check_a}
818 ; CHECK-NEXT: ${check_b}
819 ; CHECK-NEXT: ${check_c}
820 %r = call ${ret_ty} @${intrinsic}(
827 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
828 test_params
["function"] = test_params
["intrinsic"].replace(".", "_")
829 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
830 test_params
["ret_ty"] = make_wmma_ld_ret_ty(op
.d
)
831 test_params
["check_a"] = check_pattern(op
.a
)
832 test_params
["check_b"] = check_pattern(op
.b
)
833 test_params
["check_c"] = check_pattern(op
.c
)
834 test_params
["check_d"] = check_pattern(op
.d
)
835 args
= ",\n ".join(make_wmma_slice_args(frag
) for frag
in (op
.a
, op
.b
, op
.c
))
836 test_params
["args"] = args
837 print(Template(mma_template
).substitute(test_params
))
838 return (test_params
["intrinsic"], test_params
["instruction"])
841 def get_b1_ops(ptx_type
):
844 if ptx_version
>= 71:
845 return [".xor.popc", ".and.popc"]
849 def gen_wmma_mma_tests():
850 wmma_intrinsic_template
= "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
851 wmma_instruction_template
= "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
855 for op
, alayout
, blayout
, rnd
, satf
in product(
859 [".rn", ".rz", ".rm", ".rp", ""],
863 if not is_wmma_variant_supported(op
, alayout
, blayout
, rnd
, satf
):
866 for b1op
in get_b1_ops(op
.a
.mma_type
.ptx_type
):
868 "aligned": ".aligned" if ptx_version
>= 63 else "",
871 "intrinsic_signature": wmma_signature(op
),
872 "ptx_signature": wmma_ptx_signature(op
),
879 intrinsic_template
= wmma_intrinsic_template
880 instruction_template
= wmma_instruction_template
882 generated_items
.append(
884 params
, op
, intrinsic_template
, instruction_template
888 return generated_items
892 mma_intrinsic_template
= "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
893 mma_instruction_template
= "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
897 for op
, alayout
, blayout
, satf
in product(
898 get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]
901 if not is_mma_variant_supported(op
, alayout
, blayout
, satf
):
904 for b1op
in get_b1_ops(op
.a
.mma_type
.ptx_type
):
906 "aligned": ".aligned" if ptx_version
>= 63 else "",
909 "intrinsic_signature": mma_signature(op
),
910 "ptx_signature": mma_ptx_signature(op
),
916 intrinsic_template
= mma_intrinsic_template
917 instruction_template
= mma_instruction_template
919 generated_items
.append(
921 params
, op
, intrinsic_template
, instruction_template
925 return generated_items
928 # Append complete list of intrinsics and instructions we've generated tests for.
929 # Generate set of checks to verify that that we did generate sensible set of
930 # tests for the given combination of PTX and SM variants.
932 def gen_check_unsupported_ops(items
):
934 "; Complete list of intrinsics supported by PTX%d on sm_%d"
935 % (ptx_version
, gpu_arch
)
937 print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
941 ; NOEXTGEOM-NOT: {{m8n32|m32n8}}
942 ; NOINT-NOT: .{{s32|s8}}
943 ; NOSUBINT-NOT: {{s4|u4|b1}}
944 ; NOMMA-NOT: .m8n8k4.
945 ; NOALTFLOAT-NOT: .{{bf16|tf32}}
947 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned
949 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
950 ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
951 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f32
952 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f16
953 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f16
954 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f32
956 ; PTX60 adds support for m32n8k16/m8n32k16 geometries.
957 ; EXTGEOM-DAG: m32n8k16.load.{{[ab].*}}.f16.p
958 ; EXTGEOM-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
959 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f32
960 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f16
961 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f16
962 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f32
964 ; EXTGEOM-DAG: m8n32k16.load.{{[ab].*}}.f16.p
965 ; EXTGEOM-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
966 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f32
967 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f16
968 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f16
969 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f32
971 ; INT-DAG: m16n16k16.load.{{[ab].*}}.s8.p
972 ; INT-DAG: m8n32k16.load.{{[ab].*}}.s8.p
973 ; INT-DAG: m32n8k16.load.{{[ab].*}}.s8.p
974 ; INT-DAG: m16n16k16.load.{{[ab].*}}.u8.p
975 ; INT-DAG: m8n32k16.load.{{[ab].*}}.u8.p
976 ; INT-DAG: m32n8k16.load.{{[ab].*}}.u8.p
977 ; INT-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
978 ; INT-DAG: m16n16k16.mma.{{.*}}.u8
979 ; INT-DAG: m16n16k16.mma.{{.*}}.s8
980 ; INT-DAG: m8n32k16.mma.{{.*}}.u8
981 ; INT-DAG: m8n32k16.mma.{{.*}}.s8
982 ; INT-DAG: m32n8k16.mma.{{.*}}.u8
983 ; INT-DAG: m32n8k16.mma.{{.*}}.s8
985 ; SUBINT-DAG: m8n8k128.load.{{[ab].*}}.b1.p
986 ; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.s4.p
987 ; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.u4.p
988 ; SUBINT-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
989 ; SUBINT-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
990 ; SUBINT-DAG: m8n8k32.mma.{{.*}}.u4
991 ; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4
992 ; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1
994 ; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p
995 ; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p
996 ; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p
997 ; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p
998 ; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16
999 ; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16
1000 ; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16
1001 ; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32
1003 ; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p
1004 ; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p
1005 ; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64
1007 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32
1008 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16
1009 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16
1010 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32
1012 ; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16
1013 ; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32
1014 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8
1015 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8
1016 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8
1017 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8
1018 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4
1019 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4
1020 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
1021 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
1023 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.b16
1024 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.b16
1025 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.b16
1026 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.b16
1027 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.b16
1028 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.b16
1029 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.shared.b16
1030 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.shared.b16
1031 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.shared.b16
1032 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16
1033 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
1034 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
1036 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
1037 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
1038 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
1039 ; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16
1040 ; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16
1041 ; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16
1042 ; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32
1043 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
1044 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
1045 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
1046 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
1047 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
1048 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
1049 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
1050 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
1051 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
1052 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
1053 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
1054 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
1055 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
1056 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
1057 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
1058 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
1059 ; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1
1060 ; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1
1061 ; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1
1062 ; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1
1063 ; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1
1064 ; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1
1070 print("; INTRINSICS_LIST_BEGIN")
1071 for intrinsic
, instruction
in sorted(items
):
1072 print("; ", intrinsic
, " -> ", instruction
, "")
1073 print("; INTRINSICS_LIST_END")
1074 print("; INTRINSICS: ; INTRINSICS_LIST_END")
1078 items
= gen_wmma_load_tests()
1079 items
+= gen_wmma_store_tests()
1080 items
+= gen_ldmatrix_tests()
1081 items
+= gen_wmma_mma_tests()
1082 items
+= gen_mma_tests()
1083 gen_check_unsupported_ops(items
)
1086 parser
= argparse
.ArgumentParser()
1087 parser
.add_argument("--ptx", type=int, default
=60)
1088 parser
.add_argument("--gpu-arch", type=int, default
=70)
1089 args
= parser
.parse_args()
1090 ptx_version
= args
.ptx
1091 gpu_arch
= args
.gpu_arch