[Infra] Fix version-check workflow (#100090)
[llvm-project.git] / mlir / tools / mlir-tblgen / SPIRVUtilsGen.cpp
blob9aeb14d14eeca5ac046ff1aec216a35ba953517e
1 //===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // SPIRVSerializationGen generates common utility functions for SPIR-V
10 // serialization.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Operator.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Record.h"
30 #include "llvm/TableGen/TableGenBackend.h"
32 #include <list>
33 #include <optional>
35 using llvm::ArrayRef;
36 using llvm::formatv;
37 using llvm::raw_ostream;
38 using llvm::raw_string_ostream;
39 using llvm::Record;
40 using llvm::RecordKeeper;
41 using llvm::SmallVector;
42 using llvm::SMLoc;
43 using llvm::StringMap;
44 using llvm::StringRef;
45 using mlir::tblgen::Attribute;
46 using mlir::tblgen::EnumAttr;
47 using mlir::tblgen::EnumAttrCase;
48 using mlir::tblgen::NamedAttribute;
49 using mlir::tblgen::NamedTypeConstraint;
50 using mlir::tblgen::NamespaceEmitter;
51 using mlir::tblgen::Operator;
53 //===----------------------------------------------------------------------===//
54 // Availability Wrapper Class
55 //===----------------------------------------------------------------------===//
57 namespace {
58 // Wrapper class with helper methods for accessing availability defined in
59 // TableGen.
60 class Availability {
61 public:
62 explicit Availability(const Record *def);
64 // Returns the name of the direct TableGen class for this availability
65 // instance.
66 StringRef getClass() const;
68 // Returns the generated C++ interface's class namespace.
69 StringRef getInterfaceClassNamespace() const;
71 // Returns the generated C++ interface's class name.
72 StringRef getInterfaceClassName() const;
74 // Returns the generated C++ interface's description.
75 StringRef getInterfaceDescription() const;
77 // Returns the name of the query function insided the generated C++ interface.
78 StringRef getQueryFnName() const;
80 // Returns the return type of the query function insided the generated C++
81 // interface.
82 StringRef getQueryFnRetType() const;
84 // Returns the code for merging availability requirements.
85 StringRef getMergeActionCode() const;
87 // Returns the initializer expression for initializing the final availability
88 // requirements.
89 StringRef getMergeInitializer() const;
91 // Returns the C++ type for an availability instance.
92 StringRef getMergeInstanceType() const;
94 // Returns the C++ statements for preparing availability instance.
95 StringRef getMergeInstancePreparation() const;
97 // Returns the concrete availability instance carried in this case.
98 StringRef getMergeInstance() const;
100 // Returns the underlying LLVM TableGen Record.
101 const llvm::Record *getDef() const { return def; }
103 private:
104 // The TableGen definition of this availability.
105 const llvm::Record *def;
107 } // namespace
109 Availability::Availability(const llvm::Record *def) : def(def) {
110 assert(def->isSubClassOf("Availability") &&
111 "must be subclass of TableGen 'Availability' class");
114 StringRef Availability::getClass() const {
115 SmallVector<Record *, 1> parentClass;
116 def->getDirectSuperClasses(parentClass);
117 if (parentClass.size() != 1) {
118 PrintFatalError(def->getLoc(),
119 "expected to only have one direct superclass");
121 return parentClass.front()->getName();
124 StringRef Availability::getInterfaceClassNamespace() const {
125 return def->getValueAsString("cppNamespace");
128 StringRef Availability::getInterfaceClassName() const {
129 return def->getValueAsString("interfaceName");
132 StringRef Availability::getInterfaceDescription() const {
133 return def->getValueAsString("interfaceDescription");
136 StringRef Availability::getQueryFnRetType() const {
137 return def->getValueAsString("queryFnRetType");
140 StringRef Availability::getQueryFnName() const {
141 return def->getValueAsString("queryFnName");
144 StringRef Availability::getMergeActionCode() const {
145 return def->getValueAsString("mergeAction");
148 StringRef Availability::getMergeInitializer() const {
149 return def->getValueAsString("initializer");
152 StringRef Availability::getMergeInstanceType() const {
153 return def->getValueAsString("instanceType");
156 StringRef Availability::getMergeInstancePreparation() const {
157 return def->getValueAsString("instancePreparation");
160 StringRef Availability::getMergeInstance() const {
161 return def->getValueAsString("instance");
164 // Returns the availability spec of the given `def`.
165 std::vector<Availability> getAvailabilities(const Record &def) {
166 std::vector<Availability> availabilities;
168 if (def.getValue("availability")) {
169 std::vector<Record *> availDefs = def.getValueAsListOfDefs("availability");
170 availabilities.reserve(availDefs.size());
171 for (const Record *avail : availDefs)
172 availabilities.emplace_back(avail);
175 return availabilities;
178 //===----------------------------------------------------------------------===//
179 // Availability Interface Definitions AutoGen
180 //===----------------------------------------------------------------------===//
182 static void emitInterfaceDef(const Availability &availability,
183 raw_ostream &os) {
185 os << availability.getQueryFnRetType() << " ";
187 StringRef cppNamespace = availability.getInterfaceClassNamespace();
188 cppNamespace.consume_front("::");
189 if (!cppNamespace.empty())
190 os << cppNamespace << "::";
192 StringRef methodName = availability.getQueryFnName();
193 os << availability.getInterfaceClassName() << "::" << methodName << "() {\n"
194 << " return getImpl()->" << methodName << "(getImpl(), getOperation());\n"
195 << "}\n";
198 static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
199 raw_ostream &os) {
200 llvm::emitSourceFileHeader("Availability Interface Definitions", os,
201 recordKeeper);
203 auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
204 SmallVector<const Record *, 1> handledClasses;
205 for (const Record *def : defs) {
206 SmallVector<Record *, 1> parent;
207 def->getDirectSuperClasses(parent);
208 if (parent.size() != 1) {
209 PrintFatalError(def->getLoc(),
210 "expected to only have one direct superclass");
212 if (llvm::is_contained(handledClasses, parent.front()))
213 continue;
215 Availability availability(def);
216 emitInterfaceDef(availability, os);
217 handledClasses.push_back(parent.front());
219 return false;
222 //===----------------------------------------------------------------------===//
223 // Availability Interface Declarations AutoGen
224 //===----------------------------------------------------------------------===//
226 static void emitConceptDecl(const Availability &availability, raw_ostream &os) {
227 os << " class Concept {\n"
228 << " public:\n"
229 << " virtual ~Concept() = default;\n"
230 << " virtual " << availability.getQueryFnRetType() << " "
231 << availability.getQueryFnName()
232 << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n"
233 << " };\n";
236 static void emitModelDecl(const Availability &availability, raw_ostream &os) {
237 for (const char *modelClass : {"Model", "FallbackModel"}) {
238 os << " template<typename ConcreteOp>\n";
239 os << " class " << modelClass << " : public Concept {\n"
240 << " public:\n"
241 << " using Interface = " << availability.getInterfaceClassName()
242 << ";\n"
243 << " " << availability.getQueryFnRetType() << " "
244 << availability.getQueryFnName()
245 << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n"
246 << " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
247 << " (void)op;\n"
248 // Forward to the method on the concrete operation type.
249 << " return op." << availability.getQueryFnName() << "();\n"
250 << " }\n"
251 << " };\n";
253 os << " template<typename ConcreteModel, typename ConcreteOp>\n";
254 os << " class ExternalModel : public FallbackModel<ConcreteOp> {};\n";
257 static void emitInterfaceDecl(const Availability &availability,
258 raw_ostream &os) {
259 StringRef interfaceName = availability.getInterfaceClassName();
260 std::string interfaceTraitsName =
261 std::string(formatv("{0}Traits", interfaceName));
263 StringRef cppNamespace = availability.getInterfaceClassNamespace();
264 NamespaceEmitter nsEmitter(os, cppNamespace);
265 os << "class " << interfaceName << ";\n\n";
267 // Emit the traits struct containing the concept and model declarations.
268 os << "namespace detail {\n"
269 << "struct " << interfaceTraitsName << " {\n";
270 emitConceptDecl(availability, os);
271 os << '\n';
272 emitModelDecl(availability, os);
273 os << "};\n} // namespace detail\n\n";
275 // Emit the main interface class declaration.
276 os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n";
277 os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
278 "public:\n"
279 " using OpInterface<{1}, detail::{2}>::OpInterface;\n",
280 interfaceName, interfaceName, interfaceTraitsName);
282 // Emit query function declaration.
283 os << " " << availability.getQueryFnRetType() << " "
284 << availability.getQueryFnName() << "();\n";
285 os << "};\n\n";
288 static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
289 raw_ostream &os) {
290 llvm::emitSourceFileHeader("Availability Interface Declarations", os,
291 recordKeeper);
293 auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
294 SmallVector<const Record *, 4> handledClasses;
295 for (const Record *def : defs) {
296 SmallVector<Record *, 1> parent;
297 def->getDirectSuperClasses(parent);
298 if (parent.size() != 1) {
299 PrintFatalError(def->getLoc(),
300 "expected to only have one direct superclass");
302 if (llvm::is_contained(handledClasses, parent.front()))
303 continue;
305 Availability avail(def);
306 emitInterfaceDecl(avail, os);
307 handledClasses.push_back(parent.front());
309 return false;
312 //===----------------------------------------------------------------------===//
313 // Availability Interface Hook Registration
314 //===----------------------------------------------------------------------===//
316 // Registers the operation interface generator to mlir-tblgen.
317 static mlir::GenRegistration
318 genInterfaceDecls("gen-avail-interface-decls",
319 "Generate availability interface declarations",
320 [](const RecordKeeper &records, raw_ostream &os) {
321 return emitInterfaceDecls(records, os);
324 // Registers the operation interface generator to mlir-tblgen.
325 static mlir::GenRegistration
326 genInterfaceDefs("gen-avail-interface-defs",
327 "Generate op interface definitions",
328 [](const RecordKeeper &records, raw_ostream &os) {
329 return emitInterfaceDefs(records, os);
332 //===----------------------------------------------------------------------===//
333 // Enum Availability Query AutoGen
334 //===----------------------------------------------------------------------===//
336 static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
337 raw_ostream &os) {
338 EnumAttr enumAttr(enumDef);
339 StringRef enumName = enumAttr.getEnumClassName();
340 std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
342 // Mapping from availability class name to (enumerant, availability
343 // specification) pairs.
344 llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
345 classCaseMap;
347 // Place all availability specifications to their corresponding
348 // availability classes.
349 for (const EnumAttrCase &enumerant : enumerants)
350 for (const Availability &avail : getAvailabilities(enumerant.getDef()))
351 classCaseMap[avail.getClass()].push_back({enumerant, avail});
353 for (const auto &classCasePair : classCaseMap) {
354 Availability avail = classCasePair.getValue().front().second;
356 os << formatv("std::optional<{0}> {1}({2} value) {{\n",
357 avail.getMergeInstanceType(), avail.getQueryFnName(),
358 enumName);
360 os << " switch (value) {\n";
361 for (const auto &caseSpecPair : classCasePair.getValue()) {
362 EnumAttrCase enumerant = caseSpecPair.first;
363 Availability avail = caseSpecPair.second;
364 os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
365 enumerant.getSymbol(), avail.getMergeInstancePreparation(),
366 avail.getMergeInstanceType(), avail.getMergeInstance());
368 // Only emit default if uncovered cases.
369 if (classCasePair.getValue().size() < enumAttr.getAllCases().size())
370 os << " default: break;\n";
371 os << " }\n"
372 << " return std::nullopt;\n"
373 << "}\n";
377 static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
378 raw_ostream &os) {
379 EnumAttr enumAttr(enumDef);
380 StringRef enumName = enumAttr.getEnumClassName();
381 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
382 std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
384 // Mapping from availability class name to (enumerant, availability
385 // specification) pairs.
386 llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
387 classCaseMap;
389 // Place all availability specifications to their corresponding
390 // availability classes.
391 for (const EnumAttrCase &enumerant : enumerants)
392 for (const Availability &avail : getAvailabilities(enumerant.getDef()))
393 classCaseMap[avail.getClass()].push_back({enumerant, avail});
395 for (const auto &classCasePair : classCaseMap) {
396 Availability avail = classCasePair.getValue().front().second;
398 os << formatv("std::optional<{0}> {1}({2} value) {{\n",
399 avail.getMergeInstanceType(), avail.getQueryFnName(),
400 enumName);
402 os << formatv(
403 " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1"
404 " && \"cannot have more than one bit set\");\n",
405 underlyingType);
407 os << " switch (value) {\n";
408 for (const auto &caseSpecPair : classCasePair.getValue()) {
409 EnumAttrCase enumerant = caseSpecPair.first;
410 Availability avail = caseSpecPair.second;
411 os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
412 enumerant.getSymbol(), avail.getMergeInstancePreparation(),
413 avail.getMergeInstanceType(), avail.getMergeInstance());
415 os << " default: break;\n";
416 os << " }\n"
417 << " return std::nullopt;\n"
418 << "}\n";
422 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
423 EnumAttr enumAttr(enumDef);
424 StringRef enumName = enumAttr.getEnumClassName();
425 StringRef cppNamespace = enumAttr.getCppNamespace();
426 auto enumerants = enumAttr.getAllCases();
428 llvm::SmallVector<StringRef, 2> namespaces;
429 llvm::SplitString(cppNamespace, namespaces, "::");
431 for (auto ns : namespaces)
432 os << "namespace " << ns << " {\n";
434 llvm::StringSet<> handledClasses;
436 // Place all availability specifications to their corresponding
437 // availability classes.
438 for (const EnumAttrCase &enumerant : enumerants)
439 for (const Availability &avail : getAvailabilities(enumerant.getDef())) {
440 StringRef className = avail.getClass();
441 if (handledClasses.count(className))
442 continue;
443 os << formatv("std::optional<{0}> {1}({2} value);\n",
444 avail.getMergeInstanceType(), avail.getQueryFnName(),
445 enumName);
446 handledClasses.insert(className);
449 for (auto ns : llvm::reverse(namespaces))
450 os << "} // namespace " << ns << "\n";
453 static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
454 llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os,
455 recordKeeper);
457 auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
458 for (const auto *def : defs)
459 emitEnumDecl(*def, os);
461 return false;
464 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
465 EnumAttr enumAttr(enumDef);
466 StringRef cppNamespace = enumAttr.getCppNamespace();
468 llvm::SmallVector<StringRef, 2> namespaces;
469 llvm::SplitString(cppNamespace, namespaces, "::");
471 for (auto ns : namespaces)
472 os << "namespace " << ns << " {\n";
474 if (enumAttr.isBitEnum()) {
475 emitAvailabilityQueryForBitEnum(enumDef, os);
476 } else {
477 emitAvailabilityQueryForIntEnum(enumDef, os);
480 for (auto ns : llvm::reverse(namespaces))
481 os << "} // namespace " << ns << "\n";
482 os << "\n";
485 static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
486 llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os,
487 recordKeeper);
489 auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
490 for (const auto *def : defs)
491 emitEnumDef(*def, os);
493 return false;
496 //===----------------------------------------------------------------------===//
497 // Enum Availability Query Hook Registration
498 //===----------------------------------------------------------------------===//
500 // Registers the enum utility generator to mlir-tblgen.
501 static mlir::GenRegistration
502 genEnumDecls("gen-spirv-enum-avail-decls",
503 "Generate SPIR-V enum availability declarations",
504 [](const RecordKeeper &records, raw_ostream &os) {
505 return emitEnumDecls(records, os);
508 // Registers the enum utility generator to mlir-tblgen.
509 static mlir::GenRegistration
510 genEnumDefs("gen-spirv-enum-avail-defs",
511 "Generate SPIR-V enum availability definitions",
512 [](const RecordKeeper &records, raw_ostream &os) {
513 return emitEnumDefs(records, os);
516 //===----------------------------------------------------------------------===//
517 // Serialization AutoGen
518 //===----------------------------------------------------------------------===//
520 // These enums are encoded as <id> to constant values in SPIR-V blob, but we
521 // directly use the constant value as attribute in SPIR-V dialect. So need
522 // to handle them separately from normal enum attributes.
523 constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
524 "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
525 "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
526 "SPIRV_MatrixLayoutAttr"};
528 /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
529 /// generates code extracts the attribute with name `attrName` from
530 /// `operandList` of `op`.
531 static void emitAttributeSerialization(const Attribute &attr,
532 ArrayRef<SMLoc> loc, StringRef tabs,
533 StringRef opVar, StringRef operandList,
534 StringRef attrName, raw_ostream &os) {
535 os << tabs
536 << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
537 if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
538 EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
539 os << tabs
540 << formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), "
541 "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>("
542 "::llvm::cast<{2}::{3}Attr>(attr).getValue()))));\n",
543 operandList, opVar, baseEnum.getCppNamespace(),
544 baseEnum.getEnumClassName());
545 } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") ||
546 attr.isSubClassOf("SPIRV_I32EnumAttr")) {
547 EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
548 os << tabs
549 << formatv(" {0}.push_back(static_cast<uint32_t>("
550 "::llvm::cast<{1}::{2}Attr>(attr).getValue()));\n",
551 operandList, baseEnum.getCppNamespace(),
552 baseEnum.getEnumClassName());
553 } else if (attr.getAttrDefName() == "I32ArrayAttr") {
554 // Serialize all the elements of the array
555 os << tabs << " for (auto attrElem : llvm::cast<ArrayAttr>(attr)) {\n";
556 os << tabs
557 << formatv(" {0}.push_back(static_cast<uint32_t>("
558 "llvm::cast<IntegerAttr>(attrElem).getValue().getZExtValue())"
559 ");\n",
560 operandList);
561 os << tabs << " }\n";
562 } else if (attr.getAttrDefName() == "I32Attr") {
563 os << tabs
564 << formatv(
565 " {0}.push_back(static_cast<uint32_t>("
566 "llvm::cast<IntegerAttr>(attr).getValue().getZExtValue()));\n",
567 operandList);
568 } else if (attr.isEnumAttr() || attr.isTypeAttr()) {
569 // It may be the first time this type appears in the IR, so we need to
570 // process it.
571 StringRef attrTypeID = "attrTypeID";
572 os << tabs << formatv(" uint32_t {0} = 0;\n", attrTypeID);
573 os << tabs
574 << formatv(" if (failed(processType({0}.getLoc(), "
575 "llvm::cast<TypeAttr>(attr).getValue(), {1}))) {{\n",
576 opVar, attrTypeID);
577 os << tabs << " return failure();\n";
578 os << tabs << " }\n";
579 os << tabs << formatv(" {0}.push_back(attrTypeID);\n", operandList);
580 } else {
581 PrintFatalError(
582 loc,
583 llvm::Twine(
584 "unhandled attribute type in SPIR-V serialization generation : '") +
585 attr.getAttrDefName() + llvm::Twine("'"));
587 os << tabs << "}\n";
590 /// Generates code to serialize the operands of a SPIRV_Op `op` into `os`. The
591 /// generated queries the SSA-ID if operand is a SSA-Value, or serializes the
592 /// attributes. The `operands` vector is updated appropriately. `elidedAttrs`
593 /// updated as well to include the serialized attributes.
594 static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
595 StringRef tabs, StringRef opVar,
596 StringRef operands, StringRef elidedAttrs,
597 raw_ostream &os) {
598 using mlir::tblgen::Argument;
600 // SPIR-V ops can mix operands and attributes in the definition. These
601 // operands and attributes are serialized in the exact order of the definition
602 // to match SPIR-V binary format requirements. It can cause excessive
603 // generated code bloat because we are emitting code to handle each
604 // operand/attribute separately. So here we probe first to check whether all
605 // the operands are ahead of attributes. Then we can serialize all operands
606 // together.
608 // Whether all operands are ahead of all attributes in the op's spec.
609 bool areOperandsAheadOfAttrs = true;
610 // Find the first attribute.
611 const Argument *it = llvm::find_if(op.getArgs(), [](const Argument &arg) {
612 return arg.is<NamedAttribute *>();
614 // Check whether all following arguments are attributes.
615 for (const Argument *ie = op.arg_end(); it != ie; ++it) {
616 if (!it->is<NamedAttribute *>()) {
617 areOperandsAheadOfAttrs = false;
618 break;
622 // Serialize all operands together.
623 if (areOperandsAheadOfAttrs) {
624 if (op.getNumOperands() != 0) {
625 os << tabs
626 << formatv("for (Value operand : {0}->getOperands()) {{\n", opVar);
627 os << tabs << " auto id = getValueID(operand);\n";
628 os << tabs << " assert(id && \"use before def!\");\n";
629 os << tabs << formatv(" {0}.push_back(id);\n", operands);
630 os << tabs << "}\n";
632 for (const NamedAttribute &attr : op.getAttributes()) {
633 emitAttributeSerialization(
634 (attr.attr.isOptional() ? attr.attr.getBaseAttr() : attr.attr), loc,
635 tabs, opVar, operands, attr.name, os);
636 os << tabs
637 << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr.name);
639 return;
642 // Serialize operands separately.
643 auto operandNum = 0;
644 for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
645 auto argument = op.getArg(i);
646 os << tabs << "{\n";
647 if (argument.is<NamedTypeConstraint *>()) {
648 os << tabs
649 << formatv(" for (auto arg : {0}.getODSOperands({1})) {{\n", opVar,
650 operandNum);
651 os << tabs << " auto argID = getValueID(arg);\n";
652 os << tabs << " if (!argID) {\n";
653 os << tabs
654 << formatv(" return emitError({0}.getLoc(), "
655 "\"operand #{1} has a use before def\");\n",
656 opVar, operandNum);
657 os << tabs << " }\n";
658 os << tabs << formatv(" {0}.push_back(argID);\n", operands);
659 os << " }\n";
660 operandNum++;
661 } else {
662 NamedAttribute *attr = argument.get<NamedAttribute *>();
663 auto newtabs = tabs.str() + " ";
664 emitAttributeSerialization(
665 (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
666 loc, newtabs, opVar, operands, attr->name, os);
667 os << newtabs
668 << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr->name);
670 os << tabs << "}\n";
674 /// Generates code to serializes the result of SPIRV_Op `op` into `os`. The
675 /// generated gets the ID for the type of the result (if any), the SSA-ID of
676 /// the result and updates `resultID` with the SSA-ID.
677 static void emitResultSerialization(const Operator &op, ArrayRef<SMLoc> loc,
678 StringRef tabs, StringRef opVar,
679 StringRef operands, StringRef resultID,
680 raw_ostream &os) {
681 if (op.getNumResults() == 1) {
682 StringRef resultTypeID("resultTypeID");
683 os << tabs << formatv("uint32_t {0} = 0;\n", resultTypeID);
684 os << tabs
685 << formatv(
686 "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n",
687 opVar, resultTypeID);
688 os << tabs << " return failure();\n";
689 os << tabs << "}\n";
690 os << tabs << formatv("{0}.push_back({1});\n", operands, resultTypeID);
691 // Create an SSA result <id> for the op
692 os << tabs << formatv("{0} = getNextID();\n", resultID);
693 os << tabs
694 << formatv("valueIDMap[{0}.getResult()] = {1};\n", opVar, resultID);
695 os << tabs << formatv("{0}.push_back({1});\n", operands, resultID);
696 } else if (op.getNumResults() != 0) {
697 PrintFatalError(loc, "SPIR-V ops can only have zero or one result");
701 /// Generates code to serialize attributes of SPIRV_Op `op` that become
702 /// decorations on the `resultID` of the serialized operation `opVar` in the
703 /// SPIR-V binary.
704 static void emitDecorationSerialization(const Operator &op, StringRef tabs,
705 StringRef opVar, StringRef elidedAttrs,
706 StringRef resultID, raw_ostream &os) {
707 if (op.getNumResults() == 1) {
708 // All non-argument attributes translated into OpDecorate instruction
709 os << tabs << formatv("for (auto attr : {0}->getAttrs()) {{\n", opVar);
710 os << tabs
711 << formatv(" if (llvm::is_contained({0}, attr.getName())) {{",
712 elidedAttrs);
713 os << tabs << " continue;\n";
714 os << tabs << " }\n";
715 os << tabs
716 << formatv(
717 " if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n",
718 opVar, resultID);
719 os << tabs << " return failure();\n";
720 os << tabs << " }\n";
721 os << tabs << "}\n";
725 /// Generates code to serialize an SPIRV_Op `op` into `os`.
726 static void emitSerializationFunction(const Record *attrClass,
727 const Record *record, const Operator &op,
728 raw_ostream &os) {
729 // If the record has 'autogenSerialization' set to 0, nothing to do
730 if (!record->getValueAsBit("autogenSerialization"))
731 return;
733 StringRef opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"),
734 resultID("resultID");
736 os << formatv(
737 "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n",
738 op.getQualCppClassName(), opVar);
740 // Special case for ops without attributes in TableGen definitions
741 if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
742 std::string extInstSet;
743 std::string opcode;
744 if (record->isSubClassOf("SPIRV_ExtInstOp")) {
745 extInstSet =
746 formatv("\"{0}\"", record->getValueAsString("extendedInstSetName"));
747 opcode = std::to_string(record->getValueAsInt("extendedInstOpcode"));
748 } else {
749 extInstSet = "\"\"";
750 opcode = formatv("static_cast<uint32_t>(spirv::Opcode::{0})",
751 record->getValueAsString("spirvOpName"));
754 os << formatv(" return processOpWithoutGrammarAttr({0}, {1}, {2});\n}\n\n",
755 opVar, extInstSet, opcode);
756 return;
759 os << formatv(" SmallVector<uint32_t, 4> {0};\n", operands);
760 os << formatv(" SmallVector<StringRef, 2> {0};\n", elidedAttrs);
762 // Serialize result information.
763 if (op.getNumResults() == 1) {
764 os << formatv(" uint32_t {0} = 0;\n", resultID);
765 emitResultSerialization(op, record->getLoc(), " ", opVar, operands,
766 resultID, os);
769 // Process arguments.
770 emitArgumentSerialization(op, record->getLoc(), " ", opVar, operands,
771 elidedAttrs, os);
773 if (record->isSubClassOf("SPIRV_ExtInstOp")) {
774 os << formatv(
775 " (void)encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n", opVar,
776 record->getValueAsString("extendedInstSetName"),
777 record->getValueAsInt("extendedInstOpcode"), operands);
778 } else {
779 // Emit debug info.
780 os << formatv(" (void)emitDebugLine(functionBody, {0}.getLoc());\n",
781 opVar);
782 os << formatv(" (void)encodeInstructionInto("
783 "functionBody, spirv::Opcode::{1}, {2});\n",
784 op.getQualCppClassName(),
785 record->getValueAsString("spirvOpName"), operands);
788 // Process decorations.
789 emitDecorationSerialization(op, " ", opVar, elidedAttrs, resultID, os);
791 os << " return success();\n";
792 os << "}\n\n";
795 /// Generates the prologue for the function that dispatches the serialization of
796 /// the operation `opVar` based on its opcode.
797 static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
798 os << formatv(
799 "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
800 "*{0}) {{\n",
801 opVar);
804 /// Generates the body of the dispatch function. This function generates the
805 /// check that if satisfied, will call the serialization function generated for
806 /// the `op`.
807 static void emitSerializationDispatch(const Operator &op, StringRef tabs,
808 StringRef opVar, raw_ostream &os) {
809 os << tabs
810 << formatv("if (isa<{0}>({1})) {{\n", op.getQualCppClassName(), opVar);
811 os << tabs
812 << formatv(" return processOp(cast<{0}>({1}));\n",
813 op.getQualCppClassName(), opVar);
814 os << tabs << "}\n";
817 /// Generates the epilogue for the function that dispatches the serialization of
818 /// the operation.
819 static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
820 os << formatv(
821 " return {0}->emitError(\"unhandled operation serialization\");\n",
822 opVar);
823 os << "}\n\n";
826 /// Generates code to deserialize the attribute of a SPIRV_Op into `os`. The
827 /// generated code reads the `words` of the serialized instruction at
828 /// position `wordIndex` and adds the deserialized attribute into `attrList`.
829 static void emitAttributeDeserialization(const Attribute &attr,
830 ArrayRef<SMLoc> loc, StringRef tabs,
831 StringRef attrList, StringRef attrName,
832 StringRef words, StringRef wordIndex,
833 raw_ostream &os) {
834 if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
835 EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
836 os << tabs
837 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
838 "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>("
839 "getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n",
840 attrList, attrName, baseEnum.getCppNamespace(),
841 baseEnum.getEnumClassName(), words, wordIndex);
842 } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") ||
843 attr.isSubClassOf("SPIRV_I32EnumAttr")) {
844 EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
845 os << tabs
846 << formatv(" {0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
847 "opBuilder.getAttr<{2}::{3}Attr>("
848 "static_cast<{2}::{3}>({4}[{5}++]))));\n",
849 attrList, attrName, baseEnum.getCppNamespace(),
850 baseEnum.getEnumClassName(), words, wordIndex);
851 } else if (attr.getAttrDefName() == "I32ArrayAttr") {
852 os << tabs << "SmallVector<Attribute, 4> attrListElems;\n";
853 os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words);
854 os << tabs
855 << formatv(
857 "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))"
858 ";\n",
859 words, wordIndex);
860 os << tabs << "}\n";
861 os << tabs
862 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
863 "opBuilder.getArrayAttr(attrListElems)));\n",
864 attrList, attrName);
865 } else if (attr.getAttrDefName() == "I32Attr") {
866 os << tabs
867 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
868 "opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
869 attrList, attrName, words, wordIndex);
870 } else if (attr.isEnumAttr() || attr.isTypeAttr()) {
871 os << tabs
872 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
873 "TypeAttr::get(getType({2}[{3}++]))));\n",
874 attrList, attrName, words, wordIndex);
875 } else {
876 PrintFatalError(
877 loc, llvm::Twine(
878 "unhandled attribute type in deserialization generation : '") +
879 attrName + llvm::Twine("'"));
883 /// Generates the code to deserialize the result of an SPIRV_Op `op` into
884 /// `os`. The generated code gets the type of the result specified at
885 /// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1
886 /// and updates the `resultType` and `valueID` with the parsed type and SSA ID,
887 /// respectively.
888 static void emitResultDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
889 StringRef tabs, StringRef words,
890 StringRef wordIndex,
891 StringRef resultTypes, StringRef valueID,
892 raw_ostream &os) {
893 // Deserialize result information if it exists
894 if (op.getNumResults() == 1) {
895 os << tabs << "{\n";
896 os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words);
897 os << tabs
898 << formatv(
899 " return emitError(unknownLoc, \"expected result type <id> "
900 "while deserializing {0}\");\n",
901 op.getQualCppClassName());
902 os << tabs << " }\n";
903 os << tabs << formatv(" auto ty = getType({0}[{1}]);\n", words, wordIndex);
904 os << tabs << " if (!ty) {\n";
905 os << tabs
906 << formatv(
907 " return emitError(unknownLoc, \"unknown type result <id> : "
908 "\") << {0}[{1}];\n",
909 words, wordIndex);
910 os << tabs << " }\n";
911 os << tabs << formatv(" {0}.push_back(ty);\n", resultTypes);
912 os << tabs << formatv(" {0}++;\n", wordIndex);
913 os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words);
914 os << tabs
915 << formatv(
916 " return emitError(unknownLoc, \"expected result <id> while "
917 "deserializing {0}\");\n",
918 op.getQualCppClassName());
919 os << tabs << " }\n";
920 os << tabs << "}\n";
921 os << tabs << formatv("{0} = {1}[{2}++];\n", valueID, words, wordIndex);
922 } else if (op.getNumResults() != 0) {
923 PrintFatalError(loc, "SPIR-V ops can have only zero or one result");
927 /// Generates the code to deserialize the operands of an SPIRV_Op `op` into
928 /// `os`. The generated code reads the `words` of the binary instruction, from
929 /// position `wordIndex` to the end, and either gets the Value corresponding to
930 /// the ID encoded, or deserializes the attributes encoded. The parsed
931 /// operand(attribute) is added to the `operands` list or `attributes` list.
932 static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
933 StringRef tabs, StringRef words,
934 StringRef wordIndex, StringRef operands,
935 StringRef attributes, raw_ostream &os) {
936 // Process operands/attributes
937 for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
938 auto argument = op.getArg(i);
939 if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) {
940 if (valueArg->isVariableLength()) {
941 if (i != e - 1) {
942 PrintFatalError(
943 loc, "SPIR-V ops can have Variadic<..> or "
944 "Optional<...> arguments only if it's the last argument");
946 os << tabs
947 << formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words);
948 } else {
949 os << tabs << formatv("if ({0} < {1}.size())", wordIndex, words);
951 os << " {\n";
952 os << tabs
953 << formatv(" auto arg = getValue({0}[{1}]);\n", words, wordIndex);
954 os << tabs << " if (!arg) {\n";
955 os << tabs
956 << formatv(
957 " return emitError(unknownLoc, \"unknown result <id> : \") "
958 "<< {0}[{1}];\n",
959 words, wordIndex);
960 os << tabs << " }\n";
961 os << tabs << formatv(" {0}.push_back(arg);\n", operands);
962 if (!valueArg->isVariableLength()) {
963 os << tabs << formatv(" {0}++;\n", wordIndex);
965 os << tabs << "}\n";
966 } else {
967 os << tabs << formatv("if ({0} < {1}.size()) {{\n", wordIndex, words);
968 auto *attr = argument.get<NamedAttribute *>();
969 auto newtabs = tabs.str() + " ";
970 emitAttributeDeserialization(
971 (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
972 loc, newtabs, attributes, attr->name, words, wordIndex, os);
973 os << " }\n";
977 os << tabs << formatv("if ({0} != {1}.size()) {{\n", wordIndex, words);
978 os << tabs
979 << formatv(
980 " return emitError(unknownLoc, \"found more operands than "
981 "expected when deserializing {0}, only \") << {1} << \" of \" << "
982 "{2}.size() << \" processed\";\n",
983 op.getQualCppClassName(), wordIndex, words);
984 os << tabs << "}\n\n";
987 /// Generates code to update the `attributes` vector with the attributes
988 /// obtained from parsing the decorations in the SPIR-V binary associated with
989 /// an <id> `valueID`
990 static void emitDecorationDeserialization(const Operator &op, StringRef tabs,
991 StringRef valueID,
992 StringRef attributes,
993 raw_ostream &os) {
994 // Import decorations parsed
995 if (op.getNumResults() == 1) {
996 os << tabs << formatv("if (decorations.count({0})) {{\n", valueID);
997 os << tabs
998 << formatv(" auto attrs = decorations[{0}].getAttrs();\n", valueID);
999 os << tabs
1000 << formatv(" {0}.append(attrs.begin(), attrs.end());\n", attributes);
1001 os << tabs << "}\n";
1005 /// Generates code to deserialize an SPIRV_Op `op` into `os`.
1006 static void emitDeserializationFunction(const Record *attrClass,
1007 const Record *record,
1008 const Operator &op, raw_ostream &os) {
1009 // If the record has 'autogenSerialization' set to 0, nothing to do
1010 if (!record->getValueAsBit("autogenSerialization"))
1011 return;
1013 StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"),
1014 wordIndex("wordIndex"), opVar("op"), operands("operands"),
1015 attributes("attributes");
1017 // Method declaration
1018 os << formatv("template <> "
1019 "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
1020 "uint32_t> {1}) {{\n",
1021 op.getQualCppClassName(), words);
1023 // Special case for ops without attributes in TableGen definitions
1024 if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
1025 os << formatv(" return processOpWithoutGrammarAttr("
1026 "{0}, \"{1}\", {2}, {3});\n}\n\n",
1027 words, op.getOperationName(),
1028 op.getNumResults() ? "true" : "false", op.getNumOperands());
1029 return;
1032 os << formatv(" SmallVector<Type, 1> {0};\n", resultTypes);
1033 os << formatv(" size_t {0} = 0; (void){0};\n", wordIndex);
1034 os << formatv(" uint32_t {0} = 0; (void){0};\n", valueID);
1036 // Deserialize result information
1037 emitResultDeserialization(op, record->getLoc(), " ", words, wordIndex,
1038 resultTypes, valueID, os);
1040 os << formatv(" SmallVector<Value, 4> {0};\n", operands);
1041 os << formatv(" SmallVector<NamedAttribute, 4> {0};\n", attributes);
1042 // Operand deserialization
1043 emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex,
1044 operands, attributes, os);
1046 // Decorations
1047 emitDecorationDeserialization(op, " ", valueID, attributes, os);
1049 os << formatv(" Location loc = createFileLineColLoc(opBuilder);\n");
1050 os << formatv(" auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); "
1051 "(void){1};\n",
1052 op.getQualCppClassName(), opVar, resultTypes, operands,
1053 attributes);
1054 if (op.getNumResults() == 1) {
1055 os << formatv(" valueMap[{0}] = {1}.getResult();\n\n", valueID, opVar);
1058 // According to SPIR-V spec:
1059 // This location information applies to the instructions physically following
1060 // this instruction, up to the first occurrence of any of the following: the
1061 // next end of block.
1062 os << formatv(" if ({0}.hasTrait<OpTrait::IsTerminator>())\n", opVar);
1063 os << formatv(" (void)clearDebugLine();\n");
1064 os << " return success();\n";
1065 os << "}\n\n";
1068 /// Generates the prologue for the function that dispatches the deserialization
1069 /// based on the `opcode`.
1070 static void initDispatchDeserializationFn(StringRef opcode, StringRef words,
1071 raw_ostream &os) {
1072 os << formatv("LogicalResult spirv::Deserializer::"
1073 "dispatchToAutogenDeserialization(spirv::Opcode {0},"
1074 " ArrayRef<uint32_t> {1}) {{\n",
1075 opcode, words);
1076 os << formatv(" switch ({0}) {{\n", opcode);
1079 /// Generates the body of the dispatch function, by generating the case label
1080 /// for an opcode and the call to the method to perform the deserialization.
1081 static void emitDeserializationDispatch(const Operator &op, const Record *def,
1082 StringRef tabs, StringRef words,
1083 raw_ostream &os) {
1084 os << tabs
1085 << formatv("case spirv::Opcode::{0}:\n",
1086 def->getValueAsString("spirvOpName"));
1087 os << tabs
1088 << formatv(" return processOp<{0}>({1});\n", op.getQualCppClassName(),
1089 words);
1092 /// Generates the epilogue for the function that dispatches the deserialization
1093 /// of the operation.
1094 static void finalizeDispatchDeserializationFn(StringRef opcode,
1095 raw_ostream &os) {
1096 os << " default:\n";
1097 os << " ;\n";
1098 os << " }\n";
1099 StringRef opcodeVar("opcodeString");
1100 os << formatv(" auto {0} = spirv::stringifyOpcode({1});\n", opcodeVar,
1101 opcode);
1102 os << formatv(" if (!{0}.empty()) {{\n", opcodeVar);
1103 os << formatv(" return emitError(unknownLoc, \"unhandled deserialization "
1104 "of \") << {0};\n",
1105 opcodeVar);
1106 os << " } else {\n";
1107 os << formatv(" return emitError(unknownLoc, \"unhandled opcode \") << "
1108 "static_cast<uint32_t>({0});\n",
1109 opcode);
1110 os << " }\n";
1111 os << "}\n";
1114 static void initExtendedSetDeserializationDispatch(StringRef extensionSetName,
1115 StringRef instructionID,
1116 StringRef words,
1117 raw_ostream &os) {
1118 os << formatv("LogicalResult spirv::Deserializer::"
1119 "dispatchToExtensionSetAutogenDeserialization("
1120 "StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n",
1121 extensionSetName, instructionID, words);
1124 static void
1125 emitExtendedSetDeserializationDispatch(const RecordKeeper &recordKeeper,
1126 raw_ostream &os) {
1127 StringRef extensionSetName("extensionSetName"),
1128 instructionID("instructionID"), words("words");
1130 // First iterate over all ops derived from SPIRV_ExtensionSetOps to get all
1131 // extensionSets.
1133 // For each of the extensions a separate raw_string_ostream is used to
1134 // generate code into. These are then concatenated at the end. Since
1135 // raw_string_ostream needs a string&, use a vector to store all the string
1136 // that are captured by reference within raw_string_ostream.
1137 StringMap<raw_string_ostream> extensionSets;
1138 std::list<std::string> extensionSetNames;
1140 initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words,
1141 os);
1142 auto defs = recordKeeper.getAllDerivedDefinitions("SPIRV_ExtInstOp");
1143 for (const auto *def : defs) {
1144 if (!def->getValueAsBit("autogenSerialization")) {
1145 continue;
1147 Operator op(def);
1148 auto setName = def->getValueAsString("extendedInstSetName");
1149 if (!extensionSets.count(setName)) {
1150 extensionSetNames.emplace_back("");
1151 extensionSets.try_emplace(setName, extensionSetNames.back());
1152 auto &setos = extensionSets.find(setName)->second;
1153 setos << formatv(" if ({0} == \"{1}\") {{\n", extensionSetName, setName);
1154 setos << formatv(" switch ({0}) {{\n", instructionID);
1156 auto &setos = extensionSets.find(setName)->second;
1157 setos << formatv(" case {0}:\n",
1158 def->getValueAsInt("extendedInstOpcode"));
1159 setos << formatv(" return processOp<{0}>({1});\n",
1160 op.getQualCppClassName(), words);
1163 // Append the dispatch code for all the extended sets.
1164 for (auto &extensionSet : extensionSets) {
1165 os << extensionSet.second.str();
1166 os << " default:\n";
1167 os << formatv(
1168 " return emitError(unknownLoc, \"unhandled deserializations of "
1169 "\") << {0} << \" from extension set \" << {1};\n",
1170 instructionID, extensionSetName);
1171 os << " }\n";
1172 os << " }\n";
1175 os << formatv(" return emitError(unknownLoc, \"unhandled deserialization of "
1176 "extended instruction set {0}\");\n",
1177 extensionSetName);
1178 os << "}\n";
1181 /// Emits all the autogenerated serialization/deserializations functions for the
1182 /// SPIRV_Ops.
1183 static bool emitSerializationFns(const RecordKeeper &recordKeeper,
1184 raw_ostream &os) {
1185 llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os,
1186 recordKeeper);
1188 std::string dSerFnString, dDesFnString, serFnString, deserFnString,
1189 utilsString;
1190 raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
1191 serFn(serFnString), deserFn(deserFnString);
1192 Record *attrClass = recordKeeper.getClass("Attr");
1194 // Emit the serialization and deserialization functions simultaneously.
1195 StringRef opVar("op");
1196 StringRef opcode("opcode"), words("words");
1198 // Handle the SPIR-V ops.
1199 initDispatchSerializationFn(opVar, dSerFn);
1200 initDispatchDeserializationFn(opcode, words, dDesFn);
1201 auto defs = recordKeeper.getAllDerivedDefinitions("SPIRV_Op");
1202 for (const auto *def : defs) {
1203 Operator op(def);
1204 emitSerializationFunction(attrClass, def, op, serFn);
1205 emitDeserializationFunction(attrClass, def, op, deserFn);
1206 if (def->getValueAsBit("hasOpcode") ||
1207 def->isSubClassOf("SPIRV_ExtInstOp")) {
1208 emitSerializationDispatch(op, " ", opVar, dSerFn);
1210 if (def->getValueAsBit("hasOpcode")) {
1211 emitDeserializationDispatch(op, def, " ", words, dDesFn);
1214 finalizeDispatchSerializationFn(opVar, dSerFn);
1215 finalizeDispatchDeserializationFn(opcode, dDesFn);
1217 emitExtendedSetDeserializationDispatch(recordKeeper, dDesFn);
1219 os << "#ifdef GET_SERIALIZATION_FNS\n\n";
1220 os << serFn.str();
1221 os << dSerFn.str();
1222 os << "#endif // GET_SERIALIZATION_FNS\n\n";
1224 os << "#ifdef GET_DESERIALIZATION_FNS\n\n";
1225 os << deserFn.str();
1226 os << dDesFn.str();
1227 os << "#endif // GET_DESERIALIZATION_FNS\n\n";
1229 return false;
1232 //===----------------------------------------------------------------------===//
1233 // Serialization Hook Registration
1234 //===----------------------------------------------------------------------===//
1236 static mlir::GenRegistration genSerialization(
1237 "gen-spirv-serialization",
1238 "Generate SPIR-V (de)serialization utilities and functions",
1239 [](const RecordKeeper &records, raw_ostream &os) {
1240 return emitSerializationFns(records, os);
1243 //===----------------------------------------------------------------------===//
1244 // Op Utils AutoGen
1245 //===----------------------------------------------------------------------===//
1247 static void emitEnumGetAttrNameFnDecl(raw_ostream &os) {
1248 os << formatv("template <typename EnumClass> inline constexpr StringRef "
1249 "attributeName();\n");
1252 static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
1253 raw_ostream &os) {
1254 auto enumName = enumAttr.getEnumClassName();
1255 os << formatv("template <> inline StringRef attributeName<{0}>() {{\n",
1256 enumName);
1257 os << " "
1258 << formatv("static constexpr const char attrName[] = \"{0}\";\n",
1259 llvm::convertToSnakeFromCamelCase(enumName));
1260 os << " return attrName;\n";
1261 os << "}\n";
1264 static bool emitAttrUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
1265 llvm::emitSourceFileHeader("SPIR-V Attribute Utilities", os, recordKeeper);
1267 auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
1268 os << "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1269 os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1270 emitEnumGetAttrNameFnDecl(os);
1271 for (const auto *def : defs) {
1272 EnumAttr enumAttr(*def);
1273 emitEnumGetAttrNameFnDefn(enumAttr, os);
1275 os << "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n";
1276 return false;
1279 //===----------------------------------------------------------------------===//
1280 // Op Utils Hook Registration
1281 //===----------------------------------------------------------------------===//
1283 static mlir::GenRegistration
1284 genOpUtils("gen-spirv-attr-utils",
1285 "Generate SPIR-V attribute utility definitions",
1286 [](const RecordKeeper &records, raw_ostream &os) {
1287 return emitAttrUtils(records, os);
1290 //===----------------------------------------------------------------------===//
1291 // SPIR-V Availability Impl AutoGen
1292 //===----------------------------------------------------------------------===//
1294 static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
1295 mlir::tblgen::FmtContext fctx;
1296 fctx.addSubst("overall", "tblgen_overall");
1298 std::vector<Availability> opAvailabilities =
1299 getAvailabilities(srcOp.getDef());
1301 // First collect all availability classes this op should implement.
1302 // All availability instances keep information for the generated interface and
1303 // the instance's specific requirement. Here we remember a random instance so
1304 // we can get the information regarding the generated interface.
1305 llvm::StringMap<Availability> availClasses;
1306 for (const Availability &avail : opAvailabilities)
1307 availClasses.try_emplace(avail.getClass(), avail);
1308 for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
1309 if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") &&
1310 !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr"))
1311 continue;
1312 EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
1314 for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
1315 for (const Availability &caseAvail :
1316 getAvailabilities(enumerant.getDef()))
1317 availClasses.try_emplace(caseAvail.getClass(), caseAvail);
1320 // Then generate implementation for each availability class.
1321 for (const auto &availClass : availClasses) {
1322 StringRef availClassName = availClass.getKey();
1323 Availability avail = availClass.getValue();
1325 // Generate the implementation method signature.
1326 os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(),
1327 srcOp.getCppClassName(), avail.getQueryFnName());
1329 // Create the variable for the final requirement and initialize it.
1330 os << formatv(" {0} tblgen_overall = {1};\n", avail.getQueryFnRetType(),
1331 avail.getMergeInitializer());
1333 // Update with the op's specific availability spec.
1334 for (const Availability &avail : opAvailabilities)
1335 if (avail.getClass() == availClassName &&
1336 (!avail.getMergeInstancePreparation().empty() ||
1337 !avail.getMergeActionCode().empty())) {
1338 os << " {\n "
1339 // Prepare this instance.
1340 << avail.getMergeInstancePreparation()
1341 << "\n "
1342 // Merge this instance.
1343 << std::string(
1344 tgfmt(avail.getMergeActionCode(),
1345 &fctx.addSubst("instance", avail.getMergeInstance())))
1346 << ";\n }\n";
1349 // Update with enum attributes' specific availability spec.
1350 for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
1351 if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") &&
1352 !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr"))
1353 continue;
1354 EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
1356 // (enumerant, availability specification) pairs for this availability
1357 // class.
1358 SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs;
1360 // Collect all cases' availability specs.
1361 for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
1362 for (const Availability &caseAvail :
1363 getAvailabilities(enumerant.getDef()))
1364 if (availClassName == caseAvail.getClass())
1365 caseSpecs.push_back({enumerant, caseAvail});
1367 // If this attribute kind does not have any availability spec from any of
1368 // its cases, no more work to do.
1369 if (caseSpecs.empty())
1370 continue;
1372 if (enumAttr.isBitEnum()) {
1373 // For BitEnumAttr, we need to iterate over each bit to query its
1374 // availability spec.
1375 os << formatv(" for (unsigned i = 0; "
1376 "i < std::numeric_limits<{0}>::digits; ++i) {{\n",
1377 enumAttr.getUnderlyingType());
1378 os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & "
1379 "static_cast<{0}::{1}>(1 << i);\n",
1380 enumAttr.getCppNamespace(), enumAttr.getEnumClassName(),
1381 srcOp.getGetterName(namedAttr.name));
1382 os << formatv(
1383 " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
1384 enumAttr.getUnderlyingType());
1385 } else {
1386 // For IntEnumAttr, we just need to query the value as a whole.
1387 os << " {\n";
1388 os << formatv(" auto tblgen_attrVal = this->{0}();\n",
1389 srcOp.getGetterName(namedAttr.name));
1391 os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
1392 enumAttr.getCppNamespace(), avail.getQueryFnName());
1393 os << " if (tblgen_instance) "
1394 // TODO` here once ODS supports
1395 // dialect-specific contents so that we can use not implementing the
1396 // availability interface as indication of no requirements.
1397 << std::string(tgfmt(caseSpecs.front().second.getMergeActionCode(),
1398 &fctx.addSubst("instance", "*tblgen_instance")))
1399 << ";\n";
1400 os << " }\n";
1403 os << " return tblgen_overall;\n";
1404 os << "}\n";
1408 static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper,
1409 raw_ostream &os) {
1410 llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os,
1411 recordKeeper);
1413 auto defs = recordKeeper.getAllDerivedDefinitions("SPIRV_Op");
1414 for (const auto *def : defs) {
1415 Operator op(def);
1416 if (def->getValueAsBit("autogenAvailability"))
1417 emitAvailabilityImpl(op, os);
1419 return false;
1422 //===----------------------------------------------------------------------===//
1423 // Op Availability Implementation Hook Registration
1424 //===----------------------------------------------------------------------===//
1426 static mlir::GenRegistration
1427 genOpAvailabilityImpl("gen-spirv-avail-impls",
1428 "Generate SPIR-V operation utility definitions",
1429 [](const RecordKeeper &records, raw_ostream &os) {
1430 return emitAvailabilityImpl(records, os);
1433 //===----------------------------------------------------------------------===//
1434 // SPIR-V Capability Implication AutoGen
1435 //===----------------------------------------------------------------------===//
1437 static bool emitCapabilityImplication(const RecordKeeper &recordKeeper,
1438 raw_ostream &os) {
1439 llvm::emitSourceFileHeader("SPIR-V Capability Implication", os, recordKeeper);
1441 EnumAttr enumAttr(
1442 recordKeeper.getDef("SPIRV_CapabilityAttr")->getValueAsDef("enum"));
1444 os << "ArrayRef<spirv::Capability> "
1445 "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n"
1446 << " switch (cap) {\n"
1447 << " default: return {};\n";
1448 for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) {
1449 const Record &def = enumerant.getDef();
1450 if (!def.getValue("implies"))
1451 continue;
1453 std::vector<Record *> impliedCapsDefs = def.getValueAsListOfDefs("implies");
1454 os << " case spirv::Capability::" << enumerant.getSymbol()
1455 << ": {static const spirv::Capability implies[" << impliedCapsDefs.size()
1456 << "] = {";
1457 llvm::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) {
1458 os << "spirv::Capability::" << EnumAttrCase(capDef).getSymbol();
1460 os << "}; return ArrayRef<spirv::Capability>(implies, "
1461 << impliedCapsDefs.size() << "); }\n";
1463 os << " }\n";
1464 os << "}\n";
1466 return false;
1469 //===----------------------------------------------------------------------===//
1470 // SPIR-V Capability Implication Hook Registration
1471 //===----------------------------------------------------------------------===//
1473 static mlir::GenRegistration
1474 genCapabilityImplication("gen-spirv-capability-implication",
1475 "Generate utility function to return implied "
1476 "capabilities for a given capability",
1477 [](const RecordKeeper &records, raw_ostream &os) {
1478 return emitCapabilityImplication(records, os);