1 //===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===//
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/ExtensibleDialect.h"
10 #include "mlir/IR/AttributeSupport.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "mlir/IR/OperationSupport.h"
13 #include "mlir/IR/StorageUniquerSupport.h"
17 //===----------------------------------------------------------------------===//
18 // Dynamic types and attributes shared functions
19 //===----------------------------------------------------------------------===//
21 /// Default parser for dynamic attribute or type parameters.
22 /// Parse in the format '(<>)?' or '<attr (,attr)*>'.
24 typeOrAttrParser(AsmParser
&parser
, SmallVectorImpl
<Attribute
> &parsedParams
) {
26 if (parser
.parseOptionalLess() || !parser
.parseOptionalGreater())
30 if (parser
.parseAttribute(attr
))
32 parsedParams
.push_back(attr
);
34 while (parser
.parseOptionalGreater()) {
36 if (parser
.parseComma() || parser
.parseAttribute(attr
))
38 parsedParams
.push_back(attr
);
44 /// Default printer for dynamic attribute or type parameters.
45 /// Print in the format '(<>)?' or '<attr (,attr)*>'.
46 static void typeOrAttrPrinter(AsmPrinter
&printer
, ArrayRef
<Attribute
> params
) {
51 interleaveComma(params
, printer
.getStream());
55 //===----------------------------------------------------------------------===//
57 //===----------------------------------------------------------------------===//
59 std::unique_ptr
<DynamicTypeDefinition
>
60 DynamicTypeDefinition::get(StringRef name
, ExtensibleDialect
*dialect
,
61 VerifierFn
&&verifier
) {
62 return DynamicTypeDefinition::get(name
, dialect
, std::move(verifier
),
63 typeOrAttrParser
, typeOrAttrPrinter
);
66 std::unique_ptr
<DynamicTypeDefinition
>
67 DynamicTypeDefinition::get(StringRef name
, ExtensibleDialect
*dialect
,
68 VerifierFn
&&verifier
, ParserFn
&&parser
,
69 PrinterFn
&&printer
) {
70 return std::unique_ptr
<DynamicTypeDefinition
>(
71 new DynamicTypeDefinition(name
, dialect
, std::move(verifier
),
72 std::move(parser
), std::move(printer
)));
75 DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef
,
76 ExtensibleDialect
*dialect
,
77 VerifierFn
&&verifier
,
80 : name(nameRef
), dialect(dialect
), verifier(std::move(verifier
)),
81 parser(std::move(parser
)), printer(std::move(printer
)),
82 ctx(dialect
->getContext()) {}
84 DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect
*dialect
,
86 : name(nameRef
), dialect(dialect
), ctx(dialect
->getContext()) {}
88 void DynamicTypeDefinition::registerInTypeUniquer() {
89 detail::TypeUniquer::registerType
<DynamicType
>(&getContext(), getTypeID());
94 /// Storage of DynamicType.
95 /// Contains a pointer to the type definition and type parameters.
96 struct DynamicTypeStorage
: public TypeStorage
{
98 using KeyTy
= std::pair
<DynamicTypeDefinition
*, ArrayRef
<Attribute
>>;
100 explicit DynamicTypeStorage(DynamicTypeDefinition
*typeDef
,
101 ArrayRef
<Attribute
> params
)
102 : typeDef(typeDef
), params(params
) {}
104 bool operator==(const KeyTy
&key
) const {
105 return typeDef
== key
.first
&& params
== key
.second
;
108 static llvm::hash_code
hashKey(const KeyTy
&key
) {
109 return llvm::hash_value(key
);
112 static DynamicTypeStorage
*construct(TypeStorageAllocator
&alloc
,
114 return new (alloc
.allocate
<DynamicTypeStorage
>())
115 DynamicTypeStorage(key
.first
, alloc
.copyInto(key
.second
));
118 /// Definition of the type.
119 DynamicTypeDefinition
*typeDef
;
121 /// The type parameters.
122 ArrayRef
<Attribute
> params
;
124 } // namespace detail
127 DynamicType
DynamicType::get(DynamicTypeDefinition
*typeDef
,
128 ArrayRef
<Attribute
> params
) {
129 auto &ctx
= typeDef
->getContext();
130 auto emitError
= detail::getDefaultDiagnosticEmitFn(&ctx
);
131 assert(succeeded(typeDef
->verify(emitError
, params
)));
132 return detail::TypeUniquer::getWithTypeID
<DynamicType
>(
133 &ctx
, typeDef
->getTypeID(), typeDef
, params
);
137 DynamicType::getChecked(function_ref
<InFlightDiagnostic()> emitError
,
138 DynamicTypeDefinition
*typeDef
,
139 ArrayRef
<Attribute
> params
) {
140 if (failed(typeDef
->verify(emitError
, params
)))
142 auto &ctx
= typeDef
->getContext();
143 return detail::TypeUniquer::getWithTypeID
<DynamicType
>(
144 &ctx
, typeDef
->getTypeID(), typeDef
, params
);
147 DynamicTypeDefinition
*DynamicType::getTypeDef() { return getImpl()->typeDef
; }
149 ArrayRef
<Attribute
> DynamicType::getParams() { return getImpl()->params
; }
151 bool DynamicType::classof(Type type
) {
152 return type
.hasTrait
<TypeTrait::IsDynamicType
>();
155 ParseResult
DynamicType::parse(AsmParser
&parser
,
156 DynamicTypeDefinition
*typeDef
,
157 DynamicType
&parsedType
) {
158 SmallVector
<Attribute
> params
;
159 if (failed(typeDef
->parser(parser
, params
)))
161 parsedType
= parser
.getChecked
<DynamicType
>(typeDef
, params
);
167 void DynamicType::print(AsmPrinter
&printer
) {
168 printer
<< getTypeDef()->getName();
169 getTypeDef()->printer(printer
, getParams());
172 //===----------------------------------------------------------------------===//
174 //===----------------------------------------------------------------------===//
176 std::unique_ptr
<DynamicAttrDefinition
>
177 DynamicAttrDefinition::get(StringRef name
, ExtensibleDialect
*dialect
,
178 VerifierFn
&&verifier
) {
179 return DynamicAttrDefinition::get(name
, dialect
, std::move(verifier
),
180 typeOrAttrParser
, typeOrAttrPrinter
);
183 std::unique_ptr
<DynamicAttrDefinition
>
184 DynamicAttrDefinition::get(StringRef name
, ExtensibleDialect
*dialect
,
185 VerifierFn
&&verifier
, ParserFn
&&parser
,
186 PrinterFn
&&printer
) {
187 return std::unique_ptr
<DynamicAttrDefinition
>(
188 new DynamicAttrDefinition(name
, dialect
, std::move(verifier
),
189 std::move(parser
), std::move(printer
)));
192 DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef
,
193 ExtensibleDialect
*dialect
,
194 VerifierFn
&&verifier
,
197 : name(nameRef
), dialect(dialect
), verifier(std::move(verifier
)),
198 parser(std::move(parser
)), printer(std::move(printer
)),
199 ctx(dialect
->getContext()) {}
201 DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect
*dialect
,
203 : name(nameRef
), dialect(dialect
), ctx(dialect
->getContext()) {}
205 void DynamicAttrDefinition::registerInAttrUniquer() {
206 detail::AttributeUniquer::registerAttribute
<DynamicAttr
>(&getContext(),
212 /// Storage of DynamicAttr.
213 /// Contains a pointer to the attribute definition and attribute parameters.
214 struct DynamicAttrStorage
: public AttributeStorage
{
215 using KeyTy
= std::pair
<DynamicAttrDefinition
*, ArrayRef
<Attribute
>>;
217 explicit DynamicAttrStorage(DynamicAttrDefinition
*attrDef
,
218 ArrayRef
<Attribute
> params
)
219 : attrDef(attrDef
), params(params
) {}
221 bool operator==(const KeyTy
&key
) const {
222 return attrDef
== key
.first
&& params
== key
.second
;
225 static llvm::hash_code
hashKey(const KeyTy
&key
) {
226 return llvm::hash_value(key
);
229 static DynamicAttrStorage
*construct(AttributeStorageAllocator
&alloc
,
231 return new (alloc
.allocate
<DynamicAttrStorage
>())
232 DynamicAttrStorage(key
.first
, alloc
.copyInto(key
.second
));
235 /// Definition of the type.
236 DynamicAttrDefinition
*attrDef
;
238 /// The type parameters.
239 ArrayRef
<Attribute
> params
;
241 } // namespace detail
244 DynamicAttr
DynamicAttr::get(DynamicAttrDefinition
*attrDef
,
245 ArrayRef
<Attribute
> params
) {
246 auto &ctx
= attrDef
->getContext();
247 return detail::AttributeUniquer::getWithTypeID
<DynamicAttr
>(
248 &ctx
, attrDef
->getTypeID(), attrDef
, params
);
252 DynamicAttr::getChecked(function_ref
<InFlightDiagnostic()> emitError
,
253 DynamicAttrDefinition
*attrDef
,
254 ArrayRef
<Attribute
> params
) {
255 if (failed(attrDef
->verify(emitError
, params
)))
257 return get(attrDef
, params
);
260 DynamicAttrDefinition
*DynamicAttr::getAttrDef() { return getImpl()->attrDef
; }
262 ArrayRef
<Attribute
> DynamicAttr::getParams() { return getImpl()->params
; }
264 bool DynamicAttr::classof(Attribute attr
) {
265 return attr
.hasTrait
<AttributeTrait::IsDynamicAttr
>();
268 ParseResult
DynamicAttr::parse(AsmParser
&parser
,
269 DynamicAttrDefinition
*attrDef
,
270 DynamicAttr
&parsedAttr
) {
271 SmallVector
<Attribute
> params
;
272 if (failed(attrDef
->parser(parser
, params
)))
274 parsedAttr
= parser
.getChecked
<DynamicAttr
>(attrDef
, params
);
280 void DynamicAttr::print(AsmPrinter
&printer
) {
281 printer
<< getAttrDef()->getName();
282 getAttrDef()->printer(printer
, getParams());
285 //===----------------------------------------------------------------------===//
287 //===----------------------------------------------------------------------===//
289 DynamicOpDefinition::DynamicOpDefinition(
290 StringRef name
, ExtensibleDialect
*dialect
,
291 OperationName::VerifyInvariantsFn
&&verifyFn
,
292 OperationName::VerifyRegionInvariantsFn
&&verifyRegionFn
,
293 OperationName::ParseAssemblyFn
&&parseFn
,
294 OperationName::PrintAssemblyFn
&&printFn
,
295 OperationName::FoldHookFn
&&foldHookFn
,
296 GetCanonicalizationPatternsFn
&&getCanonicalizationPatternsFn
,
297 OperationName::PopulateDefaultAttrsFn
&&populateDefaultAttrsFn
)
298 : Impl(StringAttr::get(dialect
->getContext(),
299 (dialect
->getNamespace() + "." + name
).str()),
300 dialect
, dialect
->allocateTypeID(),
301 /*interfaceMap=*/detail::InterfaceMap()),
302 verifyFn(std::move(verifyFn
)), verifyRegionFn(std::move(verifyRegionFn
)),
303 parseFn(std::move(parseFn
)), printFn(std::move(printFn
)),
304 foldHookFn(std::move(foldHookFn
)),
305 getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn
)),
306 populateDefaultAttrsFn(std::move(populateDefaultAttrsFn
)) {
307 typeID
= dialect
->allocateTypeID();
310 std::unique_ptr
<DynamicOpDefinition
> DynamicOpDefinition::get(
311 StringRef name
, ExtensibleDialect
*dialect
,
312 OperationName::VerifyInvariantsFn
&&verifyFn
,
313 OperationName::VerifyRegionInvariantsFn
&&verifyRegionFn
) {
314 auto parseFn
= [](OpAsmParser
&parser
, OperationState
&result
) {
315 return parser
.emitError(
316 parser
.getCurrentLocation(),
317 "dynamic operation do not define any parser function");
320 auto printFn
= [](Operation
*op
, OpAsmPrinter
&printer
, StringRef
) {
321 printer
.printGenericOp(op
);
324 return DynamicOpDefinition::get(name
, dialect
, std::move(verifyFn
),
325 std::move(verifyRegionFn
), std::move(parseFn
),
329 std::unique_ptr
<DynamicOpDefinition
> DynamicOpDefinition::get(
330 StringRef name
, ExtensibleDialect
*dialect
,
331 OperationName::VerifyInvariantsFn
&&verifyFn
,
332 OperationName::VerifyRegionInvariantsFn
&&verifyRegionFn
,
333 OperationName::ParseAssemblyFn
&&parseFn
,
334 OperationName::PrintAssemblyFn
&&printFn
) {
335 auto foldHookFn
= [](Operation
*op
, ArrayRef
<Attribute
> operands
,
336 SmallVectorImpl
<OpFoldResult
> &results
) {
340 auto getCanonicalizationPatternsFn
= [](RewritePatternSet
&, MLIRContext
*) {
343 auto populateDefaultAttrsFn
= [](const OperationName
&, NamedAttrList
&) {};
345 return DynamicOpDefinition::get(name
, dialect
, std::move(verifyFn
),
346 std::move(verifyRegionFn
), std::move(parseFn
),
347 std::move(printFn
), std::move(foldHookFn
),
348 std::move(getCanonicalizationPatternsFn
),
349 std::move(populateDefaultAttrsFn
));
352 std::unique_ptr
<DynamicOpDefinition
> DynamicOpDefinition::get(
353 StringRef name
, ExtensibleDialect
*dialect
,
354 OperationName::VerifyInvariantsFn
&&verifyFn
,
355 OperationName::VerifyInvariantsFn
&&verifyRegionFn
,
356 OperationName::ParseAssemblyFn
&&parseFn
,
357 OperationName::PrintAssemblyFn
&&printFn
,
358 OperationName::FoldHookFn
&&foldHookFn
,
359 GetCanonicalizationPatternsFn
&&getCanonicalizationPatternsFn
,
360 OperationName::PopulateDefaultAttrsFn
&&populateDefaultAttrsFn
) {
361 return std::unique_ptr
<DynamicOpDefinition
>(new DynamicOpDefinition(
362 name
, dialect
, std::move(verifyFn
), std::move(verifyRegionFn
),
363 std::move(parseFn
), std::move(printFn
), std::move(foldHookFn
),
364 std::move(getCanonicalizationPatternsFn
),
365 std::move(populateDefaultAttrsFn
)));
368 //===----------------------------------------------------------------------===//
369 // Extensible dialect
370 //===----------------------------------------------------------------------===//
373 /// Interface that can only be implemented by extensible dialects.
374 /// The interface is used to check if a dialect is extensible or not.
375 class IsExtensibleDialect
: public DialectInterface::Base
<IsExtensibleDialect
> {
377 IsExtensibleDialect(Dialect
*dialect
) : Base(dialect
) {}
379 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsExtensibleDialect
)
383 ExtensibleDialect::ExtensibleDialect(StringRef name
, MLIRContext
*ctx
,
385 : Dialect(name
, ctx
, typeID
) {
386 addInterfaces
<IsExtensibleDialect
>();
389 void ExtensibleDialect::registerDynamicType(
390 std::unique_ptr
<DynamicTypeDefinition
> &&type
) {
391 DynamicTypeDefinition
*typePtr
= type
.get();
392 TypeID typeID
= type
->getTypeID();
393 StringRef name
= type
->getName();
394 ExtensibleDialect
*dialect
= type
->getDialect();
396 assert(dialect
== this &&
397 "trying to register a dynamic type in the wrong dialect");
399 // If a type with the same name is already defined, fail.
400 auto registered
= dynTypes
.try_emplace(typeID
, std::move(type
)).second
;
402 assert(registered
&& "type TypeID was not unique");
404 registered
= nameToDynTypes
.insert({name
, typePtr
}).second
;
407 "Trying to create a new dynamic type with an existing name");
409 // The StringAttr allocates the type name StringRef for the duration of the
411 MLIRContext
*ctx
= getContext();
413 StringAttr::get(ctx
, getNamespace() + "." + typePtr
->getName());
415 auto abstractType
= AbstractType::get(
416 *dialect
, DynamicAttr::getInterfaceMap(), DynamicType::getHasTraitFn(),
417 DynamicType::getWalkImmediateSubElementsFn(),
418 DynamicType::getReplaceImmediateSubElementsFn(), typeID
, nameAttr
);
420 /// Add the type to the dialect and the type uniquer.
421 addType(typeID
, std::move(abstractType
));
422 typePtr
->registerInTypeUniquer();
425 void ExtensibleDialect::registerDynamicAttr(
426 std::unique_ptr
<DynamicAttrDefinition
> &&attr
) {
427 auto *attrPtr
= attr
.get();
428 auto typeID
= attr
->getTypeID();
429 auto name
= attr
->getName();
430 auto *dialect
= attr
->getDialect();
432 assert(dialect
== this &&
433 "trying to register a dynamic attribute in the wrong dialect");
435 // If an attribute with the same name is already defined, fail.
436 auto registered
= dynAttrs
.try_emplace(typeID
, std::move(attr
)).second
;
438 assert(registered
&& "attribute TypeID was not unique");
440 registered
= nameToDynAttrs
.insert({name
, attrPtr
}).second
;
443 "Trying to create a new dynamic attribute with an existing name");
445 // The StringAttr allocates the attribute name StringRef for the duration of
447 MLIRContext
*ctx
= getContext();
449 StringAttr::get(ctx
, getNamespace() + "." + attrPtr
->getName());
451 auto abstractAttr
= AbstractAttribute::get(
452 *dialect
, DynamicAttr::getInterfaceMap(), DynamicAttr::getHasTraitFn(),
453 DynamicAttr::getWalkImmediateSubElementsFn(),
454 DynamicAttr::getReplaceImmediateSubElementsFn(), typeID
, nameAttr
);
456 /// Add the type to the dialect and the type uniquer.
457 addAttribute(typeID
, std::move(abstractAttr
));
458 attrPtr
->registerInAttrUniquer();
461 void ExtensibleDialect::registerDynamicOp(
462 std::unique_ptr
<DynamicOpDefinition
> &&op
) {
463 assert(op
->dialect
== this &&
464 "trying to register a dynamic op in the wrong dialect");
465 RegisteredOperationName::insert(std::move(op
), /*attrNames=*/{});
468 bool ExtensibleDialect::classof(const Dialect
*dialect
) {
469 return const_cast<Dialect
*>(dialect
)
470 ->getRegisteredInterface
<IsExtensibleDialect
>();
473 OptionalParseResult
ExtensibleDialect::parseOptionalDynamicType(
474 StringRef typeName
, AsmParser
&parser
, Type
&resultType
) const {
475 DynamicTypeDefinition
*typeDef
= lookupTypeDefinition(typeName
);
480 if (DynamicType::parse(parser
, typeDef
, dynType
))
482 resultType
= dynType
;
486 LogicalResult
ExtensibleDialect::printIfDynamicType(Type type
,
487 AsmPrinter
&printer
) {
488 if (auto dynType
= llvm::dyn_cast
<DynamicType
>(type
)) {
489 dynType
.print(printer
);
495 OptionalParseResult
ExtensibleDialect::parseOptionalDynamicAttr(
496 StringRef attrName
, AsmParser
&parser
, Attribute
&resultAttr
) const {
497 DynamicAttrDefinition
*attrDef
= lookupAttrDefinition(attrName
);
502 if (DynamicAttr::parse(parser
, attrDef
, dynAttr
))
504 resultAttr
= dynAttr
;
508 LogicalResult
ExtensibleDialect::printIfDynamicAttr(Attribute attribute
,
509 AsmPrinter
&printer
) {
510 if (auto dynAttr
= llvm::dyn_cast
<DynamicAttr
>(attribute
)) {
511 dynAttr
.print(printer
);
517 //===----------------------------------------------------------------------===//
519 //===----------------------------------------------------------------------===//
522 /// Interface that can only be implemented by extensible dialects.
523 /// The interface is used to check if a dialect is extensible or not.
524 class IsDynamicDialect
: public DialectInterface::Base
<IsDynamicDialect
> {
526 IsDynamicDialect(Dialect
*dialect
) : Base(dialect
) {}
528 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsDynamicDialect
)
532 DynamicDialect::DynamicDialect(StringRef name
, MLIRContext
*ctx
)
533 : SelfOwningTypeID(),
534 ExtensibleDialect(name
, ctx
, SelfOwningTypeID::getTypeID()) {
535 addInterfaces
<IsDynamicDialect
>();
538 bool DynamicDialect::classof(const Dialect
*dialect
) {
539 return const_cast<Dialect
*>(dialect
)
540 ->getRegisteredInterface
<IsDynamicDialect
>();
543 Type
DynamicDialect::parseType(DialectAsmParser
&parser
) const {
544 auto loc
= parser
.getCurrentLocation();
546 if (failed(parser
.parseKeyword(&typeTag
)))
551 auto parseResult
= parseOptionalDynamicType(typeTag
, parser
, dynType
);
552 if (parseResult
.has_value()) {
553 if (succeeded(parseResult
.value()))
559 parser
.emitError(loc
, "expected dynamic type");
563 void DynamicDialect::printType(Type type
, DialectAsmPrinter
&printer
) const {
564 auto wasDynamic
= printIfDynamicType(type
, printer
);
566 assert(succeeded(wasDynamic
) &&
567 "non-dynamic type defined in dynamic dialect");
570 Attribute
DynamicDialect::parseAttribute(DialectAsmParser
&parser
,
572 auto loc
= parser
.getCurrentLocation();
574 if (failed(parser
.parseKeyword(&typeTag
)))
579 auto parseResult
= parseOptionalDynamicAttr(typeTag
, parser
, dynAttr
);
580 if (parseResult
.has_value()) {
581 if (succeeded(parseResult
.value()))
587 parser
.emitError(loc
, "expected dynamic attribute");
590 void DynamicDialect::printAttribute(Attribute attr
,
591 DialectAsmPrinter
&printer
) const {
592 auto wasDynamic
= printIfDynamicAttr(attr
, printer
);
594 assert(succeeded(wasDynamic
) &&
595 "non-dynamic attribute defined in dynamic dialect");