1 //===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // SPIRVSerializationGen generates common utility functions for SPIR-V
12 //===----------------------------------------------------------------------===//
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Operator.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSet.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "llvm/TableGen/Error.h"
28 #include "llvm/TableGen/Record.h"
29 #include "llvm/TableGen/TableGenBackend.h"
36 using llvm::raw_ostream
;
37 using llvm::raw_string_ostream
;
39 using llvm::RecordKeeper
;
40 using llvm::SmallVector
;
42 using llvm::StringMap
;
43 using llvm::StringRef
;
44 using mlir::tblgen::Attribute
;
45 using mlir::tblgen::EnumAttr
;
46 using mlir::tblgen::EnumAttrCase
;
47 using mlir::tblgen::NamedAttribute
;
48 using mlir::tblgen::NamedTypeConstraint
;
49 using mlir::tblgen::NamespaceEmitter
;
50 using mlir::tblgen::Operator
;
52 //===----------------------------------------------------------------------===//
53 // Availability Wrapper Class
54 //===----------------------------------------------------------------------===//
57 // Wrapper class with helper methods for accessing availability defined in
61 explicit Availability(const Record
*def
);
63 // Returns the name of the direct TableGen class for this availability
65 StringRef
getClass() const;
67 // Returns the generated C++ interface's class namespace.
68 StringRef
getInterfaceClassNamespace() const;
70 // Returns the generated C++ interface's class name.
71 StringRef
getInterfaceClassName() const;
73 // Returns the generated C++ interface's description.
74 StringRef
getInterfaceDescription() const;
76 // Returns the name of the query function insided the generated C++ interface.
77 StringRef
getQueryFnName() const;
79 // Returns the return type of the query function insided the generated C++
81 StringRef
getQueryFnRetType() const;
83 // Returns the code for merging availability requirements.
84 StringRef
getMergeActionCode() const;
86 // Returns the initializer expression for initializing the final availability
88 StringRef
getMergeInitializer() const;
90 // Returns the C++ type for an availability instance.
91 StringRef
getMergeInstanceType() const;
93 // Returns the C++ statements for preparing availability instance.
94 StringRef
getMergeInstancePreparation() const;
96 // Returns the concrete availability instance carried in this case.
97 StringRef
getMergeInstance() const;
99 // Returns the underlying LLVM TableGen Record.
100 const llvm::Record
*getDef() const { return def
; }
103 // The TableGen definition of this availability.
104 const llvm::Record
*def
;
108 Availability::Availability(const llvm::Record
*def
) : def(def
) {
109 assert(def
->isSubClassOf("Availability") &&
110 "must be subclass of TableGen 'Availability' class");
113 StringRef
Availability::getClass() const {
114 SmallVector
<Record
*, 1> parentClass
;
115 def
->getDirectSuperClasses(parentClass
);
116 if (parentClass
.size() != 1) {
117 PrintFatalError(def
->getLoc(),
118 "expected to only have one direct superclass");
120 return parentClass
.front()->getName();
123 StringRef
Availability::getInterfaceClassNamespace() const {
124 return def
->getValueAsString("cppNamespace");
127 StringRef
Availability::getInterfaceClassName() const {
128 return def
->getValueAsString("interfaceName");
131 StringRef
Availability::getInterfaceDescription() const {
132 return def
->getValueAsString("interfaceDescription");
135 StringRef
Availability::getQueryFnRetType() const {
136 return def
->getValueAsString("queryFnRetType");
139 StringRef
Availability::getQueryFnName() const {
140 return def
->getValueAsString("queryFnName");
143 StringRef
Availability::getMergeActionCode() const {
144 return def
->getValueAsString("mergeAction");
147 StringRef
Availability::getMergeInitializer() const {
148 return def
->getValueAsString("initializer");
151 StringRef
Availability::getMergeInstanceType() const {
152 return def
->getValueAsString("instanceType");
155 StringRef
Availability::getMergeInstancePreparation() const {
156 return def
->getValueAsString("instancePreparation");
159 StringRef
Availability::getMergeInstance() const {
160 return def
->getValueAsString("instance");
163 // Returns the availability spec of the given `def`.
164 std::vector
<Availability
> getAvailabilities(const Record
&def
) {
165 std::vector
<Availability
> availabilities
;
167 if (def
.getValue("availability")) {
168 std::vector
<Record
*> availDefs
= def
.getValueAsListOfDefs("availability");
169 availabilities
.reserve(availDefs
.size());
170 for (const Record
*avail
: availDefs
)
171 availabilities
.emplace_back(avail
);
174 return availabilities
;
177 //===----------------------------------------------------------------------===//
178 // Availability Interface Definitions AutoGen
179 //===----------------------------------------------------------------------===//
181 static void emitInterfaceDef(const Availability
&availability
,
184 os
<< availability
.getQueryFnRetType() << " ";
186 StringRef cppNamespace
= availability
.getInterfaceClassNamespace();
187 cppNamespace
.consume_front("::");
188 if (!cppNamespace
.empty())
189 os
<< cppNamespace
<< "::";
191 StringRef methodName
= availability
.getQueryFnName();
192 os
<< availability
.getInterfaceClassName() << "::" << methodName
<< "() {\n"
193 << " return getImpl()->" << methodName
<< "(getImpl(), getOperation());\n"
197 static bool emitInterfaceDefs(const RecordKeeper
&recordKeeper
,
199 llvm::emitSourceFileHeader("Availability Interface Definitions", os
);
201 auto defs
= recordKeeper
.getAllDerivedDefinitions("Availability");
202 SmallVector
<const Record
*, 1> handledClasses
;
203 for (const Record
*def
: defs
) {
204 SmallVector
<Record
*, 1> parent
;
205 def
->getDirectSuperClasses(parent
);
206 if (parent
.size() != 1) {
207 PrintFatalError(def
->getLoc(),
208 "expected to only have one direct superclass");
210 if (llvm::is_contained(handledClasses
, parent
.front()))
213 Availability
availability(def
);
214 emitInterfaceDef(availability
, os
);
215 handledClasses
.push_back(parent
.front());
220 //===----------------------------------------------------------------------===//
221 // Availability Interface Declarations AutoGen
222 //===----------------------------------------------------------------------===//
224 static void emitConceptDecl(const Availability
&availability
, raw_ostream
&os
) {
225 os
<< " class Concept {\n"
227 << " virtual ~Concept() = default;\n"
228 << " virtual " << availability
.getQueryFnRetType() << " "
229 << availability
.getQueryFnName()
230 << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n"
234 static void emitModelDecl(const Availability
&availability
, raw_ostream
&os
) {
235 for (const char *modelClass
: {"Model", "FallbackModel"}) {
236 os
<< " template<typename ConcreteOp>\n";
237 os
<< " class " << modelClass
<< " : public Concept {\n"
239 << " using Interface = " << availability
.getInterfaceClassName()
241 << " " << availability
.getQueryFnRetType() << " "
242 << availability
.getQueryFnName()
243 << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n"
244 << " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
246 // Forward to the method on the concrete operation type.
247 << " return op." << availability
.getQueryFnName() << "();\n"
251 os
<< " template<typename ConcreteModel, typename ConcreteOp>\n";
252 os
<< " class ExternalModel : public FallbackModel<ConcreteOp> {};\n";
255 static void emitInterfaceDecl(const Availability
&availability
,
257 StringRef interfaceName
= availability
.getInterfaceClassName();
258 std::string interfaceTraitsName
=
259 std::string(formatv("{0}Traits", interfaceName
));
261 StringRef cppNamespace
= availability
.getInterfaceClassNamespace();
262 NamespaceEmitter
nsEmitter(os
, cppNamespace
);
263 os
<< "class " << interfaceName
<< ";\n\n";
265 // Emit the traits struct containing the concept and model declarations.
266 os
<< "namespace detail {\n"
267 << "struct " << interfaceTraitsName
<< " {\n";
268 emitConceptDecl(availability
, os
);
270 emitModelDecl(availability
, os
);
271 os
<< "};\n} // namespace detail\n\n";
273 // Emit the main interface class declaration.
274 os
<< "/*\n" << availability
.getInterfaceDescription().trim() << "\n*/\n";
275 os
<< llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
277 " using OpInterface<{1}, detail::{2}>::OpInterface;\n",
278 interfaceName
, interfaceName
, interfaceTraitsName
);
280 // Emit query function declaration.
281 os
<< " " << availability
.getQueryFnRetType() << " "
282 << availability
.getQueryFnName() << "();\n";
286 static bool emitInterfaceDecls(const RecordKeeper
&recordKeeper
,
288 llvm::emitSourceFileHeader("Availability Interface Declarations", os
);
290 auto defs
= recordKeeper
.getAllDerivedDefinitions("Availability");
291 SmallVector
<const Record
*, 4> handledClasses
;
292 for (const Record
*def
: defs
) {
293 SmallVector
<Record
*, 1> parent
;
294 def
->getDirectSuperClasses(parent
);
295 if (parent
.size() != 1) {
296 PrintFatalError(def
->getLoc(),
297 "expected to only have one direct superclass");
299 if (llvm::is_contained(handledClasses
, parent
.front()))
302 Availability
avail(def
);
303 emitInterfaceDecl(avail
, os
);
304 handledClasses
.push_back(parent
.front());
309 //===----------------------------------------------------------------------===//
310 // Availability Interface Hook Registration
311 //===----------------------------------------------------------------------===//
313 // Registers the operation interface generator to mlir-tblgen.
314 static mlir::GenRegistration
315 genInterfaceDecls("gen-avail-interface-decls",
316 "Generate availability interface declarations",
317 [](const RecordKeeper
&records
, raw_ostream
&os
) {
318 return emitInterfaceDecls(records
, os
);
321 // Registers the operation interface generator to mlir-tblgen.
322 static mlir::GenRegistration
323 genInterfaceDefs("gen-avail-interface-defs",
324 "Generate op interface definitions",
325 [](const RecordKeeper
&records
, raw_ostream
&os
) {
326 return emitInterfaceDefs(records
, os
);
329 //===----------------------------------------------------------------------===//
330 // Enum Availability Query AutoGen
331 //===----------------------------------------------------------------------===//
333 static void emitAvailabilityQueryForIntEnum(const Record
&enumDef
,
335 EnumAttr
enumAttr(enumDef
);
336 StringRef enumName
= enumAttr
.getEnumClassName();
337 std::vector
<EnumAttrCase
> enumerants
= enumAttr
.getAllCases();
339 // Mapping from availability class name to (enumerant, availability
340 // specification) pairs.
341 llvm::StringMap
<llvm::SmallVector
<std::pair
<EnumAttrCase
, Availability
>, 1>>
344 // Place all availability specifications to their corresponding
345 // availability classes.
346 for (const EnumAttrCase
&enumerant
: enumerants
)
347 for (const Availability
&avail
: getAvailabilities(enumerant
.getDef()))
348 classCaseMap
[avail
.getClass()].push_back({enumerant
, avail
});
350 for (const auto &classCasePair
: classCaseMap
) {
351 Availability avail
= classCasePair
.getValue().front().second
;
353 os
<< formatv("std::optional<{0}> {1}({2} value) {{\n",
354 avail
.getMergeInstanceType(), avail
.getQueryFnName(),
357 os
<< " switch (value) {\n";
358 for (const auto &caseSpecPair
: classCasePair
.getValue()) {
359 EnumAttrCase enumerant
= caseSpecPair
.first
;
360 Availability avail
= caseSpecPair
.second
;
361 os
<< formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName
,
362 enumerant
.getSymbol(), avail
.getMergeInstancePreparation(),
363 avail
.getMergeInstanceType(), avail
.getMergeInstance());
365 // Only emit default if uncovered cases.
366 if (classCasePair
.getValue().size() < enumAttr
.getAllCases().size())
367 os
<< " default: break;\n";
369 << " return std::nullopt;\n"
374 static void emitAvailabilityQueryForBitEnum(const Record
&enumDef
,
376 EnumAttr
enumAttr(enumDef
);
377 StringRef enumName
= enumAttr
.getEnumClassName();
378 std::string underlyingType
= std::string(enumAttr
.getUnderlyingType());
379 std::vector
<EnumAttrCase
> enumerants
= enumAttr
.getAllCases();
381 // Mapping from availability class name to (enumerant, availability
382 // specification) pairs.
383 llvm::StringMap
<llvm::SmallVector
<std::pair
<EnumAttrCase
, Availability
>, 1>>
386 // Place all availability specifications to their corresponding
387 // availability classes.
388 for (const EnumAttrCase
&enumerant
: enumerants
)
389 for (const Availability
&avail
: getAvailabilities(enumerant
.getDef()))
390 classCaseMap
[avail
.getClass()].push_back({enumerant
, avail
});
392 for (const auto &classCasePair
: classCaseMap
) {
393 Availability avail
= classCasePair
.getValue().front().second
;
395 os
<< formatv("std::optional<{0}> {1}({2} value) {{\n",
396 avail
.getMergeInstanceType(), avail
.getQueryFnName(),
400 " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1"
401 " && \"cannot have more than one bit set\");\n",
404 os
<< " switch (value) {\n";
405 for (const auto &caseSpecPair
: classCasePair
.getValue()) {
406 EnumAttrCase enumerant
= caseSpecPair
.first
;
407 Availability avail
= caseSpecPair
.second
;
408 os
<< formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName
,
409 enumerant
.getSymbol(), avail
.getMergeInstancePreparation(),
410 avail
.getMergeInstanceType(), avail
.getMergeInstance());
412 os
<< " default: break;\n";
414 << " return std::nullopt;\n"
419 static void emitEnumDecl(const Record
&enumDef
, raw_ostream
&os
) {
420 EnumAttr
enumAttr(enumDef
);
421 StringRef enumName
= enumAttr
.getEnumClassName();
422 StringRef cppNamespace
= enumAttr
.getCppNamespace();
423 auto enumerants
= enumAttr
.getAllCases();
425 llvm::SmallVector
<StringRef
, 2> namespaces
;
426 llvm::SplitString(cppNamespace
, namespaces
, "::");
428 for (auto ns
: namespaces
)
429 os
<< "namespace " << ns
<< " {\n";
431 llvm::StringSet
<> handledClasses
;
433 // Place all availability specifications to their corresponding
434 // availability classes.
435 for (const EnumAttrCase
&enumerant
: enumerants
)
436 for (const Availability
&avail
: getAvailabilities(enumerant
.getDef())) {
437 StringRef className
= avail
.getClass();
438 if (handledClasses
.count(className
))
440 os
<< formatv("std::optional<{0}> {1}({2} value);\n",
441 avail
.getMergeInstanceType(), avail
.getQueryFnName(),
443 handledClasses
.insert(className
);
446 for (auto ns
: llvm::reverse(namespaces
))
447 os
<< "} // namespace " << ns
<< "\n";
450 static bool emitEnumDecls(const RecordKeeper
&recordKeeper
, raw_ostream
&os
) {
451 llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os
);
453 auto defs
= recordKeeper
.getAllDerivedDefinitions("EnumAttrInfo");
454 for (const auto *def
: defs
)
455 emitEnumDecl(*def
, os
);
460 static void emitEnumDef(const Record
&enumDef
, raw_ostream
&os
) {
461 EnumAttr
enumAttr(enumDef
);
462 StringRef cppNamespace
= enumAttr
.getCppNamespace();
464 llvm::SmallVector
<StringRef
, 2> namespaces
;
465 llvm::SplitString(cppNamespace
, namespaces
, "::");
467 for (auto ns
: namespaces
)
468 os
<< "namespace " << ns
<< " {\n";
470 if (enumAttr
.isBitEnum()) {
471 emitAvailabilityQueryForBitEnum(enumDef
, os
);
473 emitAvailabilityQueryForIntEnum(enumDef
, os
);
476 for (auto ns
: llvm::reverse(namespaces
))
477 os
<< "} // namespace " << ns
<< "\n";
481 static bool emitEnumDefs(const RecordKeeper
&recordKeeper
, raw_ostream
&os
) {
482 llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os
);
484 auto defs
= recordKeeper
.getAllDerivedDefinitions("EnumAttrInfo");
485 for (const auto *def
: defs
)
486 emitEnumDef(*def
, os
);
491 //===----------------------------------------------------------------------===//
492 // Enum Availability Query Hook Registration
493 //===----------------------------------------------------------------------===//
495 // Registers the enum utility generator to mlir-tblgen.
496 static mlir::GenRegistration
497 genEnumDecls("gen-spirv-enum-avail-decls",
498 "Generate SPIR-V enum availability declarations",
499 [](const RecordKeeper
&records
, raw_ostream
&os
) {
500 return emitEnumDecls(records
, os
);
503 // Registers the enum utility generator to mlir-tblgen.
504 static mlir::GenRegistration
505 genEnumDefs("gen-spirv-enum-avail-defs",
506 "Generate SPIR-V enum availability definitions",
507 [](const RecordKeeper
&records
, raw_ostream
&os
) {
508 return emitEnumDefs(records
, os
);
511 //===----------------------------------------------------------------------===//
512 // Serialization AutoGen
513 //===----------------------------------------------------------------------===//
515 /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
516 /// generates code extracts the attribute with name `attrName` from
517 /// `operandList` of `op`.
518 static void emitAttributeSerialization(const Attribute
&attr
,
519 ArrayRef
<SMLoc
> loc
, StringRef tabs
,
520 StringRef opVar
, StringRef operandList
,
521 StringRef attrName
, raw_ostream
&os
) {
523 << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar
, attrName
);
524 if (attr
.getAttrDefName() == "SPIRV_ScopeAttr" ||
525 attr
.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
526 attr
.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
527 // These two enums are encoded as <id> to constant values in SPIR-V blob,
528 // but we directly use the constant value as attribute in SPIR-V dialect. So
529 // need to handle them separately from normal enum attributes.
530 EnumAttr
baseEnum(attr
.getDef().getValueAsDef("enum"));
532 << formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), "
533 "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>("
534 "attr.cast<{2}::{3}Attr>().getValue()))));\n",
535 operandList
, opVar
, baseEnum
.getCppNamespace(),
536 baseEnum
.getEnumClassName());
537 } else if (attr
.isSubClassOf("SPIRV_BitEnumAttr") ||
538 attr
.isSubClassOf("SPIRV_I32EnumAttr")) {
539 EnumAttr
baseEnum(attr
.getDef().getValueAsDef("enum"));
541 << formatv(" {0}.push_back(static_cast<uint32_t>("
542 "attr.cast<{1}::{2}Attr>().getValue()));\n",
543 operandList
, baseEnum
.getCppNamespace(),
544 baseEnum
.getEnumClassName());
545 } else if (attr
.getAttrDefName() == "I32ArrayAttr") {
546 // Serialize all the elements of the array
547 os
<< tabs
<< " for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
549 << formatv(" {0}.push_back(static_cast<uint32_t>("
550 "attrElem.cast<IntegerAttr>().getValue().getZExtValue()));\n",
552 os
<< tabs
<< " }\n";
553 } else if (attr
.getAttrDefName() == "I32Attr") {
555 << formatv(" {0}.push_back(static_cast<uint32_t>("
556 "attr.cast<IntegerAttr>().getValue().getZExtValue()));\n",
558 } else if (attr
.isEnumAttr() || attr
.getAttrDefName() == "TypeAttr") {
560 << formatv(" {0}.push_back(static_cast<uint32_t>("
561 "getTypeID(attr.cast<TypeAttr>().getValue())));\n",
567 "unhandled attribute type in SPIR-V serialization generation : '") +
568 attr
.getAttrDefName() + llvm::Twine("'"));
573 /// Generates code to serialize the operands of a SPIRV_Op `op` into `os`. The
574 /// generated queries the SSA-ID if operand is a SSA-Value, or serializes the
575 /// attributes. The `operands` vector is updated appropriately. `elidedAttrs`
576 /// updated as well to include the serialized attributes.
577 static void emitArgumentSerialization(const Operator
&op
, ArrayRef
<SMLoc
> loc
,
578 StringRef tabs
, StringRef opVar
,
579 StringRef operands
, StringRef elidedAttrs
,
581 using mlir::tblgen::Argument
;
583 // SPIR-V ops can mix operands and attributes in the definition. These
584 // operands and attributes are serialized in the exact order of the definition
585 // to match SPIR-V binary format requirements. It can cause excessive
586 // generated code bloat because we are emitting code to handle each
587 // operand/attribute separately. So here we probe first to check whether all
588 // the operands are ahead of attributes. Then we can serialize all operands
591 // Whether all operands are ahead of all attributes in the op's spec.
592 bool areOperandsAheadOfAttrs
= true;
593 // Find the first attribute.
594 const Argument
*it
= llvm::find_if(op
.getArgs(), [](const Argument
&arg
) {
595 return arg
.is
<NamedAttribute
*>();
597 // Check whether all following arguments are attributes.
598 for (const Argument
*ie
= op
.arg_end(); it
!= ie
; ++it
) {
599 if (!it
->is
<NamedAttribute
*>()) {
600 areOperandsAheadOfAttrs
= false;
605 // Serialize all operands together.
606 if (areOperandsAheadOfAttrs
) {
607 if (op
.getNumOperands() != 0) {
609 << formatv("for (Value operand : {0}->getOperands()) {{\n", opVar
);
610 os
<< tabs
<< " auto id = getValueID(operand);\n";
611 os
<< tabs
<< " assert(id && \"use before def!\");\n";
612 os
<< tabs
<< formatv(" {0}.push_back(id);\n", operands
);
615 for (const NamedAttribute
&attr
: op
.getAttributes()) {
616 emitAttributeSerialization(
617 (attr
.attr
.isOptional() ? attr
.attr
.getBaseAttr() : attr
.attr
), loc
,
618 tabs
, opVar
, operands
, attr
.name
, os
);
620 << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs
, attr
.name
);
625 // Serialize operands separately.
627 for (unsigned i
= 0, e
= op
.getNumArgs(); i
< e
; ++i
) {
628 auto argument
= op
.getArg(i
);
630 if (argument
.is
<NamedTypeConstraint
*>()) {
632 << formatv(" for (auto arg : {0}.getODSOperands({1})) {{\n", opVar
,
634 os
<< tabs
<< " auto argID = getValueID(arg);\n";
635 os
<< tabs
<< " if (!argID) {\n";
637 << formatv(" return emitError({0}.getLoc(), "
638 "\"operand #{1} has a use before def\");\n",
640 os
<< tabs
<< " }\n";
641 os
<< tabs
<< formatv(" {0}.push_back(argID);\n", operands
);
645 NamedAttribute
*attr
= argument
.get
<NamedAttribute
*>();
646 auto newtabs
= tabs
.str() + " ";
647 emitAttributeSerialization(
648 (attr
->attr
.isOptional() ? attr
->attr
.getBaseAttr() : attr
->attr
),
649 loc
, newtabs
, opVar
, operands
, attr
->name
, os
);
651 << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs
, attr
->name
);
657 /// Generates code to serializes the result of SPIRV_Op `op` into `os`. The
658 /// generated gets the ID for the type of the result (if any), the SSA-ID of
659 /// the result and updates `resultID` with the SSA-ID.
660 static void emitResultSerialization(const Operator
&op
, ArrayRef
<SMLoc
> loc
,
661 StringRef tabs
, StringRef opVar
,
662 StringRef operands
, StringRef resultID
,
664 if (op
.getNumResults() == 1) {
665 StringRef
resultTypeID("resultTypeID");
666 os
<< tabs
<< formatv("uint32_t {0} = 0;\n", resultTypeID
);
669 "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n",
670 opVar
, resultTypeID
);
671 os
<< tabs
<< " return failure();\n";
673 os
<< tabs
<< formatv("{0}.push_back({1});\n", operands
, resultTypeID
);
674 // Create an SSA result <id> for the op
675 os
<< tabs
<< formatv("{0} = getNextID();\n", resultID
);
677 << formatv("valueIDMap[{0}.getResult()] = {1};\n", opVar
, resultID
);
678 os
<< tabs
<< formatv("{0}.push_back({1});\n", operands
, resultID
);
679 } else if (op
.getNumResults() != 0) {
680 PrintFatalError(loc
, "SPIR-V ops can only have zero or one result");
684 /// Generates code to serialize attributes of SPIRV_Op `op` that become
685 /// decorations on the `resultID` of the serialized operation `opVar` in the
687 static void emitDecorationSerialization(const Operator
&op
, StringRef tabs
,
688 StringRef opVar
, StringRef elidedAttrs
,
689 StringRef resultID
, raw_ostream
&os
) {
690 if (op
.getNumResults() == 1) {
691 // All non-argument attributes translated into OpDecorate instruction
692 os
<< tabs
<< formatv("for (auto attr : {0}->getAttrs()) {{\n", opVar
);
694 << formatv(" if (llvm::is_contained({0}, attr.getName())) {{",
696 os
<< tabs
<< " continue;\n";
697 os
<< tabs
<< " }\n";
700 " if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n",
702 os
<< tabs
<< " return failure();\n";
703 os
<< tabs
<< " }\n";
708 /// Generates code to serialize an SPIRV_Op `op` into `os`.
709 static void emitSerializationFunction(const Record
*attrClass
,
710 const Record
*record
, const Operator
&op
,
712 // If the record has 'autogenSerialization' set to 0, nothing to do
713 if (!record
->getValueAsBit("autogenSerialization"))
716 StringRef
opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"),
717 resultID("resultID");
720 "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n",
721 op
.getQualCppClassName(), opVar
);
723 // Special case for ops without attributes in TableGen definitions
724 if (op
.getNumAttributes() == 0 && op
.getNumVariableLengthOperands() == 0) {
725 std::string extInstSet
;
727 if (record
->isSubClassOf("SPIRV_ExtInstOp")) {
729 formatv("\"{0}\"", record
->getValueAsString("extendedInstSetName"));
730 opcode
= std::to_string(record
->getValueAsInt("extendedInstOpcode"));
733 opcode
= formatv("static_cast<uint32_t>(spirv::Opcode::{0})",
734 record
->getValueAsString("spirvOpName"));
737 os
<< formatv(" return processOpWithoutGrammarAttr({0}, {1}, {2});\n}\n\n",
738 opVar
, extInstSet
, opcode
);
742 os
<< formatv(" SmallVector<uint32_t, 4> {0};\n", operands
);
743 os
<< formatv(" SmallVector<StringRef, 2> {0};\n", elidedAttrs
);
745 // Serialize result information.
746 if (op
.getNumResults() == 1) {
747 os
<< formatv(" uint32_t {0} = 0;\n", resultID
);
748 emitResultSerialization(op
, record
->getLoc(), " ", opVar
, operands
,
752 // Process arguments.
753 emitArgumentSerialization(op
, record
->getLoc(), " ", opVar
, operands
,
756 if (record
->isSubClassOf("SPIRV_ExtInstOp")) {
758 " (void)encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n", opVar
,
759 record
->getValueAsString("extendedInstSetName"),
760 record
->getValueAsInt("extendedInstOpcode"), operands
);
763 os
<< formatv(" (void)emitDebugLine(functionBody, {0}.getLoc());\n",
765 os
<< formatv(" (void)encodeInstructionInto("
766 "functionBody, spirv::Opcode::{1}, {2});\n",
767 op
.getQualCppClassName(),
768 record
->getValueAsString("spirvOpName"), operands
);
771 // Process decorations.
772 emitDecorationSerialization(op
, " ", opVar
, elidedAttrs
, resultID
, os
);
774 os
<< " return success();\n";
778 /// Generates the prologue for the function that dispatches the serialization of
779 /// the operation `opVar` based on its opcode.
780 static void initDispatchSerializationFn(StringRef opVar
, raw_ostream
&os
) {
782 "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
787 /// Generates the body of the dispatch function. This function generates the
788 /// check that if satisfied, will call the serialization function generated for
790 static void emitSerializationDispatch(const Operator
&op
, StringRef tabs
,
791 StringRef opVar
, raw_ostream
&os
) {
793 << formatv("if (isa<{0}>({1})) {{\n", op
.getQualCppClassName(), opVar
);
795 << formatv(" return processOp(cast<{0}>({1}));\n",
796 op
.getQualCppClassName(), opVar
);
800 /// Generates the epilogue for the function that dispatches the serialization of
802 static void finalizeDispatchSerializationFn(StringRef opVar
, raw_ostream
&os
) {
804 " return {0}->emitError(\"unhandled operation serialization\");\n",
809 /// Generates code to deserialize the attribute of a SPIRV_Op into `os`. The
810 /// generated code reads the `words` of the serialized instruction at
811 /// position `wordIndex` and adds the deserialized attribute into `attrList`.
812 static void emitAttributeDeserialization(const Attribute
&attr
,
813 ArrayRef
<SMLoc
> loc
, StringRef tabs
,
814 StringRef attrList
, StringRef attrName
,
815 StringRef words
, StringRef wordIndex
,
817 if (attr
.getAttrDefName() == "SPIRV_ScopeAttr" ||
818 attr
.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
819 attr
.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
820 // These two enums are encoded as <id> to constant values in SPIR-V blob,
821 // but we directly use the constant value as attribute in SPIR-V dialect. So
822 // need to handle them separately from normal enum attributes.
823 EnumAttr
baseEnum(attr
.getDef().getValueAsDef("enum"));
825 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
826 "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>("
827 "getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n",
828 attrList
, attrName
, baseEnum
.getCppNamespace(),
829 baseEnum
.getEnumClassName(), words
, wordIndex
);
830 } else if (attr
.isSubClassOf("SPIRV_BitEnumAttr") ||
831 attr
.isSubClassOf("SPIRV_I32EnumAttr")) {
832 EnumAttr
baseEnum(attr
.getDef().getValueAsDef("enum"));
834 << formatv(" {0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
835 "opBuilder.getAttr<{2}::{3}Attr>("
836 "static_cast<{2}::{3}>({4}[{5}++]))));\n",
837 attrList
, attrName
, baseEnum
.getCppNamespace(),
838 baseEnum
.getEnumClassName(), words
, wordIndex
);
839 } else if (attr
.getAttrDefName() == "I32ArrayAttr") {
840 os
<< tabs
<< "SmallVector<Attribute, 4> attrListElems;\n";
841 os
<< tabs
<< formatv("while ({0} < {1}.size()) {{\n", wordIndex
, words
);
845 "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))"
850 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
851 "opBuilder.getArrayAttr(attrListElems)));\n",
853 } else if (attr
.getAttrDefName() == "I32Attr") {
855 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
856 "opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
857 attrList
, attrName
, words
, wordIndex
);
858 } else if (attr
.isEnumAttr() || attr
.getAttrDefName() == "TypeAttr") {
860 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
861 "TypeAttr::get(getType({2}[{3}++]))));\n",
862 attrList
, attrName
, words
, wordIndex
);
866 "unhandled attribute type in deserialization generation : '") +
867 attr
.getAttrDefName() + llvm::Twine("'"));
871 /// Generates the code to deserialize the result of an SPIRV_Op `op` into
872 /// `os`. The generated code gets the type of the result specified at
873 /// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1
874 /// and updates the `resultType` and `valueID` with the parsed type and SSA ID,
876 static void emitResultDeserialization(const Operator
&op
, ArrayRef
<SMLoc
> loc
,
877 StringRef tabs
, StringRef words
,
879 StringRef resultTypes
, StringRef valueID
,
881 // Deserialize result information if it exists
882 if (op
.getNumResults() == 1) {
884 os
<< tabs
<< formatv(" if ({0} >= {1}.size()) {{\n", wordIndex
, words
);
887 " return emitError(unknownLoc, \"expected result type <id> "
888 "while deserializing {0}\");\n",
889 op
.getQualCppClassName());
890 os
<< tabs
<< " }\n";
891 os
<< tabs
<< formatv(" auto ty = getType({0}[{1}]);\n", words
, wordIndex
);
892 os
<< tabs
<< " if (!ty) {\n";
895 " return emitError(unknownLoc, \"unknown type result <id> : "
896 "\") << {0}[{1}];\n",
898 os
<< tabs
<< " }\n";
899 os
<< tabs
<< formatv(" {0}.push_back(ty);\n", resultTypes
);
900 os
<< tabs
<< formatv(" {0}++;\n", wordIndex
);
901 os
<< tabs
<< formatv(" if ({0} >= {1}.size()) {{\n", wordIndex
, words
);
904 " return emitError(unknownLoc, \"expected result <id> while "
905 "deserializing {0}\");\n",
906 op
.getQualCppClassName());
907 os
<< tabs
<< " }\n";
909 os
<< tabs
<< formatv("{0} = {1}[{2}++];\n", valueID
, words
, wordIndex
);
910 } else if (op
.getNumResults() != 0) {
911 PrintFatalError(loc
, "SPIR-V ops can have only zero or one result");
915 /// Generates the code to deserialize the operands of an SPIRV_Op `op` into
916 /// `os`. The generated code reads the `words` of the binary instruction, from
917 /// position `wordIndex` to the end, and either gets the Value corresponding to
918 /// the ID encoded, or deserializes the attributes encoded. The parsed
919 /// operand(attribute) is added to the `operands` list or `attributes` list.
920 static void emitOperandDeserialization(const Operator
&op
, ArrayRef
<SMLoc
> loc
,
921 StringRef tabs
, StringRef words
,
922 StringRef wordIndex
, StringRef operands
,
923 StringRef attributes
, raw_ostream
&os
) {
924 // Process operands/attributes
925 for (unsigned i
= 0, e
= op
.getNumArgs(); i
< e
; ++i
) {
926 auto argument
= op
.getArg(i
);
927 if (auto *valueArg
= argument
.dyn_cast
<NamedTypeConstraint
*>()) {
928 if (valueArg
->isVariableLength()) {
930 PrintFatalError(loc
, "SPIR-V ops can have Variadic<..> or "
931 "std::optional<...> arguments only if "
932 "it's the last argument");
935 << formatv("for (; {0} < {1}.size(); ++{0})", wordIndex
, words
);
937 os
<< tabs
<< formatv("if ({0} < {1}.size())", wordIndex
, words
);
941 << formatv(" auto arg = getValue({0}[{1}]);\n", words
, wordIndex
);
942 os
<< tabs
<< " if (!arg) {\n";
945 " return emitError(unknownLoc, \"unknown result <id> : \") "
948 os
<< tabs
<< " }\n";
949 os
<< tabs
<< formatv(" {0}.push_back(arg);\n", operands
);
950 if (!valueArg
->isVariableLength()) {
951 os
<< tabs
<< formatv(" {0}++;\n", wordIndex
);
955 os
<< tabs
<< formatv("if ({0} < {1}.size()) {{\n", wordIndex
, words
);
956 auto *attr
= argument
.get
<NamedAttribute
*>();
957 auto newtabs
= tabs
.str() + " ";
958 emitAttributeDeserialization(
959 (attr
->attr
.isOptional() ? attr
->attr
.getBaseAttr() : attr
->attr
),
960 loc
, newtabs
, attributes
, attr
->name
, words
, wordIndex
, os
);
965 os
<< tabs
<< formatv("if ({0} != {1}.size()) {{\n", wordIndex
, words
);
968 " return emitError(unknownLoc, \"found more operands than "
969 "expected when deserializing {0}, only \") << {1} << \" of \" << "
970 "{2}.size() << \" processed\";\n",
971 op
.getQualCppClassName(), wordIndex
, words
);
972 os
<< tabs
<< "}\n\n";
975 /// Generates code to update the `attributes` vector with the attributes
976 /// obtained from parsing the decorations in the SPIR-V binary associated with
977 /// an <id> `valueID`
978 static void emitDecorationDeserialization(const Operator
&op
, StringRef tabs
,
980 StringRef attributes
,
982 // Import decorations parsed
983 if (op
.getNumResults() == 1) {
984 os
<< tabs
<< formatv("if (decorations.count({0})) {{\n", valueID
);
986 << formatv(" auto attrs = decorations[{0}].getAttrs();\n", valueID
);
988 << formatv(" {0}.append(attrs.begin(), attrs.end());\n", attributes
);
993 /// Generates code to deserialize an SPIRV_Op `op` into `os`.
994 static void emitDeserializationFunction(const Record
*attrClass
,
995 const Record
*record
,
996 const Operator
&op
, raw_ostream
&os
) {
997 // If the record has 'autogenSerialization' set to 0, nothing to do
998 if (!record
->getValueAsBit("autogenSerialization"))
1001 StringRef
resultTypes("resultTypes"), valueID("valueID"), words("words"),
1002 wordIndex("wordIndex"), opVar("op"), operands("operands"),
1003 attributes("attributes");
1005 // Method declaration
1006 os
<< formatv("template <> "
1007 "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
1008 "uint32_t> {1}) {{\n",
1009 op
.getQualCppClassName(), words
);
1011 // Special case for ops without attributes in TableGen definitions
1012 if (op
.getNumAttributes() == 0 && op
.getNumVariableLengthOperands() == 0) {
1013 os
<< formatv(" return processOpWithoutGrammarAttr("
1014 "{0}, \"{1}\", {2}, {3});\n}\n\n",
1015 words
, op
.getOperationName(),
1016 op
.getNumResults() ? "true" : "false", op
.getNumOperands());
1020 os
<< formatv(" SmallVector<Type, 1> {0};\n", resultTypes
);
1021 os
<< formatv(" size_t {0} = 0; (void){0};\n", wordIndex
);
1022 os
<< formatv(" uint32_t {0} = 0; (void){0};\n", valueID
);
1024 // Deserialize result information
1025 emitResultDeserialization(op
, record
->getLoc(), " ", words
, wordIndex
,
1026 resultTypes
, valueID
, os
);
1028 os
<< formatv(" SmallVector<Value, 4> {0};\n", operands
);
1029 os
<< formatv(" SmallVector<NamedAttribute, 4> {0};\n", attributes
);
1030 // Operand deserialization
1031 emitOperandDeserialization(op
, record
->getLoc(), " ", words
, wordIndex
,
1032 operands
, attributes
, os
);
1035 emitDecorationDeserialization(op
, " ", valueID
, attributes
, os
);
1037 os
<< formatv(" Location loc = createFileLineColLoc(opBuilder);\n");
1038 os
<< formatv(" auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); "
1040 op
.getQualCppClassName(), opVar
, resultTypes
, operands
,
1042 if (op
.getNumResults() == 1) {
1043 os
<< formatv(" valueMap[{0}] = {1}.getResult();\n\n", valueID
, opVar
);
1046 // According to SPIR-V spec:
1047 // This location information applies to the instructions physically following
1048 // this instruction, up to the first occurrence of any of the following: the
1049 // next end of block.
1050 os
<< formatv(" if ({0}.hasTrait<OpTrait::IsTerminator>())\n", opVar
);
1051 os
<< formatv(" (void)clearDebugLine();\n");
1052 os
<< " return success();\n";
1056 /// Generates the prologue for the function that dispatches the deserialization
1057 /// based on the `opcode`.
1058 static void initDispatchDeserializationFn(StringRef opcode
, StringRef words
,
1060 os
<< formatv("LogicalResult spirv::Deserializer::"
1061 "dispatchToAutogenDeserialization(spirv::Opcode {0},"
1062 " ArrayRef<uint32_t> {1}) {{\n",
1064 os
<< formatv(" switch ({0}) {{\n", opcode
);
1067 /// Generates the body of the dispatch function, by generating the case label
1068 /// for an opcode and the call to the method to perform the deserialization.
1069 static void emitDeserializationDispatch(const Operator
&op
, const Record
*def
,
1070 StringRef tabs
, StringRef words
,
1073 << formatv("case spirv::Opcode::{0}:\n",
1074 def
->getValueAsString("spirvOpName"));
1076 << formatv(" return processOp<{0}>({1});\n", op
.getQualCppClassName(),
1080 /// Generates the epilogue for the function that dispatches the deserialization
1081 /// of the operation.
1082 static void finalizeDispatchDeserializationFn(StringRef opcode
,
1084 os
<< " default:\n";
1087 StringRef
opcodeVar("opcodeString");
1088 os
<< formatv(" auto {0} = spirv::stringifyOpcode({1});\n", opcodeVar
,
1090 os
<< formatv(" if (!{0}.empty()) {{\n", opcodeVar
);
1091 os
<< formatv(" return emitError(unknownLoc, \"unhandled deserialization "
1094 os
<< " } else {\n";
1095 os
<< formatv(" return emitError(unknownLoc, \"unhandled opcode \") << "
1096 "static_cast<uint32_t>({0});\n",
1102 static void initExtendedSetDeserializationDispatch(StringRef extensionSetName
,
1103 StringRef instructionID
,
1106 os
<< formatv("LogicalResult spirv::Deserializer::"
1107 "dispatchToExtensionSetAutogenDeserialization("
1108 "StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n",
1109 extensionSetName
, instructionID
, words
);
1113 emitExtendedSetDeserializationDispatch(const RecordKeeper
&recordKeeper
,
1115 StringRef
extensionSetName("extensionSetName"),
1116 instructionID("instructionID"), words("words");
1118 // First iterate over all ops derived from SPIRV_ExtensionSetOps to get all
1121 // For each of the extensions a separate raw_string_ostream is used to
1122 // generate code into. These are then concatenated at the end. Since
1123 // raw_string_ostream needs a string&, use a vector to store all the string
1124 // that are captured by reference within raw_string_ostream.
1125 StringMap
<raw_string_ostream
> extensionSets
;
1126 std::list
<std::string
> extensionSetNames
;
1128 initExtendedSetDeserializationDispatch(extensionSetName
, instructionID
, words
,
1130 auto defs
= recordKeeper
.getAllDerivedDefinitions("SPIRV_ExtInstOp");
1131 for (const auto *def
: defs
) {
1132 if (!def
->getValueAsBit("autogenSerialization")) {
1136 auto setName
= def
->getValueAsString("extendedInstSetName");
1137 if (!extensionSets
.count(setName
)) {
1138 extensionSetNames
.emplace_back("");
1139 extensionSets
.try_emplace(setName
, extensionSetNames
.back());
1140 auto &setos
= extensionSets
.find(setName
)->second
;
1141 setos
<< formatv(" if ({0} == \"{1}\") {{\n", extensionSetName
, setName
);
1142 setos
<< formatv(" switch ({0}) {{\n", instructionID
);
1144 auto &setos
= extensionSets
.find(setName
)->second
;
1145 setos
<< formatv(" case {0}:\n",
1146 def
->getValueAsInt("extendedInstOpcode"));
1147 setos
<< formatv(" return processOp<{0}>({1});\n",
1148 op
.getQualCppClassName(), words
);
1151 // Append the dispatch code for all the extended sets.
1152 for (auto &extensionSet
: extensionSets
) {
1153 os
<< extensionSet
.second
.str();
1154 os
<< " default:\n";
1156 " return emitError(unknownLoc, \"unhandled deserializations of "
1157 "\") << {0} << \" from extension set \" << {1};\n",
1158 instructionID
, extensionSetName
);
1163 os
<< formatv(" return emitError(unknownLoc, \"unhandled deserialization of "
1164 "extended instruction set {0}\");\n",
1169 /// Emits all the autogenerated serialization/deserializations functions for the
1171 static bool emitSerializationFns(const RecordKeeper
&recordKeeper
,
1173 llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os
);
1175 std::string dSerFnString
, dDesFnString
, serFnString
, deserFnString
,
1177 raw_string_ostream
dSerFn(dSerFnString
), dDesFn(dDesFnString
),
1178 serFn(serFnString
), deserFn(deserFnString
);
1179 Record
*attrClass
= recordKeeper
.getClass("Attr");
1181 // Emit the serialization and deserialization functions simultaneously.
1182 StringRef
opVar("op");
1183 StringRef
opcode("opcode"), words("words");
1185 // Handle the SPIR-V ops.
1186 initDispatchSerializationFn(opVar
, dSerFn
);
1187 initDispatchDeserializationFn(opcode
, words
, dDesFn
);
1188 auto defs
= recordKeeper
.getAllDerivedDefinitions("SPIRV_Op");
1189 for (const auto *def
: defs
) {
1191 emitSerializationFunction(attrClass
, def
, op
, serFn
);
1192 emitDeserializationFunction(attrClass
, def
, op
, deserFn
);
1193 if (def
->getValueAsBit("hasOpcode") ||
1194 def
->isSubClassOf("SPIRV_ExtInstOp")) {
1195 emitSerializationDispatch(op
, " ", opVar
, dSerFn
);
1197 if (def
->getValueAsBit("hasOpcode")) {
1198 emitDeserializationDispatch(op
, def
, " ", words
, dDesFn
);
1201 finalizeDispatchSerializationFn(opVar
, dSerFn
);
1202 finalizeDispatchDeserializationFn(opcode
, dDesFn
);
1204 emitExtendedSetDeserializationDispatch(recordKeeper
, dDesFn
);
1206 os
<< "#ifdef GET_SERIALIZATION_FNS\n\n";
1209 os
<< "#endif // GET_SERIALIZATION_FNS\n\n";
1211 os
<< "#ifdef GET_DESERIALIZATION_FNS\n\n";
1212 os
<< deserFn
.str();
1214 os
<< "#endif // GET_DESERIALIZATION_FNS\n\n";
1219 //===----------------------------------------------------------------------===//
1220 // Serialization Hook Registration
1221 //===----------------------------------------------------------------------===//
1223 static mlir::GenRegistration
genSerialization(
1224 "gen-spirv-serialization",
1225 "Generate SPIR-V (de)serialization utilities and functions",
1226 [](const RecordKeeper
&records
, raw_ostream
&os
) {
1227 return emitSerializationFns(records
, os
);
1230 //===----------------------------------------------------------------------===//
1232 //===----------------------------------------------------------------------===//
1234 static void emitEnumGetAttrNameFnDecl(raw_ostream
&os
) {
1235 os
<< formatv("template <typename EnumClass> inline constexpr StringRef "
1236 "attributeName();\n");
1239 static void emitEnumGetAttrNameFnDefn(const EnumAttr
&enumAttr
,
1241 auto enumName
= enumAttr
.getEnumClassName();
1242 os
<< formatv("template <> inline StringRef attributeName<{0}>() {{\n",
1245 << formatv("static constexpr const char attrName[] = \"{0}\";\n",
1246 llvm::convertToSnakeFromCamelCase(enumName
));
1247 os
<< " return attrName;\n";
1251 static bool emitAttrUtils(const RecordKeeper
&recordKeeper
, raw_ostream
&os
) {
1252 llvm::emitSourceFileHeader("SPIR-V Attribute Utilities", os
);
1254 auto defs
= recordKeeper
.getAllDerivedDefinitions("EnumAttrInfo");
1255 os
<< "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1256 os
<< "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1257 emitEnumGetAttrNameFnDecl(os
);
1258 for (const auto *def
: defs
) {
1259 EnumAttr
enumAttr(*def
);
1260 emitEnumGetAttrNameFnDefn(enumAttr
, os
);
1262 os
<< "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n";
1266 //===----------------------------------------------------------------------===//
1267 // Op Utils Hook Registration
1268 //===----------------------------------------------------------------------===//
1270 static mlir::GenRegistration
1271 genOpUtils("gen-spirv-attr-utils",
1272 "Generate SPIR-V attribute utility definitions",
1273 [](const RecordKeeper
&records
, raw_ostream
&os
) {
1274 return emitAttrUtils(records
, os
);
1277 //===----------------------------------------------------------------------===//
1278 // SPIR-V Availability Impl AutoGen
1279 //===----------------------------------------------------------------------===//
1281 static void emitAvailabilityImpl(const Operator
&srcOp
, raw_ostream
&os
) {
1282 mlir::tblgen::FmtContext fctx
;
1283 fctx
.addSubst("overall", "tblgen_overall");
1285 std::vector
<Availability
> opAvailabilities
=
1286 getAvailabilities(srcOp
.getDef());
1288 // First collect all availability classes this op should implement.
1289 // All availability instances keep information for the generated interface and
1290 // the instance's specific requirement. Here we remember a random instance so
1291 // we can get the information regarding the generated interface.
1292 llvm::StringMap
<Availability
> availClasses
;
1293 for (const Availability
&avail
: opAvailabilities
)
1294 availClasses
.try_emplace(avail
.getClass(), avail
);
1295 for (const NamedAttribute
&namedAttr
: srcOp
.getAttributes()) {
1296 if (!namedAttr
.attr
.isSubClassOf("SPIRV_BitEnumAttr") &&
1297 !namedAttr
.attr
.isSubClassOf("SPIRV_I32EnumAttr"))
1299 EnumAttr
enumAttr(namedAttr
.attr
.getDef().getValueAsDef("enum"));
1301 for (const EnumAttrCase
&enumerant
: enumAttr
.getAllCases())
1302 for (const Availability
&caseAvail
:
1303 getAvailabilities(enumerant
.getDef()))
1304 availClasses
.try_emplace(caseAvail
.getClass(), caseAvail
);
1307 // Then generate implementation for each availability class.
1308 for (const auto &availClass
: availClasses
) {
1309 StringRef availClassName
= availClass
.getKey();
1310 Availability avail
= availClass
.getValue();
1312 // Generate the implementation method signature.
1313 os
<< formatv("{0} {1}::{2}() {{\n", avail
.getQueryFnRetType(),
1314 srcOp
.getCppClassName(), avail
.getQueryFnName());
1316 // Create the variable for the final requirement and initialize it.
1317 os
<< formatv(" {0} tblgen_overall = {1};\n", avail
.getQueryFnRetType(),
1318 avail
.getMergeInitializer());
1320 // Update with the op's specific availability spec.
1321 for (const Availability
&avail
: opAvailabilities
)
1322 if (avail
.getClass() == availClassName
&&
1323 (!avail
.getMergeInstancePreparation().empty() ||
1324 !avail
.getMergeActionCode().empty())) {
1326 // Prepare this instance.
1327 << avail
.getMergeInstancePreparation()
1329 // Merge this instance.
1331 tgfmt(avail
.getMergeActionCode(),
1332 &fctx
.addSubst("instance", avail
.getMergeInstance())))
1336 // Update with enum attributes' specific availability spec.
1337 for (const NamedAttribute
&namedAttr
: srcOp
.getAttributes()) {
1338 if (!namedAttr
.attr
.isSubClassOf("SPIRV_BitEnumAttr") &&
1339 !namedAttr
.attr
.isSubClassOf("SPIRV_I32EnumAttr"))
1341 EnumAttr
enumAttr(namedAttr
.attr
.getDef().getValueAsDef("enum"));
1343 // (enumerant, availability specification) pairs for this availability
1345 SmallVector
<std::pair
<EnumAttrCase
, Availability
>, 1> caseSpecs
;
1347 // Collect all cases' availability specs.
1348 for (const EnumAttrCase
&enumerant
: enumAttr
.getAllCases())
1349 for (const Availability
&caseAvail
:
1350 getAvailabilities(enumerant
.getDef()))
1351 if (availClassName
== caseAvail
.getClass())
1352 caseSpecs
.push_back({enumerant
, caseAvail
});
1354 // If this attribute kind does not have any availability spec from any of
1355 // its cases, no more work to do.
1356 if (caseSpecs
.empty())
1359 if (enumAttr
.isBitEnum()) {
1360 // For BitEnumAttr, we need to iterate over each bit to query its
1361 // availability spec.
1362 os
<< formatv(" for (unsigned i = 0; "
1363 "i < std::numeric_limits<{0}>::digits; ++i) {{\n",
1364 enumAttr
.getUnderlyingType());
1365 os
<< formatv(" {0}::{1} tblgen_attrVal = this->{2}() & "
1366 "static_cast<{0}::{1}>(1 << i);\n",
1367 enumAttr
.getCppNamespace(), enumAttr
.getEnumClassName(),
1368 srcOp
.getGetterName(namedAttr
.name
));
1370 " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
1371 enumAttr
.getUnderlyingType());
1373 // For IntEnumAttr, we just need to query the value as a whole.
1375 os
<< formatv(" auto tblgen_attrVal = this->{0}();\n",
1376 srcOp
.getGetterName(namedAttr
.name
));
1378 os
<< formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
1379 enumAttr
.getCppNamespace(), avail
.getQueryFnName());
1380 os
<< " if (tblgen_instance) "
1381 // TODO` here once ODS supports
1382 // dialect-specific contents so that we can use not implementing the
1383 // availability interface as indication of no requirements.
1384 << std::string(tgfmt(caseSpecs
.front().second
.getMergeActionCode(),
1385 &fctx
.addSubst("instance", "*tblgen_instance")))
1390 os
<< " return tblgen_overall;\n";
1395 static bool emitAvailabilityImpl(const RecordKeeper
&recordKeeper
,
1397 llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os
);
1399 auto defs
= recordKeeper
.getAllDerivedDefinitions("SPIRV_Op");
1400 for (const auto *def
: defs
) {
1402 if (def
->getValueAsBit("autogenAvailability"))
1403 emitAvailabilityImpl(op
, os
);
1408 //===----------------------------------------------------------------------===//
1409 // Op Availability Implementation Hook Registration
1410 //===----------------------------------------------------------------------===//
1412 static mlir::GenRegistration
1413 genOpAvailabilityImpl("gen-spirv-avail-impls",
1414 "Generate SPIR-V operation utility definitions",
1415 [](const RecordKeeper
&records
, raw_ostream
&os
) {
1416 return emitAvailabilityImpl(records
, os
);
1419 //===----------------------------------------------------------------------===//
1420 // SPIR-V Capability Implication AutoGen
1421 //===----------------------------------------------------------------------===//
1423 static bool emitCapabilityImplication(const RecordKeeper
&recordKeeper
,
1425 llvm::emitSourceFileHeader("SPIR-V Capability Implication", os
);
1428 recordKeeper
.getDef("SPIRV_CapabilityAttr")->getValueAsDef("enum"));
1430 os
<< "ArrayRef<spirv::Capability> "
1431 "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n"
1432 << " switch (cap) {\n"
1433 << " default: return {};\n";
1434 for (const EnumAttrCase
&enumerant
: enumAttr
.getAllCases()) {
1435 const Record
&def
= enumerant
.getDef();
1436 if (!def
.getValue("implies"))
1439 std::vector
<Record
*> impliedCapsDefs
= def
.getValueAsListOfDefs("implies");
1440 os
<< " case spirv::Capability::" << enumerant
.getSymbol()
1441 << ": {static const spirv::Capability implies[" << impliedCapsDefs
.size()
1443 llvm::interleaveComma(impliedCapsDefs
, os
, [&](const Record
*capDef
) {
1444 os
<< "spirv::Capability::" << EnumAttrCase(capDef
).getSymbol();
1446 os
<< "}; return ArrayRef<spirv::Capability>(implies, "
1447 << impliedCapsDefs
.size() << "); }\n";
1455 //===----------------------------------------------------------------------===//
1456 // SPIR-V Capability Implication Hook Registration
1457 //===----------------------------------------------------------------------===//
1459 static mlir::GenRegistration
1460 genCapabilityImplication("gen-spirv-capability-implication",
1461 "Generate utility function to return implied "
1462 "capabilities for a given capability",
1463 [](const RecordKeeper
&records
, raw_ostream
&os
) {
1464 return emitCapabilityImplication(records
, os
);