1 //===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // SPIRVSerializationGen generates common utility functions for SPIR-V
12 //===----------------------------------------------------------------------===//
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Operator.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Record.h"
30 #include "llvm/TableGen/TableGenBackend.h"
37 using llvm::raw_ostream
;
38 using llvm::raw_string_ostream
;
40 using llvm::RecordKeeper
;
41 using llvm::SmallVector
;
43 using llvm::StringMap
;
44 using llvm::StringRef
;
45 using mlir::tblgen::Attribute
;
46 using mlir::tblgen::EnumAttr
;
47 using mlir::tblgen::EnumAttrCase
;
48 using mlir::tblgen::NamedAttribute
;
49 using mlir::tblgen::NamedTypeConstraint
;
50 using mlir::tblgen::NamespaceEmitter
;
51 using mlir::tblgen::Operator
;
53 //===----------------------------------------------------------------------===//
54 // Availability Wrapper Class
55 //===----------------------------------------------------------------------===//
58 // Wrapper class with helper methods for accessing availability defined in
62 explicit Availability(const Record
*def
);
64 // Returns the name of the direct TableGen class for this availability
66 StringRef
getClass() const;
68 // Returns the generated C++ interface's class namespace.
69 StringRef
getInterfaceClassNamespace() const;
71 // Returns the generated C++ interface's class name.
72 StringRef
getInterfaceClassName() const;
74 // Returns the generated C++ interface's description.
75 StringRef
getInterfaceDescription() const;
77 // Returns the name of the query function insided the generated C++ interface.
78 StringRef
getQueryFnName() const;
80 // Returns the return type of the query function insided the generated C++
82 StringRef
getQueryFnRetType() const;
84 // Returns the code for merging availability requirements.
85 StringRef
getMergeActionCode() const;
87 // Returns the initializer expression for initializing the final availability
89 StringRef
getMergeInitializer() const;
91 // Returns the C++ type for an availability instance.
92 StringRef
getMergeInstanceType() const;
94 // Returns the C++ statements for preparing availability instance.
95 StringRef
getMergeInstancePreparation() const;
97 // Returns the concrete availability instance carried in this case.
98 StringRef
getMergeInstance() const;
100 // Returns the underlying LLVM TableGen Record.
101 const llvm::Record
*getDef() const { return def
; }
104 // The TableGen definition of this availability.
105 const llvm::Record
*def
;
109 Availability::Availability(const llvm::Record
*def
) : def(def
) {
110 assert(def
->isSubClassOf("Availability") &&
111 "must be subclass of TableGen 'Availability' class");
114 StringRef
Availability::getClass() const {
115 SmallVector
<Record
*, 1> parentClass
;
116 def
->getDirectSuperClasses(parentClass
);
117 if (parentClass
.size() != 1) {
118 PrintFatalError(def
->getLoc(),
119 "expected to only have one direct superclass");
121 return parentClass
.front()->getName();
124 StringRef
Availability::getInterfaceClassNamespace() const {
125 return def
->getValueAsString("cppNamespace");
128 StringRef
Availability::getInterfaceClassName() const {
129 return def
->getValueAsString("interfaceName");
132 StringRef
Availability::getInterfaceDescription() const {
133 return def
->getValueAsString("interfaceDescription");
136 StringRef
Availability::getQueryFnRetType() const {
137 return def
->getValueAsString("queryFnRetType");
140 StringRef
Availability::getQueryFnName() const {
141 return def
->getValueAsString("queryFnName");
144 StringRef
Availability::getMergeActionCode() const {
145 return def
->getValueAsString("mergeAction");
148 StringRef
Availability::getMergeInitializer() const {
149 return def
->getValueAsString("initializer");
152 StringRef
Availability::getMergeInstanceType() const {
153 return def
->getValueAsString("instanceType");
156 StringRef
Availability::getMergeInstancePreparation() const {
157 return def
->getValueAsString("instancePreparation");
160 StringRef
Availability::getMergeInstance() const {
161 return def
->getValueAsString("instance");
164 // Returns the availability spec of the given `def`.
165 std::vector
<Availability
> getAvailabilities(const Record
&def
) {
166 std::vector
<Availability
> availabilities
;
168 if (def
.getValue("availability")) {
169 std::vector
<Record
*> availDefs
= def
.getValueAsListOfDefs("availability");
170 availabilities
.reserve(availDefs
.size());
171 for (const Record
*avail
: availDefs
)
172 availabilities
.emplace_back(avail
);
175 return availabilities
;
178 //===----------------------------------------------------------------------===//
179 // Availability Interface Definitions AutoGen
180 //===----------------------------------------------------------------------===//
182 static void emitInterfaceDef(const Availability
&availability
,
185 os
<< availability
.getQueryFnRetType() << " ";
187 StringRef cppNamespace
= availability
.getInterfaceClassNamespace();
188 cppNamespace
.consume_front("::");
189 if (!cppNamespace
.empty())
190 os
<< cppNamespace
<< "::";
192 StringRef methodName
= availability
.getQueryFnName();
193 os
<< availability
.getInterfaceClassName() << "::" << methodName
<< "() {\n"
194 << " return getImpl()->" << methodName
<< "(getImpl(), getOperation());\n"
198 static bool emitInterfaceDefs(const RecordKeeper
&recordKeeper
,
200 llvm::emitSourceFileHeader("Availability Interface Definitions", os
,
203 auto defs
= recordKeeper
.getAllDerivedDefinitions("Availability");
204 SmallVector
<const Record
*, 1> handledClasses
;
205 for (const Record
*def
: defs
) {
206 SmallVector
<Record
*, 1> parent
;
207 def
->getDirectSuperClasses(parent
);
208 if (parent
.size() != 1) {
209 PrintFatalError(def
->getLoc(),
210 "expected to only have one direct superclass");
212 if (llvm::is_contained(handledClasses
, parent
.front()))
215 Availability
availability(def
);
216 emitInterfaceDef(availability
, os
);
217 handledClasses
.push_back(parent
.front());
222 //===----------------------------------------------------------------------===//
223 // Availability Interface Declarations AutoGen
224 //===----------------------------------------------------------------------===//
226 static void emitConceptDecl(const Availability
&availability
, raw_ostream
&os
) {
227 os
<< " class Concept {\n"
229 << " virtual ~Concept() = default;\n"
230 << " virtual " << availability
.getQueryFnRetType() << " "
231 << availability
.getQueryFnName()
232 << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n"
236 static void emitModelDecl(const Availability
&availability
, raw_ostream
&os
) {
237 for (const char *modelClass
: {"Model", "FallbackModel"}) {
238 os
<< " template<typename ConcreteOp>\n";
239 os
<< " class " << modelClass
<< " : public Concept {\n"
241 << " using Interface = " << availability
.getInterfaceClassName()
243 << " " << availability
.getQueryFnRetType() << " "
244 << availability
.getQueryFnName()
245 << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n"
246 << " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
248 // Forward to the method on the concrete operation type.
249 << " return op." << availability
.getQueryFnName() << "();\n"
253 os
<< " template<typename ConcreteModel, typename ConcreteOp>\n";
254 os
<< " class ExternalModel : public FallbackModel<ConcreteOp> {};\n";
257 static void emitInterfaceDecl(const Availability
&availability
,
259 StringRef interfaceName
= availability
.getInterfaceClassName();
260 std::string interfaceTraitsName
=
261 std::string(formatv("{0}Traits", interfaceName
));
263 StringRef cppNamespace
= availability
.getInterfaceClassNamespace();
264 NamespaceEmitter
nsEmitter(os
, cppNamespace
);
265 os
<< "class " << interfaceName
<< ";\n\n";
267 // Emit the traits struct containing the concept and model declarations.
268 os
<< "namespace detail {\n"
269 << "struct " << interfaceTraitsName
<< " {\n";
270 emitConceptDecl(availability
, os
);
272 emitModelDecl(availability
, os
);
273 os
<< "};\n} // namespace detail\n\n";
275 // Emit the main interface class declaration.
276 os
<< "/*\n" << availability
.getInterfaceDescription().trim() << "\n*/\n";
277 os
<< llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
279 " using OpInterface<{1}, detail::{2}>::OpInterface;\n",
280 interfaceName
, interfaceName
, interfaceTraitsName
);
282 // Emit query function declaration.
283 os
<< " " << availability
.getQueryFnRetType() << " "
284 << availability
.getQueryFnName() << "();\n";
288 static bool emitInterfaceDecls(const RecordKeeper
&recordKeeper
,
290 llvm::emitSourceFileHeader("Availability Interface Declarations", os
,
293 auto defs
= recordKeeper
.getAllDerivedDefinitions("Availability");
294 SmallVector
<const Record
*, 4> handledClasses
;
295 for (const Record
*def
: defs
) {
296 SmallVector
<Record
*, 1> parent
;
297 def
->getDirectSuperClasses(parent
);
298 if (parent
.size() != 1) {
299 PrintFatalError(def
->getLoc(),
300 "expected to only have one direct superclass");
302 if (llvm::is_contained(handledClasses
, parent
.front()))
305 Availability
avail(def
);
306 emitInterfaceDecl(avail
, os
);
307 handledClasses
.push_back(parent
.front());
312 //===----------------------------------------------------------------------===//
313 // Availability Interface Hook Registration
314 //===----------------------------------------------------------------------===//
316 // Registers the operation interface generator to mlir-tblgen.
317 static mlir::GenRegistration
318 genInterfaceDecls("gen-avail-interface-decls",
319 "Generate availability interface declarations",
320 [](const RecordKeeper
&records
, raw_ostream
&os
) {
321 return emitInterfaceDecls(records
, os
);
324 // Registers the operation interface generator to mlir-tblgen.
325 static mlir::GenRegistration
326 genInterfaceDefs("gen-avail-interface-defs",
327 "Generate op interface definitions",
328 [](const RecordKeeper
&records
, raw_ostream
&os
) {
329 return emitInterfaceDefs(records
, os
);
332 //===----------------------------------------------------------------------===//
333 // Enum Availability Query AutoGen
334 //===----------------------------------------------------------------------===//
336 static void emitAvailabilityQueryForIntEnum(const Record
&enumDef
,
338 EnumAttr
enumAttr(enumDef
);
339 StringRef enumName
= enumAttr
.getEnumClassName();
340 std::vector
<EnumAttrCase
> enumerants
= enumAttr
.getAllCases();
342 // Mapping from availability class name to (enumerant, availability
343 // specification) pairs.
344 llvm::StringMap
<llvm::SmallVector
<std::pair
<EnumAttrCase
, Availability
>, 1>>
347 // Place all availability specifications to their corresponding
348 // availability classes.
349 for (const EnumAttrCase
&enumerant
: enumerants
)
350 for (const Availability
&avail
: getAvailabilities(enumerant
.getDef()))
351 classCaseMap
[avail
.getClass()].push_back({enumerant
, avail
});
353 for (const auto &classCasePair
: classCaseMap
) {
354 Availability avail
= classCasePair
.getValue().front().second
;
356 os
<< formatv("std::optional<{0}> {1}({2} value) {{\n",
357 avail
.getMergeInstanceType(), avail
.getQueryFnName(),
360 os
<< " switch (value) {\n";
361 for (const auto &caseSpecPair
: classCasePair
.getValue()) {
362 EnumAttrCase enumerant
= caseSpecPair
.first
;
363 Availability avail
= caseSpecPair
.second
;
364 os
<< formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName
,
365 enumerant
.getSymbol(), avail
.getMergeInstancePreparation(),
366 avail
.getMergeInstanceType(), avail
.getMergeInstance());
368 // Only emit default if uncovered cases.
369 if (classCasePair
.getValue().size() < enumAttr
.getAllCases().size())
370 os
<< " default: break;\n";
372 << " return std::nullopt;\n"
377 static void emitAvailabilityQueryForBitEnum(const Record
&enumDef
,
379 EnumAttr
enumAttr(enumDef
);
380 StringRef enumName
= enumAttr
.getEnumClassName();
381 std::string underlyingType
= std::string(enumAttr
.getUnderlyingType());
382 std::vector
<EnumAttrCase
> enumerants
= enumAttr
.getAllCases();
384 // Mapping from availability class name to (enumerant, availability
385 // specification) pairs.
386 llvm::StringMap
<llvm::SmallVector
<std::pair
<EnumAttrCase
, Availability
>, 1>>
389 // Place all availability specifications to their corresponding
390 // availability classes.
391 for (const EnumAttrCase
&enumerant
: enumerants
)
392 for (const Availability
&avail
: getAvailabilities(enumerant
.getDef()))
393 classCaseMap
[avail
.getClass()].push_back({enumerant
, avail
});
395 for (const auto &classCasePair
: classCaseMap
) {
396 Availability avail
= classCasePair
.getValue().front().second
;
398 os
<< formatv("std::optional<{0}> {1}({2} value) {{\n",
399 avail
.getMergeInstanceType(), avail
.getQueryFnName(),
403 " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1"
404 " && \"cannot have more than one bit set\");\n",
407 os
<< " switch (value) {\n";
408 for (const auto &caseSpecPair
: classCasePair
.getValue()) {
409 EnumAttrCase enumerant
= caseSpecPair
.first
;
410 Availability avail
= caseSpecPair
.second
;
411 os
<< formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName
,
412 enumerant
.getSymbol(), avail
.getMergeInstancePreparation(),
413 avail
.getMergeInstanceType(), avail
.getMergeInstance());
415 os
<< " default: break;\n";
417 << " return std::nullopt;\n"
422 static void emitEnumDecl(const Record
&enumDef
, raw_ostream
&os
) {
423 EnumAttr
enumAttr(enumDef
);
424 StringRef enumName
= enumAttr
.getEnumClassName();
425 StringRef cppNamespace
= enumAttr
.getCppNamespace();
426 auto enumerants
= enumAttr
.getAllCases();
428 llvm::SmallVector
<StringRef
, 2> namespaces
;
429 llvm::SplitString(cppNamespace
, namespaces
, "::");
431 for (auto ns
: namespaces
)
432 os
<< "namespace " << ns
<< " {\n";
434 llvm::StringSet
<> handledClasses
;
436 // Place all availability specifications to their corresponding
437 // availability classes.
438 for (const EnumAttrCase
&enumerant
: enumerants
)
439 for (const Availability
&avail
: getAvailabilities(enumerant
.getDef())) {
440 StringRef className
= avail
.getClass();
441 if (handledClasses
.count(className
))
443 os
<< formatv("std::optional<{0}> {1}({2} value);\n",
444 avail
.getMergeInstanceType(), avail
.getQueryFnName(),
446 handledClasses
.insert(className
);
449 for (auto ns
: llvm::reverse(namespaces
))
450 os
<< "} // namespace " << ns
<< "\n";
453 static bool emitEnumDecls(const RecordKeeper
&recordKeeper
, raw_ostream
&os
) {
454 llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os
,
457 auto defs
= recordKeeper
.getAllDerivedDefinitions("EnumAttrInfo");
458 for (const auto *def
: defs
)
459 emitEnumDecl(*def
, os
);
464 static void emitEnumDef(const Record
&enumDef
, raw_ostream
&os
) {
465 EnumAttr
enumAttr(enumDef
);
466 StringRef cppNamespace
= enumAttr
.getCppNamespace();
468 llvm::SmallVector
<StringRef
, 2> namespaces
;
469 llvm::SplitString(cppNamespace
, namespaces
, "::");
471 for (auto ns
: namespaces
)
472 os
<< "namespace " << ns
<< " {\n";
474 if (enumAttr
.isBitEnum()) {
475 emitAvailabilityQueryForBitEnum(enumDef
, os
);
477 emitAvailabilityQueryForIntEnum(enumDef
, os
);
480 for (auto ns
: llvm::reverse(namespaces
))
481 os
<< "} // namespace " << ns
<< "\n";
485 static bool emitEnumDefs(const RecordKeeper
&recordKeeper
, raw_ostream
&os
) {
486 llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os
,
489 auto defs
= recordKeeper
.getAllDerivedDefinitions("EnumAttrInfo");
490 for (const auto *def
: defs
)
491 emitEnumDef(*def
, os
);
496 //===----------------------------------------------------------------------===//
497 // Enum Availability Query Hook Registration
498 //===----------------------------------------------------------------------===//
500 // Registers the enum utility generator to mlir-tblgen.
501 static mlir::GenRegistration
502 genEnumDecls("gen-spirv-enum-avail-decls",
503 "Generate SPIR-V enum availability declarations",
504 [](const RecordKeeper
&records
, raw_ostream
&os
) {
505 return emitEnumDecls(records
, os
);
508 // Registers the enum utility generator to mlir-tblgen.
509 static mlir::GenRegistration
510 genEnumDefs("gen-spirv-enum-avail-defs",
511 "Generate SPIR-V enum availability definitions",
512 [](const RecordKeeper
&records
, raw_ostream
&os
) {
513 return emitEnumDefs(records
, os
);
516 //===----------------------------------------------------------------------===//
517 // Serialization AutoGen
518 //===----------------------------------------------------------------------===//
520 // These enums are encoded as <id> to constant values in SPIR-V blob, but we
521 // directly use the constant value as attribute in SPIR-V dialect. So need
522 // to handle them separately from normal enum attributes.
523 constexpr llvm::StringLiteral constantIdEnumAttrs
[] = {
524 "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
525 "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
526 "SPIRV_MatrixLayoutAttr"};
528 /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
529 /// generates code extracts the attribute with name `attrName` from
530 /// `operandList` of `op`.
531 static void emitAttributeSerialization(const Attribute
&attr
,
532 ArrayRef
<SMLoc
> loc
, StringRef tabs
,
533 StringRef opVar
, StringRef operandList
,
534 StringRef attrName
, raw_ostream
&os
) {
536 << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar
, attrName
);
537 if (llvm::is_contained(constantIdEnumAttrs
, attr
.getAttrDefName())) {
538 EnumAttr
baseEnum(attr
.getDef().getValueAsDef("enum"));
540 << formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), "
541 "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>("
542 "::llvm::cast<{2}::{3}Attr>(attr).getValue()))));\n",
543 operandList
, opVar
, baseEnum
.getCppNamespace(),
544 baseEnum
.getEnumClassName());
545 } else if (attr
.isSubClassOf("SPIRV_BitEnumAttr") ||
546 attr
.isSubClassOf("SPIRV_I32EnumAttr")) {
547 EnumAttr
baseEnum(attr
.getDef().getValueAsDef("enum"));
549 << formatv(" {0}.push_back(static_cast<uint32_t>("
550 "::llvm::cast<{1}::{2}Attr>(attr).getValue()));\n",
551 operandList
, baseEnum
.getCppNamespace(),
552 baseEnum
.getEnumClassName());
553 } else if (attr
.getAttrDefName() == "I32ArrayAttr") {
554 // Serialize all the elements of the array
555 os
<< tabs
<< " for (auto attrElem : llvm::cast<ArrayAttr>(attr)) {\n";
557 << formatv(" {0}.push_back(static_cast<uint32_t>("
558 "llvm::cast<IntegerAttr>(attrElem).getValue().getZExtValue())"
561 os
<< tabs
<< " }\n";
562 } else if (attr
.getAttrDefName() == "I32Attr") {
565 " {0}.push_back(static_cast<uint32_t>("
566 "llvm::cast<IntegerAttr>(attr).getValue().getZExtValue()));\n",
568 } else if (attr
.isEnumAttr() || attr
.isTypeAttr()) {
569 // It may be the first time this type appears in the IR, so we need to
571 StringRef attrTypeID
= "attrTypeID";
572 os
<< tabs
<< formatv(" uint32_t {0} = 0;\n", attrTypeID
);
574 << formatv(" if (failed(processType({0}.getLoc(), "
575 "llvm::cast<TypeAttr>(attr).getValue(), {1}))) {{\n",
577 os
<< tabs
<< " return failure();\n";
578 os
<< tabs
<< " }\n";
579 os
<< tabs
<< formatv(" {0}.push_back(attrTypeID);\n", operandList
);
584 "unhandled attribute type in SPIR-V serialization generation : '") +
585 attr
.getAttrDefName() + llvm::Twine("'"));
590 /// Generates code to serialize the operands of a SPIRV_Op `op` into `os`. The
591 /// generated queries the SSA-ID if operand is a SSA-Value, or serializes the
592 /// attributes. The `operands` vector is updated appropriately. `elidedAttrs`
593 /// updated as well to include the serialized attributes.
594 static void emitArgumentSerialization(const Operator
&op
, ArrayRef
<SMLoc
> loc
,
595 StringRef tabs
, StringRef opVar
,
596 StringRef operands
, StringRef elidedAttrs
,
598 using mlir::tblgen::Argument
;
600 // SPIR-V ops can mix operands and attributes in the definition. These
601 // operands and attributes are serialized in the exact order of the definition
602 // to match SPIR-V binary format requirements. It can cause excessive
603 // generated code bloat because we are emitting code to handle each
604 // operand/attribute separately. So here we probe first to check whether all
605 // the operands are ahead of attributes. Then we can serialize all operands
608 // Whether all operands are ahead of all attributes in the op's spec.
609 bool areOperandsAheadOfAttrs
= true;
610 // Find the first attribute.
611 const Argument
*it
= llvm::find_if(op
.getArgs(), [](const Argument
&arg
) {
612 return arg
.is
<NamedAttribute
*>();
614 // Check whether all following arguments are attributes.
615 for (const Argument
*ie
= op
.arg_end(); it
!= ie
; ++it
) {
616 if (!it
->is
<NamedAttribute
*>()) {
617 areOperandsAheadOfAttrs
= false;
622 // Serialize all operands together.
623 if (areOperandsAheadOfAttrs
) {
624 if (op
.getNumOperands() != 0) {
626 << formatv("for (Value operand : {0}->getOperands()) {{\n", opVar
);
627 os
<< tabs
<< " auto id = getValueID(operand);\n";
628 os
<< tabs
<< " assert(id && \"use before def!\");\n";
629 os
<< tabs
<< formatv(" {0}.push_back(id);\n", operands
);
632 for (const NamedAttribute
&attr
: op
.getAttributes()) {
633 emitAttributeSerialization(
634 (attr
.attr
.isOptional() ? attr
.attr
.getBaseAttr() : attr
.attr
), loc
,
635 tabs
, opVar
, operands
, attr
.name
, os
);
637 << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs
, attr
.name
);
642 // Serialize operands separately.
644 for (unsigned i
= 0, e
= op
.getNumArgs(); i
< e
; ++i
) {
645 auto argument
= op
.getArg(i
);
647 if (argument
.is
<NamedTypeConstraint
*>()) {
649 << formatv(" for (auto arg : {0}.getODSOperands({1})) {{\n", opVar
,
651 os
<< tabs
<< " auto argID = getValueID(arg);\n";
652 os
<< tabs
<< " if (!argID) {\n";
654 << formatv(" return emitError({0}.getLoc(), "
655 "\"operand #{1} has a use before def\");\n",
657 os
<< tabs
<< " }\n";
658 os
<< tabs
<< formatv(" {0}.push_back(argID);\n", operands
);
662 NamedAttribute
*attr
= argument
.get
<NamedAttribute
*>();
663 auto newtabs
= tabs
.str() + " ";
664 emitAttributeSerialization(
665 (attr
->attr
.isOptional() ? attr
->attr
.getBaseAttr() : attr
->attr
),
666 loc
, newtabs
, opVar
, operands
, attr
->name
, os
);
668 << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs
, attr
->name
);
674 /// Generates code to serializes the result of SPIRV_Op `op` into `os`. The
675 /// generated gets the ID for the type of the result (if any), the SSA-ID of
676 /// the result and updates `resultID` with the SSA-ID.
677 static void emitResultSerialization(const Operator
&op
, ArrayRef
<SMLoc
> loc
,
678 StringRef tabs
, StringRef opVar
,
679 StringRef operands
, StringRef resultID
,
681 if (op
.getNumResults() == 1) {
682 StringRef
resultTypeID("resultTypeID");
683 os
<< tabs
<< formatv("uint32_t {0} = 0;\n", resultTypeID
);
686 "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n",
687 opVar
, resultTypeID
);
688 os
<< tabs
<< " return failure();\n";
690 os
<< tabs
<< formatv("{0}.push_back({1});\n", operands
, resultTypeID
);
691 // Create an SSA result <id> for the op
692 os
<< tabs
<< formatv("{0} = getNextID();\n", resultID
);
694 << formatv("valueIDMap[{0}.getResult()] = {1};\n", opVar
, resultID
);
695 os
<< tabs
<< formatv("{0}.push_back({1});\n", operands
, resultID
);
696 } else if (op
.getNumResults() != 0) {
697 PrintFatalError(loc
, "SPIR-V ops can only have zero or one result");
701 /// Generates code to serialize attributes of SPIRV_Op `op` that become
702 /// decorations on the `resultID` of the serialized operation `opVar` in the
704 static void emitDecorationSerialization(const Operator
&op
, StringRef tabs
,
705 StringRef opVar
, StringRef elidedAttrs
,
706 StringRef resultID
, raw_ostream
&os
) {
707 if (op
.getNumResults() == 1) {
708 // All non-argument attributes translated into OpDecorate instruction
709 os
<< tabs
<< formatv("for (auto attr : {0}->getAttrs()) {{\n", opVar
);
711 << formatv(" if (llvm::is_contained({0}, attr.getName())) {{",
713 os
<< tabs
<< " continue;\n";
714 os
<< tabs
<< " }\n";
717 " if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n",
719 os
<< tabs
<< " return failure();\n";
720 os
<< tabs
<< " }\n";
725 /// Generates code to serialize an SPIRV_Op `op` into `os`.
726 static void emitSerializationFunction(const Record
*attrClass
,
727 const Record
*record
, const Operator
&op
,
729 // If the record has 'autogenSerialization' set to 0, nothing to do
730 if (!record
->getValueAsBit("autogenSerialization"))
733 StringRef
opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"),
734 resultID("resultID");
737 "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n",
738 op
.getQualCppClassName(), opVar
);
740 // Special case for ops without attributes in TableGen definitions
741 if (op
.getNumAttributes() == 0 && op
.getNumVariableLengthOperands() == 0) {
742 std::string extInstSet
;
744 if (record
->isSubClassOf("SPIRV_ExtInstOp")) {
746 formatv("\"{0}\"", record
->getValueAsString("extendedInstSetName"));
747 opcode
= std::to_string(record
->getValueAsInt("extendedInstOpcode"));
750 opcode
= formatv("static_cast<uint32_t>(spirv::Opcode::{0})",
751 record
->getValueAsString("spirvOpName"));
754 os
<< formatv(" return processOpWithoutGrammarAttr({0}, {1}, {2});\n}\n\n",
755 opVar
, extInstSet
, opcode
);
759 os
<< formatv(" SmallVector<uint32_t, 4> {0};\n", operands
);
760 os
<< formatv(" SmallVector<StringRef, 2> {0};\n", elidedAttrs
);
762 // Serialize result information.
763 if (op
.getNumResults() == 1) {
764 os
<< formatv(" uint32_t {0} = 0;\n", resultID
);
765 emitResultSerialization(op
, record
->getLoc(), " ", opVar
, operands
,
769 // Process arguments.
770 emitArgumentSerialization(op
, record
->getLoc(), " ", opVar
, operands
,
773 if (record
->isSubClassOf("SPIRV_ExtInstOp")) {
775 " (void)encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n", opVar
,
776 record
->getValueAsString("extendedInstSetName"),
777 record
->getValueAsInt("extendedInstOpcode"), operands
);
780 os
<< formatv(" (void)emitDebugLine(functionBody, {0}.getLoc());\n",
782 os
<< formatv(" (void)encodeInstructionInto("
783 "functionBody, spirv::Opcode::{1}, {2});\n",
784 op
.getQualCppClassName(),
785 record
->getValueAsString("spirvOpName"), operands
);
788 // Process decorations.
789 emitDecorationSerialization(op
, " ", opVar
, elidedAttrs
, resultID
, os
);
791 os
<< " return success();\n";
795 /// Generates the prologue for the function that dispatches the serialization of
796 /// the operation `opVar` based on its opcode.
797 static void initDispatchSerializationFn(StringRef opVar
, raw_ostream
&os
) {
799 "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
804 /// Generates the body of the dispatch function. This function generates the
805 /// check that if satisfied, will call the serialization function generated for
807 static void emitSerializationDispatch(const Operator
&op
, StringRef tabs
,
808 StringRef opVar
, raw_ostream
&os
) {
810 << formatv("if (isa<{0}>({1})) {{\n", op
.getQualCppClassName(), opVar
);
812 << formatv(" return processOp(cast<{0}>({1}));\n",
813 op
.getQualCppClassName(), opVar
);
817 /// Generates the epilogue for the function that dispatches the serialization of
819 static void finalizeDispatchSerializationFn(StringRef opVar
, raw_ostream
&os
) {
821 " return {0}->emitError(\"unhandled operation serialization\");\n",
826 /// Generates code to deserialize the attribute of a SPIRV_Op into `os`. The
827 /// generated code reads the `words` of the serialized instruction at
828 /// position `wordIndex` and adds the deserialized attribute into `attrList`.
829 static void emitAttributeDeserialization(const Attribute
&attr
,
830 ArrayRef
<SMLoc
> loc
, StringRef tabs
,
831 StringRef attrList
, StringRef attrName
,
832 StringRef words
, StringRef wordIndex
,
834 if (llvm::is_contained(constantIdEnumAttrs
, attr
.getAttrDefName())) {
835 EnumAttr
baseEnum(attr
.getDef().getValueAsDef("enum"));
837 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
838 "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>("
839 "getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n",
840 attrList
, attrName
, baseEnum
.getCppNamespace(),
841 baseEnum
.getEnumClassName(), words
, wordIndex
);
842 } else if (attr
.isSubClassOf("SPIRV_BitEnumAttr") ||
843 attr
.isSubClassOf("SPIRV_I32EnumAttr")) {
844 EnumAttr
baseEnum(attr
.getDef().getValueAsDef("enum"));
846 << formatv(" {0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
847 "opBuilder.getAttr<{2}::{3}Attr>("
848 "static_cast<{2}::{3}>({4}[{5}++]))));\n",
849 attrList
, attrName
, baseEnum
.getCppNamespace(),
850 baseEnum
.getEnumClassName(), words
, wordIndex
);
851 } else if (attr
.getAttrDefName() == "I32ArrayAttr") {
852 os
<< tabs
<< "SmallVector<Attribute, 4> attrListElems;\n";
853 os
<< tabs
<< formatv("while ({0} < {1}.size()) {{\n", wordIndex
, words
);
857 "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))"
862 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
863 "opBuilder.getArrayAttr(attrListElems)));\n",
865 } else if (attr
.getAttrDefName() == "I32Attr") {
867 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
868 "opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
869 attrList
, attrName
, words
, wordIndex
);
870 } else if (attr
.isEnumAttr() || attr
.isTypeAttr()) {
872 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
873 "TypeAttr::get(getType({2}[{3}++]))));\n",
874 attrList
, attrName
, words
, wordIndex
);
878 "unhandled attribute type in deserialization generation : '") +
879 attrName
+ llvm::Twine("'"));
883 /// Generates the code to deserialize the result of an SPIRV_Op `op` into
884 /// `os`. The generated code gets the type of the result specified at
885 /// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1
886 /// and updates the `resultType` and `valueID` with the parsed type and SSA ID,
888 static void emitResultDeserialization(const Operator
&op
, ArrayRef
<SMLoc
> loc
,
889 StringRef tabs
, StringRef words
,
891 StringRef resultTypes
, StringRef valueID
,
893 // Deserialize result information if it exists
894 if (op
.getNumResults() == 1) {
896 os
<< tabs
<< formatv(" if ({0} >= {1}.size()) {{\n", wordIndex
, words
);
899 " return emitError(unknownLoc, \"expected result type <id> "
900 "while deserializing {0}\");\n",
901 op
.getQualCppClassName());
902 os
<< tabs
<< " }\n";
903 os
<< tabs
<< formatv(" auto ty = getType({0}[{1}]);\n", words
, wordIndex
);
904 os
<< tabs
<< " if (!ty) {\n";
907 " return emitError(unknownLoc, \"unknown type result <id> : "
908 "\") << {0}[{1}];\n",
910 os
<< tabs
<< " }\n";
911 os
<< tabs
<< formatv(" {0}.push_back(ty);\n", resultTypes
);
912 os
<< tabs
<< formatv(" {0}++;\n", wordIndex
);
913 os
<< tabs
<< formatv(" if ({0} >= {1}.size()) {{\n", wordIndex
, words
);
916 " return emitError(unknownLoc, \"expected result <id> while "
917 "deserializing {0}\");\n",
918 op
.getQualCppClassName());
919 os
<< tabs
<< " }\n";
921 os
<< tabs
<< formatv("{0} = {1}[{2}++];\n", valueID
, words
, wordIndex
);
922 } else if (op
.getNumResults() != 0) {
923 PrintFatalError(loc
, "SPIR-V ops can have only zero or one result");
927 /// Generates the code to deserialize the operands of an SPIRV_Op `op` into
928 /// `os`. The generated code reads the `words` of the binary instruction, from
929 /// position `wordIndex` to the end, and either gets the Value corresponding to
930 /// the ID encoded, or deserializes the attributes encoded. The parsed
931 /// operand(attribute) is added to the `operands` list or `attributes` list.
932 static void emitOperandDeserialization(const Operator
&op
, ArrayRef
<SMLoc
> loc
,
933 StringRef tabs
, StringRef words
,
934 StringRef wordIndex
, StringRef operands
,
935 StringRef attributes
, raw_ostream
&os
) {
936 // Process operands/attributes
937 for (unsigned i
= 0, e
= op
.getNumArgs(); i
< e
; ++i
) {
938 auto argument
= op
.getArg(i
);
939 if (auto *valueArg
= llvm::dyn_cast_if_present
<NamedTypeConstraint
*>(argument
)) {
940 if (valueArg
->isVariableLength()) {
943 loc
, "SPIR-V ops can have Variadic<..> or "
944 "Optional<...> arguments only if it's the last argument");
947 << formatv("for (; {0} < {1}.size(); ++{0})", wordIndex
, words
);
949 os
<< tabs
<< formatv("if ({0} < {1}.size())", wordIndex
, words
);
953 << formatv(" auto arg = getValue({0}[{1}]);\n", words
, wordIndex
);
954 os
<< tabs
<< " if (!arg) {\n";
957 " return emitError(unknownLoc, \"unknown result <id> : \") "
960 os
<< tabs
<< " }\n";
961 os
<< tabs
<< formatv(" {0}.push_back(arg);\n", operands
);
962 if (!valueArg
->isVariableLength()) {
963 os
<< tabs
<< formatv(" {0}++;\n", wordIndex
);
967 os
<< tabs
<< formatv("if ({0} < {1}.size()) {{\n", wordIndex
, words
);
968 auto *attr
= argument
.get
<NamedAttribute
*>();
969 auto newtabs
= tabs
.str() + " ";
970 emitAttributeDeserialization(
971 (attr
->attr
.isOptional() ? attr
->attr
.getBaseAttr() : attr
->attr
),
972 loc
, newtabs
, attributes
, attr
->name
, words
, wordIndex
, os
);
977 os
<< tabs
<< formatv("if ({0} != {1}.size()) {{\n", wordIndex
, words
);
980 " return emitError(unknownLoc, \"found more operands than "
981 "expected when deserializing {0}, only \") << {1} << \" of \" << "
982 "{2}.size() << \" processed\";\n",
983 op
.getQualCppClassName(), wordIndex
, words
);
984 os
<< tabs
<< "}\n\n";
987 /// Generates code to update the `attributes` vector with the attributes
988 /// obtained from parsing the decorations in the SPIR-V binary associated with
989 /// an <id> `valueID`
990 static void emitDecorationDeserialization(const Operator
&op
, StringRef tabs
,
992 StringRef attributes
,
994 // Import decorations parsed
995 if (op
.getNumResults() == 1) {
996 os
<< tabs
<< formatv("if (decorations.count({0})) {{\n", valueID
);
998 << formatv(" auto attrs = decorations[{0}].getAttrs();\n", valueID
);
1000 << formatv(" {0}.append(attrs.begin(), attrs.end());\n", attributes
);
1001 os
<< tabs
<< "}\n";
1005 /// Generates code to deserialize an SPIRV_Op `op` into `os`.
1006 static void emitDeserializationFunction(const Record
*attrClass
,
1007 const Record
*record
,
1008 const Operator
&op
, raw_ostream
&os
) {
1009 // If the record has 'autogenSerialization' set to 0, nothing to do
1010 if (!record
->getValueAsBit("autogenSerialization"))
1013 StringRef
resultTypes("resultTypes"), valueID("valueID"), words("words"),
1014 wordIndex("wordIndex"), opVar("op"), operands("operands"),
1015 attributes("attributes");
1017 // Method declaration
1018 os
<< formatv("template <> "
1019 "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
1020 "uint32_t> {1}) {{\n",
1021 op
.getQualCppClassName(), words
);
1023 // Special case for ops without attributes in TableGen definitions
1024 if (op
.getNumAttributes() == 0 && op
.getNumVariableLengthOperands() == 0) {
1025 os
<< formatv(" return processOpWithoutGrammarAttr("
1026 "{0}, \"{1}\", {2}, {3});\n}\n\n",
1027 words
, op
.getOperationName(),
1028 op
.getNumResults() ? "true" : "false", op
.getNumOperands());
1032 os
<< formatv(" SmallVector<Type, 1> {0};\n", resultTypes
);
1033 os
<< formatv(" size_t {0} = 0; (void){0};\n", wordIndex
);
1034 os
<< formatv(" uint32_t {0} = 0; (void){0};\n", valueID
);
1036 // Deserialize result information
1037 emitResultDeserialization(op
, record
->getLoc(), " ", words
, wordIndex
,
1038 resultTypes
, valueID
, os
);
1040 os
<< formatv(" SmallVector<Value, 4> {0};\n", operands
);
1041 os
<< formatv(" SmallVector<NamedAttribute, 4> {0};\n", attributes
);
1042 // Operand deserialization
1043 emitOperandDeserialization(op
, record
->getLoc(), " ", words
, wordIndex
,
1044 operands
, attributes
, os
);
1047 emitDecorationDeserialization(op
, " ", valueID
, attributes
, os
);
1049 os
<< formatv(" Location loc = createFileLineColLoc(opBuilder);\n");
1050 os
<< formatv(" auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); "
1052 op
.getQualCppClassName(), opVar
, resultTypes
, operands
,
1054 if (op
.getNumResults() == 1) {
1055 os
<< formatv(" valueMap[{0}] = {1}.getResult();\n\n", valueID
, opVar
);
1058 // According to SPIR-V spec:
1059 // This location information applies to the instructions physically following
1060 // this instruction, up to the first occurrence of any of the following: the
1061 // next end of block.
1062 os
<< formatv(" if ({0}.hasTrait<OpTrait::IsTerminator>())\n", opVar
);
1063 os
<< formatv(" (void)clearDebugLine();\n");
1064 os
<< " return success();\n";
1068 /// Generates the prologue for the function that dispatches the deserialization
1069 /// based on the `opcode`.
1070 static void initDispatchDeserializationFn(StringRef opcode
, StringRef words
,
1072 os
<< formatv("LogicalResult spirv::Deserializer::"
1073 "dispatchToAutogenDeserialization(spirv::Opcode {0},"
1074 " ArrayRef<uint32_t> {1}) {{\n",
1076 os
<< formatv(" switch ({0}) {{\n", opcode
);
1079 /// Generates the body of the dispatch function, by generating the case label
1080 /// for an opcode and the call to the method to perform the deserialization.
1081 static void emitDeserializationDispatch(const Operator
&op
, const Record
*def
,
1082 StringRef tabs
, StringRef words
,
1085 << formatv("case spirv::Opcode::{0}:\n",
1086 def
->getValueAsString("spirvOpName"));
1088 << formatv(" return processOp<{0}>({1});\n", op
.getQualCppClassName(),
1092 /// Generates the epilogue for the function that dispatches the deserialization
1093 /// of the operation.
1094 static void finalizeDispatchDeserializationFn(StringRef opcode
,
1096 os
<< " default:\n";
1099 StringRef
opcodeVar("opcodeString");
1100 os
<< formatv(" auto {0} = spirv::stringifyOpcode({1});\n", opcodeVar
,
1102 os
<< formatv(" if (!{0}.empty()) {{\n", opcodeVar
);
1103 os
<< formatv(" return emitError(unknownLoc, \"unhandled deserialization "
1106 os
<< " } else {\n";
1107 os
<< formatv(" return emitError(unknownLoc, \"unhandled opcode \") << "
1108 "static_cast<uint32_t>({0});\n",
1114 static void initExtendedSetDeserializationDispatch(StringRef extensionSetName
,
1115 StringRef instructionID
,
1118 os
<< formatv("LogicalResult spirv::Deserializer::"
1119 "dispatchToExtensionSetAutogenDeserialization("
1120 "StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n",
1121 extensionSetName
, instructionID
, words
);
1125 emitExtendedSetDeserializationDispatch(const RecordKeeper
&recordKeeper
,
1127 StringRef
extensionSetName("extensionSetName"),
1128 instructionID("instructionID"), words("words");
1130 // First iterate over all ops derived from SPIRV_ExtensionSetOps to get all
1133 // For each of the extensions a separate raw_string_ostream is used to
1134 // generate code into. These are then concatenated at the end. Since
1135 // raw_string_ostream needs a string&, use a vector to store all the string
1136 // that are captured by reference within raw_string_ostream.
1137 StringMap
<raw_string_ostream
> extensionSets
;
1138 std::list
<std::string
> extensionSetNames
;
1140 initExtendedSetDeserializationDispatch(extensionSetName
, instructionID
, words
,
1142 auto defs
= recordKeeper
.getAllDerivedDefinitions("SPIRV_ExtInstOp");
1143 for (const auto *def
: defs
) {
1144 if (!def
->getValueAsBit("autogenSerialization")) {
1148 auto setName
= def
->getValueAsString("extendedInstSetName");
1149 if (!extensionSets
.count(setName
)) {
1150 extensionSetNames
.emplace_back("");
1151 extensionSets
.try_emplace(setName
, extensionSetNames
.back());
1152 auto &setos
= extensionSets
.find(setName
)->second
;
1153 setos
<< formatv(" if ({0} == \"{1}\") {{\n", extensionSetName
, setName
);
1154 setos
<< formatv(" switch ({0}) {{\n", instructionID
);
1156 auto &setos
= extensionSets
.find(setName
)->second
;
1157 setos
<< formatv(" case {0}:\n",
1158 def
->getValueAsInt("extendedInstOpcode"));
1159 setos
<< formatv(" return processOp<{0}>({1});\n",
1160 op
.getQualCppClassName(), words
);
1163 // Append the dispatch code for all the extended sets.
1164 for (auto &extensionSet
: extensionSets
) {
1165 os
<< extensionSet
.second
.str();
1166 os
<< " default:\n";
1168 " return emitError(unknownLoc, \"unhandled deserializations of "
1169 "\") << {0} << \" from extension set \" << {1};\n",
1170 instructionID
, extensionSetName
);
1175 os
<< formatv(" return emitError(unknownLoc, \"unhandled deserialization of "
1176 "extended instruction set {0}\");\n",
1181 /// Emits all the autogenerated serialization/deserializations functions for the
1183 static bool emitSerializationFns(const RecordKeeper
&recordKeeper
,
1185 llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os
,
1188 std::string dSerFnString
, dDesFnString
, serFnString
, deserFnString
,
1190 raw_string_ostream
dSerFn(dSerFnString
), dDesFn(dDesFnString
),
1191 serFn(serFnString
), deserFn(deserFnString
);
1192 Record
*attrClass
= recordKeeper
.getClass("Attr");
1194 // Emit the serialization and deserialization functions simultaneously.
1195 StringRef
opVar("op");
1196 StringRef
opcode("opcode"), words("words");
1198 // Handle the SPIR-V ops.
1199 initDispatchSerializationFn(opVar
, dSerFn
);
1200 initDispatchDeserializationFn(opcode
, words
, dDesFn
);
1201 auto defs
= recordKeeper
.getAllDerivedDefinitions("SPIRV_Op");
1202 for (const auto *def
: defs
) {
1204 emitSerializationFunction(attrClass
, def
, op
, serFn
);
1205 emitDeserializationFunction(attrClass
, def
, op
, deserFn
);
1206 if (def
->getValueAsBit("hasOpcode") ||
1207 def
->isSubClassOf("SPIRV_ExtInstOp")) {
1208 emitSerializationDispatch(op
, " ", opVar
, dSerFn
);
1210 if (def
->getValueAsBit("hasOpcode")) {
1211 emitDeserializationDispatch(op
, def
, " ", words
, dDesFn
);
1214 finalizeDispatchSerializationFn(opVar
, dSerFn
);
1215 finalizeDispatchDeserializationFn(opcode
, dDesFn
);
1217 emitExtendedSetDeserializationDispatch(recordKeeper
, dDesFn
);
1219 os
<< "#ifdef GET_SERIALIZATION_FNS\n\n";
1222 os
<< "#endif // GET_SERIALIZATION_FNS\n\n";
1224 os
<< "#ifdef GET_DESERIALIZATION_FNS\n\n";
1225 os
<< deserFn
.str();
1227 os
<< "#endif // GET_DESERIALIZATION_FNS\n\n";
1232 //===----------------------------------------------------------------------===//
1233 // Serialization Hook Registration
1234 //===----------------------------------------------------------------------===//
1236 static mlir::GenRegistration
genSerialization(
1237 "gen-spirv-serialization",
1238 "Generate SPIR-V (de)serialization utilities and functions",
1239 [](const RecordKeeper
&records
, raw_ostream
&os
) {
1240 return emitSerializationFns(records
, os
);
1243 //===----------------------------------------------------------------------===//
1245 //===----------------------------------------------------------------------===//
1247 static void emitEnumGetAttrNameFnDecl(raw_ostream
&os
) {
1248 os
<< formatv("template <typename EnumClass> inline constexpr StringRef "
1249 "attributeName();\n");
1252 static void emitEnumGetAttrNameFnDefn(const EnumAttr
&enumAttr
,
1254 auto enumName
= enumAttr
.getEnumClassName();
1255 os
<< formatv("template <> inline StringRef attributeName<{0}>() {{\n",
1258 << formatv("static constexpr const char attrName[] = \"{0}\";\n",
1259 llvm::convertToSnakeFromCamelCase(enumName
));
1260 os
<< " return attrName;\n";
1264 static bool emitAttrUtils(const RecordKeeper
&recordKeeper
, raw_ostream
&os
) {
1265 llvm::emitSourceFileHeader("SPIR-V Attribute Utilities", os
, recordKeeper
);
1267 auto defs
= recordKeeper
.getAllDerivedDefinitions("EnumAttrInfo");
1268 os
<< "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1269 os
<< "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1270 emitEnumGetAttrNameFnDecl(os
);
1271 for (const auto *def
: defs
) {
1272 EnumAttr
enumAttr(*def
);
1273 emitEnumGetAttrNameFnDefn(enumAttr
, os
);
1275 os
<< "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n";
1279 //===----------------------------------------------------------------------===//
1280 // Op Utils Hook Registration
1281 //===----------------------------------------------------------------------===//
1283 static mlir::GenRegistration
1284 genOpUtils("gen-spirv-attr-utils",
1285 "Generate SPIR-V attribute utility definitions",
1286 [](const RecordKeeper
&records
, raw_ostream
&os
) {
1287 return emitAttrUtils(records
, os
);
1290 //===----------------------------------------------------------------------===//
1291 // SPIR-V Availability Impl AutoGen
1292 //===----------------------------------------------------------------------===//
1294 static void emitAvailabilityImpl(const Operator
&srcOp
, raw_ostream
&os
) {
1295 mlir::tblgen::FmtContext fctx
;
1296 fctx
.addSubst("overall", "tblgen_overall");
1298 std::vector
<Availability
> opAvailabilities
=
1299 getAvailabilities(srcOp
.getDef());
1301 // First collect all availability classes this op should implement.
1302 // All availability instances keep information for the generated interface and
1303 // the instance's specific requirement. Here we remember a random instance so
1304 // we can get the information regarding the generated interface.
1305 llvm::StringMap
<Availability
> availClasses
;
1306 for (const Availability
&avail
: opAvailabilities
)
1307 availClasses
.try_emplace(avail
.getClass(), avail
);
1308 for (const NamedAttribute
&namedAttr
: srcOp
.getAttributes()) {
1309 if (!namedAttr
.attr
.isSubClassOf("SPIRV_BitEnumAttr") &&
1310 !namedAttr
.attr
.isSubClassOf("SPIRV_I32EnumAttr"))
1312 EnumAttr
enumAttr(namedAttr
.attr
.getDef().getValueAsDef("enum"));
1314 for (const EnumAttrCase
&enumerant
: enumAttr
.getAllCases())
1315 for (const Availability
&caseAvail
:
1316 getAvailabilities(enumerant
.getDef()))
1317 availClasses
.try_emplace(caseAvail
.getClass(), caseAvail
);
1320 // Then generate implementation for each availability class.
1321 for (const auto &availClass
: availClasses
) {
1322 StringRef availClassName
= availClass
.getKey();
1323 Availability avail
= availClass
.getValue();
1325 // Generate the implementation method signature.
1326 os
<< formatv("{0} {1}::{2}() {{\n", avail
.getQueryFnRetType(),
1327 srcOp
.getCppClassName(), avail
.getQueryFnName());
1329 // Create the variable for the final requirement and initialize it.
1330 os
<< formatv(" {0} tblgen_overall = {1};\n", avail
.getQueryFnRetType(),
1331 avail
.getMergeInitializer());
1333 // Update with the op's specific availability spec.
1334 for (const Availability
&avail
: opAvailabilities
)
1335 if (avail
.getClass() == availClassName
&&
1336 (!avail
.getMergeInstancePreparation().empty() ||
1337 !avail
.getMergeActionCode().empty())) {
1339 // Prepare this instance.
1340 << avail
.getMergeInstancePreparation()
1342 // Merge this instance.
1344 tgfmt(avail
.getMergeActionCode(),
1345 &fctx
.addSubst("instance", avail
.getMergeInstance())))
1349 // Update with enum attributes' specific availability spec.
1350 for (const NamedAttribute
&namedAttr
: srcOp
.getAttributes()) {
1351 if (!namedAttr
.attr
.isSubClassOf("SPIRV_BitEnumAttr") &&
1352 !namedAttr
.attr
.isSubClassOf("SPIRV_I32EnumAttr"))
1354 EnumAttr
enumAttr(namedAttr
.attr
.getDef().getValueAsDef("enum"));
1356 // (enumerant, availability specification) pairs for this availability
1358 SmallVector
<std::pair
<EnumAttrCase
, Availability
>, 1> caseSpecs
;
1360 // Collect all cases' availability specs.
1361 for (const EnumAttrCase
&enumerant
: enumAttr
.getAllCases())
1362 for (const Availability
&caseAvail
:
1363 getAvailabilities(enumerant
.getDef()))
1364 if (availClassName
== caseAvail
.getClass())
1365 caseSpecs
.push_back({enumerant
, caseAvail
});
1367 // If this attribute kind does not have any availability spec from any of
1368 // its cases, no more work to do.
1369 if (caseSpecs
.empty())
1372 if (enumAttr
.isBitEnum()) {
1373 // For BitEnumAttr, we need to iterate over each bit to query its
1374 // availability spec.
1375 os
<< formatv(" for (unsigned i = 0; "
1376 "i < std::numeric_limits<{0}>::digits; ++i) {{\n",
1377 enumAttr
.getUnderlyingType());
1378 os
<< formatv(" {0}::{1} tblgen_attrVal = this->{2}() & "
1379 "static_cast<{0}::{1}>(1 << i);\n",
1380 enumAttr
.getCppNamespace(), enumAttr
.getEnumClassName(),
1381 srcOp
.getGetterName(namedAttr
.name
));
1383 " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
1384 enumAttr
.getUnderlyingType());
1386 // For IntEnumAttr, we just need to query the value as a whole.
1388 os
<< formatv(" auto tblgen_attrVal = this->{0}();\n",
1389 srcOp
.getGetterName(namedAttr
.name
));
1391 os
<< formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
1392 enumAttr
.getCppNamespace(), avail
.getQueryFnName());
1393 os
<< " if (tblgen_instance) "
1394 // TODO` here once ODS supports
1395 // dialect-specific contents so that we can use not implementing the
1396 // availability interface as indication of no requirements.
1397 << std::string(tgfmt(caseSpecs
.front().second
.getMergeActionCode(),
1398 &fctx
.addSubst("instance", "*tblgen_instance")))
1403 os
<< " return tblgen_overall;\n";
1408 static bool emitAvailabilityImpl(const RecordKeeper
&recordKeeper
,
1410 llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os
,
1413 auto defs
= recordKeeper
.getAllDerivedDefinitions("SPIRV_Op");
1414 for (const auto *def
: defs
) {
1416 if (def
->getValueAsBit("autogenAvailability"))
1417 emitAvailabilityImpl(op
, os
);
1422 //===----------------------------------------------------------------------===//
1423 // Op Availability Implementation Hook Registration
1424 //===----------------------------------------------------------------------===//
1426 static mlir::GenRegistration
1427 genOpAvailabilityImpl("gen-spirv-avail-impls",
1428 "Generate SPIR-V operation utility definitions",
1429 [](const RecordKeeper
&records
, raw_ostream
&os
) {
1430 return emitAvailabilityImpl(records
, os
);
1433 //===----------------------------------------------------------------------===//
1434 // SPIR-V Capability Implication AutoGen
1435 //===----------------------------------------------------------------------===//
1437 static bool emitCapabilityImplication(const RecordKeeper
&recordKeeper
,
1439 llvm::emitSourceFileHeader("SPIR-V Capability Implication", os
, recordKeeper
);
1442 recordKeeper
.getDef("SPIRV_CapabilityAttr")->getValueAsDef("enum"));
1444 os
<< "ArrayRef<spirv::Capability> "
1445 "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n"
1446 << " switch (cap) {\n"
1447 << " default: return {};\n";
1448 for (const EnumAttrCase
&enumerant
: enumAttr
.getAllCases()) {
1449 const Record
&def
= enumerant
.getDef();
1450 if (!def
.getValue("implies"))
1453 std::vector
<Record
*> impliedCapsDefs
= def
.getValueAsListOfDefs("implies");
1454 os
<< " case spirv::Capability::" << enumerant
.getSymbol()
1455 << ": {static const spirv::Capability implies[" << impliedCapsDefs
.size()
1457 llvm::interleaveComma(impliedCapsDefs
, os
, [&](const Record
*capDef
) {
1458 os
<< "spirv::Capability::" << EnumAttrCase(capDef
).getSymbol();
1460 os
<< "}; return ArrayRef<spirv::Capability>(implies, "
1461 << impliedCapsDefs
.size() << "); }\n";
1469 //===----------------------------------------------------------------------===//
1470 // SPIR-V Capability Implication Hook Registration
1471 //===----------------------------------------------------------------------===//
1473 static mlir::GenRegistration
1474 genCapabilityImplication("gen-spirv-capability-implication",
1475 "Generate utility function to return implied "
1476 "capabilities for a given capability",
1477 [](const RecordKeeper
&records
, raw_ostream
&os
) {
1478 return emitCapabilityImplication(records
, os
);