1 # This test generates all variants of wmma intrinsics and verifies that LLVM
2 # generates correct instructions for them.
4 # Check all variants of instructions supported by PTX60 on SM70
5 # RUN: %python %s --ptx=60 --gpu-arch=70 > %t-ptx60-sm_70.ll
6 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
7 # RUN: --check-prefixes=INTRINSICS,M16N16
8 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
9 # RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
10 # RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
11 # RUN: | FileCheck %t-ptx60-sm_70.ll
13 # Check all variants of instructions supported by PTX61 on SM70
14 # RUN: %python %s --ptx=61 --gpu-arch=70 > %t-ptx61-sm_70.ll
15 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
16 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM
17 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
18 # RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
19 # RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
20 # RUN: | FileCheck %t-ptx61-sm_70.ll
22 # Check all variants of instructions supported by PTX63 on SM72
23 # RUN: %python %s --ptx=63 --gpu-arch=72 > %t-ptx63-sm_72.ll
24 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
25 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT
26 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
27 # RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
28 # RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
29 # RUN: | FileCheck %t-ptx63-sm_72.ll
31 # Check all variants of instructions supported by PTX63 on SM75
32 # RUN: %python %s --ptx=63 --gpu-arch=75 > %t-ptx63-sm_75.ll
33 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
34 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT
35 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
36 # RUN: --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
37 # RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
38 # RUN: | FileCheck %t-ptx63-sm_75.ll
40 # Check all variants of instructions supported by PTX64 on SM70+
41 # RUN: %python %s --ptx=64 --gpu-arch=70 > %t-ptx64-sm_70.ll
42 # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
43 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA
44 # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
45 # RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT,NOLDMATRIX
46 # RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \
47 # RUN: | FileCheck %t-ptx64-sm_70.ll
49 # Check all variants of instructions supported by PTX65 on SM75+
50 # RUN: %python %s --ptx=65 --gpu-arch=75 > %t-ptx65-sm_75.ll
51 # RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
52 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA,PTX65LDMATRIX
53 # RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
54 # RUN: --check-prefixes=INTRINSICS
55 # RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
56 # RUN: | FileCheck %t-ptx65-sm_75.ll
58 # Check all variants of instructions supported by PTX71 on SM80+
59 # RUN: %python %s --ptx=71 --gpu-arch=80 > %t-ptx71-sm_80.ll
60 # RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
61 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX65LDMATRIX,PTX71MMA
62 # RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
63 # RUN: --check-prefixes=INTRINSICS
64 # RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \
65 # RUN: | FileCheck %t-ptx71-sm_80.ll
67 from __future__
import print_function
70 from itertools
import product
71 from string
import Template
74 def __init__(self
, ptx_type
):
75 self
.ptx_type
= ptx_type
91 self
.ptx_reg_pattern
= {
95 }.get(ptx_type
, "%r[0-9]+")
98 return "%s/%s" % (self
.ptx_type
, self
.llvm_type
)
101 def __init__(self
, geom
, frag
, ptx_elt_type
):
104 self
.mma_type
= MMAType(ptx_elt_type
);
106 # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
107 "m16n16k16:a:u8" : 2,
108 "m16n16k16:a:s8" : 2,
109 "m16n16k16:b:u8" : 2,
110 "m16n16k16:b:s8" : 2,
111 "m16n16k16:c:s32" : 8,
112 "m16n16k16:d:s32" : 8,
118 "m8n32k16:c:s32" : 8,
119 "m8n32k16:d:s32" : 8,
125 "m32n8k16:c:s32" : 8,
126 "m32n8k16:d:s32" : 8,
149 # u4/s4 -> s32 @ m8n8k32 (u4/s4)
161 "m16n8k32:c:s32" : 4,
162 "m16n8k32:d:s32" : 4,
168 "m16n8k64:c:s32" : 4,
169 "m16n8k64:d:s32" : 4,
171 # b1 -> s32 @ m8n8k128(b1)
174 "m8n8k128:c:s32" : 2,
175 "m8n8k128:d:s32" : 2,
177 "m16n8k128:a:b1" : 2,
178 "m16n8k128:b:b1" : 1,
179 "m16n8k128:c:s32" : 4,
180 "m16n8k128:d:s32" : 4,
182 "m16n8k256:a:b1" : 4,
183 "m16n8k256:b:b1" : 2,
184 "m16n8k256:c:s32" : 4,
185 "m16n8k256:d:s32" : 4,
187 # bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
188 "m16n16k16:a:bf16" : 4,
189 "m16n16k16:b:bf16" : 4,
190 "m8n32k16:a:bf16" : 2,
191 "m8n32k16:b:bf16" : 8,
192 "m32n8k16:a:bf16" : 8,
193 "m32n8k16:b:bf16" : 2,
195 "m16n8k16:a:bf16" : 4,
196 "m16n8k16:b:bf16" : 2,
197 "m16n8k16:c:f32" : 4,
198 "m16n8k16:d:f32" : 4,
199 "m16n8k8:a:bf16" : 2,
200 "m16n8k8:b:bf16" : 1,
209 # tf32 -> s32 @ m16n16k8
210 "m16n16k8:a:tf32" : 4,
211 "m16n16k8:b:tf32" : 4,
213 "m16n8k4:a:tf32" : 2,
214 "m16n8k4:b:tf32" : 1,
217 "m16n8k8:a:tf32" : 4,
218 "m16n8k8:b:tf32" : 2,
241 }.get("%s:%s:%s" % (geom
, frag
, ptx_elt_type
), {
242 # All other FP shape/fragment/type combinations have the same size
249 }.get("%s:%s" % (frag
, ptx_elt_type
), None))
253 return "%s:%s:%s%s" % (self
.geom
, self
.frag
, self
.mma_type
,
254 "" if self
.nregs
== 1 else ("*%d" % self
.nregs
))
257 def __init__(self
, a
, b
, c
, d
):
264 return ("{A:%s, B:%s, C:%s, D:%s}" % (self
.a
, self
.b
, self
.c
, self
.d
))
266 def make_mma_ops(geoms
, types_a
, types_b
, types_c
, types_d
):
268 for geom
, type_a
, type_c
in product( geoms
, types_a
, types_c
):
269 for type_b
, type_d
in product(types_b
if types_b
else [type_a
],
270 types_d
if types_d
else [type_c
]):
271 ops
.append(MMAOp(MMAFrag(geom
, "a", type_a
),
272 MMAFrag(geom
, "b", type_b
),
273 MMAFrag(geom
, "c", type_c
),
274 MMAFrag(geom
, "d", type_d
)))
277 def make_ldst_ops(geoms
, frags
, types
):
278 return [MMAFrag(geom
, frag
, ptx_type
) for (geom
, frag
, ptx_type
)
279 in product(geoms
, frags
, types
)]
281 def make_ldmatrix_ops(geoms
, frags
, types
):
282 return [MMAFrag(geom
, frag
, ptx_type
) for (geom
, frag
, ptx_type
)
283 in product(geoms
, frags
, types
)]
286 return (make_mma_ops(["m16n16k8"],
287 ["tf32"], [], ["f32"], []) +
288 make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
289 ["bf16"], [], ["f32"], []) +
290 make_mma_ops(["m8n8k4"],
291 ["f64"], [], ["f64"], []) +
292 make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
293 ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
294 make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
295 ["s8", "u8"], [], ["s32"], []) +
296 make_mma_ops(["m8n8k32"],
297 ["s4", "u4"], [], ["s32"], []) +
298 make_mma_ops(["m8n8k128"],
299 ["b1"], [], ["s32"], []))
302 return (make_mma_ops(["m8n8k4"],
303 ["f64"], [], ["f64"], []) +
304 make_mma_ops(["m16n8k4", "m16n8k8"],
305 ["tf32"], [], ["f32"], []) +
306 make_mma_ops(["m16n8k16", "m16n8k8"],
307 ["bf16"], [], ["f32"], []) +
308 make_mma_ops(["m8n8k4", "m16n8k8", "m16n8k16"],
309 ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
310 make_mma_ops(["m8n8k16", "m16n8k16", "m16n8k32"],
311 ["s8", "u8"], ["s8", "u8"], ["s32"], []) +
312 make_mma_ops(["m8n8k32", "m16n8k32", "m16n8k64"],
313 ["s4", "u4"], ["s4", "u4"], ["s32"], []) +
314 make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"],
315 ["b1"], [], ["s32"], []))
317 def get_ldst_ops(kind
):
318 ldst_ops
= (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
319 ["a", "b"], ["f16", "u8", "s8", "bf16"]) +
320 make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
321 ["c", "d"], ["f16", "f32", "s32"]) +
322 make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
323 make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
324 make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) +
325 make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) +
326 make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) +
327 make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]))
328 return [ x
for x
in ldst_ops
if (x
.frag
== "d") == (kind
== "store")]
330 def get_ldmatrix_ops():
331 return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
333 def is_wmma_geom_supported(geom
):
334 # geometries for FP and ints.
335 if geom
in ["m8n32k16", "m32n8k16"]:
336 return ptx_version
>= 61
337 # geometries for sub-ints.
338 if geom
in ["m8n8k32", "m8n8k128"]:
339 return ptx_version
>= 63 and gpu_arch
>= 75
340 if geom
== "m16n16k16":
341 return ptx_version
>= 60
342 if geom
== "m16n8k8":
343 return ptx_version
>= 65
344 if geom
in ["m16n16k8", "m8n8k4"]:
345 return ptx_version
>= 70
346 assert(False) # Unexpected geometry.
348 def is_mma_geom_supported(geom
):
349 # geometries for FP and ints.
351 return ptx_version
>= 64
352 if geom
in ["m16n8k8", "m8n8k16", "m8n8k32"]:
353 return ptx_version
>= 65
354 if geom
in ["m16n8k16", "m16n8k4", "m16n8k32", "m16n8k64", "m8n8k128",
355 "m16n8k128", "m16n8k256"]:
356 return ptx_version
>= 70
357 assert(False) # Unexpected geometry.
359 def is_ldmatrix_geom_supported(geom
):
361 return ptx_version
>= 65 and gpu_arch
>= 75
362 assert(False) # Unexpected geometry.
364 def is_type_supported(ptx_type
):
365 if ptx_type
in ["s8", "u8", "s32"]:
366 return ptx_version
>= 63 and gpu_arch
>= 72
367 if ptx_type
in ["s4", "u4", "b1"]:
368 return ptx_version
>= 63 and gpu_arch
>= 75
369 if ptx_type
== "b16":
370 return ptx_version
>= 65 and gpu_arch
>= 75
371 if ptx_type
in ["bf16", "tf32", "f64"]:
372 return ptx_version
>= 70
373 return ptx_version
>= 60 and gpu_arch
>= 70
375 def is_wmma_variant_supported(op
, layout_a
, layout_b
, rnd
, satf
):
376 if not (is_type_supported(op
.a
.mma_type
.ptx_type
)
377 and is_wmma_geom_supported(op
.a
.geom
)):
380 # rnd is only supported for FP64 WMMA
381 if rnd
and op
.a
.mma_type
.ptx_type
!= "f64":
385 # satfinite for floating points was removed in PTX 6.5
386 if op
.a
.mma_type
.ptx_type
== "f16" and ptx_version
>= 65:
388 if not op
.a
.mma_type
.ptx_type
in ["f16", "s8", "u8", "s4", "u4"]:
391 # sub-integer require row/col layout.
392 if op
.a
.mma_type
.ptx_type
in ["s4", "u4", "b1"]:
393 return layout_a
== "row" and layout_b
== "col"
396 def is_mma_variant_supported(op
, layout_a
, layout_b
, satf
):
397 if not (is_type_supported(op
.a
.mma_type
.ptx_type
)
398 and is_mma_geom_supported(op
.a
.geom
)):
401 if satf
and not op
.a
.mma_type
.ptx_type
in ["s8", "u8", "s4", "u4"]:
404 # If the type of C is f32 then so must the type of D
405 if (op
.a
.geom
== "m8n8k4" and op
.c
.mma_type
.ptx_type
== "f32"
406 and op
.d
.mma_type
.ptx_type
!= "f32"):
409 # A and B type must be the same. C and D type must be the same
410 if (op
.a
.geom
== "m16n8k8"
411 and (op
.a
.mma_type
.ptx_type
!= op
.b
.mma_type
.ptx_type
412 or op
.c
.mma_type
.ptx_type
!= op
.d
.mma_type
.ptx_type
)):
415 # C and D type must be the same
416 if (op
.a
.geom
== "m16n8k16"
417 and op
.c
.mma_type
.ptx_type
!= op
.d
.mma_type
.ptx_type
):
420 # Require row/col layout for all MMA except m8n8k4 on FP16
421 if not (op
.a
.geom
== "m8n8k4" and op
.a
.mma_type
.ptx_type
== "f16"):
422 return layout_a
== "row" and layout_b
== "col"
425 def is_ldst_variant_supported(frag
, layout
):
426 if not (is_type_supported(frag
.mma_type
.ptx_type
)
427 and is_wmma_geom_supported(frag
.geom
)):
429 if frag
.mma_type
.ptx_type
in ["s4", "u4", "b1"]:
430 # sub-integer require sm_75 and ptx63, row/col layout for a/b.
431 return ((frag
.frag
== "a" and layout
== "row")
432 or (frag
.frag
== "b" and layout
== "col")
433 or frag
.frag
in ["c", "d"])
436 def is_ldmatrix_variant_supported(frag
):
437 if not (is_type_supported(frag
.mma_type
.ptx_type
)
438 and is_ldmatrix_geom_supported(frag
.geom
)):
440 return frag
.frag
in ["x1", "x2", "x4"]
442 def make_wmma_slice_ty(frag
):
443 return [frag
.mma_type
.llvm_type
] * frag
.nregs
445 def make_wmma_ld_ret_ty(frag
):
446 results
= make_wmma_slice_ty(frag
)
447 if len(results
) == 1:
448 return "%s" % results
[0]
449 return "{%s}" % ", ".join(results
)
451 # returns address space
452 def get_aspace(space
):
462 return space_map
[space
];
464 def get_pspace(space
):
465 return "p%di8" % get_aspace(space
);
467 def check_pattern(frag
):
468 return "{{%s}}" % ", *".join([frag
.mma_type
.ptx_reg_pattern
] * frag
.nregs
)
470 def gen_wmma_load_tests():
472 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
474 ; CHECK-LABEL: .func {{.*}}test_${function}(
475 define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
476 ; CHECK: ${instruction}
477 ; CHECK: {${check_result}}
478 ; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
479 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
483 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
484 define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
485 ; CHECK: ${instruction}
486 ; CHECK: {${check_result}}
487 ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
488 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
489 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
493 intrinsic_template
= "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
494 instruction_template
= "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
498 for frag
, layout
, space
, stride
in product(
499 get_ldst_ops("load"),
501 ["",".shared",".global"],
504 if not is_ldst_variant_supported(frag
, layout
):
509 "aligned" : ".aligned" if ptx_version
>= 63 else "",
513 "itype" : frag
.mma_type
.ptx_type
,
514 "pspace" : get_pspace(space
),
515 "as" : "addrspace(%d)" % get_aspace(space
),
520 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
521 test_params
["function"] = test_params
["intrinsic"].replace(".","_")
522 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
523 test_params
["ret_ty"] = make_wmma_ld_ret_ty(frag
)
524 test_params
["check_result"] = check_pattern(frag
)
527 test_params
["extra_args"] = ", i32 %stride";
528 test_params
["stride_pattern"] = ", %r{{[0-9]+}}"
530 test_params
["extra_args"] = ""
531 test_params
["stride_pattern"] = ""
533 print(Template(load_template
).substitute(test_params
))
535 generated_items
.append((test_params
["intrinsic"],
536 test_params
["instruction"]))
538 return generated_items
540 def make_wmma_slice_args(frag
):
541 return ", ".join(["%s %%%s%d" % (t
, frag
.frag
, i
) for i
,t
542 in enumerate(make_wmma_slice_ty(frag
))])
544 def gen_wmma_store_tests():
546 declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
548 ; CHECK-LABEL: .func {{.*}}test_${function}(
549 define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
550 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
551 ; CHECK: {${check_args}}
552 ; CHECK: ${stride_pattern}
553 call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
557 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
558 define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
559 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
560 ; CHECK: ${check_args}
561 ; CHECK: ${stride_pattern}
562 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
563 call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
567 intrinsic_template
= "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
568 instruction_template
= "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
572 for frag
, layout
, space
, stride
in product(
573 get_ldst_ops("store"),
575 ["",".shared",".global"],
578 if not is_ldst_variant_supported(frag
, layout
):
583 "aligned" : ".aligned" if ptx_version
>= 63 else "",
587 "itype" : frag
.mma_type
.ptx_type
,
588 "pspace" : get_pspace(space
),
589 "as" : "addrspace(%d)" % get_aspace(space
),
594 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
595 test_params
["function"] = test_params
["intrinsic"].replace(".","_")
596 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
597 test_params
["ret_ty"] = make_wmma_ld_ret_ty(frag
)
598 test_params
["check_args"] = check_pattern(frag
)
600 test_params
["extra_args"] = ", i32 %stride";
601 test_params
["stride_pattern"] = ", %r{{[0-9]+}};"
603 test_params
["extra_args"] = ""
604 test_params
["stride_pattern"] = ";"
605 test_params
["args"] = make_wmma_slice_args(frag
);
607 print(Template(store_template
).substitute(test_params
))
608 generated_items
.append((test_params
["intrinsic"],
609 test_params
["instruction"]))
611 return generated_items
613 def gen_ldmatrix_tests():
614 ldmatrix_template
= """
615 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src);
617 ; CHECK-LABEL: .func {{.*}}test_${function}(
618 define ${ret_ty} @test_${function}(i8 ${as}* %src) {
619 ; CHECK: ${instruction}
620 ; CHECK: {${check_result}}
621 ; CHECK: [%rd{{[0-9]+}}]
622 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src);
626 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
627 define ${ret_ty} @test_${function}_o(i8 ${as}* %src) {
628 ; CHECK: ${instruction}
629 ; CHECK: {${check_result}}
630 ; CHECK: [%rd{{[0-9]+}}+128]
631 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
632 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1);
636 intrinsic_template
= "llvm.nvvm.ldmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
637 instruction_template
= "ldmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
641 for frag
, space
, trans
in product(
646 if not is_ldmatrix_variant_supported(frag
):
653 "itype" : frag
.mma_type
.ptx_type
,
654 "pspace" : get_pspace(space
),
655 "as" : "addrspace(%d)" % get_aspace(space
),
660 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
661 test_params
["function"] = test_params
["intrinsic"].replace(".","_")
662 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
663 test_params
["ret_ty"] = make_wmma_ld_ret_ty(frag
)
664 test_params
["check_result"] = check_pattern(frag
)
666 print(Template(ldmatrix_template
).substitute(test_params
))
668 generated_items
.append((test_params
["intrinsic"],
669 test_params
["instruction"]))
671 return generated_items
673 def mma_signature(op
):
674 if op
.a
.mma_type
.ptx_type
== "f16":
675 # FP16 ops identified by accumulator & result type.
676 return "%s.%s" % (op
.d
.mma_type
.ptx_type
, op
.c
.mma_type
.ptx_type
)
677 elif op
.a
.mma_type
.ptx_type
!= op
.b
.mma_type
.ptx_type
:
678 # other ops are identified by input types.
679 return "%s.%s" % (op
.a
.mma_type
.ptx_type
, op
.b
.mma_type
.ptx_type
)
681 # if input types are the same, it only appears once.
682 return op
.a
.mma_type
.ptx_type
684 def mma_ptx_signature(op
):
685 # Encode all four types as D.A.B.C
686 return ".".join(x
.mma_type
.ptx_type
for x
in (op
.d
, op
.a
, op
.b
, op
.c
))
688 def wmma_signature(op
):
689 if op
.a
.mma_type
.ptx_type
== "f16":
690 # FP16 ops identified by accumulator & result type.
691 return "%s.%s" % (op
.d
.mma_type
.ptx_type
, op
.c
.mma_type
.ptx_type
)
693 # other ops are identified by input type.
694 return op
.a
.mma_type
.ptx_type
696 def wmma_ptx_signature(op
):
697 if op
.a
.mma_type
.ptx_type
== "f16":
698 # FP16 instructions use D.C
699 return "%s.%s" % (op
.d
.mma_type
.ptx_type
, op
.c
.mma_type
.ptx_type
)
701 # other instructions encode all four types as D.A.B.C
702 return ".".join(x
.mma_type
.ptx_type
for x
in (op
.d
, op
.a
, op
.b
, op
.c
))
704 def common_mma_test_gen(params
, op
, intrinsic_template
, instruction_template
):
706 declare ${ret_ty} @${intrinsic}(
709 ; CHECK-LABEL: .func {{.*}}test_${function}(
710 define ${ret_ty} @test_${function}(
712 ; CHECK: ${instruction}
713 ; CHECK-NEXT: ${check_d}
714 ; CHECK-NEXT: ${check_a}
715 ; CHECK-NEXT: ${check_b}
716 ; CHECK-NEXT: ${check_c}
717 %r = call ${ret_ty} @${intrinsic}(
724 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
725 test_params
["function"] = test_params
["intrinsic"].replace(".", "_")
726 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
727 test_params
["ret_ty"] = make_wmma_ld_ret_ty(op
.d
)
728 test_params
["check_a"] = check_pattern(op
.a
)
729 test_params
["check_b"] = check_pattern(op
.b
)
730 test_params
["check_c"] = check_pattern(op
.c
)
731 test_params
["check_d"] = check_pattern(op
.d
)
732 args
= ",\n ".join(make_wmma_slice_args(frag
)
733 for frag
in (op
.a
, op
.b
, op
.c
))
734 test_params
["args"] = args
735 print(Template(mma_template
).substitute(test_params
))
736 return (test_params
["intrinsic"], test_params
["instruction"])
738 def get_b1_ops(ptx_type
):
741 if ptx_version
>= 71:
742 return [".xor.popc", ".and.popc"]
745 def gen_wmma_mma_tests():
746 wmma_intrinsic_template
= "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
747 wmma_instruction_template
= "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
751 for op
, alayout
, blayout
, rnd
, satf
in product(
755 [".rn", ".rz", ".rm", ".rp", ""],
758 if not is_wmma_variant_supported(op
, alayout
, blayout
, rnd
, satf
):
761 for b1op
in get_b1_ops(op
.a
.mma_type
.ptx_type
):
763 "aligned" : ".aligned" if ptx_version
>= 63 else "",
766 "intrinsic_signature" : wmma_signature(op
),
767 "ptx_signature" : wmma_ptx_signature(op
),
774 intrinsic_template
= wmma_intrinsic_template
775 instruction_template
= wmma_instruction_template
777 generated_items
.append(common_mma_test_gen(params
, op
,
778 intrinsic_template
, instruction_template
))
780 return generated_items
783 mma_intrinsic_template
= "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
784 mma_instruction_template
= "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
788 for op
, alayout
, blayout
, satf
in product(
794 if not is_mma_variant_supported(op
, alayout
, blayout
, satf
):
797 for b1op
in get_b1_ops(op
.a
.mma_type
.ptx_type
):
799 "aligned" : ".aligned" if ptx_version
>= 63 else "",
802 "intrinsic_signature" : mma_signature(op
),
803 "ptx_signature" : mma_ptx_signature(op
),
809 intrinsic_template
= mma_intrinsic_template
810 instruction_template
= mma_instruction_template
812 generated_items
.append(common_mma_test_gen(params
, op
,
813 intrinsic_template
, instruction_template
))
815 return generated_items
817 # Append complete list of intrinsics and instructions we've generated tests for.
818 # Generate set of checks to verify that that we did generate sensible set of
819 # tests for the given combination of PTX and SM variants.
821 def gen_check_unsupported_ops(items
):
822 print("; Complete list of intrinsics supported by PTX%d on sm_%d"
823 % (ptx_version
, gpu_arch
))
824 print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
827 ; NOEXTGEOM-NOT: {{m8n32|m32n8}}
828 ; NOINT-NOT: .{{s32|s8}}
829 ; NOSUBINT-NOT: {{s4|u4|b1}}
830 ; NOMMA-NOT: .m8n8k4.
831 ; NOALTFLOAT-NOT: .{{bf16|tf32}}
833 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned
835 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
836 ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
837 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f32
838 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f16
839 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f16
840 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f32
842 ; PTX60 adds support for m32n8k16/m8n32k16 geometries.
843 ; EXTGEOM-DAG: m32n8k16.load.{{[ab].*}}.f16.p
844 ; EXTGEOM-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
845 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f32
846 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f16
847 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f16
848 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f32
850 ; EXTGEOM-DAG: m8n32k16.load.{{[ab].*}}.f16.p
851 ; EXTGEOM-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
852 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f32
853 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f16
854 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f16
855 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f32
857 ; INT-DAG: m16n16k16.load.{{[ab].*}}.s8.p
858 ; INT-DAG: m8n32k16.load.{{[ab].*}}.s8.p
859 ; INT-DAG: m32n8k16.load.{{[ab].*}}.s8.p
860 ; INT-DAG: m16n16k16.load.{{[ab].*}}.u8.p
861 ; INT-DAG: m8n32k16.load.{{[ab].*}}.u8.p
862 ; INT-DAG: m32n8k16.load.{{[ab].*}}.u8.p
863 ; INT-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
864 ; INT-DAG: m16n16k16.mma.{{.*}}.u8
865 ; INT-DAG: m16n16k16.mma.{{.*}}.s8
866 ; INT-DAG: m8n32k16.mma.{{.*}}.u8
867 ; INT-DAG: m8n32k16.mma.{{.*}}.s8
868 ; INT-DAG: m32n8k16.mma.{{.*}}.u8
869 ; INT-DAG: m32n8k16.mma.{{.*}}.s8
871 ; SUBINT-DAG: m8n8k128.load.{{[ab].*}}.b1.p
872 ; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.s4.p
873 ; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.u4.p
874 ; SUBINT-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
875 ; SUBINT-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
876 ; SUBINT-DAG: m8n8k32.mma.{{.*}}.u4
877 ; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4
878 ; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1
880 ; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p
881 ; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p
882 ; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p
883 ; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p
884 ; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16
885 ; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16
886 ; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16
887 ; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32
889 ; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p
890 ; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p
891 ; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64
893 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32
894 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16
895 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16
896 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32
898 ; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16
899 ; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32
900 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8
901 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8
902 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8
903 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8
904 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4
905 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4
906 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
907 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
909 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.b16
910 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.b16
911 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.b16
912 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.b16
913 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.b16
914 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.b16
915 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.shared.b16
916 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.shared.b16
917 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.shared.b16
918 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16
919 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
920 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
922 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
923 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
924 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
925 ; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16
926 ; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16
927 ; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16
928 ; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32
929 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
930 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
931 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
932 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
933 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
934 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
935 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
936 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
937 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
938 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
939 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
940 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
941 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
942 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
943 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
944 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
945 ; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1
946 ; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1
947 ; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1
948 ; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1
949 ; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1
950 ; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1
955 print("; INTRINSICS_LIST_BEGIN")
956 for intrinsic
, instruction
in sorted(items
):
957 print("; ", intrinsic
, " -> ", instruction
,"")
958 print("; INTRINSICS_LIST_END")
959 print("; INTRINSICS: ; INTRINSICS_LIST_END")
962 items
= gen_wmma_load_tests()
963 items
+= gen_wmma_store_tests()
964 items
+= gen_ldmatrix_tests()
965 items
+= gen_wmma_mma_tests()
966 items
+= gen_mma_tests()
967 gen_check_unsupported_ops(items
)
969 parser
= argparse
.ArgumentParser()
970 parser
.add_argument("--ptx", type=int, default
=60)
971 parser
.add_argument("--gpu-arch", type=int, default
=70)
972 args
= parser
.parse_args()
973 ptx_version
= args
.ptx
974 gpu_arch
= args
.gpu_arch