1 //===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serialization.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Target/SPIRV/Serialization.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/RegionGraphTraits.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
23 #include "llvm/ADT/DepthFirstIterator.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/ADT/bit.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
34 #define DEBUG_TYPE "spirv-serialization"
38 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
39 /// the given `binary` vector.
40 static LogicalResult
encodeInstructionInto(SmallVectorImpl
<uint32_t> &binary
,
42 ArrayRef
<uint32_t> operands
) {
43 uint32_t wordCount
= 1 + operands
.size();
44 binary
.push_back(spirv::getPrefixedOpcode(wordCount
, op
));
45 binary
.append(operands
.begin(), operands
.end());
49 /// A pre-order depth-first visitor function for processing basic blocks.
51 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
52 /// depth-first manner and calls `blockHandler` on each block. Skips handling
53 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
54 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
57 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
58 /// of blocks in a function must satisfy the rule that blocks appear before
59 /// all blocks they dominate." This can be achieved by a pre-order CFG
60 /// traversal algorithm. To make the serialization output more logical and
61 /// readable to human, we perform depth-first CFG traversal and delay the
62 /// serialization of the merge block and the continue block, if exists, until
63 /// after all other blocks have been processed.
65 visitInPrettyBlockOrder(Block
*headerBlock
,
66 function_ref
<LogicalResult(Block
*)> blockHandler
,
67 bool skipHeader
= false, BlockRange skipBlocks
= {}) {
68 llvm::df_iterator_default_set
<Block
*, 4> doneBlocks
;
69 doneBlocks
.insert(skipBlocks
.begin(), skipBlocks
.end());
71 for (Block
*block
: llvm::depth_first_ext(headerBlock
, doneBlocks
)) {
72 if (skipHeader
&& block
== headerBlock
)
74 if (failed(blockHandler(block
)))
80 /// Returns the merge block if the given `op` is a structured control flow op.
81 /// Otherwise returns nullptr.
82 static Block
*getStructuredControlFlowOpMergeBlock(Operation
*op
) {
83 if (auto selectionOp
= dyn_cast
<spirv::SelectionOp
>(op
))
84 return selectionOp
.getMergeBlock();
85 if (auto loopOp
= dyn_cast
<spirv::LoopOp
>(op
))
86 return loopOp
.getMergeBlock();
90 /// Given a predecessor `block` for a block with arguments, returns the block
91 /// that should be used as the parent block for SPIR-V OpPhi instructions
92 /// corresponding to the block arguments.
93 static Block
*getPhiIncomingBlock(Block
*block
) {
94 // If the predecessor block in question is the entry block for a spv.loop,
95 // we jump to this spv.loop from its enclosing block.
96 if (block
->isEntryBlock()) {
97 if (auto loopOp
= dyn_cast
<spirv::LoopOp
>(block
->getParentOp())) {
98 // Then the incoming parent block for OpPhi should be the merge block of
99 // the structured control flow op before this loop.
100 Operation
*op
= loopOp
.getOperation();
101 while ((op
= op
->getPrevNode()) != nullptr)
102 if (Block
*incomingBlock
= getStructuredControlFlowOpMergeBlock(op
))
103 return incomingBlock
;
104 // Or the enclosing block itself if no structured control flow ops
105 // exists before this loop.
106 return loopOp
->getBlock();
110 // Otherwise, we jump from the given predecessor block. Try to see if there is
111 // a structured control flow op inside it.
112 for (Operation
&op
: llvm::reverse(block
->getOperations())) {
113 if (Block
*incomingBlock
= getStructuredControlFlowOpMergeBlock(&op
))
114 return incomingBlock
;
121 /// A SPIR-V module serializer.
123 /// A SPIR-V binary module is a single linear stream of instructions; each
124 /// instruction is composed of 32-bit words with the layout:
126 /// | <word-count>|<opcode> | <operand> | <operand> | ... |
127 /// | <------ word -------> | <-- word --> | <-- word --> | ... |
129 /// For the first word, the 16 high-order bits are the word count of the
130 /// instruction, the 16 low-order bits are the opcode enumerant. The
131 /// instructions then belong to different sections, which must be laid out in
132 /// the particular order as specified in "2.4 Logical Layout of a Module" of
136 /// Creates a serializer for the given SPIR-V `module`.
137 explicit Serializer(spirv::ModuleOp module
, bool emitDebugInfo
= false);
139 /// Serializes the remembered SPIR-V module.
140 LogicalResult
serialize();
142 /// Collects the final SPIR-V `binary`.
143 void collect(SmallVectorImpl
<uint32_t> &binary
);
146 /// (For debugging) prints each value and its corresponding result <id>.
147 void printValueIDMap(raw_ostream
&os
);
151 // Note that there are two main categories of methods in this class:
152 // * process*() methods are meant to fully serialize a SPIR-V module entity
153 // (header, type, op, etc.). They update internal vectors containing
154 // different binary sections. They are not meant to be called except the
155 // top-level serialization loop.
156 // * prepare*() methods are meant to be helpers that prepare for serializing
157 // certain entity. They may or may not update internal vectors containing
158 // different binary sections. They are meant to be called among themselves
159 // or by other process*() methods for subtasks.
161 //===--------------------------------------------------------------------===//
163 //===--------------------------------------------------------------------===//
165 // Note that it is illegal to use id <0> in SPIR-V binary module. Various
166 // methods in this class, if using SPIR-V word (uint32_t) as interface,
167 // check or return id <0> to indicate error in processing.
169 /// Consumes the next unused <id>. This method will never return 0.
170 uint32_t getNextID() { return nextID
++; }
172 //===--------------------------------------------------------------------===//
174 //===--------------------------------------------------------------------===//
176 uint32_t getSpecConstID(StringRef constName
) const {
177 return specConstIDMap
.lookup(constName
);
180 uint32_t getVariableID(StringRef varName
) const {
181 return globalVarIDMap
.lookup(varName
);
184 uint32_t getFunctionID(StringRef fnName
) const {
185 return funcIDMap
.lookup(fnName
);
188 /// Gets the <id> for the function with the given name. Assigns the next
189 /// available <id> if the function haven't been deserialized.
190 uint32_t getOrCreateFunctionID(StringRef fnName
);
192 void processCapability();
194 void processDebugInfo();
196 void processExtension();
198 void processMemoryModel();
200 LogicalResult
processConstantOp(spirv::ConstantOp op
);
202 LogicalResult
processSpecConstantOp(spirv::SpecConstantOp op
);
205 processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op
);
207 /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA
208 /// value to use with other operations. The SPIR-V spec recommends that
209 /// OpUndef be generated at module level. The serialization generates an
210 /// OpUndef for each type needed at module level.
211 LogicalResult
processUndefOp(spirv::UndefOp op
);
213 /// Emit OpName for the given `resultID`.
214 LogicalResult
processName(uint32_t resultID
, StringRef name
);
216 /// Processes a SPIR-V function op.
217 LogicalResult
processFuncOp(spirv::FuncOp op
);
219 LogicalResult
processVariableOp(spirv::VariableOp op
);
221 /// Process a SPIR-V GlobalVariableOp
222 LogicalResult
processGlobalVariableOp(spirv::GlobalVariableOp varOp
);
224 /// Process attributes that translate to decorations on the result <id>
225 LogicalResult
processDecoration(Location loc
, uint32_t resultID
,
226 NamedAttribute attr
);
228 template <typename DType
>
229 LogicalResult
processTypeDecoration(Location loc
, DType type
,
231 return emitError(loc
, "unhandled decoration for type:") << type
;
234 /// Process member decoration
235 LogicalResult
processMemberDecoration(
237 const spirv::StructType::MemberDecorationInfo
&memberDecorationInfo
);
239 //===--------------------------------------------------------------------===//
241 //===--------------------------------------------------------------------===//
243 uint32_t getTypeID(Type type
) const { return typeIDMap
.lookup(type
); }
245 Type
getVoidType() { return mlirBuilder
.getNoneType(); }
247 bool isVoidType(Type type
) const { return type
.isa
<NoneType
>(); }
249 /// Returns true if the given type is a pointer type to a struct in some
250 /// interface storage class.
251 bool isInterfaceStructPtrType(Type type
) const;
253 /// Main dispatch method for serializing a type. The result <id> of the
254 /// serialized type will be returned as `typeID`.
255 LogicalResult
processType(Location loc
, Type type
, uint32_t &typeID
);
256 LogicalResult
processTypeImpl(Location loc
, Type type
, uint32_t &typeID
,
257 llvm::SetVector
<StringRef
> &serializationCtx
);
259 /// Method for preparing basic SPIR-V type serialization. Returns the type's
260 /// opcode and operands for the instruction via `typeEnum` and `operands`.
261 LogicalResult
prepareBasicType(Location loc
, Type type
, uint32_t resultID
,
262 spirv::Opcode
&typeEnum
,
263 SmallVectorImpl
<uint32_t> &operands
,
264 bool &deferSerialization
,
265 llvm::SetVector
<StringRef
> &serializationCtx
);
267 LogicalResult
prepareFunctionType(Location loc
, FunctionType type
,
268 spirv::Opcode
&typeEnum
,
269 SmallVectorImpl
<uint32_t> &operands
);
271 //===--------------------------------------------------------------------===//
273 //===--------------------------------------------------------------------===//
275 uint32_t getConstantID(Attribute value
) const {
276 return constIDMap
.lookup(value
);
279 /// Main dispatch method for processing a constant with the given `constType`
280 /// and `valueAttr`. `constType` is needed here because we can interpret the
281 /// `valueAttr` as a different type than the type of `valueAttr` itself; for
282 /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
284 uint32_t prepareConstant(Location loc
, Type constType
, Attribute valueAttr
);
286 /// Prepares array attribute serialization. This method emits corresponding
287 /// OpConstant* and returns the result <id> associated with it. Returns 0 if
289 uint32_t prepareArrayConstant(Location loc
, Type constType
, ArrayAttr attr
);
291 /// Prepares bool/int/float DenseElementsAttr serialization. This method
292 /// iterates the DenseElementsAttr to construct the constant array, and
293 /// returns the result <id> associated with it. Returns 0 if failed. Note
294 /// that the size of `index` must match the rank.
295 /// TODO: Consider to enhance splat elements cases. For splat cases,
296 /// we don't need to loop over all elements, especially when the splat value
297 /// is zero. We can use OpConstantNull when the value is zero.
298 uint32_t prepareDenseElementsConstant(Location loc
, Type constType
,
299 DenseElementsAttr valueAttr
, int dim
,
300 MutableArrayRef
<uint64_t> index
);
302 /// Prepares scalar attribute serialization. This method emits corresponding
303 /// OpConstant* and returns the result <id> associated with it. Returns 0 if
304 /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is
305 /// true, then the constant will be serialized as a specialization constant.
306 uint32_t prepareConstantScalar(Location loc
, Attribute valueAttr
,
307 bool isSpec
= false);
309 uint32_t prepareConstantBool(Location loc
, BoolAttr boolAttr
,
310 bool isSpec
= false);
312 uint32_t prepareConstantInt(Location loc
, IntegerAttr intAttr
,
313 bool isSpec
= false);
315 uint32_t prepareConstantFp(Location loc
, FloatAttr floatAttr
,
316 bool isSpec
= false);
318 //===--------------------------------------------------------------------===//
320 //===--------------------------------------------------------------------===//
322 /// Returns the result <id> for the given block.
323 uint32_t getBlockID(Block
*block
) const { return blockIDMap
.lookup(block
); }
325 /// Returns the result <id> for the given block. If no <id> has been assigned,
326 /// assigns the next available <id>
327 uint32_t getOrCreateBlockID(Block
*block
);
329 /// Processes the given `block` and emits SPIR-V instructions for all ops
330 /// inside. Does not emit OpLabel for this block if `omitLabel` is true.
331 /// `actionBeforeTerminator` is a callback that will be invoked before
332 /// handling the terminator op. It can be used to inject the Op*Merge
333 /// instruction if this is a SPIR-V selection/loop header block.
335 processBlock(Block
*block
, bool omitLabel
= false,
336 function_ref
<void()> actionBeforeTerminator
= nullptr);
338 /// Emits OpPhi instructions for the given block if it has block arguments.
339 LogicalResult
emitPhiForBlockArguments(Block
*block
);
341 LogicalResult
processSelectionOp(spirv::SelectionOp selectionOp
);
343 LogicalResult
processLoopOp(spirv::LoopOp loopOp
);
345 LogicalResult
processBranchConditionalOp(spirv::BranchConditionalOp
);
347 LogicalResult
processBranchOp(spirv::BranchOp branchOp
);
349 //===--------------------------------------------------------------------===//
351 //===--------------------------------------------------------------------===//
353 LogicalResult
encodeExtensionInstruction(Operation
*op
,
354 StringRef extensionSetName
,
356 ArrayRef
<uint32_t> operands
);
358 uint32_t getValueID(Value val
) const { return valueIDMap
.lookup(val
); }
360 LogicalResult
processAddressOfOp(spirv::AddressOfOp addressOfOp
);
362 LogicalResult
processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp
);
364 /// Main dispatch method for serializing an operation.
365 LogicalResult
processOperation(Operation
*op
);
367 /// Method to dispatch to the serialization function for an operation in
368 /// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec.
369 /// This is auto-generated from ODS. Dispatch is handled for all operations
370 /// in SPIR-V dialect that have hasOpcode == 1.
371 LogicalResult
dispatchToAutogenSerialization(Operation
*op
);
373 /// Method to serialize an operation in the SPIR-V dialect that is a mirror of
374 /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode ==
375 /// 1 and autogenSerialization == 1 in ODS.
376 template <typename OpTy
>
377 LogicalResult
processOp(OpTy op
) {
378 return op
.emitError("unsupported op serialization");
381 //===--------------------------------------------------------------------===//
383 //===--------------------------------------------------------------------===//
385 /// Emits an OpDecorate instruction to decorate the given `target` with the
386 /// given `decoration`.
387 LogicalResult
emitDecoration(uint32_t target
, spirv::Decoration decoration
,
388 ArrayRef
<uint32_t> params
= {});
390 /// Emits an OpLine instruction with the given `loc` location information into
391 /// the given `binary` vector.
392 LogicalResult
emitDebugLine(SmallVectorImpl
<uint32_t> &binary
, Location loc
);
395 /// The SPIR-V module to be serialized.
396 spirv::ModuleOp module
;
398 /// An MLIR builder for getting MLIR constructs.
399 mlir::Builder mlirBuilder
;
401 /// A flag which indicates if the debuginfo should be emitted.
402 bool emitDebugInfo
= false;
404 /// A flag which indicates if the last processed instruction was a merge
406 /// According to SPIR-V spec: "If a branch merge instruction is used, the last
407 /// OpLine in the block must be before its merge instruction".
408 bool lastProcessedWasMergeInst
= false;
410 /// The <id> of the OpString instruction, which specifies a file name, for
411 /// use by other debug instructions.
414 /// The next available result <id>.
417 // The following are for different SPIR-V instruction sections. They follow
418 // the logical layout of a SPIR-V module.
420 SmallVector
<uint32_t, 4> capabilities
;
421 SmallVector
<uint32_t, 0> extensions
;
422 SmallVector
<uint32_t, 0> extendedSets
;
423 SmallVector
<uint32_t, 3> memoryModel
;
424 SmallVector
<uint32_t, 0> entryPoints
;
425 SmallVector
<uint32_t, 4> executionModes
;
426 SmallVector
<uint32_t, 0> debug
;
427 SmallVector
<uint32_t, 0> names
;
428 SmallVector
<uint32_t, 0> decorations
;
429 SmallVector
<uint32_t, 0> typesGlobalValues
;
430 SmallVector
<uint32_t, 0> functions
;
432 /// Recursive struct references are serialized as OpTypePointer instructions
433 /// to the recursive struct type. However, the OpTypePointer instruction
434 /// cannot be emitted before the recursive struct's OpTypeStruct.
435 /// RecursiveStructPointerInfo stores the data needed to emit such
436 /// OpTypePointer instructions after forward references to such types.
437 struct RecursiveStructPointerInfo
{
438 uint32_t pointerTypeID
;
439 spirv::StorageClass storageClass
;
442 // Maps spirv::StructType to its recursive reference member info.
443 DenseMap
<Type
, SmallVector
<RecursiveStructPointerInfo
, 0>>
444 recursiveStructInfos
;
446 /// `functionHeader` contains all the instructions that must be in the first
447 /// block in the function, and `functionBody` contains the rest. After
448 /// processing FuncOp, the encoded instructions of a function are appended to
449 /// `functions`. An example of instructions in `functionHeader` in order:
451 /// OpFunctionParameter ...
452 /// OpFunctionParameter ...
456 SmallVector
<uint32_t, 0> functionHeader
;
457 SmallVector
<uint32_t, 0> functionBody
;
459 /// Map from type used in SPIR-V module to their <id>s.
460 DenseMap
<Type
, uint32_t> typeIDMap
;
462 /// Map from constant values to their <id>s.
463 DenseMap
<Attribute
, uint32_t> constIDMap
;
465 /// Map from specialization constant names to their <id>s.
466 llvm::StringMap
<uint32_t> specConstIDMap
;
468 /// Map from GlobalVariableOps name to <id>s.
469 llvm::StringMap
<uint32_t> globalVarIDMap
;
471 /// Map from FuncOps name to <id>s.
472 llvm::StringMap
<uint32_t> funcIDMap
;
474 /// Map from blocks to their <id>s.
475 DenseMap
<Block
*, uint32_t> blockIDMap
;
477 /// Map from the Type to the <id> that represents undef value of that type.
478 DenseMap
<Type
, uint32_t> undefValIDMap
;
480 /// Map from results of normal operations to their <id>s.
481 DenseMap
<Value
, uint32_t> valueIDMap
;
483 /// Map from extended instruction set name to <id>s.
484 llvm::StringMap
<uint32_t> extendedInstSetIDMap
;
486 /// Map from values used in OpPhi instructions to their offset in the
487 /// `functions` section.
489 /// When processing a block with arguments, we need to emit OpPhi
490 /// instructions to record the predecessor block <id>s and the values they
491 /// send to the block in question. But it's not guaranteed all values are
492 /// visited and thus assigned result <id>s. So we need this list to capture
493 /// the offsets into `functions` where a value is used so that we can fix it
494 /// up later after processing all the blocks in a function.
496 /// More concretely, say if we are visiting the following blocks:
499 /// ^phi(%arg0: i32):
503 /// spv.Branch ^phi(%val0: i32)
506 /// spv.Branch ^phi(%val1: i32)
509 /// When we are serializing the `^phi` block, we need to emit at the beginning
510 /// of the block OpPhi instructions which has the following parameters:
512 /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1
513 /// id-for-%val1 id-for-^parent2
515 /// But we don't know the <id> for %val0 and %val1 yet. One way is to visit
516 /// all the blocks twice and use the first visit to assign an <id> to each
517 /// value. But it's paying the overheads just for OpPhi emission. Instead,
518 /// we still visit the blocks once for emission. When we emit the OpPhi
519 /// instructions, we use 0 as a placeholder for the <id>s for %val0 and %val1.
520 /// At the same time, we record their offsets in the emitted binary (which is
521 /// placed inside `functions`) here. And then after emitting all blocks, we
522 /// replace the dummy <id> 0 with the real result <id> by overwriting
523 /// `functions[offset]`.
524 DenseMap
<Value
, SmallVector
<size_t, 1>> deferredPhiValues
;
528 Serializer::Serializer(spirv::ModuleOp module
, bool emitDebugInfo
)
529 : module(module
), mlirBuilder(module
.getContext()),
530 emitDebugInfo(emitDebugInfo
) {}
532 LogicalResult
Serializer::serialize() {
533 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
535 if (failed(module
.verify()))
538 // TODO: handle the other sections
541 processMemoryModel();
544 // Iterate over the module body to serialize it. Assumptions are that there is
545 // only one basic block in the moduleOp
546 for (auto &op
: module
.getBlock()) {
547 if (failed(processOperation(&op
))) {
552 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
556 void Serializer::collect(SmallVectorImpl
<uint32_t> &binary
) {
557 auto moduleSize
= spirv::kHeaderWordCount
+ capabilities
.size() +
558 extensions
.size() + extendedSets
.size() +
559 memoryModel
.size() + entryPoints
.size() +
560 executionModes
.size() + decorations
.size() +
561 typesGlobalValues
.size() + functions
.size();
564 binary
.reserve(moduleSize
);
566 spirv::appendModuleHeader(binary
, module
.vce_triple()->getVersion(), nextID
);
567 binary
.append(capabilities
.begin(), capabilities
.end());
568 binary
.append(extensions
.begin(), extensions
.end());
569 binary
.append(extendedSets
.begin(), extendedSets
.end());
570 binary
.append(memoryModel
.begin(), memoryModel
.end());
571 binary
.append(entryPoints
.begin(), entryPoints
.end());
572 binary
.append(executionModes
.begin(), executionModes
.end());
573 binary
.append(debug
.begin(), debug
.end());
574 binary
.append(names
.begin(), names
.end());
575 binary
.append(decorations
.begin(), decorations
.end());
576 binary
.append(typesGlobalValues
.begin(), typesGlobalValues
.end());
577 binary
.append(functions
.begin(), functions
.end());
581 void Serializer::printValueIDMap(raw_ostream
&os
) {
582 os
<< "\n= Value <id> Map =\n\n";
583 for (auto valueIDPair
: valueIDMap
) {
584 Value val
= valueIDPair
.first
;
585 os
<< " " << val
<< " "
586 << "id = " << valueIDPair
.second
<< ' ';
587 if (auto *op
= val
.getDefiningOp()) {
588 os
<< "from op '" << op
->getName() << "'";
589 } else if (auto arg
= val
.dyn_cast
<BlockArgument
>()) {
590 Block
*block
= arg
.getOwner();
591 os
<< "from argument of block " << block
<< ' ';
592 os
<< " in op '" << block
->getParentOp()->getName() << "'";
599 //===----------------------------------------------------------------------===//
601 //===----------------------------------------------------------------------===//
603 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName
) {
604 auto funcID
= funcIDMap
.lookup(fnName
);
606 funcID
= getNextID();
607 funcIDMap
[fnName
] = funcID
;
612 void Serializer::processCapability() {
613 for (auto cap
: module
.vce_triple()->getCapabilities())
614 encodeInstructionInto(capabilities
, spirv::Opcode::OpCapability
,
615 {static_cast<uint32_t>(cap
)});
618 void Serializer::processDebugInfo() {
621 auto fileLoc
= module
.getLoc().dyn_cast
<FileLineColLoc
>();
622 auto fileName
= fileLoc
? fileLoc
.getFilename() : "<unknown>";
623 fileID
= getNextID();
624 SmallVector
<uint32_t, 16> operands
;
625 operands
.push_back(fileID
);
626 spirv::encodeStringLiteralInto(operands
, fileName
);
627 encodeInstructionInto(debug
, spirv::Opcode::OpString
, operands
);
628 // TODO: Encode more debug instructions.
631 void Serializer::processExtension() {
632 llvm::SmallVector
<uint32_t, 16> extName
;
633 for (spirv::Extension ext
: module
.vce_triple()->getExtensions()) {
635 spirv::encodeStringLiteralInto(extName
, spirv::stringifyExtension(ext
));
636 encodeInstructionInto(extensions
, spirv::Opcode::OpExtension
, extName
);
640 void Serializer::processMemoryModel() {
641 uint32_t mm
= module
->getAttrOfType
<IntegerAttr
>("memory_model").getInt();
642 uint32_t am
= module
->getAttrOfType
<IntegerAttr
>("addressing_model").getInt();
644 encodeInstructionInto(memoryModel
, spirv::Opcode::OpMemoryModel
, {am
, mm
});
647 LogicalResult
Serializer::processConstantOp(spirv::ConstantOp op
) {
648 if (auto resultID
= prepareConstant(op
.getLoc(), op
.getType(), op
.value())) {
649 valueIDMap
[op
.getResult()] = resultID
;
655 LogicalResult
Serializer::processSpecConstantOp(spirv::SpecConstantOp op
) {
656 if (auto resultID
= prepareConstantScalar(op
.getLoc(), op
.default_value(),
658 // Emit the OpDecorate instruction for SpecId.
659 if (auto specID
= op
->getAttrOfType
<IntegerAttr
>("spec_id")) {
660 auto val
= static_cast<uint32_t>(specID
.getInt());
661 emitDecoration(resultID
, spirv::Decoration::SpecId
, {val
});
664 specConstIDMap
[op
.sym_name()] = resultID
;
665 return processName(resultID
, op
.sym_name());
671 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op
) {
673 if (failed(processType(op
.getLoc(), op
.type(), typeID
))) {
677 auto resultID
= getNextID();
679 SmallVector
<uint32_t, 8> operands
;
680 operands
.push_back(typeID
);
681 operands
.push_back(resultID
);
683 auto constituents
= op
.constituents();
685 for (auto index
: llvm::seq
<uint32_t>(0, constituents
.size())) {
686 auto constituent
= constituents
[index
].dyn_cast
<FlatSymbolRefAttr
>();
688 auto constituentName
= constituent
.getValue();
689 auto constituentID
= getSpecConstID(constituentName
);
691 if (!constituentID
) {
692 return op
.emitError("unknown result <id> for specialization constant ")
696 operands
.push_back(constituentID
);
699 encodeInstructionInto(typesGlobalValues
,
700 spirv::Opcode::OpSpecConstantComposite
, operands
);
701 specConstIDMap
[op
.sym_name()] = resultID
;
703 return processName(resultID
, op
.sym_name());
706 LogicalResult
Serializer::processUndefOp(spirv::UndefOp op
) {
707 auto undefType
= op
.getType();
708 auto &id
= undefValIDMap
[undefType
];
712 if (failed(processType(op
.getLoc(), undefType
, typeID
)) ||
713 failed(encodeInstructionInto(typesGlobalValues
, spirv::Opcode::OpUndef
,
718 valueIDMap
[op
.getResult()] = id
;
722 LogicalResult
Serializer::processDecoration(Location loc
, uint32_t resultID
,
723 NamedAttribute attr
) {
724 auto attrName
= attr
.first
.strref();
725 auto decorationName
= llvm::convertToCamelFromSnakeCase(attrName
, true);
726 auto decoration
= spirv::symbolizeDecoration(decorationName
);
729 loc
, "non-argument attributes expected to have snake-case-ified "
730 "decoration name, unhandled attribute with name : ")
733 SmallVector
<uint32_t, 1> args
;
734 switch (decoration
.getValue()) {
735 case spirv::Decoration::Binding
:
736 case spirv::Decoration::DescriptorSet
:
737 case spirv::Decoration::Location
:
738 if (auto intAttr
= attr
.second
.dyn_cast
<IntegerAttr
>()) {
739 args
.push_back(intAttr
.getValue().getZExtValue());
742 return emitError(loc
, "expected integer attribute for ") << attrName
;
743 case spirv::Decoration::BuiltIn
:
744 if (auto strAttr
= attr
.second
.dyn_cast
<StringAttr
>()) {
745 auto enumVal
= spirv::symbolizeBuiltIn(strAttr
.getValue());
747 args
.push_back(static_cast<uint32_t>(enumVal
.getValue()));
750 return emitError(loc
, "invalid ")
751 << attrName
<< " attribute " << strAttr
.getValue();
753 return emitError(loc
, "expected string attribute for ") << attrName
;
754 case spirv::Decoration::Aliased
:
755 case spirv::Decoration::Flat
:
756 case spirv::Decoration::NonReadable
:
757 case spirv::Decoration::NonWritable
:
758 case spirv::Decoration::NoPerspective
:
759 case spirv::Decoration::Restrict
:
760 // For unit attributes, the args list has no values so we do nothing
761 if (auto unitAttr
= attr
.second
.dyn_cast
<UnitAttr
>())
763 return emitError(loc
, "expected unit attribute for ") << attrName
;
765 return emitError(loc
, "unhandled decoration ") << decorationName
;
767 return emitDecoration(resultID
, decoration
.getValue(), args
);
770 LogicalResult
Serializer::processName(uint32_t resultID
, StringRef name
) {
771 assert(!name
.empty() && "unexpected empty string for OpName");
773 SmallVector
<uint32_t, 4> nameOperands
;
774 nameOperands
.push_back(resultID
);
775 if (failed(spirv::encodeStringLiteralInto(nameOperands
, name
))) {
778 return encodeInstructionInto(names
, spirv::Opcode::OpName
, nameOperands
);
783 LogicalResult
Serializer::processTypeDecoration
<spirv::ArrayType
>(
784 Location loc
, spirv::ArrayType type
, uint32_t resultID
) {
785 if (unsigned stride
= type
.getArrayStride()) {
786 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
787 return emitDecoration(resultID
, spirv::Decoration::ArrayStride
, {stride
});
793 LogicalResult
Serializer::processTypeDecoration
<spirv::RuntimeArrayType
>(
794 Location Loc
, spirv::RuntimeArrayType type
, uint32_t resultID
) {
795 if (unsigned stride
= type
.getArrayStride()) {
796 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
797 return emitDecoration(resultID
, spirv::Decoration::ArrayStride
, {stride
});
802 LogicalResult
Serializer::processMemberDecoration(
804 const spirv::StructType::MemberDecorationInfo
&memberDecoration
) {
805 SmallVector
<uint32_t, 4> args(
806 {structID
, memberDecoration
.memberIndex
,
807 static_cast<uint32_t>(memberDecoration
.decoration
)});
808 if (memberDecoration
.hasValue
) {
809 args
.push_back(memberDecoration
.decorationValue
);
811 return encodeInstructionInto(decorations
, spirv::Opcode::OpMemberDecorate
,
816 LogicalResult
Serializer::processFuncOp(spirv::FuncOp op
) {
817 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op
.getName() << "' --\n");
818 assert(functionHeader
.empty() && functionBody
.empty());
820 uint32_t fnTypeID
= 0;
821 // Generate type of the function.
822 processType(op
.getLoc(), op
.getType(), fnTypeID
);
824 // Add the function definition.
825 SmallVector
<uint32_t, 4> operands
;
826 uint32_t resTypeID
= 0;
827 auto resultTypes
= op
.getType().getResults();
828 if (resultTypes
.size() > 1) {
829 return op
.emitError("cannot serialize function with multiple return types");
831 if (failed(processType(op
.getLoc(),
832 (resultTypes
.empty() ? getVoidType() : resultTypes
[0]),
836 operands
.push_back(resTypeID
);
837 auto funcID
= getOrCreateFunctionID(op
.getName());
838 operands
.push_back(funcID
);
839 operands
.push_back(static_cast<uint32_t>(op
.function_control()));
840 operands
.push_back(fnTypeID
);
841 encodeInstructionInto(functionHeader
, spirv::Opcode::OpFunction
, operands
);
843 // Add function name.
844 if (failed(processName(funcID
, op
.getName()))) {
848 // Declare the parameters.
849 for (auto arg
: op
.getArguments()) {
850 uint32_t argTypeID
= 0;
851 if (failed(processType(op
.getLoc(), arg
.getType(), argTypeID
))) {
854 auto argValueID
= getNextID();
855 valueIDMap
[arg
] = argValueID
;
856 encodeInstructionInto(functionHeader
, spirv::Opcode::OpFunctionParameter
,
857 {argTypeID
, argValueID
});
861 if (op
.isExternal()) {
862 return op
.emitError("external function is unhandled");
865 // Some instructions (e.g., OpVariable) in a function must be in the first
866 // block in the function. These instructions will be put in functionHeader.
867 // Thus, we put the label in functionHeader first, and omit it from the first
869 encodeInstructionInto(functionHeader
, spirv::Opcode::OpLabel
,
870 {getOrCreateBlockID(&op
.front())});
871 processBlock(&op
.front(), /*omitLabel=*/true);
872 if (failed(visitInPrettyBlockOrder(
873 &op
.front(), [&](Block
*block
) { return processBlock(block
); },
874 /*skipHeader=*/true))) {
878 // There might be OpPhi instructions who have value references needing to fix.
879 for (auto deferredValue
: deferredPhiValues
) {
880 Value value
= deferredValue
.first
;
881 uint32_t id
= getValueID(value
);
882 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
883 << " to id = " << id
<< '\n');
884 assert(id
&& "OpPhi references undefined value!");
885 for (size_t offset
: deferredValue
.second
)
886 functionBody
[offset
] = id
;
888 deferredPhiValues
.clear();
890 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op
.getName()
892 // Insert OpFunctionEnd.
893 if (failed(encodeInstructionInto(functionBody
, spirv::Opcode::OpFunctionEnd
,
898 functions
.append(functionHeader
.begin(), functionHeader
.end());
899 functions
.append(functionBody
.begin(), functionBody
.end());
900 functionHeader
.clear();
901 functionBody
.clear();
906 LogicalResult
Serializer::processVariableOp(spirv::VariableOp op
) {
907 SmallVector
<uint32_t, 4> operands
;
908 SmallVector
<StringRef
, 2> elidedAttrs
;
909 uint32_t resultID
= 0;
910 uint32_t resultTypeID
= 0;
911 if (failed(processType(op
.getLoc(), op
.getType(), resultTypeID
))) {
914 operands
.push_back(resultTypeID
);
915 resultID
= getNextID();
916 valueIDMap
[op
.getResult()] = resultID
;
917 operands
.push_back(resultID
);
918 auto attr
= op
->getAttr(spirv::attributeName
<spirv::StorageClass
>());
920 operands
.push_back(static_cast<uint32_t>(
921 attr
.cast
<IntegerAttr
>().getValue().getZExtValue()));
923 elidedAttrs
.push_back(spirv::attributeName
<spirv::StorageClass
>());
924 for (auto arg
: op
.getODSOperands(0)) {
925 auto argID
= getValueID(arg
);
927 return emitError(op
.getLoc(), "operand 0 has a use before def");
929 operands
.push_back(argID
);
931 emitDebugLine(functionHeader
, op
.getLoc());
932 encodeInstructionInto(functionHeader
, spirv::Opcode::OpVariable
, operands
);
933 for (auto attr
: op
->getAttrs()) {
934 if (llvm::any_of(elidedAttrs
,
935 [&](StringRef elided
) { return attr
.first
== elided
; })) {
938 if (failed(processDecoration(op
.getLoc(), resultID
, attr
))) {
946 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp
) {
948 uint32_t resultTypeID
= 0;
949 SmallVector
<StringRef
, 4> elidedAttrs
;
950 if (failed(processType(varOp
.getLoc(), varOp
.type(), resultTypeID
))) {
954 if (isInterfaceStructPtrType(varOp
.type())) {
955 auto structType
= varOp
.type()
956 .cast
<spirv::PointerType
>()
958 .cast
<spirv::StructType
>();
960 emitDecoration(getTypeID(structType
), spirv::Decoration::Block
))) {
961 return varOp
.emitError("cannot decorate ")
962 << structType
<< " with Block decoration";
966 elidedAttrs
.push_back("type");
967 SmallVector
<uint32_t, 4> operands
;
968 operands
.push_back(resultTypeID
);
969 auto resultID
= getNextID();
972 auto varName
= varOp
.sym_name();
973 elidedAttrs
.push_back(SymbolTable::getSymbolAttrName());
974 if (failed(processName(resultID
, varName
))) {
977 globalVarIDMap
[varName
] = resultID
;
978 operands
.push_back(resultID
);
980 // Encode StorageClass.
981 operands
.push_back(static_cast<uint32_t>(varOp
.storageClass()));
983 // Encode initialization.
984 if (auto initializer
= varOp
.initializer()) {
985 auto initializerID
= getVariableID(initializer
.getValue());
986 if (!initializerID
) {
987 return emitError(varOp
.getLoc(),
988 "invalid usage of undefined variable as initializer");
990 operands
.push_back(initializerID
);
991 elidedAttrs
.push_back("initializer");
994 emitDebugLine(typesGlobalValues
, varOp
.getLoc());
995 if (failed(encodeInstructionInto(typesGlobalValues
, spirv::Opcode::OpVariable
,
997 elidedAttrs
.push_back("initializer");
1001 // Encode decorations.
1002 for (auto attr
: varOp
->getAttrs()) {
1003 if (llvm::any_of(elidedAttrs
,
1004 [&](StringRef elided
) { return attr
.first
== elided
; })) {
1007 if (failed(processDecoration(varOp
.getLoc(), resultID
, attr
))) {
1014 //===----------------------------------------------------------------------===//
1016 //===----------------------------------------------------------------------===//
1018 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
1019 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
1020 // PushConstant Storage Classes must be explicitly laid out."
1021 bool Serializer::isInterfaceStructPtrType(Type type
) const {
1022 if (auto ptrType
= type
.dyn_cast
<spirv::PointerType
>()) {
1023 switch (ptrType
.getStorageClass()) {
1024 case spirv::StorageClass::PhysicalStorageBuffer
:
1025 case spirv::StorageClass::PushConstant
:
1026 case spirv::StorageClass::StorageBuffer
:
1027 case spirv::StorageClass::Uniform
:
1028 return ptrType
.getPointeeType().isa
<spirv::StructType
>();
1036 LogicalResult
Serializer::processType(Location loc
, Type type
,
1038 // Maintains a set of names for nested identified struct types. This is used
1039 // to properly serialize resursive references.
1040 llvm::SetVector
<StringRef
> serializationCtx
;
1041 return processTypeImpl(loc
, type
, typeID
, serializationCtx
);
1045 Serializer::processTypeImpl(Location loc
, Type type
, uint32_t &typeID
,
1046 llvm::SetVector
<StringRef
> &serializationCtx
) {
1047 typeID
= getTypeID(type
);
1051 typeID
= getNextID();
1052 SmallVector
<uint32_t, 4> operands
;
1054 operands
.push_back(typeID
);
1055 auto typeEnum
= spirv::Opcode::OpTypeVoid
;
1056 bool deferSerialization
= false;
1058 if ((type
.isa
<FunctionType
>() &&
1059 succeeded(prepareFunctionType(loc
, type
.cast
<FunctionType
>(), typeEnum
,
1061 succeeded(prepareBasicType(loc
, type
, typeID
, typeEnum
, operands
,
1062 deferSerialization
, serializationCtx
))) {
1063 if (deferSerialization
)
1066 typeIDMap
[type
] = typeID
;
1068 if (failed(encodeInstructionInto(typesGlobalValues
, typeEnum
, operands
)))
1071 if (recursiveStructInfos
.count(type
) != 0) {
1072 // This recursive struct type is emitted already, now the OpTypePointer
1073 // instructions referring to recursive references are emitted as well.
1074 for (auto &ptrInfo
: recursiveStructInfos
[type
]) {
1075 // TODO: This might not work if more than 1 recursive reference is
1076 // present in the struct.
1077 SmallVector
<uint32_t, 4> ptrOperands
;
1078 ptrOperands
.push_back(ptrInfo
.pointerTypeID
);
1079 ptrOperands
.push_back(static_cast<uint32_t>(ptrInfo
.storageClass
));
1080 ptrOperands
.push_back(typeIDMap
[type
]);
1082 if (failed(encodeInstructionInto(
1083 typesGlobalValues
, spirv::Opcode::OpTypePointer
, ptrOperands
)))
1087 recursiveStructInfos
[type
].clear();
1096 LogicalResult
Serializer::prepareBasicType(
1097 Location loc
, Type type
, uint32_t resultID
, spirv::Opcode
&typeEnum
,
1098 SmallVectorImpl
<uint32_t> &operands
, bool &deferSerialization
,
1099 llvm::SetVector
<StringRef
> &serializationCtx
) {
1100 deferSerialization
= false;
1102 if (isVoidType(type
)) {
1103 typeEnum
= spirv::Opcode::OpTypeVoid
;
1107 if (auto intType
= type
.dyn_cast
<IntegerType
>()) {
1108 if (intType
.getWidth() == 1) {
1109 typeEnum
= spirv::Opcode::OpTypeBool
;
1113 typeEnum
= spirv::Opcode::OpTypeInt
;
1114 operands
.push_back(intType
.getWidth());
1115 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
1116 // to preserve or validate.
1117 // 0 indicates unsigned, or no signedness semantics
1118 // 1 indicates signed semantics."
1119 operands
.push_back(intType
.isSigned() ? 1 : 0);
1123 if (auto floatType
= type
.dyn_cast
<FloatType
>()) {
1124 typeEnum
= spirv::Opcode::OpTypeFloat
;
1125 operands
.push_back(floatType
.getWidth());
1129 if (auto vectorType
= type
.dyn_cast
<VectorType
>()) {
1130 uint32_t elementTypeID
= 0;
1131 if (failed(processTypeImpl(loc
, vectorType
.getElementType(), elementTypeID
,
1132 serializationCtx
))) {
1135 typeEnum
= spirv::Opcode::OpTypeVector
;
1136 operands
.push_back(elementTypeID
);
1137 operands
.push_back(vectorType
.getNumElements());
1141 if (auto arrayType
= type
.dyn_cast
<spirv::ArrayType
>()) {
1142 typeEnum
= spirv::Opcode::OpTypeArray
;
1143 uint32_t elementTypeID
= 0;
1144 if (failed(processTypeImpl(loc
, arrayType
.getElementType(), elementTypeID
,
1145 serializationCtx
))) {
1148 operands
.push_back(elementTypeID
);
1149 if (auto elementCountID
= prepareConstantInt(
1150 loc
, mlirBuilder
.getI32IntegerAttr(arrayType
.getNumElements()))) {
1151 operands
.push_back(elementCountID
);
1153 return processTypeDecoration(loc
, arrayType
, resultID
);
1156 if (auto ptrType
= type
.dyn_cast
<spirv::PointerType
>()) {
1157 uint32_t pointeeTypeID
= 0;
1158 spirv::StructType pointeeStruct
=
1159 ptrType
.getPointeeType().dyn_cast
<spirv::StructType
>();
1161 if (pointeeStruct
&& pointeeStruct
.isIdentified() &&
1162 serializationCtx
.count(pointeeStruct
.getIdentifier()) != 0) {
1163 // A recursive reference to an enclosing struct is found.
1165 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
1166 // class as operands.
1167 SmallVector
<uint32_t, 2> forwardPtrOperands
;
1168 forwardPtrOperands
.push_back(resultID
);
1169 forwardPtrOperands
.push_back(
1170 static_cast<uint32_t>(ptrType
.getStorageClass()));
1172 encodeInstructionInto(typesGlobalValues
,
1173 spirv::Opcode::OpTypeForwardPointer
,
1174 forwardPtrOperands
);
1176 // 2. Find the pointee (enclosing) struct.
1177 auto structType
= spirv::StructType::getIdentified(
1178 module
.getContext(), pointeeStruct
.getIdentifier());
1183 // 3. Mark the OpTypePointer that is supposed to be emitted by this call
1185 deferSerialization
= true;
1187 // 4. Record the info needed to emit the deferred OpTypePointer
1188 // instruction when the enclosing struct is completely serialized.
1189 recursiveStructInfos
[structType
].push_back(
1190 {resultID
, ptrType
.getStorageClass()});
1192 if (failed(processTypeImpl(loc
, ptrType
.getPointeeType(), pointeeTypeID
,
1197 typeEnum
= spirv::Opcode::OpTypePointer
;
1198 operands
.push_back(static_cast<uint32_t>(ptrType
.getStorageClass()));
1199 operands
.push_back(pointeeTypeID
);
1203 if (auto runtimeArrayType
= type
.dyn_cast
<spirv::RuntimeArrayType
>()) {
1204 uint32_t elementTypeID
= 0;
1205 if (failed(processTypeImpl(loc
, runtimeArrayType
.getElementType(),
1206 elementTypeID
, serializationCtx
))) {
1209 typeEnum
= spirv::Opcode::OpTypeRuntimeArray
;
1210 operands
.push_back(elementTypeID
);
1211 return processTypeDecoration(loc
, runtimeArrayType
, resultID
);
1214 if (auto structType
= type
.dyn_cast
<spirv::StructType
>()) {
1215 if (structType
.isIdentified()) {
1216 processName(resultID
, structType
.getIdentifier());
1217 serializationCtx
.insert(structType
.getIdentifier());
1220 bool hasOffset
= structType
.hasOffset();
1221 for (auto elementIndex
:
1222 llvm::seq
<uint32_t>(0, structType
.getNumElements())) {
1223 uint32_t elementTypeID
= 0;
1224 if (failed(processTypeImpl(loc
, structType
.getElementType(elementIndex
),
1225 elementTypeID
, serializationCtx
))) {
1228 operands
.push_back(elementTypeID
);
1230 // Decorate each struct member with an offset
1231 spirv::StructType::MemberDecorationInfo offsetDecoration
{
1232 elementIndex
, /*hasValue=*/1, spirv::Decoration::Offset
,
1233 static_cast<uint32_t>(structType
.getMemberOffset(elementIndex
))};
1234 if (failed(processMemberDecoration(resultID
, offsetDecoration
))) {
1235 return emitError(loc
, "cannot decorate ")
1236 << elementIndex
<< "-th member of " << structType
1237 << " with its offset";
1241 SmallVector
<spirv::StructType::MemberDecorationInfo
, 4> memberDecorations
;
1242 structType
.getMemberDecorations(memberDecorations
);
1244 for (auto &memberDecoration
: memberDecorations
) {
1245 if (failed(processMemberDecoration(resultID
, memberDecoration
))) {
1246 return emitError(loc
, "cannot decorate ")
1247 << static_cast<uint32_t>(memberDecoration
.memberIndex
)
1248 << "-th member of " << structType
<< " with "
1249 << stringifyDecoration(memberDecoration
.decoration
);
1253 typeEnum
= spirv::Opcode::OpTypeStruct
;
1255 if (structType
.isIdentified())
1256 serializationCtx
.remove(structType
.getIdentifier());
1261 if (auto cooperativeMatrixType
=
1262 type
.dyn_cast
<spirv::CooperativeMatrixNVType
>()) {
1263 uint32_t elementTypeID
= 0;
1264 if (failed(processTypeImpl(loc
, cooperativeMatrixType
.getElementType(),
1265 elementTypeID
, serializationCtx
))) {
1268 typeEnum
= spirv::Opcode::OpTypeCooperativeMatrixNV
;
1269 auto getConstantOp
= [&](uint32_t id
) {
1270 auto attr
= IntegerAttr::get(IntegerType::get(32, type
.getContext()), id
);
1271 return prepareConstantInt(loc
, attr
);
1273 operands
.push_back(elementTypeID
);
1275 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType
.getScope())));
1276 operands
.push_back(getConstantOp(cooperativeMatrixType
.getRows()));
1277 operands
.push_back(getConstantOp(cooperativeMatrixType
.getColumns()));
1281 if (auto matrixType
= type
.dyn_cast
<spirv::MatrixType
>()) {
1282 uint32_t elementTypeID
= 0;
1283 if (failed(processTypeImpl(loc
, matrixType
.getColumnType(), elementTypeID
,
1284 serializationCtx
))) {
1287 typeEnum
= spirv::Opcode::OpTypeMatrix
;
1288 operands
.push_back(elementTypeID
);
1289 operands
.push_back(matrixType
.getNumColumns());
1293 // TODO: Handle other types.
1294 return emitError(loc
, "unhandled type in serialization: ") << type
;
1298 Serializer::prepareFunctionType(Location loc
, FunctionType type
,
1299 spirv::Opcode
&typeEnum
,
1300 SmallVectorImpl
<uint32_t> &operands
) {
1301 typeEnum
= spirv::Opcode::OpTypeFunction
;
1302 assert(type
.getNumResults() <= 1 &&
1303 "serialization supports only a single return value");
1304 uint32_t resultID
= 0;
1305 if (failed(processType(
1306 loc
, type
.getNumResults() == 1 ? type
.getResult(0) : getVoidType(),
1310 operands
.push_back(resultID
);
1311 for (auto &res
: type
.getInputs()) {
1312 uint32_t argTypeID
= 0;
1313 if (failed(processType(loc
, res
, argTypeID
))) {
1316 operands
.push_back(argTypeID
);
1321 //===----------------------------------------------------------------------===//
1323 //===----------------------------------------------------------------------===//
1325 uint32_t Serializer::prepareConstant(Location loc
, Type constType
,
1326 Attribute valueAttr
) {
1327 if (auto id
= prepareConstantScalar(loc
, valueAttr
)) {
1331 // This is a composite literal. We need to handle each component separately
1332 // and then emit an OpConstantComposite for the whole.
1334 if (auto id
= getConstantID(valueAttr
)) {
1338 uint32_t typeID
= 0;
1339 if (failed(processType(loc
, constType
, typeID
))) {
1343 uint32_t resultID
= 0;
1344 if (auto attr
= valueAttr
.dyn_cast
<DenseElementsAttr
>()) {
1345 int rank
= attr
.getType().dyn_cast
<ShapedType
>().getRank();
1346 SmallVector
<uint64_t, 4> index(rank
);
1347 resultID
= prepareDenseElementsConstant(loc
, constType
, attr
,
1349 } else if (auto arrayAttr
= valueAttr
.dyn_cast
<ArrayAttr
>()) {
1350 resultID
= prepareArrayConstant(loc
, constType
, arrayAttr
);
1353 if (resultID
== 0) {
1354 emitError(loc
, "cannot serialize attribute: ") << valueAttr
;
1358 constIDMap
[valueAttr
] = resultID
;
1362 uint32_t Serializer::prepareArrayConstant(Location loc
, Type constType
,
1364 uint32_t typeID
= 0;
1365 if (failed(processType(loc
, constType
, typeID
))) {
1369 uint32_t resultID
= getNextID();
1370 SmallVector
<uint32_t, 4> operands
= {typeID
, resultID
};
1371 operands
.reserve(attr
.size() + 2);
1372 auto elementType
= constType
.cast
<spirv::ArrayType
>().getElementType();
1373 for (Attribute elementAttr
: attr
) {
1374 if (auto elementID
= prepareConstant(loc
, elementType
, elementAttr
)) {
1375 operands
.push_back(elementID
);
1380 spirv::Opcode opcode
= spirv::Opcode::OpConstantComposite
;
1381 encodeInstructionInto(typesGlobalValues
, opcode
, operands
);
1386 // TODO: Turn the below function into iterative function, instead of
1387 // recursive function.
1389 Serializer::prepareDenseElementsConstant(Location loc
, Type constType
,
1390 DenseElementsAttr valueAttr
, int dim
,
1391 MutableArrayRef
<uint64_t> index
) {
1392 auto shapedType
= valueAttr
.getType().dyn_cast
<ShapedType
>();
1393 assert(dim
<= shapedType
.getRank());
1394 if (shapedType
.getRank() == dim
) {
1395 if (auto attr
= valueAttr
.dyn_cast
<DenseIntElementsAttr
>()) {
1396 return attr
.getType().getElementType().isInteger(1)
1397 ? prepareConstantBool(loc
, attr
.getValue
<BoolAttr
>(index
))
1398 : prepareConstantInt(loc
, attr
.getValue
<IntegerAttr
>(index
));
1400 if (auto attr
= valueAttr
.dyn_cast
<DenseFPElementsAttr
>()) {
1401 return prepareConstantFp(loc
, attr
.getValue
<FloatAttr
>(index
));
1406 uint32_t typeID
= 0;
1407 if (failed(processType(loc
, constType
, typeID
))) {
1411 uint32_t resultID
= getNextID();
1412 SmallVector
<uint32_t, 4> operands
= {typeID
, resultID
};
1413 operands
.reserve(shapedType
.getDimSize(dim
) + 2);
1414 auto elementType
= constType
.cast
<spirv::CompositeType
>().getElementType(0);
1415 for (int i
= 0; i
< shapedType
.getDimSize(dim
); ++i
) {
1417 if (auto elementID
= prepareDenseElementsConstant(
1418 loc
, elementType
, valueAttr
, dim
+ 1, index
)) {
1419 operands
.push_back(elementID
);
1424 spirv::Opcode opcode
= spirv::Opcode::OpConstantComposite
;
1425 encodeInstructionInto(typesGlobalValues
, opcode
, operands
);
1430 uint32_t Serializer::prepareConstantScalar(Location loc
, Attribute valueAttr
,
1432 if (auto floatAttr
= valueAttr
.dyn_cast
<FloatAttr
>()) {
1433 return prepareConstantFp(loc
, floatAttr
, isSpec
);
1435 if (auto boolAttr
= valueAttr
.dyn_cast
<BoolAttr
>()) {
1436 return prepareConstantBool(loc
, boolAttr
, isSpec
);
1438 if (auto intAttr
= valueAttr
.dyn_cast
<IntegerAttr
>()) {
1439 return prepareConstantInt(loc
, intAttr
, isSpec
);
1445 uint32_t Serializer::prepareConstantBool(Location loc
, BoolAttr boolAttr
,
1448 // We can de-duplicate normal constants, but not specialization constants.
1449 if (auto id
= getConstantID(boolAttr
)) {
1454 // Process the type for this bool literal
1455 uint32_t typeID
= 0;
1456 if (failed(processType(loc
, boolAttr
.getType(), typeID
))) {
1460 auto resultID
= getNextID();
1461 auto opcode
= boolAttr
.getValue()
1462 ? (isSpec
? spirv::Opcode::OpSpecConstantTrue
1463 : spirv::Opcode::OpConstantTrue
)
1464 : (isSpec
? spirv::Opcode::OpSpecConstantFalse
1465 : spirv::Opcode::OpConstantFalse
);
1466 encodeInstructionInto(typesGlobalValues
, opcode
, {typeID
, resultID
});
1469 constIDMap
[boolAttr
] = resultID
;
1474 uint32_t Serializer::prepareConstantInt(Location loc
, IntegerAttr intAttr
,
1477 // We can de-duplicate normal constants, but not specialization constants.
1478 if (auto id
= getConstantID(intAttr
)) {
1483 // Process the type for this integer literal
1484 uint32_t typeID
= 0;
1485 if (failed(processType(loc
, intAttr
.getType(), typeID
))) {
1489 auto resultID
= getNextID();
1490 APInt value
= intAttr
.getValue();
1491 unsigned bitwidth
= value
.getBitWidth();
1492 bool isSigned
= value
.isSignedIntN(bitwidth
);
1495 isSpec
? spirv::Opcode::OpSpecConstant
: spirv::Opcode::OpConstant
;
1497 // According to SPIR-V spec, "When the type's bit width is less than 32-bits,
1498 // the literal's value appears in the low-order bits of the word, and the
1499 // high-order bits must be 0 for a floating-point type, or 0 for an integer
1500 // type with Signedness of 0, or sign extended when Signedness is 1."
1501 if (bitwidth
== 32 || bitwidth
== 16) {
1504 word
= static_cast<int32_t>(value
.getSExtValue());
1506 word
= static_cast<uint32_t>(value
.getZExtValue());
1508 encodeInstructionInto(typesGlobalValues
, opcode
, {typeID
, resultID
, word
});
1510 // According to SPIR-V spec: "When the type's bit width is larger than one
1511 // word, the literal’s low-order words appear first."
1512 else if (bitwidth
== 64) {
1518 words
= llvm::bit_cast
<DoubleWord
>(value
.getSExtValue());
1520 words
= llvm::bit_cast
<DoubleWord
>(value
.getZExtValue());
1522 encodeInstructionInto(typesGlobalValues
, opcode
,
1523 {typeID
, resultID
, words
.word1
, words
.word2
});
1525 std::string valueStr
;
1526 llvm::raw_string_ostream
rss(valueStr
);
1527 value
.print(rss
, /*isSigned=*/false);
1529 emitError(loc
, "cannot serialize ")
1530 << bitwidth
<< "-bit integer literal: " << rss
.str();
1535 constIDMap
[intAttr
] = resultID
;
1540 uint32_t Serializer::prepareConstantFp(Location loc
, FloatAttr floatAttr
,
1543 // We can de-duplicate normal constants, but not specialization constants.
1544 if (auto id
= getConstantID(floatAttr
)) {
1549 // Process the type for this float literal
1550 uint32_t typeID
= 0;
1551 if (failed(processType(loc
, floatAttr
.getType(), typeID
))) {
1555 auto resultID
= getNextID();
1556 APFloat value
= floatAttr
.getValue();
1557 APInt intValue
= value
.bitcastToAPInt();
1560 isSpec
? spirv::Opcode::OpSpecConstant
: spirv::Opcode::OpConstant
;
1562 if (&value
.getSemantics() == &APFloat::IEEEsingle()) {
1563 uint32_t word
= llvm::bit_cast
<uint32_t>(value
.convertToFloat());
1564 encodeInstructionInto(typesGlobalValues
, opcode
, {typeID
, resultID
, word
});
1565 } else if (&value
.getSemantics() == &APFloat::IEEEdouble()) {
1569 } words
= llvm::bit_cast
<DoubleWord
>(value
.convertToDouble());
1570 encodeInstructionInto(typesGlobalValues
, opcode
,
1571 {typeID
, resultID
, words
.word1
, words
.word2
});
1572 } else if (&value
.getSemantics() == &APFloat::IEEEhalf()) {
1574 static_cast<uint32_t>(value
.bitcastToAPInt().getZExtValue());
1575 encodeInstructionInto(typesGlobalValues
, opcode
, {typeID
, resultID
, word
});
1577 std::string valueStr
;
1578 llvm::raw_string_ostream
rss(valueStr
);
1581 emitError(loc
, "cannot serialize ")
1582 << floatAttr
.getType() << "-typed float literal: " << rss
.str();
1587 constIDMap
[floatAttr
] = resultID
;
1592 //===----------------------------------------------------------------------===//
1594 //===----------------------------------------------------------------------===//
1596 uint32_t Serializer::getOrCreateBlockID(Block
*block
) {
1597 if (uint32_t id
= getBlockID(block
))
1599 return blockIDMap
[block
] = getNextID();
1603 Serializer::processBlock(Block
*block
, bool omitLabel
,
1604 function_ref
<void()> actionBeforeTerminator
) {
1605 LLVM_DEBUG(llvm::dbgs() << "processing block " << block
<< ":\n");
1606 LLVM_DEBUG(block
->print(llvm::dbgs()));
1607 LLVM_DEBUG(llvm::dbgs() << '\n');
1609 uint32_t blockID
= getOrCreateBlockID(block
);
1610 LLVM_DEBUG(llvm::dbgs()
1611 << "[block] " << block
<< " (id = " << blockID
<< ")\n");
1613 // Emit OpLabel for this block.
1614 encodeInstructionInto(functionBody
, spirv::Opcode::OpLabel
, {blockID
});
1617 // Emit OpPhi instructions for block arguments, if any.
1618 if (failed(emitPhiForBlockArguments(block
)))
1621 // Process each op in this block except the terminator.
1622 for (auto &op
: llvm::make_range(block
->begin(), std::prev(block
->end()))) {
1623 if (failed(processOperation(&op
)))
1627 // Process the terminator.
1628 if (actionBeforeTerminator
)
1629 actionBeforeTerminator();
1630 if (failed(processOperation(&block
->back())))
1636 LogicalResult
Serializer::emitPhiForBlockArguments(Block
*block
) {
1637 // Nothing to do if this block has no arguments or it's the entry block, which
1638 // always has the same arguments as the function signature.
1639 if (block
->args_empty() || block
->isEntryBlock())
1642 // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1643 // A SPIR-V OpPhi instruction is of the syntax:
1644 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1645 // So we need to collect all predecessor blocks and the arguments they send
1647 SmallVector
<std::pair
<Block
*, Operation::operand_iterator
>, 4> predecessors
;
1648 for (Block
*predecessor
: block
->getPredecessors()) {
1649 auto *terminator
= predecessor
->getTerminator();
1650 // The predecessor here is the immediate one according to MLIR's IR
1651 // structure. It does not directly map to the incoming parent block for the
1652 // OpPhi instructions at SPIR-V binary level. This is because structured
1653 // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1654 // spv.selection/spv.loop op in the MLIR predecessor block, the branch op
1655 // jumping to the OpPhi's block then resides in the previous structured
1656 // control flow op's merge block.
1657 predecessor
= getPhiIncomingBlock(predecessor
);
1658 if (auto branchOp
= dyn_cast
<spirv::BranchOp
>(terminator
)) {
1659 predecessors
.emplace_back(predecessor
, branchOp
.operand_begin());
1661 return terminator
->emitError("unimplemented terminator for Phi creation");
1665 // Then create OpPhi instruction for each of the block argument.
1666 for (auto argIndex
: llvm::seq
<unsigned>(0, block
->getNumArguments())) {
1667 BlockArgument arg
= block
->getArgument(argIndex
);
1669 // Get the type <id> and result <id> for this OpPhi instruction.
1670 uint32_t phiTypeID
= 0;
1671 if (failed(processType(arg
.getLoc(), arg
.getType(), phiTypeID
)))
1673 uint32_t phiID
= getNextID();
1675 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex
<< ' '
1676 << arg
<< " (id = " << phiID
<< ")\n");
1678 // Prepare the (value <id>, parent block <id>) pairs.
1679 SmallVector
<uint32_t, 8> phiArgs
;
1680 phiArgs
.push_back(phiTypeID
);
1681 phiArgs
.push_back(phiID
);
1683 for (auto predIndex
: llvm::seq
<unsigned>(0, predecessors
.size())) {
1684 Value value
= *(predecessors
[predIndex
].second
+ argIndex
);
1685 uint32_t predBlockId
= getOrCreateBlockID(predecessors
[predIndex
].first
);
1686 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1687 << ") value " << value
<< ' ');
1688 // Each pair is a value <id> ...
1689 uint32_t valueId
= getValueID(value
);
1691 // The op generating this value hasn't been visited yet so we don't have
1692 // an <id> assigned yet. Record this to fix up later.
1693 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1694 deferredPhiValues
[value
].push_back(functionBody
.size() + 1 +
1697 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId
<< ")\n");
1699 phiArgs
.push_back(valueId
);
1700 // ... and a parent block <id>.
1701 phiArgs
.push_back(predBlockId
);
1704 encodeInstructionInto(functionBody
, spirv::Opcode::OpPhi
, phiArgs
);
1705 valueIDMap
[arg
] = phiID
;
1711 LogicalResult
Serializer::processSelectionOp(spirv::SelectionOp selectionOp
) {
1712 // Assign <id>s to all blocks so that branches inside the SelectionOp can
1713 // resolve properly.
1714 auto &body
= selectionOp
.body();
1715 for (Block
&block
: body
)
1716 getOrCreateBlockID(&block
);
1718 auto *headerBlock
= selectionOp
.getHeaderBlock();
1719 auto *mergeBlock
= selectionOp
.getMergeBlock();
1720 auto mergeID
= getBlockID(mergeBlock
);
1721 auto loc
= selectionOp
.getLoc();
1723 // Emit the selection header block, which dominates all other blocks, first.
1724 // We need to emit an OpSelectionMerge instruction before the selection header
1725 // block's terminator.
1726 auto emitSelectionMerge
= [&]() {
1727 emitDebugLine(functionBody
, loc
);
1728 lastProcessedWasMergeInst
= true;
1729 encodeInstructionInto(
1730 functionBody
, spirv::Opcode::OpSelectionMerge
,
1731 {mergeID
, static_cast<uint32_t>(selectionOp
.selection_control())});
1733 // For structured selection, we cannot have blocks in the selection construct
1734 // branching to the selection header block. Entering the selection (and
1735 // reaching the selection header) must be from the block containing the
1736 // spv.selection op. If there are ops ahead of the spv.selection op in the
1737 // block, we can "merge" them into the selection header. So here we don't need
1738 // to emit a separate block; just continue with the existing block.
1739 if (failed(processBlock(headerBlock
, /*omitLabel=*/true, emitSelectionMerge
)))
1742 // Process all blocks with a depth-first visitor starting from the header
1743 // block. The selection header block and merge block are skipped by this
1745 if (failed(visitInPrettyBlockOrder(
1746 headerBlock
, [&](Block
*block
) { return processBlock(block
); },
1747 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock
})))
1750 // There is nothing to do for the merge block in the selection, which just
1751 // contains a spv.mlir.merge op, itself. But we need to have an OpLabel
1752 // instruction to start a new SPIR-V block for ops following this SelectionOp.
1753 // The block should use the <id> for the merge block.
1754 return encodeInstructionInto(functionBody
, spirv::Opcode::OpLabel
, {mergeID
});
1757 LogicalResult
Serializer::processLoopOp(spirv::LoopOp loopOp
) {
1758 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
1759 // properly. We don't need to assign for the entry block, which is just for
1760 // satisfying MLIR region's structural requirement.
1761 auto &body
= loopOp
.body();
1763 llvm::make_range(std::next(body
.begin(), 1), body
.end())) {
1764 getOrCreateBlockID(&block
);
1766 auto *headerBlock
= loopOp
.getHeaderBlock();
1767 auto *continueBlock
= loopOp
.getContinueBlock();
1768 auto *mergeBlock
= loopOp
.getMergeBlock();
1769 auto headerID
= getBlockID(headerBlock
);
1770 auto continueID
= getBlockID(continueBlock
);
1771 auto mergeID
= getBlockID(mergeBlock
);
1772 auto loc
= loopOp
.getLoc();
1774 // This LoopOp is in some MLIR block with preceding and following ops. In the
1775 // binary format, it should reside in separate SPIR-V blocks from its
1776 // preceding and following ops. So we need to emit unconditional branches to
1777 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
1779 encodeInstructionInto(functionBody
, spirv::Opcode::OpBranch
, {headerID
});
1781 // LoopOp's entry block is just there for satisfying MLIR's structural
1782 // requirements so we omit it and start serialization from the loop header
1785 // Emit the loop header block, which dominates all other blocks, first. We
1786 // need to emit an OpLoopMerge instruction before the loop header block's
1788 auto emitLoopMerge
= [&]() {
1789 emitDebugLine(functionBody
, loc
);
1790 lastProcessedWasMergeInst
= true;
1791 encodeInstructionInto(
1792 functionBody
, spirv::Opcode::OpLoopMerge
,
1793 {mergeID
, continueID
, static_cast<uint32_t>(loopOp
.loop_control())});
1795 if (failed(processBlock(headerBlock
, /*omitLabel=*/false, emitLoopMerge
)))
1798 // Process all blocks with a depth-first visitor starting from the header
1799 // block. The loop header block, loop continue block, and loop merge block are
1800 // skipped by this visitor and handled later in this function.
1801 if (failed(visitInPrettyBlockOrder(
1802 headerBlock
, [&](Block
*block
) { return processBlock(block
); },
1803 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock
, mergeBlock
})))
1806 // We have handled all other blocks. Now get to the loop continue block.
1807 if (failed(processBlock(continueBlock
)))
1810 // There is nothing to do for the merge block in the loop, which just contains
1811 // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to
1812 // start a new SPIR-V block for ops following this LoopOp. The block should
1813 // use the <id> for the merge block.
1814 return encodeInstructionInto(functionBody
, spirv::Opcode::OpLabel
, {mergeID
});
1817 LogicalResult
Serializer::processBranchConditionalOp(
1818 spirv::BranchConditionalOp condBranchOp
) {
1819 auto conditionID
= getValueID(condBranchOp
.condition());
1820 auto trueLabelID
= getOrCreateBlockID(condBranchOp
.getTrueBlock());
1821 auto falseLabelID
= getOrCreateBlockID(condBranchOp
.getFalseBlock());
1822 SmallVector
<uint32_t, 5> arguments
{conditionID
, trueLabelID
, falseLabelID
};
1824 if (auto weights
= condBranchOp
.branch_weights()) {
1825 for (auto val
: weights
->getValue())
1826 arguments
.push_back(val
.cast
<IntegerAttr
>().getInt());
1829 emitDebugLine(functionBody
, condBranchOp
.getLoc());
1830 return encodeInstructionInto(functionBody
, spirv::Opcode::OpBranchConditional
,
1834 LogicalResult
Serializer::processBranchOp(spirv::BranchOp branchOp
) {
1835 emitDebugLine(functionBody
, branchOp
.getLoc());
1836 return encodeInstructionInto(functionBody
, spirv::Opcode::OpBranch
,
1837 {getOrCreateBlockID(branchOp
.getTarget())});
1840 //===----------------------------------------------------------------------===//
1842 //===----------------------------------------------------------------------===//
1844 LogicalResult
Serializer::encodeExtensionInstruction(
1845 Operation
*op
, StringRef extensionSetName
, uint32_t extensionOpcode
,
1846 ArrayRef
<uint32_t> operands
) {
1847 // Check if the extension has been imported.
1848 auto &setID
= extendedInstSetIDMap
[extensionSetName
];
1850 setID
= getNextID();
1851 SmallVector
<uint32_t, 16> importOperands
;
1852 importOperands
.push_back(setID
);
1854 spirv::encodeStringLiteralInto(importOperands
, extensionSetName
)) ||
1855 failed(encodeInstructionInto(
1856 extendedSets
, spirv::Opcode::OpExtInstImport
, importOperands
))) {
1861 // The first two operands are the result type <id> and result <id>. The set
1862 // <id> and the opcode need to be insert after this.
1863 if (operands
.size() < 2) {
1864 return op
->emitError("extended instructions must have a result encoding");
1866 SmallVector
<uint32_t, 8> extInstOperands
;
1867 extInstOperands
.reserve(operands
.size() + 2);
1868 extInstOperands
.append(operands
.begin(), std::next(operands
.begin(), 2));
1869 extInstOperands
.push_back(setID
);
1870 extInstOperands
.push_back(extensionOpcode
);
1871 extInstOperands
.append(std::next(operands
.begin(), 2), operands
.end());
1872 return encodeInstructionInto(functionBody
, spirv::Opcode::OpExtInst
,
1876 LogicalResult
Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp
) {
1877 auto varName
= addressOfOp
.variable();
1878 auto variableID
= getVariableID(varName
);
1880 return addressOfOp
.emitError("unknown result <id> for variable ")
1883 valueIDMap
[addressOfOp
.pointer()] = variableID
;
1888 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp
) {
1889 auto constName
= referenceOfOp
.spec_const();
1890 auto constID
= getSpecConstID(constName
);
1892 return referenceOfOp
.emitError(
1893 "unknown result <id> for specialization constant ")
1896 valueIDMap
[referenceOfOp
.reference()] = constID
;
1900 LogicalResult
Serializer::processOperation(Operation
*opInst
) {
1901 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst
->getName() << "'\n");
1903 // First dispatch the ops that do not directly mirror an instruction from
1905 return TypeSwitch
<Operation
*, LogicalResult
>(opInst
)
1906 .Case([&](spirv::AddressOfOp op
) { return processAddressOfOp(op
); })
1907 .Case([&](spirv::BranchOp op
) { return processBranchOp(op
); })
1908 .Case([&](spirv::BranchConditionalOp op
) {
1909 return processBranchConditionalOp(op
);
1911 .Case([&](spirv::ConstantOp op
) { return processConstantOp(op
); })
1912 .Case([&](spirv::FuncOp op
) { return processFuncOp(op
); })
1913 .Case([&](spirv::GlobalVariableOp op
) {
1914 return processGlobalVariableOp(op
);
1916 .Case([&](spirv::LoopOp op
) { return processLoopOp(op
); })
1917 .Case([&](spirv::ModuleEndOp
) { return success(); })
1918 .Case([&](spirv::ReferenceOfOp op
) { return processReferenceOfOp(op
); })
1919 .Case([&](spirv::SelectionOp op
) { return processSelectionOp(op
); })
1920 .Case([&](spirv::SpecConstantOp op
) { return processSpecConstantOp(op
); })
1921 .Case([&](spirv::SpecConstantCompositeOp op
) {
1922 return processSpecConstantCompositeOp(op
);
1924 .Case([&](spirv::UndefOp op
) { return processUndefOp(op
); })
1925 .Case([&](spirv::VariableOp op
) { return processVariableOp(op
); })
1927 // Then handle all the ops that directly mirror SPIR-V instructions with
1928 // auto-generated methods.
1930 [&](Operation
*op
) { return dispatchToAutogenSerialization(op
); });
1936 Serializer::processOp
<spirv::EntryPointOp
>(spirv::EntryPointOp op
) {
1937 SmallVector
<uint32_t, 4> operands
;
1938 // Add the ExecutionModel.
1939 operands
.push_back(static_cast<uint32_t>(op
.execution_model()));
1940 // Add the function <id>.
1941 auto funcID
= getFunctionID(op
.fn());
1943 return op
.emitError("missing <id> for function ")
1945 << "; function needs to be defined before spv.EntryPoint is "
1948 operands
.push_back(funcID
);
1949 // Add the name of the function.
1950 spirv::encodeStringLiteralInto(operands
, op
.fn());
1952 // Add the interface values.
1953 if (auto interface
= op
.interface()) {
1954 for (auto var
: interface
.getValue()) {
1955 auto id
= getVariableID(var
.cast
<FlatSymbolRefAttr
>().getValue());
1957 return op
.emitError("referencing undefined global variable."
1958 "spv.EntryPoint is at the end of spv.module. All "
1959 "referenced variables should already be defined");
1961 operands
.push_back(id
);
1964 return encodeInstructionInto(entryPoints
, spirv::Opcode::OpEntryPoint
,
1970 Serializer::processOp
<spirv::ControlBarrierOp
>(spirv::ControlBarrierOp op
) {
1971 StringRef argNames
[] = {"execution_scope", "memory_scope",
1972 "memory_semantics"};
1973 SmallVector
<uint32_t, 3> operands
;
1975 for (auto argName
: argNames
) {
1976 auto argIntAttr
= op
->getAttrOfType
<IntegerAttr
>(argName
);
1977 auto operand
= prepareConstantInt(op
.getLoc(), argIntAttr
);
1981 operands
.push_back(operand
);
1984 return encodeInstructionInto(functionBody
, spirv::Opcode::OpControlBarrier
,
1990 Serializer::processOp
<spirv::ExecutionModeOp
>(spirv::ExecutionModeOp op
) {
1991 SmallVector
<uint32_t, 4> operands
;
1992 // Add the function <id>.
1993 auto funcID
= getFunctionID(op
.fn());
1995 return op
.emitError("missing <id> for function ")
1997 << "; function needs to be serialized before ExecutionModeOp is "
2000 operands
.push_back(funcID
);
2001 // Add the ExecutionMode.
2002 operands
.push_back(static_cast<uint32_t>(op
.execution_mode()));
2004 // Serialize values if any.
2005 auto values
= op
.values();
2007 for (auto &intVal
: values
.getValue()) {
2008 operands
.push_back(static_cast<uint32_t>(
2009 intVal
.cast
<IntegerAttr
>().getValue().getZExtValue()));
2012 return encodeInstructionInto(executionModes
, spirv::Opcode::OpExecutionMode
,
2018 Serializer::processOp
<spirv::MemoryBarrierOp
>(spirv::MemoryBarrierOp op
) {
2019 StringRef argNames
[] = {"memory_scope", "memory_semantics"};
2020 SmallVector
<uint32_t, 2> operands
;
2022 for (auto argName
: argNames
) {
2023 auto argIntAttr
= op
->getAttrOfType
<IntegerAttr
>(argName
);
2024 auto operand
= prepareConstantInt(op
.getLoc(), argIntAttr
);
2028 operands
.push_back(operand
);
2031 return encodeInstructionInto(functionBody
, spirv::Opcode::OpMemoryBarrier
,
2037 Serializer::processOp
<spirv::FunctionCallOp
>(spirv::FunctionCallOp op
) {
2038 auto funcName
= op
.callee();
2039 uint32_t resTypeID
= 0;
2041 Type resultTy
= op
.getNumResults() ? *op
.result_type_begin() : getVoidType();
2042 if (failed(processType(op
.getLoc(), resultTy
, resTypeID
)))
2045 auto funcID
= getOrCreateFunctionID(funcName
);
2046 auto funcCallID
= getNextID();
2047 SmallVector
<uint32_t, 8> operands
{resTypeID
, funcCallID
, funcID
};
2049 for (auto value
: op
.arguments()) {
2050 auto valueID
= getValueID(value
);
2051 assert(valueID
&& "cannot find a value for spv.FunctionCall");
2052 operands
.push_back(valueID
);
2055 if (!resultTy
.isa
<NoneType
>())
2056 valueIDMap
[op
.getResult(0)] = funcCallID
;
2058 return encodeInstructionInto(functionBody
, spirv::Opcode::OpFunctionCall
,
2064 Serializer::processOp
<spirv::CopyMemoryOp
>(spirv::CopyMemoryOp op
) {
2065 SmallVector
<uint32_t, 4> operands
;
2066 SmallVector
<StringRef
, 2> elidedAttrs
;
2068 for (Value operand
: op
->getOperands()) {
2069 auto id
= getValueID(operand
);
2070 assert(id
&& "use before def!");
2071 operands
.push_back(id
);
2074 if (auto attr
= op
->getAttr("memory_access")) {
2075 operands
.push_back(static_cast<uint32_t>(
2076 attr
.cast
<IntegerAttr
>().getValue().getZExtValue()));
2079 elidedAttrs
.push_back("memory_access");
2081 if (auto attr
= op
->getAttr("alignment")) {
2082 operands
.push_back(static_cast<uint32_t>(
2083 attr
.cast
<IntegerAttr
>().getValue().getZExtValue()));
2086 elidedAttrs
.push_back("alignment");
2088 if (auto attr
= op
->getAttr("source_memory_access")) {
2089 operands
.push_back(static_cast<uint32_t>(
2090 attr
.cast
<IntegerAttr
>().getValue().getZExtValue()));
2093 elidedAttrs
.push_back("source_memory_access");
2095 if (auto attr
= op
->getAttr("source_alignment")) {
2096 operands
.push_back(static_cast<uint32_t>(
2097 attr
.cast
<IntegerAttr
>().getValue().getZExtValue()));
2100 elidedAttrs
.push_back("source_alignment");
2101 emitDebugLine(functionBody
, op
.getLoc());
2102 encodeInstructionInto(functionBody
, spirv::Opcode::OpCopyMemory
, operands
);
2107 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
2108 // various Serializer::processOp<...>() specializations.
2109 #define GET_SERIALIZATION_FNS
2110 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
2113 LogicalResult
Serializer::emitDecoration(uint32_t target
,
2114 spirv::Decoration decoration
,
2115 ArrayRef
<uint32_t> params
) {
2116 uint32_t wordCount
= 3 + params
.size();
2117 decorations
.push_back(
2118 spirv::getPrefixedOpcode(wordCount
, spirv::Opcode::OpDecorate
));
2119 decorations
.push_back(target
);
2120 decorations
.push_back(static_cast<uint32_t>(decoration
));
2121 decorations
.append(params
.begin(), params
.end());
2125 LogicalResult
Serializer::emitDebugLine(SmallVectorImpl
<uint32_t> &binary
,
2130 if (lastProcessedWasMergeInst
) {
2131 lastProcessedWasMergeInst
= false;
2135 auto fileLoc
= loc
.dyn_cast
<FileLineColLoc
>();
2137 encodeInstructionInto(binary
, spirv::Opcode::OpLine
,
2138 {fileID
, fileLoc
.getLine(), fileLoc
.getColumn()});
2143 LogicalResult
spirv::serialize(spirv::ModuleOp module
,
2144 SmallVectorImpl
<uint32_t> &binary
,
2145 bool emitDebugInfo
) {
2146 if (!module
.vce_triple().hasValue())
2147 return module
.emitError(
2148 "module must have 'vce_triple' attribute to be serializeable");
2150 Serializer
serializer(module
, emitDebugInfo
);
2152 if (failed(serializer
.serialize()))
2155 LLVM_DEBUG(serializer
.printValueIDMap(llvm::dbgs()));
2157 serializer
.collect(binary
);