1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
11 //===----------------------------------------------------------------------===//
13 #include "Deserializer.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Location.h"
19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Debug.h"
27 #define DEBUG_TYPE "spirv-deserialization"
29 //===----------------------------------------------------------------------===//
31 //===----------------------------------------------------------------------===//
33 /// Extracts the opcode from the given first word of a SPIR-V instruction.
34 static inline spirv::Opcode
extractOpcode(uint32_t word
) {
35 return static_cast<spirv::Opcode
>(word
& 0xffff);
38 //===----------------------------------------------------------------------===//
40 //===----------------------------------------------------------------------===//
42 Value
spirv::Deserializer::getValue(uint32_t id
) {
43 if (auto constInfo
= getConstant(id
)) {
44 // Materialize a `spirv.Constant` op at every use site.
45 return opBuilder
.create
<spirv::ConstantOp
>(unknownLoc
, constInfo
->second
,
48 if (auto varOp
= getGlobalVariable(id
)) {
49 auto addressOfOp
= opBuilder
.create
<spirv::AddressOfOp
>(
50 unknownLoc
, varOp
.getType(), SymbolRefAttr::get(varOp
.getOperation()));
51 return addressOfOp
.getPointer();
53 if (auto constOp
= getSpecConstant(id
)) {
54 auto referenceOfOp
= opBuilder
.create
<spirv::ReferenceOfOp
>(
55 unknownLoc
, constOp
.getDefaultValue().getType(),
56 SymbolRefAttr::get(constOp
.getOperation()));
57 return referenceOfOp
.getReference();
59 if (auto constCompositeOp
= getSpecConstantComposite(id
)) {
60 auto referenceOfOp
= opBuilder
.create
<spirv::ReferenceOfOp
>(
61 unknownLoc
, constCompositeOp
.getType(),
62 SymbolRefAttr::get(constCompositeOp
.getOperation()));
63 return referenceOfOp
.getReference();
65 if (auto specConstOperationInfo
= getSpecConstantOperation(id
)) {
66 return materializeSpecConstantOperation(
67 id
, specConstOperationInfo
->enclodesOpcode
,
68 specConstOperationInfo
->resultTypeID
,
69 specConstOperationInfo
->enclosedOpOperands
);
71 if (auto undef
= getUndefType(id
)) {
72 return opBuilder
.create
<spirv::UndefOp
>(unknownLoc
, undef
);
74 return valueMap
.lookup(id
);
77 LogicalResult
spirv::Deserializer::sliceInstruction(
78 spirv::Opcode
&opcode
, ArrayRef
<uint32_t> &operands
,
79 std::optional
<spirv::Opcode
> expectedOpcode
) {
80 auto binarySize
= binary
.size();
81 if (curOffset
>= binarySize
) {
82 return emitError(unknownLoc
, "expected ")
83 << (expectedOpcode
? spirv::stringifyOpcode(*expectedOpcode
)
88 // For each instruction, get its word count from the first word to slice it
89 // from the stream properly, and then dispatch to the instruction handler.
91 uint32_t wordCount
= binary
[curOffset
] >> 16;
94 return emitError(unknownLoc
, "word count cannot be zero");
96 uint32_t nextOffset
= curOffset
+ wordCount
;
97 if (nextOffset
> binarySize
)
98 return emitError(unknownLoc
, "insufficient words for the last instruction");
100 opcode
= extractOpcode(binary
[curOffset
]);
101 operands
= binary
.slice(curOffset
+ 1, wordCount
- 1);
102 curOffset
= nextOffset
;
106 LogicalResult
spirv::Deserializer::processInstruction(
107 spirv::Opcode opcode
, ArrayRef
<uint32_t> operands
, bool deferInstructions
) {
108 LLVM_DEBUG(logger
.startLine() << "[inst] processing instruction "
109 << spirv::stringifyOpcode(opcode
) << "\n");
111 // First dispatch all the instructions whose opcode does not correspond to
112 // those that have a direct mirror in the SPIR-V dialect
114 case spirv::Opcode::OpCapability
:
115 return processCapability(operands
);
116 case spirv::Opcode::OpExtension
:
117 return processExtension(operands
);
118 case spirv::Opcode::OpExtInst
:
119 return processExtInst(operands
);
120 case spirv::Opcode::OpExtInstImport
:
121 return processExtInstImport(operands
);
122 case spirv::Opcode::OpMemberName
:
123 return processMemberName(operands
);
124 case spirv::Opcode::OpMemoryModel
:
125 return processMemoryModel(operands
);
126 case spirv::Opcode::OpEntryPoint
:
127 case spirv::Opcode::OpExecutionMode
:
128 if (deferInstructions
) {
129 deferredInstructions
.emplace_back(opcode
, operands
);
133 case spirv::Opcode::OpVariable
:
134 if (isa
<spirv::ModuleOp
>(opBuilder
.getBlock()->getParentOp())) {
135 return processGlobalVariable(operands
);
138 case spirv::Opcode::OpLine
:
139 return processDebugLine(operands
);
140 case spirv::Opcode::OpNoLine
:
143 case spirv::Opcode::OpName
:
144 return processName(operands
);
145 case spirv::Opcode::OpString
:
146 return processDebugString(operands
);
147 case spirv::Opcode::OpModuleProcessed
:
148 case spirv::Opcode::OpSource
:
149 case spirv::Opcode::OpSourceContinued
:
150 case spirv::Opcode::OpSourceExtension
:
151 // TODO: This is debug information embedded in the binary which should be
152 // translated into the spirv.module.
154 case spirv::Opcode::OpTypeVoid
:
155 case spirv::Opcode::OpTypeBool
:
156 case spirv::Opcode::OpTypeInt
:
157 case spirv::Opcode::OpTypeFloat
:
158 case spirv::Opcode::OpTypeVector
:
159 case spirv::Opcode::OpTypeMatrix
:
160 case spirv::Opcode::OpTypeArray
:
161 case spirv::Opcode::OpTypeFunction
:
162 case spirv::Opcode::OpTypeImage
:
163 case spirv::Opcode::OpTypeSampledImage
:
164 case spirv::Opcode::OpTypeRuntimeArray
:
165 case spirv::Opcode::OpTypeStruct
:
166 case spirv::Opcode::OpTypePointer
:
167 case spirv::Opcode::OpTypeCooperativeMatrixKHR
:
168 return processType(opcode
, operands
);
169 case spirv::Opcode::OpTypeForwardPointer
:
170 return processTypeForwardPointer(operands
);
171 case spirv::Opcode::OpTypeJointMatrixINTEL
:
172 return processType(opcode
, operands
);
173 case spirv::Opcode::OpConstant
:
174 return processConstant(operands
, /*isSpec=*/false);
175 case spirv::Opcode::OpSpecConstant
:
176 return processConstant(operands
, /*isSpec=*/true);
177 case spirv::Opcode::OpConstantComposite
:
178 return processConstantComposite(operands
);
179 case spirv::Opcode::OpSpecConstantComposite
:
180 return processSpecConstantComposite(operands
);
181 case spirv::Opcode::OpSpecConstantOp
:
182 return processSpecConstantOperation(operands
);
183 case spirv::Opcode::OpConstantTrue
:
184 return processConstantBool(/*isTrue=*/true, operands
, /*isSpec=*/false);
185 case spirv::Opcode::OpSpecConstantTrue
:
186 return processConstantBool(/*isTrue=*/true, operands
, /*isSpec=*/true);
187 case spirv::Opcode::OpConstantFalse
:
188 return processConstantBool(/*isTrue=*/false, operands
, /*isSpec=*/false);
189 case spirv::Opcode::OpSpecConstantFalse
:
190 return processConstantBool(/*isTrue=*/false, operands
, /*isSpec=*/true);
191 case spirv::Opcode::OpConstantNull
:
192 return processConstantNull(operands
);
193 case spirv::Opcode::OpDecorate
:
194 return processDecoration(operands
);
195 case spirv::Opcode::OpMemberDecorate
:
196 return processMemberDecoration(operands
);
197 case spirv::Opcode::OpFunction
:
198 return processFunction(operands
);
199 case spirv::Opcode::OpLabel
:
200 return processLabel(operands
);
201 case spirv::Opcode::OpBranch
:
202 return processBranch(operands
);
203 case spirv::Opcode::OpBranchConditional
:
204 return processBranchConditional(operands
);
205 case spirv::Opcode::OpSelectionMerge
:
206 return processSelectionMerge(operands
);
207 case spirv::Opcode::OpLoopMerge
:
208 return processLoopMerge(operands
);
209 case spirv::Opcode::OpPhi
:
210 return processPhi(operands
);
211 case spirv::Opcode::OpUndef
:
212 return processUndef(operands
);
216 return dispatchToAutogenDeserialization(opcode
, operands
);
219 LogicalResult
spirv::Deserializer::processOpWithoutGrammarAttr(
220 ArrayRef
<uint32_t> words
, StringRef opName
, bool hasResult
,
221 unsigned numOperands
) {
222 SmallVector
<Type
, 1> resultTypes
;
223 uint32_t valueID
= 0;
225 size_t wordIndex
= 0;
227 if (wordIndex
>= words
.size())
228 return emitError(unknownLoc
,
229 "expected result type <id> while deserializing for ")
232 // Decode the type <id>
233 auto type
= getType(words
[wordIndex
]);
235 return emitError(unknownLoc
, "unknown type result <id>: ")
237 resultTypes
.push_back(type
);
240 // Decode the result <id>
241 if (wordIndex
>= words
.size())
242 return emitError(unknownLoc
,
243 "expected result <id> while deserializing for ")
245 valueID
= words
[wordIndex
];
249 SmallVector
<Value
, 4> operands
;
250 SmallVector
<NamedAttribute
, 4> attributes
;
253 size_t operandIndex
= 0;
254 for (; operandIndex
< numOperands
&& wordIndex
< words
.size();
255 ++operandIndex
, ++wordIndex
) {
256 auto arg
= getValue(words
[wordIndex
]);
258 return emitError(unknownLoc
, "unknown result <id>: ") << words
[wordIndex
];
259 operands
.push_back(arg
);
261 if (operandIndex
!= numOperands
) {
264 "found less operands than expected when deserializing for ")
265 << opName
<< "; only " << operandIndex
<< " of " << numOperands
268 if (wordIndex
!= words
.size()) {
271 "found more operands than expected when deserializing for ")
272 << opName
<< "; only " << wordIndex
<< " of " << words
.size()
276 // Attach attributes from decorations
277 if (decorations
.count(valueID
)) {
278 auto attrs
= decorations
[valueID
].getAttrs();
279 attributes
.append(attrs
.begin(), attrs
.end());
282 // Create the op and update bookkeeping maps
283 Location loc
= createFileLineColLoc(opBuilder
);
284 OperationState
opState(loc
, opName
);
285 opState
.addOperands(operands
);
287 opState
.addTypes(resultTypes
);
288 opState
.addAttributes(attributes
);
289 Operation
*op
= opBuilder
.create(opState
);
291 valueMap
[valueID
] = op
->getResult(0);
293 if (op
->hasTrait
<OpTrait::IsTerminator
>())
299 LogicalResult
spirv::Deserializer::processUndef(ArrayRef
<uint32_t> operands
) {
300 if (operands
.size() != 2) {
301 return emitError(unknownLoc
, "OpUndef instruction must have two operands");
303 auto type
= getType(operands
[0]);
305 return emitError(unknownLoc
, "unknown type <id> with OpUndef instruction");
307 undefMap
[operands
[1]] = type
;
311 LogicalResult
spirv::Deserializer::processExtInst(ArrayRef
<uint32_t> operands
) {
312 if (operands
.size() < 4) {
313 return emitError(unknownLoc
,
314 "OpExtInst must have at least 4 operands, result type "
315 "<id>, result <id>, set <id> and instruction opcode");
317 if (!extendedInstSets
.count(operands
[2])) {
318 return emitError(unknownLoc
, "undefined set <id> in OpExtInst");
320 SmallVector
<uint32_t, 4> slicedOperands
;
321 slicedOperands
.append(operands
.begin(), std::next(operands
.begin(), 2));
322 slicedOperands
.append(std::next(operands
.begin(), 4), operands
.end());
323 return dispatchToExtensionSetAutogenDeserialization(
324 extendedInstSets
[operands
[2]], operands
[3], slicedOperands
);
332 Deserializer::processOp
<spirv::EntryPointOp
>(ArrayRef
<uint32_t> words
) {
333 unsigned wordIndex
= 0;
334 if (wordIndex
>= words
.size()) {
335 return emitError(unknownLoc
,
336 "missing Execution Model specification in OpEntryPoint");
338 auto execModel
= spirv::ExecutionModelAttr::get(
339 context
, static_cast<spirv::ExecutionModel
>(words
[wordIndex
++]));
340 if (wordIndex
>= words
.size()) {
341 return emitError(unknownLoc
, "missing <id> in OpEntryPoint");
343 // Get the function <id>
344 auto fnID
= words
[wordIndex
++];
345 // Get the function name
346 auto fnName
= decodeStringLiteral(words
, wordIndex
);
347 // Verify that the function <id> matches the fnName
348 auto parsedFunc
= getFunction(fnID
);
350 return emitError(unknownLoc
, "no function matching <id> ") << fnID
;
352 if (parsedFunc
.getName() != fnName
) {
353 // The deserializer uses "spirv_fn_<id>" as the function name if the input
354 // SPIR-V blob does not contain a name for it. We should use a more clear
355 // indication for such case rather than relying on naming details.
356 if (!parsedFunc
.getName().starts_with("spirv_fn_"))
357 return emitError(unknownLoc
,
358 "function name mismatch between OpEntryPoint "
359 "and OpFunction with <id> ")
360 << fnID
<< ": " << fnName
<< " vs. " << parsedFunc
.getName();
361 parsedFunc
.setName(fnName
);
363 SmallVector
<Attribute
, 4> interface
;
364 while (wordIndex
< words
.size()) {
365 auto arg
= getGlobalVariable(words
[wordIndex
]);
367 return emitError(unknownLoc
, "undefined result <id> ")
368 << words
[wordIndex
] << " while decoding OpEntryPoint";
370 interface
.push_back(SymbolRefAttr::get(arg
.getOperation()));
373 opBuilder
.create
<spirv::EntryPointOp
>(
374 unknownLoc
, execModel
, SymbolRefAttr::get(opBuilder
.getContext(), fnName
),
375 opBuilder
.getArrayAttr(interface
));
381 Deserializer::processOp
<spirv::ExecutionModeOp
>(ArrayRef
<uint32_t> words
) {
382 unsigned wordIndex
= 0;
383 if (wordIndex
>= words
.size()) {
384 return emitError(unknownLoc
,
385 "missing function result <id> in OpExecutionMode");
387 // Get the function <id> to get the name of the function
388 auto fnID
= words
[wordIndex
++];
389 auto fn
= getFunction(fnID
);
391 return emitError(unknownLoc
, "no function matching <id> ") << fnID
;
393 // Get the Execution mode
394 if (wordIndex
>= words
.size()) {
395 return emitError(unknownLoc
, "missing Execution Mode in OpExecutionMode");
397 auto execMode
= spirv::ExecutionModeAttr::get(
398 context
, static_cast<spirv::ExecutionMode
>(words
[wordIndex
++]));
401 SmallVector
<Attribute
, 4> attrListElems
;
402 while (wordIndex
< words
.size()) {
403 attrListElems
.push_back(opBuilder
.getI32IntegerAttr(words
[wordIndex
++]));
405 auto values
= opBuilder
.getArrayAttr(attrListElems
);
406 opBuilder
.create
<spirv::ExecutionModeOp
>(
407 unknownLoc
, SymbolRefAttr::get(opBuilder
.getContext(), fn
.getName()),
414 Deserializer::processOp
<spirv::FunctionCallOp
>(ArrayRef
<uint32_t> operands
) {
415 if (operands
.size() < 3) {
416 return emitError(unknownLoc
,
417 "OpFunctionCall must have at least 3 operands");
420 Type resultType
= getType(operands
[0]);
422 return emitError(unknownLoc
, "undefined result type from <id> ")
426 // Use null type to mean no result type.
427 if (isVoidType(resultType
))
428 resultType
= nullptr;
430 auto resultID
= operands
[1];
431 auto functionID
= operands
[2];
433 auto functionName
= getFunctionSymbol(functionID
);
435 SmallVector
<Value
, 4> arguments
;
436 for (auto operand
: llvm::drop_begin(operands
, 3)) {
437 auto value
= getValue(operand
);
439 return emitError(unknownLoc
, "unknown <id> ")
440 << operand
<< " used by OpFunctionCall";
442 arguments
.push_back(value
);
445 auto opFunctionCall
= opBuilder
.create
<spirv::FunctionCallOp
>(
446 unknownLoc
, resultType
,
447 SymbolRefAttr::get(opBuilder
.getContext(), functionName
), arguments
);
450 valueMap
[resultID
] = opFunctionCall
.getResult(0);
456 Deserializer::processOp
<spirv::CopyMemoryOp
>(ArrayRef
<uint32_t> words
) {
457 SmallVector
<Type
, 1> resultTypes
;
458 size_t wordIndex
= 0;
459 SmallVector
<Value
, 4> operands
;
460 SmallVector
<NamedAttribute
, 4> attributes
;
462 if (wordIndex
< words
.size()) {
463 auto arg
= getValue(words
[wordIndex
]);
466 return emitError(unknownLoc
, "unknown result <id> : ")
470 operands
.push_back(arg
);
474 if (wordIndex
< words
.size()) {
475 auto arg
= getValue(words
[wordIndex
]);
478 return emitError(unknownLoc
, "unknown result <id> : ")
482 operands
.push_back(arg
);
486 bool isAlignedAttr
= false;
488 if (wordIndex
< words
.size()) {
489 auto attrValue
= words
[wordIndex
++];
490 auto attr
= opBuilder
.getAttr
<spirv::MemoryAccessAttr
>(
491 static_cast<spirv::MemoryAccess
>(attrValue
));
492 attributes
.push_back(
493 opBuilder
.getNamedAttr(attributeName
<MemoryAccess
>(), attr
));
494 isAlignedAttr
= (attrValue
== 2);
497 if (isAlignedAttr
&& wordIndex
< words
.size()) {
498 attributes
.push_back(opBuilder
.getNamedAttr(
499 "alignment", opBuilder
.getI32IntegerAttr(words
[wordIndex
++])));
502 if (wordIndex
< words
.size()) {
503 auto attrValue
= words
[wordIndex
++];
504 auto attr
= opBuilder
.getAttr
<spirv::MemoryAccessAttr
>(
505 static_cast<spirv::MemoryAccess
>(attrValue
));
506 attributes
.push_back(opBuilder
.getNamedAttr("source_memory_access", attr
));
509 if (wordIndex
< words
.size()) {
510 attributes
.push_back(opBuilder
.getNamedAttr(
511 "source_alignment", opBuilder
.getI32IntegerAttr(words
[wordIndex
++])));
514 if (wordIndex
!= words
.size()) {
515 return emitError(unknownLoc
,
516 "found more operands than expected when deserializing "
517 "spirv::CopyMemoryOp, only ")
518 << wordIndex
<< " of " << words
.size() << " processed";
521 Location loc
= createFileLineColLoc(opBuilder
);
522 opBuilder
.create
<spirv::CopyMemoryOp
>(loc
, resultTypes
, operands
, attributes
);
528 LogicalResult
Deserializer::processOp
<spirv::GenericCastToPtrExplicitOp
>(
529 ArrayRef
<uint32_t> words
) {
530 if (words
.size() != 4) {
531 return emitError(unknownLoc
,
532 "expected 4 words in GenericCastToPtrExplicitOp"
536 SmallVector
<Type
, 1> resultTypes
;
537 SmallVector
<Value
, 4> operands
;
538 uint32_t valueID
= 0;
539 auto type
= getType(words
[0]);
542 return emitError(unknownLoc
, "unknown type result <id> : ") << words
[0];
543 resultTypes
.push_back(type
);
547 auto arg
= getValue(words
[2]);
549 return emitError(unknownLoc
, "unknown result <id> : ") << words
[2];
550 operands
.push_back(arg
);
552 Location loc
= createFileLineColLoc(opBuilder
);
553 Operation
*op
= opBuilder
.create
<spirv::GenericCastToPtrExplicitOp
>(
554 loc
, resultTypes
, operands
);
555 valueMap
[valueID
] = op
->getResult(0);
559 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
560 // various Deserializer::processOp<...>() specializations.
561 #define GET_DESERIALIZATION_FNS
562 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"