1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Support/IndentedOstream.h"
10 #include "mlir/TableGen/GenInfo.h"
11 #include "llvm/ADT/MapVector.h"
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/Support/CommandLine.h"
14 #include "llvm/Support/FormatVariadic.h"
15 #include "llvm/TableGen/Error.h"
16 #include "llvm/TableGen/Record.h"
21 static cl::OptionCategory
dialectGenCat("Options for -gen-bytecode");
22 static cl::opt
<std::string
>
23 selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"),
24 cl::cat(dialectGenCat
), cl::CommaSeparated
);
28 /// Helper class to generate C++ bytecode parser helpers.
31 Generator(raw_ostream
&output
) : output(output
) {}
33 /// Returns whether successfully emitted attribute/type parsers.
34 void emitParse(StringRef kind
, const Record
&x
);
36 /// Returns whether successfully emitted attribute/type printers.
37 void emitPrint(StringRef kind
, StringRef type
,
38 ArrayRef
<std::pair
<int64_t, const Record
*>> vec
);
40 /// Emits parse dispatch table.
41 void emitParseDispatch(StringRef kind
, ArrayRef
<const Record
*> vec
);
43 /// Emits print dispatch table.
44 void emitPrintDispatch(StringRef kind
, ArrayRef
<std::string
> vec
);
47 /// Emits parse calls to construct given kind.
48 void emitParseHelper(StringRef kind
, StringRef returnType
, StringRef builder
,
49 ArrayRef
<const Init
*> args
,
50 ArrayRef
<std::string
> argNames
, StringRef failure
,
51 mlir::raw_indented_ostream
&ios
);
53 /// Emits print instructions.
54 void emitPrintHelper(const Record
*memberRec
, StringRef kind
,
55 StringRef parent
, StringRef name
,
56 mlir::raw_indented_ostream
&ios
);
62 /// Helper to replace set of from strings to target in `s`.
63 /// Assumed: non-overlapping replacements.
64 static std::string
format(StringRef templ
,
65 std::map
<std::string
, std::string
> &&map
) {
66 std::string s
= templ
.str();
67 for (const auto &[from
, to
] : map
)
68 // All replacements start with $, don't treat as anchor.
69 s
= std::regex_replace(s
, std::regex("\\" + from
), to
);
73 /// Return string with first character capitalized.
74 static std::string
capitalize(StringRef str
) {
75 return ((Twine
)toUpper(str
[0]) + str
.drop_front()).str();
78 /// Return the C++ type for the given record.
79 static std::string
getCType(const Record
*def
) {
80 std::string format
= "{0}";
81 if (def
->isSubClassOf("Array")) {
82 def
= def
->getValueAsDef("elemT");
83 format
= "SmallVector<{0}>";
86 StringRef cType
= def
->getValueAsString("cType");
88 if (def
->isAnonymous())
89 PrintFatalError(def
->getLoc(), "Unable to determine cType");
91 return formatv(format
.c_str(), def
->getName().str());
93 return formatv(format
.c_str(), cType
.str());
96 void Generator::emitParseDispatch(StringRef kind
,
97 ArrayRef
<const Record
*> vec
) {
98 mlir::raw_indented_ostream
os(output
);
100 R
"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
101 os
<< formatv(head
, capitalize(kind
));
102 auto funScope
= os
.scope(" {\n", "}\n\n");
105 os
<< "return reader.emitError() << \"unknown attribute\", "
106 << capitalize(kind
) << "();\n";
110 os
<< "uint64_t kind;\n";
111 os
<< "if (failed(reader.readVarInt(kind)))\n"
112 << " return " << capitalize(kind
) << "();\n";
113 os
<< "switch (kind) ";
115 auto switchScope
= os
.scope("{\n", "}\n");
116 for (const auto &it
: llvm::enumerate(vec
)) {
117 if (it
.value()->getName() == "ReservedOrDead")
120 os
<< formatv("case {1}:\n return read{0}(context, reader);\n",
121 it
.value()->getName(), it
.index());
124 << " reader.emitError() << \"unknown attribute code: \" "
126 << " return " << capitalize(kind
) << "();\n";
128 os
<< "return " << capitalize(kind
) << "();\n";
131 void Generator::emitParse(StringRef kind
, const Record
&x
) {
132 if (x
.getNameInitAsString() == "ReservedOrDead")
136 R
"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
137 mlir::raw_indented_ostream
os(output
);
138 std::string returnType
= getCType(&x
);
140 kind
== "attribute" ? "::mlir::Attribute" : "::mlir::Type",
142 const DagInit
*members
= x
.getValueAsDag("members");
143 SmallVector
<std::string
> argNames
= llvm::to_vector(
144 map_range(members
->getArgNames(), [](const StringInit
*init
) {
145 return init
->getAsUnquotedString();
147 StringRef builder
= x
.getValueAsString("cBuilder").trim();
148 emitParseHelper(kind
, returnType
, builder
, members
->getArgs(), argNames
,
149 returnType
+ "()", os
);
153 void printParseConditional(mlir::raw_indented_ostream
&ios
,
154 ArrayRef
<const Init
*> args
,
155 ArrayRef
<std::string
> argNames
) {
157 auto parenScope
= ios
.scope("(", ") {");
160 auto listHelperName
= [](StringRef name
) {
161 return formatv("read{0}", capitalize(name
));
165 llvm::to_vector(make_filter_range(args
, [](const Init
*const attr
) {
166 const Record
*def
= cast
<DefInit
>(attr
)->getDef();
167 if (def
->isSubClassOf("Array"))
169 return !def
->getValueAsString("cParser").empty();
173 zip(parsedArgs
, argNames
),
174 [&](std::tuple
<const Init
*&, const std::string
&> it
) {
175 const Record
*attr
= cast
<DefInit
>(std::get
<0>(it
))->getDef();
177 if (auto optParser
= attr
->getValueAsOptionalString("cParser")) {
179 } else if (attr
->isSubClassOf("Array")) {
180 const Record
*def
= attr
->getValueAsDef("elemT");
181 bool composite
= def
->isSubClassOf("CompositeBytecode");
182 if (!composite
&& def
->isSubClassOf("AttributeKind"))
183 parser
= "succeeded($_reader.readAttributes($_var))";
184 else if (!composite
&& def
->isSubClassOf("TypeKind"))
185 parser
= "succeeded($_reader.readTypes($_var))";
187 parser
= ("succeeded($_reader.readList($_var, " +
188 listHelperName(std::get
<1>(it
)) + "))")
191 PrintFatalError(attr
->getLoc(), "No parser specified");
193 std::string type
= getCType(attr
);
194 ios
<< format(parser
, {{"$_reader", "reader"},
195 {"$_resultType", type
},
196 {"$_var", std::get
<1>(it
)}});
198 [&]() { ios
<< " &&\n"; });
201 void Generator::emitParseHelper(StringRef kind
, StringRef returnType
,
202 StringRef builder
, ArrayRef
<const Init
*> args
,
203 ArrayRef
<std::string
> argNames
,
205 mlir::raw_indented_ostream
&ios
) {
206 auto funScope
= ios
.scope("{\n", "}");
209 ios
<< formatv("return get<{0}>(context);\n", returnType
);
214 std::string lastCType
= "";
215 for (auto [arg
, name
] : zip(args
, argNames
)) {
216 const DefInit
*first
= dyn_cast
<DefInit
>(arg
);
218 PrintFatalError("Unexpected type for " + name
);
219 const Record
*def
= first
->getDef();
221 // Create variable decls, if there are a block of same type then create
222 // comma separated list of them.
223 std::string cType
= getCType(def
);
224 if (lastCType
== cType
) {
227 if (!lastCType
.empty())
236 // Returns the name of the helper used in list parsing. E.g., the name of the
237 // lambda passed to array parsing.
238 auto listHelperName
= [](StringRef name
) {
239 return formatv("read{0}", capitalize(name
));
242 // Emit list helper functions.
243 for (auto [arg
, name
] : zip(args
, argNames
)) {
244 const Record
*attr
= cast
<DefInit
>(arg
)->getDef();
245 if (!attr
->isSubClassOf("Array"))
248 // TODO: Dedupe readers.
249 const Record
*def
= attr
->getValueAsDef("elemT");
250 if (!def
->isSubClassOf("CompositeBytecode") &&
251 (def
->isSubClassOf("AttributeKind") || def
->isSubClassOf("TypeKind")))
254 std::string returnType
= getCType(def
);
255 ios
<< "auto " << listHelperName(name
) << " = [&]() -> FailureOr<"
256 << returnType
<< "> ";
257 SmallVector
<const Init
*> args
;
258 SmallVector
<std::string
> argNames
;
259 if (def
->isSubClassOf("CompositeBytecode")) {
260 const DagInit
*members
= def
->getValueAsDag("members");
261 args
= llvm::to_vector(members
->getArgs());
262 argNames
= llvm::to_vector(
263 map_range(members
->getArgNames(), [](const StringInit
*init
) {
264 return init
->getAsUnquotedString();
267 args
= {def
->getDefInit()};
270 StringRef builder
= def
->getValueAsString("cBuilder");
271 emitParseHelper(kind
, returnType
, builder
, args
, argNames
, "failure()",
276 // Print parse conditional.
277 printParseConditional(ios
, args
, argNames
);
279 // Compute args to pass to create method.
280 auto passedArgs
= llvm::to_vector(make_filter_range(
281 argNames
, [](StringRef str
) { return !str
.starts_with("_"); }));
283 raw_string_ostream
argStream(argStr
);
284 interleaveComma(passedArgs
, argStream
,
285 [&](const std::string
&str
) { argStream
<< str
; });
286 // Return the invoked constructor.
288 << format(builder
, {{"$_resultType", returnType
.str()},
289 {"$_args", argStream
.str()}})
293 // TODO: Emit error in debug.
294 // This assumes the result types in error case can always be empty
296 ios
<< "}\nreturn " << failure
<< ";\n";
299 void Generator::emitPrint(StringRef kind
, StringRef type
,
300 ArrayRef
<std::pair
<int64_t, const Record
*>> vec
) {
301 if (type
== "ReservedOrDead")
305 R
"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
306 mlir::raw_indented_ostream
os(output
);
307 os
<< formatv(head
, type
, kind
);
308 auto funScope
= os
.scope("{\n", "}\n\n");
310 // Check that predicates specified if multiple bytecode instances.
311 for (const Record
*rec
: make_second_range(vec
)) {
312 StringRef pred
= rec
->getValueAsString("printerPredicate");
313 if (vec
.size() > 1 && pred
.empty()) {
314 for (auto [index
, rec
] : vec
) {
316 StringRef pred
= rec
->getValueAsString("printerPredicate");
317 if (vec
.size() > 1 && pred
.empty())
318 PrintError(rec
->getLoc(),
319 "Requires parsing predicate given common cType");
321 PrintFatalError("Unspecified for shared cType " + type
);
325 for (auto [index
, rec
] : vec
) {
326 StringRef pred
= rec
->getValueAsString("printerPredicate");
328 os
<< "if (" << format(pred
, {{"$_val", kind
.str()}}) << ") {\n";
332 os
<< "writer.writeVarInt(/* " << rec
->getName() << " */ " << index
335 auto *members
= rec
->getValueAsDag("members");
336 for (auto [arg
, name
] :
337 llvm::zip(members
->getArgs(), members
->getArgNames())) {
338 const DefInit
*def
= dyn_cast
<DefInit
>(arg
);
340 const Record
*memberRec
= def
->getDef();
341 emitPrintHelper(memberRec
, kind
, kind
, name
->getAsUnquotedString(), os
);
351 void Generator::emitPrintHelper(const Record
*memberRec
, StringRef kind
,
352 StringRef parent
, StringRef name
,
353 mlir::raw_indented_ostream
&ios
) {
355 if (auto cGetter
= memberRec
->getValueAsOptionalString("cGetter");
356 cGetter
&& !cGetter
->empty()) {
359 {{"$_attrType", parent
.str()},
360 {"$_member", name
.str()},
361 {"$_getMember", "get" + convertToCamelFromSnakeCase(name
, true)}});
364 formatv("{0}.get{1}()", parent
, convertToCamelFromSnakeCase(name
, true))
368 if (memberRec
->isSubClassOf("Array")) {
369 const Record
*def
= memberRec
->getValueAsDef("elemT");
370 if (!def
->isSubClassOf("CompositeBytecode")) {
371 if (def
->isSubClassOf("AttributeKind")) {
372 ios
<< "writer.writeAttributes(" << getter
<< ");\n";
375 if (def
->isSubClassOf("TypeKind")) {
376 ios
<< "writer.writeTypes(" << getter
<< ");\n";
380 std::string returnType
= getCType(def
);
381 std::string nestedName
= kind
.str();
382 ios
<< "writer.writeList(" << getter
<< ", [&](" << returnType
<< " "
383 << nestedName
<< ") ";
384 auto lambdaScope
= ios
.scope("{\n", "});\n");
385 return emitPrintHelper(def
, kind
, nestedName
, nestedName
, ios
);
387 if (memberRec
->isSubClassOf("CompositeBytecode")) {
388 auto *members
= memberRec
->getValueAsDag("members");
389 for (auto [arg
, argName
] :
390 zip(members
->getArgs(), members
->getArgNames())) {
391 const DefInit
*def
= dyn_cast
<DefInit
>(arg
);
393 emitPrintHelper(def
->getDef(), kind
, parent
,
394 argName
->getAsUnquotedString(), ios
);
398 if (std::string printer
= memberRec
->getValueAsString("cPrinter").str();
400 ios
<< format(printer
, {{"$_writer", "writer"},
401 {"$_name", kind
.str()},
402 {"$_getter", getter
}})
406 void Generator::emitPrintDispatch(StringRef kind
, ArrayRef
<std::string
> vec
) {
407 mlir::raw_indented_ostream
os(output
);
408 char const *head
= R
"(static LogicalResult write{0}({0} {1},
409 DialectBytecodeWriter &writer))";
410 os
<< formatv(head
, capitalize(kind
), kind
);
411 auto funScope
= os
.scope(" {\n", "}\n\n");
413 os
<< "return TypeSwitch<" << capitalize(kind
) << ", LogicalResult>(" << kind
415 auto switchScope
= os
.scope("", "");
416 for (StringRef type
: vec
) {
417 if (type
== "ReservedOrDead")
420 os
<< "\n.Case([&](" << type
<< " t)";
421 auto caseScope
= os
.scope(" {\n", "})");
422 os
<< "return write(t, writer), success();\n";
424 os
<< "\n.Default([&](" << capitalize(kind
) << ") { return failure(); });\n";
428 /// Container of Attribute or Type for Dialect.
430 std::vector
<const Record
*> attr
, type
;
434 static bool emitBCRW(const RecordKeeper
&records
, raw_ostream
&os
) {
435 MapVector
<StringRef
, AttrOrType
> dialectAttrOrType
;
436 for (const Record
*it
:
437 records
.getAllDerivedDefinitions("DialectAttributes")) {
438 if (!selectedBcDialect
.empty() &&
439 it
->getValueAsString("dialect") != selectedBcDialect
)
441 dialectAttrOrType
[it
->getValueAsString("dialect")].attr
=
442 it
->getValueAsListOfDefs("elems");
444 for (const Record
*it
: records
.getAllDerivedDefinitions("DialectTypes")) {
445 if (!selectedBcDialect
.empty() &&
446 it
->getValueAsString("dialect") != selectedBcDialect
)
448 dialectAttrOrType
[it
->getValueAsString("dialect")].type
=
449 it
->getValueAsListOfDefs("elems");
452 if (dialectAttrOrType
.size() != 1)
453 PrintFatalError("Single dialect per invocation required (either only "
454 "one in input file or specified via dialect option)");
456 auto it
= dialectAttrOrType
.front();
459 SmallVector
<std::vector
<const Record
*> *, 2> vecs
;
460 SmallVector
<std::string
, 2> kinds
;
461 vecs
.push_back(&it
.second
.attr
);
462 kinds
.push_back("attribute");
463 vecs
.push_back(&it
.second
.type
);
464 kinds
.push_back("type");
465 for (auto [vec
, kind
] : zip(vecs
, kinds
)) {
466 // Handle Attribute/Type emission.
467 std::map
<std::string
, std::vector
<std::pair
<int64_t, const Record
*>>>
469 for (auto kt
: llvm::enumerate(*vec
))
470 perType
[getCType(kt
.value())].emplace_back(kt
.index(), kt
.value());
471 for (const auto &jt
: perType
) {
472 for (auto kt
: jt
.second
)
473 gen
.emitParse(kind
, *std::get
<1>(kt
));
474 gen
.emitPrint(kind
, jt
.first
, jt
.second
);
476 gen
.emitParseDispatch(kind
, *vec
);
478 SmallVector
<std::string
> types
;
479 for (const auto &it
: perType
) {
480 types
.push_back(it
.first
);
482 gen
.emitPrintDispatch(kind
, types
);
488 static mlir::GenRegistration
489 genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
490 [](const RecordKeeper
&records
, raw_ostream
&os
) {
491 return emitBCRW(records
, os
);