[ORC] Add std::tuple support to SimplePackedSerialization.
[llvm-project.git] / llvm / test / CodeGen / NVPTX / wmma.py
blob3b3d10947cac9dab88d6771c0dce565a269dcd2c
1 # This test generates all variants of wmma intrinsics and verifies that LLVM
2 # generates correct instructions for them.
4 # Check all variants of instructions supported by PTX60 on SM70
5 # RUN: %python %s --ptx=60 --gpu-arch=70 > %t-ptx60-sm_70.ll
6 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
7 # RUN: --check-prefixes=INTRINSICS,M16N16
8 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
9 # RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
10 # RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
11 # RUN: | FileCheck %t-ptx60-sm_70.ll
13 # Check all variants of instructions supported by PTX61 on SM70
14 # RUN: %python %s --ptx=61 --gpu-arch=70 > %t-ptx61-sm_70.ll
15 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
16 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM
17 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
18 # RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
19 # RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
20 # RUN: | FileCheck %t-ptx61-sm_70.ll
22 # Check all variants of instructions supported by PTX63 on SM72
23 # RUN: %python %s --ptx=63 --gpu-arch=72 > %t-ptx63-sm_72.ll
24 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
25 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT
26 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
27 # RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
28 # RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
29 # RUN: | FileCheck %t-ptx63-sm_72.ll
31 # Check all variants of instructions supported by PTX63 on SM75
32 # RUN: %python %s --ptx=63 --gpu-arch=75 > %t-ptx63-sm_75.ll
33 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
34 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT
35 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
36 # RUN: --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
37 # RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
38 # RUN: | FileCheck %t-ptx63-sm_75.ll
40 # Check all variants of instructions supported by PTX64 on SM70+
41 # RUN: %python %s --ptx=64 --gpu-arch=70 > %t-ptx64-sm_70.ll
42 # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
43 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA
44 # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
45 # RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT,NOLDMATRIX
46 # RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \
47 # RUN: | FileCheck %t-ptx64-sm_70.ll
49 # Check all variants of instructions supported by PTX65 on SM75+
50 # RUN: %python %s --ptx=65 --gpu-arch=75 > %t-ptx65-sm_75.ll
51 # RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
52 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA,PTX65LDMATRIX
53 # RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
54 # RUN: --check-prefixes=INTRINSICS
55 # RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
56 # RUN: | FileCheck %t-ptx65-sm_75.ll
58 # Check all variants of instructions supported by PTX71 on SM80+
59 # RUN: %python %s --ptx=71 --gpu-arch=80 > %t-ptx71-sm_80.ll
60 # RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
61 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX65LDMATRIX,PTX71MMA
62 # RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
63 # RUN: --check-prefixes=INTRINSICS
64 # RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \
65 # RUN: | FileCheck %t-ptx71-sm_80.ll
67 from __future__ import print_function
69 import argparse
70 from itertools import product
71 from string import Template
73 class MMAType:
74 def __init__(self, ptx_type):
75 self.ptx_type = ptx_type
76 self.llvm_type = {
77 "f16" : "<2 x half>",
78 "f32" : "float",
79 "f64" : "double",
80 "s32" : "i32",
81 "b16" : "i32",
82 "s8" : "i32",
83 "u8" : "i32",
84 "s4" : "i32",
85 "u4" : "i32",
86 "b1" : "i32",
87 "bf16" : "i32",
88 "tf32" : "i32",
89 }[ptx_type];
91 self.ptx_reg_pattern = {
92 "f16" : "%hh[0-9]+",
93 "f32" : "%f[0-9]+",
94 "f64" : "%fd[0-9]+",
95 }.get(ptx_type, "%r[0-9]+")
97 def __repr__(self):
98 return "%s/%s" % (self.ptx_type, self.llvm_type)
100 class MMAFrag:
101 def __init__(self, geom, frag, ptx_elt_type):
102 self.geom = geom
103 self.frag = frag
104 self.mma_type = MMAType(ptx_elt_type);
105 self.nregs = {
106 # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
107 "m16n16k16:a:u8" : 2,
108 "m16n16k16:a:s8" : 2,
109 "m16n16k16:b:u8" : 2,
110 "m16n16k16:b:s8" : 2,
111 "m16n16k16:c:s32" : 8,
112 "m16n16k16:d:s32" : 8,
114 "m8n32k16:a:u8" : 1,
115 "m8n32k16:a:s8" : 1,
116 "m8n32k16:b:u8" : 4,
117 "m8n32k16:b:s8" : 4,
118 "m8n32k16:c:s32" : 8,
119 "m8n32k16:d:s32" : 8,
121 "m32n8k16:a:u8" : 4,
122 "m32n8k16:a:s8" : 4,
123 "m32n8k16:b:u8" : 1,
124 "m32n8k16:b:s8" : 1,
125 "m32n8k16:c:s32" : 8,
126 "m32n8k16:d:s32" : 8,
128 "m8n8k16:a:u8": 1,
129 "m8n8k16:a:s8": 1,
130 "m8n8k16:b:u8": 1,
131 "m8n8k16:b:s8": 1,
132 "m8n8k16:c:s32": 2,
133 "m8n8k16:d:s32": 2,
135 "m16n8k16:a:u8": 2,
136 "m16n8k16:a:s8": 2,
137 "m16n8k16:b:u8": 1,
138 "m16n8k16:b:s8": 1,
139 "m16n8k16:c:s32": 4,
140 "m16n8k16:d:s32": 4,
142 "m16n8k32:a:u8": 4,
143 "m16n8k32:a:s8": 4,
144 "m16n8k32:b:u8": 2,
145 "m16n8k32:b:s8": 2,
146 "m16n8k32:c:s32": 4,
147 "m16n8k32:d:s32": 4,
149 # u4/s4 -> s32 @ m8n8k32 (u4/s4)
150 "m8n8k32:a:u4" : 1,
151 "m8n8k32:a:s4" : 1,
152 "m8n8k32:b:u4" : 1,
153 "m8n8k32:b:s4" : 1,
154 "m8n8k32:c:s32" : 2,
155 "m8n8k32:d:s32" : 2,
157 "m16n8k32:a:u4" : 2,
158 "m16n8k32:a:s4" : 2,
159 "m16n8k32:b:u4" : 1,
160 "m16n8k32:b:s4" : 1,
161 "m16n8k32:c:s32" : 4,
162 "m16n8k32:d:s32" : 4,
164 "m16n8k64:a:u4" : 4,
165 "m16n8k64:a:s4" : 4,
166 "m16n8k64:b:u4" : 2,
167 "m16n8k64:b:s4" : 2,
168 "m16n8k64:c:s32" : 4,
169 "m16n8k64:d:s32" : 4,
171 # b1 -> s32 @ m8n8k128(b1)
172 "m8n8k128:a:b1" : 1,
173 "m8n8k128:b:b1" : 1,
174 "m8n8k128:c:s32" : 2,
175 "m8n8k128:d:s32" : 2,
177 "m16n8k128:a:b1" : 2,
178 "m16n8k128:b:b1" : 1,
179 "m16n8k128:c:s32" : 4,
180 "m16n8k128:d:s32" : 4,
182 "m16n8k256:a:b1" : 4,
183 "m16n8k256:b:b1" : 2,
184 "m16n8k256:c:s32" : 4,
185 "m16n8k256:d:s32" : 4,
187 # bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
188 "m16n16k16:a:bf16" : 4,
189 "m16n16k16:b:bf16" : 4,
190 "m8n32k16:a:bf16" : 2,
191 "m8n32k16:b:bf16" : 8,
192 "m32n8k16:a:bf16" : 8,
193 "m32n8k16:b:bf16" : 2,
195 "m16n8k16:a:bf16" : 4,
196 "m16n8k16:b:bf16" : 2,
197 "m16n8k16:c:f32" : 4,
198 "m16n8k16:d:f32" : 4,
199 "m16n8k8:a:bf16" : 2,
200 "m16n8k8:b:bf16" : 1,
201 "m16n8k8:c:f32" : 4,
202 "m16n8k8:d:f32" : 4,
204 "m8n8k4:a:f64" : 1,
205 "m8n8k4:b:f64" : 1,
206 "m8n8k4:c:f64" : 2,
207 "m8n8k4:d:f64" : 2,
209 # tf32 -> s32 @ m16n16k8
210 "m16n16k8:a:tf32" : 4,
211 "m16n16k8:b:tf32" : 4,
213 "m16n8k4:a:tf32" : 2,
214 "m16n8k4:b:tf32" : 1,
215 "m16n8k4:c:f32" : 4,
216 "m16n8k4:d:f32" : 4,
217 "m16n8k8:a:tf32" : 4,
218 "m16n8k8:b:tf32" : 2,
219 "m16n8k8:c:f32" : 4,
220 "m16n8k8:d:f32" : 4,
222 "m8n8k4:a:f16": 2,
223 "m8n8k4:b:f16": 2,
224 "m16n8k8:a:f16": 2,
225 "m16n8k8:b:f16": 1,
226 "m16n8k8:c:f16": 2,
227 "m16n8k8:d:f16": 2,
228 "m16n8k8:c:f32": 4,
229 "m16n8k8:d:f32": 4,
230 "m16n8k16:a:f16": 4,
231 "m16n8k16:b:f16": 2,
232 "m16n8k16:c:f16": 2,
233 "m16n8k16:d:f16": 2,
234 "m16n8k16:c:f32": 4,
235 "m16n8k16:d:f32": 4,
237 # ldmatrix
238 "m8n8:x1:b16": 1,
239 "m8n8:x2:b16": 2,
240 "m8n8:x4:b16": 4,
241 }.get("%s:%s:%s" % (geom, frag, ptx_elt_type), {
242 # All other FP shape/fragment/type combinations have the same size
243 "a:f16" : 8,
244 "b:f16" : 8,
245 "c:f16" : 4,
246 "d:f16" : 4,
247 "c:f32" : 8,
248 "d:f32" : 8,
249 }.get("%s:%s" % (frag, ptx_elt_type), None))
250 assert(self.nregs);
252 def __repr__(self):
253 return "%s:%s:%s%s" % (self.geom, self.frag, self.mma_type,
254 "" if self.nregs == 1 else ("*%d" % self.nregs))
256 class MMAOp:
257 def __init__(self, a, b, c, d):
258 self.a = a
259 self.b = b
260 self.c = c
261 self.d = d
263 def __repr__(self):
264 return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d ))
266 def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
267 ops = []
268 for geom, type_a, type_c in product( geoms, types_a, types_c):
269 for type_b, type_d in product(types_b if types_b else [type_a],
270 types_d if types_d else [type_c]):
271 ops.append(MMAOp(MMAFrag(geom, "a", type_a),
272 MMAFrag(geom, "b", type_b),
273 MMAFrag(geom, "c", type_c),
274 MMAFrag(geom, "d", type_d)))
275 return ops
277 def make_ldst_ops(geoms, frags, types):
278 return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
279 in product(geoms, frags, types)]
281 def make_ldmatrix_ops(geoms, frags, types):
282 return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
283 in product(geoms, frags, types)]
285 def get_wmma_ops():
286 return (make_mma_ops(["m16n16k8"],
287 ["tf32"], [], ["f32"], []) +
288 make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
289 ["bf16"], [], ["f32"], []) +
290 make_mma_ops(["m8n8k4"],
291 ["f64"], [], ["f64"], []) +
292 make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
293 ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
294 make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
295 ["s8", "u8"], [], ["s32"], []) +
296 make_mma_ops(["m8n8k32"],
297 ["s4", "u4"], [], ["s32"], []) +
298 make_mma_ops(["m8n8k128"],
299 ["b1"], [], ["s32"], []))
301 def get_mma_ops():
302 return (make_mma_ops(["m8n8k4"],
303 ["f64"], [], ["f64"], []) +
304 make_mma_ops(["m16n8k4", "m16n8k8"],
305 ["tf32"], [], ["f32"], []) +
306 make_mma_ops(["m16n8k16", "m16n8k8"],
307 ["bf16"], [], ["f32"], []) +
308 make_mma_ops(["m8n8k4", "m16n8k8", "m16n8k16"],
309 ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
310 make_mma_ops(["m8n8k16", "m16n8k16", "m16n8k32"],
311 ["s8", "u8"], ["s8", "u8"], ["s32"], []) +
312 make_mma_ops(["m8n8k32", "m16n8k32", "m16n8k64"],
313 ["s4", "u4"], ["s4", "u4"], ["s32"], []) +
314 make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"],
315 ["b1"], [], ["s32"], []))
317 def get_ldst_ops(kind):
318 ldst_ops = (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
319 ["a", "b"], ["f16", "u8", "s8", "bf16"]) +
320 make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
321 ["c", "d"], ["f16", "f32", "s32"]) +
322 make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
323 make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
324 make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) +
325 make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) +
326 make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) +
327 make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]))
328 return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
330 def get_ldmatrix_ops():
331 return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
333 def is_wmma_geom_supported(geom):
334 # geometries for FP and ints.
335 if geom in ["m8n32k16", "m32n8k16"]:
336 return ptx_version >= 61
337 # geometries for sub-ints.
338 if geom in ["m8n8k32", "m8n8k128"]:
339 return ptx_version >= 63 and gpu_arch >= 75
340 if geom == "m16n16k16":
341 return ptx_version >= 60
342 if geom == "m16n8k8":
343 return ptx_version >= 65
344 if geom in ["m16n16k8", "m8n8k4"]:
345 return ptx_version >= 70
346 assert(False) # Unexpected geometry.
348 def is_mma_geom_supported(geom):
349 # geometries for FP and ints.
350 if geom == "m8n8k4":
351 return ptx_version >= 64
352 if geom in ["m16n8k8", "m8n8k16", "m8n8k32"]:
353 return ptx_version >= 65
354 if geom in ["m16n8k16", "m16n8k4", "m16n8k32", "m16n8k64", "m8n8k128",
355 "m16n8k128", "m16n8k256"]:
356 return ptx_version >= 70
357 assert(False) # Unexpected geometry.
359 def is_ldmatrix_geom_supported(geom):
360 if geom in ["m8n8"]:
361 return ptx_version >= 65 and gpu_arch >= 75
362 assert(False) # Unexpected geometry.
364 def is_type_supported(ptx_type):
365 if ptx_type in ["s8", "u8", "s32"]:
366 return ptx_version >= 63 and gpu_arch >= 72
367 if ptx_type in ["s4", "u4", "b1"]:
368 return ptx_version >= 63 and gpu_arch >= 75
369 if ptx_type == "b16":
370 return ptx_version >= 65 and gpu_arch >= 75
371 if ptx_type in ["bf16", "tf32", "f64"]:
372 return ptx_version >= 70
373 return ptx_version >= 60 and gpu_arch >= 70
375 def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
376 if not (is_type_supported(op.a.mma_type.ptx_type)
377 and is_wmma_geom_supported(op.a.geom)):
378 return False
380 # rnd is only supported for FP64 WMMA
381 if rnd and op.a.mma_type.ptx_type != "f64":
382 return False
384 if satf:
385 # satfinite for floating points was removed in PTX 6.5
386 if op.a.mma_type.ptx_type == "f16" and ptx_version >= 65:
387 return False
388 if not op.a.mma_type.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
389 return False
391 # sub-integer require row/col layout.
392 if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
393 return layout_a == "row" and layout_b == "col"
394 return True
396 def is_mma_variant_supported(op, layout_a, layout_b, satf):
397 if not (is_type_supported(op.a.mma_type.ptx_type)
398 and is_mma_geom_supported(op.a.geom)):
399 return False
401 if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]:
402 return False
404 # If the type of C is f32 then so must the type of D
405 if (op.a.geom == "m8n8k4" and op.c.mma_type.ptx_type == "f32"
406 and op.d.mma_type.ptx_type != "f32"):
407 return False
409 # A and B type must be the same. C and D type must be the same
410 if (op.a.geom == "m16n8k8"
411 and (op.a.mma_type.ptx_type != op.b.mma_type.ptx_type
412 or op.c.mma_type.ptx_type != op.d.mma_type.ptx_type)):
413 return False
415 # C and D type must be the same
416 if (op.a.geom == "m16n8k16"
417 and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type):
418 return False
420 # Require row/col layout for all MMA except m8n8k4 on FP16
421 if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
422 return layout_a == "row" and layout_b == "col"
423 return True
425 def is_ldst_variant_supported(frag, layout):
426 if not (is_type_supported(frag.mma_type.ptx_type)
427 and is_wmma_geom_supported(frag.geom)):
428 return False
429 if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
430 # sub-integer require sm_75 and ptx63, row/col layout for a/b.
431 return ((frag.frag == "a" and layout == "row")
432 or (frag.frag == "b" and layout == "col")
433 or frag.frag in ["c", "d"])
434 return True
436 def is_ldmatrix_variant_supported(frag):
437 if not (is_type_supported(frag.mma_type.ptx_type)
438 and is_ldmatrix_geom_supported(frag.geom)):
439 return False
440 return frag.frag in ["x1", "x2", "x4"]
442 def make_wmma_slice_ty(frag):
443 return [frag.mma_type.llvm_type] * frag.nregs
445 def make_wmma_ld_ret_ty(frag):
446 results = make_wmma_slice_ty(frag)
447 if len(results) == 1:
448 return "%s" % results[0]
449 return "{%s}" % ", ".join(results)
451 # returns address space
452 def get_aspace(space):
453 space_map = {
454 ".global" : 1,
455 ".shared" : 3,
456 ".const" : 4,
457 ".local" : 5,
458 ".param" : 101,
459 "" : 0,
460 ".generic": 0
462 return space_map[space];
464 def get_pspace(space):
465 return "p%di8" % get_aspace(space);
467 def check_pattern(frag):
468 return "{{%s}}" % ", *".join([frag.mma_type.ptx_reg_pattern] * frag.nregs)
470 def gen_wmma_load_tests():
471 load_template = """
472 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
474 ; CHECK-LABEL: .func {{.*}}test_${function}(
475 define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
476 ; CHECK: ${instruction}
477 ; CHECK: {${check_result}}
478 ; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
479 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
480 ret ${ret_ty} %v0;
483 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
484 define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
485 ; CHECK: ${instruction}
486 ; CHECK: {${check_result}}
487 ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
488 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
489 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
490 ret ${ret_ty} %v0;
493 intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
494 instruction_template = "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
496 generated_items = []
498 for frag, layout, space, stride in product(
499 get_ldst_ops("load"),
500 ["row","col"],
501 ["",".shared",".global"],
502 ["", ".stride"],
504 if not is_ldst_variant_supported(frag, layout):
505 continue
507 params = {
508 "abc" : frag.frag,
509 "aligned" : ".aligned" if ptx_version >= 63 else "",
510 "layout" : layout,
511 "space" : space,
512 "stride" : stride,
513 "itype" : frag.mma_type.ptx_type,
514 "pspace" : get_pspace(space),
515 "as" : "addrspace(%d)" % get_aspace(space),
516 "geom" : frag.geom,
519 test_params = params
520 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
521 test_params["function"] = test_params["intrinsic"].replace(".","_")
522 test_params["instruction"] = Template(instruction_template).substitute(params)
523 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
524 test_params["check_result"] = check_pattern(frag)
526 if stride:
527 test_params["extra_args"] = ", i32 %stride";
528 test_params["stride_pattern"] = ", %r{{[0-9]+}}"
529 else:
530 test_params["extra_args"] = ""
531 test_params["stride_pattern"] = ""
533 print(Template(load_template).substitute(test_params))
535 generated_items.append((test_params["intrinsic"],
536 test_params["instruction"]))
538 return generated_items
540 def make_wmma_slice_args(frag):
541 return ", ".join(["%s %%%s%d" % (t, frag.frag, i) for i,t
542 in enumerate(make_wmma_slice_ty(frag))])
544 def gen_wmma_store_tests():
545 store_template = """
546 declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
548 ; CHECK-LABEL: .func {{.*}}test_${function}(
549 define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
550 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
551 ; CHECK: {${check_args}}
552 ; CHECK: ${stride_pattern}
553 call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
554 ret void
557 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
558 define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
559 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
560 ; CHECK: ${check_args}
561 ; CHECK: ${stride_pattern}
562 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
563 call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
564 ret void
567 intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
568 instruction_template = "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
570 generated_items = []
572 for frag, layout, space, stride in product(
573 get_ldst_ops("store"),
574 ["row","col"],
575 ["",".shared",".global"],
576 ["", ".stride"]):
578 if not is_ldst_variant_supported(frag, layout):
579 continue
581 params = {
582 "abc" : frag.frag,
583 "aligned" : ".aligned" if ptx_version >= 63 else "",
584 "layout" : layout,
585 "space" : space,
586 "stride" : stride,
587 "itype" : frag.mma_type.ptx_type,
588 "pspace" : get_pspace(space),
589 "as" : "addrspace(%d)" % get_aspace(space),
590 "geom" : frag.geom,
593 test_params = params
594 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
595 test_params["function"] = test_params["intrinsic"].replace(".","_")
596 test_params["instruction"] = Template(instruction_template).substitute(params)
597 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
598 test_params["check_args"] = check_pattern(frag)
599 if stride:
600 test_params["extra_args"] = ", i32 %stride";
601 test_params["stride_pattern"] = ", %r{{[0-9]+}};"
602 else:
603 test_params["extra_args"] = ""
604 test_params["stride_pattern"] = ";"
605 test_params["args"] = make_wmma_slice_args(frag);
607 print(Template(store_template).substitute(test_params))
608 generated_items.append((test_params["intrinsic"],
609 test_params["instruction"]))
611 return generated_items
613 def gen_ldmatrix_tests():
614 ldmatrix_template = """
615 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src);
617 ; CHECK-LABEL: .func {{.*}}test_${function}(
618 define ${ret_ty} @test_${function}(i8 ${as}* %src) {
619 ; CHECK: ${instruction}
620 ; CHECK: {${check_result}}
621 ; CHECK: [%rd{{[0-9]+}}]
622 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src);
623 ret ${ret_ty} %v0;
626 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
627 define ${ret_ty} @test_${function}_o(i8 ${as}* %src) {
628 ; CHECK: ${instruction}
629 ; CHECK: {${check_result}}
630 ; CHECK: [%rd{{[0-9]+}}+128]
631 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
632 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1);
633 ret ${ret_ty} %v0;
636 intrinsic_template = "llvm.nvvm.ldmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
637 instruction_template = "ldmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
639 generated_items = []
641 for frag, space, trans in product(
642 get_ldmatrix_ops(),
643 ["",".shared"],
644 ["",".trans"],
646 if not is_ldmatrix_variant_supported(frag):
647 continue
649 params = {
650 "frag" : frag.frag,
651 "space" : space,
652 "trans" : trans,
653 "itype" : frag.mma_type.ptx_type,
654 "pspace" : get_pspace(space),
655 "as" : "addrspace(%d)" % get_aspace(space),
656 "geom" : frag.geom,
659 test_params = params
660 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
661 test_params["function"] = test_params["intrinsic"].replace(".","_")
662 test_params["instruction"] = Template(instruction_template).substitute(params)
663 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
664 test_params["check_result"] = check_pattern(frag)
666 print(Template(ldmatrix_template).substitute(test_params))
668 generated_items.append((test_params["intrinsic"],
669 test_params["instruction"]))
671 return generated_items
673 def mma_signature(op):
674 if op.a.mma_type.ptx_type == "f16":
675 # FP16 ops identified by accumulator & result type.
676 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
677 elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type:
678 # other ops are identified by input types.
679 return "%s.%s" % (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
680 else:
681 # if input types are the same, it only appears once.
682 return op.a.mma_type.ptx_type
684 def mma_ptx_signature(op):
685 # Encode all four types as D.A.B.C
686 return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
688 def wmma_signature(op):
689 if op.a.mma_type.ptx_type == "f16":
690 # FP16 ops identified by accumulator & result type.
691 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
692 else:
693 # other ops are identified by input type.
694 return op.a.mma_type.ptx_type
696 def wmma_ptx_signature(op):
697 if op.a.mma_type.ptx_type == "f16":
698 # FP16 instructions use D.C
699 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
700 else:
701 # other instructions encode all four types as D.A.B.C
702 return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
704 def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
705 mma_template = """
706 declare ${ret_ty} @${intrinsic}(
707 ${args});
709 ; CHECK-LABEL: .func {{.*}}test_${function}(
710 define ${ret_ty} @test_${function}(
711 ${args}) {
712 ; CHECK: ${instruction}
713 ; CHECK-NEXT: ${check_d}
714 ; CHECK-NEXT: ${check_a}
715 ; CHECK-NEXT: ${check_b}
716 ; CHECK-NEXT: ${check_c}
717 %r = call ${ret_ty} @${intrinsic}(
718 ${args});
719 ret ${ret_ty} %r;
723 test_params = params
724 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
725 test_params["function"] = test_params["intrinsic"].replace(".", "_")
726 test_params["instruction"] = Template(instruction_template).substitute(params)
727 test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
728 test_params["check_a"] = check_pattern(op.a)
729 test_params["check_b"] = check_pattern(op.b)
730 test_params["check_c"] = check_pattern(op.c)
731 test_params["check_d"] = check_pattern(op.d)
732 args = ",\n ".join(make_wmma_slice_args(frag)
733 for frag in (op.a, op.b, op.c))
734 test_params["args"] = args
735 print(Template(mma_template).substitute(test_params))
736 return (test_params["intrinsic"], test_params["instruction"])
738 def get_b1_ops(ptx_type):
739 if ptx_type != "b1":
740 return [""]
741 if ptx_version >= 71:
742 return [".xor.popc", ".and.popc"]
743 return [".xor.popc"]
745 def gen_wmma_mma_tests():
746 wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
747 wmma_instruction_template = "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
749 generated_items=[]
751 for op, alayout, blayout, rnd, satf in product(
752 get_wmma_ops(),
753 ["row","col"],
754 ["row","col"],
755 [".rn", ".rz", ".rm", ".rp", ""],
756 [".satfinite", ""]):
758 if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf):
759 continue
761 for b1op in get_b1_ops(op.a.mma_type.ptx_type):
762 params = {
763 "aligned" : ".aligned" if ptx_version >= 63 else "",
764 "alayout" : alayout,
765 "blayout" : blayout,
766 "intrinsic_signature" : wmma_signature(op),
767 "ptx_signature" : wmma_ptx_signature(op),
768 "satf" : satf,
769 "rnd" : rnd,
770 "geom" : op.a.geom,
771 "b1op" : b1op
774 intrinsic_template = wmma_intrinsic_template
775 instruction_template = wmma_instruction_template
777 generated_items.append(common_mma_test_gen(params, op,
778 intrinsic_template, instruction_template))
780 return generated_items
782 def gen_mma_tests():
783 mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
784 mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
786 generated_items=[]
788 for op, alayout, blayout, satf in product(
789 get_mma_ops(),
790 ["row","col"],
791 ["row","col"],
792 [".satfinite", ""]):
794 if not is_mma_variant_supported(op, alayout, blayout, satf):
795 continue
797 for b1op in get_b1_ops(op.a.mma_type.ptx_type):
798 params = {
799 "aligned" : ".aligned" if ptx_version >= 63 else "",
800 "alayout" : alayout,
801 "blayout" : blayout,
802 "intrinsic_signature" : mma_signature(op),
803 "ptx_signature" : mma_ptx_signature(op),
804 "satf" : satf,
805 "geom" : op.a.geom,
806 "b1op" : b1op
809 intrinsic_template = mma_intrinsic_template
810 instruction_template = mma_instruction_template
812 generated_items.append(common_mma_test_gen(params, op,
813 intrinsic_template, instruction_template))
815 return generated_items
817 # Append complete list of intrinsics and instructions we've generated tests for.
818 # Generate set of checks to verify that that we did generate sensible set of
819 # tests for the given combination of PTX and SM variants.
821 def gen_check_unsupported_ops(items):
822 print("; Complete list of intrinsics supported by PTX%d on sm_%d"
823 % (ptx_version, gpu_arch))
824 print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
825 print("""
827 ; NOEXTGEOM-NOT: {{m8n32|m32n8}}
828 ; NOINT-NOT: .{{s32|s8}}
829 ; NOSUBINT-NOT: {{s4|u4|b1}}
830 ; NOMMA-NOT: .m8n8k4.
831 ; NOALTFLOAT-NOT: .{{bf16|tf32}}
832 ; NODOUBLE-NOT: .f64
833 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned
835 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
836 ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
837 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f32
838 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f16
839 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f16
840 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f32
842 ; PTX60 adds support for m32n8k16/m8n32k16 geometries.
843 ; EXTGEOM-DAG: m32n8k16.load.{{[ab].*}}.f16.p
844 ; EXTGEOM-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
845 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f32
846 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f16
847 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f16
848 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f32
850 ; EXTGEOM-DAG: m8n32k16.load.{{[ab].*}}.f16.p
851 ; EXTGEOM-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
852 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f32
853 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f16
854 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f16
855 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f32
857 ; INT-DAG: m16n16k16.load.{{[ab].*}}.s8.p
858 ; INT-DAG: m8n32k16.load.{{[ab].*}}.s8.p
859 ; INT-DAG: m32n8k16.load.{{[ab].*}}.s8.p
860 ; INT-DAG: m16n16k16.load.{{[ab].*}}.u8.p
861 ; INT-DAG: m8n32k16.load.{{[ab].*}}.u8.p
862 ; INT-DAG: m32n8k16.load.{{[ab].*}}.u8.p
863 ; INT-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
864 ; INT-DAG: m16n16k16.mma.{{.*}}.u8
865 ; INT-DAG: m16n16k16.mma.{{.*}}.s8
866 ; INT-DAG: m8n32k16.mma.{{.*}}.u8
867 ; INT-DAG: m8n32k16.mma.{{.*}}.s8
868 ; INT-DAG: m32n8k16.mma.{{.*}}.u8
869 ; INT-DAG: m32n8k16.mma.{{.*}}.s8
871 ; SUBINT-DAG: m8n8k128.load.{{[ab].*}}.b1.p
872 ; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.s4.p
873 ; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.u4.p
874 ; SUBINT-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
875 ; SUBINT-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
876 ; SUBINT-DAG: m8n8k32.mma.{{.*}}.u4
877 ; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4
878 ; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1
880 ; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p
881 ; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p
882 ; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p
883 ; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p
884 ; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16
885 ; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16
886 ; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16
887 ; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32
889 ; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p
890 ; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p
891 ; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64
893 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32
894 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16
895 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16
896 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32
898 ; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16
899 ; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32
900 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8
901 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8
902 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8
903 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8
904 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4
905 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4
906 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
907 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
909 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.b16
910 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.b16
911 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.b16
912 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.b16
913 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.b16
914 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.b16
915 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.shared.b16
916 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.shared.b16
917 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.shared.b16
918 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16
919 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
920 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
922 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
923 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
924 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
925 ; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16
926 ; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16
927 ; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16
928 ; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32
929 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
930 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
931 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
932 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
933 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
934 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
935 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
936 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
937 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
938 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
939 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
940 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
941 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
942 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
943 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
944 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
945 ; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1
946 ; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1
947 ; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1
948 ; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1
949 ; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1
950 ; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1
953 """)
955 print("; INTRINSICS_LIST_BEGIN")
956 for intrinsic, instruction in sorted(items):
957 print("; ", intrinsic, " -> ", instruction,"")
958 print("; INTRINSICS_LIST_END")
959 print("; INTRINSICS: ; INTRINSICS_LIST_END")
961 def gen_tests():
962 items = gen_wmma_load_tests()
963 items += gen_wmma_store_tests()
964 items += gen_ldmatrix_tests()
965 items += gen_wmma_mma_tests()
966 items += gen_mma_tests()
967 gen_check_unsupported_ops(items)
969 parser = argparse.ArgumentParser()
970 parser.add_argument("--ptx", type=int, default=60)
971 parser.add_argument("--gpu-arch", type=int, default=70)
972 args = parser.parse_args()
973 ptx_version = args.ptx
974 gpu_arch = args.gpu_arch
976 gen_tests()