1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
11 //===----------------------------------------------------------------------===//
13 #include "Deserializer.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/Support/LogicalResult.h"
23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/bit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/raw_ostream.h"
36 #define DEBUG_TYPE "spirv-deserialization"
38 //===----------------------------------------------------------------------===//
40 //===----------------------------------------------------------------------===//
42 /// Returns true if the given `block` is a function entry block.
43 static inline bool isFnEntryBlock(Block
*block
) {
44 return block
->isEntryBlock() &&
45 isa_and_nonnull
<spirv::FuncOp
>(block
->getParentOp());
48 //===----------------------------------------------------------------------===//
49 // Deserializer Method Definitions
50 //===----------------------------------------------------------------------===//
52 spirv::Deserializer::Deserializer(ArrayRef
<uint32_t> binary
,
54 : binary(binary
), context(context
), unknownLoc(UnknownLoc::get(context
)),
55 module(createModuleOp()), opBuilder(module
->getRegion())
63 LogicalResult
spirv::Deserializer::deserialize() {
67 << "//+++---------- start deserialization ----------+++//\n";
70 if (failed(processHeader()))
73 spirv::Opcode opcode
= spirv::Opcode::OpNop
;
74 ArrayRef
<uint32_t> operands
;
75 auto binarySize
= binary
.size();
76 while (curOffset
< binarySize
) {
77 // Slice the next instruction out and populate `opcode` and `operands`.
78 // Internally this also updates `curOffset`.
79 if (failed(sliceInstruction(opcode
, operands
)))
82 if (failed(processInstruction(opcode
, operands
)))
86 assert(curOffset
== binarySize
&&
87 "deserializer should never index beyond the binary end");
89 for (auto &deferred
: deferredInstructions
) {
90 if (failed(processInstruction(deferred
.first
, deferred
.second
, false))) {
97 LLVM_DEBUG(logger
.startLine()
98 << "//+++-------- completed deserialization --------+++//\n");
102 OwningOpRef
<spirv::ModuleOp
> spirv::Deserializer::collect() {
103 return std::move(module
);
106 //===----------------------------------------------------------------------===//
108 //===----------------------------------------------------------------------===//
110 OwningOpRef
<spirv::ModuleOp
> spirv::Deserializer::createModuleOp() {
111 OpBuilder
builder(context
);
112 OperationState
state(unknownLoc
, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder
, state
);
114 return cast
<spirv::ModuleOp
>(Operation::create(state
));
117 LogicalResult
spirv::Deserializer::processHeader() {
118 if (binary
.size() < spirv::kHeaderWordCount
)
119 return emitError(unknownLoc
,
120 "SPIR-V binary module must have a 5-word header");
122 if (binary
[0] != spirv::kMagicNumber
)
123 return emitError(unknownLoc
, "incorrect magic number");
125 // Version number bytes: 0 | major number | minor number | 0
126 uint32_t majorVersion
= (binary
[1] << 8) >> 24;
127 uint32_t minorVersion
= (binary
[1] << 16) >> 24;
128 if (majorVersion
== 1) {
129 switch (minorVersion
) {
130 #define MIN_VERSION_CASE(v) \
132 version = spirv::Version::V_1_##v; \
141 #undef MIN_VERSION_CASE
143 return emitError(unknownLoc
, "unsupported SPIR-V minor version: ")
147 return emitError(unknownLoc
, "unsupported SPIR-V major version: ")
151 // TODO: generator number, bound, schema
152 curOffset
= spirv::kHeaderWordCount
;
157 spirv::Deserializer::processCapability(ArrayRef
<uint32_t> operands
) {
158 if (operands
.size() != 1)
159 return emitError(unknownLoc
, "OpMemoryModel must have one parameter");
161 auto cap
= spirv::symbolizeCapability(operands
[0]);
163 return emitError(unknownLoc
, "unknown capability: ") << operands
[0];
165 capabilities
.insert(*cap
);
169 LogicalResult
spirv::Deserializer::processExtension(ArrayRef
<uint32_t> words
) {
173 "OpExtension must have a literal string for the extension name");
176 unsigned wordIndex
= 0;
177 StringRef extName
= decodeStringLiteral(words
, wordIndex
);
178 if (wordIndex
!= words
.size())
179 return emitError(unknownLoc
,
180 "unexpected trailing words in OpExtension instruction");
181 auto ext
= spirv::symbolizeExtension(extName
);
183 return emitError(unknownLoc
, "unknown extension: ") << extName
;
185 extensions
.insert(*ext
);
190 spirv::Deserializer::processExtInstImport(ArrayRef
<uint32_t> words
) {
191 if (words
.size() < 2) {
192 return emitError(unknownLoc
,
193 "OpExtInstImport must have a result <id> and a literal "
194 "string for the extended instruction set name");
197 unsigned wordIndex
= 1;
198 extendedInstSets
[words
[0]] = decodeStringLiteral(words
, wordIndex
);
199 if (wordIndex
!= words
.size()) {
200 return emitError(unknownLoc
,
201 "unexpected trailing words in OpExtInstImport");
206 void spirv::Deserializer::attachVCETriple() {
208 spirv::ModuleOp::getVCETripleAttrName(),
209 spirv::VerCapExtAttr::get(version
, capabilities
.getArrayRef(),
210 extensions
.getArrayRef(), context
));
214 spirv::Deserializer::processMemoryModel(ArrayRef
<uint32_t> operands
) {
215 if (operands
.size() != 2)
216 return emitError(unknownLoc
, "OpMemoryModel must have two operands");
219 module
->getAddressingModelAttrName(),
220 opBuilder
.getAttr
<spirv::AddressingModelAttr
>(
221 static_cast<spirv::AddressingModel
>(operands
.front())));
223 (*module
)->setAttr(module
->getMemoryModelAttrName(),
224 opBuilder
.getAttr
<spirv::MemoryModelAttr
>(
225 static_cast<spirv::MemoryModel
>(operands
.back())));
230 LogicalResult
spirv::Deserializer::processDecoration(ArrayRef
<uint32_t> words
) {
231 // TODO: This function should also be auto-generated. For now, since only a
232 // few decorations are processed/handled in a meaningful manner, going with a
233 // manual implementation.
234 if (words
.size() < 2) {
236 unknownLoc
, "OpDecorate must have at least result <id> and Decoration");
238 auto decorationName
=
239 stringifyDecoration(static_cast<spirv::Decoration
>(words
[1]));
240 if (decorationName
.empty()) {
241 return emitError(unknownLoc
, "invalid Decoration code : ") << words
[1];
243 auto symbol
= getSymbolDecoration(decorationName
);
244 switch (static_cast<spirv::Decoration
>(words
[1])) {
245 case spirv::Decoration::FPFastMathMode
:
246 if (words
.size() != 3) {
247 return emitError(unknownLoc
, "OpDecorate with ")
248 << decorationName
<< " needs a single integer literal";
250 decorations
[words
[0]].set(
251 symbol
, FPFastMathModeAttr::get(opBuilder
.getContext(),
252 static_cast<FPFastMathMode
>(words
[2])));
254 case spirv::Decoration::DescriptorSet
:
255 case spirv::Decoration::Binding
:
256 if (words
.size() != 3) {
257 return emitError(unknownLoc
, "OpDecorate with ")
258 << decorationName
<< " needs a single integer literal";
260 decorations
[words
[0]].set(
261 symbol
, opBuilder
.getI32IntegerAttr(static_cast<int32_t>(words
[2])));
263 case spirv::Decoration::BuiltIn
:
264 if (words
.size() != 3) {
265 return emitError(unknownLoc
, "OpDecorate with ")
266 << decorationName
<< " needs a single integer literal";
268 decorations
[words
[0]].set(
269 symbol
, opBuilder
.getStringAttr(
270 stringifyBuiltIn(static_cast<spirv::BuiltIn
>(words
[2]))));
272 case spirv::Decoration::ArrayStride
:
273 if (words
.size() != 3) {
274 return emitError(unknownLoc
, "OpDecorate with ")
275 << decorationName
<< " needs a single integer literal";
277 typeDecorations
[words
[0]] = words
[2];
279 case spirv::Decoration::LinkageAttributes
: {
280 if (words
.size() < 4) {
281 return emitError(unknownLoc
, "OpDecorate with ")
283 << " needs at least 1 string and 1 integer literal";
285 // LinkageAttributes has two parameters ["linkageName", linkageType]
286 // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
287 // "linkageName" is a stringliteral encoded as uint32_t,
288 // hence the size of name is variable length which results in words.size()
289 // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
290 // 3 + ceildiv(strlen(name), 4).
291 unsigned wordIndex
= 2;
292 auto linkageName
= spirv::decodeStringLiteral(words
, wordIndex
).str();
293 auto linkageTypeAttr
= opBuilder
.getAttr
<::mlir::spirv::LinkageTypeAttr
>(
294 static_cast<::mlir::spirv::LinkageType
>(words
[wordIndex
++]));
295 auto linkageAttr
= opBuilder
.getAttr
<::mlir::spirv::LinkageAttributesAttr
>(
296 StringAttr::get(context
, linkageName
), linkageTypeAttr
);
297 decorations
[words
[0]].set(symbol
, llvm::dyn_cast
<Attribute
>(linkageAttr
));
300 case spirv::Decoration::Aliased
:
301 case spirv::Decoration::AliasedPointer
:
302 case spirv::Decoration::Block
:
303 case spirv::Decoration::BufferBlock
:
304 case spirv::Decoration::Flat
:
305 case spirv::Decoration::NonReadable
:
306 case spirv::Decoration::NonWritable
:
307 case spirv::Decoration::NoPerspective
:
308 case spirv::Decoration::NoSignedWrap
:
309 case spirv::Decoration::NoUnsignedWrap
:
310 case spirv::Decoration::RelaxedPrecision
:
311 case spirv::Decoration::Restrict
:
312 case spirv::Decoration::RestrictPointer
:
313 case spirv::Decoration::NoContraction
:
314 if (words
.size() != 2) {
315 return emitError(unknownLoc
, "OpDecoration with ")
316 << decorationName
<< "needs a single target <id>";
318 // Block decoration does not affect spirv.struct type, but is still stored
320 // TODO: Update StructType to contain this information since
321 // it is needed for many validation rules.
322 decorations
[words
[0]].set(symbol
, opBuilder
.getUnitAttr());
324 case spirv::Decoration::Location
:
325 case spirv::Decoration::SpecId
:
326 if (words
.size() != 3) {
327 return emitError(unknownLoc
, "OpDecoration with ")
328 << decorationName
<< "needs a single integer literal";
330 decorations
[words
[0]].set(
331 symbol
, opBuilder
.getI32IntegerAttr(static_cast<int32_t>(words
[2])));
334 return emitError(unknownLoc
, "unhandled Decoration : '") << decorationName
;
340 spirv::Deserializer::processMemberDecoration(ArrayRef
<uint32_t> words
) {
341 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
342 if (words
.size() < 3) {
343 return emitError(unknownLoc
,
344 "OpMemberDecorate must have at least 3 operands");
347 auto decoration
= static_cast<spirv::Decoration
>(words
[2]);
348 if (decoration
== spirv::Decoration::Offset
&& words
.size() != 4) {
349 return emitError(unknownLoc
,
350 " missing offset specification in OpMemberDecorate with "
351 "Offset decoration");
353 ArrayRef
<uint32_t> decorationOperands
;
354 if (words
.size() > 3) {
355 decorationOperands
= words
.slice(3);
357 memberDecorationMap
[words
[0]][words
[1]][decoration
] = decorationOperands
;
361 LogicalResult
spirv::Deserializer::processMemberName(ArrayRef
<uint32_t> words
) {
362 if (words
.size() < 3) {
363 return emitError(unknownLoc
, "OpMemberName must have at least 3 operands");
365 unsigned wordIndex
= 2;
366 auto name
= decodeStringLiteral(words
, wordIndex
);
367 if (wordIndex
!= words
.size()) {
368 return emitError(unknownLoc
,
369 "unexpected trailing words in OpMemberName instruction");
371 memberNameMap
[words
[0]][words
[1]] = name
;
375 LogicalResult
spirv::Deserializer::setFunctionArgAttrs(
376 uint32_t argID
, SmallVectorImpl
<Attribute
> &argAttrs
, size_t argIndex
) {
377 if (!decorations
.contains(argID
)) {
378 argAttrs
[argIndex
] = DictionaryAttr::get(context
, {});
382 spirv::DecorationAttr foundDecorationAttr
;
383 for (NamedAttribute decAttr
: decorations
[argID
]) {
384 for (auto decoration
:
385 {spirv::Decoration::Aliased
, spirv::Decoration::Restrict
,
386 spirv::Decoration::AliasedPointer
,
387 spirv::Decoration::RestrictPointer
}) {
389 if (decAttr
.getName() !=
390 getSymbolDecoration(stringifyDecoration(decoration
)))
393 if (foundDecorationAttr
)
394 return emitError(unknownLoc
,
395 "more than one Aliased/Restrict decorations for "
396 "function argument with result <id> ")
399 foundDecorationAttr
= spirv::DecorationAttr::get(context
, decoration
);
404 if (!foundDecorationAttr
)
405 return emitError(unknownLoc
, "unimplemented decoration support for "
406 "function argument with result <id> ")
409 NamedAttribute
attr(StringAttr::get(context
, spirv::DecorationAttr::name
),
410 foundDecorationAttr
);
411 argAttrs
[argIndex
] = DictionaryAttr::get(context
, attr
);
416 spirv::Deserializer::processFunction(ArrayRef
<uint32_t> operands
) {
418 return emitError(unknownLoc
, "found function inside function");
421 // Get the result type
422 if (operands
.size() != 4) {
423 return emitError(unknownLoc
, "OpFunction must have 4 parameters");
425 Type resultType
= getType(operands
[0]);
427 return emitError(unknownLoc
, "undefined result type from <id> ")
431 uint32_t fnID
= operands
[1];
432 if (funcMap
.count(fnID
)) {
433 return emitError(unknownLoc
, "duplicate function definition/declaration");
436 auto fnControl
= spirv::symbolizeFunctionControl(operands
[2]);
438 return emitError(unknownLoc
, "unknown Function Control: ") << operands
[2];
441 Type fnType
= getType(operands
[3]);
442 if (!fnType
|| !isa
<FunctionType
>(fnType
)) {
443 return emitError(unknownLoc
, "unknown function type from <id> ")
446 auto functionType
= cast
<FunctionType
>(fnType
);
448 if ((isVoidType(resultType
) && functionType
.getNumResults() != 0) ||
449 (functionType
.getNumResults() == 1 &&
450 functionType
.getResult(0) != resultType
)) {
451 return emitError(unknownLoc
, "mismatch in function type ")
452 << functionType
<< " and return type " << resultType
<< " specified";
455 std::string fnName
= getFunctionSymbol(fnID
);
456 auto funcOp
= opBuilder
.create
<spirv::FuncOp
>(
457 unknownLoc
, fnName
, functionType
, fnControl
.value());
458 // Processing other function attributes.
459 if (decorations
.count(fnID
)) {
460 for (auto attr
: decorations
[fnID
].getAttrs()) {
461 funcOp
->setAttr(attr
.getName(), attr
.getValue());
464 curFunction
= funcMap
[fnID
] = funcOp
;
465 auto *entryBlock
= funcOp
.addEntryBlock();
468 << "//===-------------------------------------------===//\n";
469 logger
.startLine() << "[fn] name: " << fnName
<< "\n";
470 logger
.startLine() << "[fn] type: " << fnType
<< "\n";
471 logger
.startLine() << "[fn] ID: " << fnID
<< "\n";
472 logger
.startLine() << "[fn] entry block: " << entryBlock
<< "\n";
476 SmallVector
<Attribute
> argAttrs
;
477 argAttrs
.resize(functionType
.getNumInputs());
479 // Parse the op argument instructions
480 if (functionType
.getNumInputs()) {
481 for (size_t i
= 0, e
= functionType
.getNumInputs(); i
!= e
; ++i
) {
482 auto argType
= functionType
.getInput(i
);
483 spirv::Opcode opcode
= spirv::Opcode::OpNop
;
484 ArrayRef
<uint32_t> operands
;
485 if (failed(sliceInstruction(opcode
, operands
,
486 spirv::Opcode::OpFunctionParameter
))) {
489 if (opcode
!= spirv::Opcode::OpFunctionParameter
) {
492 "missing OpFunctionParameter instruction for argument ")
495 if (operands
.size() != 2) {
498 "expected result type and result <id> for OpFunctionParameter");
500 auto argDefinedType
= getType(operands
[0]);
501 if (!argDefinedType
|| argDefinedType
!= argType
) {
502 return emitError(unknownLoc
,
503 "mismatch in argument type between function type "
505 << functionType
<< " and argument type definition "
506 << argDefinedType
<< " at argument " << i
;
508 if (getValue(operands
[1])) {
509 return emitError(unknownLoc
, "duplicate definition of result <id> ")
512 if (failed(setFunctionArgAttrs(operands
[1], argAttrs
, i
))) {
516 auto argValue
= funcOp
.getArgument(i
);
517 valueMap
[operands
[1]] = argValue
;
521 if (llvm::any_of(argAttrs
, [](Attribute attr
) {
522 auto argAttr
= cast
<DictionaryAttr
>(attr
);
523 return !argAttr
.empty();
525 funcOp
.setArgAttrsAttr(ArrayAttr::get(context
, argAttrs
));
527 // entryBlock is needed to access the arguments, Once that is done, we can
528 // erase the block for functions with 'Import' LinkageAttributes, since these
529 // are essentially function declarations, so they have no body.
530 auto linkageAttr
= funcOp
.getLinkageAttributes();
531 auto hasImportLinkage
=
532 linkageAttr
&& (linkageAttr
.value().getLinkageType().getValue() ==
533 spirv::LinkageType::Import
);
534 if (hasImportLinkage
)
537 // RAII guard to reset the insertion point to the module's region after
538 // deserializing the body of this function.
539 OpBuilder::InsertionGuard
moduleInsertionGuard(opBuilder
);
541 spirv::Opcode opcode
= spirv::Opcode::OpNop
;
542 ArrayRef
<uint32_t> instOperands
;
544 // Special handling for the entry block. We need to make sure it starts with
545 // an OpLabel instruction. The entry block takes the same parameters as the
546 // function. All other blocks do not take any parameter. We have already
547 // created the entry block, here we need to register it to the correct label
549 if (failed(sliceInstruction(opcode
, instOperands
,
550 spirv::Opcode::OpFunctionEnd
))) {
553 if (opcode
== spirv::Opcode::OpFunctionEnd
) {
554 return processFunctionEnd(instOperands
);
556 if (opcode
!= spirv::Opcode::OpLabel
) {
557 return emitError(unknownLoc
, "a basic block must start with OpLabel");
559 if (instOperands
.size() != 1) {
560 return emitError(unknownLoc
, "OpLabel should only have result <id>");
562 blockMap
[instOperands
[0]] = entryBlock
;
563 if (failed(processLabel(instOperands
))) {
567 // Then process all the other instructions in the function until we hit
569 while (succeeded(sliceInstruction(opcode
, instOperands
,
570 spirv::Opcode::OpFunctionEnd
)) &&
571 opcode
!= spirv::Opcode::OpFunctionEnd
) {
572 if (failed(processInstruction(opcode
, instOperands
))) {
576 if (opcode
!= spirv::Opcode::OpFunctionEnd
) {
580 return processFunctionEnd(instOperands
);
584 spirv::Deserializer::processFunctionEnd(ArrayRef
<uint32_t> operands
) {
585 // Process OpFunctionEnd.
586 if (!operands
.empty()) {
587 return emitError(unknownLoc
, "unexpected operands for OpFunctionEnd");
590 // Wire up block arguments from OpPhi instructions.
591 // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
593 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
598 curFunction
= std::nullopt
;
603 << "//===-------------------------------------------===//\n";
608 std::optional
<std::pair
<Attribute
, Type
>>
609 spirv::Deserializer::getConstant(uint32_t id
) {
610 auto constIt
= constantMap
.find(id
);
611 if (constIt
== constantMap
.end())
613 return constIt
->getSecond();
616 std::optional
<spirv::SpecConstOperationMaterializationInfo
>
617 spirv::Deserializer::getSpecConstantOperation(uint32_t id
) {
618 auto constIt
= specConstOperationMap
.find(id
);
619 if (constIt
== specConstOperationMap
.end())
621 return constIt
->getSecond();
624 std::string
spirv::Deserializer::getFunctionSymbol(uint32_t id
) {
625 auto funcName
= nameMap
.lookup(id
).str();
626 if (funcName
.empty()) {
627 funcName
= "spirv_fn_" + std::to_string(id
);
632 std::string
spirv::Deserializer::getSpecConstantSymbol(uint32_t id
) {
633 auto constName
= nameMap
.lookup(id
).str();
634 if (constName
.empty()) {
635 constName
= "spirv_spec_const_" + std::to_string(id
);
640 spirv::SpecConstantOp
641 spirv::Deserializer::createSpecConstant(Location loc
, uint32_t resultID
,
642 TypedAttr defaultValue
) {
643 auto symName
= opBuilder
.getStringAttr(getSpecConstantSymbol(resultID
));
644 auto op
= opBuilder
.create
<spirv::SpecConstantOp
>(unknownLoc
, symName
,
646 if (decorations
.count(resultID
)) {
647 for (auto attr
: decorations
[resultID
].getAttrs())
648 op
->setAttr(attr
.getName(), attr
.getValue());
650 specConstMap
[resultID
] = op
;
655 spirv::Deserializer::processGlobalVariable(ArrayRef
<uint32_t> operands
) {
656 unsigned wordIndex
= 0;
657 if (operands
.size() < 3) {
660 "OpVariable needs at least 3 operands, type, <id> and storage class");
664 auto type
= getType(operands
[wordIndex
]);
666 return emitError(unknownLoc
, "unknown result type <id> : ")
667 << operands
[wordIndex
];
669 auto ptrType
= dyn_cast
<spirv::PointerType
>(type
);
671 return emitError(unknownLoc
,
672 "expected a result type <id> to be a spirv.ptr, found : ")
678 auto variableID
= operands
[wordIndex
];
679 auto variableName
= nameMap
.lookup(variableID
).str();
680 if (variableName
.empty()) {
681 variableName
= "spirv_var_" + std::to_string(variableID
);
686 auto storageClass
= static_cast<spirv::StorageClass
>(operands
[wordIndex
]);
687 if (ptrType
.getStorageClass() != storageClass
) {
688 return emitError(unknownLoc
, "mismatch in storage class of pointer type ")
689 << type
<< " and that specified in OpVariable instruction : "
690 << stringifyStorageClass(storageClass
);
695 FlatSymbolRefAttr initializer
= nullptr;
697 if (wordIndex
< operands
.size()) {
698 Operation
*op
= nullptr;
700 if (auto initOp
= getGlobalVariable(operands
[wordIndex
]))
702 else if (auto initOp
= getSpecConstant(operands
[wordIndex
]))
704 else if (auto initOp
= getSpecConstantComposite(operands
[wordIndex
]))
707 return emitError(unknownLoc
, "unknown <id> ")
708 << operands
[wordIndex
] << "used as initializer";
710 initializer
= SymbolRefAttr::get(op
);
713 if (wordIndex
!= operands
.size()) {
714 return emitError(unknownLoc
,
715 "found more operands than expected when deserializing "
716 "OpVariable instruction, only ")
717 << wordIndex
<< " of " << operands
.size() << " processed";
719 auto loc
= createFileLineColLoc(opBuilder
);
720 auto varOp
= opBuilder
.create
<spirv::GlobalVariableOp
>(
721 loc
, TypeAttr::get(type
), opBuilder
.getStringAttr(variableName
),
725 if (decorations
.count(variableID
)) {
726 for (auto attr
: decorations
[variableID
].getAttrs())
727 varOp
->setAttr(attr
.getName(), attr
.getValue());
729 globalVariableMap
[variableID
] = varOp
;
733 IntegerAttr
spirv::Deserializer::getConstantInt(uint32_t id
) {
734 auto constInfo
= getConstant(id
);
738 return dyn_cast
<IntegerAttr
>(constInfo
->first
);
741 LogicalResult
spirv::Deserializer::processName(ArrayRef
<uint32_t> operands
) {
742 if (operands
.size() < 2) {
743 return emitError(unknownLoc
, "OpName needs at least 2 operands");
745 if (!nameMap
.lookup(operands
[0]).empty()) {
746 return emitError(unknownLoc
, "duplicate name found for result <id> ")
749 unsigned wordIndex
= 1;
750 StringRef name
= decodeStringLiteral(operands
, wordIndex
);
751 if (wordIndex
!= operands
.size()) {
752 return emitError(unknownLoc
,
753 "unexpected trailing words in OpName instruction");
755 nameMap
[operands
[0]] = name
;
759 //===----------------------------------------------------------------------===//
761 //===----------------------------------------------------------------------===//
763 LogicalResult
spirv::Deserializer::processType(spirv::Opcode opcode
,
764 ArrayRef
<uint32_t> operands
) {
765 if (operands
.empty()) {
766 return emitError(unknownLoc
, "type instruction with opcode ")
767 << spirv::stringifyOpcode(opcode
) << " needs at least one <id>";
770 /// TODO: Types might be forward declared in some instructions and need to be
771 /// handled appropriately.
772 if (typeMap
.count(operands
[0])) {
773 return emitError(unknownLoc
, "duplicate definition for result <id> ")
778 case spirv::Opcode::OpTypeVoid
:
779 if (operands
.size() != 1)
780 return emitError(unknownLoc
, "OpTypeVoid must have no parameters");
781 typeMap
[operands
[0]] = opBuilder
.getNoneType();
783 case spirv::Opcode::OpTypeBool
:
784 if (operands
.size() != 1)
785 return emitError(unknownLoc
, "OpTypeBool must have no parameters");
786 typeMap
[operands
[0]] = opBuilder
.getI1Type();
788 case spirv::Opcode::OpTypeInt
: {
789 if (operands
.size() != 3)
791 unknownLoc
, "OpTypeInt must have bitwidth and signedness parameters");
793 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
794 // to preserve or validate.
795 // 0 indicates unsigned, or no signedness semantics
796 // 1 indicates signed semantics."
798 // So we cannot differentiate signless and unsigned integers; always use
799 // signless semantics for such cases.
800 auto sign
= operands
[2] == 1 ? IntegerType::SignednessSemantics::Signed
801 : IntegerType::SignednessSemantics::Signless
;
802 typeMap
[operands
[0]] = IntegerType::get(context
, operands
[1], sign
);
804 case spirv::Opcode::OpTypeFloat
: {
805 if (operands
.size() != 2)
806 return emitError(unknownLoc
, "OpTypeFloat must have bitwidth parameter");
809 switch (operands
[1]) {
811 floatTy
= opBuilder
.getF16Type();
814 floatTy
= opBuilder
.getF32Type();
817 floatTy
= opBuilder
.getF64Type();
820 return emitError(unknownLoc
, "unsupported OpTypeFloat bitwidth: ")
823 typeMap
[operands
[0]] = floatTy
;
825 case spirv::Opcode::OpTypeVector
: {
826 if (operands
.size() != 3) {
829 "OpTypeVector must have element type and count parameters");
831 Type elementTy
= getType(operands
[1]);
833 return emitError(unknownLoc
, "OpTypeVector references undefined <id> ")
836 typeMap
[operands
[0]] = VectorType::get({operands
[2]}, elementTy
);
838 case spirv::Opcode::OpTypePointer
: {
839 return processOpTypePointer(operands
);
841 case spirv::Opcode::OpTypeArray
:
842 return processArrayType(operands
);
843 case spirv::Opcode::OpTypeCooperativeMatrixKHR
:
844 return processCooperativeMatrixTypeKHR(operands
);
845 case spirv::Opcode::OpTypeFunction
:
846 return processFunctionType(operands
);
847 case spirv::Opcode::OpTypeJointMatrixINTEL
:
848 return processJointMatrixType(operands
);
849 case spirv::Opcode::OpTypeImage
:
850 return processImageType(operands
);
851 case spirv::Opcode::OpTypeSampledImage
:
852 return processSampledImageType(operands
);
853 case spirv::Opcode::OpTypeRuntimeArray
:
854 return processRuntimeArrayType(operands
);
855 case spirv::Opcode::OpTypeStruct
:
856 return processStructType(operands
);
857 case spirv::Opcode::OpTypeMatrix
:
858 return processMatrixType(operands
);
860 return emitError(unknownLoc
, "unhandled type instruction");
866 spirv::Deserializer::processOpTypePointer(ArrayRef
<uint32_t> operands
) {
867 if (operands
.size() != 3)
868 return emitError(unknownLoc
, "OpTypePointer must have two parameters");
870 auto pointeeType
= getType(operands
[2]);
872 return emitError(unknownLoc
, "unknown OpTypePointer pointee type <id> ")
875 uint32_t typePointerID
= operands
[0];
876 auto storageClass
= static_cast<spirv::StorageClass
>(operands
[1]);
877 typeMap
[typePointerID
] = spirv::PointerType::get(pointeeType
, storageClass
);
879 for (auto *deferredStructIt
= std::begin(deferredStructTypesInfos
);
880 deferredStructIt
!= std::end(deferredStructTypesInfos
);) {
881 for (auto *unresolvedMemberIt
=
882 std::begin(deferredStructIt
->unresolvedMemberTypes
);
883 unresolvedMemberIt
!=
884 std::end(deferredStructIt
->unresolvedMemberTypes
);) {
885 if (unresolvedMemberIt
->first
== typePointerID
) {
886 // The newly constructed pointer type can resolve one of the
887 // deferred struct type members; update the memberTypes list and
888 // clean the unresolvedMemberTypes list accordingly.
889 deferredStructIt
->memberTypes
[unresolvedMemberIt
->second
] =
890 typeMap
[typePointerID
];
892 deferredStructIt
->unresolvedMemberTypes
.erase(unresolvedMemberIt
);
894 ++unresolvedMemberIt
;
898 if (deferredStructIt
->unresolvedMemberTypes
.empty()) {
899 // All deferred struct type members are now resolved, set the struct body.
900 auto structType
= deferredStructIt
->deferredStructType
;
902 assert(structType
&& "expected a spirv::StructType");
903 assert(structType
.isIdentified() && "expected an indentified struct");
905 if (failed(structType
.trySetBody(
906 deferredStructIt
->memberTypes
, deferredStructIt
->offsetInfo
,
907 deferredStructIt
->memberDecorationsInfo
)))
910 deferredStructIt
= deferredStructTypesInfos
.erase(deferredStructIt
);
920 spirv::Deserializer::processArrayType(ArrayRef
<uint32_t> operands
) {
921 if (operands
.size() != 3) {
922 return emitError(unknownLoc
,
923 "OpTypeArray must have element type and count parameters");
926 Type elementTy
= getType(operands
[1]);
928 return emitError(unknownLoc
, "OpTypeArray references undefined <id> ")
933 // TODO: The count can also come frome a specialization constant.
934 auto countInfo
= getConstant(operands
[2]);
936 return emitError(unknownLoc
, "OpTypeArray count <id> ")
937 << operands
[2] << "can only come from normal constant right now";
940 if (auto intVal
= dyn_cast
<IntegerAttr
>(countInfo
->first
)) {
941 count
= intVal
.getValue().getZExtValue();
943 return emitError(unknownLoc
, "OpTypeArray count must come from a "
944 "scalar integer constant instruction");
947 typeMap
[operands
[0]] = spirv::ArrayType::get(
948 elementTy
, count
, typeDecorations
.lookup(operands
[0]));
953 spirv::Deserializer::processFunctionType(ArrayRef
<uint32_t> operands
) {
954 assert(!operands
.empty() && "No operands for processing function type");
955 if (operands
.size() == 1) {
956 return emitError(unknownLoc
, "missing return type for OpTypeFunction");
958 auto returnType
= getType(operands
[1]);
960 return emitError(unknownLoc
, "unknown return type in OpTypeFunction");
962 SmallVector
<Type
, 1> argTypes
;
963 for (size_t i
= 2, e
= operands
.size(); i
< e
; ++i
) {
964 auto ty
= getType(operands
[i
]);
966 return emitError(unknownLoc
, "unknown argument type in OpTypeFunction");
968 argTypes
.push_back(ty
);
970 ArrayRef
<Type
> returnTypes
;
971 if (!isVoidType(returnType
)) {
972 returnTypes
= llvm::ArrayRef(returnType
);
974 typeMap
[operands
[0]] = FunctionType::get(context
, argTypes
, returnTypes
);
978 LogicalResult
spirv::Deserializer::processCooperativeMatrixTypeKHR(
979 ArrayRef
<uint32_t> operands
) {
980 if (operands
.size() != 6) {
981 return emitError(unknownLoc
,
982 "OpTypeCooperativeMatrixKHR must have element type, "
983 "scope, row and column parameters, and use");
986 Type elementTy
= getType(operands
[1]);
988 return emitError(unknownLoc
,
989 "OpTypeCooperativeMatrixKHR references undefined <id> ")
993 std::optional
<spirv::Scope
> scope
=
994 spirv::symbolizeScope(getConstantInt(operands
[2]).getInt());
998 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1002 unsigned rows
= getConstantInt(operands
[3]).getInt();
1003 unsigned columns
= getConstantInt(operands
[4]).getInt();
1005 std::optional
<spirv::CooperativeMatrixUseKHR
> use
=
1006 spirv::symbolizeCooperativeMatrixUseKHR(
1007 getConstantInt(operands
[5]).getInt());
1011 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1015 typeMap
[operands
[0]] =
1016 spirv::CooperativeMatrixType::get(elementTy
, rows
, columns
, *scope
, *use
);
1021 spirv::Deserializer::processJointMatrixType(ArrayRef
<uint32_t> operands
) {
1022 if (operands
.size() != 6) {
1023 return emitError(unknownLoc
, "OpTypeJointMatrix must have element "
1024 "type and row x column parameters");
1027 Type elementTy
= getType(operands
[1]);
1029 return emitError(unknownLoc
, "OpTypeJointMatrix references undefined <id> ")
1033 auto scope
= spirv::symbolizeScope(getConstantInt(operands
[5]).getInt());
1035 return emitError(unknownLoc
,
1036 "OpTypeJointMatrix references undefined scope <id> ")
1040 spirv::symbolizeMatrixLayout(getConstantInt(operands
[4]).getInt());
1041 if (!matrixLayout
) {
1042 return emitError(unknownLoc
,
1043 "OpTypeJointMatrix references undefined scope <id> ")
1046 unsigned rows
= getConstantInt(operands
[2]).getInt();
1047 unsigned columns
= getConstantInt(operands
[3]).getInt();
1049 typeMap
[operands
[0]] = spirv::JointMatrixINTELType::get(
1050 elementTy
, scope
.value(), rows
, columns
, matrixLayout
.value());
1055 spirv::Deserializer::processRuntimeArrayType(ArrayRef
<uint32_t> operands
) {
1056 if (operands
.size() != 2) {
1057 return emitError(unknownLoc
, "OpTypeRuntimeArray must have two operands");
1059 Type memberType
= getType(operands
[1]);
1061 return emitError(unknownLoc
,
1062 "OpTypeRuntimeArray references undefined <id> ")
1065 typeMap
[operands
[0]] = spirv::RuntimeArrayType::get(
1066 memberType
, typeDecorations
.lookup(operands
[0]));
1071 spirv::Deserializer::processStructType(ArrayRef
<uint32_t> operands
) {
1072 // TODO: Find a way to handle identified structs when debug info is stripped.
1074 if (operands
.empty()) {
1075 return emitError(unknownLoc
, "OpTypeStruct must have at least result <id>");
1078 if (operands
.size() == 1) {
1079 // Handle empty struct.
1080 typeMap
[operands
[0]] =
1081 spirv::StructType::getEmpty(context
, nameMap
.lookup(operands
[0]).str());
1085 // First element is operand ID, second element is member index in the struct.
1086 SmallVector
<std::pair
<uint32_t, unsigned>, 0> unresolvedMemberTypes
;
1087 SmallVector
<Type
, 4> memberTypes
;
1089 for (auto op
: llvm::drop_begin(operands
, 1)) {
1090 Type memberType
= getType(op
);
1091 bool typeForwardPtr
= (typeForwardPointerIDs
.count(op
) != 0);
1093 if (!memberType
&& !typeForwardPtr
)
1094 return emitError(unknownLoc
, "OpTypeStruct references undefined <id> ")
1098 unresolvedMemberTypes
.emplace_back(op
, memberTypes
.size());
1100 memberTypes
.push_back(memberType
);
1103 SmallVector
<spirv::StructType::OffsetInfo
, 0> offsetInfo
;
1104 SmallVector
<spirv::StructType::MemberDecorationInfo
, 0> memberDecorationsInfo
;
1105 if (memberDecorationMap
.count(operands
[0])) {
1106 auto &allMemberDecorations
= memberDecorationMap
[operands
[0]];
1107 for (auto memberIndex
: llvm::seq
<uint32_t>(0, memberTypes
.size())) {
1108 if (allMemberDecorations
.count(memberIndex
)) {
1109 for (auto &memberDecoration
: allMemberDecorations
[memberIndex
]) {
1110 // Check for offset.
1111 if (memberDecoration
.first
== spirv::Decoration::Offset
) {
1112 // If offset info is empty, resize to the number of members;
1113 if (offsetInfo
.empty()) {
1114 offsetInfo
.resize(memberTypes
.size());
1116 offsetInfo
[memberIndex
] = memberDecoration
.second
[0];
1118 if (!memberDecoration
.second
.empty()) {
1119 memberDecorationsInfo
.emplace_back(memberIndex
, /*hasValue=*/1,
1120 memberDecoration
.first
,
1121 memberDecoration
.second
[0]);
1123 memberDecorationsInfo
.emplace_back(memberIndex
, /*hasValue=*/0,
1124 memberDecoration
.first
, 0);
1132 uint32_t structID
= operands
[0];
1133 std::string structIdentifier
= nameMap
.lookup(structID
).str();
1135 if (structIdentifier
.empty()) {
1136 assert(unresolvedMemberTypes
.empty() &&
1137 "didn't expect unresolved member types");
1139 spirv::StructType::get(memberTypes
, offsetInfo
, memberDecorationsInfo
);
1141 auto structTy
= spirv::StructType::getIdentified(context
, structIdentifier
);
1142 typeMap
[structID
] = structTy
;
1144 if (!unresolvedMemberTypes
.empty())
1145 deferredStructTypesInfos
.push_back({structTy
, unresolvedMemberTypes
,
1146 memberTypes
, offsetInfo
,
1147 memberDecorationsInfo
});
1148 else if (failed(structTy
.trySetBody(memberTypes
, offsetInfo
,
1149 memberDecorationsInfo
)))
1153 // TODO: Update StructType to have member name as attribute as
1159 spirv::Deserializer::processMatrixType(ArrayRef
<uint32_t> operands
) {
1160 if (operands
.size() != 3) {
1161 // Three operands are needed: result_id, column_type, and column_count
1162 return emitError(unknownLoc
, "OpTypeMatrix must have 3 operands"
1163 " (result_id, column_type, and column_count)");
1165 // Matrix columns must be of vector type
1166 Type elementTy
= getType(operands
[1]);
1168 return emitError(unknownLoc
,
1169 "OpTypeMatrix references undefined column type.")
1173 uint32_t colsCount
= operands
[2];
1174 typeMap
[operands
[0]] = spirv::MatrixType::get(elementTy
, colsCount
);
1179 spirv::Deserializer::processTypeForwardPointer(ArrayRef
<uint32_t> operands
) {
1180 if (operands
.size() != 2)
1181 return emitError(unknownLoc
,
1182 "OpTypeForwardPointer instruction must have two operands");
1184 typeForwardPointerIDs
.insert(operands
[0]);
1185 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1186 // instruction that defines the actual type.
1192 spirv::Deserializer::processImageType(ArrayRef
<uint32_t> operands
) {
1193 // TODO: Add support for Access Qualifier.
1194 if (operands
.size() != 8)
1197 "OpTypeImage with non-eight operands are not supported yet");
1199 Type elementTy
= getType(operands
[1]);
1201 return emitError(unknownLoc
, "OpTypeImage references undefined <id>: ")
1204 auto dim
= spirv::symbolizeDim(operands
[2]);
1206 return emitError(unknownLoc
, "unknown Dim for OpTypeImage: ")
1209 auto depthInfo
= spirv::symbolizeImageDepthInfo(operands
[3]);
1211 return emitError(unknownLoc
, "unknown Depth for OpTypeImage: ")
1214 auto arrayedInfo
= spirv::symbolizeImageArrayedInfo(operands
[4]);
1216 return emitError(unknownLoc
, "unknown Arrayed for OpTypeImage: ")
1219 auto samplingInfo
= spirv::symbolizeImageSamplingInfo(operands
[5]);
1221 return emitError(unknownLoc
, "unknown MS for OpTypeImage: ") << operands
[5];
1223 auto samplerUseInfo
= spirv::symbolizeImageSamplerUseInfo(operands
[6]);
1224 if (!samplerUseInfo
)
1225 return emitError(unknownLoc
, "unknown Sampled for OpTypeImage: ")
1228 auto format
= spirv::symbolizeImageFormat(operands
[7]);
1230 return emitError(unknownLoc
, "unknown Format for OpTypeImage: ")
1233 typeMap
[operands
[0]] = spirv::ImageType::get(
1234 elementTy
, dim
.value(), depthInfo
.value(), arrayedInfo
.value(),
1235 samplingInfo
.value(), samplerUseInfo
.value(), format
.value());
1240 spirv::Deserializer::processSampledImageType(ArrayRef
<uint32_t> operands
) {
1241 if (operands
.size() != 2)
1242 return emitError(unknownLoc
, "OpTypeSampledImage must have two operands");
1244 Type elementTy
= getType(operands
[1]);
1246 return emitError(unknownLoc
,
1247 "OpTypeSampledImage references undefined <id>: ")
1250 typeMap
[operands
[0]] = spirv::SampledImageType::get(elementTy
);
1254 //===----------------------------------------------------------------------===//
1256 //===----------------------------------------------------------------------===//
1258 LogicalResult
spirv::Deserializer::processConstant(ArrayRef
<uint32_t> operands
,
1260 StringRef opname
= isSpec
? "OpSpecConstant" : "OpConstant";
1262 if (operands
.size() < 2) {
1263 return emitError(unknownLoc
)
1264 << opname
<< " must have type <id> and result <id>";
1266 if (operands
.size() < 3) {
1267 return emitError(unknownLoc
)
1268 << opname
<< " must have at least 1 more parameter";
1271 Type resultType
= getType(operands
[0]);
1273 return emitError(unknownLoc
, "undefined result type from <id> ")
1277 auto checkOperandSizeForBitwidth
= [&](unsigned bitwidth
) -> LogicalResult
{
1278 if (bitwidth
== 64) {
1279 if (operands
.size() == 4) {
1282 return emitError(unknownLoc
)
1283 << opname
<< " should have 2 parameters for 64-bit values";
1285 if (bitwidth
<= 32) {
1286 if (operands
.size() == 3) {
1290 return emitError(unknownLoc
)
1292 << " should have 1 parameter for values with no more than 32 bits";
1294 return emitError(unknownLoc
, "unsupported OpConstant bitwidth: ")
1298 auto resultID
= operands
[1];
1300 if (auto intType
= dyn_cast
<IntegerType
>(resultType
)) {
1301 auto bitwidth
= intType
.getWidth();
1302 if (failed(checkOperandSizeForBitwidth(bitwidth
))) {
1307 if (bitwidth
== 64) {
1308 // 64-bit integers are represented with two SPIR-V words. According to
1309 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1310 // literal’s low-order words appear first."
1314 } words
= {operands
[2], operands
[3]};
1315 value
= APInt(64, llvm::bit_cast
<uint64_t>(words
), /*isSigned=*/true);
1316 } else if (bitwidth
<= 32) {
1317 value
= APInt(bitwidth
, operands
[2], /*isSigned=*/true);
1320 auto attr
= opBuilder
.getIntegerAttr(intType
, value
);
1323 createSpecConstant(unknownLoc
, resultID
, attr
);
1325 // For normal constants, we just record the attribute (and its type) for
1326 // later materialization at use sites.
1327 constantMap
.try_emplace(resultID
, attr
, intType
);
1333 if (auto floatType
= dyn_cast
<FloatType
>(resultType
)) {
1334 auto bitwidth
= floatType
.getWidth();
1335 if (failed(checkOperandSizeForBitwidth(bitwidth
))) {
1340 if (floatType
.isF64()) {
1341 // Double values are represented with two SPIR-V words. According to
1342 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1343 // literal’s low-order words appear first."
1347 } words
= {operands
[2], operands
[3]};
1348 value
= APFloat(llvm::bit_cast
<double>(words
));
1349 } else if (floatType
.isF32()) {
1350 value
= APFloat(llvm::bit_cast
<float>(operands
[2]));
1351 } else if (floatType
.isF16()) {
1352 APInt
data(16, operands
[2]);
1353 value
= APFloat(APFloat::IEEEhalf(), data
);
1356 auto attr
= opBuilder
.getFloatAttr(floatType
, value
);
1358 createSpecConstant(unknownLoc
, resultID
, attr
);
1360 // For normal constants, we just record the attribute (and its type) for
1361 // later materialization at use sites.
1362 constantMap
.try_emplace(resultID
, attr
, floatType
);
1368 return emitError(unknownLoc
, "OpConstant can only generate values of "
1369 "scalar integer or floating-point type");
1372 LogicalResult
spirv::Deserializer::processConstantBool(
1373 bool isTrue
, ArrayRef
<uint32_t> operands
, bool isSpec
) {
1374 if (operands
.size() != 2) {
1375 return emitError(unknownLoc
, "Op")
1376 << (isSpec
? "Spec" : "") << "Constant"
1377 << (isTrue
? "True" : "False")
1378 << " must have type <id> and result <id>";
1381 auto attr
= opBuilder
.getBoolAttr(isTrue
);
1382 auto resultID
= operands
[1];
1384 createSpecConstant(unknownLoc
, resultID
, attr
);
1386 // For normal constants, we just record the attribute (and its type) for
1387 // later materialization at use sites.
1388 constantMap
.try_emplace(resultID
, attr
, opBuilder
.getI1Type());
1395 spirv::Deserializer::processConstantComposite(ArrayRef
<uint32_t> operands
) {
1396 if (operands
.size() < 2) {
1397 return emitError(unknownLoc
,
1398 "OpConstantComposite must have type <id> and result <id>");
1400 if (operands
.size() < 3) {
1401 return emitError(unknownLoc
,
1402 "OpConstantComposite must have at least 1 parameter");
1405 Type resultType
= getType(operands
[0]);
1407 return emitError(unknownLoc
, "undefined result type from <id> ")
1411 SmallVector
<Attribute
, 4> elements
;
1412 elements
.reserve(operands
.size() - 2);
1413 for (unsigned i
= 2, e
= operands
.size(); i
< e
; ++i
) {
1414 auto elementInfo
= getConstant(operands
[i
]);
1416 return emitError(unknownLoc
, "OpConstantComposite component <id> ")
1417 << operands
[i
] << " must come from a normal constant";
1419 elements
.push_back(elementInfo
->first
);
1422 auto resultID
= operands
[1];
1423 if (auto vectorType
= dyn_cast
<VectorType
>(resultType
)) {
1424 auto attr
= DenseElementsAttr::get(vectorType
, elements
);
1425 // For normal constants, we just record the attribute (and its type) for
1426 // later materialization at use sites.
1427 constantMap
.try_emplace(resultID
, attr
, resultType
);
1428 } else if (auto arrayType
= dyn_cast
<spirv::ArrayType
>(resultType
)) {
1429 auto attr
= opBuilder
.getArrayAttr(elements
);
1430 constantMap
.try_emplace(resultID
, attr
, resultType
);
1432 return emitError(unknownLoc
, "unsupported OpConstantComposite type: ")
1440 spirv::Deserializer::processSpecConstantComposite(ArrayRef
<uint32_t> operands
) {
1441 if (operands
.size() < 2) {
1442 return emitError(unknownLoc
,
1443 "OpConstantComposite must have type <id> and result <id>");
1445 if (operands
.size() < 3) {
1446 return emitError(unknownLoc
,
1447 "OpConstantComposite must have at least 1 parameter");
1450 Type resultType
= getType(operands
[0]);
1452 return emitError(unknownLoc
, "undefined result type from <id> ")
1456 auto resultID
= operands
[1];
1457 auto symName
= opBuilder
.getStringAttr(getSpecConstantSymbol(resultID
));
1459 SmallVector
<Attribute
, 4> elements
;
1460 elements
.reserve(operands
.size() - 2);
1461 for (unsigned i
= 2, e
= operands
.size(); i
< e
; ++i
) {
1462 auto elementInfo
= getSpecConstant(operands
[i
]);
1463 elements
.push_back(SymbolRefAttr::get(elementInfo
));
1466 auto op
= opBuilder
.create
<spirv::SpecConstantCompositeOp
>(
1467 unknownLoc
, TypeAttr::get(resultType
), symName
,
1468 opBuilder
.getArrayAttr(elements
));
1469 specConstCompositeMap
[resultID
] = op
;
1475 spirv::Deserializer::processSpecConstantOperation(ArrayRef
<uint32_t> operands
) {
1476 if (operands
.size() < 3)
1477 return emitError(unknownLoc
, "OpConstantOperation must have type <id>, "
1478 "result <id>, and operand opcode");
1480 uint32_t resultTypeID
= operands
[0];
1482 if (!getType(resultTypeID
))
1483 return emitError(unknownLoc
, "undefined result type from <id> ")
1486 uint32_t resultID
= operands
[1];
1487 spirv::Opcode enclosedOpcode
= static_cast<spirv::Opcode
>(operands
[2]);
1488 auto emplaceResult
= specConstOperationMap
.try_emplace(
1490 SpecConstOperationMaterializationInfo
{
1491 enclosedOpcode
, resultTypeID
,
1492 SmallVector
<uint32_t>{operands
.begin() + 3, operands
.end()}});
1494 if (!emplaceResult
.second
)
1495 return emitError(unknownLoc
, "value with <id>: ")
1496 << resultID
<< " is probably defined before.";
1501 Value
spirv::Deserializer::materializeSpecConstantOperation(
1502 uint32_t resultID
, spirv::Opcode enclosedOpcode
, uint32_t resultTypeID
,
1503 ArrayRef
<uint32_t> enclosedOpOperands
) {
1505 Type resultType
= getType(resultTypeID
);
1507 // Instructions wrapped by OpSpecConstantOp need an ID for their
1508 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1509 // dialect wrapped op. For that purpose, a new value map is created and "fake"
1510 // ID in that map is assigned to the result of the enclosed instruction. Note
1511 // that there is no need to update this fake ID since we only need to
1512 // reference the created Value for the enclosed op from the spv::YieldOp
1513 // created later in this method (both of which are the only values in their
1514 // region: the SpecConstantOperation's region). If we encounter another
1515 // SpecConstantOperation in the module, we simply re-use the fake ID since the
1516 // previous Value assigned to it isn't visible in the current scope anyway.
1517 DenseMap
<uint32_t, Value
> newValueMap
;
1518 llvm::SaveAndRestore
valueMapGuard(valueMap
, newValueMap
);
1519 constexpr uint32_t fakeID
= static_cast<uint32_t>(-3);
1521 SmallVector
<uint32_t, 4> enclosedOpResultTypeAndOperands
;
1522 enclosedOpResultTypeAndOperands
.push_back(resultTypeID
);
1523 enclosedOpResultTypeAndOperands
.push_back(fakeID
);
1524 enclosedOpResultTypeAndOperands
.append(enclosedOpOperands
.begin(),
1525 enclosedOpOperands
.end());
1527 // Process enclosed instruction before creating the enclosing
1528 // specConstantOperation (and its region). This way, references to constants,
1529 // global variables, and spec constants will be materialized outside the new
1530 // op's region. For more info, see Deserializer::getValue's implementation.
1532 processInstruction(enclosedOpcode
, enclosedOpResultTypeAndOperands
)))
1535 // Since the enclosed op is emitted in the current block, split it in a
1536 // separate new block.
1537 Block
*enclosedBlock
= curBlock
->splitBlock(&curBlock
->back());
1539 auto loc
= createFileLineColLoc(opBuilder
);
1540 auto specConstOperationOp
=
1541 opBuilder
.create
<spirv::SpecConstantOperationOp
>(loc
, resultType
);
1543 Region
&body
= specConstOperationOp
.getBody();
1544 // Move the new block into SpecConstantOperation's body.
1545 body
.getBlocks().splice(body
.end(), curBlock
->getParent()->getBlocks(),
1546 Region::iterator(enclosedBlock
));
1547 Block
&block
= body
.back();
1549 // RAII guard to reset the insertion point to the module's region after
1550 // deserializing the body of the specConstantOperation.
1551 OpBuilder::InsertionGuard
moduleInsertionGuard(opBuilder
);
1552 opBuilder
.setInsertionPointToEnd(&block
);
1554 opBuilder
.create
<spirv::YieldOp
>(loc
, block
.front().getResult(0));
1555 return specConstOperationOp
.getResult();
1559 spirv::Deserializer::processConstantNull(ArrayRef
<uint32_t> operands
) {
1560 if (operands
.size() != 2) {
1561 return emitError(unknownLoc
,
1562 "OpConstantNull must have type <id> and result <id>");
1565 Type resultType
= getType(operands
[0]);
1567 return emitError(unknownLoc
, "undefined result type from <id> ")
1571 auto resultID
= operands
[1];
1572 if (resultType
.isIntOrFloat() || isa
<VectorType
>(resultType
)) {
1573 auto attr
= opBuilder
.getZeroAttr(resultType
);
1574 // For normal constants, we just record the attribute (and its type) for
1575 // later materialization at use sites.
1576 constantMap
.try_emplace(resultID
, attr
, resultType
);
1580 return emitError(unknownLoc
, "unsupported OpConstantNull type: ")
1584 //===----------------------------------------------------------------------===//
1586 //===----------------------------------------------------------------------===//
1588 Block
*spirv::Deserializer::getOrCreateBlock(uint32_t id
) {
1589 if (auto *block
= getBlock(id
)) {
1590 LLVM_DEBUG(logger
.startLine() << "[block] got exiting block for id = " << id
1591 << " @ " << block
<< "\n");
1595 // We don't know where this block will be placed finally (in a
1596 // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1597 // function for now and sort out the proper place later.
1598 auto *block
= curFunction
->addBlock();
1599 LLVM_DEBUG(logger
.startLine() << "[block] created block for id = " << id
1600 << " @ " << block
<< "\n");
1601 return blockMap
[id
] = block
;
1604 LogicalResult
spirv::Deserializer::processBranch(ArrayRef
<uint32_t> operands
) {
1606 return emitError(unknownLoc
, "OpBranch must appear inside a block");
1609 if (operands
.size() != 1) {
1610 return emitError(unknownLoc
, "OpBranch must take exactly one target label");
1613 auto *target
= getOrCreateBlock(operands
[0]);
1614 auto loc
= createFileLineColLoc(opBuilder
);
1615 // The preceding instruction for the OpBranch instruction could be an
1616 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1617 // the same OpLine information.
1618 opBuilder
.create
<spirv::BranchOp
>(loc
, target
);
1625 spirv::Deserializer::processBranchConditional(ArrayRef
<uint32_t> operands
) {
1627 return emitError(unknownLoc
,
1628 "OpBranchConditional must appear inside a block");
1631 if (operands
.size() != 3 && operands
.size() != 5) {
1632 return emitError(unknownLoc
,
1633 "OpBranchConditional must have condition, true label, "
1634 "false label, and optionally two branch weights");
1637 auto condition
= getValue(operands
[0]);
1638 auto *trueBlock
= getOrCreateBlock(operands
[1]);
1639 auto *falseBlock
= getOrCreateBlock(operands
[2]);
1641 std::optional
<std::pair
<uint32_t, uint32_t>> weights
;
1642 if (operands
.size() == 5) {
1643 weights
= std::make_pair(operands
[3], operands
[4]);
1645 // The preceding instruction for the OpBranchConditional instruction could be
1646 // an OpSelectionMerge instruction, in this case they will have the same
1647 // OpLine information.
1648 auto loc
= createFileLineColLoc(opBuilder
);
1649 opBuilder
.create
<spirv::BranchConditionalOp
>(
1650 loc
, condition
, trueBlock
,
1651 /*trueArguments=*/ArrayRef
<Value
>(), falseBlock
,
1652 /*falseArguments=*/ArrayRef
<Value
>(), weights
);
1658 LogicalResult
spirv::Deserializer::processLabel(ArrayRef
<uint32_t> operands
) {
1660 return emitError(unknownLoc
, "OpLabel must appear inside a function");
1663 if (operands
.size() != 1) {
1664 return emitError(unknownLoc
, "OpLabel should only have result <id>");
1667 auto labelID
= operands
[0];
1668 // We may have forward declared this block.
1669 auto *block
= getOrCreateBlock(labelID
);
1670 LLVM_DEBUG(logger
.startLine()
1671 << "[block] populating block " << block
<< "\n");
1672 // If we have seen this block, make sure it was just a forward declaration.
1673 assert(block
->empty() && "re-deserialize the same block!");
1675 opBuilder
.setInsertionPointToStart(block
);
1676 blockMap
[labelID
] = curBlock
= block
;
1682 spirv::Deserializer::processSelectionMerge(ArrayRef
<uint32_t> operands
) {
1684 return emitError(unknownLoc
, "OpSelectionMerge must appear in a block");
1687 if (operands
.size() < 2) {
1690 "OpSelectionMerge must specify merge target and selection control");
1693 auto *mergeBlock
= getOrCreateBlock(operands
[0]);
1694 auto loc
= createFileLineColLoc(opBuilder
);
1695 auto selectionControl
= operands
[1];
1697 if (!blockMergeInfo
.try_emplace(curBlock
, loc
, selectionControl
, mergeBlock
)
1701 "a block cannot have more than one OpSelectionMerge instruction");
1708 spirv::Deserializer::processLoopMerge(ArrayRef
<uint32_t> operands
) {
1710 return emitError(unknownLoc
, "OpLoopMerge must appear in a block");
1713 if (operands
.size() < 3) {
1714 return emitError(unknownLoc
, "OpLoopMerge must specify merge target, "
1715 "continue target and loop control");
1718 auto *mergeBlock
= getOrCreateBlock(operands
[0]);
1719 auto *continueBlock
= getOrCreateBlock(operands
[1]);
1720 auto loc
= createFileLineColLoc(opBuilder
);
1721 uint32_t loopControl
= operands
[2];
1724 .try_emplace(curBlock
, loc
, loopControl
, mergeBlock
, continueBlock
)
1728 "a block cannot have more than one OpLoopMerge instruction");
1734 LogicalResult
spirv::Deserializer::processPhi(ArrayRef
<uint32_t> operands
) {
1736 return emitError(unknownLoc
, "OpPhi must appear in a block");
1739 if (operands
.size() < 4) {
1740 return emitError(unknownLoc
, "OpPhi must specify result type, result <id>, "
1741 "and variable-parent pairs");
1744 // Create a block argument for this OpPhi instruction.
1745 Type blockArgType
= getType(operands
[0]);
1746 BlockArgument blockArg
= curBlock
->addArgument(blockArgType
, unknownLoc
);
1747 valueMap
[operands
[1]] = blockArg
;
1748 LLVM_DEBUG(logger
.startLine()
1749 << "[phi] created block argument " << blockArg
1750 << " id = " << operands
[1] << " of type " << blockArgType
<< "\n");
1752 // For each (value, predecessor) pair, insert the value to the predecessor's
1753 // blockPhiInfo entry so later we can fix the block argument there.
1754 for (unsigned i
= 2, e
= operands
.size(); i
< e
; i
+= 2) {
1755 uint32_t value
= operands
[i
];
1756 Block
*predecessor
= getOrCreateBlock(operands
[i
+ 1]);
1757 std::pair
<Block
*, Block
*> predecessorTargetPair
{predecessor
, curBlock
};
1758 blockPhiInfo
[predecessorTargetPair
].push_back(value
);
1759 LLVM_DEBUG(logger
.startLine() << "[phi] predecessor @ " << predecessor
1760 << " with arg id = " << value
<< "\n");
1767 /// A class for putting all blocks in a structured selection/loop in a
1768 /// spirv.mlir.selection/spirv.mlir.loop op.
1769 class ControlFlowStructurizer
{
1772 ControlFlowStructurizer(Location loc
, uint32_t control
,
1773 spirv::BlockMergeInfoMap
&mergeInfo
, Block
*header
,
1774 Block
*merge
, Block
*cont
,
1775 llvm::ScopedPrinter
&logger
)
1776 : location(loc
), control(control
), blockMergeInfo(mergeInfo
),
1777 headerBlock(header
), mergeBlock(merge
), continueBlock(cont
),
1780 ControlFlowStructurizer(Location loc
, uint32_t control
,
1781 spirv::BlockMergeInfoMap
&mergeInfo
, Block
*header
,
1782 Block
*merge
, Block
*cont
)
1783 : location(loc
), control(control
), blockMergeInfo(mergeInfo
),
1784 headerBlock(header
), mergeBlock(merge
), continueBlock(cont
) {}
1787 /// Structurizes the loop at the given `headerBlock`.
1789 /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1790 /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1791 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1792 /// method will also update `mergeInfo` by remapping all blocks inside to the
1793 /// newly cloned ones inside structured control flow op's regions.
1794 LogicalResult
structurize();
1797 /// Creates a new spirv.mlir.selection op at the beginning of the
1799 spirv::SelectionOp
createSelectionOp(uint32_t selectionControl
);
1801 /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1802 spirv::LoopOp
createLoopOp(uint32_t loopControl
);
1804 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1805 void collectBlocksInConstruct();
1810 spirv::BlockMergeInfoMap
&blockMergeInfo
;
1814 Block
*continueBlock
; // nullptr for spirv.mlir.selection
1816 SetVector
<Block
*> constructBlocks
;
1819 /// A logger used to emit information during the deserialzation process.
1820 llvm::ScopedPrinter
&logger
;
1826 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl
) {
1827 // Create a builder and set the insertion point to the beginning of the
1828 // merge block so that the newly created SelectionOp will be inserted there.
1829 OpBuilder
builder(&mergeBlock
->front());
1831 auto control
= static_cast<spirv::SelectionControl
>(selectionControl
);
1832 auto selectionOp
= builder
.create
<spirv::SelectionOp
>(location
, control
);
1833 selectionOp
.addMergeBlock(builder
);
1838 spirv::LoopOp
ControlFlowStructurizer::createLoopOp(uint32_t loopControl
) {
1839 // Create a builder and set the insertion point to the beginning of the
1840 // merge block so that the newly created LoopOp will be inserted there.
1841 OpBuilder
builder(&mergeBlock
->front());
1843 auto control
= static_cast<spirv::LoopControl
>(loopControl
);
1844 auto loopOp
= builder
.create
<spirv::LoopOp
>(location
, control
);
1845 loopOp
.addEntryAndMergeBlock(builder
);
1850 void ControlFlowStructurizer::collectBlocksInConstruct() {
1851 assert(constructBlocks
.empty() && "expected empty constructBlocks");
1853 // Put the header block in the work list first.
1854 constructBlocks
.insert(headerBlock
);
1856 // For each item in the work list, add its successors excluding the merge
1858 for (unsigned i
= 0; i
< constructBlocks
.size(); ++i
) {
1859 for (auto *successor
: constructBlocks
[i
]->getSuccessors())
1860 if (successor
!= mergeBlock
)
1861 constructBlocks
.insert(successor
);
1865 LogicalResult
ControlFlowStructurizer::structurize() {
1866 Operation
*op
= nullptr;
1867 bool isLoop
= continueBlock
!= nullptr;
1869 if (auto loopOp
= createLoopOp(control
))
1870 op
= loopOp
.getOperation();
1872 if (auto selectionOp
= createSelectionOp(control
))
1873 op
= selectionOp
.getOperation();
1877 Region
&body
= op
->getRegion(0);
1880 // All references to the old merge block should be directed to the
1881 // selection/loop merge block in the SelectionOp/LoopOp's region.
1882 mapper
.map(mergeBlock
, &body
.back());
1884 collectBlocksInConstruct();
1886 // We've identified all blocks belonging to the selection/loop's region. Now
1887 // need to "move" them into the selection/loop. Instead of really moving the
1888 // blocks, in the following we copy them and remap all values and branches.
1890 // * Inserting a block into a region requires the block not in any region
1891 // before. But selections/loops can nest so we can create selection/loop ops
1892 // in a nested manner, which means some blocks may already be in a
1893 // selection/loop region when to be moved again.
1894 // * It's much trickier to fix up the branches into and out of the loop's
1895 // region: we need to treat not-moved blocks and moved blocks differently:
1896 // Not-moved blocks jumping to the loop header block need to jump to the
1897 // merge point containing the new loop op but not the loop continue block's
1898 // back edge. Moved blocks jumping out of the loop need to jump to the
1899 // merge block inside the loop region but not other not-moved blocks.
1900 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1903 // Create a corresponding block in the SelectionOp/LoopOp's region for each
1904 // block in this loop construct.
1905 OpBuilder
builder(body
);
1906 for (auto *block
: constructBlocks
) {
1907 // Create a block and insert it before the selection/loop merge block in the
1908 // SelectionOp/LoopOp's region.
1909 auto *newBlock
= builder
.createBlock(&body
.back());
1910 mapper
.map(block
, newBlock
);
1911 LLVM_DEBUG(logger
.startLine() << "[cf] cloned block " << newBlock
1912 << " from block " << block
<< "\n");
1913 if (!isFnEntryBlock(block
)) {
1914 for (BlockArgument blockArg
: block
->getArguments()) {
1916 newBlock
->addArgument(blockArg
.getType(), blockArg
.getLoc());
1917 mapper
.map(blockArg
, newArg
);
1918 LLVM_DEBUG(logger
.startLine() << "[cf] remapped block argument "
1919 << blockArg
<< " to " << newArg
<< "\n");
1922 LLVM_DEBUG(logger
.startLine()
1923 << "[cf] block " << block
<< " is a function entry block\n");
1926 for (auto &op
: *block
)
1927 newBlock
->push_back(op
.clone(mapper
));
1930 // Go through all ops and remap the operands.
1931 auto remapOperands
= [&](Operation
*op
) {
1932 for (auto &operand
: op
->getOpOperands())
1933 if (Value mappedOp
= mapper
.lookupOrNull(operand
.get()))
1934 operand
.set(mappedOp
);
1935 for (auto &succOp
: op
->getBlockOperands())
1936 if (Block
*mappedOp
= mapper
.lookupOrNull(succOp
.get()))
1937 succOp
.set(mappedOp
);
1939 for (auto &block
: body
)
1940 block
.walk(remapOperands
);
1942 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1943 // the selection/loop construct into its region. Next we need to fix the
1944 // connections between this new SelectionOp/LoopOp with existing blocks.
1946 // All existing incoming branches should go to the merge block, where the
1947 // SelectionOp/LoopOp resides right now.
1948 headerBlock
->replaceAllUsesWith(mergeBlock
);
1951 logger
.startLine() << "[cf] after cloning and fixing references:\n";
1952 headerBlock
->getParentOp()->print(logger
.getOStream());
1953 logger
.startLine() << "\n";
1957 if (!mergeBlock
->args_empty()) {
1958 return mergeBlock
->getParentOp()->emitError(
1959 "OpPhi in loop merge block unsupported");
1962 // The loop header block may have block arguments. Since now we place the
1963 // loop op inside the old merge block, we need to make sure the old merge
1964 // block has the same block argument list.
1965 for (BlockArgument blockArg
: headerBlock
->getArguments())
1966 mergeBlock
->addArgument(blockArg
.getType(), blockArg
.getLoc());
1968 // If the loop header block has block arguments, make sure the spirv.Branch
1970 SmallVector
<Value
, 4> blockArgs
;
1971 if (!headerBlock
->args_empty())
1972 blockArgs
= {mergeBlock
->args_begin(), mergeBlock
->args_end()};
1974 // The loop entry block should have a unconditional branch jumping to the
1975 // loop header block.
1976 builder
.setInsertionPointToEnd(&body
.front());
1977 builder
.create
<spirv::BranchOp
>(location
, mapper
.lookupOrNull(headerBlock
),
1978 ArrayRef
<Value
>(blockArgs
));
1981 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1983 LLVM_DEBUG(logger
.startLine() << "[cf] cleaning up blocks after clone\n");
1984 // First we need to drop all operands' references inside all blocks. This is
1985 // needed because we can have blocks referencing SSA values from one another.
1986 for (auto *block
: constructBlocks
)
1987 block
->dropAllReferences();
1989 // Check that whether some op in the to-be-erased blocks still has uses. Those
1990 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
1991 // region. We cannot handle such cases given that once a value is sinked into
1992 // the SelectionOp/LoopOp's region, there is no escape for it:
1993 // SelectionOp/LooOp does not support yield values right now.
1994 for (auto *block
: constructBlocks
) {
1995 for (Operation
&op
: *block
)
1996 if (!op
.use_empty())
1997 return op
.emitOpError(
1998 "failed control flow structurization: it has uses outside of the "
1999 "enclosing selection/loop construct");
2002 // Then erase all old blocks.
2003 for (auto *block
: constructBlocks
) {
2004 // We've cloned all blocks belonging to this construct into the structured
2005 // control flow op's region. Among these blocks, some may compose another
2006 // selection/loop. If so, they will be recorded within blockMergeInfo.
2007 // We need to update the pointers there to the newly remapped ones so we can
2008 // continue structurizing them later.
2009 // TODO: The asserts in the following assumes input SPIR-V blob forms
2010 // correctly nested selection/loop constructs. We should relax this and
2011 // support error cases better.
2012 auto it
= blockMergeInfo
.find(block
);
2013 if (it
!= blockMergeInfo
.end()) {
2014 // Use the original location for nested selection/loop ops.
2015 Location loc
= it
->second
.loc
;
2017 Block
*newHeader
= mapper
.lookupOrNull(block
);
2019 return emitError(loc
, "failed control flow structurization: nested "
2020 "loop header block should be remapped!");
2022 Block
*newContinue
= it
->second
.continueBlock
;
2024 newContinue
= mapper
.lookupOrNull(newContinue
);
2026 return emitError(loc
, "failed control flow structurization: nested "
2027 "loop continue block should be remapped!");
2030 Block
*newMerge
= it
->second
.mergeBlock
;
2031 if (Block
*mappedTo
= mapper
.lookupOrNull(newMerge
))
2032 newMerge
= mappedTo
;
2034 // The iterator should be erased before adding a new entry into
2035 // blockMergeInfo to avoid iterator invalidation.
2036 blockMergeInfo
.erase(it
);
2037 blockMergeInfo
.try_emplace(newHeader
, loc
, it
->second
.control
, newMerge
,
2041 // The structured selection/loop's entry block does not have arguments.
2042 // If the function's header block is also part of the structured control
2043 // flow, we cannot just simply erase it because it may contain arguments
2044 // matching the function signature and used by the cloned blocks.
2045 if (isFnEntryBlock(block
)) {
2046 LLVM_DEBUG(logger
.startLine() << "[cf] changing entry block " << block
2047 << " to only contain a spirv.Branch op\n");
2048 // Still keep the function entry block for the potential block arguments,
2049 // but replace all ops inside with a branch to the merge block.
2051 builder
.setInsertionPointToEnd(block
);
2052 builder
.create
<spirv::BranchOp
>(location
, mergeBlock
);
2054 LLVM_DEBUG(logger
.startLine() << "[cf] erasing block " << block
<< "\n");
2059 LLVM_DEBUG(logger
.startLine()
2060 << "[cf] after structurizing construct with header block "
2061 << headerBlock
<< ":\n"
2067 LogicalResult
spirv::Deserializer::wireUpBlockArgument() {
2070 << "//----- [phi] start wiring up block arguments -----//\n";
2074 OpBuilder::InsertionGuard
guard(opBuilder
);
2076 for (const auto &info
: blockPhiInfo
) {
2077 Block
*block
= info
.first
.first
;
2078 Block
*target
= info
.first
.second
;
2079 const BlockPhiInfo
&phiInfo
= info
.second
;
2081 logger
.startLine() << "[phi] block " << block
<< "\n";
2082 logger
.startLine() << "[phi] before creating block argument:\n";
2083 block
->getParentOp()->print(logger
.getOStream());
2084 logger
.startLine() << "\n";
2087 // Set insertion point to before this block's terminator early because we
2088 // may materialize ops via getValue() call.
2089 auto *op
= block
->getTerminator();
2090 opBuilder
.setInsertionPoint(op
);
2092 SmallVector
<Value
, 4> blockArgs
;
2093 blockArgs
.reserve(phiInfo
.size());
2094 for (uint32_t valueId
: phiInfo
) {
2095 if (Value value
= getValue(valueId
)) {
2096 blockArgs
.push_back(value
);
2097 LLVM_DEBUG(logger
.startLine() << "[phi] block argument " << value
2098 << " id = " << valueId
<< "\n");
2100 return emitError(unknownLoc
, "OpPhi references undefined value!");
2104 if (auto branchOp
= dyn_cast
<spirv::BranchOp
>(op
)) {
2105 // Replace the previous branch op with a new one with block arguments.
2106 opBuilder
.create
<spirv::BranchOp
>(branchOp
.getLoc(), branchOp
.getTarget(),
2109 } else if (auto branchCondOp
= dyn_cast
<spirv::BranchConditionalOp
>(op
)) {
2110 assert((branchCondOp
.getTrueBlock() == target
||
2111 branchCondOp
.getFalseBlock() == target
) &&
2112 "expected target to be either the true or false target");
2113 if (target
== branchCondOp
.getTrueTarget())
2114 opBuilder
.create
<spirv::BranchConditionalOp
>(
2115 branchCondOp
.getLoc(), branchCondOp
.getCondition(), blockArgs
,
2116 branchCondOp
.getFalseBlockArguments(),
2117 branchCondOp
.getBranchWeightsAttr(), branchCondOp
.getTrueTarget(),
2118 branchCondOp
.getFalseTarget());
2120 opBuilder
.create
<spirv::BranchConditionalOp
>(
2121 branchCondOp
.getLoc(), branchCondOp
.getCondition(),
2122 branchCondOp
.getTrueBlockArguments(), blockArgs
,
2123 branchCondOp
.getBranchWeightsAttr(), branchCondOp
.getTrueBlock(),
2124 branchCondOp
.getFalseBlock());
2126 branchCondOp
.erase();
2128 return emitError(unknownLoc
, "unimplemented terminator for Phi creation");
2132 logger
.startLine() << "[phi] after creating block argument:\n";
2133 block
->getParentOp()->print(logger
.getOStream());
2134 logger
.startLine() << "\n";
2137 blockPhiInfo
.clear();
2142 << "//--- [phi] completed wiring up block arguments ---//\n";
2147 LogicalResult
spirv::Deserializer::structurizeControlFlow() {
2150 << "//----- [cf] start structurizing control flow -----//\n";
2154 while (!blockMergeInfo
.empty()) {
2155 Block
*headerBlock
= blockMergeInfo
.begin()->first
;
2156 BlockMergeInfo mergeInfo
= blockMergeInfo
.begin()->second
;
2159 logger
.startLine() << "[cf] header block " << headerBlock
<< ":\n";
2160 headerBlock
->print(logger
.getOStream());
2161 logger
.startLine() << "\n";
2164 auto *mergeBlock
= mergeInfo
.mergeBlock
;
2165 assert(mergeBlock
&& "merge block cannot be nullptr");
2166 if (!mergeBlock
->args_empty())
2167 return emitError(unknownLoc
, "OpPhi in loop merge block unimplemented");
2169 logger
.startLine() << "[cf] merge block " << mergeBlock
<< ":\n";
2170 mergeBlock
->print(logger
.getOStream());
2171 logger
.startLine() << "\n";
2174 auto *continueBlock
= mergeInfo
.continueBlock
;
2175 LLVM_DEBUG(if (continueBlock
) {
2176 logger
.startLine() << "[cf] continue block " << continueBlock
<< ":\n";
2177 continueBlock
->print(logger
.getOStream());
2178 logger
.startLine() << "\n";
2180 // Erase this case before calling into structurizer, who will update
2182 blockMergeInfo
.erase(blockMergeInfo
.begin());
2183 ControlFlowStructurizer
structurizer(mergeInfo
.loc
, mergeInfo
.control
,
2184 blockMergeInfo
, headerBlock
,
2185 mergeBlock
, continueBlock
2191 if (failed(structurizer
.structurize()))
2198 << "//--- [cf] completed structurizing control flow ---//\n";
2203 //===----------------------------------------------------------------------===//
2205 //===----------------------------------------------------------------------===//
2207 Location
spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder
) {
2211 auto fileName
= debugInfoMap
.lookup(debugLine
->fileID
).str();
2212 if (fileName
.empty())
2213 fileName
= "<unknown>";
2214 return FileLineColLoc::get(opBuilder
.getStringAttr(fileName
), debugLine
->line
,
2219 spirv::Deserializer::processDebugLine(ArrayRef
<uint32_t> operands
) {
2220 // According to SPIR-V spec:
2221 // "This location information applies to the instructions physically
2222 // following this instruction, up to the first occurrence of any of the
2223 // following: the next end of block, the next OpLine instruction, or the next
2224 // OpNoLine instruction."
2225 if (operands
.size() != 3)
2226 return emitError(unknownLoc
, "OpLine must have 3 operands");
2227 debugLine
= DebugLine
{operands
[0], operands
[1], operands
[2]};
2231 void spirv::Deserializer::clearDebugLine() { debugLine
= std::nullopt
; }
2234 spirv::Deserializer::processDebugString(ArrayRef
<uint32_t> operands
) {
2235 if (operands
.size() < 2)
2236 return emitError(unknownLoc
, "OpString needs at least 2 operands");
2238 if (!debugInfoMap
.lookup(operands
[0]).empty())
2239 return emitError(unknownLoc
,
2240 "duplicate debug string found for result <id> ")
2243 unsigned wordIndex
= 1;
2244 StringRef debugString
= decodeStringLiteral(operands
, wordIndex
);
2245 if (wordIndex
!= operands
.size())
2246 return emitError(unknownLoc
,
2247 "unexpected trailing words in OpString instruction");
2249 debugInfoMap
[operands
[0]] = debugString
;