[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / unittests / IR / InterfaceAttachmentTest.cpp
blobb6066dd5685dc616ba36158f55c83d82311e57f0
1 //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements the tests for attaching interfaces to attributes and types
10 // without having to specify them on the attribute or type class directly.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/IR/BuiltinAttributes.h"
15 #include "mlir/IR/BuiltinDialect.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "gtest/gtest.h"
20 #include "../../test/lib/Dialect/Test/TestAttributes.h"
21 #include "../../test/lib/Dialect/Test/TestDialect.h"
22 #include "../../test/lib/Dialect/Test/TestOps.h"
23 #include "../../test/lib/Dialect/Test/TestTypes.h"
24 #include "mlir/IR/OwningOpRef.h"
26 using namespace mlir;
27 using namespace test;
29 namespace {
31 /// External interface model for the integer type. Only provides non-default
32 /// methods.
33 struct Model
34 : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> {
35 unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
36 return type.getIntOrFloatBitWidth() + arg;
39 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
42 /// External interface model for the float type. Provides non-deafult and
43 /// overrides default methods.
44 struct OverridingModel
45 : public TestExternalTypeInterface::ExternalModel<OverridingModel,
46 FloatType> {
47 unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
48 return type.getIntOrFloatBitWidth() + arg;
51 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
53 unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const {
54 return 128;
57 static unsigned staticGetArgument(unsigned arg) { return 420; }
60 TEST(InterfaceAttachment, Type) {
61 MLIRContext context;
63 // Check that the type has no interface.
64 IntegerType i8 = IntegerType::get(&context, 8);
65 ASSERT_FALSE(isa<TestExternalTypeInterface>(i8));
67 // Attach an interface and check that the type now has the interface.
68 IntegerType::attachInterface<Model>(context);
69 TestExternalTypeInterface iface = dyn_cast<TestExternalTypeInterface>(i8);
70 ASSERT_TRUE(iface != nullptr);
71 EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u);
72 EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u);
73 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u);
74 EXPECT_EQ(iface.staticGetArgument(17), 17u);
76 // Same, but with the default implementation overridden.
77 FloatType flt = Float32Type::get(&context);
78 ASSERT_FALSE(isa<TestExternalTypeInterface>(flt));
79 Float32Type::attachInterface<OverridingModel>(context);
80 iface = dyn_cast<TestExternalTypeInterface>(flt);
81 ASSERT_TRUE(iface != nullptr);
82 EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u);
83 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u);
84 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u);
85 EXPECT_EQ(iface.staticGetArgument(17), 420u);
87 // Other contexts shouldn't have the attribute attached.
88 MLIRContext other;
89 IntegerType i8other = IntegerType::get(&other, 8);
90 EXPECT_FALSE(isa<TestExternalTypeInterface>(i8other));
93 /// External interface model for the test type from the test dialect.
94 struct TestTypeModel
95 : public TestExternalTypeInterface::ExternalModel<TestTypeModel,
96 test::TestType> {
97 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; }
99 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; }
102 TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
103 // Put the interface in the registry.
104 DialectRegistry registry;
105 registry.insert<test::TestDialect>();
106 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
107 test::TestType::attachInterface<TestTypeModel>(*ctx);
110 // Check that when a context is constructed with the given registry, the type
111 // interface gets registered.
112 MLIRContext context(registry);
113 context.loadDialect<test::TestDialect>();
114 test::TestType testType = test::TestType::get(&context);
115 auto iface = dyn_cast<TestExternalTypeInterface>(testType);
116 ASSERT_TRUE(iface != nullptr);
117 EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
118 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
121 TEST(InterfaceAttachment, TypeDelayedContextAppend) {
122 // Put the interface in the registry.
123 DialectRegistry registry;
124 registry.insert<test::TestDialect>();
125 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
126 test::TestType::attachInterface<TestTypeModel>(*ctx);
129 // Check that when the registry gets appended to the context, the interface
130 // becomes available for objects in loaded dialects.
131 MLIRContext context;
132 context.loadDialect<test::TestDialect>();
133 test::TestType testType = test::TestType::get(&context);
134 EXPECT_FALSE(isa<TestExternalTypeInterface>(testType));
135 context.appendDialectRegistry(registry);
136 EXPECT_TRUE(isa<TestExternalTypeInterface>(testType));
139 TEST(InterfaceAttachment, RepeatedRegistration) {
140 DialectRegistry registry;
141 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
142 IntegerType::attachInterface<Model>(*ctx);
144 MLIRContext context(registry);
146 // Should't fail on repeated registration through the dialect registry.
147 context.appendDialectRegistry(registry);
150 TEST(InterfaceAttachment, TypeBuiltinDelayed) {
151 // Builtin dialect needs to registration or loading, but delayed interface
152 // registration must still work.
153 DialectRegistry registry;
154 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
155 IntegerType::attachInterface<Model>(*ctx);
158 MLIRContext context(registry);
159 IntegerType i16 = IntegerType::get(&context, 16);
160 EXPECT_TRUE(isa<TestExternalTypeInterface>(i16));
162 MLIRContext initiallyEmpty;
163 IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
164 EXPECT_FALSE(isa<TestExternalTypeInterface>(i32));
165 initiallyEmpty.appendDialectRegistry(registry);
166 EXPECT_TRUE(isa<TestExternalTypeInterface>(i32));
169 /// The interface provides a default implementation that expects
170 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
171 /// just derives from the ExternalModel.
172 struct TestExternalFallbackTypeIntegerModel
173 : public TestExternalFallbackTypeInterface::ExternalModel<
174 TestExternalFallbackTypeIntegerModel, IntegerType> {};
176 /// The interface provides a default implementation that expects
177 /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use
178 /// FallbackModel instead to override this and make sure the code still compiles
179 /// because we never instantiate the ExternalModel class template with a
180 /// template argument that would have led to compilation failures.
181 struct TestExternalFallbackTypeVectorModel
182 : public TestExternalFallbackTypeInterface::FallbackModel<
183 TestExternalFallbackTypeVectorModel> {
184 unsigned getBitwidth(Type type) const {
185 IntegerType elementType =
186 dyn_cast_or_null<IntegerType>(cast<VectorType>(type).getElementType());
187 return elementType ? elementType.getWidth() : 0;
191 TEST(InterfaceAttachment, Fallback) {
192 MLIRContext context;
194 // Just check that we can attach the interface.
195 IntegerType i8 = IntegerType::get(&context, 8);
196 ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(i8));
197 IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context);
198 ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(i8));
200 // Call the method so it is guaranteed not to be instantiated.
201 VectorType vec = VectorType::get({42}, i8);
202 ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(vec));
203 VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context);
204 ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(vec));
205 EXPECT_EQ(cast<TestExternalFallbackTypeInterface>(vec).getBitwidth(), 8u);
208 /// External model for attribute interfaces.
209 struct TestExternalIntegerAttrModel
210 : public TestExternalAttrInterface::ExternalModel<
211 TestExternalIntegerAttrModel, IntegerAttr> {
212 const Dialect *getDialectPtr(Attribute attr) const {
213 return &cast<IntegerAttr>(attr).getDialect();
216 static int getSomeNumber() { return 42; }
219 TEST(InterfaceAttachment, Attribute) {
220 MLIRContext context;
222 // Attribute interfaces use the exact same mechanism as types, so just check
223 // that the basics work for attributes.
224 IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
225 ASSERT_FALSE(isa<TestExternalAttrInterface>(attr));
226 IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
227 auto iface = dyn_cast<TestExternalAttrInterface>(attr);
228 ASSERT_TRUE(iface != nullptr);
229 EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
230 EXPECT_EQ(iface.getSomeNumber(), 42);
233 /// External model for an interface attachable to a non-builtin attribute.
234 struct TestExternalSimpleAAttrModel
235 : public TestExternalAttrInterface::ExternalModel<
236 TestExternalSimpleAAttrModel, test::SimpleAAttr> {
237 const Dialect *getDialectPtr(Attribute attr) const {
238 return &attr.getDialect();
241 static int getSomeNumber() { return 21; }
244 TEST(InterfaceAttachmentTest, AttributeDelayed) {
245 // Attribute interfaces use the exact same mechanism as types, so just check
246 // that the delayed registration work for attributes.
247 DialectRegistry registry;
248 registry.insert<test::TestDialect>();
249 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
250 test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx);
253 MLIRContext context(registry);
254 context.loadDialect<test::TestDialect>();
255 auto attr = test::SimpleAAttr::get(&context);
256 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
258 MLIRContext initiallyEmpty;
259 initiallyEmpty.loadDialect<test::TestDialect>();
260 attr = test::SimpleAAttr::get(&initiallyEmpty);
261 EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
262 initiallyEmpty.appendDialectRegistry(registry);
263 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
266 /// External interface model for the module operation. Only provides non-default
267 /// methods.
268 struct TestExternalOpModel
269 : public TestExternalOpInterface::ExternalModel<TestExternalOpModel,
270 ModuleOp> {
271 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
272 return op->getName().getStringRef().size() + arg;
275 static unsigned getNameLengthPlusArgTwice(unsigned arg) {
276 return ModuleOp::getOperationName().size() + 2 * arg;
280 /// External interface model for the func operation. Provides non-deafult and
281 /// overrides default methods.
282 struct TestExternalOpOverridingModel
283 : public TestExternalOpInterface::FallbackModel<
284 TestExternalOpOverridingModel> {
285 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
286 return op->getName().getStringRef().size() + arg;
289 static unsigned getNameLengthPlusArgTwice(unsigned arg) {
290 return UnrealizedConversionCastOp::getOperationName().size() + 2 * arg;
293 unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const {
294 return 42;
297 static unsigned getNameLengthMinusArg(unsigned arg) { return 21; }
300 TEST(InterfaceAttachment, Operation) {
301 MLIRContext context;
302 OpBuilder builder(&context);
304 // Initially, the operation doesn't have the interface.
305 OwningOpRef<ModuleOp> moduleOp =
306 builder.create<ModuleOp>(UnknownLoc::get(&context));
307 ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation()));
309 // We can attach an external interface and now the operaiton has it.
310 ModuleOp::attachInterface<TestExternalOpModel>(context);
311 auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation());
312 ASSERT_TRUE(iface != nullptr);
313 EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u);
314 EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u);
315 EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u);
316 EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u);
318 // Default implementation can be overridden.
319 OwningOpRef<UnrealizedConversionCastOp> castOp =
320 builder.create<UnrealizedConversionCastOp>(UnknownLoc::get(&context),
321 TypeRange(), ValueRange());
322 ASSERT_FALSE(isa<TestExternalOpInterface>(castOp->getOperation()));
323 UnrealizedConversionCastOp::attachInterface<TestExternalOpOverridingModel>(
324 context);
325 iface = dyn_cast<TestExternalOpInterface>(castOp->getOperation());
326 ASSERT_TRUE(iface != nullptr);
327 EXPECT_EQ(iface.getNameLengthPlusArg(10), 44u);
328 EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u);
329 EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 50u);
330 EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u);
332 // Another context doesn't have the interfaces registered.
333 MLIRContext other;
334 OwningOpRef<ModuleOp> otherModuleOp =
335 ModuleOp::create(UnknownLoc::get(&other));
336 ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation()));
339 template <class ConcreteOp>
340 struct TestExternalTestOpModel
341 : public TestExternalOpInterface::ExternalModel<
342 TestExternalTestOpModel<ConcreteOp>, ConcreteOp> {
343 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
344 return op->getName().getStringRef().size() + arg;
347 static unsigned getNameLengthPlusArgTwice(unsigned arg) {
348 return ConcreteOp::getOperationName().size() + 2 * arg;
352 TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
353 DialectRegistry registry;
354 registry.insert<test::TestDialect>();
355 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
356 ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
358 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
359 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
360 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
363 // Construct the context directly from a registry. The interfaces are
364 // expected to be readily available on operations.
365 MLIRContext context(registry);
366 context.loadDialect<test::TestDialect>();
368 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
369 OpBuilder builder(module->getBody(), module->getBody()->begin());
370 auto opJ =
371 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
372 auto opH =
373 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
374 auto opI =
375 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
377 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
378 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
379 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
380 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
383 TEST(InterfaceAttachment, OperationDelayedContextAppend) {
384 DialectRegistry registry;
385 registry.insert<test::TestDialect>();
386 registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
387 ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
389 registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
390 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
391 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
394 // Construct the context, create ops, and only then append the registry. The
395 // interfaces are expected to be available after appending the registry.
396 MLIRContext context;
397 context.loadDialect<test::TestDialect>();
399 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
400 OpBuilder builder(module->getBody(), module->getBody()->begin());
401 auto opJ =
402 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
403 auto opH =
404 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
405 auto opI =
406 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
408 EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation()));
409 EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation()));
410 EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation()));
411 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
413 context.appendDialectRegistry(registry);
415 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
416 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
417 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
418 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
421 TEST(InterfaceAttachmentTest, PromisedInterfaces) {
422 // Attribute interfaces use the exact same mechanism as types, so just check
423 // that the promise mechanism works for attributes.
424 MLIRContext context;
425 auto *testDialect = context.getOrLoadDialect<test::TestDialect>();
426 auto attr = test::SimpleAAttr::get(&context);
428 // `SimpleAAttr` doesn't implement nor promises the
429 // `TestExternalAttrInterface` interface.
430 EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
431 EXPECT_FALSE(
432 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
434 // Add a promise `TestExternalAttrInterface`.
435 testDialect->declarePromisedInterface<TestExternalAttrInterface,
436 test::SimpleAAttr>();
437 EXPECT_TRUE(
438 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
440 // Attach the interface.
441 test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context);
442 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
443 EXPECT_TRUE(
444 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
447 } // namespace