1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
11 //===----------------------------------------------------------------------===//
13 #include "Deserializer.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Location.h"
19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Debug.h"
27 #define DEBUG_TYPE "spirv-deserialization"
29 //===----------------------------------------------------------------------===//
31 //===----------------------------------------------------------------------===//
33 /// Extracts the opcode from the given first word of a SPIR-V instruction.
34 static inline spirv::Opcode
extractOpcode(uint32_t word
) {
35 return static_cast<spirv::Opcode
>(word
& 0xffff);
38 //===----------------------------------------------------------------------===//
40 //===----------------------------------------------------------------------===//
42 Value
spirv::Deserializer::getValue(uint32_t id
) {
43 if (auto constInfo
= getConstant(id
)) {
44 // Materialize a `spirv.Constant` op at every use site.
45 return opBuilder
.create
<spirv::ConstantOp
>(unknownLoc
, constInfo
->second
,
48 if (auto varOp
= getGlobalVariable(id
)) {
49 auto addressOfOp
= opBuilder
.create
<spirv::AddressOfOp
>(
50 unknownLoc
, varOp
.getType(), SymbolRefAttr::get(varOp
.getOperation()));
51 return addressOfOp
.getPointer();
53 if (auto constOp
= getSpecConstant(id
)) {
54 auto referenceOfOp
= opBuilder
.create
<spirv::ReferenceOfOp
>(
55 unknownLoc
, constOp
.getDefaultValue().getType(),
56 SymbolRefAttr::get(constOp
.getOperation()));
57 return referenceOfOp
.getReference();
59 if (auto constCompositeOp
= getSpecConstantComposite(id
)) {
60 auto referenceOfOp
= opBuilder
.create
<spirv::ReferenceOfOp
>(
61 unknownLoc
, constCompositeOp
.getType(),
62 SymbolRefAttr::get(constCompositeOp
.getOperation()));
63 return referenceOfOp
.getReference();
65 if (auto specConstOperationInfo
= getSpecConstantOperation(id
)) {
66 return materializeSpecConstantOperation(
67 id
, specConstOperationInfo
->enclodesOpcode
,
68 specConstOperationInfo
->resultTypeID
,
69 specConstOperationInfo
->enclosedOpOperands
);
71 if (auto undef
= getUndefType(id
)) {
72 return opBuilder
.create
<spirv::UndefOp
>(unknownLoc
, undef
);
74 return valueMap
.lookup(id
);
77 LogicalResult
spirv::Deserializer::sliceInstruction(
78 spirv::Opcode
&opcode
, ArrayRef
<uint32_t> &operands
,
79 std::optional
<spirv::Opcode
> expectedOpcode
) {
80 auto binarySize
= binary
.size();
81 if (curOffset
>= binarySize
) {
82 return emitError(unknownLoc
, "expected ")
83 << (expectedOpcode
? spirv::stringifyOpcode(*expectedOpcode
)
88 // For each instruction, get its word count from the first word to slice it
89 // from the stream properly, and then dispatch to the instruction handler.
91 uint32_t wordCount
= binary
[curOffset
] >> 16;
94 return emitError(unknownLoc
, "word count cannot be zero");
96 uint32_t nextOffset
= curOffset
+ wordCount
;
97 if (nextOffset
> binarySize
)
98 return emitError(unknownLoc
, "insufficient words for the last instruction");
100 opcode
= extractOpcode(binary
[curOffset
]);
101 operands
= binary
.slice(curOffset
+ 1, wordCount
- 1);
102 curOffset
= nextOffset
;
106 LogicalResult
spirv::Deserializer::processInstruction(
107 spirv::Opcode opcode
, ArrayRef
<uint32_t> operands
, bool deferInstructions
) {
108 LLVM_DEBUG(logger
.startLine() << "[inst] processing instruction "
109 << spirv::stringifyOpcode(opcode
) << "\n");
111 // First dispatch all the instructions whose opcode does not correspond to
112 // those that have a direct mirror in the SPIR-V dialect
114 case spirv::Opcode::OpCapability
:
115 return processCapability(operands
);
116 case spirv::Opcode::OpExtension
:
117 return processExtension(operands
);
118 case spirv::Opcode::OpExtInst
:
119 return processExtInst(operands
);
120 case spirv::Opcode::OpExtInstImport
:
121 return processExtInstImport(operands
);
122 case spirv::Opcode::OpMemberName
:
123 return processMemberName(operands
);
124 case spirv::Opcode::OpMemoryModel
:
125 return processMemoryModel(operands
);
126 case spirv::Opcode::OpEntryPoint
:
127 case spirv::Opcode::OpExecutionMode
:
128 if (deferInstructions
) {
129 deferredInstructions
.emplace_back(opcode
, operands
);
133 case spirv::Opcode::OpVariable
:
134 if (isa
<spirv::ModuleOp
>(opBuilder
.getBlock()->getParentOp())) {
135 return processGlobalVariable(operands
);
138 case spirv::Opcode::OpLine
:
139 return processDebugLine(operands
);
140 case spirv::Opcode::OpNoLine
:
143 case spirv::Opcode::OpName
:
144 return processName(operands
);
145 case spirv::Opcode::OpString
:
146 return processDebugString(operands
);
147 case spirv::Opcode::OpModuleProcessed
:
148 case spirv::Opcode::OpSource
:
149 case spirv::Opcode::OpSourceContinued
:
150 case spirv::Opcode::OpSourceExtension
:
151 // TODO: This is debug information embedded in the binary which should be
152 // translated into the spirv.module.
154 case spirv::Opcode::OpTypeVoid
:
155 case spirv::Opcode::OpTypeBool
:
156 case spirv::Opcode::OpTypeInt
:
157 case spirv::Opcode::OpTypeFloat
:
158 case spirv::Opcode::OpTypeVector
:
159 case spirv::Opcode::OpTypeMatrix
:
160 case spirv::Opcode::OpTypeArray
:
161 case spirv::Opcode::OpTypeFunction
:
162 case spirv::Opcode::OpTypeImage
:
163 case spirv::Opcode::OpTypeSampledImage
:
164 case spirv::Opcode::OpTypeRuntimeArray
:
165 case spirv::Opcode::OpTypeStruct
:
166 case spirv::Opcode::OpTypePointer
:
167 case spirv::Opcode::OpTypeCooperativeMatrixKHR
:
168 return processType(opcode
, operands
);
169 case spirv::Opcode::OpTypeForwardPointer
:
170 return processTypeForwardPointer(operands
);
171 case spirv::Opcode::OpConstant
:
172 return processConstant(operands
, /*isSpec=*/false);
173 case spirv::Opcode::OpSpecConstant
:
174 return processConstant(operands
, /*isSpec=*/true);
175 case spirv::Opcode::OpConstantComposite
:
176 return processConstantComposite(operands
);
177 case spirv::Opcode::OpSpecConstantComposite
:
178 return processSpecConstantComposite(operands
);
179 case spirv::Opcode::OpSpecConstantOp
:
180 return processSpecConstantOperation(operands
);
181 case spirv::Opcode::OpConstantTrue
:
182 return processConstantBool(/*isTrue=*/true, operands
, /*isSpec=*/false);
183 case spirv::Opcode::OpSpecConstantTrue
:
184 return processConstantBool(/*isTrue=*/true, operands
, /*isSpec=*/true);
185 case spirv::Opcode::OpConstantFalse
:
186 return processConstantBool(/*isTrue=*/false, operands
, /*isSpec=*/false);
187 case spirv::Opcode::OpSpecConstantFalse
:
188 return processConstantBool(/*isTrue=*/false, operands
, /*isSpec=*/true);
189 case spirv::Opcode::OpConstantNull
:
190 return processConstantNull(operands
);
191 case spirv::Opcode::OpDecorate
:
192 return processDecoration(operands
);
193 case spirv::Opcode::OpMemberDecorate
:
194 return processMemberDecoration(operands
);
195 case spirv::Opcode::OpFunction
:
196 return processFunction(operands
);
197 case spirv::Opcode::OpLabel
:
198 return processLabel(operands
);
199 case spirv::Opcode::OpBranch
:
200 return processBranch(operands
);
201 case spirv::Opcode::OpBranchConditional
:
202 return processBranchConditional(operands
);
203 case spirv::Opcode::OpSelectionMerge
:
204 return processSelectionMerge(operands
);
205 case spirv::Opcode::OpLoopMerge
:
206 return processLoopMerge(operands
);
207 case spirv::Opcode::OpPhi
:
208 return processPhi(operands
);
209 case spirv::Opcode::OpUndef
:
210 return processUndef(operands
);
214 return dispatchToAutogenDeserialization(opcode
, operands
);
217 LogicalResult
spirv::Deserializer::processOpWithoutGrammarAttr(
218 ArrayRef
<uint32_t> words
, StringRef opName
, bool hasResult
,
219 unsigned numOperands
) {
220 SmallVector
<Type
, 1> resultTypes
;
221 uint32_t valueID
= 0;
223 size_t wordIndex
= 0;
225 if (wordIndex
>= words
.size())
226 return emitError(unknownLoc
,
227 "expected result type <id> while deserializing for ")
230 // Decode the type <id>
231 auto type
= getType(words
[wordIndex
]);
233 return emitError(unknownLoc
, "unknown type result <id>: ")
235 resultTypes
.push_back(type
);
238 // Decode the result <id>
239 if (wordIndex
>= words
.size())
240 return emitError(unknownLoc
,
241 "expected result <id> while deserializing for ")
243 valueID
= words
[wordIndex
];
247 SmallVector
<Value
, 4> operands
;
248 SmallVector
<NamedAttribute
, 4> attributes
;
251 size_t operandIndex
= 0;
252 for (; operandIndex
< numOperands
&& wordIndex
< words
.size();
253 ++operandIndex
, ++wordIndex
) {
254 auto arg
= getValue(words
[wordIndex
]);
256 return emitError(unknownLoc
, "unknown result <id>: ") << words
[wordIndex
];
257 operands
.push_back(arg
);
259 if (operandIndex
!= numOperands
) {
262 "found less operands than expected when deserializing for ")
263 << opName
<< "; only " << operandIndex
<< " of " << numOperands
266 if (wordIndex
!= words
.size()) {
269 "found more operands than expected when deserializing for ")
270 << opName
<< "; only " << wordIndex
<< " of " << words
.size()
274 // Attach attributes from decorations
275 if (decorations
.count(valueID
)) {
276 auto attrs
= decorations
[valueID
].getAttrs();
277 attributes
.append(attrs
.begin(), attrs
.end());
280 // Create the op and update bookkeeping maps
281 Location loc
= createFileLineColLoc(opBuilder
);
282 OperationState
opState(loc
, opName
);
283 opState
.addOperands(operands
);
285 opState
.addTypes(resultTypes
);
286 opState
.addAttributes(attributes
);
287 Operation
*op
= opBuilder
.create(opState
);
289 valueMap
[valueID
] = op
->getResult(0);
291 if (op
->hasTrait
<OpTrait::IsTerminator
>())
297 LogicalResult
spirv::Deserializer::processUndef(ArrayRef
<uint32_t> operands
) {
298 if (operands
.size() != 2) {
299 return emitError(unknownLoc
, "OpUndef instruction must have two operands");
301 auto type
= getType(operands
[0]);
303 return emitError(unknownLoc
, "unknown type <id> with OpUndef instruction");
305 undefMap
[operands
[1]] = type
;
309 LogicalResult
spirv::Deserializer::processExtInst(ArrayRef
<uint32_t> operands
) {
310 if (operands
.size() < 4) {
311 return emitError(unknownLoc
,
312 "OpExtInst must have at least 4 operands, result type "
313 "<id>, result <id>, set <id> and instruction opcode");
315 if (!extendedInstSets
.count(operands
[2])) {
316 return emitError(unknownLoc
, "undefined set <id> in OpExtInst");
318 SmallVector
<uint32_t, 4> slicedOperands
;
319 slicedOperands
.append(operands
.begin(), std::next(operands
.begin(), 2));
320 slicedOperands
.append(std::next(operands
.begin(), 4), operands
.end());
321 return dispatchToExtensionSetAutogenDeserialization(
322 extendedInstSets
[operands
[2]], operands
[3], slicedOperands
);
330 Deserializer::processOp
<spirv::EntryPointOp
>(ArrayRef
<uint32_t> words
) {
331 unsigned wordIndex
= 0;
332 if (wordIndex
>= words
.size()) {
333 return emitError(unknownLoc
,
334 "missing Execution Model specification in OpEntryPoint");
336 auto execModel
= spirv::ExecutionModelAttr::get(
337 context
, static_cast<spirv::ExecutionModel
>(words
[wordIndex
++]));
338 if (wordIndex
>= words
.size()) {
339 return emitError(unknownLoc
, "missing <id> in OpEntryPoint");
341 // Get the function <id>
342 auto fnID
= words
[wordIndex
++];
343 // Get the function name
344 auto fnName
= decodeStringLiteral(words
, wordIndex
);
345 // Verify that the function <id> matches the fnName
346 auto parsedFunc
= getFunction(fnID
);
348 return emitError(unknownLoc
, "no function matching <id> ") << fnID
;
350 if (parsedFunc
.getName() != fnName
) {
351 // The deserializer uses "spirv_fn_<id>" as the function name if the input
352 // SPIR-V blob does not contain a name for it. We should use a more clear
353 // indication for such case rather than relying on naming details.
354 if (!parsedFunc
.getName().starts_with("spirv_fn_"))
355 return emitError(unknownLoc
,
356 "function name mismatch between OpEntryPoint "
357 "and OpFunction with <id> ")
358 << fnID
<< ": " << fnName
<< " vs. " << parsedFunc
.getName();
359 parsedFunc
.setName(fnName
);
361 SmallVector
<Attribute
, 4> interface
;
362 while (wordIndex
< words
.size()) {
363 auto arg
= getGlobalVariable(words
[wordIndex
]);
365 return emitError(unknownLoc
, "undefined result <id> ")
366 << words
[wordIndex
] << " while decoding OpEntryPoint";
368 interface
.push_back(SymbolRefAttr::get(arg
.getOperation()));
371 opBuilder
.create
<spirv::EntryPointOp
>(
372 unknownLoc
, execModel
, SymbolRefAttr::get(opBuilder
.getContext(), fnName
),
373 opBuilder
.getArrayAttr(interface
));
379 Deserializer::processOp
<spirv::ExecutionModeOp
>(ArrayRef
<uint32_t> words
) {
380 unsigned wordIndex
= 0;
381 if (wordIndex
>= words
.size()) {
382 return emitError(unknownLoc
,
383 "missing function result <id> in OpExecutionMode");
385 // Get the function <id> to get the name of the function
386 auto fnID
= words
[wordIndex
++];
387 auto fn
= getFunction(fnID
);
389 return emitError(unknownLoc
, "no function matching <id> ") << fnID
;
391 // Get the Execution mode
392 if (wordIndex
>= words
.size()) {
393 return emitError(unknownLoc
, "missing Execution Mode in OpExecutionMode");
395 auto execMode
= spirv::ExecutionModeAttr::get(
396 context
, static_cast<spirv::ExecutionMode
>(words
[wordIndex
++]));
399 SmallVector
<Attribute
, 4> attrListElems
;
400 while (wordIndex
< words
.size()) {
401 attrListElems
.push_back(opBuilder
.getI32IntegerAttr(words
[wordIndex
++]));
403 auto values
= opBuilder
.getArrayAttr(attrListElems
);
404 opBuilder
.create
<spirv::ExecutionModeOp
>(
405 unknownLoc
, SymbolRefAttr::get(opBuilder
.getContext(), fn
.getName()),
412 Deserializer::processOp
<spirv::FunctionCallOp
>(ArrayRef
<uint32_t> operands
) {
413 if (operands
.size() < 3) {
414 return emitError(unknownLoc
,
415 "OpFunctionCall must have at least 3 operands");
418 Type resultType
= getType(operands
[0]);
420 return emitError(unknownLoc
, "undefined result type from <id> ")
424 // Use null type to mean no result type.
425 if (isVoidType(resultType
))
426 resultType
= nullptr;
428 auto resultID
= operands
[1];
429 auto functionID
= operands
[2];
431 auto functionName
= getFunctionSymbol(functionID
);
433 SmallVector
<Value
, 4> arguments
;
434 for (auto operand
: llvm::drop_begin(operands
, 3)) {
435 auto value
= getValue(operand
);
437 return emitError(unknownLoc
, "unknown <id> ")
438 << operand
<< " used by OpFunctionCall";
440 arguments
.push_back(value
);
443 auto opFunctionCall
= opBuilder
.create
<spirv::FunctionCallOp
>(
444 unknownLoc
, resultType
,
445 SymbolRefAttr::get(opBuilder
.getContext(), functionName
), arguments
);
448 valueMap
[resultID
] = opFunctionCall
.getResult(0);
454 Deserializer::processOp
<spirv::CopyMemoryOp
>(ArrayRef
<uint32_t> words
) {
455 SmallVector
<Type
, 1> resultTypes
;
456 size_t wordIndex
= 0;
457 SmallVector
<Value
, 4> operands
;
458 SmallVector
<NamedAttribute
, 4> attributes
;
460 if (wordIndex
< words
.size()) {
461 auto arg
= getValue(words
[wordIndex
]);
464 return emitError(unknownLoc
, "unknown result <id> : ")
468 operands
.push_back(arg
);
472 if (wordIndex
< words
.size()) {
473 auto arg
= getValue(words
[wordIndex
]);
476 return emitError(unknownLoc
, "unknown result <id> : ")
480 operands
.push_back(arg
);
484 bool isAlignedAttr
= false;
486 if (wordIndex
< words
.size()) {
487 auto attrValue
= words
[wordIndex
++];
488 auto attr
= opBuilder
.getAttr
<spirv::MemoryAccessAttr
>(
489 static_cast<spirv::MemoryAccess
>(attrValue
));
490 attributes
.push_back(
491 opBuilder
.getNamedAttr(attributeName
<MemoryAccess
>(), attr
));
492 isAlignedAttr
= (attrValue
== 2);
495 if (isAlignedAttr
&& wordIndex
< words
.size()) {
496 attributes
.push_back(opBuilder
.getNamedAttr(
497 "alignment", opBuilder
.getI32IntegerAttr(words
[wordIndex
++])));
500 if (wordIndex
< words
.size()) {
501 auto attrValue
= words
[wordIndex
++];
502 auto attr
= opBuilder
.getAttr
<spirv::MemoryAccessAttr
>(
503 static_cast<spirv::MemoryAccess
>(attrValue
));
504 attributes
.push_back(opBuilder
.getNamedAttr("source_memory_access", attr
));
507 if (wordIndex
< words
.size()) {
508 attributes
.push_back(opBuilder
.getNamedAttr(
509 "source_alignment", opBuilder
.getI32IntegerAttr(words
[wordIndex
++])));
512 if (wordIndex
!= words
.size()) {
513 return emitError(unknownLoc
,
514 "found more operands than expected when deserializing "
515 "spirv::CopyMemoryOp, only ")
516 << wordIndex
<< " of " << words
.size() << " processed";
519 Location loc
= createFileLineColLoc(opBuilder
);
520 opBuilder
.create
<spirv::CopyMemoryOp
>(loc
, resultTypes
, operands
, attributes
);
526 LogicalResult
Deserializer::processOp
<spirv::GenericCastToPtrExplicitOp
>(
527 ArrayRef
<uint32_t> words
) {
528 if (words
.size() != 4) {
529 return emitError(unknownLoc
,
530 "expected 4 words in GenericCastToPtrExplicitOp"
534 SmallVector
<Type
, 1> resultTypes
;
535 SmallVector
<Value
, 4> operands
;
536 uint32_t valueID
= 0;
537 auto type
= getType(words
[0]);
540 return emitError(unknownLoc
, "unknown type result <id> : ") << words
[0];
541 resultTypes
.push_back(type
);
545 auto arg
= getValue(words
[2]);
547 return emitError(unknownLoc
, "unknown result <id> : ") << words
[2];
548 operands
.push_back(arg
);
550 Location loc
= createFileLineColLoc(opBuilder
);
551 Operation
*op
= opBuilder
.create
<spirv::GenericCastToPtrExplicitOp
>(
552 loc
, resultTypes
, operands
);
553 valueMap
[valueID
] = op
->getResult(0);
557 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
558 // various Deserializer::processOp<...>() specializations.
559 #define GET_DESERIALIZATION_FNS
560 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"