1 # This script generates all variants of wmma builtins, verifies that clang calls
2 # correct LLVM intrinsics, and checks that availability of specific builtins is
3 # constrained by the correct PTX version and the target GPU variant.
5 # Dummy test run to avoid lit warnings.
6 # RUN: echo "This is not a real test. It's a generator for builtins-nvpts-mma.cu" >/dev/null
8 from __future__
import print_function
11 from collections
import defaultdict
12 from itertools
import product
13 from string
import Template
17 def __init__(self
, geom
, frag
, ptx_elt_type
):
20 self
.ptx_type
= ptx_elt_type
23 return "%s:%s:%s" % (self
.geom
, self
.frag
, self
.ptx_type
)
27 def __init__(self
, a
, b
, c
, d
, b1op
=""):
35 return "{A:%s, B:%s, C:%s, D:%s}" % (self
.a
, self
.b
, self
.c
, self
.d
)
38 def make_mma_ops(geoms
, types_a
, types_b
, types_c
, types_d
, b1ops
=None):
42 for geom
, type_a
, type_c
in product(geoms
, types_a
, types_c
):
43 for type_b
, type_d
in product(
44 types_b
if types_b
else [type_a
], types_d
if types_d
else [type_c
]
48 MMAFrag(geom
, "a", type_a
),
49 MMAFrag(geom
, "b", type_b
),
50 MMAFrag(geom
, "c", type_c
),
51 MMAFrag(geom
, "d", type_d
),
59 def make_ldst_ops(geoms
, frags
, types
):
61 MMAFrag(geom
, frag
, ptx_type
)
62 for (geom
, frag
, ptx_type
) in product(geoms
, frags
, types
)
68 make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
69 + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["bf16"], [], ["f32"], [])
70 + make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
72 ["m16n16k16", "m32n8k16", "m8n32k16"],
79 ["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], []
81 + make_mma_ops(["m8n8k32"], ["s4", "u4"], [], ["s32"], [])
83 ["m8n8k128"], ["b1"], [], ["s32"], [], [".xor.popc", ".and.popc"]
89 # NOTE: fragemts are from the point of view of PTX.
90 # fragment `d` is only for store ops, others for both loads and stores.
93 ["m16n16k16", "m32n8k16", "m8n32k16"],
95 ["f16", "u8", "s8", "bf16"],
98 ["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"]
100 + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4", "u4"])
101 + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"])
102 + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])
103 + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"])
105 # TF32 m16n16k8 is odd.
106 # For fragment 'C' it uses __mma_*tf32*_m16n16k8_ld_c
107 # but 'D' calls __mma_m16n16k8_st_c_*f32*.
108 make_ldst_ops(["m16n16k8"], ["a", "b", "c"], ["tf32"])
109 + make_ldst_ops(["m16n16k8"], ["d"], ["f32"])
113 def is_geom_supported(geom
):
114 # geometries for FP and ints.
115 if geom
in ["m8n32k16", "m32n8k16"]:
116 return ptx_version
>= 61
117 # geometries for sub-ints.
118 if geom
in ["m8n8k32", "m8n8k128"]:
119 return ptx_version
>= 63 and gpu_arch
>= 75
120 if geom
== "m16n16k16":
121 return ptx_version
>= 60
122 if geom
in ["m16n16k8", "m8n8k4"]:
123 return ptx_version
>= 70 and gpu_arch
>= 80
124 assert False # Unexpected geometry.
127 def is_type_supported(ptx_type
):
128 if ptx_type
in ["s8", "u8", "s32"]:
129 return ptx_version
>= 63 and gpu_arch
>= 72
130 if ptx_type
in ["s4", "u4", "b1"]:
131 return ptx_version
>= 63 and gpu_arch
>= 75
132 if ptx_type
in ["bf16", "tf32", "f64"]:
133 return ptx_version
>= 70 and gpu_arch
>= 80
134 return ptx_version
>= 60 and gpu_arch
>= 70
137 def is_rnd_supported(op
):
138 # rnd is only supported for FP64 WMMA
139 return op
.a
.ptx_type
== "f64"
142 def is_mma_variant_supported(op
, layout_a
, layout_b
, satf
):
143 if not (is_type_supported(op
.a
.ptx_type
) and is_geom_supported(op
.a
.geom
)):
146 if satf
and not op
.a
.ptx_type
in ["f16", "s8", "u8", "s4", "u4"]:
149 # sub-integer types require row/col layout.
150 if op
.a
.ptx_type
in ["s4", "u4", "b1"]:
151 return layout_a
== "row" and layout_b
== "col"
155 def is_ldst_variant_supported(frag
, layout
):
156 if not (is_type_supported(frag
.ptx_type
) and is_geom_supported(frag
.geom
)):
158 if frag
.ptx_type
in ["s4", "u4", "b1"]:
159 # sub-integer types require sm_75 and ptx63, row/col layout for a/b.
161 (frag
.frag
== "a" and layout
== "row")
162 or (frag
.frag
== "b" and layout
== "col")
163 or frag
.frag
in ["c", "d"]
168 def get_builtin_prefix(frag
):
170 if frag
.geom
in ["m16n16k16", "m32n8k16", "m8n32k16"]:
171 if frag
.ptx_type
in ["f16", "f32"]:
173 elif frag
.ptx_type
== "bf16":
174 prefix
= "__mma_bf16"
177 elif frag
.geom
== "m8n8k32":
178 prefix
= "__imma" # sub-integers
179 elif frag
.geom
== "m8n8k128":
181 elif frag
.geom
== "m8n8k4":
183 elif frag
.geom
== "m16n16k8":
184 if frag
.ptx_type
== "f32":
187 prefix
= "__mma_tf32"
192 def get_ldst_builtin_name(frag
):
193 prefix
= get_builtin_prefix(frag
)
195 if prefix
== "__hmma":
196 suffix
= "" if frag
.frag
in ["a", "b"] else frag
.ptx_type
197 elif prefix
in ["__dmma", "__mma_bf16", "__mma_tf32"]:
198 suffix
= "" if frag
.frag
in ["a", "b", "c"] else frag
.ptx_type
200 suffix
= "" if frag
.frag
== "c" else frag
.ptx_type
211 name
= "%s_%s_%s_%s%s" % (
216 "_" + suffix
if suffix
else "",
221 def get_mma_builtin_name(op
):
222 prefix
= get_builtin_prefix(op
.a
)
224 if prefix
== "__hmma":
225 suffix
= op
.d
.ptx_type
+ op
.c
.ptx_type
226 elif prefix
in ["__mma_bf16", "__mma_tf32"]:
227 suffix
= op
.d
.ptx_type
229 suffix
= op
.a
.ptx_type
231 name
= "{prefix}_{geom}_mma{b1op}_{suffix}".format(
232 prefix
=prefix
, geom
=op
.a
.geom
, b1op
=op
.b1op
.replace(".", "_"), suffix
=suffix
237 def get_required_sm(frag
, b1op
=""):
238 if frag
.ptx_type
in ["f64", "bf16", "tf32"]:
240 if frag
.ptx_type
in ["u4", "s4", "b1"]:
241 if b1op
== ".and.popc":
244 if frag
.ptx_type
in ["s8", "u8"]:
246 if frag
.ptx_type
== "s32":
247 if frag
.geom
in ["m8n8k32", "m8n8k128"]: # s4/u4/b1
251 if frag
.ptx_type
in ["f16", "f32"]:
252 if frag
.geom
== "m16n16k8":
259 def get_required_ptx(frag
, b1op
=""):
260 if frag
.ptx_type
== "b1" and b1op
== ".and.popc":
262 if frag
.ptx_type
in ["f64", "bf16", "tf32"]:
264 if frag
.ptx_type
in ["f16", "f32"]:
265 if frag
.geom
== "m16n16k16":
267 if frag
.geom
== "m16n16k8":
273 def get_src_dst_prefix(frag
):
274 if frag
.ptx_type
== "f32":
276 if frag
.ptx_type
== "f64":
278 if frag
.ptx_type
== "tf32" and frag
.frag
in ["c", "d"]:
283 def gen_wmma_ldst_tests(results
):
285 // CHECK${check_suffix}: call {{.*}} @${intrinsic}
286 // expected-error-re@+1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}}
287 ${builtin}(${dst}, ${src}, ldm, ${blayout});
289 intrinsic_template
= (
290 "llvm.nvvm.wmma.${geom}.${op}.${frag}.${ilayout}.stride.${itype}"
293 for frag
, layout
in sorted(product(get_ldst_ops(), ["row", "col"]), key
=str):
295 if not is_ldst_variant_supported(frag
, layout
):
298 src_dst_prefix
= get_src_dst_prefix(frag
)
300 min_sm
= get_required_sm(frag
)
301 min_ptx
= get_required_ptx(frag
)
302 # TF32 uses f32 for accumulator loads.
303 if frag
.geom
== "m16n16k8" and frag
.frag
== "c":
304 assert frag
.ptx_type
== "tf32"
307 itype
= frag
.ptx_type
310 "check_suffix": "_PTX%d_SM%d" % (min_ptx
, min_sm
),
311 "builtin": get_ldst_builtin_name(frag
),
314 "dst": src_dst_prefix
+ "dst",
315 "src": src_dst_prefix
+ "src",
316 "blayout": 0 if layout
== "row" else 1,
317 "intrinsic": Template(intrinsic_template
).substitute(
323 "op": "store" if frag
.frag
== "d" else "load",
327 results
[(min_ptx
, min_sm
)] += Template(load_template
).substitute(params
)
332 def mma_signature(op
):
333 if op
.a
.ptx_type
== "f16":
334 # FP16 ops identified by accumulator & result type.
335 return "%s.%s" % (op
.d
.ptx_type
, op
.c
.ptx_type
)
337 # other ops are identified by input type.
341 # Get numeric value for rowcol parameter of the builtin
342 # AFAICT it uses the encoding accepted by NVVM intrinsics:
343 # https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#nvvm-intrin-warp-level-matrix-mma
344 def get_ilayout(a
, b
):
345 return {"row.row": 0, "row.col": 1, "col.row": 2, "col.col": 3}[a
+ "." + b
]
348 def gen_wmma_mma_tests(results
):
350 // CHECK${check_suffix}: call {{.*}} @${intrinsic}
351 // expected-error-re@+1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}}
352 ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_satf});
354 intrinsic_template
= "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}.${intrinsic_signature}${satf}"
356 for op
, alayout
, blayout
, satf
in sorted(
357 product(get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]),
361 if not is_mma_variant_supported(op
, alayout
, blayout
, satf
):
364 asrc_prefix
= get_src_dst_prefix(op
.a
)
365 csrc_prefix
= get_src_dst_prefix(op
.c
)
366 ddst_prefix
= get_src_dst_prefix(op
.d
)
367 if op
.a
.ptx_type
== "b1": # .b1 MMA has no satf argument.
370 isatf_arg
= ", 1" if satf
else ", 0"
371 min_sm
= get_required_sm(op
.a
, op
.b1op
)
372 min_ptx
= get_required_ptx(op
.a
, op
.b1op
)
374 "check_suffix": "_PTX%d_SM%d" % (min_ptx
, min_sm
),
375 "builtin": get_mma_builtin_name(op
),
378 "dst": ddst_prefix
+ "dst",
379 "asrc": asrc_prefix
+ "src",
380 "csrc": csrc_prefix
+ "src",
381 "ilayout": get_ilayout(alayout
, blayout
),
382 "maybe_satf": isatf_arg
,
383 "intrinsic": Template(intrinsic_template
).substitute(
388 "intrinsic_signature": mma_signature(op
),
394 results
[(min_ptx
, min_sm
)] += Template(mma_template
).substitute(params
)
400 results
= gen_wmma_ldst_tests(defaultdict(str))
401 results
= gen_wmma_mma_tests(results
)
405 // *** DO NOT EDIT ***
407 // This test has been automatically generated by
408 // builtins-nvtx-mma.py --ptx=${ptx} --gpu-arch=${sm}
410 // Make sure we can handle all builtins available on sm_${sm} with PTX${ptx}
411 // ${run}: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_${sm} \
412 // ${run}: -fcuda-is-device -target-feature +ptx${ptx} \
413 // ${run}: -DPTX=${ptx} -DSM=${sm} \
414 // ${run}: -S -emit-llvm -o - -x cuda %s \
415 // ${run}: | FileCheck -check-prefixes=${check_labels} %s
416 // Verify that all builtins have correct constraints.
417 // ${run}: %clang_cc1 -triple nvptx-unknown-unknown \
418 // ${run}: -target-cpu sm_60 -target-feature +ptx42 \
419 // ${run}: -DPTX=${ptx} -DSM=${sm} -fcuda-is-device -S -o /dev/null -x cuda \
420 // ${run}: -verify %s
423 def supported_variants(ptx
, sm
, results
):
424 return [(ptx_
, sm_
) for ptx_
, sm_
in results
if ptx_
<= ptx
and sm_
<= sm
]
427 Template(run_template
).substitute(
429 "run": "RUN", # To avoid lit misinterpreting the template
432 "check_labels": ",".join(
434 "CHECK_PTX%d_SM%d" % (ptx_
, sm_
)
435 for ptx_
, sm_
in supported_variants(
436 ptx_version
, gpu_arch
, results
446 #if !defined(CUDA_VERSION)
447 #define __device__ __attribute__((device))
448 #define __global__ __attribute__((global))
449 #define __shared__ __attribute__((shared))
450 #define __constant__ __attribute__((constant))
452 typedef unsigned long long uint64_t;
455 // CHECK-LABEL: test_wmma_buitins
456 __device__ void test_wmma_buitins(int *src, int *dst,
457 float *fsrc, float *fdst,
458 double *dsrc, double *ddst, int ldm) {
462 for (ptx
, sm
), tests
in sorted(results
.items()):
464 print("#if (PTX >= %d) && (SM >= %d)" % (ptx
, sm
))
466 print("#endif // (PTX >= %d) && (SM >= %d)" % (ptx
, sm
))
471 parser
= argparse
.ArgumentParser()
472 parser
.add_argument("--ptx", type=int, default
=60)
473 parser
.add_argument("--gpu-arch", type=int, default
=70)
474 args
= parser
.parse_args()
475 ptx_version
= args
.ptx
476 gpu_arch
= args
.gpu_arch