1 # This test generates all variants of wmma intrinsics and verifies that LLVM
2 # generates correct instructions for them.
4 # Check all variants of instructions supported by PTX60 on SM70
5 # RUN: python %s --ptx=60 --gpu-arch=70 > %t-ptx60-sm_70.ll
6 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
7 # RUN: --check-prefixes=INTRINSICS,PTX60,SM70
8 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
9 # RUN: --check-prefixes=INTRINSICS,PTX60U,SM70U
10 # RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
11 # RUN: | FileCheck %t-ptx60-sm_70.ll
13 # Check all variants of instructions supported by PTX61 on SM70
14 # RUN: python %s --ptx=61 --gpu-arch=70 > %t-ptx61-sm_70.ll
15 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
16 # RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,SM70
17 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
18 # RUN: --check-prefixes=INTRINSICS,PTX61U,SM70U
19 # RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
20 # RUN: | FileCheck %t-ptx61-sm_70.ll
22 # Check all variants of instructions supported by PTX63 on SM72
23 # RUN: python %s --ptx=63 --gpu-arch=72 > %t-ptx63-sm_72.ll
24 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
25 # RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72
26 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
27 # RUN: --check-prefixes=INTRINSICS,PTX63U,SM72U
28 # RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
29 # RUN: | FileCheck %t-ptx63-sm_72.ll
31 # Check all variants of instructions supported by PTX63 on SM75
32 # RUN: python %s --ptx=63 --gpu-arch=75 > %t-ptx63-sm_75.ll
33 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
34 # RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72,SM75
35 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
36 # RUN: --check-prefixes=INTRINSICS,PTX63U,SM75U
37 # RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
38 # RUN: | FileCheck %t-ptx63-sm_75.ll
41 from __future__
import print_function
44 from itertools
import product
45 from string
import Template
48 def __init__(self
, ptx_type
):
49 self
.ptx_type
= ptx_type
61 self
.ptx_reg_pattern
= {
64 }.get(ptx_type
, "%r[0-9]+")
67 return "%s/%s" % (self
.ptx_type
, self
.llvm_type
)
70 def __init__(self
, geom
, frag
, ptx_elt_type
):
73 self
.mma_type
= MMAType(ptx_elt_type
);
81 }.get("%s:%s" % (frag
, ptx_elt_type
), {
82 # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
87 "m16n16k16:c:s32" : 8,
88 "m16n16k16:d:s32" : 8,
101 "m32n8k16:c:s32" : 8,
102 "m32n8k16:d:s32" : 8,
104 # u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
111 "m8n8k128:c:s32" : 2,
112 "m8n8k128:d:s32" : 2,
115 }.get("%s:%s:%s" % (geom
, frag
, ptx_elt_type
), None));
119 return "%s:%s:%s%s" % (self
.geom
, self
.frag
, self
.mma_type
,
120 "" if self
.nregs
== 1 else ("*%d" % self
.nregs
))
123 def __init__(self
, a
, b
, c
, d
):
130 return ("{A:%s, B:%s, C:%s, D:%s}" % (self
.a
, self
.b
, self
.c
, self
.d
))
132 def make_mma_ops(geoms
, types_a
, types_b
, types_c
, types_d
):
134 for geom
, type_a
, type_c
in product( geoms
, types_a
, types_c
):
135 for type_b
, type_d
in product(types_b
if types_b
else [type_a
],
136 types_d
if types_d
else [type_c
]):
137 ops
.append(MMAOp(MMAFrag(geom
, "a", type_a
),
138 MMAFrag(geom
, "b", type_b
),
139 MMAFrag(geom
, "c", type_c
),
140 MMAFrag(geom
, "d", type_d
)))
143 def make_ldst_ops(geoms
, frags
, types
):
144 return [MMAFrag(geom
, frag
, ptx_type
) for (geom
, frag
, ptx_type
)
145 in product(geoms
, frags
, types
)]
148 return (make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
149 ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
150 make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
151 ["s8", "u8"], [], ["s32"], []) +
152 make_mma_ops(["m8n8k32"],
153 ["s4", "u4"], [], ["s32"], []) +
154 make_mma_ops(["m8n8k128"],
155 ["b1"], [], ["s32"], []))
156 def get_ldst_ops(kind
):
157 ldst_ops
= (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
158 ["a", "b"], ["f16", "u8", "s8"]) +
159 make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
160 ["c", "d"], ["f16", "f32", "s32"]) +
161 make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
162 make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
163 make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]))
164 return [ x
for x
in ldst_ops
if (x
.frag
== "d") == (kind
== "store")]
166 def is_geom_supported(geom
):
167 # geometries for FP and ints.
168 if geom
in ["m8n32k16", "m32n8k16"]:
169 return ptx_version
>= 61
170 # geometries for sub-ints.
171 if geom
in ["m8n8k32", "m8n8k128"]:
172 return ptx_version
>= 63 and gpu_arch
>= 75
173 if geom
== "m16n16k16":
174 return ptx_version
>= 60
175 assert(False) # Unexpected geometry.
177 def is_type_supported(ptx_type
):
178 if ptx_type
in ["s8", "u8", "s32"]:
179 return ptx_version
>= 63 and gpu_arch
>= 72
180 if ptx_type
in ["s4", "u4", "b1"]:
181 return ptx_version
>= 63 and gpu_arch
>= 75
182 return ptx_version
>= 60 and gpu_arch
>= 70
185 def is_mma_variant_supported(op
, layout_a
, layout_b
, satf
):
186 if not (is_type_supported(op
.a
.mma_type
.ptx_type
)
187 and is_geom_supported(op
.a
.geom
)):
189 # sub-integer require row/col layout, and no satf.
190 if op
.a
.mma_type
.ptx_type
in ["s4", "u4", "b1"]:
191 if op
.a
.mma_type
.ptx_type
== "b1" and satf
:
193 return layout_a
== "row" and layout_b
== "col"
196 def is_ldst_variant_supported(frag
, layout
):
197 if not (is_type_supported(frag
.mma_type
.ptx_type
)
198 and is_geom_supported(frag
.geom
)):
200 if frag
.mma_type
.ptx_type
in ["s4", "u4", "b1"]:
201 # sub-integer require sm_75 and ptx63, row/col layout for a/b.
202 return ((frag
.frag
== "a" and layout
== "row")
203 or (frag
.frag
== "b" and layout
== "col")
204 or frag
.frag
in ["c", "d"])
207 def make_wmma_slice_ty(frag
):
208 return [frag
.mma_type
.llvm_type
] * frag
.nregs
210 def make_wmma_ld_ret_ty(frag
):
211 results
= make_wmma_slice_ty(frag
)
212 if len(results
) == 1:
213 return "%s" % results
[0]
214 return "{%s}" % ", ".join(results
)
216 # returns address space
217 def get_aspace(space
):
227 return space_map
[space
];
229 def get_pspace(space
):
230 return "p%di8" % get_aspace(space
);
232 def check_pattern(frag
):
233 return "{{%s}}" % ", *".join([frag
.mma_type
.ptx_reg_pattern
] * frag
.nregs
)
235 known_geoms
= ["m16n16k16", "m8n32k16", "m32n8k16"]
237 def gen_wmma_load_tests():
239 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
241 ; CHECK-LABEL: .func {{.*}}test_${function}(
242 define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
243 ; CHECK: ${instruction}
244 ; CHECK: {${check_result}}
245 ; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
246 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
250 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
251 define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
252 ; CHECK: ${instruction}
253 ; CHECK: {${check_result}}
254 ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
255 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
256 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
260 intrinsic_template
= "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
261 instruction_template
= "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
265 for frag
, layout
, space
, stride
in product(
266 get_ldst_ops("load"),
268 ["",".shared",".global"],
271 if not is_ldst_variant_supported(frag
, layout
):
276 "aligned" : ".aligned" if ptx_version
>= 63 else "",
280 "itype" : frag
.mma_type
.ptx_type
,
281 "pspace" : get_pspace(space
),
282 "as" : "addrspace(%d)" % get_aspace(space
),
287 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
288 test_params
["function"] = test_params
["intrinsic"].replace(".","_")
289 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
290 test_params
["ret_ty"] = make_wmma_ld_ret_ty(frag
)
291 test_params
["check_result"] = check_pattern(frag
)
294 test_params
["extra_args"] = ", i32 %stride";
295 test_params
["stride_pattern"] = ", %r{{[0-9]+}}"
297 test_params
["extra_args"] = ""
298 test_params
["stride_pattern"] = ""
300 print(Template(load_template
).substitute(test_params
))
302 generated_items
.append((test_params
["intrinsic"],
303 test_params
["instruction"]))
305 return generated_items
307 def make_wmma_slice_args(frag
):
308 return ", ".join(["%s %%%s%d" % (t
, frag
.frag
, i
) for i
,t
309 in enumerate(make_wmma_slice_ty(frag
))])
311 def gen_wmma_store_tests():
313 declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
315 ; CHECK-LABEL: .func {{.*}}test_${function}(
316 define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
317 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
318 ; CHECK: {${check_args}}
319 ; CHECK: ${stride_pattern}
320 call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
324 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
325 define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
326 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
327 ; CHECK: ${check_args}
328 ; CHECK: ${stride_pattern}
329 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
330 call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
334 intrinsic_template
= "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
335 instruction_template
= "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
339 for frag
, layout
, space
, stride
in product(
340 get_ldst_ops("store"),
342 ["",".shared",".global"],
345 if not is_ldst_variant_supported(frag
, layout
):
350 "aligned" : ".aligned" if ptx_version
>= 63 else "",
354 "itype" : frag
.mma_type
.ptx_type
,
355 "pspace" : get_pspace(space
),
356 "as" : "addrspace(%d)" % get_aspace(space
),
361 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
362 test_params
["function"] = test_params
["intrinsic"].replace(".","_")
363 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
364 test_params
["ret_ty"] = make_wmma_ld_ret_ty(frag
)
365 test_params
["check_args"] = check_pattern(frag
)
367 test_params
["extra_args"] = ", i32 %stride";
368 test_params
["stride_pattern"] = ", %r{{[0-9]+}};"
370 test_params
["extra_args"] = ""
371 test_params
["stride_pattern"] = ";"
372 test_params
["args"] = make_wmma_slice_args(frag
);
374 print(Template(store_template
).substitute(test_params
))
375 generated_items
.append((test_params
["intrinsic"],
376 test_params
["instruction"]))
378 return generated_items
380 def mma_signature(op
):
381 if op
.a
.mma_type
.ptx_type
in ["s8", "u8", "s4", "u4", "b1"]:
382 # int and sub-int ops are identified by input type.
383 return op
.a
.mma_type
.ptx_type
385 # the rest are FP ops identified by accumulator & result type.
386 return "%s.%s" % (op
.d
.mma_type
.ptx_type
, op
.c
.mma_type
.ptx_type
)
388 def mma_ptx_signature(op
):
389 if op
.a
.mma_type
.ptx_type
in ["s8", "u8", "s4", "u4", "b1"]:
390 # int and sub-int instructions encode all four types as D.A.B.C
391 return ".".join(x
.mma_type
.ptx_type
for x
in (op
.d
, op
.a
, op
.b
, op
.c
))
393 # the rest are FP instructions use D.C
394 return "%s.%s" % (op
.d
.mma_type
.ptx_type
, op
.c
.mma_type
.ptx_type
)
396 def gen_wmma_mma_tests():
398 declare ${ret_ty} @${intrinsic}(
401 ; CHECK-LABEL: .func {{.*}}test_${function}(
402 define ${ret_ty} @test_${function}(
404 ; CHECK: ${instruction}
405 ; CHECK-NEXT: ${check_d}
406 ; CHECK-NEXT: ${check_a}
407 ; CHECK-NEXT: ${check_b}
408 ; CHECK-NEXT: ${check_c}
409 %r = call ${ret_ty} @${intrinsic}(
414 intrinsic_template
= "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
415 instruction_template
= "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}.${ptx_signature}${satf}"
419 for op
, alayout
, blayout
, satf
in product(
425 if not is_mma_variant_supported(op
, alayout
, blayout
, satf
):
429 "aligned" : ".aligned" if ptx_version
>= 63 else "",
432 "intrinsic_signature" : mma_signature(op
),
433 "ptx_signature" : mma_ptx_signature(op
),
436 "mma_variant" : ".xor.popc" if op
.a
.mma_type
.ptx_type
== "b1" else "",
440 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
441 test_params
["function"] = test_params
["intrinsic"].replace(".", "_")
442 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
443 test_params
["ret_ty"] = make_wmma_ld_ret_ty(op
.d
)
444 test_params
["check_a"] = check_pattern(op
.a
)
445 test_params
["check_b"] = check_pattern(op
.b
)
446 test_params
["check_c"] = check_pattern(op
.c
)
447 test_params
["check_d"] = check_pattern(op
.d
)
448 args
= ",\n ".join(make_wmma_slice_args(frag
)
449 for frag
in (op
.a
, op
.b
, op
.c
))
450 test_params
["args"] = args
451 print(Template(mma_template
).substitute(test_params
))
452 generated_items
.append((test_params
["intrinsic"],
453 test_params
["instruction"]))
455 return generated_items
457 # Append complete list of intrinsics and instructions we've generated tests for.
458 # Generate set of checks to verify that that we did generate sensible set of
459 # tests for the given combination of PTX and SM variants.
461 # PTX<N>: verifies that we did generate tests for correct classes of intrinsics.
462 # PTX<N>U: verifies that we did not generate intrinsics unsupported by
464 # SM<N>: verifies that we did generate correct classes of instructions for the SM.
465 # SM<N>U: verifies that we did not generate instructions unsupported by the SM
467 # Note that SM/PTX constraints overlap, but DAG checks do not allow overlapping
468 # matches. We implicitly rely that we generate multiple variants of most of the
469 # instructions and usually have enough input data to find more than one match of
470 # the same kind, if necessary. When it's not possible (e.g. there's only one
471 # m8n8k128.mma.row.col.b1), we may need to match PTX instruction instead.
472 def gen_check_unsupported_ops(items
):
473 print("; Complete list of intrinsics supported by PTX%d on sm_%d"
474 % (ptx_version
, gpu_arch
))
475 print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
477 ; PTX60-DAG: m16n16k16.load.{{[ab].*}}.f16.p
478 ; PTX60-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
479 ; PTX60U-NOT: m32n8k16
480 ; PTX60U-NOT: m8n32k16
481 ; PTX60U-NOT: .{{s32|s[48]|u[48]|b1}}
483 ; All features of PTX60, plus m32n8k16/m8n32k16 geometries.
484 ; PTX61-DAG: m32n8k16.load.{{[ab].*}}.f16.p
485 ; PTX61-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
486 ; PTX61-DAG: m8n32k16.load.{{[ab].*}}.f16.p
487 ; PTX61-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
488 ; PTX61U-NOT: .{{s32|s[48]|u[48]|b1}}
490 ; SM70U-NOT: .{{s32|s[48]|u[48]|b1}}
492 ; PTX63 supports all features of PTX60+PTX61, plus support for integers.
493 ; Alas we can"t just use PTX<N> checks for that as available instructions
494 ; depend on SM integers need sm72+ and subinteger ops need sm75, so we
495 ; transition to SM<N> checks
496 ; SM72-DAG: m16n16k16.load.{{[ab].*}}.s8.p
497 ; SM72-DAG: m8n32k16.load.{{[ab].*}}.s8.p
498 ; SM72-DAG: m32n8k16.load.{{[ab].*}}.s8.p
499 ; SM72-DAG: m16n16k16.load.{{[ab].*}}.u8.p
500 ; SM72-DAG: m8n32k16.load.{{[ab].*}}.u8.p
501 ; SM72-DAG: m32n8k16.load.{{[ab].*}}.u8.p
502 ; SM72-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
503 ; SM72U-NOT: .{{s4|u4|b1}}
505 ; SM75-DAG: m8n8k128.load.{{[ab].*}}.b1.p
506 ; SM75-DAG: m8n8k32.load.{{[ab].*}}.s4.p
507 ; SM75-DAG: m8n8k32.load.{{[ab].*}}.u4.p
508 ; SM75-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
509 ; SM75-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
512 print("; INTRINSICS_LIST_BEGIN")
513 for intrinsic
, instruction
in sorted(items
):
514 print("; ", intrinsic
, " -> ", instruction
,"")
515 print("; INTRINSICS_LIST_END")
516 print("; INTRINSICS: ; INTRINSICS_LIST_END")
519 items
= gen_wmma_load_tests()
520 items
+= gen_wmma_store_tests()
521 items
+= gen_wmma_mma_tests()
522 gen_check_unsupported_ops(items
)
524 parser
= argparse
.ArgumentParser()
525 parser
.add_argument("--ptx", type=int, default
=60)
526 parser
.add_argument("--gpu-arch", type=int, default
=70)
527 args
= parser
.parse_args()
528 ptx_version
= args
.ptx
529 gpu_arch
= args
.gpu_arch