1 # This test generates all variants of wmma intrinsics and verifies that LLVM
2 # generates correct instructions for them.
4 # RUN: python %s > %t.ll
5 # RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll
7 from __future__
import print_function
9 from itertools
import product
10 from string
import Template
12 def make_wmma_slice_ty(abcd
, itype
):
13 elt_ty
= "<2 x half>" if itype
== "f16" else "float"
14 num_elts
= 4 if abcd
in "cd" and itype
== "f16" else 8;
15 return [elt_ty
] * num_elts
17 def make_wmma_ld_ret_ty(abc
, itype
):
18 return "{%s}" % ", ".join(make_wmma_slice_ty(abc
, itype
))
20 # returns address space
21 def get_aspace(space
):
31 return space_map
[space
];
33 def get_pspace(space
):
34 return "p%di8" % get_aspace(space
);
36 # Convenient test patterns.
37 check_f16_8
= "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
38 check_f16_4
= "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
39 check_f32_8
= "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
41 known_geoms
= ["m16n16k16", "m8n32k16", "m32n8k16"]
43 def gen_wmma_load_tests():
45 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
47 ; CHECK-LABEL: .func {{.*}}test_${function}(
48 define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
49 ; CHECK: ${instruction}
50 ; CHECK: {${check_result}}
51 ; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
52 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
56 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
57 define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
58 ; CHECK: ${instruction}
59 ; CHECK: {${check_result}}
60 ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
61 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
62 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
66 intrinsic_template
= "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
67 instruction_template
= "wmma.load.${abc}.sync.${layout}.${geom}${space}.${itype}"
69 for geom
, abc
, layout
, space
, stride
, itype
in product(
73 ["",".shared",".global"],
83 "pspace" : get_pspace(space
),
84 "as" : "addrspace(%d)" % get_aspace(space
),
88 if itype
== "f32" and abc
!= "c":
92 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
93 test_params
["function"] = test_params
["intrinsic"].replace(".","_")
94 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
95 test_params
["ret_ty"] = make_wmma_ld_ret_ty(abc
, itype
)
97 test_params
["check_result"] = check_f16_4
if itype
== "f16" else check_f32_8
99 test_params
["check_result"] = check_f16_8
102 test_params
["extra_args"] = ", i32 %stride";
103 test_params
["stride_pattern"] = ", %r{{[0-9]+}}"
105 test_params
["extra_args"] = ""
106 test_params
["stride_pattern"] = ""
108 print(Template(load_template
).substitute(test_params
))
110 def make_wmma_slice_args(itype
, abcd
, prefix
="v"):
111 return ", ".join(["%s %%%s%d" % (t
, prefix
, i
) for i
,t
112 in enumerate(make_wmma_slice_ty(abcd
, itype
))])
114 def gen_wmma_store_tests():
116 declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
118 ; CHECK-LABEL: .func {{.*}}test_${function}(
119 define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
120 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
121 ; CHECK: {${check_args}}
122 ; CHECK: ${stride_pattern}
123 call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
127 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
128 define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
129 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
130 ; CHECK: ${check_args}
131 ; CHECK: ${stride_pattern}
132 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
133 call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
137 intrinsic_template
= "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
138 instruction_template
= "wmma.store.${abc}.sync.${layout}.${geom}${space}.${itype}"
140 for geom
, abc
, layout
, space
, stride
, itype
in product(
144 ["",".shared",".global"],
154 "pspace" : get_pspace(space
),
155 "as" : "addrspace(%d)" % get_aspace(space
),
160 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
161 test_params
["function"] = test_params
["intrinsic"].replace(".","_")
162 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
163 test_params
["ret_ty"] = make_wmma_ld_ret_ty(abc
, itype
)
164 test_params
["check_args"] = check_f16_4
if itype
== "f16" else check_f32_8
166 test_params
["extra_args"] = ", i32 %stride";
167 test_params
["stride_pattern"] = ", %r{{[0-9]+}};"
169 test_params
["extra_args"] = ""
170 test_params
["stride_pattern"] = ";"
171 test_params
["args"] = make_wmma_slice_args(itype
, "d");
173 print(Template(store_template
).substitute(test_params
))
175 def gen_wmma_mma_tests():
177 declare ${ret_ty} @${intrinsic}(
180 ; CHECK-LABEL: .func {{.*}}test_${function}(
181 define ${ret_ty} @test_${function}(
183 ; CHECK: ${instruction}
184 ; CHECK-NEXT: ${check_d}
185 ; CHECK-NEXT: ${check_ab}
186 ; CHECK-NEXT: ${check_ab}
187 ; CHECK-NEXT: ${check_c}
188 %r = call ${ret_ty} @${intrinsic}(
193 intrinsic_template
= "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
194 instruction_template
= "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"
196 for geom
, alayout
, blayout
, ctype
, dtype
, satf
in product(
214 test_params
["intrinsic"] = Template(intrinsic_template
).substitute(params
)
215 test_params
["function"] = test_params
["intrinsic"].replace(".", "_")
216 test_params
["instruction"] = Template(instruction_template
).substitute(params
)
217 test_params
["ret_ty"] = make_wmma_ld_ret_ty("d", dtype
)
218 test_params
["check_ab"] = check_f16_8
219 test_params
["check_c"] = check_f16_4
if ctype
== "f16" else check_f32_8
220 test_params
["check_d"] = check_f16_4
if dtype
== "f16" else check_f32_8
221 args
= ",\n ".join(make_wmma_slice_args(t
, abcd
, prefix
=abcd
)
222 for abcd
, t
in (("a", "f16"),
225 test_params
["args"] = args
226 print(Template(mma_template
).substitute(test_params
))
229 gen_wmma_load_tests()
230 gen_wmma_store_tests()