1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Support/IndentedOstream.h"
10 #include "mlir/TableGen/GenInfo.h"
11 #include "llvm/ADT/MapVector.h"
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/Support/CommandLine.h"
14 #include "llvm/Support/FormatVariadic.h"
15 #include "llvm/TableGen/Error.h"
16 #include "llvm/TableGen/Record.h"
21 static llvm::cl::OptionCategory
dialectGenCat("Options for -gen-bytecode");
22 static llvm::cl::opt
<std::string
>
23 selectedBcDialect("bytecode-dialect",
24 llvm::cl::desc("The dialect to gen for"),
25 llvm::cl::cat(dialectGenCat
), llvm::cl::CommaSeparated
);
29 /// Helper class to generate C++ bytecode parser helpers.
32 Generator(raw_ostream
&output
) : output(output
) {}
34 /// Returns whether successfully emitted attribute/type parsers.
35 void emitParse(StringRef kind
, Record
&x
);
37 /// Returns whether successfully emitted attribute/type printers.
38 void emitPrint(StringRef kind
, StringRef type
,
39 ArrayRef
<std::pair
<int64_t, Record
*>> vec
);
41 /// Emits parse dispatch table.
42 void emitParseDispatch(StringRef kind
, ArrayRef
<Record
*> vec
);
44 /// Emits print dispatch table.
45 void emitPrintDispatch(StringRef kind
, ArrayRef
<std::string
> vec
);
48 /// Emits parse calls to construct given kind.
49 void emitParseHelper(StringRef kind
, StringRef returnType
, StringRef builder
,
50 ArrayRef
<Init
*> args
, ArrayRef
<std::string
> argNames
,
51 StringRef failure
, mlir::raw_indented_ostream
&ios
);
53 /// Emits print instructions.
54 void emitPrintHelper(Record
*memberRec
, StringRef kind
, StringRef parent
,
55 StringRef name
, mlir::raw_indented_ostream
&ios
);
61 /// Helper to replace set of from strings to target in `s`.
62 /// Assumed: non-overlapping replacements.
63 static std::string
format(StringRef templ
,
64 std::map
<std::string
, std::string
> &&map
) {
65 std::string s
= templ
.str();
66 for (const auto &[from
, to
] : map
)
67 // All replacements start with $, don't treat as anchor.
68 s
= std::regex_replace(s
, std::regex("\\" + from
), to
);
72 /// Return string with first character capitalized.
73 static std::string
capitalize(StringRef str
) {
74 return ((Twine
)toUpper(str
[0]) + str
.drop_front()).str();
77 /// Return the C++ type for the given record.
78 static std::string
getCType(Record
*def
) {
79 std::string format
= "{0}";
80 if (def
->isSubClassOf("Array")) {
81 def
= def
->getValueAsDef("elemT");
82 format
= "SmallVector<{0}>";
85 StringRef cType
= def
->getValueAsString("cType");
87 if (def
->isAnonymous())
88 PrintFatalError(def
->getLoc(), "Unable to determine cType");
90 return formatv(format
.c_str(), def
->getName().str());
92 return formatv(format
.c_str(), cType
.str());
95 void Generator::emitParseDispatch(StringRef kind
, ArrayRef
<Record
*> vec
) {
96 mlir::raw_indented_ostream
os(output
);
98 R
"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
99 os
<< formatv(head
, capitalize(kind
));
100 auto funScope
= os
.scope(" {\n", "}\n\n");
103 os
<< "return reader.emitError() << \"unknown attribute\", "
104 << capitalize(kind
) << "();\n";
108 os
<< "uint64_t kind;\n";
109 os
<< "if (failed(reader.readVarInt(kind)))\n"
110 << " return " << capitalize(kind
) << "();\n";
111 os
<< "switch (kind) ";
113 auto switchScope
= os
.scope("{\n", "}\n");
114 for (const auto &it
: llvm::enumerate(vec
)) {
115 if (it
.value()->getName() == "ReservedOrDead")
118 os
<< formatv("case {1}:\n return read{0}(context, reader);\n",
119 it
.value()->getName(), it
.index());
122 << " reader.emitError() << \"unknown attribute code: \" "
124 << " return " << capitalize(kind
) << "();\n";
126 os
<< "return " << capitalize(kind
) << "();\n";
129 void Generator::emitParse(StringRef kind
, Record
&x
) {
130 if (x
.getNameInitAsString() == "ReservedOrDead")
134 R
"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
135 mlir::raw_indented_ostream
os(output
);
136 std::string returnType
= getCType(&x
);
137 os
<< formatv(head
, kind
== "attribute" ? "::mlir::Attribute" : "::mlir::Type", x
.getName());
138 DagInit
*members
= x
.getValueAsDag("members");
139 SmallVector
<std::string
> argNames
=
140 llvm::to_vector(map_range(members
->getArgNames(), [](StringInit
*init
) {
141 return init
->getAsUnquotedString();
143 StringRef builder
= x
.getValueAsString("cBuilder").trim();
144 emitParseHelper(kind
, returnType
, builder
, members
->getArgs(), argNames
,
145 returnType
+ "()", os
);
149 void printParseConditional(mlir::raw_indented_ostream
&ios
,
150 ArrayRef
<Init
*> args
,
151 ArrayRef
<std::string
> argNames
) {
153 auto parenScope
= ios
.scope("(", ") {");
156 auto listHelperName
= [](StringRef name
) {
157 return formatv("read{0}", capitalize(name
));
161 llvm::to_vector(make_filter_range(args
, [](Init
*const attr
) {
162 Record
*def
= cast
<DefInit
>(attr
)->getDef();
163 if (def
->isSubClassOf("Array"))
165 return !def
->getValueAsString("cParser").empty();
169 zip(parsedArgs
, argNames
),
170 [&](std::tuple
<llvm::Init
*&, const std::string
&> it
) {
171 Record
*attr
= cast
<DefInit
>(std::get
<0>(it
))->getDef();
173 if (auto optParser
= attr
->getValueAsOptionalString("cParser")) {
175 } else if (attr
->isSubClassOf("Array")) {
176 Record
*def
= attr
->getValueAsDef("elemT");
177 bool composite
= def
->isSubClassOf("CompositeBytecode");
178 if (!composite
&& def
->isSubClassOf("AttributeKind"))
179 parser
= "succeeded($_reader.readAttributes($_var))";
180 else if (!composite
&& def
->isSubClassOf("TypeKind"))
181 parser
= "succeeded($_reader.readTypes($_var))";
183 parser
= ("succeeded($_reader.readList($_var, " +
184 listHelperName(std::get
<1>(it
)) + "))")
187 PrintFatalError(attr
->getLoc(), "No parser specified");
189 std::string type
= getCType(attr
);
190 ios
<< format(parser
, {{"$_reader", "reader"},
191 {"$_resultType", type
},
192 {"$_var", std::get
<1>(it
)}});
194 [&]() { ios
<< " &&\n"; });
197 void Generator::emitParseHelper(StringRef kind
, StringRef returnType
,
198 StringRef builder
, ArrayRef
<Init
*> args
,
199 ArrayRef
<std::string
> argNames
,
201 mlir::raw_indented_ostream
&ios
) {
202 auto funScope
= ios
.scope("{\n", "}");
205 ios
<< formatv("return get<{0}>(context);\n", returnType
);
210 std::string lastCType
= "";
211 for (auto [arg
, name
] : zip(args
, argNames
)) {
212 DefInit
*first
= dyn_cast
<DefInit
>(arg
);
214 PrintFatalError("Unexpected type for " + name
);
215 Record
*def
= first
->getDef();
217 // Create variable decls, if there are a block of same type then create
218 // comma separated list of them.
219 std::string cType
= getCType(def
);
220 if (lastCType
== cType
) {
223 if (!lastCType
.empty())
232 // Returns the name of the helper used in list parsing. E.g., the name of the
233 // lambda passed to array parsing.
234 auto listHelperName
= [](StringRef name
) {
235 return formatv("read{0}", capitalize(name
));
238 // Emit list helper functions.
239 for (auto [arg
, name
] : zip(args
, argNames
)) {
240 Record
*attr
= cast
<DefInit
>(arg
)->getDef();
241 if (!attr
->isSubClassOf("Array"))
244 // TODO: Dedupe readers.
245 Record
*def
= attr
->getValueAsDef("elemT");
246 if (!def
->isSubClassOf("CompositeBytecode") &&
247 (def
->isSubClassOf("AttributeKind") || def
->isSubClassOf("TypeKind")))
250 std::string returnType
= getCType(def
);
251 ios
<< "auto " << listHelperName(name
) << " = [&]() -> FailureOr<"
252 << returnType
<< "> ";
253 SmallVector
<Init
*> args
;
254 SmallVector
<std::string
> argNames
;
255 if (def
->isSubClassOf("CompositeBytecode")) {
256 DagInit
*members
= def
->getValueAsDag("members");
257 args
= llvm::to_vector(members
->getArgs());
258 argNames
= llvm::to_vector(
259 map_range(members
->getArgNames(), [](StringInit
*init
) {
260 return init
->getAsUnquotedString();
263 args
= {def
->getDefInit()};
266 StringRef builder
= def
->getValueAsString("cBuilder");
267 emitParseHelper(kind
, returnType
, builder
, args
, argNames
, "failure()",
272 // Print parse conditional.
273 printParseConditional(ios
, args
, argNames
);
275 // Compute args to pass to create method.
276 auto passedArgs
= llvm::to_vector(make_filter_range(
277 argNames
, [](StringRef str
) { return !str
.starts_with("_"); }));
279 raw_string_ostream
argStream(argStr
);
280 interleaveComma(passedArgs
, argStream
,
281 [&](const std::string
&str
) { argStream
<< str
; });
282 // Return the invoked constructor.
284 << format(builder
, {{"$_resultType", returnType
.str()},
285 {"$_args", argStream
.str()}})
289 // TODO: Emit error in debug.
290 // This assumes the result types in error case can always be empty
292 ios
<< "}\nreturn " << failure
<< ";\n";
295 void Generator::emitPrint(StringRef kind
, StringRef type
,
296 ArrayRef
<std::pair
<int64_t, Record
*>> vec
) {
297 if (type
== "ReservedOrDead")
301 R
"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
302 mlir::raw_indented_ostream
os(output
);
303 os
<< formatv(head
, type
, kind
);
304 auto funScope
= os
.scope("{\n", "}\n\n");
306 // Check that predicates specified if multiple bytecode instances.
307 for (llvm::Record
*rec
: make_second_range(vec
)) {
308 StringRef pred
= rec
->getValueAsString("printerPredicate");
309 if (vec
.size() > 1 && pred
.empty()) {
310 for (auto [index
, rec
] : vec
) {
312 StringRef pred
= rec
->getValueAsString("printerPredicate");
313 if (vec
.size() > 1 && pred
.empty())
314 PrintError(rec
->getLoc(),
315 "Requires parsing predicate given common cType");
317 PrintFatalError("Unspecified for shared cType " + type
);
321 for (auto [index
, rec
] : vec
) {
322 StringRef pred
= rec
->getValueAsString("printerPredicate");
324 os
<< "if (" << format(pred
, {{"$_val", kind
.str()}}) << ") {\n";
328 os
<< "writer.writeVarInt(/* " << rec
->getName() << " */ " << index
331 auto *members
= rec
->getValueAsDag("members");
332 for (auto [arg
, name
] :
333 llvm::zip(members
->getArgs(), members
->getArgNames())) {
334 DefInit
*def
= dyn_cast
<DefInit
>(arg
);
336 Record
*memberRec
= def
->getDef();
337 emitPrintHelper(memberRec
, kind
, kind
, name
->getAsUnquotedString(), os
);
347 void Generator::emitPrintHelper(Record
*memberRec
, StringRef kind
,
348 StringRef parent
, StringRef name
,
349 mlir::raw_indented_ostream
&ios
) {
351 if (auto cGetter
= memberRec
->getValueAsOptionalString("cGetter");
352 cGetter
&& !cGetter
->empty()) {
355 {{"$_attrType", parent
.str()},
356 {"$_member", name
.str()},
357 {"$_getMember", "get" + convertToCamelFromSnakeCase(name
, true)}});
360 formatv("{0}.get{1}()", parent
, convertToCamelFromSnakeCase(name
, true))
364 if (memberRec
->isSubClassOf("Array")) {
365 Record
*def
= memberRec
->getValueAsDef("elemT");
366 if (!def
->isSubClassOf("CompositeBytecode")) {
367 if (def
->isSubClassOf("AttributeKind")) {
368 ios
<< "writer.writeAttributes(" << getter
<< ");\n";
371 if (def
->isSubClassOf("TypeKind")) {
372 ios
<< "writer.writeTypes(" << getter
<< ");\n";
376 std::string returnType
= getCType(def
);
377 std::string nestedName
= kind
.str();
378 ios
<< "writer.writeList(" << getter
<< ", [&](" << returnType
<< " "
379 << nestedName
<< ") ";
380 auto lambdaScope
= ios
.scope("{\n", "});\n");
381 return emitPrintHelper(def
, kind
, nestedName
, nestedName
, ios
);
383 if (memberRec
->isSubClassOf("CompositeBytecode")) {
384 auto *members
= memberRec
->getValueAsDag("members");
385 for (auto [arg
, argName
] :
386 zip(members
->getArgs(), members
->getArgNames())) {
387 DefInit
*def
= dyn_cast
<DefInit
>(arg
);
389 emitPrintHelper(def
->getDef(), kind
, parent
,
390 argName
->getAsUnquotedString(), ios
);
394 if (std::string printer
= memberRec
->getValueAsString("cPrinter").str();
396 ios
<< format(printer
, {{"$_writer", "writer"},
397 {"$_name", kind
.str()},
398 {"$_getter", getter
}})
402 void Generator::emitPrintDispatch(StringRef kind
, ArrayRef
<std::string
> vec
) {
403 mlir::raw_indented_ostream
os(output
);
404 char const *head
= R
"(static LogicalResult write{0}({0} {1},
405 DialectBytecodeWriter &writer))";
406 os
<< formatv(head
, capitalize(kind
), kind
);
407 auto funScope
= os
.scope(" {\n", "}\n\n");
409 os
<< "return TypeSwitch<" << capitalize(kind
) << ", LogicalResult>(" << kind
411 auto switchScope
= os
.scope("", "");
412 for (StringRef type
: vec
) {
413 if (type
== "ReservedOrDead")
416 os
<< "\n.Case([&](" << type
<< " t)";
417 auto caseScope
= os
.scope(" {\n", "})");
418 os
<< "return write(t, writer), success();\n";
420 os
<< "\n.Default([&](" << capitalize(kind
) << ") { return failure(); });\n";
424 /// Container of Attribute or Type for Dialect.
426 std::vector
<Record
*> attr
, type
;
430 static bool emitBCRW(const RecordKeeper
&records
, raw_ostream
&os
) {
431 MapVector
<StringRef
, AttrOrType
> dialectAttrOrType
;
432 for (auto &it
: records
.getAllDerivedDefinitions("DialectAttributes")) {
433 if (!selectedBcDialect
.empty() &&
434 it
->getValueAsString("dialect") != selectedBcDialect
)
436 dialectAttrOrType
[it
->getValueAsString("dialect")].attr
=
437 it
->getValueAsListOfDefs("elems");
439 for (auto &it
: records
.getAllDerivedDefinitions("DialectTypes")) {
440 if (!selectedBcDialect
.empty() &&
441 it
->getValueAsString("dialect") != selectedBcDialect
)
443 dialectAttrOrType
[it
->getValueAsString("dialect")].type
=
444 it
->getValueAsListOfDefs("elems");
447 if (dialectAttrOrType
.size() != 1)
448 PrintFatalError("Single dialect per invocation required (either only "
449 "one in input file or specified via dialect option)");
451 auto it
= dialectAttrOrType
.front();
454 SmallVector
<std::vector
<Record
*> *, 2> vecs
;
455 SmallVector
<std::string
, 2> kinds
;
456 vecs
.push_back(&it
.second
.attr
);
457 kinds
.push_back("attribute");
458 vecs
.push_back(&it
.second
.type
);
459 kinds
.push_back("type");
460 for (auto [vec
, kind
] : zip(vecs
, kinds
)) {
461 // Handle Attribute/Type emission.
462 std::map
<std::string
, std::vector
<std::pair
<int64_t, Record
*>>> perType
;
463 for (auto kt
: llvm::enumerate(*vec
))
464 perType
[getCType(kt
.value())].emplace_back(kt
.index(), kt
.value());
465 for (const auto &jt
: perType
) {
466 for (auto kt
: jt
.second
)
467 gen
.emitParse(kind
, *std::get
<1>(kt
));
468 gen
.emitPrint(kind
, jt
.first
, jt
.second
);
470 gen
.emitParseDispatch(kind
, *vec
);
472 SmallVector
<std::string
> types
;
473 for (const auto &it
: perType
) {
474 types
.push_back(it
.first
);
476 gen
.emitPrintDispatch(kind
, types
);
482 static mlir::GenRegistration
483 genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
484 [](const RecordKeeper
&records
, raw_ostream
&os
) {
485 return emitBCRW(records
, os
);