1 //===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks --===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "TestDialect.h"
11 #include "mlir/Bytecode/BytecodeReader.h"
12 #include "mlir/Bytecode/BytecodeWriter.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/OperationSupport.h"
15 #include "mlir/Parser/Parser.h"
16 #include "mlir/Pass/Pass.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/MemoryBufferRef.h"
19 #include "llvm/Support/raw_ostream.h"
26 class TestDialectVersionParser
: public cl::parser
<test::TestDialectVersion
> {
28 TestDialectVersionParser(cl::Option
&o
)
29 : cl::parser
<test::TestDialectVersion
>(o
) {}
31 bool parse(cl::Option
&o
, StringRef
/*argName*/, StringRef arg
,
32 test::TestDialectVersion
&v
) {
33 long long major
, minor
;
34 if (getAsSignedInteger(arg
.split(".").first
, 10, major
))
35 return o
.error("Invalid argument '" + arg
);
36 if (getAsSignedInteger(arg
.split(".").second
, 10, minor
))
37 return o
.error("Invalid argument '" + arg
);
38 v
= test::TestDialectVersion(major
, minor
);
39 // Returns true on error.
42 static void print(raw_ostream
&os
, const test::TestDialectVersion
&v
) {
43 os
<< v
.major_
<< "." << v
.minor_
;
47 /// This is a test pass which uses callbacks to encode attributes and types in a
49 struct TestBytecodeRoundtripPass
50 : public PassWrapper
<TestBytecodeRoundtripPass
, OperationPass
<ModuleOp
>> {
51 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeRoundtripPass
)
53 StringRef
getArgument() const final
{ return "test-bytecode-roundtrip"; }
54 StringRef
getDescription() const final
{
55 return "Test pass to implement bytecode roundtrip tests.";
57 void getDependentDialects(DialectRegistry
®istry
) const override
{
58 registry
.insert
<test::TestDialect
>();
60 TestBytecodeRoundtripPass() = default;
61 TestBytecodeRoundtripPass(const TestBytecodeRoundtripPass
&) {}
63 LogicalResult
initialize(MLIRContext
*context
) override
{
64 testDialect
= context
->getOrLoadDialect
<test::TestDialect
>();
68 void runOnOperation() override
{
70 // Tests 0-5 implement a custom roundtrip with callbacks.
72 return runTest0(getOperation());
74 return runTest1(getOperation());
76 return runTest2(getOperation());
78 return runTest3(getOperation());
80 return runTest4(getOperation());
82 return runTest5(getOperation());
84 // test-kind 6 is a plain roundtrip with downgrade/upgrade to/from
86 return runTest6(getOperation());
88 llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
92 mlir::Pass::Option
<test::TestDialectVersion
, TestDialectVersionParser
>
93 targetVersion
{*this, "test-dialect-version",
95 "Specifies the test dialect version to emit and parse"),
96 cl::init(test::TestDialectVersion())};
98 mlir::Pass::Option
<int> testKind
{
99 *this, "test-kind", llvm::cl::desc("Specifies the test kind to execute"),
103 void doRoundtripWithConfigs(Operation
*op
,
104 const BytecodeWriterConfig
&writeConfig
,
105 const ParserConfig
&parseConfig
) {
106 std::string bytecode
;
107 llvm::raw_string_ostream
os(bytecode
);
108 if (failed(writeBytecodeToFile(op
, os
, writeConfig
))) {
109 op
->emitError() << "failed to write bytecode\n";
113 auto newModuleOp
= parseSourceString(StringRef(bytecode
), parseConfig
);
114 if (!newModuleOp
.get()) {
115 op
->emitError() << "failed to read bytecode\n";
119 // Print the module to the output stream, so that we can filecheck the
121 newModuleOp
->print(llvm::outs());
124 // Test0: let's assume that versions older than 2.0 were relying on a special
125 // integer attribute of a deprecated dialect called "funky". Assume that its
126 // encoding was made by two varInts, the first was the ID (999) and the second
127 // contained width and signedness info. We can emit it using a callback
128 // writing a custom encoding for the "funky" dialect group, and parse it back
129 // with a custom parser reading the same encoding in the same dialect group.
130 // Note that the ID 999 does not correspond to a valid integer type in the
131 // current encodings of builtin types.
132 void runTest0(Operation
*op
) {
133 auto newCtx
= std::make_shared
<MLIRContext
>();
134 test::TestDialectVersion targetEmissionVersion
= targetVersion
;
135 BytecodeWriterConfig writeConfig
;
136 // Set the emission version for the test dialect.
137 writeConfig
.setDialectVersion
<test::TestDialect
>(
138 std::make_unique
<test::TestDialectVersion
>(targetEmissionVersion
));
139 writeConfig
.attachTypeCallback(
140 [&](Type entryValue
, std::optional
<StringRef
> &dialectGroupName
,
141 DialectBytecodeWriter
&writer
) -> LogicalResult
{
142 // Do not override anything if version greater than 2.0.
143 auto versionOr
= writer
.getDialectVersion
<test::TestDialect
>();
144 assert(succeeded(versionOr
) && "expected reader to be able to access "
145 "the version for test dialect");
146 const auto *version
=
147 reinterpret_cast<const test::TestDialectVersion
*>(*versionOr
);
148 if (version
->major_
>= 2)
151 // For version less than 2.0, override the encoding of IntegerType.
152 if (auto type
= llvm::dyn_cast
<IntegerType
>(entryValue
)) {
153 llvm::outs() << "Overriding IntegerType encoding...\n";
154 dialectGroupName
= StringLiteral("funky");
155 writer
.writeVarInt(/* IntegerType */ 999);
156 writer
.writeVarInt(type
.getWidth() << 2 | type
.getSignedness());
161 newCtx
->appendDialectRegistry(op
->getContext()->getDialectRegistry());
162 newCtx
->allowUnregisteredDialects();
163 ParserConfig
parseConfig(newCtx
.get(), /*verifyAfterParse=*/true);
164 parseConfig
.getBytecodeReaderConfig().attachTypeCallback(
165 [&](DialectBytecodeReader
&reader
, StringRef dialectName
,
166 Type
&entry
) -> LogicalResult
{
167 // Get test dialect version from the version map.
168 auto versionOr
= reader
.getDialectVersion
<test::TestDialect
>();
169 assert(succeeded(versionOr
) && "expected reader to be able to access "
170 "the version for test dialect");
171 const auto *version
=
172 reinterpret_cast<const test::TestDialectVersion
*>(*versionOr
);
173 if (version
->major_
>= 2)
176 // `dialectName` is the name of the group we have the opportunity to
177 // override. In this case, override only the dialect group "funky",
178 // for which does not exist in memory.
179 if (dialectName
!= StringLiteral("funky"))
183 if (failed(reader
.readVarInt(encoding
)) || encoding
!= 999)
185 llvm::outs() << "Overriding parsing of IntegerType encoding...\n";
186 uint64_t widthAndSignedness
, width
;
187 IntegerType::SignednessSemantics signedness
;
188 if (succeeded(reader
.readVarInt(widthAndSignedness
)) &&
189 ((width
= widthAndSignedness
>> 2), true) &&
190 ((signedness
= static_cast<IntegerType::SignednessSemantics
>(
191 widthAndSignedness
& 0x3)),
193 entry
= IntegerType::get(reader
.getContext(), width
, signedness
);
194 // Return nullopt to fall through the rest of the parsing code path.
197 doRoundtripWithConfigs(op
, writeConfig
, parseConfig
);
200 // Test1: When writing bytecode, we override the encoding of TestI32Type with
201 // the encoding of builtin IntegerType. We can natively parse this without
202 // the use of a callback, relying on the existing builtin reader mechanism.
203 void runTest1(Operation
*op
) {
204 auto *builtin
= op
->getContext()->getLoadedDialect
<mlir::BuiltinDialect
>();
205 BytecodeDialectInterface
*iface
=
206 builtin
->getRegisteredInterface
<BytecodeDialectInterface
>();
207 BytecodeWriterConfig writeConfig
;
208 writeConfig
.attachTypeCallback(
209 [&](Type entryValue
, std::optional
<StringRef
> &dialectGroupName
,
210 DialectBytecodeWriter
&writer
) -> LogicalResult
{
211 // Emit TestIntegerType using the builtin dialect encoding.
212 if (llvm::isa
<test::TestI32Type
>(entryValue
)) {
213 llvm::outs() << "Overriding TestI32Type encoding...\n";
214 auto builtinI32Type
=
215 IntegerType::get(op
->getContext(), 32,
216 IntegerType::SignednessSemantics::Signless
);
217 // Specify that this type will need to be written as part of the
218 // builtin group. This will override the default dialect group of
219 // the attribute (test).
220 dialectGroupName
= StringLiteral("builtin");
221 if (succeeded(iface
->writeType(builtinI32Type
, writer
)))
226 // We natively parse the attribute as a builtin, so no callback needed.
227 ParserConfig
parseConfig(op
->getContext(), /*verifyAfterParse=*/true);
228 doRoundtripWithConfigs(op
, writeConfig
, parseConfig
);
231 // Test2: When writing bytecode, we write standard builtin IntegerTypes. At
232 // parsing, we use the encoding of IntegerType to intercept all i32. Then,
233 // instead of creating i32s, we assemble TestI32Type and return it.
234 void runTest2(Operation
*op
) {
235 auto *builtin
= op
->getContext()->getLoadedDialect
<mlir::BuiltinDialect
>();
236 BytecodeDialectInterface
*iface
=
237 builtin
->getRegisteredInterface
<BytecodeDialectInterface
>();
238 BytecodeWriterConfig writeConfig
;
239 ParserConfig
parseConfig(op
->getContext(), /*verifyAfterParse=*/true);
240 parseConfig
.getBytecodeReaderConfig().attachTypeCallback(
241 [&](DialectBytecodeReader
&reader
, StringRef dialectName
,
242 Type
&entry
) -> LogicalResult
{
243 if (dialectName
!= StringLiteral("builtin"))
245 Type builtinAttr
= iface
->readType(reader
);
246 if (auto integerType
=
247 llvm::dyn_cast_or_null
<IntegerType
>(builtinAttr
)) {
248 if (integerType
.getWidth() == 32 && integerType
.isSignless()) {
249 llvm::outs() << "Overriding parsing of TestI32Type encoding...\n";
250 entry
= test::TestI32Type::get(reader
.getContext());
255 doRoundtripWithConfigs(op
, writeConfig
, parseConfig
);
258 // Test3: When writing bytecode, we override the encoding of
259 // TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. We
260 // can natively parse this without the use of a callback, relying on the
261 // existing builtin reader mechanism.
262 void runTest3(Operation
*op
) {
263 auto *builtin
= op
->getContext()->getLoadedDialect
<mlir::BuiltinDialect
>();
264 BytecodeDialectInterface
*iface
=
265 builtin
->getRegisteredInterface
<BytecodeDialectInterface
>();
266 auto i32Type
= IntegerType::get(op
->getContext(), 32,
267 IntegerType::SignednessSemantics::Signless
);
268 BytecodeWriterConfig writeConfig
;
269 writeConfig
.attachAttributeCallback(
270 [&](Attribute entryValue
, std::optional
<StringRef
> &dialectGroupName
,
271 DialectBytecodeWriter
&writer
) -> LogicalResult
{
272 // Emit TestIntegerType using the builtin dialect encoding.
273 if (auto testParamAttrs
=
274 llvm::dyn_cast
<test::TestAttrParamsAttr
>(entryValue
)) {
275 llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n";
276 // Specify that this attribute will need to be written as part of
277 // the builtin group. This will override the default dialect group
278 // of the attribute (test).
279 dialectGroupName
= StringLiteral("builtin");
280 auto denseAttr
= DenseIntElementsAttr::get(
281 RankedTensorType::get({2}, i32Type
),
282 {testParamAttrs
.getV0(), testParamAttrs
.getV1()});
283 if (succeeded(iface
->writeAttribute(denseAttr
, writer
)))
288 // We natively parse the attribute as a builtin, so no callback needed.
289 ParserConfig
parseConfig(op
->getContext(), /*verifyAfterParse=*/false);
290 doRoundtripWithConfigs(op
, writeConfig
, parseConfig
);
293 // Test4: When writing bytecode, we write standard builtin
294 // DenseIntElementsAttr. At parsing, we use the encoding of
295 // DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of
296 // <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble
297 // TestAttrParamsAttr and return it.
298 void runTest4(Operation
*op
) {
299 auto *builtin
= op
->getContext()->getLoadedDialect
<mlir::BuiltinDialect
>();
300 BytecodeDialectInterface
*iface
=
301 builtin
->getRegisteredInterface
<BytecodeDialectInterface
>();
302 auto i32Type
= IntegerType::get(op
->getContext(), 32,
303 IntegerType::SignednessSemantics::Signless
);
304 BytecodeWriterConfig writeConfig
;
305 ParserConfig
parseConfig(op
->getContext(), /*verifyAfterParse=*/false);
306 parseConfig
.getBytecodeReaderConfig().attachAttributeCallback(
307 [&](DialectBytecodeReader
&reader
, StringRef dialectName
,
308 Attribute
&entry
) -> LogicalResult
{
309 // Override only the case where the return type of the builtin reader
310 // is an i32 and fall through on all the other cases, since we want to
311 // still use TestDialect normal codepath to parse the other types.
312 Attribute builtinAttr
= iface
->readAttribute(reader
);
314 llvm::dyn_cast_or_null
<DenseIntElementsAttr
>(builtinAttr
)) {
315 if (denseAttr
.getType().getShape() == ArrayRef
<int64_t>(2) &&
316 denseAttr
.getElementType() == i32Type
) {
318 << "Overriding parsing of TestAttrParamsAttr encoding...\n";
319 int v0
= denseAttr
.getValues
<IntegerAttr
>()[0].getInt();
320 int v1
= denseAttr
.getValues
<IntegerAttr
>()[1].getInt();
322 test::TestAttrParamsAttr::get(reader
.getContext(), v0
, v1
);
327 doRoundtripWithConfigs(op
, writeConfig
, parseConfig
);
330 // Test5: When writing bytecode, we want TestDialect to use nothing else than
331 // the builtin types and attributes and take full control of the encoding,
332 // returning failure if any type or attribute is not part of builtin.
333 void runTest5(Operation
*op
) {
334 auto *builtin
= op
->getContext()->getLoadedDialect
<mlir::BuiltinDialect
>();
335 BytecodeDialectInterface
*iface
=
336 builtin
->getRegisteredInterface
<BytecodeDialectInterface
>();
337 BytecodeWriterConfig writeConfig
;
338 writeConfig
.attachAttributeCallback(
339 [&](Attribute attr
, std::optional
<StringRef
> &dialectGroupName
,
340 DialectBytecodeWriter
&writer
) -> LogicalResult
{
341 return iface
->writeAttribute(attr
, writer
);
343 writeConfig
.attachTypeCallback(
344 [&](Type type
, std::optional
<StringRef
> &dialectGroupName
,
345 DialectBytecodeWriter
&writer
) -> LogicalResult
{
346 return iface
->writeType(type
, writer
);
348 ParserConfig
parseConfig(op
->getContext(), /*verifyAfterParse=*/false);
349 parseConfig
.getBytecodeReaderConfig().attachAttributeCallback(
350 [&](DialectBytecodeReader
&reader
, StringRef dialectName
,
351 Attribute
&entry
) -> LogicalResult
{
352 Attribute builtinAttr
= iface
->readAttribute(reader
);
358 parseConfig
.getBytecodeReaderConfig().attachTypeCallback(
359 [&](DialectBytecodeReader
&reader
, StringRef dialectName
,
360 Type
&entry
) -> LogicalResult
{
361 Type builtinType
= iface
->readType(reader
);
368 doRoundtripWithConfigs(op
, writeConfig
, parseConfig
);
371 LogicalResult
downgradeToVersion(Operation
*op
,
372 const test::TestDialectVersion
&version
) {
373 if ((version
.major_
== 2) && (version
.minor_
== 0))
375 if (version
.major_
> 2 || (version
.major_
== 2 && version
.minor_
> 0)) {
376 return op
->emitError() << "current test dialect version is 2.0, "
377 "can't downgrade to version: "
378 << version
.major_
<< "." << version
.minor_
;
380 // Prior version 2.0, the old op supported only a single attribute called
381 // "dimensions". We need to check that the modifier is false, otherwise we
382 // can't do the downgrade.
383 auto status
= op
->walk([&](test::TestVersionedOpA op
) {
384 auto &prop
= op
.getProperties();
385 if (prop
.modifier
.getValue()) {
386 op
->emitOpError() << "cannot downgrade to version " << version
.major_
387 << "." << version
.minor_
388 << " since the modifier is not compatible";
389 return WalkResult::interrupt();
391 llvm::outs() << "downgrading op...\n";
392 return WalkResult::advance();
394 return failure(status
.wasInterrupted());
397 // Test6: Downgrade IR to `targetVersion`, write to bytecode. Then, read and
398 // upgrade IR when back in memory. The module is expected to be unmodified at
399 // the end of the function.
400 void runTest6(Operation
*op
) {
401 test::TestDialectVersion targetEmissionVersion
= targetVersion
;
403 // Downgrade IR constructs before writing the IR to bytecode.
404 auto status
= downgradeToVersion(op
, targetEmissionVersion
);
405 assert(succeeded(status
) && "expected the downgrade to succeed");
408 BytecodeWriterConfig writeConfig
;
409 writeConfig
.setDialectVersion
<test::TestDialect
>(
410 std::make_unique
<test::TestDialectVersion
>(targetEmissionVersion
));
411 ParserConfig
parseConfig(op
->getContext(), /*verifyAfterParse=*/true);
412 doRoundtripWithConfigs(op
, writeConfig
, parseConfig
);
415 test::TestDialect
*testDialect
;
420 void registerTestBytecodeRoundtripPasses() {
421 PassRegistration
<TestBytecodeRoundtripPass
>();