[RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices (...
[llvm-project.git] / llvm / lib / BinaryFormat / AMDGPUMetadataVerifier.cpp
blob33eed07c46292f3e6e027b7de2d2b9e1d28c151f
1 //===- AMDGPUMetadataVerifier.cpp - MsgPack Types ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Implements a verifier for AMDGPU HSA metadata.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/BinaryFormat/AMDGPUMetadataVerifier.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/StringSwitch.h"
18 #include "llvm/BinaryFormat/MsgPackDocument.h"
20 #include <utility>
22 namespace llvm {
23 namespace AMDGPU {
24 namespace HSAMD {
25 namespace V3 {
27 bool MetadataVerifier::verifyScalar(
28 msgpack::DocNode &Node, msgpack::Type SKind,
29 function_ref<bool(msgpack::DocNode &)> verifyValue) {
30 if (!Node.isScalar())
31 return false;
32 if (Node.getKind() != SKind) {
33 if (Strict)
34 return false;
35 // If we are not strict, we interpret string values as "implicitly typed"
36 // and attempt to coerce them to the expected type here.
37 if (Node.getKind() != msgpack::Type::String)
38 return false;
39 StringRef StringValue = Node.getString();
40 Node.fromString(StringValue);
41 if (Node.getKind() != SKind)
42 return false;
44 if (verifyValue)
45 return verifyValue(Node);
46 return true;
49 bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
50 if (!verifyScalar(Node, msgpack::Type::UInt))
51 if (!verifyScalar(Node, msgpack::Type::Int))
52 return false;
53 return true;
56 bool MetadataVerifier::verifyArray(
57 msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
58 std::optional<size_t> Size) {
59 if (!Node.isArray())
60 return false;
61 auto &Array = Node.getArray();
62 if (Size && Array.size() != *Size)
63 return false;
64 return llvm::all_of(Array, verifyNode);
67 bool MetadataVerifier::verifyEntry(
68 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
69 function_ref<bool(msgpack::DocNode &)> verifyNode) {
70 auto Entry = MapNode.find(Key);
71 if (Entry == MapNode.end())
72 return !Required;
73 return verifyNode(Entry->second);
76 bool MetadataVerifier::verifyScalarEntry(
77 msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
78 msgpack::Type SKind,
79 function_ref<bool(msgpack::DocNode &)> verifyValue) {
80 return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
81 return verifyScalar(Node, SKind, verifyValue);
82 });
85 bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
86 StringRef Key, bool Required) {
87 return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
88 return verifyInteger(Node);
89 });
92 bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
93 if (!Node.isMap())
94 return false;
95 auto &ArgsMap = Node.getMap();
97 if (!verifyScalarEntry(ArgsMap, ".name", false,
98 msgpack::Type::String))
99 return false;
100 if (!verifyScalarEntry(ArgsMap, ".type_name", false,
101 msgpack::Type::String))
102 return false;
103 if (!verifyIntegerEntry(ArgsMap, ".size", true))
104 return false;
105 if (!verifyIntegerEntry(ArgsMap, ".offset", true))
106 return false;
107 if (!verifyScalarEntry(ArgsMap, ".value_kind", true, msgpack::Type::String,
108 [](msgpack::DocNode &SNode) {
109 return StringSwitch<bool>(SNode.getString())
110 .Case("by_value", true)
111 .Case("global_buffer", true)
112 .Case("dynamic_shared_pointer", true)
113 .Case("sampler", true)
114 .Case("image", true)
115 .Case("pipe", true)
116 .Case("queue", true)
117 .Case("hidden_block_count_x", true)
118 .Case("hidden_block_count_y", true)
119 .Case("hidden_block_count_z", true)
120 .Case("hidden_group_size_x", true)
121 .Case("hidden_group_size_y", true)
122 .Case("hidden_group_size_z", true)
123 .Case("hidden_remainder_x", true)
124 .Case("hidden_remainder_y", true)
125 .Case("hidden_remainder_z", true)
126 .Case("hidden_global_offset_x", true)
127 .Case("hidden_global_offset_y", true)
128 .Case("hidden_global_offset_z", true)
129 .Case("hidden_grid_dims", true)
130 .Case("hidden_none", true)
131 .Case("hidden_printf_buffer", true)
132 .Case("hidden_hostcall_buffer", true)
133 .Case("hidden_heap_v1", true)
134 .Case("hidden_default_queue", true)
135 .Case("hidden_completion_action", true)
136 .Case("hidden_multigrid_sync_arg", true)
137 .Case("hidden_dynamic_lds_size", true)
138 .Case("hidden_private_base", true)
139 .Case("hidden_shared_base", true)
140 .Case("hidden_queue_ptr", true)
141 .Default(false);
143 return false;
144 if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
145 return false;
146 if (!verifyScalarEntry(ArgsMap, ".address_space", false,
147 msgpack::Type::String,
148 [](msgpack::DocNode &SNode) {
149 return StringSwitch<bool>(SNode.getString())
150 .Case("private", true)
151 .Case("global", true)
152 .Case("constant", true)
153 .Case("local", true)
154 .Case("generic", true)
155 .Case("region", true)
156 .Default(false);
158 return false;
159 if (!verifyScalarEntry(ArgsMap, ".access", false,
160 msgpack::Type::String,
161 [](msgpack::DocNode &SNode) {
162 return StringSwitch<bool>(SNode.getString())
163 .Case("read_only", true)
164 .Case("write_only", true)
165 .Case("read_write", true)
166 .Default(false);
168 return false;
169 if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
170 msgpack::Type::String,
171 [](msgpack::DocNode &SNode) {
172 return StringSwitch<bool>(SNode.getString())
173 .Case("read_only", true)
174 .Case("write_only", true)
175 .Case("read_write", true)
176 .Default(false);
178 return false;
179 if (!verifyScalarEntry(ArgsMap, ".is_const", false,
180 msgpack::Type::Boolean))
181 return false;
182 if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
183 msgpack::Type::Boolean))
184 return false;
185 if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
186 msgpack::Type::Boolean))
187 return false;
188 if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
189 msgpack::Type::Boolean))
190 return false;
192 return true;
195 bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
196 if (!Node.isMap())
197 return false;
198 auto &KernelMap = Node.getMap();
200 if (!verifyScalarEntry(KernelMap, ".name", true,
201 msgpack::Type::String))
202 return false;
203 if (!verifyScalarEntry(KernelMap, ".symbol", true,
204 msgpack::Type::String))
205 return false;
206 if (!verifyScalarEntry(KernelMap, ".language", false,
207 msgpack::Type::String,
208 [](msgpack::DocNode &SNode) {
209 return StringSwitch<bool>(SNode.getString())
210 .Case("OpenCL C", true)
211 .Case("OpenCL C++", true)
212 .Case("HCC", true)
213 .Case("HIP", true)
214 .Case("OpenMP", true)
215 .Case("Assembler", true)
216 .Default(false);
218 return false;
219 if (!verifyEntry(
220 KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
221 return verifyArray(
222 Node,
223 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
225 return false;
226 if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
227 return verifyArray(Node, [this](msgpack::DocNode &Node) {
228 return verifyKernelArgs(Node);
231 return false;
232 if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
233 [this](msgpack::DocNode &Node) {
234 return verifyArray(Node,
235 [this](msgpack::DocNode &Node) {
236 return verifyInteger(Node);
240 return false;
241 if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
242 [this](msgpack::DocNode &Node) {
243 return verifyArray(Node,
244 [this](msgpack::DocNode &Node) {
245 return verifyInteger(Node);
249 return false;
250 if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
251 msgpack::Type::String))
252 return false;
253 if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
254 msgpack::Type::String))
255 return false;
256 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
257 return false;
258 if (!verifyIntegerEntry(KernelMap, ".group_segment_fixed_size", true))
259 return false;
260 if (!verifyIntegerEntry(KernelMap, ".private_segment_fixed_size", true))
261 return false;
262 if (!verifyScalarEntry(KernelMap, ".uses_dynamic_stack", false,
263 msgpack::Type::Boolean))
264 return false;
265 if (!verifyIntegerEntry(KernelMap, ".workgroup_processor_mode", false))
266 return false;
267 if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_align", true))
268 return false;
269 if (!verifyIntegerEntry(KernelMap, ".wavefront_size", true))
270 return false;
271 if (!verifyIntegerEntry(KernelMap, ".sgpr_count", true))
272 return false;
273 if (!verifyIntegerEntry(KernelMap, ".vgpr_count", true))
274 return false;
275 if (!verifyIntegerEntry(KernelMap, ".max_flat_workgroup_size", true))
276 return false;
277 if (!verifyIntegerEntry(KernelMap, ".sgpr_spill_count", false))
278 return false;
279 if (!verifyIntegerEntry(KernelMap, ".vgpr_spill_count", false))
280 return false;
281 if (!verifyIntegerEntry(KernelMap, ".uniform_work_group_size", false))
282 return false;
285 return true;
288 bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
289 if (!HSAMetadataRoot.isMap())
290 return false;
291 auto &RootMap = HSAMetadataRoot.getMap();
293 if (!verifyEntry(
294 RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
295 return verifyArray(
296 Node,
297 [this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
299 return false;
300 if (!verifyEntry(
301 RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
302 return verifyArray(Node, [this](msgpack::DocNode &Node) {
303 return verifyScalar(Node, msgpack::Type::String);
306 return false;
307 if (!verifyEntry(RootMap, "amdhsa.kernels", true,
308 [this](msgpack::DocNode &Node) {
309 return verifyArray(Node, [this](msgpack::DocNode &Node) {
310 return verifyKernel(Node);
313 return false;
315 return true;
318 } // end namespace V3
319 } // end namespace HSAMD
320 } // end namespace AMDGPU
321 } // end namespace llvm