1 //===- TestDialectInterfaces.cpp - Test dialect interface definitions -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "TestDialect.h"
11 #include "mlir/Interfaces/FoldInterfaces.h"
12 #include "mlir/Reducer/ReductionPatternInterface.h"
13 #include "mlir/Transforms/InliningUtils.h"
18 //===----------------------------------------------------------------------===//
19 // TestDialect Interfaces
20 //===----------------------------------------------------------------------===//
24 /// Testing the correctness of some traits.
26 llvm::is_detected
<OpTrait::has_implicit_terminator_t
,
27 SingleBlockImplicitTerminatorOp
>::value
,
28 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
29 static_assert(OpTrait::hasSingleBlockImplicitTerminator
<
30 SingleBlockImplicitTerminatorOp
>::value
,
31 "hasSingleBlockImplicitTerminator does not match "
32 "SingleBlockImplicitTerminatorOp");
34 struct TestResourceBlobManagerInterface
35 : public ResourceBlobManagerDialectInterfaceBase
<
36 TestDialectResourceBlobHandle
> {
37 using ResourceBlobManagerDialectInterfaceBase
<
38 TestDialectResourceBlobHandle
>::ResourceBlobManagerDialectInterfaceBase
;
42 enum test_encoding
{ k_attr_params
= 0, k_test_i32
= 99 };
45 // Test support for interacting with the Bytecode reader/writer.
46 struct TestBytecodeDialectInterface
: public BytecodeDialectInterface
{
47 using BytecodeDialectInterface::BytecodeDialectInterface
;
48 TestBytecodeDialectInterface(Dialect
*dialect
)
49 : BytecodeDialectInterface(dialect
) {}
51 LogicalResult
writeType(Type type
,
52 DialectBytecodeWriter
&writer
) const final
{
53 if (auto concreteType
= llvm::dyn_cast
<TestI32Type
>(type
)) {
54 writer
.writeVarInt(test_encoding::k_test_i32
);
60 Type
readType(DialectBytecodeReader
&reader
) const final
{
62 if (failed(reader
.readVarInt(encoding
)))
64 if (encoding
== test_encoding::k_test_i32
)
65 return TestI32Type::get(getContext());
69 LogicalResult
writeAttribute(Attribute attr
,
70 DialectBytecodeWriter
&writer
) const final
{
71 if (auto concreteAttr
= llvm::dyn_cast
<TestAttrParamsAttr
>(attr
)) {
72 writer
.writeVarInt(test_encoding::k_attr_params
);
73 writer
.writeVarInt(concreteAttr
.getV0());
74 writer
.writeVarInt(concreteAttr
.getV1());
80 Attribute
readAttribute(DialectBytecodeReader
&reader
) const final
{
81 auto versionOr
= reader
.getDialectVersion
<test::TestDialect
>();
82 // Assume current version if not available through the reader.
84 (succeeded(versionOr
))
85 ? *reinterpret_cast<const TestDialectVersion
*>(*versionOr
)
86 : TestDialectVersion();
87 if (version
.major_
< 2)
88 return readAttrOldEncoding(reader
);
89 if (version
.major_
== 2 && version
.minor_
== 0)
90 return readAttrNewEncoding(reader
);
91 // Forbid reading future versions by returning nullptr.
95 // Emit a specific version of the dialect.
96 void writeVersion(DialectBytecodeWriter
&writer
) const final
{
97 // Construct the current dialect version.
98 test::TestDialectVersion versionToEmit
;
100 // Check if a target version to emit was specified on the writer configs.
101 auto versionOr
= writer
.getDialectVersion
<test::TestDialect
>();
102 if (succeeded(versionOr
))
104 *reinterpret_cast<const test::TestDialectVersion
*>(*versionOr
);
105 writer
.writeVarInt(versionToEmit
.major_
); // major
106 writer
.writeVarInt(versionToEmit
.minor_
); // minor
109 std::unique_ptr
<DialectVersion
>
110 readVersion(DialectBytecodeReader
&reader
) const final
{
111 uint64_t major_
, minor_
;
112 if (failed(reader
.readVarInt(major_
)) || failed(reader
.readVarInt(minor_
)))
114 auto version
= std::make_unique
<TestDialectVersion
>();
115 version
->major_
= major_
;
116 version
->minor_
= minor_
;
120 LogicalResult
upgradeFromVersion(Operation
*topLevelOp
,
121 const DialectVersion
&version_
) const final
{
122 const auto &version
= static_cast<const TestDialectVersion
&>(version_
);
123 if ((version
.major_
== 2) && (version
.minor_
== 0))
125 if (version
.major_
> 2 || (version
.major_
== 2 && version
.minor_
> 0)) {
126 return topLevelOp
->emitError()
127 << "current test dialect version is 2.0, can't parse version: "
128 << version
.major_
<< "." << version
.minor_
;
130 // Prior version 2.0, the old op supported only a single attribute called
131 // "dimensions". We can perform the upgrade.
132 topLevelOp
->walk([](TestVersionedOpA op
) {
133 // Prior version 2.0, `readProperties` did not process the modifier
134 // attribute. Handle that according to the version here.
135 auto &prop
= op
.getProperties();
136 prop
.modifier
= BoolAttr::get(op
->getContext(), false);
142 Attribute
readAttrNewEncoding(DialectBytecodeReader
&reader
) const {
144 if (failed(reader
.readVarInt(encoding
)) ||
145 encoding
!= test_encoding::k_attr_params
)
147 // The new encoding has v0 first, v1 second.
149 if (failed(reader
.readVarInt(v0
)) || failed(reader
.readVarInt(v1
)))
151 return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0
),
152 static_cast<int>(v1
));
155 Attribute
readAttrOldEncoding(DialectBytecodeReader
&reader
) const {
157 if (failed(reader
.readVarInt(encoding
)) ||
158 encoding
!= test_encoding::k_attr_params
)
160 // The old encoding has v1 first, v0 second.
162 if (failed(reader
.readVarInt(v1
)) || failed(reader
.readVarInt(v0
)))
164 return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0
),
165 static_cast<int>(v1
));
169 // Test support for interacting with the AsmPrinter.
170 struct TestOpAsmInterface
: public OpAsmDialectInterface
{
171 using OpAsmDialectInterface::OpAsmDialectInterface
;
172 TestOpAsmInterface(Dialect
*dialect
, TestResourceBlobManagerInterface
&mgr
)
173 : OpAsmDialectInterface(dialect
), blobManager(mgr
) {}
175 //===------------------------------------------------------------------===//
177 //===------------------------------------------------------------------===//
179 AliasResult
getAlias(Attribute attr
, raw_ostream
&os
) const final
{
180 StringAttr strAttr
= dyn_cast
<StringAttr
>(attr
);
182 return AliasResult::NoAlias
;
184 // Check the contents of the string attribute to see what the test alias
186 std::optional
<StringRef
> aliasName
=
187 StringSwitch
<std::optional
<StringRef
>>(strAttr
.getValue())
188 .Case("alias_test:dot_in_name", StringRef("test.alias"))
189 .Case("alias_test:trailing_digit", StringRef("test_alias0"))
190 .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
191 .Case("alias_test:prefixed_symbol", StringRef("%test"))
192 .Case("alias_test:sanitize_conflict_a",
193 StringRef("test_alias_conflict0"))
194 .Case("alias_test:sanitize_conflict_b",
195 StringRef("test_alias_conflict0_"))
196 .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
197 .Default(std::nullopt
);
199 return AliasResult::NoAlias
;
202 return AliasResult::FinalAlias
;
205 AliasResult
getAlias(Type type
, raw_ostream
&os
) const final
{
206 if (auto tupleType
= dyn_cast
<TupleType
>(type
)) {
207 if (tupleType
.size() > 0 &&
208 llvm::all_of(tupleType
.getTypes(), [](Type elemType
) {
209 return isa
<SimpleAType
>(elemType
);
212 return AliasResult::FinalAlias
;
215 if (auto intType
= dyn_cast
<TestIntegerType
>(type
)) {
216 if (intType
.getSignedness() ==
217 TestIntegerType::SignednessSemantics::Unsigned
&&
218 intType
.getWidth() == 8) {
220 return AliasResult::FinalAlias
;
223 if (auto recType
= dyn_cast
<TestRecursiveType
>(type
)) {
224 if (recType
.getName() == "type_to_alias") {
225 // We only make alias for a specific recursive type.
227 return AliasResult::FinalAlias
;
230 if (auto recAliasType
= dyn_cast
<TestRecursiveAliasType
>(type
)) {
231 os
<< recAliasType
.getName();
232 return AliasResult::FinalAlias
;
234 return AliasResult::NoAlias
;
237 //===------------------------------------------------------------------===//
239 //===------------------------------------------------------------------===//
242 getResourceKey(const AsmDialectResourceHandle
&handle
) const override
{
243 return cast
<TestDialectResourceBlobHandle
>(handle
).getKey().str();
246 FailureOr
<AsmDialectResourceHandle
>
247 declareResource(StringRef key
) const final
{
248 return blobManager
.insert(key
);
251 LogicalResult
parseResource(AsmParsedResourceEntry
&entry
) const final
{
252 FailureOr
<AsmResourceBlob
> blob
= entry
.parseAsBlob();
256 // Update the blob for this entry.
257 blobManager
.update(entry
.getKey(), std::move(*blob
));
262 buildResources(Operation
*op
,
263 const SetVector
<AsmDialectResourceHandle
> &referencedResources
,
264 AsmResourceBuilder
&provider
) const final
{
265 blobManager
.buildResources(provider
, referencedResources
.getArrayRef());
269 /// The blob manager for the dialect.
270 TestResourceBlobManagerInterface
&blobManager
;
273 struct TestDialectFoldInterface
: public DialectFoldInterface
{
274 using DialectFoldInterface::DialectFoldInterface
;
276 /// Registered hook to check if the given region, which is attached to an
277 /// operation that is *not* isolated from above, should be used when
278 /// materializing constants.
279 bool shouldMaterializeInto(Region
*region
) const final
{
280 // If this is a one region operation, then insert into it.
281 return isa
<OneRegionOp
>(region
->getParentOp());
285 /// This class defines the interface for handling inlining with standard
287 struct TestInlinerInterface
: public DialectInlinerInterface
{
288 using DialectInlinerInterface::DialectInlinerInterface
;
290 //===--------------------------------------------------------------------===//
292 //===--------------------------------------------------------------------===//
294 bool isLegalToInline(Operation
*call
, Operation
*callable
,
295 bool wouldBeCloned
) const final
{
296 // Don't allow inlining calls that are marked `noinline`.
297 return !call
->hasAttr("noinline");
299 bool isLegalToInline(Region
*, Region
*, bool, IRMapping
&) const final
{
300 // Inlining into test dialect regions is legal.
303 bool isLegalToInline(Operation
*, Region
*, bool, IRMapping
&) const final
{
307 bool shouldAnalyzeRecursively(Operation
*op
) const final
{
308 // Analyze recursively if this is not a functional region operation, it
309 // froms a separate functional scope.
310 return !isa
<FunctionalRegionOp
>(op
);
313 //===--------------------------------------------------------------------===//
314 // Transformation Hooks
315 //===--------------------------------------------------------------------===//
317 /// Handle the given inlined terminator by replacing it with a new operation
319 void handleTerminator(Operation
*op
, ValueRange valuesToRepl
) const final
{
320 // Only handle "test.return" here.
321 auto returnOp
= dyn_cast
<TestReturnOp
>(op
);
325 // Replace the values directly with the return operands.
326 assert(returnOp
.getNumOperands() == valuesToRepl
.size());
327 for (const auto &it
: llvm::enumerate(returnOp
.getOperands()))
328 valuesToRepl
[it
.index()].replaceAllUsesWith(it
.value());
331 /// Attempt to materialize a conversion for a type mismatch between a call
332 /// from this dialect, and a callable region. This method should generate an
333 /// operation that takes 'input' as the only operand, and produces a single
334 /// result of 'resultType'. If a conversion can not be generated, nullptr
335 /// should be returned.
336 Operation
*materializeCallConversion(OpBuilder
&builder
, Value input
,
338 Location conversionLoc
) const final
{
339 // Only allow conversion for i16/i32 types.
340 if (!(resultType
.isSignlessInteger(16) ||
341 resultType
.isSignlessInteger(32)) ||
342 !(input
.getType().isSignlessInteger(16) ||
343 input
.getType().isSignlessInteger(32)))
345 return builder
.create
<TestCastOp
>(conversionLoc
, resultType
, input
);
348 Value
handleArgument(OpBuilder
&builder
, Operation
*call
, Operation
*callable
,
350 DictionaryAttr argumentAttrs
) const final
{
351 if (!argumentAttrs
.contains("test.handle_argument"))
353 return builder
.create
<TestTypeChangerOp
>(call
->getLoc(), argument
.getType(),
357 Value
handleResult(OpBuilder
&builder
, Operation
*call
, Operation
*callable
,
358 Value result
, DictionaryAttr resultAttrs
) const final
{
359 if (!resultAttrs
.contains("test.handle_result"))
361 return builder
.create
<TestTypeChangerOp
>(call
->getLoc(), result
.getType(),
365 void processInlinedCallBlocks(
367 iterator_range
<Region::iterator
> inlinedBlocks
) const final
{
368 if (!isa
<ConversionCallOp
>(call
))
371 // Set attributed on all ops in the inlined blocks.
372 for (Block
&block
: inlinedBlocks
) {
373 block
.walk([&](Operation
*op
) {
374 op
->setAttr("inlined_conversion", UnitAttr::get(call
->getContext()));
380 struct TestReductionPatternInterface
: public DialectReductionPatternInterface
{
382 TestReductionPatternInterface(Dialect
*dialect
)
383 : DialectReductionPatternInterface(dialect
) {}
385 void populateReductionPatterns(RewritePatternSet
&patterns
) const final
{
386 populateTestReductionPatterns(patterns
);
392 void TestDialect::registerInterfaces() {
393 auto &blobInterface
= addInterface
<TestResourceBlobManagerInterface
>();
394 addInterface
<TestOpAsmInterface
>(blobInterface
);
396 addInterfaces
<TestDialectFoldInterface
, TestInlinerInterface
,
397 TestReductionPatternInterface
, TestBytecodeDialectInterface
>();