Revert r354244 "[DAGCombiner] Eliminate dead stores to stack."
[llvm-complete.git] / test / CodeGen / NVPTX / wmma.py
blob14bbfd7df094ab038b430f2e7612284bfee391b7
1 # This test generates all variants of wmma intrinsics and verifies that LLVM
2 # generates correct instructions for them.
4 # RUN: python %s > %t.ll
5 # RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll
7 from __future__ import print_function
9 from itertools import product
10 from string import Template
12 def make_wmma_slice_ty(abcd, itype):
13 elt_ty = "<2 x half>" if itype == "f16" else "float"
14 num_elts = 4 if abcd in "cd" and itype == "f16" else 8;
15 return [elt_ty] * num_elts
17 def make_wmma_ld_ret_ty(abc, itype):
18 return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
20 # returns address space
21 def get_aspace(space):
22 space_map = {
23 ".global" : 1,
24 ".shared" : 3,
25 ".const" : 4,
26 ".local" : 5,
27 ".param" : 101,
28 "" : 0,
29 ".generic": 0
31 return space_map[space];
33 def get_pspace(space):
34 return "p%di8" % get_aspace(space);
36 # Convenient test patterns.
37 check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
38 check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
39 check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
41 known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"]
43 def gen_wmma_load_tests():
44 load_template = """
45 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
47 ; CHECK-LABEL: .func {{.*}}test_${function}(
48 define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
49 ; CHECK: ${instruction}
50 ; CHECK: {${check_result}}
51 ; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
52 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
53 ret ${ret_ty} %v0;
56 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
57 define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
58 ; CHECK: ${instruction}
59 ; CHECK: {${check_result}}
60 ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
61 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
62 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
63 ret ${ret_ty} %v0;
65 """
66 intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
67 instruction_template = "wmma.load.${abc}.sync.${layout}.${geom}${space}.${itype}"
69 for geom, abc, layout, space, stride, itype in product(
70 known_geoms,
71 "abc",
72 ["row","col"],
73 ["",".shared",".global"],
74 ["", ".stride"],
75 ["f16", "f32"]):
77 params = {
78 "abc" : abc,
79 "layout" : layout,
80 "space" : space,
81 "stride" : stride,
82 "itype" : itype,
83 "pspace" : get_pspace(space),
84 "as" : "addrspace(%d)" % get_aspace(space),
85 "geom" : geom,
88 if itype == "f32" and abc != "c":
89 continue
91 test_params = params
92 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
93 test_params["function"] = test_params["intrinsic"].replace(".","_")
94 test_params["instruction"] = Template(instruction_template).substitute(params)
95 test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
96 if abc == "c" :
97 test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
98 else:
99 test_params["check_result"] = check_f16_8
101 if stride:
102 test_params["extra_args"] = ", i32 %stride";
103 test_params["stride_pattern"] = ", %r{{[0-9]+}}"
104 else:
105 test_params["extra_args"] = ""
106 test_params["stride_pattern"] = ""
108 print(Template(load_template).substitute(test_params))
110 def make_wmma_slice_args(itype, abcd, prefix="v"):
111 return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t
112 in enumerate(make_wmma_slice_ty(abcd, itype))])
114 def gen_wmma_store_tests():
115 store_template = """
116 declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
118 ; CHECK-LABEL: .func {{.*}}test_${function}(
119 define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
120 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
121 ; CHECK: {${check_args}}
122 ; CHECK: ${stride_pattern}
123 call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
124 ret void
127 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
128 define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
129 ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
130 ; CHECK: ${check_args}
131 ; CHECK: ${stride_pattern}
132 %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
133 call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
134 ret void
137 intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
138 instruction_template = "wmma.store.${abc}.sync.${layout}.${geom}${space}.${itype}"
140 for geom, abc, layout, space, stride, itype in product(
141 known_geoms,
142 "d",
143 ["row","col"],
144 ["",".shared",".global"],
145 ["", ".stride"],
146 ["f16", "f32"]):
148 params = {
149 "abc" : abc,
150 "layout" : layout,
151 "space" : space,
152 "stride" : stride,
153 "itype" : itype,
154 "pspace" : get_pspace(space),
155 "as" : "addrspace(%d)" % get_aspace(space),
156 "geom" : geom,
159 test_params = params
160 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
161 test_params["function"] = test_params["intrinsic"].replace(".","_")
162 test_params["instruction"] = Template(instruction_template).substitute(params)
163 test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
164 test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
165 if stride:
166 test_params["extra_args"] = ", i32 %stride";
167 test_params["stride_pattern"] = ", %r{{[0-9]+}};"
168 else:
169 test_params["extra_args"] = ""
170 test_params["stride_pattern"] = ";"
171 test_params["args"] = make_wmma_slice_args(itype, "d");
173 print(Template(store_template).substitute(test_params))
175 def gen_wmma_mma_tests():
176 mma_template = """
177 declare ${ret_ty} @${intrinsic}(
178 ${args});
180 ; CHECK-LABEL: .func {{.*}}test_${function}(
181 define ${ret_ty} @test_${function}(
182 ${args}) {
183 ; CHECK: ${instruction}
184 ; CHECK-NEXT: ${check_d}
185 ; CHECK-NEXT: ${check_ab}
186 ; CHECK-NEXT: ${check_ab}
187 ; CHECK-NEXT: ${check_c}
188 %r = call ${ret_ty} @${intrinsic}(
189 ${args});
190 ret ${ret_ty} %r;
193 intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
194 instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"
196 for geom, alayout, blayout, ctype, dtype, satf in product(
197 known_geoms,
198 ["row","col"],
199 ["row","col"],
200 ["f16", "f32"],
201 ["f16", "f32"],
202 [".satfinite", ""]):
204 params = {
205 "alayout" : alayout,
206 "blayout" : blayout,
207 "ctype" : ctype,
208 "dtype" : dtype,
209 "satf" : satf,
210 "geom" : geom,
213 test_params = params
214 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
215 test_params["function"] = test_params["intrinsic"].replace(".", "_")
216 test_params["instruction"] = Template(instruction_template).substitute(params)
217 test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
218 test_params["check_ab"] = check_f16_8
219 test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8
220 test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8
221 args = ",\n ".join(make_wmma_slice_args(t, abcd, prefix=abcd)
222 for abcd, t in (("a", "f16"),
223 ("b", "f16"),
224 ("c", ctype)))
225 test_params["args"] = args
226 print(Template(mma_template).substitute(test_params))
228 def main():
229 gen_wmma_load_tests()
230 gen_wmma_store_tests()
231 gen_wmma_mma_tests()
233 main()