1 //===- DialectTest.cpp - Dialect unit tests -------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/DialectInterface.h"
11 #include "gtest/gtest.h"
14 using namespace mlir::detail
;
17 struct TestDialect
: public Dialect
{
18 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect
)
20 static StringRef
getDialectNamespace() { return "test"; };
21 TestDialect(MLIRContext
*context
)
22 : Dialect(getDialectNamespace(), context
, TypeID::get
<TestDialect
>()) {}
24 struct AnotherTestDialect
: public Dialect
{
25 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect
)
27 static StringRef
getDialectNamespace() { return "test"; };
28 AnotherTestDialect(MLIRContext
*context
)
29 : Dialect(getDialectNamespace(), context
,
30 TypeID::get
<AnotherTestDialect
>()) {}
33 TEST(DialectDeathTest
, MultipleDialectsWithSameNamespace
) {
36 // Registering a dialect with the same namespace twice should result in a
38 context
.loadDialect
<TestDialect
>();
39 ASSERT_DEATH(context
.loadDialect
<AnotherTestDialect
>(), "");
42 struct SecondTestDialect
: public Dialect
{
43 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect
)
45 static StringRef
getDialectNamespace() { return "test2"; }
46 SecondTestDialect(MLIRContext
*context
)
47 : Dialect(getDialectNamespace(), context
,
48 TypeID::get
<SecondTestDialect
>()) {}
51 struct TestDialectInterfaceBase
52 : public DialectInterface::Base
<TestDialectInterfaceBase
> {
53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase
)
55 TestDialectInterfaceBase(Dialect
*dialect
) : Base(dialect
) {}
56 virtual int function() const { return 42; }
59 struct TestDialectInterface
: public TestDialectInterfaceBase
{
60 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface
)
62 using TestDialectInterfaceBase::TestDialectInterfaceBase
;
63 int function() const final
{ return 56; }
66 struct SecondTestDialectInterface
: public TestDialectInterfaceBase
{
67 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface
)
69 using TestDialectInterfaceBase::TestDialectInterfaceBase
;
70 int function() const final
{ return 78; }
73 TEST(Dialect
, DelayedInterfaceRegistration
) {
74 DialectRegistry registry
;
75 registry
.insert
<TestDialect
, SecondTestDialect
>();
77 // Delayed registration of an interface for TestDialect.
78 registry
.addExtension(+[](MLIRContext
*ctx
, TestDialect
*dialect
) {
79 dialect
->addInterfaces
<TestDialectInterface
>();
82 MLIRContext
context(registry
);
84 // Load the TestDialect and check that the interface got registered for it.
85 Dialect
*testDialect
= context
.getOrLoadDialect
<TestDialect
>();
86 ASSERT_TRUE(testDialect
!= nullptr);
87 auto *testDialectInterface
= dyn_cast
<TestDialectInterfaceBase
>(testDialect
);
88 EXPECT_TRUE(testDialectInterface
!= nullptr);
90 // Load the SecondTestDialect and check that the interface is not registered
92 Dialect
*secondTestDialect
= context
.getOrLoadDialect
<SecondTestDialect
>();
93 ASSERT_TRUE(secondTestDialect
!= nullptr);
94 auto *secondTestDialectInterface
=
95 dyn_cast
<SecondTestDialectInterface
>(secondTestDialect
);
96 EXPECT_TRUE(secondTestDialectInterface
== nullptr);
98 // Use the same mechanism as for delayed registration but for an already
99 // loaded dialect and check that the interface is now registered.
100 DialectRegistry secondRegistry
;
101 secondRegistry
.insert
<SecondTestDialect
>();
102 secondRegistry
.addExtension(
103 +[](MLIRContext
*ctx
, SecondTestDialect
*dialect
) {
104 dialect
->addInterfaces
<SecondTestDialectInterface
>();
106 context
.appendDialectRegistry(secondRegistry
);
107 secondTestDialectInterface
=
108 dyn_cast
<SecondTestDialectInterface
>(secondTestDialect
);
109 EXPECT_TRUE(secondTestDialectInterface
!= nullptr);
112 TEST(Dialect
, RepeatedDelayedRegistration
) {
113 // Set up the delayed registration.
114 DialectRegistry registry
;
115 registry
.insert
<TestDialect
>();
116 registry
.addExtension(+[](MLIRContext
*ctx
, TestDialect
*dialect
) {
117 dialect
->addInterfaces
<TestDialectInterface
>();
119 MLIRContext
context(registry
);
121 // Load the TestDialect and check that the interface got registered for it.
122 Dialect
*testDialect
= context
.getOrLoadDialect
<TestDialect
>();
123 ASSERT_TRUE(testDialect
!= nullptr);
124 auto *testDialectInterface
= dyn_cast
<TestDialectInterfaceBase
>(testDialect
);
125 EXPECT_TRUE(testDialectInterface
!= nullptr);
127 // Try adding the same dialect interface again and check that we don't crash
128 // on repeated interface registration.
129 DialectRegistry secondRegistry
;
130 secondRegistry
.insert
<TestDialect
>();
131 secondRegistry
.addExtension(+[](MLIRContext
*ctx
, TestDialect
*dialect
) {
132 dialect
->addInterfaces
<TestDialectInterface
>();
134 context
.appendDialectRegistry(secondRegistry
);
135 testDialectInterface
= dyn_cast
<TestDialectInterfaceBase
>(testDialect
);
136 EXPECT_TRUE(testDialectInterface
!= nullptr);
140 /// A dummy extension that increases a counter when being applied and
141 /// recursively adds additional extensions.
142 struct DummyExtension
: DialectExtension
<DummyExtension
, TestDialect
> {
143 DummyExtension(int *counter
, int numRecursive
)
144 : DialectExtension(), counter(counter
), numRecursive(numRecursive
) {}
146 void apply(MLIRContext
*ctx
, TestDialect
*dialect
) const final
{
148 DialectRegistry nestedRegistry
;
149 for (int i
= 0; i
< numRecursive
; ++i
)
150 nestedRegistry
.addExtension(
151 std::make_unique
<DummyExtension
>(counter
, /*numRecursive=*/0));
152 // Adding additional extensions may trigger a reallocation of the
153 // `extensions` vector in the dialect registry.
154 ctx
->appendDialectRegistry(nestedRegistry
);
163 TEST(Dialect
, NestedDialectExtension
) {
164 DialectRegistry registry
;
165 registry
.insert
<TestDialect
>();
167 // Add an extension that adds 100 more extensions.
169 registry
.addExtension(std::make_unique
<DummyExtension
>(&counter1
, 100));
170 // Add one more extension. This should not crash.
172 registry
.addExtension(std::make_unique
<DummyExtension
>(&counter2
, 0));
174 // Load dialect and apply extensions.
175 MLIRContext
context(registry
);
176 Dialect
*testDialect
= context
.getOrLoadDialect
<TestDialect
>();
177 ASSERT_TRUE(testDialect
!= nullptr);
179 // Extensions may be applied multiple times. Make sure that each expected
180 // extension was applied at least once.
181 EXPECT_GE(counter1
, 101);
182 EXPECT_GE(counter2
, 1);