[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / lib / IR / ExtensibleDialect.cpp
blob8a7d74700006cee7e1cd11ec3497e7751131762e
1 //===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/ExtensibleDialect.h"
10 #include "mlir/IR/AttributeSupport.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "mlir/IR/OperationSupport.h"
13 #include "mlir/IR/StorageUniquerSupport.h"
15 using namespace mlir;
17 //===----------------------------------------------------------------------===//
18 // Dynamic types and attributes shared functions
19 //===----------------------------------------------------------------------===//
21 /// Default parser for dynamic attribute or type parameters.
22 /// Parse in the format '(<>)?' or '<attr (,attr)*>'.
23 static LogicalResult
24 typeOrAttrParser(AsmParser &parser, SmallVectorImpl<Attribute> &parsedParams) {
25 // No parameters
26 if (parser.parseOptionalLess() || !parser.parseOptionalGreater())
27 return success();
29 Attribute attr;
30 if (parser.parseAttribute(attr))
31 return failure();
32 parsedParams.push_back(attr);
34 while (parser.parseOptionalGreater()) {
35 Attribute attr;
36 if (parser.parseComma() || parser.parseAttribute(attr))
37 return failure();
38 parsedParams.push_back(attr);
41 return success();
44 /// Default printer for dynamic attribute or type parameters.
45 /// Print in the format '(<>)?' or '<attr (,attr)*>'.
46 static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef<Attribute> params) {
47 if (params.empty())
48 return;
50 printer << "<";
51 interleaveComma(params, printer.getStream());
52 printer << ">";
55 //===----------------------------------------------------------------------===//
56 // Dynamic type
57 //===----------------------------------------------------------------------===//
59 std::unique_ptr<DynamicTypeDefinition>
60 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
61 VerifierFn &&verifier) {
62 return DynamicTypeDefinition::get(name, dialect, std::move(verifier),
63 typeOrAttrParser, typeOrAttrPrinter);
66 std::unique_ptr<DynamicTypeDefinition>
67 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
68 VerifierFn &&verifier, ParserFn &&parser,
69 PrinterFn &&printer) {
70 return std::unique_ptr<DynamicTypeDefinition>(
71 new DynamicTypeDefinition(name, dialect, std::move(verifier),
72 std::move(parser), std::move(printer)));
75 DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef,
76 ExtensibleDialect *dialect,
77 VerifierFn &&verifier,
78 ParserFn &&parser,
79 PrinterFn &&printer)
80 : name(nameRef), dialect(dialect), verifier(std::move(verifier)),
81 parser(std::move(parser)), printer(std::move(printer)),
82 ctx(dialect->getContext()) {}
84 DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect,
85 StringRef nameRef)
86 : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {}
88 void DynamicTypeDefinition::registerInTypeUniquer() {
89 detail::TypeUniquer::registerType<DynamicType>(&getContext(), getTypeID());
92 namespace mlir {
93 namespace detail {
94 /// Storage of DynamicType.
95 /// Contains a pointer to the type definition and type parameters.
96 struct DynamicTypeStorage : public TypeStorage {
98 using KeyTy = std::pair<DynamicTypeDefinition *, ArrayRef<Attribute>>;
100 explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef,
101 ArrayRef<Attribute> params)
102 : typeDef(typeDef), params(params) {}
104 bool operator==(const KeyTy &key) const {
105 return typeDef == key.first && params == key.second;
108 static llvm::hash_code hashKey(const KeyTy &key) {
109 return llvm::hash_value(key);
112 static DynamicTypeStorage *construct(TypeStorageAllocator &alloc,
113 const KeyTy &key) {
114 return new (alloc.allocate<DynamicTypeStorage>())
115 DynamicTypeStorage(key.first, alloc.copyInto(key.second));
118 /// Definition of the type.
119 DynamicTypeDefinition *typeDef;
121 /// The type parameters.
122 ArrayRef<Attribute> params;
124 } // namespace detail
125 } // namespace mlir
127 DynamicType DynamicType::get(DynamicTypeDefinition *typeDef,
128 ArrayRef<Attribute> params) {
129 auto &ctx = typeDef->getContext();
130 auto emitError = detail::getDefaultDiagnosticEmitFn(&ctx);
131 assert(succeeded(typeDef->verify(emitError, params)));
132 return detail::TypeUniquer::getWithTypeID<DynamicType>(
133 &ctx, typeDef->getTypeID(), typeDef, params);
136 DynamicType
137 DynamicType::getChecked(function_ref<InFlightDiagnostic()> emitError,
138 DynamicTypeDefinition *typeDef,
139 ArrayRef<Attribute> params) {
140 if (failed(typeDef->verify(emitError, params)))
141 return {};
142 auto &ctx = typeDef->getContext();
143 return detail::TypeUniquer::getWithTypeID<DynamicType>(
144 &ctx, typeDef->getTypeID(), typeDef, params);
147 DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; }
149 ArrayRef<Attribute> DynamicType::getParams() { return getImpl()->params; }
151 bool DynamicType::classof(Type type) {
152 return type.hasTrait<TypeTrait::IsDynamicType>();
155 ParseResult DynamicType::parse(AsmParser &parser,
156 DynamicTypeDefinition *typeDef,
157 DynamicType &parsedType) {
158 SmallVector<Attribute> params;
159 if (failed(typeDef->parser(parser, params)))
160 return failure();
161 parsedType = parser.getChecked<DynamicType>(typeDef, params);
162 if (!parsedType)
163 return failure();
164 return success();
167 void DynamicType::print(AsmPrinter &printer) {
168 printer << getTypeDef()->getName();
169 getTypeDef()->printer(printer, getParams());
172 //===----------------------------------------------------------------------===//
173 // Dynamic attribute
174 //===----------------------------------------------------------------------===//
176 std::unique_ptr<DynamicAttrDefinition>
177 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
178 VerifierFn &&verifier) {
179 return DynamicAttrDefinition::get(name, dialect, std::move(verifier),
180 typeOrAttrParser, typeOrAttrPrinter);
183 std::unique_ptr<DynamicAttrDefinition>
184 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
185 VerifierFn &&verifier, ParserFn &&parser,
186 PrinterFn &&printer) {
187 return std::unique_ptr<DynamicAttrDefinition>(
188 new DynamicAttrDefinition(name, dialect, std::move(verifier),
189 std::move(parser), std::move(printer)));
192 DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef,
193 ExtensibleDialect *dialect,
194 VerifierFn &&verifier,
195 ParserFn &&parser,
196 PrinterFn &&printer)
197 : name(nameRef), dialect(dialect), verifier(std::move(verifier)),
198 parser(std::move(parser)), printer(std::move(printer)),
199 ctx(dialect->getContext()) {}
201 DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect,
202 StringRef nameRef)
203 : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {}
205 void DynamicAttrDefinition::registerInAttrUniquer() {
206 detail::AttributeUniquer::registerAttribute<DynamicAttr>(&getContext(),
207 getTypeID());
210 namespace mlir {
211 namespace detail {
212 /// Storage of DynamicAttr.
213 /// Contains a pointer to the attribute definition and attribute parameters.
214 struct DynamicAttrStorage : public AttributeStorage {
215 using KeyTy = std::pair<DynamicAttrDefinition *, ArrayRef<Attribute>>;
217 explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef,
218 ArrayRef<Attribute> params)
219 : attrDef(attrDef), params(params) {}
221 bool operator==(const KeyTy &key) const {
222 return attrDef == key.first && params == key.second;
225 static llvm::hash_code hashKey(const KeyTy &key) {
226 return llvm::hash_value(key);
229 static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc,
230 const KeyTy &key) {
231 return new (alloc.allocate<DynamicAttrStorage>())
232 DynamicAttrStorage(key.first, alloc.copyInto(key.second));
235 /// Definition of the type.
236 DynamicAttrDefinition *attrDef;
238 /// The type parameters.
239 ArrayRef<Attribute> params;
241 } // namespace detail
242 } // namespace mlir
244 DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef,
245 ArrayRef<Attribute> params) {
246 auto &ctx = attrDef->getContext();
247 return detail::AttributeUniquer::getWithTypeID<DynamicAttr>(
248 &ctx, attrDef->getTypeID(), attrDef, params);
251 DynamicAttr
252 DynamicAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
253 DynamicAttrDefinition *attrDef,
254 ArrayRef<Attribute> params) {
255 if (failed(attrDef->verify(emitError, params)))
256 return {};
257 return get(attrDef, params);
260 DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; }
262 ArrayRef<Attribute> DynamicAttr::getParams() { return getImpl()->params; }
264 bool DynamicAttr::classof(Attribute attr) {
265 return attr.hasTrait<AttributeTrait::IsDynamicAttr>();
268 ParseResult DynamicAttr::parse(AsmParser &parser,
269 DynamicAttrDefinition *attrDef,
270 DynamicAttr &parsedAttr) {
271 SmallVector<Attribute> params;
272 if (failed(attrDef->parser(parser, params)))
273 return failure();
274 parsedAttr = parser.getChecked<DynamicAttr>(attrDef, params);
275 if (!parsedAttr)
276 return failure();
277 return success();
280 void DynamicAttr::print(AsmPrinter &printer) {
281 printer << getAttrDef()->getName();
282 getAttrDef()->printer(printer, getParams());
285 //===----------------------------------------------------------------------===//
286 // Dynamic operation
287 //===----------------------------------------------------------------------===//
289 DynamicOpDefinition::DynamicOpDefinition(
290 StringRef name, ExtensibleDialect *dialect,
291 OperationName::VerifyInvariantsFn &&verifyFn,
292 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
293 OperationName::ParseAssemblyFn &&parseFn,
294 OperationName::PrintAssemblyFn &&printFn,
295 OperationName::FoldHookFn &&foldHookFn,
296 GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
297 OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn)
298 : Impl(StringAttr::get(dialect->getContext(),
299 (dialect->getNamespace() + "." + name).str()),
300 dialect, dialect->allocateTypeID(),
301 /*interfaceMap=*/detail::InterfaceMap()),
302 verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
303 parseFn(std::move(parseFn)), printFn(std::move(printFn)),
304 foldHookFn(std::move(foldHookFn)),
305 getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)),
306 populateDefaultAttrsFn(std::move(populateDefaultAttrsFn)) {
307 typeID = dialect->allocateTypeID();
310 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
311 StringRef name, ExtensibleDialect *dialect,
312 OperationName::VerifyInvariantsFn &&verifyFn,
313 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn) {
314 auto parseFn = [](OpAsmParser &parser, OperationState &result) {
315 return parser.emitError(
316 parser.getCurrentLocation(),
317 "dynamic operation do not define any parser function");
320 auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) {
321 printer.printGenericOp(op);
324 return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
325 std::move(verifyRegionFn), std::move(parseFn),
326 std::move(printFn));
329 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
330 StringRef name, ExtensibleDialect *dialect,
331 OperationName::VerifyInvariantsFn &&verifyFn,
332 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
333 OperationName::ParseAssemblyFn &&parseFn,
334 OperationName::PrintAssemblyFn &&printFn) {
335 auto foldHookFn = [](Operation *op, ArrayRef<Attribute> operands,
336 SmallVectorImpl<OpFoldResult> &results) {
337 return failure();
340 auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) {
343 auto populateDefaultAttrsFn = [](const OperationName &, NamedAttrList &) {};
345 return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
346 std::move(verifyRegionFn), std::move(parseFn),
347 std::move(printFn), std::move(foldHookFn),
348 std::move(getCanonicalizationPatternsFn),
349 std::move(populateDefaultAttrsFn));
352 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
353 StringRef name, ExtensibleDialect *dialect,
354 OperationName::VerifyInvariantsFn &&verifyFn,
355 OperationName::VerifyInvariantsFn &&verifyRegionFn,
356 OperationName::ParseAssemblyFn &&parseFn,
357 OperationName::PrintAssemblyFn &&printFn,
358 OperationName::FoldHookFn &&foldHookFn,
359 GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
360 OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn) {
361 return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition(
362 name, dialect, std::move(verifyFn), std::move(verifyRegionFn),
363 std::move(parseFn), std::move(printFn), std::move(foldHookFn),
364 std::move(getCanonicalizationPatternsFn),
365 std::move(populateDefaultAttrsFn)));
368 //===----------------------------------------------------------------------===//
369 // Extensible dialect
370 //===----------------------------------------------------------------------===//
372 namespace {
373 /// Interface that can only be implemented by extensible dialects.
374 /// The interface is used to check if a dialect is extensible or not.
375 class IsExtensibleDialect : public DialectInterface::Base<IsExtensibleDialect> {
376 public:
377 IsExtensibleDialect(Dialect *dialect) : Base(dialect) {}
379 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsExtensibleDialect)
381 } // namespace
383 ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx,
384 TypeID typeID)
385 : Dialect(name, ctx, typeID) {
386 addInterfaces<IsExtensibleDialect>();
389 void ExtensibleDialect::registerDynamicType(
390 std::unique_ptr<DynamicTypeDefinition> &&type) {
391 DynamicTypeDefinition *typePtr = type.get();
392 TypeID typeID = type->getTypeID();
393 StringRef name = type->getName();
394 ExtensibleDialect *dialect = type->getDialect();
396 assert(dialect == this &&
397 "trying to register a dynamic type in the wrong dialect");
399 // If a type with the same name is already defined, fail.
400 auto registered = dynTypes.try_emplace(typeID, std::move(type)).second;
401 (void)registered;
402 assert(registered && "type TypeID was not unique");
404 registered = nameToDynTypes.insert({name, typePtr}).second;
405 (void)registered;
406 assert(registered &&
407 "Trying to create a new dynamic type with an existing name");
409 // The StringAttr allocates the type name StringRef for the duration of the
410 // MLIR context.
411 MLIRContext *ctx = getContext();
412 auto nameAttr =
413 StringAttr::get(ctx, getNamespace() + "." + typePtr->getName());
415 auto abstractType = AbstractType::get(
416 *dialect, DynamicAttr::getInterfaceMap(), DynamicType::getHasTraitFn(),
417 DynamicType::getWalkImmediateSubElementsFn(),
418 DynamicType::getReplaceImmediateSubElementsFn(), typeID, nameAttr);
420 /// Add the type to the dialect and the type uniquer.
421 addType(typeID, std::move(abstractType));
422 typePtr->registerInTypeUniquer();
425 void ExtensibleDialect::registerDynamicAttr(
426 std::unique_ptr<DynamicAttrDefinition> &&attr) {
427 auto *attrPtr = attr.get();
428 auto typeID = attr->getTypeID();
429 auto name = attr->getName();
430 auto *dialect = attr->getDialect();
432 assert(dialect == this &&
433 "trying to register a dynamic attribute in the wrong dialect");
435 // If an attribute with the same name is already defined, fail.
436 auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second;
437 (void)registered;
438 assert(registered && "attribute TypeID was not unique");
440 registered = nameToDynAttrs.insert({name, attrPtr}).second;
441 (void)registered;
442 assert(registered &&
443 "Trying to create a new dynamic attribute with an existing name");
445 // The StringAttr allocates the attribute name StringRef for the duration of
446 // the MLIR context.
447 MLIRContext *ctx = getContext();
448 auto nameAttr =
449 StringAttr::get(ctx, getNamespace() + "." + attrPtr->getName());
451 auto abstractAttr = AbstractAttribute::get(
452 *dialect, DynamicAttr::getInterfaceMap(), DynamicAttr::getHasTraitFn(),
453 DynamicAttr::getWalkImmediateSubElementsFn(),
454 DynamicAttr::getReplaceImmediateSubElementsFn(), typeID, nameAttr);
456 /// Add the type to the dialect and the type uniquer.
457 addAttribute(typeID, std::move(abstractAttr));
458 attrPtr->registerInAttrUniquer();
461 void ExtensibleDialect::registerDynamicOp(
462 std::unique_ptr<DynamicOpDefinition> &&op) {
463 assert(op->dialect == this &&
464 "trying to register a dynamic op in the wrong dialect");
465 RegisteredOperationName::insert(std::move(op), /*attrNames=*/{});
468 bool ExtensibleDialect::classof(const Dialect *dialect) {
469 return const_cast<Dialect *>(dialect)
470 ->getRegisteredInterface<IsExtensibleDialect>();
473 OptionalParseResult ExtensibleDialect::parseOptionalDynamicType(
474 StringRef typeName, AsmParser &parser, Type &resultType) const {
475 DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName);
476 if (!typeDef)
477 return std::nullopt;
479 DynamicType dynType;
480 if (DynamicType::parse(parser, typeDef, dynType))
481 return failure();
482 resultType = dynType;
483 return success();
486 LogicalResult ExtensibleDialect::printIfDynamicType(Type type,
487 AsmPrinter &printer) {
488 if (auto dynType = llvm::dyn_cast<DynamicType>(type)) {
489 dynType.print(printer);
490 return success();
492 return failure();
495 OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr(
496 StringRef attrName, AsmParser &parser, Attribute &resultAttr) const {
497 DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName);
498 if (!attrDef)
499 return std::nullopt;
501 DynamicAttr dynAttr;
502 if (DynamicAttr::parse(parser, attrDef, dynAttr))
503 return failure();
504 resultAttr = dynAttr;
505 return success();
508 LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute,
509 AsmPrinter &printer) {
510 if (auto dynAttr = llvm::dyn_cast<DynamicAttr>(attribute)) {
511 dynAttr.print(printer);
512 return success();
514 return failure();
517 //===----------------------------------------------------------------------===//
518 // Dynamic dialect
519 //===----------------------------------------------------------------------===//
521 namespace {
522 /// Interface that can only be implemented by extensible dialects.
523 /// The interface is used to check if a dialect is extensible or not.
524 class IsDynamicDialect : public DialectInterface::Base<IsDynamicDialect> {
525 public:
526 IsDynamicDialect(Dialect *dialect) : Base(dialect) {}
528 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsDynamicDialect)
530 } // namespace
532 DynamicDialect::DynamicDialect(StringRef name, MLIRContext *ctx)
533 : SelfOwningTypeID(),
534 ExtensibleDialect(name, ctx, SelfOwningTypeID::getTypeID()) {
535 addInterfaces<IsDynamicDialect>();
538 bool DynamicDialect::classof(const Dialect *dialect) {
539 return const_cast<Dialect *>(dialect)
540 ->getRegisteredInterface<IsDynamicDialect>();
543 Type DynamicDialect::parseType(DialectAsmParser &parser) const {
544 auto loc = parser.getCurrentLocation();
545 StringRef typeTag;
546 if (failed(parser.parseKeyword(&typeTag)))
547 return Type();
550 Type dynType;
551 auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
552 if (parseResult.has_value()) {
553 if (succeeded(parseResult.value()))
554 return dynType;
555 return Type();
559 parser.emitError(loc, "expected dynamic type");
560 return Type();
563 void DynamicDialect::printType(Type type, DialectAsmPrinter &printer) const {
564 auto wasDynamic = printIfDynamicType(type, printer);
565 (void)wasDynamic;
566 assert(succeeded(wasDynamic) &&
567 "non-dynamic type defined in dynamic dialect");
570 Attribute DynamicDialect::parseAttribute(DialectAsmParser &parser,
571 Type type) const {
572 auto loc = parser.getCurrentLocation();
573 StringRef typeTag;
574 if (failed(parser.parseKeyword(&typeTag)))
575 return Attribute();
578 Attribute dynAttr;
579 auto parseResult = parseOptionalDynamicAttr(typeTag, parser, dynAttr);
580 if (parseResult.has_value()) {
581 if (succeeded(parseResult.value()))
582 return dynAttr;
583 return Attribute();
587 parser.emitError(loc, "expected dynamic attribute");
588 return Attribute();
590 void DynamicDialect::printAttribute(Attribute attr,
591 DialectAsmPrinter &printer) const {
592 auto wasDynamic = printIfDynamicAttr(attr, printer);
593 (void)wasDynamic;
594 assert(succeeded(wasDynamic) &&
595 "non-dynamic attribute defined in dynamic dialect");