[NFC] Add libcxx python reformat SHA to .git-blame-ignore-revs
[llvm-project.git] / mlir / tools / mlir-tblgen / BytecodeDialectGen.cpp
blobf13bdd49413b09483dc2baaa8aa1e4256180db19
1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Support/IndentedOstream.h"
10 #include "mlir/TableGen/GenInfo.h"
11 #include "llvm/ADT/MapVector.h"
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/Support/CommandLine.h"
14 #include "llvm/Support/FormatVariadic.h"
15 #include "llvm/TableGen/Error.h"
16 #include "llvm/TableGen/Record.h"
17 #include <regex>
19 using namespace llvm;
21 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
22 static llvm::cl::opt<std::string>
23 selectedBcDialect("bytecode-dialect",
24 llvm::cl::desc("The dialect to gen for"),
25 llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
27 namespace {
29 /// Helper class to generate C++ bytecode parser helpers.
30 class Generator {
31 public:
32 Generator(raw_ostream &output) : output(output) {}
34 /// Returns whether successfully emitted attribute/type parsers.
35 void emitParse(StringRef kind, Record &x);
37 /// Returns whether successfully emitted attribute/type printers.
38 void emitPrint(StringRef kind, StringRef type,
39 ArrayRef<std::pair<int64_t, Record *>> vec);
41 /// Emits parse dispatch table.
42 void emitParseDispatch(StringRef kind, ArrayRef<Record *> vec);
44 /// Emits print dispatch table.
45 void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
47 private:
48 /// Emits parse calls to construct given kind.
49 void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
50 ArrayRef<Init *> args, ArrayRef<std::string> argNames,
51 StringRef failure, mlir::raw_indented_ostream &ios);
53 /// Emits print instructions.
54 void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent,
55 StringRef name, mlir::raw_indented_ostream &ios);
57 raw_ostream &output;
59 } // namespace
61 /// Helper to replace set of from strings to target in `s`.
62 /// Assumed: non-overlapping replacements.
63 static std::string format(StringRef templ,
64 std::map<std::string, std::string> &&map) {
65 std::string s = templ.str();
66 for (const auto &[from, to] : map)
67 // All replacements start with $, don't treat as anchor.
68 s = std::regex_replace(s, std::regex("\\" + from), to);
69 return s;
72 /// Return string with first character capitalized.
73 static std::string capitalize(StringRef str) {
74 return ((Twine)toUpper(str[0]) + str.drop_front()).str();
77 /// Return the C++ type for the given record.
78 static std::string getCType(Record *def) {
79 std::string format = "{0}";
80 if (def->isSubClassOf("Array")) {
81 def = def->getValueAsDef("elemT");
82 format = "SmallVector<{0}>";
85 StringRef cType = def->getValueAsString("cType");
86 if (cType.empty()) {
87 if (def->isAnonymous())
88 PrintFatalError(def->getLoc(), "Unable to determine cType");
90 return formatv(format.c_str(), def->getName().str());
92 return formatv(format.c_str(), cType.str());
95 void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
96 mlir::raw_indented_ostream os(output);
97 char const *head =
98 R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
99 os << formatv(head, capitalize(kind));
100 auto funScope = os.scope(" {\n", "}\n\n");
102 os << "uint64_t kind;\n";
103 os << "if (failed(reader.readVarInt(kind)))\n"
104 << " return " << capitalize(kind) << "();\n";
105 os << "switch (kind) ";
107 auto switchScope = os.scope("{\n", "}\n");
108 for (const auto &it : llvm::enumerate(vec)) {
109 os << formatv("case {1}:\n return read{0}(context, reader);\n",
110 it.value()->getName(), it.index());
112 os << "default:\n"
113 << " reader.emitError() << \"unknown attribute code: \" "
114 << "<< kind;\n"
115 << " return " << capitalize(kind) << "();\n";
117 os << "return " << capitalize(kind) << "();\n";
120 void Generator::emitParse(StringRef kind, Record &x) {
121 char const *head =
122 R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
123 mlir::raw_indented_ostream os(output);
124 std::string returnType = getCType(&x);
125 os << formatv(head, returnType, x.getName());
126 DagInit *members = x.getValueAsDag("members");
127 SmallVector<std::string> argNames =
128 llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
129 return init->getAsUnquotedString();
130 }));
131 StringRef builder = x.getValueAsString("cBuilder");
132 emitParseHelper(kind, returnType, builder, members->getArgs(), argNames,
133 returnType + "()", os);
134 os << "\n\n";
137 void printParseConditional(mlir::raw_indented_ostream &ios,
138 ArrayRef<Init *> args,
139 ArrayRef<std::string> argNames) {
140 ios << "if ";
141 auto parenScope = ios.scope("(", ") {");
142 ios.indent();
144 auto listHelperName = [](StringRef name) {
145 return formatv("read{0}", capitalize(name));
148 auto parsedArgs =
149 llvm::to_vector(make_filter_range(args, [](Init *const attr) {
150 Record *def = cast<DefInit>(attr)->getDef();
151 if (def->isSubClassOf("Array"))
152 return true;
153 return !def->getValueAsString("cParser").empty();
154 }));
156 interleave(
157 zip(parsedArgs, argNames),
158 [&](std::tuple<llvm::Init *&, const std::string &> it) {
159 Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
160 std::string parser;
161 if (auto optParser = attr->getValueAsOptionalString("cParser")) {
162 parser = *optParser;
163 } else if (attr->isSubClassOf("Array")) {
164 Record *def = attr->getValueAsDef("elemT");
165 bool composite = def->isSubClassOf("CompositeBytecode");
166 if (!composite && def->isSubClassOf("AttributeKind"))
167 parser = "succeeded($_reader.readAttributes($_var))";
168 else if (!composite && def->isSubClassOf("TypeKind"))
169 parser = "succeeded($_reader.readTypes($_var))";
170 else
171 parser = ("succeeded($_reader.readList($_var, " +
172 listHelperName(std::get<1>(it)) + "))")
173 .str();
174 } else {
175 PrintFatalError(attr->getLoc(), "No parser specified");
177 std::string type = getCType(attr);
178 ios << format(parser, {{"$_reader", "reader"},
179 {"$_resultType", type},
180 {"$_var", std::get<1>(it)}});
182 [&]() { ios << " &&\n"; });
185 void Generator::emitParseHelper(StringRef kind, StringRef returnType,
186 StringRef builder, ArrayRef<Init *> args,
187 ArrayRef<std::string> argNames,
188 StringRef failure,
189 mlir::raw_indented_ostream &ios) {
190 auto funScope = ios.scope("{\n", "}");
192 if (args.empty()) {
193 ios << formatv("return get<{0}>(context);\n", returnType);
194 return;
197 // Print decls.
198 std::string lastCType = "";
199 for (auto [arg, name] : zip(args, argNames)) {
200 DefInit *first = dyn_cast<DefInit>(arg);
201 if (!first)
202 PrintFatalError("Unexpected type for " + name);
203 Record *def = first->getDef();
205 // Create variable decls, if there are a block of same type then create
206 // comma separated list of them.
207 std::string cType = getCType(def);
208 if (lastCType == cType) {
209 ios << ", ";
210 } else {
211 if (!lastCType.empty())
212 ios << ";\n";
213 ios << cType << " ";
215 ios << name;
216 lastCType = cType;
218 ios << ";\n";
220 // Returns the name of the helper used in list parsing. E.g., the name of the
221 // lambda passed to array parsing.
222 auto listHelperName = [](StringRef name) {
223 return formatv("read{0}", capitalize(name));
226 // Emit list helper functions.
227 for (auto [arg, name] : zip(args, argNames)) {
228 Record *attr = cast<DefInit>(arg)->getDef();
229 if (!attr->isSubClassOf("Array"))
230 continue;
232 // TODO: Dedupe readers.
233 Record *def = attr->getValueAsDef("elemT");
234 if (!def->isSubClassOf("CompositeBytecode") &&
235 (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind")))
236 continue;
238 std::string returnType = getCType(def);
239 ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
240 << returnType << "> ";
241 SmallVector<Init *> args;
242 SmallVector<std::string> argNames;
243 if (def->isSubClassOf("CompositeBytecode")) {
244 DagInit *members = def->getValueAsDag("members");
245 args = llvm::to_vector(members->getArgs());
246 argNames = llvm::to_vector(
247 map_range(members->getArgNames(), [](StringInit *init) {
248 return init->getAsUnquotedString();
249 }));
250 } else {
251 args = {def->getDefInit()};
252 argNames = {"temp"};
254 StringRef builder = def->getValueAsString("cBuilder");
255 emitParseHelper(kind, returnType, builder, args, argNames, "failure()",
256 ios);
257 ios << ";\n";
260 // Print parse conditional.
261 printParseConditional(ios, args, argNames);
263 // Compute args to pass to create method.
264 auto passedArgs = llvm::to_vector(make_filter_range(
265 argNames, [](StringRef str) { return !str.starts_with("_"); }));
266 std::string argStr;
267 raw_string_ostream argStream(argStr);
268 interleaveComma(passedArgs, argStream,
269 [&](const std::string &str) { argStream << str; });
270 // Return the invoked constructor.
271 ios << "\nreturn "
272 << format(builder, {{"$_resultType", returnType.str()},
273 {"$_args", argStream.str()}})
274 << ";\n";
275 ios.unindent();
277 // TODO: Emit error in debug.
278 // This assumes the result types in error case can always be empty
279 // constructed.
280 ios << "}\nreturn " << failure << ";\n";
283 void Generator::emitPrint(StringRef kind, StringRef type,
284 ArrayRef<std::pair<int64_t, Record *>> vec) {
285 char const *head =
286 R"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
287 mlir::raw_indented_ostream os(output);
288 os << formatv(head, type, kind);
289 auto funScope = os.scope("{\n", "}\n\n");
291 // Check that predicates specified if multiple bytecode instances.
292 for (llvm::Record *rec : make_second_range(vec)) {
293 StringRef pred = rec->getValueAsString("printerPredicate");
294 if (vec.size() > 1 && pred.empty()) {
295 for (auto [index, rec] : vec) {
296 (void)index;
297 StringRef pred = rec->getValueAsString("printerPredicate");
298 if (vec.size() > 1 && pred.empty())
299 PrintError(rec->getLoc(),
300 "Requires parsing predicate given common cType");
302 PrintFatalError("Unspecified for shared cType " + type);
306 for (auto [index, rec] : vec) {
307 StringRef pred = rec->getValueAsString("printerPredicate");
308 if (!pred.empty()) {
309 os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n";
310 os.indent();
313 os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index
314 << ");\n";
316 auto *members = rec->getValueAsDag("members");
317 for (auto [arg, name] :
318 llvm::zip(members->getArgs(), members->getArgNames())) {
319 DefInit *def = dyn_cast<DefInit>(arg);
320 assert(def);
321 Record *memberRec = def->getDef();
322 emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
325 if (!pred.empty()) {
326 os.unindent();
327 os << "}\n";
332 void Generator::emitPrintHelper(Record *memberRec, StringRef kind,
333 StringRef parent, StringRef name,
334 mlir::raw_indented_ostream &ios) {
335 std::string getter;
336 if (auto cGetter = memberRec->getValueAsOptionalString("cGetter");
337 cGetter && !cGetter->empty()) {
338 getter = format(
339 *cGetter,
340 {{"$_attrType", parent.str()},
341 {"$_member", name.str()},
342 {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}});
343 } else {
344 getter =
345 formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true))
346 .str();
349 if (memberRec->isSubClassOf("Array")) {
350 Record *def = memberRec->getValueAsDef("elemT");
351 if (!def->isSubClassOf("CompositeBytecode")) {
352 if (def->isSubClassOf("AttributeKind")) {
353 ios << "writer.writeAttributes(" << getter << ");\n";
354 return;
356 if (def->isSubClassOf("TypeKind")) {
357 ios << "writer.writeTypes(" << getter << ");\n";
358 return;
361 std::string returnType = getCType(def);
362 ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
363 << kind << ") ";
364 auto lambdaScope = ios.scope("{\n", "});\n");
365 return emitPrintHelper(def, kind, kind, kind, ios);
367 if (memberRec->isSubClassOf("CompositeBytecode")) {
368 auto *members = memberRec->getValueAsDag("members");
369 for (auto [arg, argName] :
370 zip(members->getArgs(), members->getArgNames())) {
371 DefInit *def = dyn_cast<DefInit>(arg);
372 assert(def);
373 emitPrintHelper(def->getDef(), kind, parent,
374 argName->getAsUnquotedString(), ios);
378 if (std::string printer = memberRec->getValueAsString("cPrinter").str();
379 !printer.empty())
380 ios << format(printer, {{"$_writer", "writer"},
381 {"$_name", kind.str()},
382 {"$_getter", getter}})
383 << ";\n";
386 void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
387 mlir::raw_indented_ostream os(output);
388 char const *head = R"(static LogicalResult write{0}({0} {1},
389 DialectBytecodeWriter &writer))";
390 os << formatv(head, capitalize(kind), kind);
391 auto funScope = os.scope(" {\n", "}\n\n");
393 os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind
394 << ")";
395 auto switchScope = os.scope("", "");
396 for (StringRef type : vec) {
397 os << "\n.Case([&](" << type << " t)";
398 auto caseScope = os.scope(" {\n", "})");
399 os << "return write(t, writer), success();\n";
401 os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n";
404 namespace {
405 /// Container of Attribute or Type for Dialect.
406 struct AttrOrType {
407 std::vector<Record *> attr, type;
409 } // namespace
411 static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
412 MapVector<StringRef, AttrOrType> dialectAttrOrType;
413 for (auto &it : records.getAllDerivedDefinitions("DialectAttributes")) {
414 if (!selectedBcDialect.empty() &&
415 it->getValueAsString("dialect") != selectedBcDialect)
416 continue;
417 dialectAttrOrType[it->getValueAsString("dialect")].attr =
418 it->getValueAsListOfDefs("elems");
420 for (auto &it : records.getAllDerivedDefinitions("DialectTypes")) {
421 if (!selectedBcDialect.empty() &&
422 it->getValueAsString("dialect") != selectedBcDialect)
423 continue;
424 dialectAttrOrType[it->getValueAsString("dialect")].type =
425 it->getValueAsListOfDefs("elems");
428 if (dialectAttrOrType.size() != 1)
429 PrintFatalError("Single dialect per invocation required (either only "
430 "one in input file or specified via dialect option)");
432 auto it = dialectAttrOrType.front();
433 Generator gen(os);
435 SmallVector<std::vector<Record *> *, 2> vecs;
436 SmallVector<std::string, 2> kinds;
437 vecs.push_back(&it.second.attr);
438 kinds.push_back("attribute");
439 vecs.push_back(&it.second.type);
440 kinds.push_back("type");
441 for (auto [vec, kind] : zip(vecs, kinds)) {
442 // Handle Attribute/Type emission.
443 std::map<std::string, std::vector<std::pair<int64_t, Record *>>> perType;
444 for (auto kt : llvm::enumerate(*vec))
445 perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
446 for (const auto &jt : perType) {
447 for (auto kt : jt.second)
448 gen.emitParse(kind, *std::get<1>(kt));
449 gen.emitPrint(kind, jt.first, jt.second);
451 gen.emitParseDispatch(kind, *vec);
453 SmallVector<std::string> types;
454 for (const auto &it : perType) {
455 types.push_back(it.first);
457 gen.emitPrintDispatch(kind, types);
460 return false;
463 static mlir::GenRegistration
464 genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
465 [](const RecordKeeper &records, raw_ostream &os) {
466 return emitBCRW(records, os);