1 //===- LLVMIRConversionGen.cpp - MLIR LLVM IR builder generator -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file uses tablegen definitions of the LLVM IR Dialect operations to
10 // generate the code building the LLVM IR from it.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Support/LogicalResult.h"
15 #include "mlir/TableGen/Argument.h"
16 #include "mlir/TableGen/Attribute.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Operator.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27 #include "llvm/TableGen/TableGenBackend.h"
32 static LogicalResult
emitError(const Record
&record
, const Twine
&message
) {
33 PrintError(&record
, message
);
38 // Helper structure to return a position of the substring in a string.
43 // Take a substring identified by this location in the given string.
44 StringRef
in(StringRef str
) const { return str
.substr(pos
, length
); }
46 // A location is invalid if its position is outside the string.
47 explicit operator bool() { return pos
!= std::string::npos
; }
51 // Find the next TableGen variable in the given pattern. These variables start
52 // with a `$` character and can contain alphanumeric characters or underscores.
53 // Return the position of the variable in the pattern and its length, including
54 // the `$` character. The escape syntax `$$` is also detected and returned.
55 static StringLoc
findNextVariable(StringRef str
) {
56 size_t startPos
= str
.find('$');
57 if (startPos
== std::string::npos
)
60 // If we see "$$", return immediately.
61 if (startPos
!= str
.size() - 1 && str
[startPos
+ 1] == '$')
64 // Otherwise, the symbol spans until the first character that is not
65 // alphanumeric or '_'.
66 size_t endPos
= str
.find_if_not([](char c
) { return isAlnum(c
) || c
== '_'; },
68 if (endPos
== std::string::npos
)
71 return {startPos
, endPos
- startPos
};
74 // Check if `name` is a variadic operand of `op`. Seach all operands since the
75 // MLIR and LLVM IR operand order may differ and only for the latter the
76 // variadic operand is guaranteed to be at the end of the operands list.
77 static bool isVariadicOperandName(const tblgen::Operator
&op
, StringRef name
) {
78 for (int i
= 0, e
= op
.getNumOperands(); i
< e
; ++i
)
79 if (op
.getOperand(i
).name
== name
)
80 return op
.getOperand(i
).isVariadic();
84 // Check if `result` is a known name of a result of `op`.
85 static bool isResultName(const tblgen::Operator
&op
, StringRef name
) {
86 for (int i
= 0, e
= op
.getNumResults(); i
< e
; ++i
)
87 if (op
.getResultName(i
) == name
)
92 // Check if `name` is a known name of an attribute of `op`.
93 static bool isAttributeName(const tblgen::Operator
&op
, StringRef name
) {
96 [name
](const tblgen::NamedAttribute
&attr
) { return attr
.name
== name
; });
99 // Check if `name` is a known name of an operand of `op`.
100 static bool isOperandName(const tblgen::Operator
&op
, StringRef name
) {
101 for (int i
= 0, e
= op
.getNumOperands(); i
< e
; ++i
)
102 if (op
.getOperand(i
).name
== name
)
107 // Return the `op` argument index of the argument with the given `name`.
108 static FailureOr
<int> getArgumentIndex(const tblgen::Operator
&op
,
110 for (int i
= 0, e
= op
.getNumArgs(); i
!= e
; ++i
)
111 if (op
.getArgName(i
) == name
)
116 // Emit to `os` the operator-name driven check and the call to LLVM IRBuilder
117 // for one definition of an LLVM IR Dialect operation.
118 static LogicalResult
emitOneBuilder(const Record
&record
, raw_ostream
&os
) {
119 auto op
= tblgen::Operator(record
);
121 if (!record
.getValue("llvmBuilder"))
122 return emitError(record
, "expected 'llvmBuilder' field");
124 // Return early if there is no builder specified.
125 StringRef builderStrRef
= record
.getValueAsString("llvmBuilder");
126 if (builderStrRef
.empty())
129 // Progressively create the builder string by replacing $-variables with
130 // value lookups. Keep only the not-yet-traversed part of the builder pattern
131 // to avoid re-traversing the string multiple times.
133 llvm::raw_string_ostream
bs(builder
);
134 while (StringLoc loc
= findNextVariable(builderStrRef
)) {
135 auto name
= loc
.in(builderStrRef
).drop_front();
136 auto getterName
= op
.getGetterName(name
);
137 // First, insert the non-matched part as is.
138 bs
<< builderStrRef
.substr(0, loc
.pos
);
139 // Then, rewrite the name based on its kind.
140 bool isVariadicOperand
= isVariadicOperandName(op
, name
);
141 if (isOperandName(op
, name
)) {
144 ? formatv("moduleTranslation.lookupValues(op.{0}())", getterName
)
145 : formatv("moduleTranslation.lookupValue(op.{0}())", getterName
);
147 } else if (isAttributeName(op
, name
)) {
148 bs
<< formatv("op.{0}()", getterName
);
149 } else if (isResultName(op
, name
)) {
150 bs
<< formatv("moduleTranslation.mapValue(op.{0}())", getterName
);
151 } else if (name
== "_resultType") {
152 bs
<< "moduleTranslation.convertType(op.getResult().getType())";
153 } else if (name
== "_hasResult") {
154 bs
<< "opInst.getNumResults() == 1";
155 } else if (name
== "_location") {
156 bs
<< "opInst.getLoc()";
157 } else if (name
== "_numOperands") {
158 bs
<< "opInst.getNumOperands()";
159 } else if (name
== "$") {
163 record
, "expected keyword, argument, or result, but got " + name
);
165 // Finally, only keep the untraversed part of the string.
166 builderStrRef
= builderStrRef
.substr(loc
.pos
+ loc
.length
);
169 // Output the check and the rewritten builder string.
170 os
<< "if (auto op = dyn_cast<" << op
.getQualCppClassName()
172 os
<< bs
.str() << builderStrRef
<< "\n";
173 os
<< " return success();\n";
179 // Emit all builders. Returns false on success because of the generator
180 // registration requirements.
181 static bool emitBuilders(const RecordKeeper
&recordKeeper
, raw_ostream
&os
) {
182 for (const Record
*def
:
183 recordKeeper
.getAllDerivedDefinitions("LLVM_OpBase")) {
184 if (failed(emitOneBuilder(*def
, os
)))
190 using ConditionFn
= mlir::function_ref
<llvm::Twine(const Record
&record
)>;
192 // Emit a conditional call to the MLIR builder of the LLVM dialect operation to
193 // build for the given LLVM IR instruction. A condition function `conditionFn`
194 // emits a check to verify the opcode or intrinsic identifier of the LLVM IR
195 // instruction matches the LLVM dialect operation to build.
196 static LogicalResult
emitOneMLIRBuilder(const Record
&record
, raw_ostream
&os
,
197 ConditionFn conditionFn
) {
198 auto op
= tblgen::Operator(record
);
200 if (!record
.getValue("mlirBuilder"))
201 return emitError(record
, "expected 'mlirBuilder' field");
203 // Return early if there is no builder specified.
204 StringRef builderStrRef
= record
.getValueAsString("mlirBuilder");
205 if (builderStrRef
.empty())
208 // Access the argument index array that maps argument indices to LLVM IR
209 // operand indices. If the operation defines no custom mapping, set the array
210 // to the identity permutation.
211 std::vector
<int64_t> llvmArgIndices
=
212 record
.getValueAsListOfInts("llvmArgIndices");
213 if (llvmArgIndices
.empty())
214 append_range(llvmArgIndices
, seq
<int64_t>(0, op
.getNumArgs()));
215 if (llvmArgIndices
.size() != static_cast<size_t>(op
.getNumArgs())) {
218 "expected 'llvmArgIndices' size to match the number of arguments");
221 // Progressively create the builder string by replacing $-variables. Keep only
222 // the not-yet-traversed part of the builder pattern to avoid re-traversing
223 // the string multiple times. Additionally, emit an argument string
224 // immediately before the builder string. This argument string converts all
225 // operands used by the builder to MLIR values and returns failure if one of
226 // the conversions fails.
227 std::string arguments
, builder
;
228 llvm::raw_string_ostream
as(arguments
), bs(builder
);
229 while (StringLoc loc
= findNextVariable(builderStrRef
)) {
230 auto name
= loc
.in(builderStrRef
).drop_front();
231 // First, insert the non-matched part as is.
232 bs
<< builderStrRef
.substr(0, loc
.pos
);
233 // Then, rewrite the name based on its kind.
234 FailureOr
<int> argIndex
= getArgumentIndex(op
, name
);
235 if (succeeded(argIndex
)) {
236 // Access the LLVM IR operand that maps to the given argument index using
237 // the provided argument indices mapping.
238 int64_t idx
= llvmArgIndices
[*argIndex
];
241 record
, "expected non-negative operand index for argument " + name
);
243 if (isAttributeName(op
, name
)) {
244 bs
<< formatv("llvmOperands[{0}]", idx
);
246 if (isVariadicOperandName(op
, name
)) {
248 "FailureOr<SmallVector<Value>> _llvmir_gen_operand_{0} = "
249 "moduleImport.convertValues(llvmOperands.drop_front({1}));\n",
252 as
<< formatv("FailureOr<Value> _llvmir_gen_operand_{0} = "
253 "moduleImport.convertValue(llvmOperands[{1}]);\n",
256 as
<< formatv("if (failed(_llvmir_gen_operand_{0}))\n"
257 " return failure();\n",
259 bs
<< formatv("*_llvmir_gen_operand_{0}", name
);
261 } else if (isResultName(op
, name
)) {
262 if (op
.getNumResults() != 1)
263 return emitError(record
, "expected op to have one result");
264 bs
<< "moduleImport.mapValue(inst)";
265 } else if (name
== "_op") {
266 bs
<< "moduleImport.mapNoResultOp(inst)";
267 } else if (name
== "_int_attr") {
268 bs
<< "moduleImport.matchIntegerAttr";
269 } else if (name
== "_float_attr") {
270 bs
<< "moduleImport.matchFloatAttr";
271 } else if (name
== "_var_attr") {
272 bs
<< "moduleImport.matchLocalVariableAttr";
273 } else if (name
== "_resultType") {
274 bs
<< "moduleImport.convertType(inst->getType())";
275 } else if (name
== "_location") {
276 bs
<< "moduleImport.translateLoc(inst->getDebugLoc())";
277 } else if (name
== "_builder") {
279 } else if (name
== "_qualCppClassName") {
280 bs
<< op
.getQualCppClassName();
281 } else if (name
== "$") {
285 record
, "expected keyword, argument, or result, but got " + name
);
287 // Finally, only keep the untraversed part of the string.
288 builderStrRef
= builderStrRef
.substr(loc
.pos
+ loc
.length
);
291 // Output the check, the argument conversion, and the builder string.
292 os
<< "if (" << conditionFn(record
) << ") {\n";
293 os
<< as
.str() << "\n";
294 os
<< bs
.str() << builderStrRef
<< "\n";
295 os
<< " return success();\n";
301 // Emit all intrinsic MLIR builders. Returns false on success because of the
302 // generator registration requirements.
303 static bool emitIntrMLIRBuilders(const RecordKeeper
&recordKeeper
,
305 // Emit condition to check if "llvmEnumName" matches the intrinsic id.
306 auto emitIntrCond
= [](const Record
&record
) {
307 return "intrinsicID == llvm::Intrinsic::" +
308 record
.getValueAsString("llvmEnumName");
310 for (const Record
*def
:
311 recordKeeper
.getAllDerivedDefinitions("LLVM_IntrOpBase")) {
312 if (failed(emitOneMLIRBuilder(*def
, os
, emitIntrCond
)))
318 // Emit all op builders. Returns false on success because of the
319 // generator registration requirements.
320 static bool emitOpMLIRBuilders(const RecordKeeper
&recordKeeper
,
322 // Emit condition to check if "llvmInstName" matches the instruction opcode.
323 auto emitOpcodeCond
= [](const Record
&record
) {
324 return "inst->getOpcode() == llvm::Instruction::" +
325 record
.getValueAsString("llvmInstName");
327 for (const Record
*def
:
328 recordKeeper
.getAllDerivedDefinitions("LLVM_OpBase")) {
329 if (failed(emitOneMLIRBuilder(*def
, os
, emitOpcodeCond
)))
336 // Wrapper class around a Tablegen definition of an LLVM enum attribute case.
337 class LLVMEnumAttrCase
: public tblgen::EnumAttrCase
{
339 using tblgen::EnumAttrCase::EnumAttrCase
;
341 // Constructs a case from a non LLVM-specific enum attribute case.
342 explicit LLVMEnumAttrCase(const tblgen::EnumAttrCase
&other
)
343 : tblgen::EnumAttrCase(&other
.getDef()) {}
345 // Returns the C++ enumerant for the LLVM API.
346 StringRef
getLLVMEnumerant() const {
347 return def
->getValueAsString("llvmEnumerant");
351 // Wraper class around a Tablegen definition of an LLVM enum attribute.
352 class LLVMEnumAttr
: public tblgen::EnumAttr
{
354 using tblgen::EnumAttr::EnumAttr
;
356 // Returns the C++ enum name for the LLVM API.
357 StringRef
getLLVMClassName() const {
358 return def
->getValueAsString("llvmClassName");
361 // Returns all associated cases viewed as LLVM-specific enum cases.
362 std::vector
<LLVMEnumAttrCase
> getAllCases() const {
363 std::vector
<LLVMEnumAttrCase
> cases
;
365 for (auto &c
: tblgen::EnumAttr::getAllCases())
366 cases
.emplace_back(c
);
371 std::vector
<LLVMEnumAttrCase
> getAllUnsupportedCases() const {
372 const auto *inits
= def
->getValueAsListInit("unsupported");
374 std::vector
<LLVMEnumAttrCase
> cases
;
375 cases
.reserve(inits
->size());
377 for (const llvm::Init
*init
: *inits
)
378 cases
.emplace_back(cast
<llvm::DefInit
>(init
));
384 // Wraper class around a Tablegen definition of a C-style LLVM enum attribute.
385 class LLVMCEnumAttr
: public tblgen::EnumAttr
{
387 using tblgen::EnumAttr::EnumAttr
;
389 // Returns the C++ enum name for the LLVM API.
390 StringRef
getLLVMClassName() const {
391 return def
->getValueAsString("llvmClassName");
394 // Returns all associated cases viewed as LLVM-specific enum cases.
395 std::vector
<LLVMEnumAttrCase
> getAllCases() const {
396 std::vector
<LLVMEnumAttrCase
> cases
;
398 for (auto &c
: tblgen::EnumAttr::getAllCases())
399 cases
.emplace_back(c
);
406 // Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
407 // switch-based logic to convert from the MLIR LLVM dialect enum attribute case
408 // (Enum) to the corresponding LLVM API enumerant
409 static void emitOneEnumToConversion(const llvm::Record
*record
,
411 LLVMEnumAttr
enumAttr(record
);
412 StringRef llvmClass
= enumAttr
.getLLVMClassName();
413 StringRef cppClassName
= enumAttr
.getEnumClassName();
414 StringRef cppNamespace
= enumAttr
.getCppNamespace();
416 // Emit the function converting the enum attribute to its LLVM counterpart.
418 "static LLVM_ATTRIBUTE_UNUSED {0} convert{1}ToLLVM({2}::{1} value) {{\n",
419 llvmClass
, cppClassName
, cppNamespace
);
420 os
<< " switch (value) {\n";
422 for (const auto &enumerant
: enumAttr
.getAllCases()) {
423 StringRef llvmEnumerant
= enumerant
.getLLVMEnumerant();
424 StringRef cppEnumerant
= enumerant
.getSymbol();
425 os
<< formatv(" case {0}::{1}::{2}:\n", cppNamespace
, cppClassName
,
427 os
<< formatv(" return {0}::{1};\n", llvmClass
, llvmEnumerant
);
431 os
<< formatv(" llvm_unreachable(\"unknown {0} type\");\n",
432 enumAttr
.getEnumClassName());
436 // Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
437 // switch-based logic to convert from the MLIR LLVM dialect enum attribute case
438 // (Enum) to the corresponding LLVM API C-style enumerant
439 static void emitOneCEnumToConversion(const llvm::Record
*record
,
441 LLVMCEnumAttr
enumAttr(record
);
442 StringRef llvmClass
= enumAttr
.getLLVMClassName();
443 StringRef cppClassName
= enumAttr
.getEnumClassName();
444 StringRef cppNamespace
= enumAttr
.getCppNamespace();
446 // Emit the function converting the enum attribute to its LLVM counterpart.
447 os
<< formatv("static LLVM_ATTRIBUTE_UNUSED int64_t "
448 "convert{0}ToLLVM({1}::{0} value) {{\n",
449 cppClassName
, cppNamespace
);
450 os
<< " switch (value) {\n";
452 for (const auto &enumerant
: enumAttr
.getAllCases()) {
453 StringRef llvmEnumerant
= enumerant
.getLLVMEnumerant();
454 StringRef cppEnumerant
= enumerant
.getSymbol();
455 os
<< formatv(" case {0}::{1}::{2}:\n", cppNamespace
, cppClassName
,
457 os
<< formatv(" return static_cast<int64_t>({0}::{1});\n", llvmClass
,
462 os
<< formatv(" llvm_unreachable(\"unknown {0} type\");\n",
463 enumAttr
.getEnumClassName());
467 // Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
468 // containing switch-based logic to convert from the LLVM API enumerant to MLIR
469 // LLVM dialect enum attribute (Enum).
470 static void emitOneEnumFromConversion(const llvm::Record
*record
,
472 LLVMEnumAttr
enumAttr(record
);
473 StringRef llvmClass
= enumAttr
.getLLVMClassName();
474 StringRef cppClassName
= enumAttr
.getEnumClassName();
475 StringRef cppNamespace
= enumAttr
.getCppNamespace();
477 // Emit the function converting the enum attribute from its LLVM counterpart.
478 os
<< formatv("inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM({2} "
480 cppNamespace
, cppClassName
, llvmClass
);
481 os
<< " switch (value) {\n";
483 for (const auto &enumerant
: enumAttr
.getAllCases()) {
484 StringRef llvmEnumerant
= enumerant
.getLLVMEnumerant();
485 StringRef cppEnumerant
= enumerant
.getSymbol();
486 os
<< formatv(" case {0}::{1}:\n", llvmClass
, llvmEnumerant
);
487 os
<< formatv(" return {0}::{1}::{2};\n", cppNamespace
, cppClassName
,
490 for (const auto &enumerant
: enumAttr
.getAllUnsupportedCases()) {
491 StringRef llvmEnumerant
= enumerant
.getLLVMEnumerant();
492 os
<< formatv(" case {0}::{1}:\n", llvmClass
, llvmEnumerant
);
493 os
<< formatv(" llvm_unreachable(\"unsupported case {0}::{1}\");\n",
494 enumAttr
.getLLVMClassName(), llvmEnumerant
);
498 os
<< formatv(" llvm_unreachable(\"unknown {0} type\");",
499 enumAttr
.getLLVMClassName());
503 // Emits conversion function "Enum convertEnumFromLLVM(LLVMEnum)" and
504 // containing switch-based logic to convert from the LLVM API C-style enumerant
505 // to MLIR LLVM dialect enum attribute (Enum).
506 static void emitOneCEnumFromConversion(const llvm::Record
*record
,
508 LLVMCEnumAttr
enumAttr(record
);
509 StringRef llvmClass
= enumAttr
.getLLVMClassName();
510 StringRef cppClassName
= enumAttr
.getEnumClassName();
511 StringRef cppNamespace
= enumAttr
.getCppNamespace();
513 // Emit the function converting the enum attribute from its LLVM counterpart.
515 "inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM(int64_t "
517 cppNamespace
, cppClassName
, llvmClass
);
518 os
<< " switch (value) {\n";
520 for (const auto &enumerant
: enumAttr
.getAllCases()) {
521 StringRef llvmEnumerant
= enumerant
.getLLVMEnumerant();
522 StringRef cppEnumerant
= enumerant
.getSymbol();
523 os
<< formatv(" case static_cast<int64_t>({0}::{1}):\n", llvmClass
,
525 os
<< formatv(" return {0}::{1}::{2};\n", cppNamespace
, cppClassName
,
530 os
<< formatv(" llvm_unreachable(\"unknown {0} type\");",
531 enumAttr
.getLLVMClassName());
535 // Emits conversion functions between MLIR enum attribute case and corresponding
536 // LLVM API enumerants for all registered LLVM dialect enum attributes.
537 template <bool ConvertTo
>
538 static bool emitEnumConversionDefs(const RecordKeeper
&recordKeeper
,
540 for (const Record
*def
:
541 recordKeeper
.getAllDerivedDefinitions("LLVM_EnumAttr"))
543 emitOneEnumToConversion(def
, os
);
545 emitOneEnumFromConversion(def
, os
);
547 for (const Record
*def
:
548 recordKeeper
.getAllDerivedDefinitions("LLVM_CEnumAttr"))
550 emitOneCEnumToConversion(def
, os
);
552 emitOneCEnumFromConversion(def
, os
);
557 static void emitOneIntrinsic(const Record
&record
, raw_ostream
&os
) {
558 auto op
= tblgen::Operator(record
);
559 os
<< "llvm::Intrinsic::" << record
.getValueAsString("llvmEnumName") << ",\n";
562 // Emit the list of LLVM IR intrinsics identifiers that are convertible to a
563 // matching MLIR LLVM dialect intrinsic operation.
564 static bool emitConvertibleIntrinsics(const RecordKeeper
&recordKeeper
,
566 for (const Record
*def
:
567 recordKeeper
.getAllDerivedDefinitions("LLVM_IntrOpBase"))
568 emitOneIntrinsic(*def
, os
);
573 static mlir::GenRegistration
574 genLLVMIRConversions("gen-llvmir-conversions",
575 "Generate LLVM IR conversions", emitBuilders
);
577 static mlir::GenRegistration
genOpFromLLVMIRConversions(
578 "gen-op-from-llvmir-conversions",
579 "Generate conversions of operations from LLVM IR", emitOpMLIRBuilders
);
581 static mlir::GenRegistration
genIntrFromLLVMIRConversions(
582 "gen-intr-from-llvmir-conversions",
583 "Generate conversions of intrinsics from LLVM IR", emitIntrMLIRBuilders
);
585 static mlir::GenRegistration
586 genEnumToLLVMConversion("gen-enum-to-llvmir-conversions",
587 "Generate conversions of EnumAttrs to LLVM IR",
588 emitEnumConversionDefs
</*ConvertTo=*/true>);
590 static mlir::GenRegistration
591 genEnumFromLLVMConversion("gen-enum-from-llvmir-conversions",
592 "Generate conversions of EnumAttrs from LLVM IR",
593 emitEnumConversionDefs
</*ConvertTo=*/false>);
595 static mlir::GenRegistration
genConvertibleLLVMIRIntrinsics(
596 "gen-convertible-llvmir-intrinsics",
597 "Generate list of convertible LLVM IR intrinsics",
598 emitConvertibleIntrinsics
);