1 //===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Support/IndentedOstream.h"
10 #include "mlir/TableGen/GenInfo.h"
11 #include "llvm/ADT/MapVector.h"
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/Support/CommandLine.h"
14 #include "llvm/Support/FormatVariadic.h"
15 #include "llvm/TableGen/Error.h"
16 #include "llvm/TableGen/Record.h"
21 static llvm::cl::OptionCategory
dialectGenCat("Options for -gen-bytecode");
22 static llvm::cl::opt
<std::string
>
23 selectedBcDialect("bytecode-dialect",
24 llvm::cl::desc("The dialect to gen for"),
25 llvm::cl::cat(dialectGenCat
), llvm::cl::CommaSeparated
);
29 /// Helper class to generate C++ bytecode parser helpers.
32 Generator(raw_ostream
&output
) : output(output
) {}
34 /// Returns whether successfully emitted attribute/type parsers.
35 void emitParse(StringRef kind
, Record
&x
);
37 /// Returns whether successfully emitted attribute/type printers.
38 void emitPrint(StringRef kind
, StringRef type
,
39 ArrayRef
<std::pair
<int64_t, Record
*>> vec
);
41 /// Emits parse dispatch table.
42 void emitParseDispatch(StringRef kind
, ArrayRef
<Record
*> vec
);
44 /// Emits print dispatch table.
45 void emitPrintDispatch(StringRef kind
, ArrayRef
<std::string
> vec
);
48 /// Emits parse calls to construct given kind.
49 void emitParseHelper(StringRef kind
, StringRef returnType
, StringRef builder
,
50 ArrayRef
<Init
*> args
, ArrayRef
<std::string
> argNames
,
51 StringRef failure
, mlir::raw_indented_ostream
&ios
);
53 /// Emits print instructions.
54 void emitPrintHelper(Record
*memberRec
, StringRef kind
, StringRef parent
,
55 StringRef name
, mlir::raw_indented_ostream
&ios
);
61 /// Helper to replace set of from strings to target in `s`.
62 /// Assumed: non-overlapping replacements.
63 static std::string
format(StringRef templ
,
64 std::map
<std::string
, std::string
> &&map
) {
65 std::string s
= templ
.str();
66 for (const auto &[from
, to
] : map
)
67 // All replacements start with $, don't treat as anchor.
68 s
= std::regex_replace(s
, std::regex("\\" + from
), to
);
72 /// Return string with first character capitalized.
73 static std::string
capitalize(StringRef str
) {
74 return ((Twine
)toUpper(str
[0]) + str
.drop_front()).str();
77 /// Return the C++ type for the given record.
78 static std::string
getCType(Record
*def
) {
79 std::string format
= "{0}";
80 if (def
->isSubClassOf("Array")) {
81 def
= def
->getValueAsDef("elemT");
82 format
= "SmallVector<{0}>";
85 StringRef cType
= def
->getValueAsString("cType");
87 if (def
->isAnonymous())
88 PrintFatalError(def
->getLoc(), "Unable to determine cType");
90 return formatv(format
.c_str(), def
->getName().str());
92 return formatv(format
.c_str(), cType
.str());
95 void Generator::emitParseDispatch(StringRef kind
, ArrayRef
<Record
*> vec
) {
96 mlir::raw_indented_ostream
os(output
);
98 R
"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
99 os
<< formatv(head
, capitalize(kind
));
100 auto funScope
= os
.scope(" {\n", "}\n\n");
102 os
<< "uint64_t kind;\n";
103 os
<< "if (failed(reader.readVarInt(kind)))\n"
104 << " return " << capitalize(kind
) << "();\n";
105 os
<< "switch (kind) ";
107 auto switchScope
= os
.scope("{\n", "}\n");
108 for (const auto &it
: llvm::enumerate(vec
)) {
109 os
<< formatv("case {1}:\n return read{0}(context, reader);\n",
110 it
.value()->getName(), it
.index());
113 << " reader.emitError() << \"unknown attribute code: \" "
115 << " return " << capitalize(kind
) << "();\n";
117 os
<< "return " << capitalize(kind
) << "();\n";
120 void Generator::emitParse(StringRef kind
, Record
&x
) {
122 R
"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
123 mlir::raw_indented_ostream
os(output
);
124 std::string returnType
= getCType(&x
);
125 os
<< formatv(head
, returnType
, x
.getName());
126 DagInit
*members
= x
.getValueAsDag("members");
127 SmallVector
<std::string
> argNames
=
128 llvm::to_vector(map_range(members
->getArgNames(), [](StringInit
*init
) {
129 return init
->getAsUnquotedString();
131 StringRef builder
= x
.getValueAsString("cBuilder");
132 emitParseHelper(kind
, returnType
, builder
, members
->getArgs(), argNames
,
133 returnType
+ "()", os
);
137 void printParseConditional(mlir::raw_indented_ostream
&ios
,
138 ArrayRef
<Init
*> args
,
139 ArrayRef
<std::string
> argNames
) {
141 auto parenScope
= ios
.scope("(", ") {");
144 auto listHelperName
= [](StringRef name
) {
145 return formatv("read{0}", capitalize(name
));
149 llvm::to_vector(make_filter_range(args
, [](Init
*const attr
) {
150 Record
*def
= cast
<DefInit
>(attr
)->getDef();
151 if (def
->isSubClassOf("Array"))
153 return !def
->getValueAsString("cParser").empty();
157 zip(parsedArgs
, argNames
),
158 [&](std::tuple
<llvm::Init
*&, const std::string
&> it
) {
159 Record
*attr
= cast
<DefInit
>(std::get
<0>(it
))->getDef();
161 if (auto optParser
= attr
->getValueAsOptionalString("cParser")) {
163 } else if (attr
->isSubClassOf("Array")) {
164 Record
*def
= attr
->getValueAsDef("elemT");
165 bool composite
= def
->isSubClassOf("CompositeBytecode");
166 if (!composite
&& def
->isSubClassOf("AttributeKind"))
167 parser
= "succeeded($_reader.readAttributes($_var))";
168 else if (!composite
&& def
->isSubClassOf("TypeKind"))
169 parser
= "succeeded($_reader.readTypes($_var))";
171 parser
= ("succeeded($_reader.readList($_var, " +
172 listHelperName(std::get
<1>(it
)) + "))")
175 PrintFatalError(attr
->getLoc(), "No parser specified");
177 std::string type
= getCType(attr
);
178 ios
<< format(parser
, {{"$_reader", "reader"},
179 {"$_resultType", type
},
180 {"$_var", std::get
<1>(it
)}});
182 [&]() { ios
<< " &&\n"; });
185 void Generator::emitParseHelper(StringRef kind
, StringRef returnType
,
186 StringRef builder
, ArrayRef
<Init
*> args
,
187 ArrayRef
<std::string
> argNames
,
189 mlir::raw_indented_ostream
&ios
) {
190 auto funScope
= ios
.scope("{\n", "}");
193 ios
<< formatv("return get<{0}>(context);\n", returnType
);
198 std::string lastCType
= "";
199 for (auto [arg
, name
] : zip(args
, argNames
)) {
200 DefInit
*first
= dyn_cast
<DefInit
>(arg
);
202 PrintFatalError("Unexpected type for " + name
);
203 Record
*def
= first
->getDef();
205 // Create variable decls, if there are a block of same type then create
206 // comma separated list of them.
207 std::string cType
= getCType(def
);
208 if (lastCType
== cType
) {
211 if (!lastCType
.empty())
220 // Returns the name of the helper used in list parsing. E.g., the name of the
221 // lambda passed to array parsing.
222 auto listHelperName
= [](StringRef name
) {
223 return formatv("read{0}", capitalize(name
));
226 // Emit list helper functions.
227 for (auto [arg
, name
] : zip(args
, argNames
)) {
228 Record
*attr
= cast
<DefInit
>(arg
)->getDef();
229 if (!attr
->isSubClassOf("Array"))
232 // TODO: Dedupe readers.
233 Record
*def
= attr
->getValueAsDef("elemT");
234 if (!def
->isSubClassOf("CompositeBytecode") &&
235 (def
->isSubClassOf("AttributeKind") || def
->isSubClassOf("TypeKind")))
238 std::string returnType
= getCType(def
);
239 ios
<< "auto " << listHelperName(name
) << " = [&]() -> FailureOr<"
240 << returnType
<< "> ";
241 SmallVector
<Init
*> args
;
242 SmallVector
<std::string
> argNames
;
243 if (def
->isSubClassOf("CompositeBytecode")) {
244 DagInit
*members
= def
->getValueAsDag("members");
245 args
= llvm::to_vector(members
->getArgs());
246 argNames
= llvm::to_vector(
247 map_range(members
->getArgNames(), [](StringInit
*init
) {
248 return init
->getAsUnquotedString();
251 args
= {def
->getDefInit()};
254 StringRef builder
= def
->getValueAsString("cBuilder");
255 emitParseHelper(kind
, returnType
, builder
, args
, argNames
, "failure()",
260 // Print parse conditional.
261 printParseConditional(ios
, args
, argNames
);
263 // Compute args to pass to create method.
264 auto passedArgs
= llvm::to_vector(make_filter_range(
265 argNames
, [](StringRef str
) { return !str
.starts_with("_"); }));
267 raw_string_ostream
argStream(argStr
);
268 interleaveComma(passedArgs
, argStream
,
269 [&](const std::string
&str
) { argStream
<< str
; });
270 // Return the invoked constructor.
272 << format(builder
, {{"$_resultType", returnType
.str()},
273 {"$_args", argStream
.str()}})
277 // TODO: Emit error in debug.
278 // This assumes the result types in error case can always be empty
280 ios
<< "}\nreturn " << failure
<< ";\n";
283 void Generator::emitPrint(StringRef kind
, StringRef type
,
284 ArrayRef
<std::pair
<int64_t, Record
*>> vec
) {
286 R
"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
287 mlir::raw_indented_ostream
os(output
);
288 os
<< formatv(head
, type
, kind
);
289 auto funScope
= os
.scope("{\n", "}\n\n");
291 // Check that predicates specified if multiple bytecode instances.
292 for (llvm::Record
*rec
: make_second_range(vec
)) {
293 StringRef pred
= rec
->getValueAsString("printerPredicate");
294 if (vec
.size() > 1 && pred
.empty()) {
295 for (auto [index
, rec
] : vec
) {
297 StringRef pred
= rec
->getValueAsString("printerPredicate");
298 if (vec
.size() > 1 && pred
.empty())
299 PrintError(rec
->getLoc(),
300 "Requires parsing predicate given common cType");
302 PrintFatalError("Unspecified for shared cType " + type
);
306 for (auto [index
, rec
] : vec
) {
307 StringRef pred
= rec
->getValueAsString("printerPredicate");
309 os
<< "if (" << format(pred
, {{"$_val", kind
.str()}}) << ") {\n";
313 os
<< "writer.writeVarInt(/* " << rec
->getName() << " */ " << index
316 auto *members
= rec
->getValueAsDag("members");
317 for (auto [arg
, name
] :
318 llvm::zip(members
->getArgs(), members
->getArgNames())) {
319 DefInit
*def
= dyn_cast
<DefInit
>(arg
);
321 Record
*memberRec
= def
->getDef();
322 emitPrintHelper(memberRec
, kind
, kind
, name
->getAsUnquotedString(), os
);
332 void Generator::emitPrintHelper(Record
*memberRec
, StringRef kind
,
333 StringRef parent
, StringRef name
,
334 mlir::raw_indented_ostream
&ios
) {
336 if (auto cGetter
= memberRec
->getValueAsOptionalString("cGetter");
337 cGetter
&& !cGetter
->empty()) {
340 {{"$_attrType", parent
.str()},
341 {"$_member", name
.str()},
342 {"$_getMember", "get" + convertToCamelFromSnakeCase(name
, true)}});
345 formatv("{0}.get{1}()", parent
, convertToCamelFromSnakeCase(name
, true))
349 if (memberRec
->isSubClassOf("Array")) {
350 Record
*def
= memberRec
->getValueAsDef("elemT");
351 if (!def
->isSubClassOf("CompositeBytecode")) {
352 if (def
->isSubClassOf("AttributeKind")) {
353 ios
<< "writer.writeAttributes(" << getter
<< ");\n";
356 if (def
->isSubClassOf("TypeKind")) {
357 ios
<< "writer.writeTypes(" << getter
<< ");\n";
361 std::string returnType
= getCType(def
);
362 ios
<< "writer.writeList(" << getter
<< ", [&](" << returnType
<< " "
364 auto lambdaScope
= ios
.scope("{\n", "});\n");
365 return emitPrintHelper(def
, kind
, kind
, kind
, ios
);
367 if (memberRec
->isSubClassOf("CompositeBytecode")) {
368 auto *members
= memberRec
->getValueAsDag("members");
369 for (auto [arg
, argName
] :
370 zip(members
->getArgs(), members
->getArgNames())) {
371 DefInit
*def
= dyn_cast
<DefInit
>(arg
);
373 emitPrintHelper(def
->getDef(), kind
, parent
,
374 argName
->getAsUnquotedString(), ios
);
378 if (std::string printer
= memberRec
->getValueAsString("cPrinter").str();
380 ios
<< format(printer
, {{"$_writer", "writer"},
381 {"$_name", kind
.str()},
382 {"$_getter", getter
}})
386 void Generator::emitPrintDispatch(StringRef kind
, ArrayRef
<std::string
> vec
) {
387 mlir::raw_indented_ostream
os(output
);
388 char const *head
= R
"(static LogicalResult write{0}({0} {1},
389 DialectBytecodeWriter &writer))";
390 os
<< formatv(head
, capitalize(kind
), kind
);
391 auto funScope
= os
.scope(" {\n", "}\n\n");
393 os
<< "return TypeSwitch<" << capitalize(kind
) << ", LogicalResult>(" << kind
395 auto switchScope
= os
.scope("", "");
396 for (StringRef type
: vec
) {
397 os
<< "\n.Case([&](" << type
<< " t)";
398 auto caseScope
= os
.scope(" {\n", "})");
399 os
<< "return write(t, writer), success();\n";
401 os
<< "\n.Default([&](" << capitalize(kind
) << ") { return failure(); });\n";
405 /// Container of Attribute or Type for Dialect.
407 std::vector
<Record
*> attr
, type
;
411 static bool emitBCRW(const RecordKeeper
&records
, raw_ostream
&os
) {
412 MapVector
<StringRef
, AttrOrType
> dialectAttrOrType
;
413 for (auto &it
: records
.getAllDerivedDefinitions("DialectAttributes")) {
414 if (!selectedBcDialect
.empty() &&
415 it
->getValueAsString("dialect") != selectedBcDialect
)
417 dialectAttrOrType
[it
->getValueAsString("dialect")].attr
=
418 it
->getValueAsListOfDefs("elems");
420 for (auto &it
: records
.getAllDerivedDefinitions("DialectTypes")) {
421 if (!selectedBcDialect
.empty() &&
422 it
->getValueAsString("dialect") != selectedBcDialect
)
424 dialectAttrOrType
[it
->getValueAsString("dialect")].type
=
425 it
->getValueAsListOfDefs("elems");
428 if (dialectAttrOrType
.size() != 1)
429 PrintFatalError("Single dialect per invocation required (either only "
430 "one in input file or specified via dialect option)");
432 auto it
= dialectAttrOrType
.front();
435 SmallVector
<std::vector
<Record
*> *, 2> vecs
;
436 SmallVector
<std::string
, 2> kinds
;
437 vecs
.push_back(&it
.second
.attr
);
438 kinds
.push_back("attribute");
439 vecs
.push_back(&it
.second
.type
);
440 kinds
.push_back("type");
441 for (auto [vec
, kind
] : zip(vecs
, kinds
)) {
442 // Handle Attribute/Type emission.
443 std::map
<std::string
, std::vector
<std::pair
<int64_t, Record
*>>> perType
;
444 for (auto kt
: llvm::enumerate(*vec
))
445 perType
[getCType(kt
.value())].emplace_back(kt
.index(), kt
.value());
446 for (const auto &jt
: perType
) {
447 for (auto kt
: jt
.second
)
448 gen
.emitParse(kind
, *std::get
<1>(kt
));
449 gen
.emitPrint(kind
, jt
.first
, jt
.second
);
451 gen
.emitParseDispatch(kind
, *vec
);
453 SmallVector
<std::string
> types
;
454 for (const auto &it
: perType
) {
455 types
.push_back(it
.first
);
457 gen
.emitPrintDispatch(kind
, types
);
463 static mlir::GenRegistration
464 genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
465 [](const RecordKeeper
&records
, raw_ostream
&os
) {
466 return emitBCRW(records
, os
);