[mlir][spirv] NFC: Shuffle code around to better follow convention
[llvm-project.git] / mlir / lib / Target / SPIRV / Serialization.cpp
blobcb41c7be93cb57c5156cf868702b1ee9d6cedf4a
1 //===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serialization.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Target/SPIRV/Serialization.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/RegionGraphTraits.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
23 #include "llvm/ADT/DepthFirstIterator.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/ADT/bit.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
34 #define DEBUG_TYPE "spirv-serialization"
36 using namespace mlir;
38 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
39 /// the given `binary` vector.
40 static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
41 spirv::Opcode op,
42 ArrayRef<uint32_t> operands) {
43 uint32_t wordCount = 1 + operands.size();
44 binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
45 binary.append(operands.begin(), operands.end());
46 return success();
49 /// A pre-order depth-first visitor function for processing basic blocks.
50 ///
51 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
52 /// depth-first manner and calls `blockHandler` on each block. Skips handling
53 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
54 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
55 /// successors.
56 ///
57 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
58 /// of blocks in a function must satisfy the rule that blocks appear before
59 /// all blocks they dominate." This can be achieved by a pre-order CFG
60 /// traversal algorithm. To make the serialization output more logical and
61 /// readable to human, we perform depth-first CFG traversal and delay the
62 /// serialization of the merge block and the continue block, if exists, until
63 /// after all other blocks have been processed.
64 static LogicalResult
65 visitInPrettyBlockOrder(Block *headerBlock,
66 function_ref<LogicalResult(Block *)> blockHandler,
67 bool skipHeader = false, BlockRange skipBlocks = {}) {
68 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
69 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
71 for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
72 if (skipHeader && block == headerBlock)
73 continue;
74 if (failed(blockHandler(block)))
75 return failure();
77 return success();
80 /// Returns the merge block if the given `op` is a structured control flow op.
81 /// Otherwise returns nullptr.
82 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) {
83 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
84 return selectionOp.getMergeBlock();
85 if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
86 return loopOp.getMergeBlock();
87 return nullptr;
90 /// Given a predecessor `block` for a block with arguments, returns the block
91 /// that should be used as the parent block for SPIR-V OpPhi instructions
92 /// corresponding to the block arguments.
93 static Block *getPhiIncomingBlock(Block *block) {
94 // If the predecessor block in question is the entry block for a spv.loop,
95 // we jump to this spv.loop from its enclosing block.
96 if (block->isEntryBlock()) {
97 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
98 // Then the incoming parent block for OpPhi should be the merge block of
99 // the structured control flow op before this loop.
100 Operation *op = loopOp.getOperation();
101 while ((op = op->getPrevNode()) != nullptr)
102 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
103 return incomingBlock;
104 // Or the enclosing block itself if no structured control flow ops
105 // exists before this loop.
106 return loopOp->getBlock();
110 // Otherwise, we jump from the given predecessor block. Try to see if there is
111 // a structured control flow op inside it.
112 for (Operation &op : llvm::reverse(block->getOperations())) {
113 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
114 return incomingBlock;
116 return block;
119 namespace {
121 /// A SPIR-V module serializer.
123 /// A SPIR-V binary module is a single linear stream of instructions; each
124 /// instruction is composed of 32-bit words with the layout:
126 /// | <word-count>|<opcode> | <operand> | <operand> | ... |
127 /// | <------ word -------> | <-- word --> | <-- word --> | ... |
129 /// For the first word, the 16 high-order bits are the word count of the
130 /// instruction, the 16 low-order bits are the opcode enumerant. The
131 /// instructions then belong to different sections, which must be laid out in
132 /// the particular order as specified in "2.4 Logical Layout of a Module" of
133 /// the SPIR-V spec.
134 class Serializer {
135 public:
136 /// Creates a serializer for the given SPIR-V `module`.
137 explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false);
139 /// Serializes the remembered SPIR-V module.
140 LogicalResult serialize();
142 /// Collects the final SPIR-V `binary`.
143 void collect(SmallVectorImpl<uint32_t> &binary);
145 #ifndef NDEBUG
146 /// (For debugging) prints each value and its corresponding result <id>.
147 void printValueIDMap(raw_ostream &os);
148 #endif
150 private:
151 // Note that there are two main categories of methods in this class:
152 // * process*() methods are meant to fully serialize a SPIR-V module entity
153 // (header, type, op, etc.). They update internal vectors containing
154 // different binary sections. They are not meant to be called except the
155 // top-level serialization loop.
156 // * prepare*() methods are meant to be helpers that prepare for serializing
157 // certain entity. They may or may not update internal vectors containing
158 // different binary sections. They are meant to be called among themselves
159 // or by other process*() methods for subtasks.
161 //===--------------------------------------------------------------------===//
162 // <id>
163 //===--------------------------------------------------------------------===//
165 // Note that it is illegal to use id <0> in SPIR-V binary module. Various
166 // methods in this class, if using SPIR-V word (uint32_t) as interface,
167 // check or return id <0> to indicate error in processing.
169 /// Consumes the next unused <id>. This method will never return 0.
170 uint32_t getNextID() { return nextID++; }
172 //===--------------------------------------------------------------------===//
173 // Module structure
174 //===--------------------------------------------------------------------===//
176 uint32_t getSpecConstID(StringRef constName) const {
177 return specConstIDMap.lookup(constName);
180 uint32_t getVariableID(StringRef varName) const {
181 return globalVarIDMap.lookup(varName);
184 uint32_t getFunctionID(StringRef fnName) const {
185 return funcIDMap.lookup(fnName);
188 /// Gets the <id> for the function with the given name. Assigns the next
189 /// available <id> if the function haven't been deserialized.
190 uint32_t getOrCreateFunctionID(StringRef fnName);
192 void processCapability();
194 void processDebugInfo();
196 void processExtension();
198 void processMemoryModel();
200 LogicalResult processConstantOp(spirv::ConstantOp op);
202 LogicalResult processSpecConstantOp(spirv::SpecConstantOp op);
204 LogicalResult
205 processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op);
207 /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA
208 /// value to use with other operations. The SPIR-V spec recommends that
209 /// OpUndef be generated at module level. The serialization generates an
210 /// OpUndef for each type needed at module level.
211 LogicalResult processUndefOp(spirv::UndefOp op);
213 /// Emit OpName for the given `resultID`.
214 LogicalResult processName(uint32_t resultID, StringRef name);
216 /// Processes a SPIR-V function op.
217 LogicalResult processFuncOp(spirv::FuncOp op);
219 LogicalResult processVariableOp(spirv::VariableOp op);
221 /// Process a SPIR-V GlobalVariableOp
222 LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
224 /// Process attributes that translate to decorations on the result <id>
225 LogicalResult processDecoration(Location loc, uint32_t resultID,
226 NamedAttribute attr);
228 template <typename DType>
229 LogicalResult processTypeDecoration(Location loc, DType type,
230 uint32_t resultId) {
231 return emitError(loc, "unhandled decoration for type:") << type;
234 /// Process member decoration
235 LogicalResult processMemberDecoration(
236 uint32_t structID,
237 const spirv::StructType::MemberDecorationInfo &memberDecorationInfo);
239 //===--------------------------------------------------------------------===//
240 // Types
241 //===--------------------------------------------------------------------===//
243 uint32_t getTypeID(Type type) const { return typeIDMap.lookup(type); }
245 Type getVoidType() { return mlirBuilder.getNoneType(); }
247 bool isVoidType(Type type) const { return type.isa<NoneType>(); }
249 /// Returns true if the given type is a pointer type to a struct in some
250 /// interface storage class.
251 bool isInterfaceStructPtrType(Type type) const;
253 /// Main dispatch method for serializing a type. The result <id> of the
254 /// serialized type will be returned as `typeID`.
255 LogicalResult processType(Location loc, Type type, uint32_t &typeID);
256 LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID,
257 llvm::SetVector<StringRef> &serializationCtx);
259 /// Method for preparing basic SPIR-V type serialization. Returns the type's
260 /// opcode and operands for the instruction via `typeEnum` and `operands`.
261 LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID,
262 spirv::Opcode &typeEnum,
263 SmallVectorImpl<uint32_t> &operands,
264 bool &deferSerialization,
265 llvm::SetVector<StringRef> &serializationCtx);
267 LogicalResult prepareFunctionType(Location loc, FunctionType type,
268 spirv::Opcode &typeEnum,
269 SmallVectorImpl<uint32_t> &operands);
271 //===--------------------------------------------------------------------===//
272 // Constant
273 //===--------------------------------------------------------------------===//
275 uint32_t getConstantID(Attribute value) const {
276 return constIDMap.lookup(value);
279 /// Main dispatch method for processing a constant with the given `constType`
280 /// and `valueAttr`. `constType` is needed here because we can interpret the
281 /// `valueAttr` as a different type than the type of `valueAttr` itself; for
282 /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
283 /// constants.
284 uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
286 /// Prepares array attribute serialization. This method emits corresponding
287 /// OpConstant* and returns the result <id> associated with it. Returns 0 if
288 /// failed.
289 uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr);
291 /// Prepares bool/int/float DenseElementsAttr serialization. This method
292 /// iterates the DenseElementsAttr to construct the constant array, and
293 /// returns the result <id> associated with it. Returns 0 if failed. Note
294 /// that the size of `index` must match the rank.
295 /// TODO: Consider to enhance splat elements cases. For splat cases,
296 /// we don't need to loop over all elements, especially when the splat value
297 /// is zero. We can use OpConstantNull when the value is zero.
298 uint32_t prepareDenseElementsConstant(Location loc, Type constType,
299 DenseElementsAttr valueAttr, int dim,
300 MutableArrayRef<uint64_t> index);
302 /// Prepares scalar attribute serialization. This method emits corresponding
303 /// OpConstant* and returns the result <id> associated with it. Returns 0 if
304 /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is
305 /// true, then the constant will be serialized as a specialization constant.
306 uint32_t prepareConstantScalar(Location loc, Attribute valueAttr,
307 bool isSpec = false);
309 uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr,
310 bool isSpec = false);
312 uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr,
313 bool isSpec = false);
315 uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr,
316 bool isSpec = false);
318 //===--------------------------------------------------------------------===//
319 // Control flow
320 //===--------------------------------------------------------------------===//
322 /// Returns the result <id> for the given block.
323 uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); }
325 /// Returns the result <id> for the given block. If no <id> has been assigned,
326 /// assigns the next available <id>
327 uint32_t getOrCreateBlockID(Block *block);
329 /// Processes the given `block` and emits SPIR-V instructions for all ops
330 /// inside. Does not emit OpLabel for this block if `omitLabel` is true.
331 /// `actionBeforeTerminator` is a callback that will be invoked before
332 /// handling the terminator op. It can be used to inject the Op*Merge
333 /// instruction if this is a SPIR-V selection/loop header block.
334 LogicalResult
335 processBlock(Block *block, bool omitLabel = false,
336 function_ref<void()> actionBeforeTerminator = nullptr);
338 /// Emits OpPhi instructions for the given block if it has block arguments.
339 LogicalResult emitPhiForBlockArguments(Block *block);
341 LogicalResult processSelectionOp(spirv::SelectionOp selectionOp);
343 LogicalResult processLoopOp(spirv::LoopOp loopOp);
345 LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp);
347 LogicalResult processBranchOp(spirv::BranchOp branchOp);
349 //===--------------------------------------------------------------------===//
350 // Operations
351 //===--------------------------------------------------------------------===//
353 LogicalResult encodeExtensionInstruction(Operation *op,
354 StringRef extensionSetName,
355 uint32_t opcode,
356 ArrayRef<uint32_t> operands);
358 uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); }
360 LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
362 LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp);
364 /// Main dispatch method for serializing an operation.
365 LogicalResult processOperation(Operation *op);
367 /// Method to dispatch to the serialization function for an operation in
368 /// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec.
369 /// This is auto-generated from ODS. Dispatch is handled for all operations
370 /// in SPIR-V dialect that have hasOpcode == 1.
371 LogicalResult dispatchToAutogenSerialization(Operation *op);
373 /// Method to serialize an operation in the SPIR-V dialect that is a mirror of
374 /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode ==
375 /// 1 and autogenSerialization == 1 in ODS.
376 template <typename OpTy>
377 LogicalResult processOp(OpTy op) {
378 return op.emitError("unsupported op serialization");
381 //===--------------------------------------------------------------------===//
382 // Utilities
383 //===--------------------------------------------------------------------===//
385 /// Emits an OpDecorate instruction to decorate the given `target` with the
386 /// given `decoration`.
387 LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration,
388 ArrayRef<uint32_t> params = {});
390 /// Emits an OpLine instruction with the given `loc` location information into
391 /// the given `binary` vector.
392 LogicalResult emitDebugLine(SmallVectorImpl<uint32_t> &binary, Location loc);
394 private:
395 /// The SPIR-V module to be serialized.
396 spirv::ModuleOp module;
398 /// An MLIR builder for getting MLIR constructs.
399 mlir::Builder mlirBuilder;
401 /// A flag which indicates if the debuginfo should be emitted.
402 bool emitDebugInfo = false;
404 /// A flag which indicates if the last processed instruction was a merge
405 /// instruction.
406 /// According to SPIR-V spec: "If a branch merge instruction is used, the last
407 /// OpLine in the block must be before its merge instruction".
408 bool lastProcessedWasMergeInst = false;
410 /// The <id> of the OpString instruction, which specifies a file name, for
411 /// use by other debug instructions.
412 uint32_t fileID = 0;
414 /// The next available result <id>.
415 uint32_t nextID = 1;
417 // The following are for different SPIR-V instruction sections. They follow
418 // the logical layout of a SPIR-V module.
420 SmallVector<uint32_t, 4> capabilities;
421 SmallVector<uint32_t, 0> extensions;
422 SmallVector<uint32_t, 0> extendedSets;
423 SmallVector<uint32_t, 3> memoryModel;
424 SmallVector<uint32_t, 0> entryPoints;
425 SmallVector<uint32_t, 4> executionModes;
426 SmallVector<uint32_t, 0> debug;
427 SmallVector<uint32_t, 0> names;
428 SmallVector<uint32_t, 0> decorations;
429 SmallVector<uint32_t, 0> typesGlobalValues;
430 SmallVector<uint32_t, 0> functions;
432 /// Recursive struct references are serialized as OpTypePointer instructions
433 /// to the recursive struct type. However, the OpTypePointer instruction
434 /// cannot be emitted before the recursive struct's OpTypeStruct.
435 /// RecursiveStructPointerInfo stores the data needed to emit such
436 /// OpTypePointer instructions after forward references to such types.
437 struct RecursiveStructPointerInfo {
438 uint32_t pointerTypeID;
439 spirv::StorageClass storageClass;
442 // Maps spirv::StructType to its recursive reference member info.
443 DenseMap<Type, SmallVector<RecursiveStructPointerInfo, 0>>
444 recursiveStructInfos;
446 /// `functionHeader` contains all the instructions that must be in the first
447 /// block in the function, and `functionBody` contains the rest. After
448 /// processing FuncOp, the encoded instructions of a function are appended to
449 /// `functions`. An example of instructions in `functionHeader` in order:
450 /// OpFunction ...
451 /// OpFunctionParameter ...
452 /// OpFunctionParameter ...
453 /// OpLabel ...
454 /// OpVariable ...
455 /// OpVariable ...
456 SmallVector<uint32_t, 0> functionHeader;
457 SmallVector<uint32_t, 0> functionBody;
459 /// Map from type used in SPIR-V module to their <id>s.
460 DenseMap<Type, uint32_t> typeIDMap;
462 /// Map from constant values to their <id>s.
463 DenseMap<Attribute, uint32_t> constIDMap;
465 /// Map from specialization constant names to their <id>s.
466 llvm::StringMap<uint32_t> specConstIDMap;
468 /// Map from GlobalVariableOps name to <id>s.
469 llvm::StringMap<uint32_t> globalVarIDMap;
471 /// Map from FuncOps name to <id>s.
472 llvm::StringMap<uint32_t> funcIDMap;
474 /// Map from blocks to their <id>s.
475 DenseMap<Block *, uint32_t> blockIDMap;
477 /// Map from the Type to the <id> that represents undef value of that type.
478 DenseMap<Type, uint32_t> undefValIDMap;
480 /// Map from results of normal operations to their <id>s.
481 DenseMap<Value, uint32_t> valueIDMap;
483 /// Map from extended instruction set name to <id>s.
484 llvm::StringMap<uint32_t> extendedInstSetIDMap;
486 /// Map from values used in OpPhi instructions to their offset in the
487 /// `functions` section.
489 /// When processing a block with arguments, we need to emit OpPhi
490 /// instructions to record the predecessor block <id>s and the values they
491 /// send to the block in question. But it's not guaranteed all values are
492 /// visited and thus assigned result <id>s. So we need this list to capture
493 /// the offsets into `functions` where a value is used so that we can fix it
494 /// up later after processing all the blocks in a function.
496 /// More concretely, say if we are visiting the following blocks:
498 /// ```mlir
499 /// ^phi(%arg0: i32):
500 /// ...
501 /// ^parent1:
502 /// ...
503 /// spv.Branch ^phi(%val0: i32)
504 /// ^parent2:
505 /// ...
506 /// spv.Branch ^phi(%val1: i32)
507 /// ```
509 /// When we are serializing the `^phi` block, we need to emit at the beginning
510 /// of the block OpPhi instructions which has the following parameters:
512 /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1
513 /// id-for-%val1 id-for-^parent2
515 /// But we don't know the <id> for %val0 and %val1 yet. One way is to visit
516 /// all the blocks twice and use the first visit to assign an <id> to each
517 /// value. But it's paying the overheads just for OpPhi emission. Instead,
518 /// we still visit the blocks once for emission. When we emit the OpPhi
519 /// instructions, we use 0 as a placeholder for the <id>s for %val0 and %val1.
520 /// At the same time, we record their offsets in the emitted binary (which is
521 /// placed inside `functions`) here. And then after emitting all blocks, we
522 /// replace the dummy <id> 0 with the real result <id> by overwriting
523 /// `functions[offset]`.
524 DenseMap<Value, SmallVector<size_t, 1>> deferredPhiValues;
526 } // namespace
528 Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo)
529 : module(module), mlirBuilder(module.getContext()),
530 emitDebugInfo(emitDebugInfo) {}
532 LogicalResult Serializer::serialize() {
533 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
535 if (failed(module.verify()))
536 return failure();
538 // TODO: handle the other sections
539 processCapability();
540 processExtension();
541 processMemoryModel();
542 processDebugInfo();
544 // Iterate over the module body to serialize it. Assumptions are that there is
545 // only one basic block in the moduleOp
546 for (auto &op : module.getBlock()) {
547 if (failed(processOperation(&op))) {
548 return failure();
552 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
553 return success();
556 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
557 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
558 extensions.size() + extendedSets.size() +
559 memoryModel.size() + entryPoints.size() +
560 executionModes.size() + decorations.size() +
561 typesGlobalValues.size() + functions.size();
563 binary.clear();
564 binary.reserve(moduleSize);
566 spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID);
567 binary.append(capabilities.begin(), capabilities.end());
568 binary.append(extensions.begin(), extensions.end());
569 binary.append(extendedSets.begin(), extendedSets.end());
570 binary.append(memoryModel.begin(), memoryModel.end());
571 binary.append(entryPoints.begin(), entryPoints.end());
572 binary.append(executionModes.begin(), executionModes.end());
573 binary.append(debug.begin(), debug.end());
574 binary.append(names.begin(), names.end());
575 binary.append(decorations.begin(), decorations.end());
576 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
577 binary.append(functions.begin(), functions.end());
580 #ifndef NDEBUG
581 void Serializer::printValueIDMap(raw_ostream &os) {
582 os << "\n= Value <id> Map =\n\n";
583 for (auto valueIDPair : valueIDMap) {
584 Value val = valueIDPair.first;
585 os << " " << val << " "
586 << "id = " << valueIDPair.second << ' ';
587 if (auto *op = val.getDefiningOp()) {
588 os << "from op '" << op->getName() << "'";
589 } else if (auto arg = val.dyn_cast<BlockArgument>()) {
590 Block *block = arg.getOwner();
591 os << "from argument of block " << block << ' ';
592 os << " in op '" << block->getParentOp()->getName() << "'";
594 os << '\n';
597 #endif
599 //===----------------------------------------------------------------------===//
600 // Module structure
601 //===----------------------------------------------------------------------===//
603 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
604 auto funcID = funcIDMap.lookup(fnName);
605 if (!funcID) {
606 funcID = getNextID();
607 funcIDMap[fnName] = funcID;
609 return funcID;
612 void Serializer::processCapability() {
613 for (auto cap : module.vce_triple()->getCapabilities())
614 encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
615 {static_cast<uint32_t>(cap)});
618 void Serializer::processDebugInfo() {
619 if (!emitDebugInfo)
620 return;
621 auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>();
622 auto fileName = fileLoc ? fileLoc.getFilename() : "<unknown>";
623 fileID = getNextID();
624 SmallVector<uint32_t, 16> operands;
625 operands.push_back(fileID);
626 spirv::encodeStringLiteralInto(operands, fileName);
627 encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
628 // TODO: Encode more debug instructions.
631 void Serializer::processExtension() {
632 llvm::SmallVector<uint32_t, 16> extName;
633 for (spirv::Extension ext : module.vce_triple()->getExtensions()) {
634 extName.clear();
635 spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
636 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
640 void Serializer::processMemoryModel() {
641 uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt();
642 uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt();
644 encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
647 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
648 if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
649 valueIDMap[op.getResult()] = resultID;
650 return success();
652 return failure();
655 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
656 if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(),
657 /*isSpec=*/true)) {
658 // Emit the OpDecorate instruction for SpecId.
659 if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
660 auto val = static_cast<uint32_t>(specID.getInt());
661 emitDecoration(resultID, spirv::Decoration::SpecId, {val});
664 specConstIDMap[op.sym_name()] = resultID;
665 return processName(resultID, op.sym_name());
667 return failure();
670 LogicalResult
671 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
672 uint32_t typeID = 0;
673 if (failed(processType(op.getLoc(), op.type(), typeID))) {
674 return failure();
677 auto resultID = getNextID();
679 SmallVector<uint32_t, 8> operands;
680 operands.push_back(typeID);
681 operands.push_back(resultID);
683 auto constituents = op.constituents();
685 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
686 auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
688 auto constituentName = constituent.getValue();
689 auto constituentID = getSpecConstID(constituentName);
691 if (!constituentID) {
692 return op.emitError("unknown result <id> for specialization constant ")
693 << constituentName;
696 operands.push_back(constituentID);
699 encodeInstructionInto(typesGlobalValues,
700 spirv::Opcode::OpSpecConstantComposite, operands);
701 specConstIDMap[op.sym_name()] = resultID;
703 return processName(resultID, op.sym_name());
706 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
707 auto undefType = op.getType();
708 auto &id = undefValIDMap[undefType];
709 if (!id) {
710 id = getNextID();
711 uint32_t typeID = 0;
712 if (failed(processType(op.getLoc(), undefType, typeID)) ||
713 failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
714 {typeID, id}))) {
715 return failure();
718 valueIDMap[op.getResult()] = id;
719 return success();
722 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
723 NamedAttribute attr) {
724 auto attrName = attr.first.strref();
725 auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true);
726 auto decoration = spirv::symbolizeDecoration(decorationName);
727 if (!decoration) {
728 return emitError(
729 loc, "non-argument attributes expected to have snake-case-ified "
730 "decoration name, unhandled attribute with name : ")
731 << attrName;
733 SmallVector<uint32_t, 1> args;
734 switch (decoration.getValue()) {
735 case spirv::Decoration::Binding:
736 case spirv::Decoration::DescriptorSet:
737 case spirv::Decoration::Location:
738 if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) {
739 args.push_back(intAttr.getValue().getZExtValue());
740 break;
742 return emitError(loc, "expected integer attribute for ") << attrName;
743 case spirv::Decoration::BuiltIn:
744 if (auto strAttr = attr.second.dyn_cast<StringAttr>()) {
745 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
746 if (enumVal) {
747 args.push_back(static_cast<uint32_t>(enumVal.getValue()));
748 break;
750 return emitError(loc, "invalid ")
751 << attrName << " attribute " << strAttr.getValue();
753 return emitError(loc, "expected string attribute for ") << attrName;
754 case spirv::Decoration::Aliased:
755 case spirv::Decoration::Flat:
756 case spirv::Decoration::NonReadable:
757 case spirv::Decoration::NonWritable:
758 case spirv::Decoration::NoPerspective:
759 case spirv::Decoration::Restrict:
760 // For unit attributes, the args list has no values so we do nothing
761 if (auto unitAttr = attr.second.dyn_cast<UnitAttr>())
762 break;
763 return emitError(loc, "expected unit attribute for ") << attrName;
764 default:
765 return emitError(loc, "unhandled decoration ") << decorationName;
767 return emitDecoration(resultID, decoration.getValue(), args);
770 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
771 assert(!name.empty() && "unexpected empty string for OpName");
773 SmallVector<uint32_t, 4> nameOperands;
774 nameOperands.push_back(resultID);
775 if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) {
776 return failure();
778 return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
781 namespace {
782 template <>
783 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
784 Location loc, spirv::ArrayType type, uint32_t resultID) {
785 if (unsigned stride = type.getArrayStride()) {
786 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
787 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
789 return success();
792 template <>
793 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
794 Location Loc, spirv::RuntimeArrayType type, uint32_t resultID) {
795 if (unsigned stride = type.getArrayStride()) {
796 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
797 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
799 return success();
802 LogicalResult Serializer::processMemberDecoration(
803 uint32_t structID,
804 const spirv::StructType::MemberDecorationInfo &memberDecoration) {
805 SmallVector<uint32_t, 4> args(
806 {structID, memberDecoration.memberIndex,
807 static_cast<uint32_t>(memberDecoration.decoration)});
808 if (memberDecoration.hasValue) {
809 args.push_back(memberDecoration.decorationValue);
811 return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate,
812 args);
814 } // namespace
816 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
817 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
818 assert(functionHeader.empty() && functionBody.empty());
820 uint32_t fnTypeID = 0;
821 // Generate type of the function.
822 processType(op.getLoc(), op.getType(), fnTypeID);
824 // Add the function definition.
825 SmallVector<uint32_t, 4> operands;
826 uint32_t resTypeID = 0;
827 auto resultTypes = op.getType().getResults();
828 if (resultTypes.size() > 1) {
829 return op.emitError("cannot serialize function with multiple return types");
831 if (failed(processType(op.getLoc(),
832 (resultTypes.empty() ? getVoidType() : resultTypes[0]),
833 resTypeID))) {
834 return failure();
836 operands.push_back(resTypeID);
837 auto funcID = getOrCreateFunctionID(op.getName());
838 operands.push_back(funcID);
839 operands.push_back(static_cast<uint32_t>(op.function_control()));
840 operands.push_back(fnTypeID);
841 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
843 // Add function name.
844 if (failed(processName(funcID, op.getName()))) {
845 return failure();
848 // Declare the parameters.
849 for (auto arg : op.getArguments()) {
850 uint32_t argTypeID = 0;
851 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
852 return failure();
854 auto argValueID = getNextID();
855 valueIDMap[arg] = argValueID;
856 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
857 {argTypeID, argValueID});
860 // Process the body.
861 if (op.isExternal()) {
862 return op.emitError("external function is unhandled");
865 // Some instructions (e.g., OpVariable) in a function must be in the first
866 // block in the function. These instructions will be put in functionHeader.
867 // Thus, we put the label in functionHeader first, and omit it from the first
868 // block.
869 encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
870 {getOrCreateBlockID(&op.front())});
871 processBlock(&op.front(), /*omitLabel=*/true);
872 if (failed(visitInPrettyBlockOrder(
873 &op.front(), [&](Block *block) { return processBlock(block); },
874 /*skipHeader=*/true))) {
875 return failure();
878 // There might be OpPhi instructions who have value references needing to fix.
879 for (auto deferredValue : deferredPhiValues) {
880 Value value = deferredValue.first;
881 uint32_t id = getValueID(value);
882 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
883 << " to id = " << id << '\n');
884 assert(id && "OpPhi references undefined value!");
885 for (size_t offset : deferredValue.second)
886 functionBody[offset] = id;
888 deferredPhiValues.clear();
890 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
891 << "' --\n");
892 // Insert OpFunctionEnd.
893 if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd,
894 {}))) {
895 return failure();
898 functions.append(functionHeader.begin(), functionHeader.end());
899 functions.append(functionBody.begin(), functionBody.end());
900 functionHeader.clear();
901 functionBody.clear();
903 return success();
906 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
907 SmallVector<uint32_t, 4> operands;
908 SmallVector<StringRef, 2> elidedAttrs;
909 uint32_t resultID = 0;
910 uint32_t resultTypeID = 0;
911 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
912 return failure();
914 operands.push_back(resultTypeID);
915 resultID = getNextID();
916 valueIDMap[op.getResult()] = resultID;
917 operands.push_back(resultID);
918 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
919 if (attr) {
920 operands.push_back(static_cast<uint32_t>(
921 attr.cast<IntegerAttr>().getValue().getZExtValue()));
923 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
924 for (auto arg : op.getODSOperands(0)) {
925 auto argID = getValueID(arg);
926 if (!argID) {
927 return emitError(op.getLoc(), "operand 0 has a use before def");
929 operands.push_back(argID);
931 emitDebugLine(functionHeader, op.getLoc());
932 encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
933 for (auto attr : op->getAttrs()) {
934 if (llvm::any_of(elidedAttrs,
935 [&](StringRef elided) { return attr.first == elided; })) {
936 continue;
938 if (failed(processDecoration(op.getLoc(), resultID, attr))) {
939 return failure();
942 return success();
945 LogicalResult
946 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
947 // Get TypeID.
948 uint32_t resultTypeID = 0;
949 SmallVector<StringRef, 4> elidedAttrs;
950 if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
951 return failure();
954 if (isInterfaceStructPtrType(varOp.type())) {
955 auto structType = varOp.type()
956 .cast<spirv::PointerType>()
957 .getPointeeType()
958 .cast<spirv::StructType>();
959 if (failed(
960 emitDecoration(getTypeID(structType), spirv::Decoration::Block))) {
961 return varOp.emitError("cannot decorate ")
962 << structType << " with Block decoration";
966 elidedAttrs.push_back("type");
967 SmallVector<uint32_t, 4> operands;
968 operands.push_back(resultTypeID);
969 auto resultID = getNextID();
971 // Encode the name.
972 auto varName = varOp.sym_name();
973 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
974 if (failed(processName(resultID, varName))) {
975 return failure();
977 globalVarIDMap[varName] = resultID;
978 operands.push_back(resultID);
980 // Encode StorageClass.
981 operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
983 // Encode initialization.
984 if (auto initializer = varOp.initializer()) {
985 auto initializerID = getVariableID(initializer.getValue());
986 if (!initializerID) {
987 return emitError(varOp.getLoc(),
988 "invalid usage of undefined variable as initializer");
990 operands.push_back(initializerID);
991 elidedAttrs.push_back("initializer");
994 emitDebugLine(typesGlobalValues, varOp.getLoc());
995 if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable,
996 operands))) {
997 elidedAttrs.push_back("initializer");
998 return failure();
1001 // Encode decorations.
1002 for (auto attr : varOp->getAttrs()) {
1003 if (llvm::any_of(elidedAttrs,
1004 [&](StringRef elided) { return attr.first == elided; })) {
1005 continue;
1007 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
1008 return failure();
1011 return success();
1014 //===----------------------------------------------------------------------===//
1015 // Type
1016 //===----------------------------------------------------------------------===//
1018 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
1019 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
1020 // PushConstant Storage Classes must be explicitly laid out."
1021 bool Serializer::isInterfaceStructPtrType(Type type) const {
1022 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
1023 switch (ptrType.getStorageClass()) {
1024 case spirv::StorageClass::PhysicalStorageBuffer:
1025 case spirv::StorageClass::PushConstant:
1026 case spirv::StorageClass::StorageBuffer:
1027 case spirv::StorageClass::Uniform:
1028 return ptrType.getPointeeType().isa<spirv::StructType>();
1029 default:
1030 break;
1033 return false;
1036 LogicalResult Serializer::processType(Location loc, Type type,
1037 uint32_t &typeID) {
1038 // Maintains a set of names for nested identified struct types. This is used
1039 // to properly serialize resursive references.
1040 llvm::SetVector<StringRef> serializationCtx;
1041 return processTypeImpl(loc, type, typeID, serializationCtx);
1044 LogicalResult
1045 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
1046 llvm::SetVector<StringRef> &serializationCtx) {
1047 typeID = getTypeID(type);
1048 if (typeID) {
1049 return success();
1051 typeID = getNextID();
1052 SmallVector<uint32_t, 4> operands;
1054 operands.push_back(typeID);
1055 auto typeEnum = spirv::Opcode::OpTypeVoid;
1056 bool deferSerialization = false;
1058 if ((type.isa<FunctionType>() &&
1059 succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
1060 operands))) ||
1061 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
1062 deferSerialization, serializationCtx))) {
1063 if (deferSerialization)
1064 return success();
1066 typeIDMap[type] = typeID;
1068 if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands)))
1069 return failure();
1071 if (recursiveStructInfos.count(type) != 0) {
1072 // This recursive struct type is emitted already, now the OpTypePointer
1073 // instructions referring to recursive references are emitted as well.
1074 for (auto &ptrInfo : recursiveStructInfos[type]) {
1075 // TODO: This might not work if more than 1 recursive reference is
1076 // present in the struct.
1077 SmallVector<uint32_t, 4> ptrOperands;
1078 ptrOperands.push_back(ptrInfo.pointerTypeID);
1079 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
1080 ptrOperands.push_back(typeIDMap[type]);
1082 if (failed(encodeInstructionInto(
1083 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands)))
1084 return failure();
1087 recursiveStructInfos[type].clear();
1090 return success();
1093 return failure();
1096 LogicalResult Serializer::prepareBasicType(
1097 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
1098 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
1099 llvm::SetVector<StringRef> &serializationCtx) {
1100 deferSerialization = false;
1102 if (isVoidType(type)) {
1103 typeEnum = spirv::Opcode::OpTypeVoid;
1104 return success();
1107 if (auto intType = type.dyn_cast<IntegerType>()) {
1108 if (intType.getWidth() == 1) {
1109 typeEnum = spirv::Opcode::OpTypeBool;
1110 return success();
1113 typeEnum = spirv::Opcode::OpTypeInt;
1114 operands.push_back(intType.getWidth());
1115 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
1116 // to preserve or validate.
1117 // 0 indicates unsigned, or no signedness semantics
1118 // 1 indicates signed semantics."
1119 operands.push_back(intType.isSigned() ? 1 : 0);
1120 return success();
1123 if (auto floatType = type.dyn_cast<FloatType>()) {
1124 typeEnum = spirv::Opcode::OpTypeFloat;
1125 operands.push_back(floatType.getWidth());
1126 return success();
1129 if (auto vectorType = type.dyn_cast<VectorType>()) {
1130 uint32_t elementTypeID = 0;
1131 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
1132 serializationCtx))) {
1133 return failure();
1135 typeEnum = spirv::Opcode::OpTypeVector;
1136 operands.push_back(elementTypeID);
1137 operands.push_back(vectorType.getNumElements());
1138 return success();
1141 if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
1142 typeEnum = spirv::Opcode::OpTypeArray;
1143 uint32_t elementTypeID = 0;
1144 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
1145 serializationCtx))) {
1146 return failure();
1148 operands.push_back(elementTypeID);
1149 if (auto elementCountID = prepareConstantInt(
1150 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
1151 operands.push_back(elementCountID);
1153 return processTypeDecoration(loc, arrayType, resultID);
1156 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
1157 uint32_t pointeeTypeID = 0;
1158 spirv::StructType pointeeStruct =
1159 ptrType.getPointeeType().dyn_cast<spirv::StructType>();
1161 if (pointeeStruct && pointeeStruct.isIdentified() &&
1162 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
1163 // A recursive reference to an enclosing struct is found.
1165 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
1166 // class as operands.
1167 SmallVector<uint32_t, 2> forwardPtrOperands;
1168 forwardPtrOperands.push_back(resultID);
1169 forwardPtrOperands.push_back(
1170 static_cast<uint32_t>(ptrType.getStorageClass()));
1172 encodeInstructionInto(typesGlobalValues,
1173 spirv::Opcode::OpTypeForwardPointer,
1174 forwardPtrOperands);
1176 // 2. Find the pointee (enclosing) struct.
1177 auto structType = spirv::StructType::getIdentified(
1178 module.getContext(), pointeeStruct.getIdentifier());
1180 if (!structType)
1181 return failure();
1183 // 3. Mark the OpTypePointer that is supposed to be emitted by this call
1184 // as deferred.
1185 deferSerialization = true;
1187 // 4. Record the info needed to emit the deferred OpTypePointer
1188 // instruction when the enclosing struct is completely serialized.
1189 recursiveStructInfos[structType].push_back(
1190 {resultID, ptrType.getStorageClass()});
1191 } else {
1192 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
1193 serializationCtx)))
1194 return failure();
1197 typeEnum = spirv::Opcode::OpTypePointer;
1198 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
1199 operands.push_back(pointeeTypeID);
1200 return success();
1203 if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
1204 uint32_t elementTypeID = 0;
1205 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
1206 elementTypeID, serializationCtx))) {
1207 return failure();
1209 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
1210 operands.push_back(elementTypeID);
1211 return processTypeDecoration(loc, runtimeArrayType, resultID);
1214 if (auto structType = type.dyn_cast<spirv::StructType>()) {
1215 if (structType.isIdentified()) {
1216 processName(resultID, structType.getIdentifier());
1217 serializationCtx.insert(structType.getIdentifier());
1220 bool hasOffset = structType.hasOffset();
1221 for (auto elementIndex :
1222 llvm::seq<uint32_t>(0, structType.getNumElements())) {
1223 uint32_t elementTypeID = 0;
1224 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
1225 elementTypeID, serializationCtx))) {
1226 return failure();
1228 operands.push_back(elementTypeID);
1229 if (hasOffset) {
1230 // Decorate each struct member with an offset
1231 spirv::StructType::MemberDecorationInfo offsetDecoration{
1232 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
1233 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
1234 if (failed(processMemberDecoration(resultID, offsetDecoration))) {
1235 return emitError(loc, "cannot decorate ")
1236 << elementIndex << "-th member of " << structType
1237 << " with its offset";
1241 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
1242 structType.getMemberDecorations(memberDecorations);
1244 for (auto &memberDecoration : memberDecorations) {
1245 if (failed(processMemberDecoration(resultID, memberDecoration))) {
1246 return emitError(loc, "cannot decorate ")
1247 << static_cast<uint32_t>(memberDecoration.memberIndex)
1248 << "-th member of " << structType << " with "
1249 << stringifyDecoration(memberDecoration.decoration);
1253 typeEnum = spirv::Opcode::OpTypeStruct;
1255 if (structType.isIdentified())
1256 serializationCtx.remove(structType.getIdentifier());
1258 return success();
1261 if (auto cooperativeMatrixType =
1262 type.dyn_cast<spirv::CooperativeMatrixNVType>()) {
1263 uint32_t elementTypeID = 0;
1264 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
1265 elementTypeID, serializationCtx))) {
1266 return failure();
1268 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
1269 auto getConstantOp = [&](uint32_t id) {
1270 auto attr = IntegerAttr::get(IntegerType::get(32, type.getContext()), id);
1271 return prepareConstantInt(loc, attr);
1273 operands.push_back(elementTypeID);
1274 operands.push_back(
1275 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
1276 operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
1277 operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
1278 return success();
1281 if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
1282 uint32_t elementTypeID = 0;
1283 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
1284 serializationCtx))) {
1285 return failure();
1287 typeEnum = spirv::Opcode::OpTypeMatrix;
1288 operands.push_back(elementTypeID);
1289 operands.push_back(matrixType.getNumColumns());
1290 return success();
1293 // TODO: Handle other types.
1294 return emitError(loc, "unhandled type in serialization: ") << type;
1297 LogicalResult
1298 Serializer::prepareFunctionType(Location loc, FunctionType type,
1299 spirv::Opcode &typeEnum,
1300 SmallVectorImpl<uint32_t> &operands) {
1301 typeEnum = spirv::Opcode::OpTypeFunction;
1302 assert(type.getNumResults() <= 1 &&
1303 "serialization supports only a single return value");
1304 uint32_t resultID = 0;
1305 if (failed(processType(
1306 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
1307 resultID))) {
1308 return failure();
1310 operands.push_back(resultID);
1311 for (auto &res : type.getInputs()) {
1312 uint32_t argTypeID = 0;
1313 if (failed(processType(loc, res, argTypeID))) {
1314 return failure();
1316 operands.push_back(argTypeID);
1318 return success();
1321 //===----------------------------------------------------------------------===//
1322 // Constant
1323 //===----------------------------------------------------------------------===//
1325 uint32_t Serializer::prepareConstant(Location loc, Type constType,
1326 Attribute valueAttr) {
1327 if (auto id = prepareConstantScalar(loc, valueAttr)) {
1328 return id;
1331 // This is a composite literal. We need to handle each component separately
1332 // and then emit an OpConstantComposite for the whole.
1334 if (auto id = getConstantID(valueAttr)) {
1335 return id;
1338 uint32_t typeID = 0;
1339 if (failed(processType(loc, constType, typeID))) {
1340 return 0;
1343 uint32_t resultID = 0;
1344 if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) {
1345 int rank = attr.getType().dyn_cast<ShapedType>().getRank();
1346 SmallVector<uint64_t, 4> index(rank);
1347 resultID = prepareDenseElementsConstant(loc, constType, attr,
1348 /*dim=*/0, index);
1349 } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
1350 resultID = prepareArrayConstant(loc, constType, arrayAttr);
1353 if (resultID == 0) {
1354 emitError(loc, "cannot serialize attribute: ") << valueAttr;
1355 return 0;
1358 constIDMap[valueAttr] = resultID;
1359 return resultID;
1362 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
1363 ArrayAttr attr) {
1364 uint32_t typeID = 0;
1365 if (failed(processType(loc, constType, typeID))) {
1366 return 0;
1369 uint32_t resultID = getNextID();
1370 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1371 operands.reserve(attr.size() + 2);
1372 auto elementType = constType.cast<spirv::ArrayType>().getElementType();
1373 for (Attribute elementAttr : attr) {
1374 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
1375 operands.push_back(elementID);
1376 } else {
1377 return 0;
1380 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
1381 encodeInstructionInto(typesGlobalValues, opcode, operands);
1383 return resultID;
1386 // TODO: Turn the below function into iterative function, instead of
1387 // recursive function.
1388 uint32_t
1389 Serializer::prepareDenseElementsConstant(Location loc, Type constType,
1390 DenseElementsAttr valueAttr, int dim,
1391 MutableArrayRef<uint64_t> index) {
1392 auto shapedType = valueAttr.getType().dyn_cast<ShapedType>();
1393 assert(dim <= shapedType.getRank());
1394 if (shapedType.getRank() == dim) {
1395 if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
1396 return attr.getType().getElementType().isInteger(1)
1397 ? prepareConstantBool(loc, attr.getValue<BoolAttr>(index))
1398 : prepareConstantInt(loc, attr.getValue<IntegerAttr>(index));
1400 if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
1401 return prepareConstantFp(loc, attr.getValue<FloatAttr>(index));
1403 return 0;
1406 uint32_t typeID = 0;
1407 if (failed(processType(loc, constType, typeID))) {
1408 return 0;
1411 uint32_t resultID = getNextID();
1412 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1413 operands.reserve(shapedType.getDimSize(dim) + 2);
1414 auto elementType = constType.cast<spirv::CompositeType>().getElementType(0);
1415 for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
1416 index[dim] = i;
1417 if (auto elementID = prepareDenseElementsConstant(
1418 loc, elementType, valueAttr, dim + 1, index)) {
1419 operands.push_back(elementID);
1420 } else {
1421 return 0;
1424 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
1425 encodeInstructionInto(typesGlobalValues, opcode, operands);
1427 return resultID;
1430 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
1431 bool isSpec) {
1432 if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
1433 return prepareConstantFp(loc, floatAttr, isSpec);
1435 if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
1436 return prepareConstantBool(loc, boolAttr, isSpec);
1438 if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
1439 return prepareConstantInt(loc, intAttr, isSpec);
1442 return 0;
1445 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
1446 bool isSpec) {
1447 if (!isSpec) {
1448 // We can de-duplicate normal constants, but not specialization constants.
1449 if (auto id = getConstantID(boolAttr)) {
1450 return id;
1454 // Process the type for this bool literal
1455 uint32_t typeID = 0;
1456 if (failed(processType(loc, boolAttr.getType(), typeID))) {
1457 return 0;
1460 auto resultID = getNextID();
1461 auto opcode = boolAttr.getValue()
1462 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1463 : spirv::Opcode::OpConstantTrue)
1464 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1465 : spirv::Opcode::OpConstantFalse);
1466 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
1468 if (!isSpec) {
1469 constIDMap[boolAttr] = resultID;
1471 return resultID;
1474 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
1475 bool isSpec) {
1476 if (!isSpec) {
1477 // We can de-duplicate normal constants, but not specialization constants.
1478 if (auto id = getConstantID(intAttr)) {
1479 return id;
1483 // Process the type for this integer literal
1484 uint32_t typeID = 0;
1485 if (failed(processType(loc, intAttr.getType(), typeID))) {
1486 return 0;
1489 auto resultID = getNextID();
1490 APInt value = intAttr.getValue();
1491 unsigned bitwidth = value.getBitWidth();
1492 bool isSigned = value.isSignedIntN(bitwidth);
1494 auto opcode =
1495 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1497 // According to SPIR-V spec, "When the type's bit width is less than 32-bits,
1498 // the literal's value appears in the low-order bits of the word, and the
1499 // high-order bits must be 0 for a floating-point type, or 0 for an integer
1500 // type with Signedness of 0, or sign extended when Signedness is 1."
1501 if (bitwidth == 32 || bitwidth == 16) {
1502 uint32_t word = 0;
1503 if (isSigned) {
1504 word = static_cast<int32_t>(value.getSExtValue());
1505 } else {
1506 word = static_cast<uint32_t>(value.getZExtValue());
1508 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1510 // According to SPIR-V spec: "When the type's bit width is larger than one
1511 // word, the literal’s low-order words appear first."
1512 else if (bitwidth == 64) {
1513 struct DoubleWord {
1514 uint32_t word1;
1515 uint32_t word2;
1516 } words;
1517 if (isSigned) {
1518 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1519 } else {
1520 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1522 encodeInstructionInto(typesGlobalValues, opcode,
1523 {typeID, resultID, words.word1, words.word2});
1524 } else {
1525 std::string valueStr;
1526 llvm::raw_string_ostream rss(valueStr);
1527 value.print(rss, /*isSigned=*/false);
1529 emitError(loc, "cannot serialize ")
1530 << bitwidth << "-bit integer literal: " << rss.str();
1531 return 0;
1534 if (!isSpec) {
1535 constIDMap[intAttr] = resultID;
1537 return resultID;
1540 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1541 bool isSpec) {
1542 if (!isSpec) {
1543 // We can de-duplicate normal constants, but not specialization constants.
1544 if (auto id = getConstantID(floatAttr)) {
1545 return id;
1549 // Process the type for this float literal
1550 uint32_t typeID = 0;
1551 if (failed(processType(loc, floatAttr.getType(), typeID))) {
1552 return 0;
1555 auto resultID = getNextID();
1556 APFloat value = floatAttr.getValue();
1557 APInt intValue = value.bitcastToAPInt();
1559 auto opcode =
1560 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1562 if (&value.getSemantics() == &APFloat::IEEEsingle()) {
1563 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1564 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1565 } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
1566 struct DoubleWord {
1567 uint32_t word1;
1568 uint32_t word2;
1569 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1570 encodeInstructionInto(typesGlobalValues, opcode,
1571 {typeID, resultID, words.word1, words.word2});
1572 } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
1573 uint32_t word =
1574 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
1575 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1576 } else {
1577 std::string valueStr;
1578 llvm::raw_string_ostream rss(valueStr);
1579 value.print(rss);
1581 emitError(loc, "cannot serialize ")
1582 << floatAttr.getType() << "-typed float literal: " << rss.str();
1583 return 0;
1586 if (!isSpec) {
1587 constIDMap[floatAttr] = resultID;
1589 return resultID;
1592 //===----------------------------------------------------------------------===//
1593 // Control flow
1594 //===----------------------------------------------------------------------===//
1596 uint32_t Serializer::getOrCreateBlockID(Block *block) {
1597 if (uint32_t id = getBlockID(block))
1598 return id;
1599 return blockIDMap[block] = getNextID();
1602 LogicalResult
1603 Serializer::processBlock(Block *block, bool omitLabel,
1604 function_ref<void()> actionBeforeTerminator) {
1605 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1606 LLVM_DEBUG(block->print(llvm::dbgs()));
1607 LLVM_DEBUG(llvm::dbgs() << '\n');
1608 if (!omitLabel) {
1609 uint32_t blockID = getOrCreateBlockID(block);
1610 LLVM_DEBUG(llvm::dbgs()
1611 << "[block] " << block << " (id = " << blockID << ")\n");
1613 // Emit OpLabel for this block.
1614 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1617 // Emit OpPhi instructions for block arguments, if any.
1618 if (failed(emitPhiForBlockArguments(block)))
1619 return failure();
1621 // Process each op in this block except the terminator.
1622 for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
1623 if (failed(processOperation(&op)))
1624 return failure();
1627 // Process the terminator.
1628 if (actionBeforeTerminator)
1629 actionBeforeTerminator();
1630 if (failed(processOperation(&block->back())))
1631 return failure();
1633 return success();
1636 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1637 // Nothing to do if this block has no arguments or it's the entry block, which
1638 // always has the same arguments as the function signature.
1639 if (block->args_empty() || block->isEntryBlock())
1640 return success();
1642 // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1643 // A SPIR-V OpPhi instruction is of the syntax:
1644 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1645 // So we need to collect all predecessor blocks and the arguments they send
1646 // to this block.
1647 SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors;
1648 for (Block *predecessor : block->getPredecessors()) {
1649 auto *terminator = predecessor->getTerminator();
1650 // The predecessor here is the immediate one according to MLIR's IR
1651 // structure. It does not directly map to the incoming parent block for the
1652 // OpPhi instructions at SPIR-V binary level. This is because structured
1653 // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1654 // spv.selection/spv.loop op in the MLIR predecessor block, the branch op
1655 // jumping to the OpPhi's block then resides in the previous structured
1656 // control flow op's merge block.
1657 predecessor = getPhiIncomingBlock(predecessor);
1658 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1659 predecessors.emplace_back(predecessor, branchOp.operand_begin());
1660 } else {
1661 return terminator->emitError("unimplemented terminator for Phi creation");
1665 // Then create OpPhi instruction for each of the block argument.
1666 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1667 BlockArgument arg = block->getArgument(argIndex);
1669 // Get the type <id> and result <id> for this OpPhi instruction.
1670 uint32_t phiTypeID = 0;
1671 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1672 return failure();
1673 uint32_t phiID = getNextID();
1675 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1676 << arg << " (id = " << phiID << ")\n");
1678 // Prepare the (value <id>, parent block <id>) pairs.
1679 SmallVector<uint32_t, 8> phiArgs;
1680 phiArgs.push_back(phiTypeID);
1681 phiArgs.push_back(phiID);
1683 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1684 Value value = *(predecessors[predIndex].second + argIndex);
1685 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1686 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1687 << ") value " << value << ' ');
1688 // Each pair is a value <id> ...
1689 uint32_t valueId = getValueID(value);
1690 if (valueId == 0) {
1691 // The op generating this value hasn't been visited yet so we don't have
1692 // an <id> assigned yet. Record this to fix up later.
1693 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1694 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1695 phiArgs.size());
1696 } else {
1697 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1699 phiArgs.push_back(valueId);
1700 // ... and a parent block <id>.
1701 phiArgs.push_back(predBlockId);
1704 encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1705 valueIDMap[arg] = phiID;
1708 return success();
1711 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
1712 // Assign <id>s to all blocks so that branches inside the SelectionOp can
1713 // resolve properly.
1714 auto &body = selectionOp.body();
1715 for (Block &block : body)
1716 getOrCreateBlockID(&block);
1718 auto *headerBlock = selectionOp.getHeaderBlock();
1719 auto *mergeBlock = selectionOp.getMergeBlock();
1720 auto mergeID = getBlockID(mergeBlock);
1721 auto loc = selectionOp.getLoc();
1723 // Emit the selection header block, which dominates all other blocks, first.
1724 // We need to emit an OpSelectionMerge instruction before the selection header
1725 // block's terminator.
1726 auto emitSelectionMerge = [&]() {
1727 emitDebugLine(functionBody, loc);
1728 lastProcessedWasMergeInst = true;
1729 encodeInstructionInto(
1730 functionBody, spirv::Opcode::OpSelectionMerge,
1731 {mergeID, static_cast<uint32_t>(selectionOp.selection_control())});
1733 // For structured selection, we cannot have blocks in the selection construct
1734 // branching to the selection header block. Entering the selection (and
1735 // reaching the selection header) must be from the block containing the
1736 // spv.selection op. If there are ops ahead of the spv.selection op in the
1737 // block, we can "merge" them into the selection header. So here we don't need
1738 // to emit a separate block; just continue with the existing block.
1739 if (failed(processBlock(headerBlock, /*omitLabel=*/true, emitSelectionMerge)))
1740 return failure();
1742 // Process all blocks with a depth-first visitor starting from the header
1743 // block. The selection header block and merge block are skipped by this
1744 // visitor.
1745 if (failed(visitInPrettyBlockOrder(
1746 headerBlock, [&](Block *block) { return processBlock(block); },
1747 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
1748 return failure();
1750 // There is nothing to do for the merge block in the selection, which just
1751 // contains a spv.mlir.merge op, itself. But we need to have an OpLabel
1752 // instruction to start a new SPIR-V block for ops following this SelectionOp.
1753 // The block should use the <id> for the merge block.
1754 return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
1757 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
1758 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
1759 // properly. We don't need to assign for the entry block, which is just for
1760 // satisfying MLIR region's structural requirement.
1761 auto &body = loopOp.body();
1762 for (Block &block :
1763 llvm::make_range(std::next(body.begin(), 1), body.end())) {
1764 getOrCreateBlockID(&block);
1766 auto *headerBlock = loopOp.getHeaderBlock();
1767 auto *continueBlock = loopOp.getContinueBlock();
1768 auto *mergeBlock = loopOp.getMergeBlock();
1769 auto headerID = getBlockID(headerBlock);
1770 auto continueID = getBlockID(continueBlock);
1771 auto mergeID = getBlockID(mergeBlock);
1772 auto loc = loopOp.getLoc();
1774 // This LoopOp is in some MLIR block with preceding and following ops. In the
1775 // binary format, it should reside in separate SPIR-V blocks from its
1776 // preceding and following ops. So we need to emit unconditional branches to
1777 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
1778 // afterwards.
1779 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
1781 // LoopOp's entry block is just there for satisfying MLIR's structural
1782 // requirements so we omit it and start serialization from the loop header
1783 // block.
1785 // Emit the loop header block, which dominates all other blocks, first. We
1786 // need to emit an OpLoopMerge instruction before the loop header block's
1787 // terminator.
1788 auto emitLoopMerge = [&]() {
1789 emitDebugLine(functionBody, loc);
1790 lastProcessedWasMergeInst = true;
1791 encodeInstructionInto(
1792 functionBody, spirv::Opcode::OpLoopMerge,
1793 {mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())});
1795 if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
1796 return failure();
1798 // Process all blocks with a depth-first visitor starting from the header
1799 // block. The loop header block, loop continue block, and loop merge block are
1800 // skipped by this visitor and handled later in this function.
1801 if (failed(visitInPrettyBlockOrder(
1802 headerBlock, [&](Block *block) { return processBlock(block); },
1803 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
1804 return failure();
1806 // We have handled all other blocks. Now get to the loop continue block.
1807 if (failed(processBlock(continueBlock)))
1808 return failure();
1810 // There is nothing to do for the merge block in the loop, which just contains
1811 // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to
1812 // start a new SPIR-V block for ops following this LoopOp. The block should
1813 // use the <id> for the merge block.
1814 return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
1817 LogicalResult Serializer::processBranchConditionalOp(
1818 spirv::BranchConditionalOp condBranchOp) {
1819 auto conditionID = getValueID(condBranchOp.condition());
1820 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
1821 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
1822 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
1824 if (auto weights = condBranchOp.branch_weights()) {
1825 for (auto val : weights->getValue())
1826 arguments.push_back(val.cast<IntegerAttr>().getInt());
1829 emitDebugLine(functionBody, condBranchOp.getLoc());
1830 return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
1831 arguments);
1834 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
1835 emitDebugLine(functionBody, branchOp.getLoc());
1836 return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
1837 {getOrCreateBlockID(branchOp.getTarget())});
1840 //===----------------------------------------------------------------------===//
1841 // Operation
1842 //===----------------------------------------------------------------------===//
1844 LogicalResult Serializer::encodeExtensionInstruction(
1845 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1846 ArrayRef<uint32_t> operands) {
1847 // Check if the extension has been imported.
1848 auto &setID = extendedInstSetIDMap[extensionSetName];
1849 if (!setID) {
1850 setID = getNextID();
1851 SmallVector<uint32_t, 16> importOperands;
1852 importOperands.push_back(setID);
1853 if (failed(
1854 spirv::encodeStringLiteralInto(importOperands, extensionSetName)) ||
1855 failed(encodeInstructionInto(
1856 extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) {
1857 return failure();
1861 // The first two operands are the result type <id> and result <id>. The set
1862 // <id> and the opcode need to be insert after this.
1863 if (operands.size() < 2) {
1864 return op->emitError("extended instructions must have a result encoding");
1866 SmallVector<uint32_t, 8> extInstOperands;
1867 extInstOperands.reserve(operands.size() + 2);
1868 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1869 extInstOperands.push_back(setID);
1870 extInstOperands.push_back(extensionOpcode);
1871 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1872 return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1873 extInstOperands);
1876 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
1877 auto varName = addressOfOp.variable();
1878 auto variableID = getVariableID(varName);
1879 if (!variableID) {
1880 return addressOfOp.emitError("unknown result <id> for variable ")
1881 << varName;
1883 valueIDMap[addressOfOp.pointer()] = variableID;
1884 return success();
1887 LogicalResult
1888 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
1889 auto constName = referenceOfOp.spec_const();
1890 auto constID = getSpecConstID(constName);
1891 if (!constID) {
1892 return referenceOfOp.emitError(
1893 "unknown result <id> for specialization constant ")
1894 << constName;
1896 valueIDMap[referenceOfOp.reference()] = constID;
1897 return success();
1900 LogicalResult Serializer::processOperation(Operation *opInst) {
1901 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1903 // First dispatch the ops that do not directly mirror an instruction from
1904 // the SPIR-V spec.
1905 return TypeSwitch<Operation *, LogicalResult>(opInst)
1906 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1907 .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1908 .Case([&](spirv::BranchConditionalOp op) {
1909 return processBranchConditionalOp(op);
1911 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1912 .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1913 .Case([&](spirv::GlobalVariableOp op) {
1914 return processGlobalVariableOp(op);
1916 .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1917 .Case([&](spirv::ModuleEndOp) { return success(); })
1918 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1919 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1920 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1921 .Case([&](spirv::SpecConstantCompositeOp op) {
1922 return processSpecConstantCompositeOp(op);
1924 .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1925 .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1927 // Then handle all the ops that directly mirror SPIR-V instructions with
1928 // auto-generated methods.
1929 .Default(
1930 [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1933 namespace {
1934 template <>
1935 LogicalResult
1936 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
1937 SmallVector<uint32_t, 4> operands;
1938 // Add the ExecutionModel.
1939 operands.push_back(static_cast<uint32_t>(op.execution_model()));
1940 // Add the function <id>.
1941 auto funcID = getFunctionID(op.fn());
1942 if (!funcID) {
1943 return op.emitError("missing <id> for function ")
1944 << op.fn()
1945 << "; function needs to be defined before spv.EntryPoint is "
1946 "serialized";
1948 operands.push_back(funcID);
1949 // Add the name of the function.
1950 spirv::encodeStringLiteralInto(operands, op.fn());
1952 // Add the interface values.
1953 if (auto interface = op.interface()) {
1954 for (auto var : interface.getValue()) {
1955 auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
1956 if (!id) {
1957 return op.emitError("referencing undefined global variable."
1958 "spv.EntryPoint is at the end of spv.module. All "
1959 "referenced variables should already be defined");
1961 operands.push_back(id);
1964 return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint,
1965 operands);
1968 template <>
1969 LogicalResult
1970 Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) {
1971 StringRef argNames[] = {"execution_scope", "memory_scope",
1972 "memory_semantics"};
1973 SmallVector<uint32_t, 3> operands;
1975 for (auto argName : argNames) {
1976 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
1977 auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
1978 if (!operand) {
1979 return failure();
1981 operands.push_back(operand);
1984 return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier,
1985 operands);
1988 template <>
1989 LogicalResult
1990 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
1991 SmallVector<uint32_t, 4> operands;
1992 // Add the function <id>.
1993 auto funcID = getFunctionID(op.fn());
1994 if (!funcID) {
1995 return op.emitError("missing <id> for function ")
1996 << op.fn()
1997 << "; function needs to be serialized before ExecutionModeOp is "
1998 "serialized";
2000 operands.push_back(funcID);
2001 // Add the ExecutionMode.
2002 operands.push_back(static_cast<uint32_t>(op.execution_mode()));
2004 // Serialize values if any.
2005 auto values = op.values();
2006 if (values) {
2007 for (auto &intVal : values.getValue()) {
2008 operands.push_back(static_cast<uint32_t>(
2009 intVal.cast<IntegerAttr>().getValue().getZExtValue()));
2012 return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
2013 operands);
2016 template <>
2017 LogicalResult
2018 Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) {
2019 StringRef argNames[] = {"memory_scope", "memory_semantics"};
2020 SmallVector<uint32_t, 2> operands;
2022 for (auto argName : argNames) {
2023 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
2024 auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
2025 if (!operand) {
2026 return failure();
2028 operands.push_back(operand);
2031 return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier,
2032 operands);
2035 template <>
2036 LogicalResult
2037 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
2038 auto funcName = op.callee();
2039 uint32_t resTypeID = 0;
2041 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
2042 if (failed(processType(op.getLoc(), resultTy, resTypeID)))
2043 return failure();
2045 auto funcID = getOrCreateFunctionID(funcName);
2046 auto funcCallID = getNextID();
2047 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
2049 for (auto value : op.arguments()) {
2050 auto valueID = getValueID(value);
2051 assert(valueID && "cannot find a value for spv.FunctionCall");
2052 operands.push_back(valueID);
2055 if (!resultTy.isa<NoneType>())
2056 valueIDMap[op.getResult(0)] = funcCallID;
2058 return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall,
2059 operands);
2062 template <>
2063 LogicalResult
2064 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
2065 SmallVector<uint32_t, 4> operands;
2066 SmallVector<StringRef, 2> elidedAttrs;
2068 for (Value operand : op->getOperands()) {
2069 auto id = getValueID(operand);
2070 assert(id && "use before def!");
2071 operands.push_back(id);
2074 if (auto attr = op->getAttr("memory_access")) {
2075 operands.push_back(static_cast<uint32_t>(
2076 attr.cast<IntegerAttr>().getValue().getZExtValue()));
2079 elidedAttrs.push_back("memory_access");
2081 if (auto attr = op->getAttr("alignment")) {
2082 operands.push_back(static_cast<uint32_t>(
2083 attr.cast<IntegerAttr>().getValue().getZExtValue()));
2086 elidedAttrs.push_back("alignment");
2088 if (auto attr = op->getAttr("source_memory_access")) {
2089 operands.push_back(static_cast<uint32_t>(
2090 attr.cast<IntegerAttr>().getValue().getZExtValue()));
2093 elidedAttrs.push_back("source_memory_access");
2095 if (auto attr = op->getAttr("source_alignment")) {
2096 operands.push_back(static_cast<uint32_t>(
2097 attr.cast<IntegerAttr>().getValue().getZExtValue()));
2100 elidedAttrs.push_back("source_alignment");
2101 emitDebugLine(functionBody, op.getLoc());
2102 encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
2104 return success();
2107 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
2108 // various Serializer::processOp<...>() specializations.
2109 #define GET_SERIALIZATION_FNS
2110 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
2111 } // namespace
2113 LogicalResult Serializer::emitDecoration(uint32_t target,
2114 spirv::Decoration decoration,
2115 ArrayRef<uint32_t> params) {
2116 uint32_t wordCount = 3 + params.size();
2117 decorations.push_back(
2118 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate));
2119 decorations.push_back(target);
2120 decorations.push_back(static_cast<uint32_t>(decoration));
2121 decorations.append(params.begin(), params.end());
2122 return success();
2125 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
2126 Location loc) {
2127 if (!emitDebugInfo)
2128 return success();
2130 if (lastProcessedWasMergeInst) {
2131 lastProcessedWasMergeInst = false;
2132 return success();
2135 auto fileLoc = loc.dyn_cast<FileLineColLoc>();
2136 if (fileLoc)
2137 encodeInstructionInto(binary, spirv::Opcode::OpLine,
2138 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
2139 return success();
2142 namespace mlir {
2143 LogicalResult spirv::serialize(spirv::ModuleOp module,
2144 SmallVectorImpl<uint32_t> &binary,
2145 bool emitDebugInfo) {
2146 if (!module.vce_triple().hasValue())
2147 return module.emitError(
2148 "module must have 'vce_triple' attribute to be serializeable");
2150 Serializer serializer(module, emitDebugInfo);
2152 if (failed(serializer.serialize()))
2153 return failure();
2155 LLVM_DEBUG(serializer.printValueIDMap(llvm::dbgs()));
2157 serializer.collect(binary);
2158 return success();
2160 } // namespace mlir