[NFC][RemoveDIs] Prefer iterators over inst-pointers in InstCombine
[llvm-project.git] / llvm / test / CodeGen / NVPTX / wmma.py
blobb7153d684671f6cbda10da33c2f1bcd6b04d79bf
1 # This test generates all variants of wmma intrinsics and verifies that LLVM
2 # generates correct instructions for them.
4 # Check all variants of instructions supported by PTX60 on SM70
5 # RUN: %python %s --ptx=60 --gpu-arch=70 > %t-ptx60-sm_70.ll
6 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
7 # RUN: --check-prefixes=INTRINSICS,M16N16
8 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
9 # RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
10 # RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
11 # RUN: | FileCheck %t-ptx60-sm_70.ll
12 # RUN: %if ptxas %{ \
13 # RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
14 # RUN: | %ptxas-verify -arch=sm_70 \
15 # RUN: %}
17 # Check all variants of instructions supported by PTX61 on SM70
18 # RUN: %python %s --ptx=61 --gpu-arch=70 > %t-ptx61-sm_70.ll
19 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
20 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM
21 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
22 # RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
23 # RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
24 # RUN: | FileCheck %t-ptx61-sm_70.ll
25 # RUN: %if ptxas-9.1 %{ \
26 # RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
27 # RUN: | %ptxas-verify -arch=sm_70 \
28 # RUN: %}
30 # Check all variants of instructions supported by PTX63 on SM72
31 # RUN: %python %s --ptx=63 --gpu-arch=72 > %t-ptx63-sm_72.ll
32 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
33 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT
34 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
35 # RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
36 # RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
37 # RUN: | FileCheck %t-ptx63-sm_72.ll
38 # RUN: %if ptxas-10.0 %{ \
39 # RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
40 # RUN: | %ptxas-verify -arch=sm_72 \
41 # RUN: %}
43 # Check all variants of instructions supported by PTX63 on SM75
44 # RUN: %python %s --ptx=63 --gpu-arch=75 > %t-ptx63-sm_75.ll
45 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
46 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT
47 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
48 # RUN: --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
49 # RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
50 # RUN: | FileCheck %t-ptx63-sm_75.ll
51 # RUN: %if ptxas-10.0 %{ \
52 # RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
53 # RUN: | %ptxas-verify -arch=sm_75 \
54 # RUN: %}
56 # Check all variants of instructions supported by PTX64 on SM70+
57 # RUN: %python %s --ptx=64 --gpu-arch=70 > %t-ptx64-sm_70.ll
58 # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
59 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA
60 # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
61 # RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT,NOLDMATRIX
62 # RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \
63 # RUN: | FileCheck %t-ptx64-sm_70.ll
64 # RUN: %if ptxas-10.1 %{ \
65 # RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \
66 # RUN: | %ptxas-verify -arch=sm_70 \
67 # RUN: %}
69 # Check all variants of instructions supported by PTX65 on SM75+
70 # RUN: %python %s --ptx=65 --gpu-arch=75 > %t-ptx65-sm_75.ll
71 # RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
72 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA,PTX65LDMATRIX
73 # RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
74 # RUN: --check-prefixes=INTRINSICS
75 # RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
76 # RUN: | FileCheck %t-ptx65-sm_75.ll
77 # RUN: %if ptxas-10.2 %{ \
78 # RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
79 # RUN: | %ptxas-verify -arch=sm_75 \
80 # RUN: %}
82 # Check all variants of instructions supported by PTX71 on SM80+
83 # RUN: %python %s --ptx=71 --gpu-arch=80 > %t-ptx71-sm_80.ll
84 # RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
85 # RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX65LDMATRIX,PTX71MMA
86 # RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
87 # RUN: --check-prefixes=INTRINSICS
88 # RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \
89 # RUN: | FileCheck %t-ptx71-sm_80.ll
90 # RUN: %if ptxas-11.1 %{ \
91 # RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \
92 # RUN: | %ptxas-verify -arch=sm_80 \
93 # RUN: %}
95 from __future__ import print_function
97 import argparse
98 from itertools import product
99 from string import Template
102 class MMAType:
103 def __init__(self, ptx_type):
104 self.ptx_type = ptx_type
105 self.llvm_type = {
106 "f16": "<2 x half>",
107 "f32": "float",
108 "f64": "double",
109 "s32": "i32",
110 "b16": "i32",
111 "s8": "i32",
112 "u8": "i32",
113 "s4": "i32",
114 "u4": "i32",
115 "b1": "i32",
116 "bf16": "i32",
117 "tf32": "i32",
118 }[ptx_type]
120 self.ptx_reg_pattern = {
121 "f16": "%r[0-9]+",
122 "f32": "%f[0-9]+",
123 "f64": "%fd[0-9]+",
124 }.get(ptx_type, "%r[0-9]+")
126 def __repr__(self):
127 return "%s/%s" % (self.ptx_type, self.llvm_type)
130 class MMAFrag:
131 def __init__(self, geom, frag, ptx_elt_type):
132 self.geom = geom
133 self.frag = frag
134 self.mma_type = MMAType(ptx_elt_type)
135 self.nregs = {
136 # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
137 "m16n16k16:a:u8": 2,
138 "m16n16k16:a:s8": 2,
139 "m16n16k16:b:u8": 2,
140 "m16n16k16:b:s8": 2,
141 "m16n16k16:c:s32": 8,
142 "m16n16k16:d:s32": 8,
143 "m8n32k16:a:u8": 1,
144 "m8n32k16:a:s8": 1,
145 "m8n32k16:b:u8": 4,
146 "m8n32k16:b:s8": 4,
147 "m8n32k16:c:s32": 8,
148 "m8n32k16:d:s32": 8,
149 "m32n8k16:a:u8": 4,
150 "m32n8k16:a:s8": 4,
151 "m32n8k16:b:u8": 1,
152 "m32n8k16:b:s8": 1,
153 "m32n8k16:c:s32": 8,
154 "m32n8k16:d:s32": 8,
155 "m8n8k16:a:u8": 1,
156 "m8n8k16:a:s8": 1,
157 "m8n8k16:b:u8": 1,
158 "m8n8k16:b:s8": 1,
159 "m8n8k16:c:s32": 2,
160 "m8n8k16:d:s32": 2,
161 "m16n8k16:a:u8": 2,
162 "m16n8k16:a:s8": 2,
163 "m16n8k16:b:u8": 1,
164 "m16n8k16:b:s8": 1,
165 "m16n8k16:c:s32": 4,
166 "m16n8k16:d:s32": 4,
167 "m16n8k32:a:u8": 4,
168 "m16n8k32:a:s8": 4,
169 "m16n8k32:b:u8": 2,
170 "m16n8k32:b:s8": 2,
171 "m16n8k32:c:s32": 4,
172 "m16n8k32:d:s32": 4,
173 # u4/s4 -> s32 @ m8n8k32 (u4/s4)
174 "m8n8k32:a:u4": 1,
175 "m8n8k32:a:s4": 1,
176 "m8n8k32:b:u4": 1,
177 "m8n8k32:b:s4": 1,
178 "m8n8k32:c:s32": 2,
179 "m8n8k32:d:s32": 2,
180 "m16n8k32:a:u4": 2,
181 "m16n8k32:a:s4": 2,
182 "m16n8k32:b:u4": 1,
183 "m16n8k32:b:s4": 1,
184 "m16n8k32:c:s32": 4,
185 "m16n8k32:d:s32": 4,
186 "m16n8k64:a:u4": 4,
187 "m16n8k64:a:s4": 4,
188 "m16n8k64:b:u4": 2,
189 "m16n8k64:b:s4": 2,
190 "m16n8k64:c:s32": 4,
191 "m16n8k64:d:s32": 4,
192 # b1 -> s32 @ m8n8k128(b1)
193 "m8n8k128:a:b1": 1,
194 "m8n8k128:b:b1": 1,
195 "m8n8k128:c:s32": 2,
196 "m8n8k128:d:s32": 2,
197 "m16n8k128:a:b1": 2,
198 "m16n8k128:b:b1": 1,
199 "m16n8k128:c:s32": 4,
200 "m16n8k128:d:s32": 4,
201 "m16n8k256:a:b1": 4,
202 "m16n8k256:b:b1": 2,
203 "m16n8k256:c:s32": 4,
204 "m16n8k256:d:s32": 4,
205 # bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
206 "m16n16k16:a:bf16": 4,
207 "m16n16k16:b:bf16": 4,
208 "m8n32k16:a:bf16": 2,
209 "m8n32k16:b:bf16": 8,
210 "m32n8k16:a:bf16": 8,
211 "m32n8k16:b:bf16": 2,
212 "m16n8k16:a:bf16": 4,
213 "m16n8k16:b:bf16": 2,
214 "m16n8k16:c:f32": 4,
215 "m16n8k16:d:f32": 4,
216 "m16n8k8:a:bf16": 2,
217 "m16n8k8:b:bf16": 1,
218 "m16n8k8:c:f32": 4,
219 "m16n8k8:d:f32": 4,
220 "m8n8k4:a:f64": 1,
221 "m8n8k4:b:f64": 1,
222 "m8n8k4:c:f64": 2,
223 "m8n8k4:d:f64": 2,
224 # tf32 -> s32 @ m16n16k8
225 "m16n16k8:a:tf32": 4,
226 "m16n16k8:b:tf32": 4,
227 "m16n8k4:a:tf32": 2,
228 "m16n8k4:b:tf32": 1,
229 "m16n8k4:c:f32": 4,
230 "m16n8k4:d:f32": 4,
231 "m16n8k8:a:tf32": 4,
232 "m16n8k8:b:tf32": 2,
233 "m16n8k8:c:f32": 4,
234 "m16n8k8:d:f32": 4,
235 "m8n8k4:a:f16": 2,
236 "m8n8k4:b:f16": 2,
237 "m16n8k8:a:f16": 2,
238 "m16n8k8:b:f16": 1,
239 "m16n8k8:c:f16": 2,
240 "m16n8k8:d:f16": 2,
241 "m16n8k8:c:f32": 4,
242 "m16n8k8:d:f32": 4,
243 "m16n8k16:a:f16": 4,
244 "m16n8k16:b:f16": 2,
245 "m16n8k16:c:f16": 2,
246 "m16n8k16:d:f16": 2,
247 "m16n8k16:c:f32": 4,
248 "m16n8k16:d:f32": 4,
249 # ldmatrix
250 "m8n8:x1:b16": 1,
251 "m8n8:x2:b16": 2,
252 "m8n8:x4:b16": 4,
253 }.get(
254 "%s:%s:%s" % (geom, frag, ptx_elt_type),
256 # All other FP shape/fragment/type combinations have the same size
257 "a:f16": 8,
258 "b:f16": 8,
259 "c:f16": 4,
260 "d:f16": 4,
261 "c:f32": 8,
262 "d:f32": 8,
263 }.get("%s:%s" % (frag, ptx_elt_type), None),
265 assert self.nregs
267 def __repr__(self):
268 return "%s:%s:%s%s" % (
269 self.geom,
270 self.frag,
271 self.mma_type,
272 "" if self.nregs == 1 else ("*%d" % self.nregs),
276 class MMAOp:
277 def __init__(self, a, b, c, d):
278 self.a = a
279 self.b = b
280 self.c = c
281 self.d = d
283 def __repr__(self):
284 return "{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d)
287 def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
288 ops = []
289 for geom, type_a, type_c in product(geoms, types_a, types_c):
290 for type_b, type_d in product(
291 types_b if types_b else [type_a], types_d if types_d else [type_c]
293 ops.append(
294 MMAOp(
295 MMAFrag(geom, "a", type_a),
296 MMAFrag(geom, "b", type_b),
297 MMAFrag(geom, "c", type_c),
298 MMAFrag(geom, "d", type_d),
301 return ops
304 def make_ldst_ops(geoms, frags, types):
305 return [
306 MMAFrag(geom, frag, ptx_type)
307 for (geom, frag, ptx_type) in product(geoms, frags, types)
311 def make_ldmatrix_ops(geoms, frags, types):
312 return [
313 MMAFrag(geom, frag, ptx_type)
314 for (geom, frag, ptx_type) in product(geoms, frags, types)
318 def get_wmma_ops():
319 return (
320 make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
321 + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["bf16"], [], ["f32"], [])
322 + make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
323 + make_mma_ops(
324 ["m16n16k16", "m32n8k16", "m8n32k16"],
325 ["f16"],
327 ["f16", "f32"],
328 ["f16", "f32"],
330 + make_mma_ops(
331 ["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], []
333 + make_mma_ops(["m8n8k32"], ["s4", "u4"], [], ["s32"], [])
334 + make_mma_ops(["m8n8k128"], ["b1"], [], ["s32"], [])
338 def get_mma_ops():
339 return (
340 make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
341 + make_mma_ops(["m16n8k4", "m16n8k8"], ["tf32"], [], ["f32"], [])
342 + make_mma_ops(["m16n8k16", "m16n8k8"], ["bf16"], [], ["f32"], [])
343 + make_mma_ops(
344 ["m8n8k4", "m16n8k8", "m16n8k16"],
345 ["f16"],
347 ["f16", "f32"],
348 ["f16", "f32"],
350 + make_mma_ops(
351 ["m8n8k16", "m16n8k16", "m16n8k32"], ["s8", "u8"], ["s8", "u8"], ["s32"], []
353 + make_mma_ops(
354 ["m8n8k32", "m16n8k32", "m16n8k64"], ["s4", "u4"], ["s4", "u4"], ["s32"], []
356 + make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], ["b1"], [], ["s32"], [])
360 def get_ldst_ops(kind):
361 ldst_ops = (
362 make_ldst_ops(
363 ["m16n16k16", "m32n8k16", "m8n32k16"],
364 ["a", "b"],
365 ["f16", "u8", "s8", "bf16"],
367 + make_ldst_ops(
368 ["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"]
370 + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4", "u4"])
371 + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"])
372 + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])
373 + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"])
374 + make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"])
375 + make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"])
377 return [x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
380 def get_ldmatrix_ops():
381 return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
384 def is_wmma_geom_supported(geom):
385 # geometries for FP and ints.
386 if geom in ["m8n32k16", "m32n8k16"]:
387 return ptx_version >= 61
388 # geometries for sub-ints.
389 if geom in ["m8n8k32", "m8n8k128"]:
390 return ptx_version >= 63 and gpu_arch >= 75
391 if geom == "m16n16k16":
392 return ptx_version >= 60
393 if geom == "m16n8k8":
394 return ptx_version >= 65
395 if geom in ["m16n16k8", "m8n8k4"]:
396 return ptx_version >= 70
397 assert False # Unexpected geometry.
400 def is_mma_geom_supported(geom):
401 # geometries for FP and ints.
402 if geom == "m8n8k4":
403 return ptx_version >= 64
404 if geom in ["m16n8k8", "m8n8k16", "m8n8k32"]:
405 return ptx_version >= 65
406 if geom in [
407 "m16n8k16",
408 "m16n8k4",
409 "m16n8k32",
410 "m16n8k64",
411 "m8n8k128",
412 "m16n8k128",
413 "m16n8k256",
415 return ptx_version >= 70
416 assert False # Unexpected geometry.
419 def is_ldmatrix_geom_supported(geom):
420 if geom in ["m8n8"]:
421 return ptx_version >= 65 and gpu_arch >= 75
422 assert False # Unexpected geometry.
425 def is_type_supported(ptx_type):
426 if ptx_type in ["s8", "u8", "s32"]:
427 return ptx_version >= 63 and gpu_arch >= 72
428 if ptx_type in ["s4", "u4", "b1"]:
429 return ptx_version >= 63 and gpu_arch >= 75
430 if ptx_type == "b16":
431 return ptx_version >= 65 and gpu_arch >= 75
432 if ptx_type in ["bf16", "tf32", "f64"]:
433 return ptx_version >= 70
434 return ptx_version >= 60 and gpu_arch >= 70
437 def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
438 if not (
439 is_type_supported(op.a.mma_type.ptx_type) and is_wmma_geom_supported(op.a.geom)
441 return False
443 # rnd is only supported for FP64 WMMA
444 if rnd and op.a.mma_type.ptx_type != "f64":
445 return False
447 if satf:
448 # satfinite for floating points was removed in PTX 6.5
449 if op.a.mma_type.ptx_type == "f16" and ptx_version >= 65:
450 return False
451 if not op.a.mma_type.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
452 return False
454 # sub-integer require row/col layout.
455 if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
456 return layout_a == "row" and layout_b == "col"
457 return True
460 def is_mma_variant_supported(op, layout_a, layout_b, satf):
461 if not (
462 is_type_supported(op.a.mma_type.ptx_type) and is_mma_geom_supported(op.a.geom)
464 return False
466 if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]:
467 return False
469 # If the type of C is f32 then so must the type of D
470 if (
471 op.a.geom == "m8n8k4"
472 and op.c.mma_type.ptx_type == "f32"
473 and op.d.mma_type.ptx_type != "f32"
475 return False
477 # A and B type must be the same. C and D type must be the same
478 if op.a.geom == "m16n8k8" and (
479 op.a.mma_type.ptx_type != op.b.mma_type.ptx_type
480 or op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
482 return False
484 # C and D type must be the same
485 if op.a.geom == "m16n8k16" and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type:
486 return False
488 # Require row/col layout for all MMA except m8n8k4 on FP16
489 if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
490 return layout_a == "row" and layout_b == "col"
491 return True
494 def is_ldst_variant_supported(frag, layout):
495 if not (
496 is_type_supported(frag.mma_type.ptx_type) and is_wmma_geom_supported(frag.geom)
498 return False
499 if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
500 # sub-integer require sm_75 and ptx63, row/col layout for a/b.
501 return (
502 (frag.frag == "a" and layout == "row")
503 or (frag.frag == "b" and layout == "col")
504 or frag.frag in ["c", "d"]
506 return True
509 def is_ldmatrix_variant_supported(frag):
510 if not (
511 is_type_supported(frag.mma_type.ptx_type)
512 and is_ldmatrix_geom_supported(frag.geom)
514 return False
515 return frag.frag in ["x1", "x2", "x4"]
518 def make_wmma_slice_ty(frag):
519 return [frag.mma_type.llvm_type] * frag.nregs
522 def make_wmma_ld_ret_ty(frag):
523 results = make_wmma_slice_ty(frag)
524 if len(results) == 1:
525 return "%s" % results[0]
526 return "{%s}" % ", ".join(results)
529 # returns address space
530 def get_aspace(space):
531 space_map = {
532 ".global": 1,
533 ".shared": 3,
534 ".const": 4,
535 ".local": 5,
536 ".param": 101,
537 "": 0,
538 ".generic": 0,
540 return space_map[space]
543 def get_pspace(space):
544 return "p%di8" % get_aspace(space)
547 def check_pattern(frag):
548 return "{{%s}}" % ", *".join([frag.mma_type.ptx_reg_pattern] * frag.nregs)
551 def gen_wmma_load_tests():
552 load_template = """
553 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
555 ; CHECK-LABEL: .func {{.*}}test_${function}(
556 define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
557 ; CHECK: ${instruction}
558 ; CHECK: {${check_result}}
559 ; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
560 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
561 ret ${ret_ty} %v0;
564 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
565 define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
566 ; CHECK: ${instruction}
567 ; CHECK: {${check_result}}
568 ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
569 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
570 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
571 ret ${ret_ty} %v0;
574 intrinsic_template = (
575 "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
577 instruction_template = (
578 "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
581 generated_items = []
583 for frag, layout, space, stride in product(
584 get_ldst_ops("load"),
585 ["row", "col"],
586 ["", ".shared", ".global"],
587 ["", ".stride"],
589 if not is_ldst_variant_supported(frag, layout):
590 continue
592 params = {
593 "abc": frag.frag,
594 "aligned": ".aligned" if ptx_version >= 63 else "",
595 "layout": layout,
596 "space": space,
597 "stride": stride,
598 "itype": frag.mma_type.ptx_type,
599 "pspace": get_pspace(space),
600 "as": "addrspace(%d)" % get_aspace(space),
601 "geom": frag.geom,
604 test_params = params
605 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
606 test_params["function"] = test_params["intrinsic"].replace(".", "_")
607 test_params["instruction"] = Template(instruction_template).substitute(params)
608 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
609 test_params["check_result"] = check_pattern(frag)
611 if stride:
612 test_params["extra_args"] = ", i32 %stride"
613 test_params["stride_pattern"] = ", %r{{[0-9]+}}"
614 else:
615 test_params["extra_args"] = ""
616 test_params["stride_pattern"] = ""
618 print(Template(load_template).substitute(test_params))
620 generated_items.append((test_params["intrinsic"], test_params["instruction"]))
622 return generated_items
625 def make_wmma_slice_args(frag):
626 return ", ".join(
628 "%s %%%s%d" % (t, frag.frag, i)
629 for i, t in enumerate(make_wmma_slice_ty(frag))
634 def gen_wmma_store_tests():
635 store_template = """
636 declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
638 ; CHECK-LABEL: .func {{.*}}test_${function}(
639 define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
640 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
641 ; CHECK: {${check_args}}
642 ; CHECK: ${stride_pattern}
643 call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
644 ret void
647 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
648 define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
649 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
650 ; CHECK: ${check_args}
651 ; CHECK: ${stride_pattern}
652 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
653 call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
654 ret void
657 intrinsic_template = (
658 "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
660 instruction_template = (
661 "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
664 generated_items = []
666 for frag, layout, space, stride in product(
667 get_ldst_ops("store"),
668 ["row", "col"],
669 ["", ".shared", ".global"],
670 ["", ".stride"],
673 if not is_ldst_variant_supported(frag, layout):
674 continue
676 params = {
677 "abc": frag.frag,
678 "aligned": ".aligned" if ptx_version >= 63 else "",
679 "layout": layout,
680 "space": space,
681 "stride": stride,
682 "itype": frag.mma_type.ptx_type,
683 "pspace": get_pspace(space),
684 "as": "addrspace(%d)" % get_aspace(space),
685 "geom": frag.geom,
688 test_params = params
689 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
690 test_params["function"] = test_params["intrinsic"].replace(".", "_")
691 test_params["instruction"] = Template(instruction_template).substitute(params)
692 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
693 test_params["check_args"] = check_pattern(frag)
694 if stride:
695 test_params["extra_args"] = ", i32 %stride"
696 test_params["stride_pattern"] = ", %r{{[0-9]+}};"
697 else:
698 test_params["extra_args"] = ""
699 test_params["stride_pattern"] = ";"
700 test_params["args"] = make_wmma_slice_args(frag)
702 print(Template(store_template).substitute(test_params))
703 generated_items.append((test_params["intrinsic"], test_params["instruction"]))
705 return generated_items
708 def gen_ldmatrix_tests():
709 ldmatrix_template = """
710 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src);
712 ; CHECK-LABEL: .func {{.*}}test_${function}(
713 define ${ret_ty} @test_${function}(i8 ${as}* %src) {
714 ; CHECK: ${instruction}
715 ; CHECK: {${check_result}}
716 ; CHECK: [%rd{{[0-9]+}}]
717 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src);
718 ret ${ret_ty} %v0;
721 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
722 define ${ret_ty} @test_${function}_o(i8 ${as}* %src) {
723 ; CHECK: ${instruction}
724 ; CHECK: {${check_result}}
725 ; CHECK: [%rd{{[0-9]+}}+128]
726 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
727 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1);
728 ret ${ret_ty} %v0;
731 intrinsic_template = (
732 "llvm.nvvm.ldmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
734 instruction_template = (
735 "ldmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
738 generated_items = []
740 for frag, space, trans in product(
741 get_ldmatrix_ops(),
742 ["", ".shared"],
743 ["", ".trans"],
745 if not is_ldmatrix_variant_supported(frag):
746 continue
748 params = {
749 "frag": frag.frag,
750 "space": space,
751 "trans": trans,
752 "itype": frag.mma_type.ptx_type,
753 "pspace": get_pspace(space),
754 "as": "addrspace(%d)" % get_aspace(space),
755 "geom": frag.geom,
758 test_params = params
759 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
760 test_params["function"] = test_params["intrinsic"].replace(".", "_")
761 test_params["instruction"] = Template(instruction_template).substitute(params)
762 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
763 test_params["check_result"] = check_pattern(frag)
765 print(Template(ldmatrix_template).substitute(test_params))
767 generated_items.append((test_params["intrinsic"], test_params["instruction"]))
769 return generated_items
772 def mma_signature(op):
773 if op.a.mma_type.ptx_type == "f16":
774 # FP16 ops identified by accumulator & result type.
775 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
776 elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type:
777 # other ops are identified by input types.
778 return "%s.%s" % (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
779 else:
780 # if input types are the same, it only appears once.
781 return op.a.mma_type.ptx_type
784 def mma_ptx_signature(op):
785 # Encode all four types as D.A.B.C
786 return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
789 def wmma_signature(op):
790 if op.a.mma_type.ptx_type == "f16":
791 # FP16 ops identified by accumulator & result type.
792 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
793 else:
794 # other ops are identified by input type.
795 return op.a.mma_type.ptx_type
798 def wmma_ptx_signature(op):
799 if op.a.mma_type.ptx_type == "f16":
800 # FP16 instructions use D.C
801 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
802 else:
803 # other instructions encode all four types as D.A.B.C
804 return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
807 def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
808 mma_template = """
809 declare ${ret_ty} @${intrinsic}(
810 ${args});
812 ; CHECK-LABEL: .func {{.*}}test_${function}(
813 define ${ret_ty} @test_${function}(
814 ${args}) {
815 ; CHECK: ${instruction}
816 ; CHECK-NEXT: ${check_d}
817 ; CHECK-NEXT: ${check_a}
818 ; CHECK-NEXT: ${check_b}
819 ; CHECK-NEXT: ${check_c}
820 %r = call ${ret_ty} @${intrinsic}(
821 ${args});
822 ret ${ret_ty} %r;
826 test_params = params
827 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
828 test_params["function"] = test_params["intrinsic"].replace(".", "_")
829 test_params["instruction"] = Template(instruction_template).substitute(params)
830 test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
831 test_params["check_a"] = check_pattern(op.a)
832 test_params["check_b"] = check_pattern(op.b)
833 test_params["check_c"] = check_pattern(op.c)
834 test_params["check_d"] = check_pattern(op.d)
835 args = ",\n ".join(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c))
836 test_params["args"] = args
837 print(Template(mma_template).substitute(test_params))
838 return (test_params["intrinsic"], test_params["instruction"])
841 def get_b1_ops(ptx_type):
842 if ptx_type != "b1":
843 return [""]
844 if ptx_version >= 71:
845 return [".xor.popc", ".and.popc"]
846 return [".xor.popc"]
849 def gen_wmma_mma_tests():
850 wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
851 wmma_instruction_template = "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
853 generated_items = []
855 for op, alayout, blayout, rnd, satf in product(
856 get_wmma_ops(),
857 ["row", "col"],
858 ["row", "col"],
859 [".rn", ".rz", ".rm", ".rp", ""],
860 [".satfinite", ""],
863 if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf):
864 continue
866 for b1op in get_b1_ops(op.a.mma_type.ptx_type):
867 params = {
868 "aligned": ".aligned" if ptx_version >= 63 else "",
869 "alayout": alayout,
870 "blayout": blayout,
871 "intrinsic_signature": wmma_signature(op),
872 "ptx_signature": wmma_ptx_signature(op),
873 "satf": satf,
874 "rnd": rnd,
875 "geom": op.a.geom,
876 "b1op": b1op,
879 intrinsic_template = wmma_intrinsic_template
880 instruction_template = wmma_instruction_template
882 generated_items.append(
883 common_mma_test_gen(
884 params, op, intrinsic_template, instruction_template
888 return generated_items
891 def gen_mma_tests():
892 mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
893 mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
895 generated_items = []
897 for op, alayout, blayout, satf in product(
898 get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]
901 if not is_mma_variant_supported(op, alayout, blayout, satf):
902 continue
904 for b1op in get_b1_ops(op.a.mma_type.ptx_type):
905 params = {
906 "aligned": ".aligned" if ptx_version >= 63 else "",
907 "alayout": alayout,
908 "blayout": blayout,
909 "intrinsic_signature": mma_signature(op),
910 "ptx_signature": mma_ptx_signature(op),
911 "satf": satf,
912 "geom": op.a.geom,
913 "b1op": b1op,
916 intrinsic_template = mma_intrinsic_template
917 instruction_template = mma_instruction_template
919 generated_items.append(
920 common_mma_test_gen(
921 params, op, intrinsic_template, instruction_template
925 return generated_items
928 # Append complete list of intrinsics and instructions we've generated tests for.
929 # Generate set of checks to verify that that we did generate sensible set of
930 # tests for the given combination of PTX and SM variants.
932 def gen_check_unsupported_ops(items):
933 print(
934 "; Complete list of intrinsics supported by PTX%d on sm_%d"
935 % (ptx_version, gpu_arch)
937 print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
938 print(
941 ; NOEXTGEOM-NOT: {{m8n32|m32n8}}
942 ; NOINT-NOT: .{{s32|s8}}
943 ; NOSUBINT-NOT: {{s4|u4|b1}}
944 ; NOMMA-NOT: .m8n8k4.
945 ; NOALTFLOAT-NOT: .{{bf16|tf32}}
946 ; NODOUBLE-NOT: .f64
947 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned
949 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
950 ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
951 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f32
952 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f16
953 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f16
954 ; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f32
956 ; PTX60 adds support for m32n8k16/m8n32k16 geometries.
957 ; EXTGEOM-DAG: m32n8k16.load.{{[ab].*}}.f16.p
958 ; EXTGEOM-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
959 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f32
960 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f16
961 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f16
962 ; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f32
964 ; EXTGEOM-DAG: m8n32k16.load.{{[ab].*}}.f16.p
965 ; EXTGEOM-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
966 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f32
967 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f16
968 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f16
969 ; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f32
971 ; INT-DAG: m16n16k16.load.{{[ab].*}}.s8.p
972 ; INT-DAG: m8n32k16.load.{{[ab].*}}.s8.p
973 ; INT-DAG: m32n8k16.load.{{[ab].*}}.s8.p
974 ; INT-DAG: m16n16k16.load.{{[ab].*}}.u8.p
975 ; INT-DAG: m8n32k16.load.{{[ab].*}}.u8.p
976 ; INT-DAG: m32n8k16.load.{{[ab].*}}.u8.p
977 ; INT-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
978 ; INT-DAG: m16n16k16.mma.{{.*}}.u8
979 ; INT-DAG: m16n16k16.mma.{{.*}}.s8
980 ; INT-DAG: m8n32k16.mma.{{.*}}.u8
981 ; INT-DAG: m8n32k16.mma.{{.*}}.s8
982 ; INT-DAG: m32n8k16.mma.{{.*}}.u8
983 ; INT-DAG: m32n8k16.mma.{{.*}}.s8
985 ; SUBINT-DAG: m8n8k128.load.{{[ab].*}}.b1.p
986 ; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.s4.p
987 ; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.u4.p
988 ; SUBINT-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
989 ; SUBINT-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
990 ; SUBINT-DAG: m8n8k32.mma.{{.*}}.u4
991 ; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4
992 ; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1
994 ; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p
995 ; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p
996 ; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p
997 ; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p
998 ; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16
999 ; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16
1000 ; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16
1001 ; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32
1003 ; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p
1004 ; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p
1005 ; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64
1007 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32
1008 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16
1009 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16
1010 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32
1012 ; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16
1013 ; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32
1014 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8
1015 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8
1016 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8
1017 ; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8
1018 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4
1019 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4
1020 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
1021 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
1023 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.b16
1024 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.b16
1025 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.b16
1026 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.b16
1027 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.b16
1028 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.b16
1029 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.shared.b16
1030 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.shared.b16
1031 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.shared.b16
1032 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16
1033 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
1034 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
1036 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
1037 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
1038 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
1039 ; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16
1040 ; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16
1041 ; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16
1042 ; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32
1043 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
1044 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
1045 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
1046 ; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
1047 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
1048 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
1049 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
1050 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
1051 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
1052 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
1053 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
1054 ; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
1055 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
1056 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
1057 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
1058 ; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
1059 ; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1
1060 ; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1
1061 ; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1
1062 ; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1
1063 ; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1
1064 ; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1
1070 print("; INTRINSICS_LIST_BEGIN")
1071 for intrinsic, instruction in sorted(items):
1072 print("; ", intrinsic, " -> ", instruction, "")
1073 print("; INTRINSICS_LIST_END")
1074 print("; INTRINSICS: ; INTRINSICS_LIST_END")
1077 def gen_tests():
1078 items = gen_wmma_load_tests()
1079 items += gen_wmma_store_tests()
1080 items += gen_ldmatrix_tests()
1081 items += gen_wmma_mma_tests()
1082 items += gen_mma_tests()
1083 gen_check_unsupported_ops(items)
1086 parser = argparse.ArgumentParser()
1087 parser.add_argument("--ptx", type=int, default=60)
1088 parser.add_argument("--gpu-arch", type=int, default=70)
1089 args = parser.parse_args()
1090 ptx_version = args.ptx
1091 gpu_arch = args.gpu_arch
1093 gen_tests()