1 //===- OpDefinitionsGen.cpp - IRDL op definitions generator ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // OpDefinitionsGen uses the description of operations to generate IRDL
10 // definitions for ops.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/IRDL/IR/IRDL.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/TableGen/AttrOrTypeDef.h"
22 #include "mlir/TableGen/GenInfo.h"
23 #include "mlir/TableGen/GenNameParser.h"
24 #include "mlir/TableGen/Interfaces.h"
25 #include "mlir/TableGen/Operator.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/InitLLVM.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "llvm/TableGen/Main.h"
30 #include "llvm/TableGen/Record.h"
31 #include "llvm/TableGen/TableGenBackend.h"
35 using tblgen::NamedTypeConstraint
;
37 static llvm::cl::OptionCategory
dialectGenCat("Options for -gen-irdl-dialect");
38 llvm::cl::opt
<std::string
>
39 selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
40 llvm::cl::cat(dialectGenCat
), llvm::cl::Required
);
42 Value
createPredicate(OpBuilder
&builder
, tblgen::Pred pred
) {
43 MLIRContext
*ctx
= builder
.getContext();
45 if (pred
.isCombined()) {
46 auto combiner
= pred
.getDef().getValueAsDef("kind")->getName();
47 if (combiner
== "PredCombinerAnd" || combiner
== "PredCombinerOr") {
48 std::vector
<Value
> constraints
;
49 for (auto *child
: pred
.getDef().getValueAsListOfDefs("children")) {
50 constraints
.push_back(createPredicate(builder
, tblgen::Pred(child
)));
52 if (combiner
== "PredCombinerAnd") {
54 builder
.create
<irdl::AllOfOp
>(UnknownLoc::get(ctx
), constraints
);
55 return op
.getOutput();
58 builder
.create
<irdl::AnyOfOp
>(UnknownLoc::get(ctx
), constraints
);
59 return op
.getOutput();
63 std::string condition
= pred
.getCondition();
64 // Build a CPredOp to match the C constraint built.
65 irdl::CPredOp op
= builder
.create
<irdl::CPredOp
>(
66 UnknownLoc::get(ctx
), StringAttr::get(ctx
, condition
));
70 Value
typeToConstraint(OpBuilder
&builder
, Type type
) {
71 MLIRContext
*ctx
= builder
.getContext();
73 builder
.create
<irdl::IsOp
>(UnknownLoc::get(ctx
), TypeAttr::get(type
));
74 return op
.getOutput();
77 Value
baseToConstraint(OpBuilder
&builder
, StringRef baseClass
) {
78 MLIRContext
*ctx
= builder
.getContext();
79 auto op
= builder
.create
<irdl::BaseOp
>(UnknownLoc::get(ctx
),
80 StringAttr::get(ctx
, baseClass
));
81 return op
.getOutput();
84 std::optional
<Type
> recordToType(MLIRContext
*ctx
, const Record
&predRec
) {
85 if (predRec
.isSubClassOf("I")) {
86 auto width
= predRec
.getValueAsInt("bitwidth");
87 return IntegerType::get(ctx
, width
, IntegerType::Signless
);
90 if (predRec
.isSubClassOf("SI")) {
91 auto width
= predRec
.getValueAsInt("bitwidth");
92 return IntegerType::get(ctx
, width
, IntegerType::Signed
);
95 if (predRec
.isSubClassOf("UI")) {
96 auto width
= predRec
.getValueAsInt("bitwidth");
97 return IntegerType::get(ctx
, width
, IntegerType::Unsigned
);
101 if (predRec
.getName() == "Index") {
102 return IndexType::get(ctx
);
106 if (predRec
.isSubClassOf("F")) {
107 auto width
= predRec
.getValueAsInt("bitwidth");
110 return FloatType::getF16(ctx
);
112 return FloatType::getF32(ctx
);
114 return FloatType::getF64(ctx
);
116 return FloatType::getF80(ctx
);
118 return FloatType::getF128(ctx
);
122 if (predRec
.getName() == "NoneType") {
123 return NoneType::get(ctx
);
126 if (predRec
.getName() == "BF16") {
127 return FloatType::getBF16(ctx
);
130 if (predRec
.getName() == "TF32") {
131 return FloatType::getTF32(ctx
);
134 if (predRec
.getName() == "F8E4M3FN") {
135 return FloatType::getFloat8E4M3FN(ctx
);
138 if (predRec
.getName() == "F8E5M2") {
139 return FloatType::getFloat8E5M2(ctx
);
142 if (predRec
.getName() == "F8E4M3") {
143 return FloatType::getFloat8E4M3(ctx
);
146 if (predRec
.getName() == "F8E4M3FNUZ") {
147 return FloatType::getFloat8E4M3FNUZ(ctx
);
150 if (predRec
.getName() == "F8E4M3B11FNUZ") {
151 return FloatType::getFloat8E4M3B11FNUZ(ctx
);
154 if (predRec
.getName() == "F8E5M2FNUZ") {
155 return FloatType::getFloat8E5M2FNUZ(ctx
);
158 if (predRec
.getName() == "F8E3M4") {
159 return FloatType::getFloat8E3M4(ctx
);
162 if (predRec
.isSubClassOf("Complex")) {
163 const Record
*elementRec
= predRec
.getValueAsDef("elementType");
164 auto elementType
= recordToType(ctx
, *elementRec
);
165 if (elementType
.has_value()) {
166 return ComplexType::get(elementType
.value());
173 Value
createTypeConstraint(OpBuilder
&builder
, tblgen::Constraint constraint
) {
174 MLIRContext
*ctx
= builder
.getContext();
175 const Record
&predRec
= constraint
.getDef();
177 if (predRec
.isSubClassOf("Variadic") || predRec
.isSubClassOf("Optional"))
178 return createTypeConstraint(builder
, predRec
.getValueAsDef("baseType"));
180 if (predRec
.getName() == "AnyType") {
181 auto op
= builder
.create
<irdl::AnyOp
>(UnknownLoc::get(ctx
));
182 return op
.getOutput();
185 if (predRec
.isSubClassOf("TypeDef")) {
186 auto dialect
= predRec
.getValueAsDef("dialect")->getValueAsString("name");
187 if (dialect
== selectedDialect
) {
188 std::string combined
= ("!" + predRec
.getValueAsString("mnemonic")).str();
189 SmallVector
<FlatSymbolRefAttr
> nested
= {
190 SymbolRefAttr::get(ctx
, combined
)};
191 auto typeSymbol
= SymbolRefAttr::get(ctx
, dialect
, nested
);
192 auto op
= builder
.create
<irdl::BaseOp
>(UnknownLoc::get(ctx
), typeSymbol
);
193 return op
.getOutput();
195 std::string typeName
= ("!" + predRec
.getValueAsString("typeName")).str();
196 auto op
= builder
.create
<irdl::BaseOp
>(UnknownLoc::get(ctx
),
197 StringAttr::get(ctx
, typeName
));
198 return op
.getOutput();
201 if (predRec
.isSubClassOf("AnyTypeOf")) {
202 std::vector
<Value
> constraints
;
203 for (const Record
*child
: predRec
.getValueAsListOfDefs("allowedTypes")) {
204 constraints
.push_back(
205 createTypeConstraint(builder
, tblgen::Constraint(child
)));
207 auto op
= builder
.create
<irdl::AnyOfOp
>(UnknownLoc::get(ctx
), constraints
);
208 return op
.getOutput();
211 if (predRec
.isSubClassOf("AllOfType")) {
212 std::vector
<Value
> constraints
;
213 for (const Record
*child
: predRec
.getValueAsListOfDefs("allowedTypes")) {
214 constraints
.push_back(
215 createTypeConstraint(builder
, tblgen::Constraint(child
)));
217 auto op
= builder
.create
<irdl::AllOfOp
>(UnknownLoc::get(ctx
), constraints
);
218 return op
.getOutput();
222 if (predRec
.getName() == "AnyInteger") {
223 auto op
= builder
.create
<irdl::BaseOp
>(
224 UnknownLoc::get(ctx
), StringAttr::get(ctx
, "!builtin.integer"));
225 return op
.getOutput();
228 if (predRec
.isSubClassOf("AnyI")) {
229 auto width
= predRec
.getValueAsInt("bitwidth");
230 std::vector
<Value
> types
= {
231 typeToConstraint(builder
,
232 IntegerType::get(ctx
, width
, IntegerType::Signless
)),
233 typeToConstraint(builder
,
234 IntegerType::get(ctx
, width
, IntegerType::Signed
)),
235 typeToConstraint(builder
,
236 IntegerType::get(ctx
, width
, IntegerType::Unsigned
))};
237 auto op
= builder
.create
<irdl::AnyOfOp
>(UnknownLoc::get(ctx
), types
);
238 return op
.getOutput();
241 auto type
= recordToType(ctx
, predRec
);
243 if (type
.has_value()) {
244 return typeToConstraint(builder
, type
.value());
248 if (predRec
.isSubClassOf("ConfinedType")) {
249 std::vector
<Value
> constraints
;
250 constraints
.push_back(createTypeConstraint(
251 builder
, tblgen::Constraint(predRec
.getValueAsDef("baseType"))));
252 for (const Record
*child
: predRec
.getValueAsListOfDefs("predicateList")) {
253 constraints
.push_back(createPredicate(builder
, tblgen::Pred(child
)));
255 auto op
= builder
.create
<irdl::AllOfOp
>(UnknownLoc::get(ctx
), constraints
);
256 return op
.getOutput();
259 return createPredicate(builder
, constraint
.getPredicate());
262 Value
createAttrConstraint(OpBuilder
&builder
, tblgen::Constraint constraint
) {
263 MLIRContext
*ctx
= builder
.getContext();
264 const Record
&predRec
= constraint
.getDef();
266 if (predRec
.isSubClassOf("DefaultValuedAttr") ||
267 predRec
.isSubClassOf("DefaultValuedOptionalAttr") ||
268 predRec
.isSubClassOf("OptionalAttr")) {
269 return createAttrConstraint(builder
, predRec
.getValueAsDef("baseAttr"));
272 if (predRec
.isSubClassOf("ConfinedAttr")) {
273 std::vector
<Value
> constraints
;
274 constraints
.push_back(createAttrConstraint(
275 builder
, tblgen::Constraint(predRec
.getValueAsDef("baseAttr"))));
276 for (const Record
*child
:
277 predRec
.getValueAsListOfDefs("attrConstraints")) {
278 constraints
.push_back(createPredicate(
279 builder
, tblgen::Pred(child
->getValueAsDef("predicate"))));
281 auto op
= builder
.create
<irdl::AllOfOp
>(UnknownLoc::get(ctx
), constraints
);
282 return op
.getOutput();
285 if (predRec
.isSubClassOf("AnyAttrOf")) {
286 std::vector
<Value
> constraints
;
287 for (const Record
*child
:
288 predRec
.getValueAsListOfDefs("allowedAttributes")) {
289 constraints
.push_back(
290 createAttrConstraint(builder
, tblgen::Constraint(child
)));
292 auto op
= builder
.create
<irdl::AnyOfOp
>(UnknownLoc::get(ctx
), constraints
);
293 return op
.getOutput();
296 if (predRec
.getName() == "AnyAttr") {
297 auto op
= builder
.create
<irdl::AnyOp
>(UnknownLoc::get(ctx
));
298 return op
.getOutput();
301 if (predRec
.isSubClassOf("AnyIntegerAttrBase") ||
302 predRec
.isSubClassOf("SignlessIntegerAttrBase") ||
303 predRec
.isSubClassOf("SignedIntegerAttrBase") ||
304 predRec
.isSubClassOf("UnsignedIntegerAttrBase") ||
305 predRec
.isSubClassOf("BoolAttr")) {
306 return baseToConstraint(builder
, "!builtin.integer");
309 if (predRec
.isSubClassOf("FloatAttrBase")) {
310 return baseToConstraint(builder
, "!builtin.float");
313 if (predRec
.isSubClassOf("StringBasedAttr")) {
314 return baseToConstraint(builder
, "!builtin.string");
317 if (predRec
.getName() == "UnitAttr") {
319 builder
.create
<irdl::IsOp
>(UnknownLoc::get(ctx
), UnitAttr::get(ctx
));
320 return op
.getOutput();
323 if (predRec
.isSubClassOf("AttrDef")) {
324 auto dialect
= predRec
.getValueAsDef("dialect")->getValueAsString("name");
325 if (dialect
== selectedDialect
) {
326 std::string combined
= ("#" + predRec
.getValueAsString("mnemonic")).str();
327 SmallVector
<FlatSymbolRefAttr
> nested
= {SymbolRefAttr::get(ctx
, combined
)
330 auto typeSymbol
= SymbolRefAttr::get(ctx
, dialect
, nested
);
331 auto op
= builder
.create
<irdl::BaseOp
>(UnknownLoc::get(ctx
), typeSymbol
);
332 return op
.getOutput();
334 std::string typeName
= ("#" + predRec
.getValueAsString("attrName")).str();
335 auto op
= builder
.create
<irdl::BaseOp
>(UnknownLoc::get(ctx
),
336 StringAttr::get(ctx
, typeName
));
337 return op
.getOutput();
340 return createPredicate(builder
, constraint
.getPredicate());
343 Value
createRegionConstraint(OpBuilder
&builder
, tblgen::Region constraint
) {
344 MLIRContext
*ctx
= builder
.getContext();
345 const Record
&predRec
= constraint
.getDef();
347 if (predRec
.getName() == "AnyRegion") {
348 ValueRange entryBlockArgs
= {};
350 builder
.create
<irdl::RegionOp
>(UnknownLoc::get(ctx
), entryBlockArgs
);
351 return op
.getResult();
354 if (predRec
.isSubClassOf("SizedRegion")) {
355 ValueRange entryBlockArgs
= {};
356 auto ty
= IntegerType::get(ctx
, 32);
357 auto op
= builder
.create
<irdl::RegionOp
>(
358 UnknownLoc::get(ctx
), entryBlockArgs
,
359 IntegerAttr::get(ty
, predRec
.getValueAsInt("blocks")));
360 return op
.getResult();
363 return createPredicate(builder
, constraint
.getPredicate());
366 /// Returns the name of the operation without the dialect prefix.
367 static StringRef
getOperatorName(tblgen::Operator
&tblgenOp
) {
368 StringRef opName
= tblgenOp
.getDef().getValueAsString("opName");
372 /// Returns the name of the type without the dialect prefix.
373 static StringRef
getTypeName(tblgen::TypeDef
&tblgenType
) {
374 StringRef opName
= tblgenType
.getDef()->getValueAsString("mnemonic");
378 /// Returns the name of the attr without the dialect prefix.
379 static StringRef
getAttrName(tblgen::AttrDef
&tblgenType
) {
380 StringRef opName
= tblgenType
.getDef()->getValueAsString("mnemonic");
384 /// Extract an operation to IRDL.
385 irdl::OperationOp
createIRDLOperation(OpBuilder
&builder
,
386 tblgen::Operator
&tblgenOp
) {
387 MLIRContext
*ctx
= builder
.getContext();
388 StringRef opName
= getOperatorName(tblgenOp
);
390 irdl::OperationOp op
= builder
.create
<irdl::OperationOp
>(
391 UnknownLoc::get(ctx
), StringAttr::get(ctx
, opName
));
393 // Add the block in the region.
394 Block
&opBlock
= op
.getBody().emplaceBlock();
395 OpBuilder consBuilder
= OpBuilder::atBlockBegin(&opBlock
);
397 auto getValues
= [&](tblgen::Operator::const_value_range namedCons
) {
398 SmallVector
<Value
> operands
;
399 SmallVector
<irdl::VariadicityAttr
> variadicity
;
400 for (const NamedTypeConstraint
&namedCons
: namedCons
) {
401 auto operand
= createTypeConstraint(consBuilder
, namedCons
.constraint
);
402 operands
.push_back(operand
);
404 irdl::VariadicityAttr var
;
405 if (namedCons
.isOptional())
406 var
= consBuilder
.getAttr
<irdl::VariadicityAttr
>(
407 irdl::Variadicity::optional
);
408 else if (namedCons
.isVariadic())
409 var
= consBuilder
.getAttr
<irdl::VariadicityAttr
>(
410 irdl::Variadicity::variadic
);
412 var
= consBuilder
.getAttr
<irdl::VariadicityAttr
>(
413 irdl::Variadicity::single
);
415 variadicity
.push_back(var
);
417 return std::make_tuple(operands
, variadicity
);
420 auto [operands
, operandVariadicity
] = getValues(tblgenOp
.getOperands());
421 auto [results
, resultVariadicity
] = getValues(tblgenOp
.getResults());
423 SmallVector
<Value
> attributes
;
424 SmallVector
<Attribute
> attrNames
;
425 for (auto namedAttr
: tblgenOp
.getAttributes()) {
426 if (namedAttr
.attr
.isOptional())
428 attributes
.push_back(createAttrConstraint(consBuilder
, namedAttr
.attr
));
429 attrNames
.push_back(StringAttr::get(ctx
, namedAttr
.name
));
432 SmallVector
<Value
> regions
;
433 for (auto namedRegion
: tblgenOp
.getRegions()) {
435 createRegionConstraint(consBuilder
, namedRegion
.constraint
));
438 // Create the operands and results operations.
439 if (!operands
.empty())
440 consBuilder
.create
<irdl::OperandsOp
>(UnknownLoc::get(ctx
), operands
,
442 if (!results
.empty())
443 consBuilder
.create
<irdl::ResultsOp
>(UnknownLoc::get(ctx
), results
,
445 if (!attributes
.empty())
446 consBuilder
.create
<irdl::AttributesOp
>(UnknownLoc::get(ctx
), attributes
,
447 ArrayAttr::get(ctx
, attrNames
));
448 if (!regions
.empty())
449 consBuilder
.create
<irdl::RegionsOp
>(UnknownLoc::get(ctx
), regions
);
454 irdl::TypeOp
createIRDLType(OpBuilder
&builder
, tblgen::TypeDef
&tblgenType
) {
455 MLIRContext
*ctx
= builder
.getContext();
456 StringRef typeName
= getTypeName(tblgenType
);
457 std::string combined
= ("!" + typeName
).str();
459 irdl::TypeOp op
= builder
.create
<irdl::TypeOp
>(
460 UnknownLoc::get(ctx
), StringAttr::get(ctx
, combined
));
462 op
.getBody().emplaceBlock();
467 irdl::AttributeOp
createIRDLAttr(OpBuilder
&builder
,
468 tblgen::AttrDef
&tblgenAttr
) {
469 MLIRContext
*ctx
= builder
.getContext();
470 StringRef attrName
= getAttrName(tblgenAttr
);
471 std::string combined
= ("#" + attrName
).str();
473 irdl::AttributeOp op
= builder
.create
<irdl::AttributeOp
>(
474 UnknownLoc::get(ctx
), StringAttr::get(ctx
, combined
));
476 op
.getBody().emplaceBlock();
481 static irdl::DialectOp
createIRDLDialect(OpBuilder
&builder
) {
482 MLIRContext
*ctx
= builder
.getContext();
483 return builder
.create
<irdl::DialectOp
>(UnknownLoc::get(ctx
),
484 StringAttr::get(ctx
, selectedDialect
));
487 static bool emitDialectIRDLDefs(const RecordKeeper
&records
, raw_ostream
&os
) {
490 ctx
.getOrLoadDialect
<irdl::IRDLDialect
>();
491 OpBuilder
builder(&ctx
);
493 // Create a module op and set it as the insertion point.
494 OwningOpRef
<ModuleOp
> module
=
495 builder
.create
<ModuleOp
>(UnknownLoc::get(&ctx
));
496 builder
= builder
.atBlockBegin(module
->getBody());
497 // Create the dialect and insert it.
498 irdl::DialectOp dialect
= createIRDLDialect(builder
);
499 // Set insertion point to start of DialectOp.
500 builder
= builder
.atBlockBegin(&dialect
.getBody().emplaceBlock());
502 for (const Record
*type
:
503 records
.getAllDerivedDefinitionsIfDefined("TypeDef")) {
504 tblgen::TypeDef
tblgenType(type
);
505 if (tblgenType
.getDialect().getName() != selectedDialect
)
507 createIRDLType(builder
, tblgenType
);
510 for (const Record
*attr
:
511 records
.getAllDerivedDefinitionsIfDefined("AttrDef")) {
512 tblgen::AttrDef
tblgenAttr(attr
);
513 if (tblgenAttr
.getDialect().getName() != selectedDialect
)
515 createIRDLAttr(builder
, tblgenAttr
);
518 for (const Record
*def
: records
.getAllDerivedDefinitionsIfDefined("Op")) {
519 tblgen::Operator
tblgenOp(def
);
520 if (tblgenOp
.getDialectName() != selectedDialect
)
523 createIRDLOperation(builder
, tblgenOp
);
532 static mlir::GenRegistration
533 genOpDefs("gen-dialect-irdl-defs", "Generate IRDL dialect definitions",
534 [](const RecordKeeper
&records
, raw_ostream
&os
) {
535 return emitDialectIRDLDefs(records
, os
);