[NFC] Add libcxx python reformat SHA to .git-blame-ignore-revs
[llvm-project.git] / mlir / tools / mlir-tblgen / SPIRVUtilsGen.cpp
blob7489c3134fc73ebe8aaeda6c584db894b89e1377
1 //===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // SPIRVSerializationGen generates common utility functions for SPIR-V
10 // serialization.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Operator.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSet.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "llvm/TableGen/Error.h"
28 #include "llvm/TableGen/Record.h"
29 #include "llvm/TableGen/TableGenBackend.h"
31 #include <list>
32 #include <optional>
34 using llvm::ArrayRef;
35 using llvm::formatv;
36 using llvm::raw_ostream;
37 using llvm::raw_string_ostream;
38 using llvm::Record;
39 using llvm::RecordKeeper;
40 using llvm::SmallVector;
41 using llvm::SMLoc;
42 using llvm::StringMap;
43 using llvm::StringRef;
44 using mlir::tblgen::Attribute;
45 using mlir::tblgen::EnumAttr;
46 using mlir::tblgen::EnumAttrCase;
47 using mlir::tblgen::NamedAttribute;
48 using mlir::tblgen::NamedTypeConstraint;
49 using mlir::tblgen::NamespaceEmitter;
50 using mlir::tblgen::Operator;
52 //===----------------------------------------------------------------------===//
53 // Availability Wrapper Class
54 //===----------------------------------------------------------------------===//
56 namespace {
57 // Wrapper class with helper methods for accessing availability defined in
58 // TableGen.
59 class Availability {
60 public:
61 explicit Availability(const Record *def);
63 // Returns the name of the direct TableGen class for this availability
64 // instance.
65 StringRef getClass() const;
67 // Returns the generated C++ interface's class namespace.
68 StringRef getInterfaceClassNamespace() const;
70 // Returns the generated C++ interface's class name.
71 StringRef getInterfaceClassName() const;
73 // Returns the generated C++ interface's description.
74 StringRef getInterfaceDescription() const;
76 // Returns the name of the query function insided the generated C++ interface.
77 StringRef getQueryFnName() const;
79 // Returns the return type of the query function insided the generated C++
80 // interface.
81 StringRef getQueryFnRetType() const;
83 // Returns the code for merging availability requirements.
84 StringRef getMergeActionCode() const;
86 // Returns the initializer expression for initializing the final availability
87 // requirements.
88 StringRef getMergeInitializer() const;
90 // Returns the C++ type for an availability instance.
91 StringRef getMergeInstanceType() const;
93 // Returns the C++ statements for preparing availability instance.
94 StringRef getMergeInstancePreparation() const;
96 // Returns the concrete availability instance carried in this case.
97 StringRef getMergeInstance() const;
99 // Returns the underlying LLVM TableGen Record.
100 const llvm::Record *getDef() const { return def; }
102 private:
103 // The TableGen definition of this availability.
104 const llvm::Record *def;
106 } // namespace
108 Availability::Availability(const llvm::Record *def) : def(def) {
109 assert(def->isSubClassOf("Availability") &&
110 "must be subclass of TableGen 'Availability' class");
113 StringRef Availability::getClass() const {
114 SmallVector<Record *, 1> parentClass;
115 def->getDirectSuperClasses(parentClass);
116 if (parentClass.size() != 1) {
117 PrintFatalError(def->getLoc(),
118 "expected to only have one direct superclass");
120 return parentClass.front()->getName();
123 StringRef Availability::getInterfaceClassNamespace() const {
124 return def->getValueAsString("cppNamespace");
127 StringRef Availability::getInterfaceClassName() const {
128 return def->getValueAsString("interfaceName");
131 StringRef Availability::getInterfaceDescription() const {
132 return def->getValueAsString("interfaceDescription");
135 StringRef Availability::getQueryFnRetType() const {
136 return def->getValueAsString("queryFnRetType");
139 StringRef Availability::getQueryFnName() const {
140 return def->getValueAsString("queryFnName");
143 StringRef Availability::getMergeActionCode() const {
144 return def->getValueAsString("mergeAction");
147 StringRef Availability::getMergeInitializer() const {
148 return def->getValueAsString("initializer");
151 StringRef Availability::getMergeInstanceType() const {
152 return def->getValueAsString("instanceType");
155 StringRef Availability::getMergeInstancePreparation() const {
156 return def->getValueAsString("instancePreparation");
159 StringRef Availability::getMergeInstance() const {
160 return def->getValueAsString("instance");
163 // Returns the availability spec of the given `def`.
164 std::vector<Availability> getAvailabilities(const Record &def) {
165 std::vector<Availability> availabilities;
167 if (def.getValue("availability")) {
168 std::vector<Record *> availDefs = def.getValueAsListOfDefs("availability");
169 availabilities.reserve(availDefs.size());
170 for (const Record *avail : availDefs)
171 availabilities.emplace_back(avail);
174 return availabilities;
177 //===----------------------------------------------------------------------===//
178 // Availability Interface Definitions AutoGen
179 //===----------------------------------------------------------------------===//
181 static void emitInterfaceDef(const Availability &availability,
182 raw_ostream &os) {
184 os << availability.getQueryFnRetType() << " ";
186 StringRef cppNamespace = availability.getInterfaceClassNamespace();
187 cppNamespace.consume_front("::");
188 if (!cppNamespace.empty())
189 os << cppNamespace << "::";
191 StringRef methodName = availability.getQueryFnName();
192 os << availability.getInterfaceClassName() << "::" << methodName << "() {\n"
193 << " return getImpl()->" << methodName << "(getImpl(), getOperation());\n"
194 << "}\n";
197 static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
198 raw_ostream &os) {
199 llvm::emitSourceFileHeader("Availability Interface Definitions", os);
201 auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
202 SmallVector<const Record *, 1> handledClasses;
203 for (const Record *def : defs) {
204 SmallVector<Record *, 1> parent;
205 def->getDirectSuperClasses(parent);
206 if (parent.size() != 1) {
207 PrintFatalError(def->getLoc(),
208 "expected to only have one direct superclass");
210 if (llvm::is_contained(handledClasses, parent.front()))
211 continue;
213 Availability availability(def);
214 emitInterfaceDef(availability, os);
215 handledClasses.push_back(parent.front());
217 return false;
220 //===----------------------------------------------------------------------===//
221 // Availability Interface Declarations AutoGen
222 //===----------------------------------------------------------------------===//
224 static void emitConceptDecl(const Availability &availability, raw_ostream &os) {
225 os << " class Concept {\n"
226 << " public:\n"
227 << " virtual ~Concept() = default;\n"
228 << " virtual " << availability.getQueryFnRetType() << " "
229 << availability.getQueryFnName()
230 << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n"
231 << " };\n";
234 static void emitModelDecl(const Availability &availability, raw_ostream &os) {
235 for (const char *modelClass : {"Model", "FallbackModel"}) {
236 os << " template<typename ConcreteOp>\n";
237 os << " class " << modelClass << " : public Concept {\n"
238 << " public:\n"
239 << " using Interface = " << availability.getInterfaceClassName()
240 << ";\n"
241 << " " << availability.getQueryFnRetType() << " "
242 << availability.getQueryFnName()
243 << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n"
244 << " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
245 << " (void)op;\n"
246 // Forward to the method on the concrete operation type.
247 << " return op." << availability.getQueryFnName() << "();\n"
248 << " }\n"
249 << " };\n";
251 os << " template<typename ConcreteModel, typename ConcreteOp>\n";
252 os << " class ExternalModel : public FallbackModel<ConcreteOp> {};\n";
255 static void emitInterfaceDecl(const Availability &availability,
256 raw_ostream &os) {
257 StringRef interfaceName = availability.getInterfaceClassName();
258 std::string interfaceTraitsName =
259 std::string(formatv("{0}Traits", interfaceName));
261 StringRef cppNamespace = availability.getInterfaceClassNamespace();
262 NamespaceEmitter nsEmitter(os, cppNamespace);
263 os << "class " << interfaceName << ";\n\n";
265 // Emit the traits struct containing the concept and model declarations.
266 os << "namespace detail {\n"
267 << "struct " << interfaceTraitsName << " {\n";
268 emitConceptDecl(availability, os);
269 os << '\n';
270 emitModelDecl(availability, os);
271 os << "};\n} // namespace detail\n\n";
273 // Emit the main interface class declaration.
274 os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n";
275 os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
276 "public:\n"
277 " using OpInterface<{1}, detail::{2}>::OpInterface;\n",
278 interfaceName, interfaceName, interfaceTraitsName);
280 // Emit query function declaration.
281 os << " " << availability.getQueryFnRetType() << " "
282 << availability.getQueryFnName() << "();\n";
283 os << "};\n\n";
286 static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
287 raw_ostream &os) {
288 llvm::emitSourceFileHeader("Availability Interface Declarations", os);
290 auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
291 SmallVector<const Record *, 4> handledClasses;
292 for (const Record *def : defs) {
293 SmallVector<Record *, 1> parent;
294 def->getDirectSuperClasses(parent);
295 if (parent.size() != 1) {
296 PrintFatalError(def->getLoc(),
297 "expected to only have one direct superclass");
299 if (llvm::is_contained(handledClasses, parent.front()))
300 continue;
302 Availability avail(def);
303 emitInterfaceDecl(avail, os);
304 handledClasses.push_back(parent.front());
306 return false;
309 //===----------------------------------------------------------------------===//
310 // Availability Interface Hook Registration
311 //===----------------------------------------------------------------------===//
313 // Registers the operation interface generator to mlir-tblgen.
314 static mlir::GenRegistration
315 genInterfaceDecls("gen-avail-interface-decls",
316 "Generate availability interface declarations",
317 [](const RecordKeeper &records, raw_ostream &os) {
318 return emitInterfaceDecls(records, os);
321 // Registers the operation interface generator to mlir-tblgen.
322 static mlir::GenRegistration
323 genInterfaceDefs("gen-avail-interface-defs",
324 "Generate op interface definitions",
325 [](const RecordKeeper &records, raw_ostream &os) {
326 return emitInterfaceDefs(records, os);
329 //===----------------------------------------------------------------------===//
330 // Enum Availability Query AutoGen
331 //===----------------------------------------------------------------------===//
333 static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
334 raw_ostream &os) {
335 EnumAttr enumAttr(enumDef);
336 StringRef enumName = enumAttr.getEnumClassName();
337 std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
339 // Mapping from availability class name to (enumerant, availability
340 // specification) pairs.
341 llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
342 classCaseMap;
344 // Place all availability specifications to their corresponding
345 // availability classes.
346 for (const EnumAttrCase &enumerant : enumerants)
347 for (const Availability &avail : getAvailabilities(enumerant.getDef()))
348 classCaseMap[avail.getClass()].push_back({enumerant, avail});
350 for (const auto &classCasePair : classCaseMap) {
351 Availability avail = classCasePair.getValue().front().second;
353 os << formatv("std::optional<{0}> {1}({2} value) {{\n",
354 avail.getMergeInstanceType(), avail.getQueryFnName(),
355 enumName);
357 os << " switch (value) {\n";
358 for (const auto &caseSpecPair : classCasePair.getValue()) {
359 EnumAttrCase enumerant = caseSpecPair.first;
360 Availability avail = caseSpecPair.second;
361 os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
362 enumerant.getSymbol(), avail.getMergeInstancePreparation(),
363 avail.getMergeInstanceType(), avail.getMergeInstance());
365 // Only emit default if uncovered cases.
366 if (classCasePair.getValue().size() < enumAttr.getAllCases().size())
367 os << " default: break;\n";
368 os << " }\n"
369 << " return std::nullopt;\n"
370 << "}\n";
374 static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
375 raw_ostream &os) {
376 EnumAttr enumAttr(enumDef);
377 StringRef enumName = enumAttr.getEnumClassName();
378 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
379 std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
381 // Mapping from availability class name to (enumerant, availability
382 // specification) pairs.
383 llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
384 classCaseMap;
386 // Place all availability specifications to their corresponding
387 // availability classes.
388 for (const EnumAttrCase &enumerant : enumerants)
389 for (const Availability &avail : getAvailabilities(enumerant.getDef()))
390 classCaseMap[avail.getClass()].push_back({enumerant, avail});
392 for (const auto &classCasePair : classCaseMap) {
393 Availability avail = classCasePair.getValue().front().second;
395 os << formatv("std::optional<{0}> {1}({2} value) {{\n",
396 avail.getMergeInstanceType(), avail.getQueryFnName(),
397 enumName);
399 os << formatv(
400 " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1"
401 " && \"cannot have more than one bit set\");\n",
402 underlyingType);
404 os << " switch (value) {\n";
405 for (const auto &caseSpecPair : classCasePair.getValue()) {
406 EnumAttrCase enumerant = caseSpecPair.first;
407 Availability avail = caseSpecPair.second;
408 os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
409 enumerant.getSymbol(), avail.getMergeInstancePreparation(),
410 avail.getMergeInstanceType(), avail.getMergeInstance());
412 os << " default: break;\n";
413 os << " }\n"
414 << " return std::nullopt;\n"
415 << "}\n";
419 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
420 EnumAttr enumAttr(enumDef);
421 StringRef enumName = enumAttr.getEnumClassName();
422 StringRef cppNamespace = enumAttr.getCppNamespace();
423 auto enumerants = enumAttr.getAllCases();
425 llvm::SmallVector<StringRef, 2> namespaces;
426 llvm::SplitString(cppNamespace, namespaces, "::");
428 for (auto ns : namespaces)
429 os << "namespace " << ns << " {\n";
431 llvm::StringSet<> handledClasses;
433 // Place all availability specifications to their corresponding
434 // availability classes.
435 for (const EnumAttrCase &enumerant : enumerants)
436 for (const Availability &avail : getAvailabilities(enumerant.getDef())) {
437 StringRef className = avail.getClass();
438 if (handledClasses.count(className))
439 continue;
440 os << formatv("std::optional<{0}> {1}({2} value);\n",
441 avail.getMergeInstanceType(), avail.getQueryFnName(),
442 enumName);
443 handledClasses.insert(className);
446 for (auto ns : llvm::reverse(namespaces))
447 os << "} // namespace " << ns << "\n";
450 static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
451 llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os);
453 auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
454 for (const auto *def : defs)
455 emitEnumDecl(*def, os);
457 return false;
460 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
461 EnumAttr enumAttr(enumDef);
462 StringRef cppNamespace = enumAttr.getCppNamespace();
464 llvm::SmallVector<StringRef, 2> namespaces;
465 llvm::SplitString(cppNamespace, namespaces, "::");
467 for (auto ns : namespaces)
468 os << "namespace " << ns << " {\n";
470 if (enumAttr.isBitEnum()) {
471 emitAvailabilityQueryForBitEnum(enumDef, os);
472 } else {
473 emitAvailabilityQueryForIntEnum(enumDef, os);
476 for (auto ns : llvm::reverse(namespaces))
477 os << "} // namespace " << ns << "\n";
478 os << "\n";
481 static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
482 llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os);
484 auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
485 for (const auto *def : defs)
486 emitEnumDef(*def, os);
488 return false;
491 //===----------------------------------------------------------------------===//
492 // Enum Availability Query Hook Registration
493 //===----------------------------------------------------------------------===//
495 // Registers the enum utility generator to mlir-tblgen.
496 static mlir::GenRegistration
497 genEnumDecls("gen-spirv-enum-avail-decls",
498 "Generate SPIR-V enum availability declarations",
499 [](const RecordKeeper &records, raw_ostream &os) {
500 return emitEnumDecls(records, os);
503 // Registers the enum utility generator to mlir-tblgen.
504 static mlir::GenRegistration
505 genEnumDefs("gen-spirv-enum-avail-defs",
506 "Generate SPIR-V enum availability definitions",
507 [](const RecordKeeper &records, raw_ostream &os) {
508 return emitEnumDefs(records, os);
511 //===----------------------------------------------------------------------===//
512 // Serialization AutoGen
513 //===----------------------------------------------------------------------===//
515 /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
516 /// generates code extracts the attribute with name `attrName` from
517 /// `operandList` of `op`.
518 static void emitAttributeSerialization(const Attribute &attr,
519 ArrayRef<SMLoc> loc, StringRef tabs,
520 StringRef opVar, StringRef operandList,
521 StringRef attrName, raw_ostream &os) {
522 os << tabs
523 << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
524 if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
525 attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
526 attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
527 // These two enums are encoded as <id> to constant values in SPIR-V blob,
528 // but we directly use the constant value as attribute in SPIR-V dialect. So
529 // need to handle them separately from normal enum attributes.
530 EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
531 os << tabs
532 << formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), "
533 "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>("
534 "attr.cast<{2}::{3}Attr>().getValue()))));\n",
535 operandList, opVar, baseEnum.getCppNamespace(),
536 baseEnum.getEnumClassName());
537 } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") ||
538 attr.isSubClassOf("SPIRV_I32EnumAttr")) {
539 EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
540 os << tabs
541 << formatv(" {0}.push_back(static_cast<uint32_t>("
542 "attr.cast<{1}::{2}Attr>().getValue()));\n",
543 operandList, baseEnum.getCppNamespace(),
544 baseEnum.getEnumClassName());
545 } else if (attr.getAttrDefName() == "I32ArrayAttr") {
546 // Serialize all the elements of the array
547 os << tabs << " for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
548 os << tabs
549 << formatv(" {0}.push_back(static_cast<uint32_t>("
550 "attrElem.cast<IntegerAttr>().getValue().getZExtValue()));\n",
551 operandList);
552 os << tabs << " }\n";
553 } else if (attr.getAttrDefName() == "I32Attr") {
554 os << tabs
555 << formatv(" {0}.push_back(static_cast<uint32_t>("
556 "attr.cast<IntegerAttr>().getValue().getZExtValue()));\n",
557 operandList);
558 } else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") {
559 os << tabs
560 << formatv(" {0}.push_back(static_cast<uint32_t>("
561 "getTypeID(attr.cast<TypeAttr>().getValue())));\n",
562 operandList);
563 } else {
564 PrintFatalError(
565 loc,
566 llvm::Twine(
567 "unhandled attribute type in SPIR-V serialization generation : '") +
568 attr.getAttrDefName() + llvm::Twine("'"));
570 os << tabs << "}\n";
573 /// Generates code to serialize the operands of a SPIRV_Op `op` into `os`. The
574 /// generated queries the SSA-ID if operand is a SSA-Value, or serializes the
575 /// attributes. The `operands` vector is updated appropriately. `elidedAttrs`
576 /// updated as well to include the serialized attributes.
577 static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
578 StringRef tabs, StringRef opVar,
579 StringRef operands, StringRef elidedAttrs,
580 raw_ostream &os) {
581 using mlir::tblgen::Argument;
583 // SPIR-V ops can mix operands and attributes in the definition. These
584 // operands and attributes are serialized in the exact order of the definition
585 // to match SPIR-V binary format requirements. It can cause excessive
586 // generated code bloat because we are emitting code to handle each
587 // operand/attribute separately. So here we probe first to check whether all
588 // the operands are ahead of attributes. Then we can serialize all operands
589 // together.
591 // Whether all operands are ahead of all attributes in the op's spec.
592 bool areOperandsAheadOfAttrs = true;
593 // Find the first attribute.
594 const Argument *it = llvm::find_if(op.getArgs(), [](const Argument &arg) {
595 return arg.is<NamedAttribute *>();
597 // Check whether all following arguments are attributes.
598 for (const Argument *ie = op.arg_end(); it != ie; ++it) {
599 if (!it->is<NamedAttribute *>()) {
600 areOperandsAheadOfAttrs = false;
601 break;
605 // Serialize all operands together.
606 if (areOperandsAheadOfAttrs) {
607 if (op.getNumOperands() != 0) {
608 os << tabs
609 << formatv("for (Value operand : {0}->getOperands()) {{\n", opVar);
610 os << tabs << " auto id = getValueID(operand);\n";
611 os << tabs << " assert(id && \"use before def!\");\n";
612 os << tabs << formatv(" {0}.push_back(id);\n", operands);
613 os << tabs << "}\n";
615 for (const NamedAttribute &attr : op.getAttributes()) {
616 emitAttributeSerialization(
617 (attr.attr.isOptional() ? attr.attr.getBaseAttr() : attr.attr), loc,
618 tabs, opVar, operands, attr.name, os);
619 os << tabs
620 << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr.name);
622 return;
625 // Serialize operands separately.
626 auto operandNum = 0;
627 for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
628 auto argument = op.getArg(i);
629 os << tabs << "{\n";
630 if (argument.is<NamedTypeConstraint *>()) {
631 os << tabs
632 << formatv(" for (auto arg : {0}.getODSOperands({1})) {{\n", opVar,
633 operandNum);
634 os << tabs << " auto argID = getValueID(arg);\n";
635 os << tabs << " if (!argID) {\n";
636 os << tabs
637 << formatv(" return emitError({0}.getLoc(), "
638 "\"operand #{1} has a use before def\");\n",
639 opVar, operandNum);
640 os << tabs << " }\n";
641 os << tabs << formatv(" {0}.push_back(argID);\n", operands);
642 os << " }\n";
643 operandNum++;
644 } else {
645 NamedAttribute *attr = argument.get<NamedAttribute *>();
646 auto newtabs = tabs.str() + " ";
647 emitAttributeSerialization(
648 (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
649 loc, newtabs, opVar, operands, attr->name, os);
650 os << newtabs
651 << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr->name);
653 os << tabs << "}\n";
657 /// Generates code to serializes the result of SPIRV_Op `op` into `os`. The
658 /// generated gets the ID for the type of the result (if any), the SSA-ID of
659 /// the result and updates `resultID` with the SSA-ID.
660 static void emitResultSerialization(const Operator &op, ArrayRef<SMLoc> loc,
661 StringRef tabs, StringRef opVar,
662 StringRef operands, StringRef resultID,
663 raw_ostream &os) {
664 if (op.getNumResults() == 1) {
665 StringRef resultTypeID("resultTypeID");
666 os << tabs << formatv("uint32_t {0} = 0;\n", resultTypeID);
667 os << tabs
668 << formatv(
669 "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n",
670 opVar, resultTypeID);
671 os << tabs << " return failure();\n";
672 os << tabs << "}\n";
673 os << tabs << formatv("{0}.push_back({1});\n", operands, resultTypeID);
674 // Create an SSA result <id> for the op
675 os << tabs << formatv("{0} = getNextID();\n", resultID);
676 os << tabs
677 << formatv("valueIDMap[{0}.getResult()] = {1};\n", opVar, resultID);
678 os << tabs << formatv("{0}.push_back({1});\n", operands, resultID);
679 } else if (op.getNumResults() != 0) {
680 PrintFatalError(loc, "SPIR-V ops can only have zero or one result");
684 /// Generates code to serialize attributes of SPIRV_Op `op` that become
685 /// decorations on the `resultID` of the serialized operation `opVar` in the
686 /// SPIR-V binary.
687 static void emitDecorationSerialization(const Operator &op, StringRef tabs,
688 StringRef opVar, StringRef elidedAttrs,
689 StringRef resultID, raw_ostream &os) {
690 if (op.getNumResults() == 1) {
691 // All non-argument attributes translated into OpDecorate instruction
692 os << tabs << formatv("for (auto attr : {0}->getAttrs()) {{\n", opVar);
693 os << tabs
694 << formatv(" if (llvm::is_contained({0}, attr.getName())) {{",
695 elidedAttrs);
696 os << tabs << " continue;\n";
697 os << tabs << " }\n";
698 os << tabs
699 << formatv(
700 " if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n",
701 opVar, resultID);
702 os << tabs << " return failure();\n";
703 os << tabs << " }\n";
704 os << tabs << "}\n";
708 /// Generates code to serialize an SPIRV_Op `op` into `os`.
709 static void emitSerializationFunction(const Record *attrClass,
710 const Record *record, const Operator &op,
711 raw_ostream &os) {
712 // If the record has 'autogenSerialization' set to 0, nothing to do
713 if (!record->getValueAsBit("autogenSerialization"))
714 return;
716 StringRef opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"),
717 resultID("resultID");
719 os << formatv(
720 "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n",
721 op.getQualCppClassName(), opVar);
723 // Special case for ops without attributes in TableGen definitions
724 if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
725 std::string extInstSet;
726 std::string opcode;
727 if (record->isSubClassOf("SPIRV_ExtInstOp")) {
728 extInstSet =
729 formatv("\"{0}\"", record->getValueAsString("extendedInstSetName"));
730 opcode = std::to_string(record->getValueAsInt("extendedInstOpcode"));
731 } else {
732 extInstSet = "\"\"";
733 opcode = formatv("static_cast<uint32_t>(spirv::Opcode::{0})",
734 record->getValueAsString("spirvOpName"));
737 os << formatv(" return processOpWithoutGrammarAttr({0}, {1}, {2});\n}\n\n",
738 opVar, extInstSet, opcode);
739 return;
742 os << formatv(" SmallVector<uint32_t, 4> {0};\n", operands);
743 os << formatv(" SmallVector<StringRef, 2> {0};\n", elidedAttrs);
745 // Serialize result information.
746 if (op.getNumResults() == 1) {
747 os << formatv(" uint32_t {0} = 0;\n", resultID);
748 emitResultSerialization(op, record->getLoc(), " ", opVar, operands,
749 resultID, os);
752 // Process arguments.
753 emitArgumentSerialization(op, record->getLoc(), " ", opVar, operands,
754 elidedAttrs, os);
756 if (record->isSubClassOf("SPIRV_ExtInstOp")) {
757 os << formatv(
758 " (void)encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n", opVar,
759 record->getValueAsString("extendedInstSetName"),
760 record->getValueAsInt("extendedInstOpcode"), operands);
761 } else {
762 // Emit debug info.
763 os << formatv(" (void)emitDebugLine(functionBody, {0}.getLoc());\n",
764 opVar);
765 os << formatv(" (void)encodeInstructionInto("
766 "functionBody, spirv::Opcode::{1}, {2});\n",
767 op.getQualCppClassName(),
768 record->getValueAsString("spirvOpName"), operands);
771 // Process decorations.
772 emitDecorationSerialization(op, " ", opVar, elidedAttrs, resultID, os);
774 os << " return success();\n";
775 os << "}\n\n";
778 /// Generates the prologue for the function that dispatches the serialization of
779 /// the operation `opVar` based on its opcode.
780 static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
781 os << formatv(
782 "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
783 "*{0}) {{\n",
784 opVar);
787 /// Generates the body of the dispatch function. This function generates the
788 /// check that if satisfied, will call the serialization function generated for
789 /// the `op`.
790 static void emitSerializationDispatch(const Operator &op, StringRef tabs,
791 StringRef opVar, raw_ostream &os) {
792 os << tabs
793 << formatv("if (isa<{0}>({1})) {{\n", op.getQualCppClassName(), opVar);
794 os << tabs
795 << formatv(" return processOp(cast<{0}>({1}));\n",
796 op.getQualCppClassName(), opVar);
797 os << tabs << "}\n";
800 /// Generates the epilogue for the function that dispatches the serialization of
801 /// the operation.
802 static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
803 os << formatv(
804 " return {0}->emitError(\"unhandled operation serialization\");\n",
805 opVar);
806 os << "}\n\n";
809 /// Generates code to deserialize the attribute of a SPIRV_Op into `os`. The
810 /// generated code reads the `words` of the serialized instruction at
811 /// position `wordIndex` and adds the deserialized attribute into `attrList`.
812 static void emitAttributeDeserialization(const Attribute &attr,
813 ArrayRef<SMLoc> loc, StringRef tabs,
814 StringRef attrList, StringRef attrName,
815 StringRef words, StringRef wordIndex,
816 raw_ostream &os) {
817 if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
818 attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
819 attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
820 // These two enums are encoded as <id> to constant values in SPIR-V blob,
821 // but we directly use the constant value as attribute in SPIR-V dialect. So
822 // need to handle them separately from normal enum attributes.
823 EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
824 os << tabs
825 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
826 "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>("
827 "getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n",
828 attrList, attrName, baseEnum.getCppNamespace(),
829 baseEnum.getEnumClassName(), words, wordIndex);
830 } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") ||
831 attr.isSubClassOf("SPIRV_I32EnumAttr")) {
832 EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
833 os << tabs
834 << formatv(" {0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
835 "opBuilder.getAttr<{2}::{3}Attr>("
836 "static_cast<{2}::{3}>({4}[{5}++]))));\n",
837 attrList, attrName, baseEnum.getCppNamespace(),
838 baseEnum.getEnumClassName(), words, wordIndex);
839 } else if (attr.getAttrDefName() == "I32ArrayAttr") {
840 os << tabs << "SmallVector<Attribute, 4> attrListElems;\n";
841 os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words);
842 os << tabs
843 << formatv(
845 "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))"
846 ";\n",
847 words, wordIndex);
848 os << tabs << "}\n";
849 os << tabs
850 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
851 "opBuilder.getArrayAttr(attrListElems)));\n",
852 attrList, attrName);
853 } else if (attr.getAttrDefName() == "I32Attr") {
854 os << tabs
855 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
856 "opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
857 attrList, attrName, words, wordIndex);
858 } else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") {
859 os << tabs
860 << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
861 "TypeAttr::get(getType({2}[{3}++]))));\n",
862 attrList, attrName, words, wordIndex);
863 } else {
864 PrintFatalError(
865 loc, llvm::Twine(
866 "unhandled attribute type in deserialization generation : '") +
867 attr.getAttrDefName() + llvm::Twine("'"));
871 /// Generates the code to deserialize the result of an SPIRV_Op `op` into
872 /// `os`. The generated code gets the type of the result specified at
873 /// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1
874 /// and updates the `resultType` and `valueID` with the parsed type and SSA ID,
875 /// respectively.
876 static void emitResultDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
877 StringRef tabs, StringRef words,
878 StringRef wordIndex,
879 StringRef resultTypes, StringRef valueID,
880 raw_ostream &os) {
881 // Deserialize result information if it exists
882 if (op.getNumResults() == 1) {
883 os << tabs << "{\n";
884 os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words);
885 os << tabs
886 << formatv(
887 " return emitError(unknownLoc, \"expected result type <id> "
888 "while deserializing {0}\");\n",
889 op.getQualCppClassName());
890 os << tabs << " }\n";
891 os << tabs << formatv(" auto ty = getType({0}[{1}]);\n", words, wordIndex);
892 os << tabs << " if (!ty) {\n";
893 os << tabs
894 << formatv(
895 " return emitError(unknownLoc, \"unknown type result <id> : "
896 "\") << {0}[{1}];\n",
897 words, wordIndex);
898 os << tabs << " }\n";
899 os << tabs << formatv(" {0}.push_back(ty);\n", resultTypes);
900 os << tabs << formatv(" {0}++;\n", wordIndex);
901 os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words);
902 os << tabs
903 << formatv(
904 " return emitError(unknownLoc, \"expected result <id> while "
905 "deserializing {0}\");\n",
906 op.getQualCppClassName());
907 os << tabs << " }\n";
908 os << tabs << "}\n";
909 os << tabs << formatv("{0} = {1}[{2}++];\n", valueID, words, wordIndex);
910 } else if (op.getNumResults() != 0) {
911 PrintFatalError(loc, "SPIR-V ops can have only zero or one result");
915 /// Generates the code to deserialize the operands of an SPIRV_Op `op` into
916 /// `os`. The generated code reads the `words` of the binary instruction, from
917 /// position `wordIndex` to the end, and either gets the Value corresponding to
918 /// the ID encoded, or deserializes the attributes encoded. The parsed
919 /// operand(attribute) is added to the `operands` list or `attributes` list.
920 static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
921 StringRef tabs, StringRef words,
922 StringRef wordIndex, StringRef operands,
923 StringRef attributes, raw_ostream &os) {
924 // Process operands/attributes
925 for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
926 auto argument = op.getArg(i);
927 if (auto *valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
928 if (valueArg->isVariableLength()) {
929 if (i != e - 1) {
930 PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or "
931 "std::optional<...> arguments only if "
932 "it's the last argument");
934 os << tabs
935 << formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words);
936 } else {
937 os << tabs << formatv("if ({0} < {1}.size())", wordIndex, words);
939 os << " {\n";
940 os << tabs
941 << formatv(" auto arg = getValue({0}[{1}]);\n", words, wordIndex);
942 os << tabs << " if (!arg) {\n";
943 os << tabs
944 << formatv(
945 " return emitError(unknownLoc, \"unknown result <id> : \") "
946 "<< {0}[{1}];\n",
947 words, wordIndex);
948 os << tabs << " }\n";
949 os << tabs << formatv(" {0}.push_back(arg);\n", operands);
950 if (!valueArg->isVariableLength()) {
951 os << tabs << formatv(" {0}++;\n", wordIndex);
953 os << tabs << "}\n";
954 } else {
955 os << tabs << formatv("if ({0} < {1}.size()) {{\n", wordIndex, words);
956 auto *attr = argument.get<NamedAttribute *>();
957 auto newtabs = tabs.str() + " ";
958 emitAttributeDeserialization(
959 (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
960 loc, newtabs, attributes, attr->name, words, wordIndex, os);
961 os << " }\n";
965 os << tabs << formatv("if ({0} != {1}.size()) {{\n", wordIndex, words);
966 os << tabs
967 << formatv(
968 " return emitError(unknownLoc, \"found more operands than "
969 "expected when deserializing {0}, only \") << {1} << \" of \" << "
970 "{2}.size() << \" processed\";\n",
971 op.getQualCppClassName(), wordIndex, words);
972 os << tabs << "}\n\n";
975 /// Generates code to update the `attributes` vector with the attributes
976 /// obtained from parsing the decorations in the SPIR-V binary associated with
977 /// an <id> `valueID`
978 static void emitDecorationDeserialization(const Operator &op, StringRef tabs,
979 StringRef valueID,
980 StringRef attributes,
981 raw_ostream &os) {
982 // Import decorations parsed
983 if (op.getNumResults() == 1) {
984 os << tabs << formatv("if (decorations.count({0})) {{\n", valueID);
985 os << tabs
986 << formatv(" auto attrs = decorations[{0}].getAttrs();\n", valueID);
987 os << tabs
988 << formatv(" {0}.append(attrs.begin(), attrs.end());\n", attributes);
989 os << tabs << "}\n";
993 /// Generates code to deserialize an SPIRV_Op `op` into `os`.
994 static void emitDeserializationFunction(const Record *attrClass,
995 const Record *record,
996 const Operator &op, raw_ostream &os) {
997 // If the record has 'autogenSerialization' set to 0, nothing to do
998 if (!record->getValueAsBit("autogenSerialization"))
999 return;
1001 StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"),
1002 wordIndex("wordIndex"), opVar("op"), operands("operands"),
1003 attributes("attributes");
1005 // Method declaration
1006 os << formatv("template <> "
1007 "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
1008 "uint32_t> {1}) {{\n",
1009 op.getQualCppClassName(), words);
1011 // Special case for ops without attributes in TableGen definitions
1012 if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
1013 os << formatv(" return processOpWithoutGrammarAttr("
1014 "{0}, \"{1}\", {2}, {3});\n}\n\n",
1015 words, op.getOperationName(),
1016 op.getNumResults() ? "true" : "false", op.getNumOperands());
1017 return;
1020 os << formatv(" SmallVector<Type, 1> {0};\n", resultTypes);
1021 os << formatv(" size_t {0} = 0; (void){0};\n", wordIndex);
1022 os << formatv(" uint32_t {0} = 0; (void){0};\n", valueID);
1024 // Deserialize result information
1025 emitResultDeserialization(op, record->getLoc(), " ", words, wordIndex,
1026 resultTypes, valueID, os);
1028 os << formatv(" SmallVector<Value, 4> {0};\n", operands);
1029 os << formatv(" SmallVector<NamedAttribute, 4> {0};\n", attributes);
1030 // Operand deserialization
1031 emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex,
1032 operands, attributes, os);
1034 // Decorations
1035 emitDecorationDeserialization(op, " ", valueID, attributes, os);
1037 os << formatv(" Location loc = createFileLineColLoc(opBuilder);\n");
1038 os << formatv(" auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); "
1039 "(void){1};\n",
1040 op.getQualCppClassName(), opVar, resultTypes, operands,
1041 attributes);
1042 if (op.getNumResults() == 1) {
1043 os << formatv(" valueMap[{0}] = {1}.getResult();\n\n", valueID, opVar);
1046 // According to SPIR-V spec:
1047 // This location information applies to the instructions physically following
1048 // this instruction, up to the first occurrence of any of the following: the
1049 // next end of block.
1050 os << formatv(" if ({0}.hasTrait<OpTrait::IsTerminator>())\n", opVar);
1051 os << formatv(" (void)clearDebugLine();\n");
1052 os << " return success();\n";
1053 os << "}\n\n";
1056 /// Generates the prologue for the function that dispatches the deserialization
1057 /// based on the `opcode`.
1058 static void initDispatchDeserializationFn(StringRef opcode, StringRef words,
1059 raw_ostream &os) {
1060 os << formatv("LogicalResult spirv::Deserializer::"
1061 "dispatchToAutogenDeserialization(spirv::Opcode {0},"
1062 " ArrayRef<uint32_t> {1}) {{\n",
1063 opcode, words);
1064 os << formatv(" switch ({0}) {{\n", opcode);
1067 /// Generates the body of the dispatch function, by generating the case label
1068 /// for an opcode and the call to the method to perform the deserialization.
1069 static void emitDeserializationDispatch(const Operator &op, const Record *def,
1070 StringRef tabs, StringRef words,
1071 raw_ostream &os) {
1072 os << tabs
1073 << formatv("case spirv::Opcode::{0}:\n",
1074 def->getValueAsString("spirvOpName"));
1075 os << tabs
1076 << formatv(" return processOp<{0}>({1});\n", op.getQualCppClassName(),
1077 words);
1080 /// Generates the epilogue for the function that dispatches the deserialization
1081 /// of the operation.
1082 static void finalizeDispatchDeserializationFn(StringRef opcode,
1083 raw_ostream &os) {
1084 os << " default:\n";
1085 os << " ;\n";
1086 os << " }\n";
1087 StringRef opcodeVar("opcodeString");
1088 os << formatv(" auto {0} = spirv::stringifyOpcode({1});\n", opcodeVar,
1089 opcode);
1090 os << formatv(" if (!{0}.empty()) {{\n", opcodeVar);
1091 os << formatv(" return emitError(unknownLoc, \"unhandled deserialization "
1092 "of \") << {0};\n",
1093 opcodeVar);
1094 os << " } else {\n";
1095 os << formatv(" return emitError(unknownLoc, \"unhandled opcode \") << "
1096 "static_cast<uint32_t>({0});\n",
1097 opcode);
1098 os << " }\n";
1099 os << "}\n";
1102 static void initExtendedSetDeserializationDispatch(StringRef extensionSetName,
1103 StringRef instructionID,
1104 StringRef words,
1105 raw_ostream &os) {
1106 os << formatv("LogicalResult spirv::Deserializer::"
1107 "dispatchToExtensionSetAutogenDeserialization("
1108 "StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n",
1109 extensionSetName, instructionID, words);
1112 static void
1113 emitExtendedSetDeserializationDispatch(const RecordKeeper &recordKeeper,
1114 raw_ostream &os) {
1115 StringRef extensionSetName("extensionSetName"),
1116 instructionID("instructionID"), words("words");
1118 // First iterate over all ops derived from SPIRV_ExtensionSetOps to get all
1119 // extensionSets.
1121 // For each of the extensions a separate raw_string_ostream is used to
1122 // generate code into. These are then concatenated at the end. Since
1123 // raw_string_ostream needs a string&, use a vector to store all the string
1124 // that are captured by reference within raw_string_ostream.
1125 StringMap<raw_string_ostream> extensionSets;
1126 std::list<std::string> extensionSetNames;
1128 initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words,
1129 os);
1130 auto defs = recordKeeper.getAllDerivedDefinitions("SPIRV_ExtInstOp");
1131 for (const auto *def : defs) {
1132 if (!def->getValueAsBit("autogenSerialization")) {
1133 continue;
1135 Operator op(def);
1136 auto setName = def->getValueAsString("extendedInstSetName");
1137 if (!extensionSets.count(setName)) {
1138 extensionSetNames.emplace_back("");
1139 extensionSets.try_emplace(setName, extensionSetNames.back());
1140 auto &setos = extensionSets.find(setName)->second;
1141 setos << formatv(" if ({0} == \"{1}\") {{\n", extensionSetName, setName);
1142 setos << formatv(" switch ({0}) {{\n", instructionID);
1144 auto &setos = extensionSets.find(setName)->second;
1145 setos << formatv(" case {0}:\n",
1146 def->getValueAsInt("extendedInstOpcode"));
1147 setos << formatv(" return processOp<{0}>({1});\n",
1148 op.getQualCppClassName(), words);
1151 // Append the dispatch code for all the extended sets.
1152 for (auto &extensionSet : extensionSets) {
1153 os << extensionSet.second.str();
1154 os << " default:\n";
1155 os << formatv(
1156 " return emitError(unknownLoc, \"unhandled deserializations of "
1157 "\") << {0} << \" from extension set \" << {1};\n",
1158 instructionID, extensionSetName);
1159 os << " }\n";
1160 os << " }\n";
1163 os << formatv(" return emitError(unknownLoc, \"unhandled deserialization of "
1164 "extended instruction set {0}\");\n",
1165 extensionSetName);
1166 os << "}\n";
1169 /// Emits all the autogenerated serialization/deserializations functions for the
1170 /// SPIRV_Ops.
1171 static bool emitSerializationFns(const RecordKeeper &recordKeeper,
1172 raw_ostream &os) {
1173 llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os);
1175 std::string dSerFnString, dDesFnString, serFnString, deserFnString,
1176 utilsString;
1177 raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
1178 serFn(serFnString), deserFn(deserFnString);
1179 Record *attrClass = recordKeeper.getClass("Attr");
1181 // Emit the serialization and deserialization functions simultaneously.
1182 StringRef opVar("op");
1183 StringRef opcode("opcode"), words("words");
1185 // Handle the SPIR-V ops.
1186 initDispatchSerializationFn(opVar, dSerFn);
1187 initDispatchDeserializationFn(opcode, words, dDesFn);
1188 auto defs = recordKeeper.getAllDerivedDefinitions("SPIRV_Op");
1189 for (const auto *def : defs) {
1190 Operator op(def);
1191 emitSerializationFunction(attrClass, def, op, serFn);
1192 emitDeserializationFunction(attrClass, def, op, deserFn);
1193 if (def->getValueAsBit("hasOpcode") ||
1194 def->isSubClassOf("SPIRV_ExtInstOp")) {
1195 emitSerializationDispatch(op, " ", opVar, dSerFn);
1197 if (def->getValueAsBit("hasOpcode")) {
1198 emitDeserializationDispatch(op, def, " ", words, dDesFn);
1201 finalizeDispatchSerializationFn(opVar, dSerFn);
1202 finalizeDispatchDeserializationFn(opcode, dDesFn);
1204 emitExtendedSetDeserializationDispatch(recordKeeper, dDesFn);
1206 os << "#ifdef GET_SERIALIZATION_FNS\n\n";
1207 os << serFn.str();
1208 os << dSerFn.str();
1209 os << "#endif // GET_SERIALIZATION_FNS\n\n";
1211 os << "#ifdef GET_DESERIALIZATION_FNS\n\n";
1212 os << deserFn.str();
1213 os << dDesFn.str();
1214 os << "#endif // GET_DESERIALIZATION_FNS\n\n";
1216 return false;
1219 //===----------------------------------------------------------------------===//
1220 // Serialization Hook Registration
1221 //===----------------------------------------------------------------------===//
1223 static mlir::GenRegistration genSerialization(
1224 "gen-spirv-serialization",
1225 "Generate SPIR-V (de)serialization utilities and functions",
1226 [](const RecordKeeper &records, raw_ostream &os) {
1227 return emitSerializationFns(records, os);
1230 //===----------------------------------------------------------------------===//
1231 // Op Utils AutoGen
1232 //===----------------------------------------------------------------------===//
1234 static void emitEnumGetAttrNameFnDecl(raw_ostream &os) {
1235 os << formatv("template <typename EnumClass> inline constexpr StringRef "
1236 "attributeName();\n");
1239 static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
1240 raw_ostream &os) {
1241 auto enumName = enumAttr.getEnumClassName();
1242 os << formatv("template <> inline StringRef attributeName<{0}>() {{\n",
1243 enumName);
1244 os << " "
1245 << formatv("static constexpr const char attrName[] = \"{0}\";\n",
1246 llvm::convertToSnakeFromCamelCase(enumName));
1247 os << " return attrName;\n";
1248 os << "}\n";
1251 static bool emitAttrUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
1252 llvm::emitSourceFileHeader("SPIR-V Attribute Utilities", os);
1254 auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
1255 os << "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1256 os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1257 emitEnumGetAttrNameFnDecl(os);
1258 for (const auto *def : defs) {
1259 EnumAttr enumAttr(*def);
1260 emitEnumGetAttrNameFnDefn(enumAttr, os);
1262 os << "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n";
1263 return false;
1266 //===----------------------------------------------------------------------===//
1267 // Op Utils Hook Registration
1268 //===----------------------------------------------------------------------===//
1270 static mlir::GenRegistration
1271 genOpUtils("gen-spirv-attr-utils",
1272 "Generate SPIR-V attribute utility definitions",
1273 [](const RecordKeeper &records, raw_ostream &os) {
1274 return emitAttrUtils(records, os);
1277 //===----------------------------------------------------------------------===//
1278 // SPIR-V Availability Impl AutoGen
1279 //===----------------------------------------------------------------------===//
1281 static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
1282 mlir::tblgen::FmtContext fctx;
1283 fctx.addSubst("overall", "tblgen_overall");
1285 std::vector<Availability> opAvailabilities =
1286 getAvailabilities(srcOp.getDef());
1288 // First collect all availability classes this op should implement.
1289 // All availability instances keep information for the generated interface and
1290 // the instance's specific requirement. Here we remember a random instance so
1291 // we can get the information regarding the generated interface.
1292 llvm::StringMap<Availability> availClasses;
1293 for (const Availability &avail : opAvailabilities)
1294 availClasses.try_emplace(avail.getClass(), avail);
1295 for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
1296 if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") &&
1297 !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr"))
1298 continue;
1299 EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
1301 for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
1302 for (const Availability &caseAvail :
1303 getAvailabilities(enumerant.getDef()))
1304 availClasses.try_emplace(caseAvail.getClass(), caseAvail);
1307 // Then generate implementation for each availability class.
1308 for (const auto &availClass : availClasses) {
1309 StringRef availClassName = availClass.getKey();
1310 Availability avail = availClass.getValue();
1312 // Generate the implementation method signature.
1313 os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(),
1314 srcOp.getCppClassName(), avail.getQueryFnName());
1316 // Create the variable for the final requirement and initialize it.
1317 os << formatv(" {0} tblgen_overall = {1};\n", avail.getQueryFnRetType(),
1318 avail.getMergeInitializer());
1320 // Update with the op's specific availability spec.
1321 for (const Availability &avail : opAvailabilities)
1322 if (avail.getClass() == availClassName &&
1323 (!avail.getMergeInstancePreparation().empty() ||
1324 !avail.getMergeActionCode().empty())) {
1325 os << " {\n "
1326 // Prepare this instance.
1327 << avail.getMergeInstancePreparation()
1328 << "\n "
1329 // Merge this instance.
1330 << std::string(
1331 tgfmt(avail.getMergeActionCode(),
1332 &fctx.addSubst("instance", avail.getMergeInstance())))
1333 << ";\n }\n";
1336 // Update with enum attributes' specific availability spec.
1337 for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
1338 if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") &&
1339 !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr"))
1340 continue;
1341 EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
1343 // (enumerant, availability specification) pairs for this availability
1344 // class.
1345 SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs;
1347 // Collect all cases' availability specs.
1348 for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
1349 for (const Availability &caseAvail :
1350 getAvailabilities(enumerant.getDef()))
1351 if (availClassName == caseAvail.getClass())
1352 caseSpecs.push_back({enumerant, caseAvail});
1354 // If this attribute kind does not have any availability spec from any of
1355 // its cases, no more work to do.
1356 if (caseSpecs.empty())
1357 continue;
1359 if (enumAttr.isBitEnum()) {
1360 // For BitEnumAttr, we need to iterate over each bit to query its
1361 // availability spec.
1362 os << formatv(" for (unsigned i = 0; "
1363 "i < std::numeric_limits<{0}>::digits; ++i) {{\n",
1364 enumAttr.getUnderlyingType());
1365 os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & "
1366 "static_cast<{0}::{1}>(1 << i);\n",
1367 enumAttr.getCppNamespace(), enumAttr.getEnumClassName(),
1368 srcOp.getGetterName(namedAttr.name));
1369 os << formatv(
1370 " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
1371 enumAttr.getUnderlyingType());
1372 } else {
1373 // For IntEnumAttr, we just need to query the value as a whole.
1374 os << " {\n";
1375 os << formatv(" auto tblgen_attrVal = this->{0}();\n",
1376 srcOp.getGetterName(namedAttr.name));
1378 os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
1379 enumAttr.getCppNamespace(), avail.getQueryFnName());
1380 os << " if (tblgen_instance) "
1381 // TODO` here once ODS supports
1382 // dialect-specific contents so that we can use not implementing the
1383 // availability interface as indication of no requirements.
1384 << std::string(tgfmt(caseSpecs.front().second.getMergeActionCode(),
1385 &fctx.addSubst("instance", "*tblgen_instance")))
1386 << ";\n";
1387 os << " }\n";
1390 os << " return tblgen_overall;\n";
1391 os << "}\n";
1395 static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper,
1396 raw_ostream &os) {
1397 llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os);
1399 auto defs = recordKeeper.getAllDerivedDefinitions("SPIRV_Op");
1400 for (const auto *def : defs) {
1401 Operator op(def);
1402 if (def->getValueAsBit("autogenAvailability"))
1403 emitAvailabilityImpl(op, os);
1405 return false;
1408 //===----------------------------------------------------------------------===//
1409 // Op Availability Implementation Hook Registration
1410 //===----------------------------------------------------------------------===//
1412 static mlir::GenRegistration
1413 genOpAvailabilityImpl("gen-spirv-avail-impls",
1414 "Generate SPIR-V operation utility definitions",
1415 [](const RecordKeeper &records, raw_ostream &os) {
1416 return emitAvailabilityImpl(records, os);
1419 //===----------------------------------------------------------------------===//
1420 // SPIR-V Capability Implication AutoGen
1421 //===----------------------------------------------------------------------===//
1423 static bool emitCapabilityImplication(const RecordKeeper &recordKeeper,
1424 raw_ostream &os) {
1425 llvm::emitSourceFileHeader("SPIR-V Capability Implication", os);
1427 EnumAttr enumAttr(
1428 recordKeeper.getDef("SPIRV_CapabilityAttr")->getValueAsDef("enum"));
1430 os << "ArrayRef<spirv::Capability> "
1431 "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n"
1432 << " switch (cap) {\n"
1433 << " default: return {};\n";
1434 for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) {
1435 const Record &def = enumerant.getDef();
1436 if (!def.getValue("implies"))
1437 continue;
1439 std::vector<Record *> impliedCapsDefs = def.getValueAsListOfDefs("implies");
1440 os << " case spirv::Capability::" << enumerant.getSymbol()
1441 << ": {static const spirv::Capability implies[" << impliedCapsDefs.size()
1442 << "] = {";
1443 llvm::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) {
1444 os << "spirv::Capability::" << EnumAttrCase(capDef).getSymbol();
1446 os << "}; return ArrayRef<spirv::Capability>(implies, "
1447 << impliedCapsDefs.size() << "); }\n";
1449 os << " }\n";
1450 os << "}\n";
1452 return false;
1455 //===----------------------------------------------------------------------===//
1456 // SPIR-V Capability Implication Hook Registration
1457 //===----------------------------------------------------------------------===//
1459 static mlir::GenRegistration
1460 genCapabilityImplication("gen-spirv-capability-implication",
1461 "Generate utility function to return implied "
1462 "capabilities for a given capability",
1463 [](const RecordKeeper &records, raw_ostream &os) {
1464 return emitCapabilityImplication(records, os);