Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / clang / test / CodeGen / builtins-nvptx-mma.py
blob72f4f5c9f655ef2ea7ac348361b9b2bce326c0a1
1 # This script generates all variants of wmma builtins, verifies that clang calls
2 # correct LLVM intrinsics, and checks that availability of specific builtins is
3 # constrained by the correct PTX version and the target GPU variant.
5 # Dummy test run to avoid lit warnings.
6 # RUN: echo "This is not a real test. It's a generator for builtins-nvpts-mma.cu" >/dev/null
8 from __future__ import print_function
10 import argparse
11 from collections import defaultdict
12 from itertools import product
13 from string import Template
16 class MMAFrag:
17 def __init__(self, geom, frag, ptx_elt_type):
18 self.geom = geom
19 self.frag = frag
20 self.ptx_type = ptx_elt_type
22 def __repr__(self):
23 return "%s:%s:%s" % (self.geom, self.frag, self.ptx_type)
26 class MMAOp:
27 def __init__(self, a, b, c, d, b1op=""):
28 self.a = a
29 self.b = b
30 self.c = c
31 self.d = d
32 self.b1op = b1op
34 def __repr__(self):
35 return "{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d)
38 def make_mma_ops(geoms, types_a, types_b, types_c, types_d, b1ops=None):
39 ops = []
40 if b1ops is None:
41 b1ops = [""]
42 for geom, type_a, type_c in product(geoms, types_a, types_c):
43 for type_b, type_d in product(
44 types_b if types_b else [type_a], types_d if types_d else [type_c]
46 ops += [
47 MMAOp(
48 MMAFrag(geom, "a", type_a),
49 MMAFrag(geom, "b", type_b),
50 MMAFrag(geom, "c", type_c),
51 MMAFrag(geom, "d", type_d),
52 b1op,
54 for b1op in b1ops
56 return ops
59 def make_ldst_ops(geoms, frags, types):
60 return [
61 MMAFrag(geom, frag, ptx_type)
62 for (geom, frag, ptx_type) in product(geoms, frags, types)
66 def get_mma_ops():
67 return (
68 make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
69 + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["bf16"], [], ["f32"], [])
70 + make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
71 + make_mma_ops(
72 ["m16n16k16", "m32n8k16", "m8n32k16"],
73 ["f16"],
74 [],
75 ["f16", "f32"],
76 ["f16", "f32"],
78 + make_mma_ops(
79 ["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], []
81 + make_mma_ops(["m8n8k32"], ["s4", "u4"], [], ["s32"], [])
82 + make_mma_ops(
83 ["m8n8k128"], ["b1"], [], ["s32"], [], [".xor.popc", ".and.popc"]
88 def get_ldst_ops():
89 # NOTE: fragemts are from the point of view of PTX.
90 # fragment `d` is only for store ops, others for both loads and stores.
91 return (
92 make_ldst_ops(
93 ["m16n16k16", "m32n8k16", "m8n32k16"],
94 ["a", "b"],
95 ["f16", "u8", "s8", "bf16"],
97 + make_ldst_ops(
98 ["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"]
100 + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4", "u4"])
101 + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"])
102 + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])
103 + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"])
105 # TF32 m16n16k8 is odd.
106 # For fragment 'C' it uses __mma_*tf32*_m16n16k8_ld_c
107 # but 'D' calls __mma_m16n16k8_st_c_*f32*.
108 make_ldst_ops(["m16n16k8"], ["a", "b", "c"], ["tf32"])
109 + make_ldst_ops(["m16n16k8"], ["d"], ["f32"])
113 def is_geom_supported(geom):
114 # geometries for FP and ints.
115 if geom in ["m8n32k16", "m32n8k16"]:
116 return ptx_version >= 61
117 # geometries for sub-ints.
118 if geom in ["m8n8k32", "m8n8k128"]:
119 return ptx_version >= 63 and gpu_arch >= 75
120 if geom == "m16n16k16":
121 return ptx_version >= 60
122 if geom in ["m16n16k8", "m8n8k4"]:
123 return ptx_version >= 70 and gpu_arch >= 80
124 assert False # Unexpected geometry.
127 def is_type_supported(ptx_type):
128 if ptx_type in ["s8", "u8", "s32"]:
129 return ptx_version >= 63 and gpu_arch >= 72
130 if ptx_type in ["s4", "u4", "b1"]:
131 return ptx_version >= 63 and gpu_arch >= 75
132 if ptx_type in ["bf16", "tf32", "f64"]:
133 return ptx_version >= 70 and gpu_arch >= 80
134 return ptx_version >= 60 and gpu_arch >= 70
137 def is_rnd_supported(op):
138 # rnd is only supported for FP64 WMMA
139 return op.a.ptx_type == "f64"
142 def is_mma_variant_supported(op, layout_a, layout_b, satf):
143 if not (is_type_supported(op.a.ptx_type) and is_geom_supported(op.a.geom)):
144 return False
146 if satf and not op.a.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
147 return False
149 # sub-integer types require row/col layout.
150 if op.a.ptx_type in ["s4", "u4", "b1"]:
151 return layout_a == "row" and layout_b == "col"
152 return True
155 def is_ldst_variant_supported(frag, layout):
156 if not (is_type_supported(frag.ptx_type) and is_geom_supported(frag.geom)):
157 return False
158 if frag.ptx_type in ["s4", "u4", "b1"]:
159 # sub-integer types require sm_75 and ptx63, row/col layout for a/b.
160 return (
161 (frag.frag == "a" and layout == "row")
162 or (frag.frag == "b" and layout == "col")
163 or frag.frag in ["c", "d"]
165 return True
168 def get_builtin_prefix(frag):
169 prefix = None
170 if frag.geom in ["m16n16k16", "m32n8k16", "m8n32k16"]:
171 if frag.ptx_type in ["f16", "f32"]:
172 prefix = "__hmma"
173 elif frag.ptx_type == "bf16":
174 prefix = "__mma_bf16"
175 else:
176 prefix = "__imma"
177 elif frag.geom == "m8n8k32":
178 prefix = "__imma" # sub-integers
179 elif frag.geom == "m8n8k128":
180 prefix = "__bmma"
181 elif frag.geom == "m8n8k4":
182 prefix = "__dmma"
183 elif frag.geom == "m16n16k8":
184 if frag.ptx_type == "f32":
185 prefix = "__mma"
186 else:
187 prefix = "__mma_tf32"
188 assert prefix
189 return prefix
192 def get_ldst_builtin_name(frag):
193 prefix = get_builtin_prefix(frag)
195 if prefix == "__hmma":
196 suffix = "" if frag.frag in ["a", "b"] else frag.ptx_type
197 elif prefix in ["__dmma", "__mma_bf16", "__mma_tf32"]:
198 suffix = "" if frag.frag in ["a", "b", "c"] else frag.ptx_type
199 else:
200 suffix = "" if frag.frag == "c" else frag.ptx_type
201 if suffix == "s32":
202 suffix = "i32"
204 if frag.frag == "d":
205 ifrag = "c"
206 op = "st"
207 else:
208 ifrag = frag.frag
209 op = "ld"
211 name = "%s_%s_%s_%s%s" % (
212 prefix,
213 frag.geom,
215 ifrag,
216 "_" + suffix if suffix else "",
218 return name
221 def get_mma_builtin_name(op):
222 prefix = get_builtin_prefix(op.a)
224 if prefix == "__hmma":
225 suffix = op.d.ptx_type + op.c.ptx_type
226 elif prefix in ["__mma_bf16", "__mma_tf32"]:
227 suffix = op.d.ptx_type
228 else:
229 suffix = op.a.ptx_type
231 name = "{prefix}_{geom}_mma{b1op}_{suffix}".format(
232 prefix=prefix, geom=op.a.geom, b1op=op.b1op.replace(".", "_"), suffix=suffix
234 return name
237 def get_required_sm(frag, b1op=""):
238 if frag.ptx_type in ["f64", "bf16", "tf32"]:
239 return 80
240 if frag.ptx_type in ["u4", "s4", "b1"]:
241 if b1op == ".and.popc":
242 return 80
243 return 75
244 if frag.ptx_type in ["s8", "u8"]:
245 return 72
246 if frag.ptx_type == "s32":
247 if frag.geom in ["m8n8k32", "m8n8k128"]: # s4/u4/b1
248 return 75
249 else: # s8/u8
250 return 72
251 if frag.ptx_type in ["f16", "f32"]:
252 if frag.geom == "m16n16k8":
253 return 80
254 else:
255 return 70
256 assert False
259 def get_required_ptx(frag, b1op=""):
260 if frag.ptx_type == "b1" and b1op == ".and.popc":
261 return 71
262 if frag.ptx_type in ["f64", "bf16", "tf32"]:
263 return 70
264 if frag.ptx_type in ["f16", "f32"]:
265 if frag.geom == "m16n16k16":
266 return 60
267 if frag.geom == "m16n16k8":
268 return 70
269 return 61
270 return 63
273 def get_src_dst_prefix(frag):
274 if frag.ptx_type == "f32":
275 return "f"
276 if frag.ptx_type == "f64":
277 return "d"
278 if frag.ptx_type == "tf32" and frag.frag in ["c", "d"]:
279 return "f"
280 return ""
283 def gen_wmma_ldst_tests(results):
284 load_template = """
285 // CHECK${check_suffix}: call {{.*}} @${intrinsic}
286 // expected-error-re@+1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}}
287 ${builtin}(${dst}, ${src}, ldm, ${blayout});
288 """.rstrip()
289 intrinsic_template = (
290 "llvm.nvvm.wmma.${geom}.${op}.${frag}.${ilayout}.stride.${itype}"
293 for frag, layout in sorted(product(get_ldst_ops(), ["row", "col"]), key=str):
295 if not is_ldst_variant_supported(frag, layout):
296 continue
298 src_dst_prefix = get_src_dst_prefix(frag)
300 min_sm = get_required_sm(frag)
301 min_ptx = get_required_ptx(frag)
302 # TF32 uses f32 for accumulator loads.
303 if frag.geom == "m16n16k8" and frag.frag == "c":
304 assert frag.ptx_type == "tf32"
305 itype = "f32"
306 else:
307 itype = frag.ptx_type
309 params = {
310 "check_suffix": "_PTX%d_SM%d" % (min_ptx, min_sm),
311 "builtin": get_ldst_builtin_name(frag),
312 "min_ptx": min_ptx,
313 "min_sm": min_sm,
314 "dst": src_dst_prefix + "dst",
315 "src": src_dst_prefix + "src",
316 "blayout": 0 if layout == "row" else 1,
317 "intrinsic": Template(intrinsic_template).substitute(
319 "frag": frag.frag,
320 "geom": frag.geom,
321 "ilayout": layout,
322 "itype": itype,
323 "op": "store" if frag.frag == "d" else "load",
327 results[(min_ptx, min_sm)] += Template(load_template).substitute(params)
329 return results
332 def mma_signature(op):
333 if op.a.ptx_type == "f16":
334 # FP16 ops identified by accumulator & result type.
335 return "%s.%s" % (op.d.ptx_type, op.c.ptx_type)
336 else:
337 # other ops are identified by input type.
338 return op.a.ptx_type
341 # Get numeric value for rowcol parameter of the builtin
342 # AFAICT it uses the encoding accepted by NVVM intrinsics:
343 # https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#nvvm-intrin-warp-level-matrix-mma
344 def get_ilayout(a, b):
345 return {"row.row": 0, "row.col": 1, "col.row": 2, "col.col": 3}[a + "." + b]
348 def gen_wmma_mma_tests(results):
349 mma_template = """
350 // CHECK${check_suffix}: call {{.*}} @${intrinsic}
351 // expected-error-re@+1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}}
352 ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_satf});
353 """.rstrip()
354 intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}.${intrinsic_signature}${satf}"
356 for op, alayout, blayout, satf in sorted(
357 product(get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]),
358 key=str,
361 if not is_mma_variant_supported(op, alayout, blayout, satf):
362 continue
364 asrc_prefix = get_src_dst_prefix(op.a)
365 csrc_prefix = get_src_dst_prefix(op.c)
366 ddst_prefix = get_src_dst_prefix(op.d)
367 if op.a.ptx_type == "b1": # .b1 MMA has no satf argument.
368 isatf_arg = ""
369 else:
370 isatf_arg = ", 1" if satf else ", 0"
371 min_sm = get_required_sm(op.a, op.b1op)
372 min_ptx = get_required_ptx(op.a, op.b1op)
373 params = {
374 "check_suffix": "_PTX%d_SM%d" % (min_ptx, min_sm),
375 "builtin": get_mma_builtin_name(op),
376 "min_ptx": min_ptx,
377 "min_sm": min_sm,
378 "dst": ddst_prefix + "dst",
379 "asrc": asrc_prefix + "src",
380 "csrc": csrc_prefix + "src",
381 "ilayout": get_ilayout(alayout, blayout),
382 "maybe_satf": isatf_arg,
383 "intrinsic": Template(intrinsic_template).substitute(
385 "geom": op.a.geom,
386 "alayout": alayout,
387 "blayout": blayout,
388 "intrinsic_signature": mma_signature(op),
389 "satf": satf,
390 "b1op": op.b1op,
394 results[(min_ptx, min_sm)] += Template(mma_template).substitute(params)
396 return results
399 def gen_tests():
400 results = gen_wmma_ldst_tests(defaultdict(str))
401 results = gen_wmma_mma_tests(results)
403 run_template = r"""
405 // *** DO NOT EDIT ***
407 // This test has been automatically generated by
408 // builtins-nvtx-mma.py --ptx=${ptx} --gpu-arch=${sm}
410 // Make sure we can handle all builtins available on sm_${sm} with PTX${ptx}
411 // ${run}: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_${sm} \
412 // ${run}: -fcuda-is-device -target-feature +ptx${ptx} \
413 // ${run}: -DPTX=${ptx} -DSM=${sm} \
414 // ${run}: -S -emit-llvm -o - -x cuda %s \
415 // ${run}: | FileCheck -check-prefixes=${check_labels} %s
416 // Verify that all builtins have correct constraints.
417 // ${run}: %clang_cc1 -triple nvptx-unknown-unknown \
418 // ${run}: -target-cpu sm_60 -target-feature +ptx42 \
419 // ${run}: -DPTX=${ptx} -DSM=${sm} -fcuda-is-device -S -o /dev/null -x cuda \
420 // ${run}: -verify %s
423 def supported_variants(ptx, sm, results):
424 return [(ptx_, sm_) for ptx_, sm_ in results if ptx_ <= ptx and sm_ <= sm]
426 print(
427 Template(run_template).substitute(
429 "run": "RUN", # To avoid lit misinterpreting the template
430 "ptx": ptx_version,
431 "sm": gpu_arch,
432 "check_labels": ",".join(
434 "CHECK_PTX%d_SM%d" % (ptx_, sm_)
435 for ptx_, sm_ in supported_variants(
436 ptx_version, gpu_arch, results
444 print(
446 #if !defined(CUDA_VERSION)
447 #define __device__ __attribute__((device))
448 #define __global__ __attribute__((global))
449 #define __shared__ __attribute__((shared))
450 #define __constant__ __attribute__((constant))
452 typedef unsigned long long uint64_t;
453 #endif
455 // CHECK-LABEL: test_wmma_buitins
456 __device__ void test_wmma_buitins(int *src, int *dst,
457 float *fsrc, float *fdst,
458 double *dsrc, double *ddst, int ldm) {
462 for (ptx, sm), tests in sorted(results.items()):
463 print()
464 print("#if (PTX >= %d) && (SM >= %d)" % (ptx, sm))
465 print(tests)
466 print("#endif // (PTX >= %d) && (SM >= %d)" % (ptx, sm))
468 print("}")
471 parser = argparse.ArgumentParser()
472 parser.add_argument("--ptx", type=int, default=60)
473 parser.add_argument("--gpu-arch", type=int, default=70)
474 args = parser.parse_args()
475 ptx_version = args.ptx
476 gpu_arch = args.gpu_arch
478 gen_tests()