[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / tools / tblgen-to-irdl / OpDefinitionsGen.cpp
bloba763105fa0fd6a33543ad8d38d3b7cce1a94a044
1 //===- OpDefinitionsGen.cpp - IRDL op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate IRDL
10 // definitions for ops.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/IRDL/IR/IRDL.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/TableGen/AttrOrTypeDef.h"
22 #include "mlir/TableGen/GenInfo.h"
23 #include "mlir/TableGen/GenNameParser.h"
24 #include "mlir/TableGen/Interfaces.h"
25 #include "mlir/TableGen/Operator.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/InitLLVM.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "llvm/TableGen/Main.h"
30 #include "llvm/TableGen/Record.h"
31 #include "llvm/TableGen/TableGenBackend.h"
33 using namespace llvm;
34 using namespace mlir;
35 using tblgen::NamedTypeConstraint;
37 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-irdl-dialect");
38 llvm::cl::opt<std::string>
39 selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
40 llvm::cl::cat(dialectGenCat), llvm::cl::Required);
42 Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
43 MLIRContext *ctx = builder.getContext();
45 if (pred.isCombined()) {
46 auto combiner = pred.getDef().getValueAsDef("kind")->getName();
47 if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr") {
48 std::vector<Value> constraints;
49 for (auto *child : pred.getDef().getValueAsListOfDefs("children")) {
50 constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
52 if (combiner == "PredCombinerAnd") {
53 auto op =
54 builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
55 return op.getOutput();
57 auto op =
58 builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
59 return op.getOutput();
63 std::string condition = pred.getCondition();
64 // Build a CPredOp to match the C constraint built.
65 irdl::CPredOp op = builder.create<irdl::CPredOp>(
66 UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
67 return op;
70 Value typeToConstraint(OpBuilder &builder, Type type) {
71 MLIRContext *ctx = builder.getContext();
72 auto op =
73 builder.create<irdl::IsOp>(UnknownLoc::get(ctx), TypeAttr::get(type));
74 return op.getOutput();
77 Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
78 MLIRContext *ctx = builder.getContext();
79 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
80 StringAttr::get(ctx, baseClass));
81 return op.getOutput();
84 std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
85 if (predRec.isSubClassOf("I")) {
86 auto width = predRec.getValueAsInt("bitwidth");
87 return IntegerType::get(ctx, width, IntegerType::Signless);
90 if (predRec.isSubClassOf("SI")) {
91 auto width = predRec.getValueAsInt("bitwidth");
92 return IntegerType::get(ctx, width, IntegerType::Signed);
95 if (predRec.isSubClassOf("UI")) {
96 auto width = predRec.getValueAsInt("bitwidth");
97 return IntegerType::get(ctx, width, IntegerType::Unsigned);
100 // Index type
101 if (predRec.getName() == "Index") {
102 return IndexType::get(ctx);
105 // Float types
106 if (predRec.isSubClassOf("F")) {
107 auto width = predRec.getValueAsInt("bitwidth");
108 switch (width) {
109 case 16:
110 return FloatType::getF16(ctx);
111 case 32:
112 return FloatType::getF32(ctx);
113 case 64:
114 return FloatType::getF64(ctx);
115 case 80:
116 return FloatType::getF80(ctx);
117 case 128:
118 return FloatType::getF128(ctx);
122 if (predRec.getName() == "NoneType") {
123 return NoneType::get(ctx);
126 if (predRec.getName() == "BF16") {
127 return FloatType::getBF16(ctx);
130 if (predRec.getName() == "TF32") {
131 return FloatType::getTF32(ctx);
134 if (predRec.getName() == "F8E4M3FN") {
135 return FloatType::getFloat8E4M3FN(ctx);
138 if (predRec.getName() == "F8E5M2") {
139 return FloatType::getFloat8E5M2(ctx);
142 if (predRec.getName() == "F8E4M3") {
143 return FloatType::getFloat8E4M3(ctx);
146 if (predRec.getName() == "F8E4M3FNUZ") {
147 return FloatType::getFloat8E4M3FNUZ(ctx);
150 if (predRec.getName() == "F8E4M3B11FNUZ") {
151 return FloatType::getFloat8E4M3B11FNUZ(ctx);
154 if (predRec.getName() == "F8E5M2FNUZ") {
155 return FloatType::getFloat8E5M2FNUZ(ctx);
158 if (predRec.getName() == "F8E3M4") {
159 return FloatType::getFloat8E3M4(ctx);
162 if (predRec.isSubClassOf("Complex")) {
163 const Record *elementRec = predRec.getValueAsDef("elementType");
164 auto elementType = recordToType(ctx, *elementRec);
165 if (elementType.has_value()) {
166 return ComplexType::get(elementType.value());
170 return std::nullopt;
173 Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
174 MLIRContext *ctx = builder.getContext();
175 const Record &predRec = constraint.getDef();
177 if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
178 return createTypeConstraint(builder, predRec.getValueAsDef("baseType"));
180 if (predRec.getName() == "AnyType") {
181 auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
182 return op.getOutput();
185 if (predRec.isSubClassOf("TypeDef")) {
186 auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
187 if (dialect == selectedDialect) {
188 std::string combined = ("!" + predRec.getValueAsString("mnemonic")).str();
189 SmallVector<FlatSymbolRefAttr> nested = {
190 SymbolRefAttr::get(ctx, combined)};
191 auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
192 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
193 return op.getOutput();
195 std::string typeName = ("!" + predRec.getValueAsString("typeName")).str();
196 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
197 StringAttr::get(ctx, typeName));
198 return op.getOutput();
201 if (predRec.isSubClassOf("AnyTypeOf")) {
202 std::vector<Value> constraints;
203 for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
204 constraints.push_back(
205 createTypeConstraint(builder, tblgen::Constraint(child)));
207 auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
208 return op.getOutput();
211 if (predRec.isSubClassOf("AllOfType")) {
212 std::vector<Value> constraints;
213 for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
214 constraints.push_back(
215 createTypeConstraint(builder, tblgen::Constraint(child)));
217 auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
218 return op.getOutput();
221 // Integer types
222 if (predRec.getName() == "AnyInteger") {
223 auto op = builder.create<irdl::BaseOp>(
224 UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer"));
225 return op.getOutput();
228 if (predRec.isSubClassOf("AnyI")) {
229 auto width = predRec.getValueAsInt("bitwidth");
230 std::vector<Value> types = {
231 typeToConstraint(builder,
232 IntegerType::get(ctx, width, IntegerType::Signless)),
233 typeToConstraint(builder,
234 IntegerType::get(ctx, width, IntegerType::Signed)),
235 typeToConstraint(builder,
236 IntegerType::get(ctx, width, IntegerType::Unsigned))};
237 auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), types);
238 return op.getOutput();
241 auto type = recordToType(ctx, predRec);
243 if (type.has_value()) {
244 return typeToConstraint(builder, type.value());
247 // Confined type
248 if (predRec.isSubClassOf("ConfinedType")) {
249 std::vector<Value> constraints;
250 constraints.push_back(createTypeConstraint(
251 builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
252 for (const Record *child : predRec.getValueAsListOfDefs("predicateList")) {
253 constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
255 auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
256 return op.getOutput();
259 return createPredicate(builder, constraint.getPredicate());
262 Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
263 MLIRContext *ctx = builder.getContext();
264 const Record &predRec = constraint.getDef();
266 if (predRec.isSubClassOf("DefaultValuedAttr") ||
267 predRec.isSubClassOf("DefaultValuedOptionalAttr") ||
268 predRec.isSubClassOf("OptionalAttr")) {
269 return createAttrConstraint(builder, predRec.getValueAsDef("baseAttr"));
272 if (predRec.isSubClassOf("ConfinedAttr")) {
273 std::vector<Value> constraints;
274 constraints.push_back(createAttrConstraint(
275 builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
276 for (const Record *child :
277 predRec.getValueAsListOfDefs("attrConstraints")) {
278 constraints.push_back(createPredicate(
279 builder, tblgen::Pred(child->getValueAsDef("predicate"))));
281 auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
282 return op.getOutput();
285 if (predRec.isSubClassOf("AnyAttrOf")) {
286 std::vector<Value> constraints;
287 for (const Record *child :
288 predRec.getValueAsListOfDefs("allowedAttributes")) {
289 constraints.push_back(
290 createAttrConstraint(builder, tblgen::Constraint(child)));
292 auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
293 return op.getOutput();
296 if (predRec.getName() == "AnyAttr") {
297 auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
298 return op.getOutput();
301 if (predRec.isSubClassOf("AnyIntegerAttrBase") ||
302 predRec.isSubClassOf("SignlessIntegerAttrBase") ||
303 predRec.isSubClassOf("SignedIntegerAttrBase") ||
304 predRec.isSubClassOf("UnsignedIntegerAttrBase") ||
305 predRec.isSubClassOf("BoolAttr")) {
306 return baseToConstraint(builder, "!builtin.integer");
309 if (predRec.isSubClassOf("FloatAttrBase")) {
310 return baseToConstraint(builder, "!builtin.float");
313 if (predRec.isSubClassOf("StringBasedAttr")) {
314 return baseToConstraint(builder, "!builtin.string");
317 if (predRec.getName() == "UnitAttr") {
318 auto op =
319 builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx));
320 return op.getOutput();
323 if (predRec.isSubClassOf("AttrDef")) {
324 auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
325 if (dialect == selectedDialect) {
326 std::string combined = ("#" + predRec.getValueAsString("mnemonic")).str();
327 SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined)
330 auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
331 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
332 return op.getOutput();
334 std::string typeName = ("#" + predRec.getValueAsString("attrName")).str();
335 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
336 StringAttr::get(ctx, typeName));
337 return op.getOutput();
340 return createPredicate(builder, constraint.getPredicate());
343 Value createRegionConstraint(OpBuilder &builder, tblgen::Region constraint) {
344 MLIRContext *ctx = builder.getContext();
345 const Record &predRec = constraint.getDef();
347 if (predRec.getName() == "AnyRegion") {
348 ValueRange entryBlockArgs = {};
349 auto op =
350 builder.create<irdl::RegionOp>(UnknownLoc::get(ctx), entryBlockArgs);
351 return op.getResult();
354 if (predRec.isSubClassOf("SizedRegion")) {
355 ValueRange entryBlockArgs = {};
356 auto ty = IntegerType::get(ctx, 32);
357 auto op = builder.create<irdl::RegionOp>(
358 UnknownLoc::get(ctx), entryBlockArgs,
359 IntegerAttr::get(ty, predRec.getValueAsInt("blocks")));
360 return op.getResult();
363 return createPredicate(builder, constraint.getPredicate());
366 /// Returns the name of the operation without the dialect prefix.
367 static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
368 StringRef opName = tblgenOp.getDef().getValueAsString("opName");
369 return opName;
372 /// Returns the name of the type without the dialect prefix.
373 static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
374 StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
375 return opName;
378 /// Returns the name of the attr without the dialect prefix.
379 static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
380 StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
381 return opName;
384 /// Extract an operation to IRDL.
385 irdl::OperationOp createIRDLOperation(OpBuilder &builder,
386 tblgen::Operator &tblgenOp) {
387 MLIRContext *ctx = builder.getContext();
388 StringRef opName = getOperatorName(tblgenOp);
390 irdl::OperationOp op = builder.create<irdl::OperationOp>(
391 UnknownLoc::get(ctx), StringAttr::get(ctx, opName));
393 // Add the block in the region.
394 Block &opBlock = op.getBody().emplaceBlock();
395 OpBuilder consBuilder = OpBuilder::atBlockBegin(&opBlock);
397 auto getValues = [&](tblgen::Operator::const_value_range namedCons) {
398 SmallVector<Value> operands;
399 SmallVector<irdl::VariadicityAttr> variadicity;
400 for (const NamedTypeConstraint &namedCons : namedCons) {
401 auto operand = createTypeConstraint(consBuilder, namedCons.constraint);
402 operands.push_back(operand);
404 irdl::VariadicityAttr var;
405 if (namedCons.isOptional())
406 var = consBuilder.getAttr<irdl::VariadicityAttr>(
407 irdl::Variadicity::optional);
408 else if (namedCons.isVariadic())
409 var = consBuilder.getAttr<irdl::VariadicityAttr>(
410 irdl::Variadicity::variadic);
411 else
412 var = consBuilder.getAttr<irdl::VariadicityAttr>(
413 irdl::Variadicity::single);
415 variadicity.push_back(var);
417 return std::make_tuple(operands, variadicity);
420 auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands());
421 auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
423 SmallVector<Value> attributes;
424 SmallVector<Attribute> attrNames;
425 for (auto namedAttr : tblgenOp.getAttributes()) {
426 if (namedAttr.attr.isOptional())
427 continue;
428 attributes.push_back(createAttrConstraint(consBuilder, namedAttr.attr));
429 attrNames.push_back(StringAttr::get(ctx, namedAttr.name));
432 SmallVector<Value> regions;
433 for (auto namedRegion : tblgenOp.getRegions()) {
434 regions.push_back(
435 createRegionConstraint(consBuilder, namedRegion.constraint));
438 // Create the operands and results operations.
439 if (!operands.empty())
440 consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
441 operandVariadicity);
442 if (!results.empty())
443 consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
444 resultVariadicity);
445 if (!attributes.empty())
446 consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes,
447 ArrayAttr::get(ctx, attrNames));
448 if (!regions.empty())
449 consBuilder.create<irdl::RegionsOp>(UnknownLoc::get(ctx), regions);
451 return op;
454 irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
455 MLIRContext *ctx = builder.getContext();
456 StringRef typeName = getTypeName(tblgenType);
457 std::string combined = ("!" + typeName).str();
459 irdl::TypeOp op = builder.create<irdl::TypeOp>(
460 UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
462 op.getBody().emplaceBlock();
464 return op;
467 irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
468 tblgen::AttrDef &tblgenAttr) {
469 MLIRContext *ctx = builder.getContext();
470 StringRef attrName = getAttrName(tblgenAttr);
471 std::string combined = ("#" + attrName).str();
473 irdl::AttributeOp op = builder.create<irdl::AttributeOp>(
474 UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
476 op.getBody().emplaceBlock();
478 return op;
481 static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
482 MLIRContext *ctx = builder.getContext();
483 return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
484 StringAttr::get(ctx, selectedDialect));
487 static bool emitDialectIRDLDefs(const RecordKeeper &records, raw_ostream &os) {
488 // Initialize.
489 MLIRContext ctx;
490 ctx.getOrLoadDialect<irdl::IRDLDialect>();
491 OpBuilder builder(&ctx);
493 // Create a module op and set it as the insertion point.
494 OwningOpRef<ModuleOp> module =
495 builder.create<ModuleOp>(UnknownLoc::get(&ctx));
496 builder = builder.atBlockBegin(module->getBody());
497 // Create the dialect and insert it.
498 irdl::DialectOp dialect = createIRDLDialect(builder);
499 // Set insertion point to start of DialectOp.
500 builder = builder.atBlockBegin(&dialect.getBody().emplaceBlock());
502 for (const Record *type :
503 records.getAllDerivedDefinitionsIfDefined("TypeDef")) {
504 tblgen::TypeDef tblgenType(type);
505 if (tblgenType.getDialect().getName() != selectedDialect)
506 continue;
507 createIRDLType(builder, tblgenType);
510 for (const Record *attr :
511 records.getAllDerivedDefinitionsIfDefined("AttrDef")) {
512 tblgen::AttrDef tblgenAttr(attr);
513 if (tblgenAttr.getDialect().getName() != selectedDialect)
514 continue;
515 createIRDLAttr(builder, tblgenAttr);
518 for (const Record *def : records.getAllDerivedDefinitionsIfDefined("Op")) {
519 tblgen::Operator tblgenOp(def);
520 if (tblgenOp.getDialectName() != selectedDialect)
521 continue;
523 createIRDLOperation(builder, tblgenOp);
526 // Print the module.
527 module->print(os);
529 return false;
532 static mlir::GenRegistration
533 genOpDefs("gen-dialect-irdl-defs", "Generate IRDL dialect definitions",
534 [](const RecordKeeper &records, raw_ostream &os) {
535 return emitDialectIRDLDefs(records, os);