1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
11 //===----------------------------------------------------------------------===//
13 #include "Deserializer.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
35 #define DEBUG_TYPE "spirv-deserialization"
37 //===----------------------------------------------------------------------===//
39 //===----------------------------------------------------------------------===//
41 /// Returns true if the given `block` is a function entry block.
42 static inline bool isFnEntryBlock(Block
*block
) {
43 return block
->isEntryBlock() &&
44 isa_and_nonnull
<spirv::FuncOp
>(block
->getParentOp());
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
51 spirv::Deserializer::Deserializer(ArrayRef
<uint32_t> binary
,
53 : binary(binary
), context(context
), unknownLoc(UnknownLoc::get(context
)),
54 module(createModuleOp()), opBuilder(module
->getRegion())
62 LogicalResult
spirv::Deserializer::deserialize() {
66 << "//+++---------- start deserialization ----------+++//\n";
69 if (failed(processHeader()))
72 spirv::Opcode opcode
= spirv::Opcode::OpNop
;
73 ArrayRef
<uint32_t> operands
;
74 auto binarySize
= binary
.size();
75 while (curOffset
< binarySize
) {
76 // Slice the next instruction out and populate `opcode` and `operands`.
77 // Internally this also updates `curOffset`.
78 if (failed(sliceInstruction(opcode
, operands
)))
81 if (failed(processInstruction(opcode
, operands
)))
85 assert(curOffset
== binarySize
&&
86 "deserializer should never index beyond the binary end");
88 for (auto &deferred
: deferredInstructions
) {
89 if (failed(processInstruction(deferred
.first
, deferred
.second
, false))) {
96 LLVM_DEBUG(logger
.startLine()
97 << "//+++-------- completed deserialization --------+++//\n");
101 OwningOpRef
<spirv::ModuleOp
> spirv::Deserializer::collect() {
102 return std::move(module
);
105 //===----------------------------------------------------------------------===//
107 //===----------------------------------------------------------------------===//
109 OwningOpRef
<spirv::ModuleOp
> spirv::Deserializer::createModuleOp() {
110 OpBuilder
builder(context
);
111 OperationState
state(unknownLoc
, spirv::ModuleOp::getOperationName());
112 spirv::ModuleOp::build(builder
, state
);
113 return cast
<spirv::ModuleOp
>(Operation::create(state
));
116 LogicalResult
spirv::Deserializer::processHeader() {
117 if (binary
.size() < spirv::kHeaderWordCount
)
118 return emitError(unknownLoc
,
119 "SPIR-V binary module must have a 5-word header");
121 if (binary
[0] != spirv::kMagicNumber
)
122 return emitError(unknownLoc
, "incorrect magic number");
124 // Version number bytes: 0 | major number | minor number | 0
125 uint32_t majorVersion
= (binary
[1] << 8) >> 24;
126 uint32_t minorVersion
= (binary
[1] << 16) >> 24;
127 if (majorVersion
== 1) {
128 switch (minorVersion
) {
129 #define MIN_VERSION_CASE(v) \
131 version = spirv::Version::V_1_##v; \
140 #undef MIN_VERSION_CASE
142 return emitError(unknownLoc
, "unsupported SPIR-V minor version: ")
146 return emitError(unknownLoc
, "unsupported SPIR-V major version: ")
150 // TODO: generator number, bound, schema
151 curOffset
= spirv::kHeaderWordCount
;
156 spirv::Deserializer::processCapability(ArrayRef
<uint32_t> operands
) {
157 if (operands
.size() != 1)
158 return emitError(unknownLoc
, "OpMemoryModel must have one parameter");
160 auto cap
= spirv::symbolizeCapability(operands
[0]);
162 return emitError(unknownLoc
, "unknown capability: ") << operands
[0];
164 capabilities
.insert(*cap
);
168 LogicalResult
spirv::Deserializer::processExtension(ArrayRef
<uint32_t> words
) {
172 "OpExtension must have a literal string for the extension name");
175 unsigned wordIndex
= 0;
176 StringRef extName
= decodeStringLiteral(words
, wordIndex
);
177 if (wordIndex
!= words
.size())
178 return emitError(unknownLoc
,
179 "unexpected trailing words in OpExtension instruction");
180 auto ext
= spirv::symbolizeExtension(extName
);
182 return emitError(unknownLoc
, "unknown extension: ") << extName
;
184 extensions
.insert(*ext
);
189 spirv::Deserializer::processExtInstImport(ArrayRef
<uint32_t> words
) {
190 if (words
.size() < 2) {
191 return emitError(unknownLoc
,
192 "OpExtInstImport must have a result <id> and a literal "
193 "string for the extended instruction set name");
196 unsigned wordIndex
= 1;
197 extendedInstSets
[words
[0]] = decodeStringLiteral(words
, wordIndex
);
198 if (wordIndex
!= words
.size()) {
199 return emitError(unknownLoc
,
200 "unexpected trailing words in OpExtInstImport");
205 void spirv::Deserializer::attachVCETriple() {
207 spirv::ModuleOp::getVCETripleAttrName(),
208 spirv::VerCapExtAttr::get(version
, capabilities
.getArrayRef(),
209 extensions
.getArrayRef(), context
));
213 spirv::Deserializer::processMemoryModel(ArrayRef
<uint32_t> operands
) {
214 if (operands
.size() != 2)
215 return emitError(unknownLoc
, "OpMemoryModel must have two operands");
218 module
->getAddressingModelAttrName(),
219 opBuilder
.getAttr
<spirv::AddressingModelAttr
>(
220 static_cast<spirv::AddressingModel
>(operands
.front())));
222 (*module
)->setAttr(module
->getMemoryModelAttrName(),
223 opBuilder
.getAttr
<spirv::MemoryModelAttr
>(
224 static_cast<spirv::MemoryModel
>(operands
.back())));
229 template <typename AttrTy
, typename EnumAttrTy
, typename EnumTy
>
230 LogicalResult
deserializeCacheControlDecoration(
231 Location loc
, OpBuilder
&opBuilder
,
232 DenseMap
<uint32_t, NamedAttrList
> &decorations
, ArrayRef
<uint32_t> words
,
233 StringAttr symbol
, StringRef decorationName
, StringRef cacheControlKind
) {
234 if (words
.size() != 4) {
235 return emitError(loc
, "OpDecoration with ")
236 << decorationName
<< "needs a cache control integer literal and a "
237 << cacheControlKind
<< " cache control literal";
239 unsigned cacheLevel
= words
[2];
240 auto cacheControlAttr
= static_cast<EnumTy
>(words
[3]);
241 auto value
= opBuilder
.getAttr
<AttrTy
>(cacheLevel
, cacheControlAttr
);
242 SmallVector
<Attribute
> attrs
;
244 llvm::dyn_cast_or_null
<ArrayAttr
>(decorations
[words
[0]].get(symbol
)))
245 llvm::append_range(attrs
, attrList
);
246 attrs
.push_back(value
);
247 decorations
[words
[0]].set(symbol
, opBuilder
.getArrayAttr(attrs
));
251 LogicalResult
spirv::Deserializer::processDecoration(ArrayRef
<uint32_t> words
) {
252 // TODO: This function should also be auto-generated. For now, since only a
253 // few decorations are processed/handled in a meaningful manner, going with a
254 // manual implementation.
255 if (words
.size() < 2) {
257 unknownLoc
, "OpDecorate must have at least result <id> and Decoration");
259 auto decorationName
=
260 stringifyDecoration(static_cast<spirv::Decoration
>(words
[1]));
261 if (decorationName
.empty()) {
262 return emitError(unknownLoc
, "invalid Decoration code : ") << words
[1];
264 auto symbol
= getSymbolDecoration(decorationName
);
265 switch (static_cast<spirv::Decoration
>(words
[1])) {
266 case spirv::Decoration::FPFastMathMode
:
267 if (words
.size() != 3) {
268 return emitError(unknownLoc
, "OpDecorate with ")
269 << decorationName
<< " needs a single integer literal";
271 decorations
[words
[0]].set(
272 symbol
, FPFastMathModeAttr::get(opBuilder
.getContext(),
273 static_cast<FPFastMathMode
>(words
[2])));
275 case spirv::Decoration::FPRoundingMode
:
276 if (words
.size() != 3) {
277 return emitError(unknownLoc
, "OpDecorate with ")
278 << decorationName
<< " needs a single integer literal";
280 decorations
[words
[0]].set(
281 symbol
, FPRoundingModeAttr::get(opBuilder
.getContext(),
282 static_cast<FPRoundingMode
>(words
[2])));
284 case spirv::Decoration::DescriptorSet
:
285 case spirv::Decoration::Binding
:
286 if (words
.size() != 3) {
287 return emitError(unknownLoc
, "OpDecorate with ")
288 << decorationName
<< " needs a single integer literal";
290 decorations
[words
[0]].set(
291 symbol
, opBuilder
.getI32IntegerAttr(static_cast<int32_t>(words
[2])));
293 case spirv::Decoration::BuiltIn
:
294 if (words
.size() != 3) {
295 return emitError(unknownLoc
, "OpDecorate with ")
296 << decorationName
<< " needs a single integer literal";
298 decorations
[words
[0]].set(
299 symbol
, opBuilder
.getStringAttr(
300 stringifyBuiltIn(static_cast<spirv::BuiltIn
>(words
[2]))));
302 case spirv::Decoration::ArrayStride
:
303 if (words
.size() != 3) {
304 return emitError(unknownLoc
, "OpDecorate with ")
305 << decorationName
<< " needs a single integer literal";
307 typeDecorations
[words
[0]] = words
[2];
309 case spirv::Decoration::LinkageAttributes
: {
310 if (words
.size() < 4) {
311 return emitError(unknownLoc
, "OpDecorate with ")
313 << " needs at least 1 string and 1 integer literal";
315 // LinkageAttributes has two parameters ["linkageName", linkageType]
316 // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
317 // "linkageName" is a stringliteral encoded as uint32_t,
318 // hence the size of name is variable length which results in words.size()
319 // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
320 // 3 + ceildiv(strlen(name), 4).
321 unsigned wordIndex
= 2;
322 auto linkageName
= spirv::decodeStringLiteral(words
, wordIndex
).str();
323 auto linkageTypeAttr
= opBuilder
.getAttr
<::mlir::spirv::LinkageTypeAttr
>(
324 static_cast<::mlir::spirv::LinkageType
>(words
[wordIndex
++]));
325 auto linkageAttr
= opBuilder
.getAttr
<::mlir::spirv::LinkageAttributesAttr
>(
326 StringAttr::get(context
, linkageName
), linkageTypeAttr
);
327 decorations
[words
[0]].set(symbol
, llvm::dyn_cast
<Attribute
>(linkageAttr
));
330 case spirv::Decoration::Aliased
:
331 case spirv::Decoration::AliasedPointer
:
332 case spirv::Decoration::Block
:
333 case spirv::Decoration::BufferBlock
:
334 case spirv::Decoration::Flat
:
335 case spirv::Decoration::NonReadable
:
336 case spirv::Decoration::NonWritable
:
337 case spirv::Decoration::NoPerspective
:
338 case spirv::Decoration::NoSignedWrap
:
339 case spirv::Decoration::NoUnsignedWrap
:
340 case spirv::Decoration::RelaxedPrecision
:
341 case spirv::Decoration::Restrict
:
342 case spirv::Decoration::RestrictPointer
:
343 case spirv::Decoration::NoContraction
:
344 case spirv::Decoration::Constant
:
345 if (words
.size() != 2) {
346 return emitError(unknownLoc
, "OpDecoration with ")
347 << decorationName
<< "needs a single target <id>";
349 // Block decoration does not affect spirv.struct type, but is still stored
351 // TODO: Update StructType to contain this information since
352 // it is needed for many validation rules.
353 decorations
[words
[0]].set(symbol
, opBuilder
.getUnitAttr());
355 case spirv::Decoration::Location
:
356 case spirv::Decoration::SpecId
:
357 if (words
.size() != 3) {
358 return emitError(unknownLoc
, "OpDecoration with ")
359 << decorationName
<< "needs a single integer literal";
361 decorations
[words
[0]].set(
362 symbol
, opBuilder
.getI32IntegerAttr(static_cast<int32_t>(words
[2])));
364 case spirv::Decoration::CacheControlLoadINTEL
: {
365 LogicalResult res
= deserializeCacheControlDecoration
<
366 CacheControlLoadINTELAttr
, LoadCacheControlAttr
, LoadCacheControl
>(
367 unknownLoc
, opBuilder
, decorations
, words
, symbol
, decorationName
,
373 case spirv::Decoration::CacheControlStoreINTEL
: {
374 LogicalResult res
= deserializeCacheControlDecoration
<
375 CacheControlStoreINTELAttr
, StoreCacheControlAttr
, StoreCacheControl
>(
376 unknownLoc
, opBuilder
, decorations
, words
, symbol
, decorationName
,
383 return emitError(unknownLoc
, "unhandled Decoration : '") << decorationName
;
389 spirv::Deserializer::processMemberDecoration(ArrayRef
<uint32_t> words
) {
390 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
391 if (words
.size() < 3) {
392 return emitError(unknownLoc
,
393 "OpMemberDecorate must have at least 3 operands");
396 auto decoration
= static_cast<spirv::Decoration
>(words
[2]);
397 if (decoration
== spirv::Decoration::Offset
&& words
.size() != 4) {
398 return emitError(unknownLoc
,
399 " missing offset specification in OpMemberDecorate with "
400 "Offset decoration");
402 ArrayRef
<uint32_t> decorationOperands
;
403 if (words
.size() > 3) {
404 decorationOperands
= words
.slice(3);
406 memberDecorationMap
[words
[0]][words
[1]][decoration
] = decorationOperands
;
410 LogicalResult
spirv::Deserializer::processMemberName(ArrayRef
<uint32_t> words
) {
411 if (words
.size() < 3) {
412 return emitError(unknownLoc
, "OpMemberName must have at least 3 operands");
414 unsigned wordIndex
= 2;
415 auto name
= decodeStringLiteral(words
, wordIndex
);
416 if (wordIndex
!= words
.size()) {
417 return emitError(unknownLoc
,
418 "unexpected trailing words in OpMemberName instruction");
420 memberNameMap
[words
[0]][words
[1]] = name
;
424 LogicalResult
spirv::Deserializer::setFunctionArgAttrs(
425 uint32_t argID
, SmallVectorImpl
<Attribute
> &argAttrs
, size_t argIndex
) {
426 if (!decorations
.contains(argID
)) {
427 argAttrs
[argIndex
] = DictionaryAttr::get(context
, {});
431 spirv::DecorationAttr foundDecorationAttr
;
432 for (NamedAttribute decAttr
: decorations
[argID
]) {
433 for (auto decoration
:
434 {spirv::Decoration::Aliased
, spirv::Decoration::Restrict
,
435 spirv::Decoration::AliasedPointer
,
436 spirv::Decoration::RestrictPointer
}) {
438 if (decAttr
.getName() !=
439 getSymbolDecoration(stringifyDecoration(decoration
)))
442 if (foundDecorationAttr
)
443 return emitError(unknownLoc
,
444 "more than one Aliased/Restrict decorations for "
445 "function argument with result <id> ")
448 foundDecorationAttr
= spirv::DecorationAttr::get(context
, decoration
);
453 if (!foundDecorationAttr
)
454 return emitError(unknownLoc
, "unimplemented decoration support for "
455 "function argument with result <id> ")
458 NamedAttribute
attr(StringAttr::get(context
, spirv::DecorationAttr::name
),
459 foundDecorationAttr
);
460 argAttrs
[argIndex
] = DictionaryAttr::get(context
, attr
);
465 spirv::Deserializer::processFunction(ArrayRef
<uint32_t> operands
) {
467 return emitError(unknownLoc
, "found function inside function");
470 // Get the result type
471 if (operands
.size() != 4) {
472 return emitError(unknownLoc
, "OpFunction must have 4 parameters");
474 Type resultType
= getType(operands
[0]);
476 return emitError(unknownLoc
, "undefined result type from <id> ")
480 uint32_t fnID
= operands
[1];
481 if (funcMap
.count(fnID
)) {
482 return emitError(unknownLoc
, "duplicate function definition/declaration");
485 auto fnControl
= spirv::symbolizeFunctionControl(operands
[2]);
487 return emitError(unknownLoc
, "unknown Function Control: ") << operands
[2];
490 Type fnType
= getType(operands
[3]);
491 if (!fnType
|| !isa
<FunctionType
>(fnType
)) {
492 return emitError(unknownLoc
, "unknown function type from <id> ")
495 auto functionType
= cast
<FunctionType
>(fnType
);
497 if ((isVoidType(resultType
) && functionType
.getNumResults() != 0) ||
498 (functionType
.getNumResults() == 1 &&
499 functionType
.getResult(0) != resultType
)) {
500 return emitError(unknownLoc
, "mismatch in function type ")
501 << functionType
<< " and return type " << resultType
<< " specified";
504 std::string fnName
= getFunctionSymbol(fnID
);
505 auto funcOp
= opBuilder
.create
<spirv::FuncOp
>(
506 unknownLoc
, fnName
, functionType
, fnControl
.value());
507 // Processing other function attributes.
508 if (decorations
.count(fnID
)) {
509 for (auto attr
: decorations
[fnID
].getAttrs()) {
510 funcOp
->setAttr(attr
.getName(), attr
.getValue());
513 curFunction
= funcMap
[fnID
] = funcOp
;
514 auto *entryBlock
= funcOp
.addEntryBlock();
517 << "//===-------------------------------------------===//\n";
518 logger
.startLine() << "[fn] name: " << fnName
<< "\n";
519 logger
.startLine() << "[fn] type: " << fnType
<< "\n";
520 logger
.startLine() << "[fn] ID: " << fnID
<< "\n";
521 logger
.startLine() << "[fn] entry block: " << entryBlock
<< "\n";
525 SmallVector
<Attribute
> argAttrs
;
526 argAttrs
.resize(functionType
.getNumInputs());
528 // Parse the op argument instructions
529 if (functionType
.getNumInputs()) {
530 for (size_t i
= 0, e
= functionType
.getNumInputs(); i
!= e
; ++i
) {
531 auto argType
= functionType
.getInput(i
);
532 spirv::Opcode opcode
= spirv::Opcode::OpNop
;
533 ArrayRef
<uint32_t> operands
;
534 if (failed(sliceInstruction(opcode
, operands
,
535 spirv::Opcode::OpFunctionParameter
))) {
538 if (opcode
!= spirv::Opcode::OpFunctionParameter
) {
541 "missing OpFunctionParameter instruction for argument ")
544 if (operands
.size() != 2) {
547 "expected result type and result <id> for OpFunctionParameter");
549 auto argDefinedType
= getType(operands
[0]);
550 if (!argDefinedType
|| argDefinedType
!= argType
) {
551 return emitError(unknownLoc
,
552 "mismatch in argument type between function type "
554 << functionType
<< " and argument type definition "
555 << argDefinedType
<< " at argument " << i
;
557 if (getValue(operands
[1])) {
558 return emitError(unknownLoc
, "duplicate definition of result <id> ")
561 if (failed(setFunctionArgAttrs(operands
[1], argAttrs
, i
))) {
565 auto argValue
= funcOp
.getArgument(i
);
566 valueMap
[operands
[1]] = argValue
;
570 if (llvm::any_of(argAttrs
, [](Attribute attr
) {
571 auto argAttr
= cast
<DictionaryAttr
>(attr
);
572 return !argAttr
.empty();
574 funcOp
.setArgAttrsAttr(ArrayAttr::get(context
, argAttrs
));
576 // entryBlock is needed to access the arguments, Once that is done, we can
577 // erase the block for functions with 'Import' LinkageAttributes, since these
578 // are essentially function declarations, so they have no body.
579 auto linkageAttr
= funcOp
.getLinkageAttributes();
580 auto hasImportLinkage
=
581 linkageAttr
&& (linkageAttr
.value().getLinkageType().getValue() ==
582 spirv::LinkageType::Import
);
583 if (hasImportLinkage
)
586 // RAII guard to reset the insertion point to the module's region after
587 // deserializing the body of this function.
588 OpBuilder::InsertionGuard
moduleInsertionGuard(opBuilder
);
590 spirv::Opcode opcode
= spirv::Opcode::OpNop
;
591 ArrayRef
<uint32_t> instOperands
;
593 // Special handling for the entry block. We need to make sure it starts with
594 // an OpLabel instruction. The entry block takes the same parameters as the
595 // function. All other blocks do not take any parameter. We have already
596 // created the entry block, here we need to register it to the correct label
598 if (failed(sliceInstruction(opcode
, instOperands
,
599 spirv::Opcode::OpFunctionEnd
))) {
602 if (opcode
== spirv::Opcode::OpFunctionEnd
) {
603 return processFunctionEnd(instOperands
);
605 if (opcode
!= spirv::Opcode::OpLabel
) {
606 return emitError(unknownLoc
, "a basic block must start with OpLabel");
608 if (instOperands
.size() != 1) {
609 return emitError(unknownLoc
, "OpLabel should only have result <id>");
611 blockMap
[instOperands
[0]] = entryBlock
;
612 if (failed(processLabel(instOperands
))) {
616 // Then process all the other instructions in the function until we hit
618 while (succeeded(sliceInstruction(opcode
, instOperands
,
619 spirv::Opcode::OpFunctionEnd
)) &&
620 opcode
!= spirv::Opcode::OpFunctionEnd
) {
621 if (failed(processInstruction(opcode
, instOperands
))) {
625 if (opcode
!= spirv::Opcode::OpFunctionEnd
) {
629 return processFunctionEnd(instOperands
);
633 spirv::Deserializer::processFunctionEnd(ArrayRef
<uint32_t> operands
) {
634 // Process OpFunctionEnd.
635 if (!operands
.empty()) {
636 return emitError(unknownLoc
, "unexpected operands for OpFunctionEnd");
639 // Wire up block arguments from OpPhi instructions.
640 // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
642 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
647 curFunction
= std::nullopt
;
652 << "//===-------------------------------------------===//\n";
657 std::optional
<std::pair
<Attribute
, Type
>>
658 spirv::Deserializer::getConstant(uint32_t id
) {
659 auto constIt
= constantMap
.find(id
);
660 if (constIt
== constantMap
.end())
662 return constIt
->getSecond();
665 std::optional
<spirv::SpecConstOperationMaterializationInfo
>
666 spirv::Deserializer::getSpecConstantOperation(uint32_t id
) {
667 auto constIt
= specConstOperationMap
.find(id
);
668 if (constIt
== specConstOperationMap
.end())
670 return constIt
->getSecond();
673 std::string
spirv::Deserializer::getFunctionSymbol(uint32_t id
) {
674 auto funcName
= nameMap
.lookup(id
).str();
675 if (funcName
.empty()) {
676 funcName
= "spirv_fn_" + std::to_string(id
);
681 std::string
spirv::Deserializer::getSpecConstantSymbol(uint32_t id
) {
682 auto constName
= nameMap
.lookup(id
).str();
683 if (constName
.empty()) {
684 constName
= "spirv_spec_const_" + std::to_string(id
);
689 spirv::SpecConstantOp
690 spirv::Deserializer::createSpecConstant(Location loc
, uint32_t resultID
,
691 TypedAttr defaultValue
) {
692 auto symName
= opBuilder
.getStringAttr(getSpecConstantSymbol(resultID
));
693 auto op
= opBuilder
.create
<spirv::SpecConstantOp
>(unknownLoc
, symName
,
695 if (decorations
.count(resultID
)) {
696 for (auto attr
: decorations
[resultID
].getAttrs())
697 op
->setAttr(attr
.getName(), attr
.getValue());
699 specConstMap
[resultID
] = op
;
704 spirv::Deserializer::processGlobalVariable(ArrayRef
<uint32_t> operands
) {
705 unsigned wordIndex
= 0;
706 if (operands
.size() < 3) {
709 "OpVariable needs at least 3 operands, type, <id> and storage class");
713 auto type
= getType(operands
[wordIndex
]);
715 return emitError(unknownLoc
, "unknown result type <id> : ")
716 << operands
[wordIndex
];
718 auto ptrType
= dyn_cast
<spirv::PointerType
>(type
);
720 return emitError(unknownLoc
,
721 "expected a result type <id> to be a spirv.ptr, found : ")
727 auto variableID
= operands
[wordIndex
];
728 auto variableName
= nameMap
.lookup(variableID
).str();
729 if (variableName
.empty()) {
730 variableName
= "spirv_var_" + std::to_string(variableID
);
735 auto storageClass
= static_cast<spirv::StorageClass
>(operands
[wordIndex
]);
736 if (ptrType
.getStorageClass() != storageClass
) {
737 return emitError(unknownLoc
, "mismatch in storage class of pointer type ")
738 << type
<< " and that specified in OpVariable instruction : "
739 << stringifyStorageClass(storageClass
);
744 FlatSymbolRefAttr initializer
= nullptr;
746 if (wordIndex
< operands
.size()) {
747 Operation
*op
= nullptr;
749 if (auto initOp
= getGlobalVariable(operands
[wordIndex
]))
751 else if (auto initOp
= getSpecConstant(operands
[wordIndex
]))
753 else if (auto initOp
= getSpecConstantComposite(operands
[wordIndex
]))
756 return emitError(unknownLoc
, "unknown <id> ")
757 << operands
[wordIndex
] << "used as initializer";
759 initializer
= SymbolRefAttr::get(op
);
762 if (wordIndex
!= operands
.size()) {
763 return emitError(unknownLoc
,
764 "found more operands than expected when deserializing "
765 "OpVariable instruction, only ")
766 << wordIndex
<< " of " << operands
.size() << " processed";
768 auto loc
= createFileLineColLoc(opBuilder
);
769 auto varOp
= opBuilder
.create
<spirv::GlobalVariableOp
>(
770 loc
, TypeAttr::get(type
), opBuilder
.getStringAttr(variableName
),
774 if (decorations
.count(variableID
)) {
775 for (auto attr
: decorations
[variableID
].getAttrs())
776 varOp
->setAttr(attr
.getName(), attr
.getValue());
778 globalVariableMap
[variableID
] = varOp
;
782 IntegerAttr
spirv::Deserializer::getConstantInt(uint32_t id
) {
783 auto constInfo
= getConstant(id
);
787 return dyn_cast
<IntegerAttr
>(constInfo
->first
);
790 LogicalResult
spirv::Deserializer::processName(ArrayRef
<uint32_t> operands
) {
791 if (operands
.size() < 2) {
792 return emitError(unknownLoc
, "OpName needs at least 2 operands");
794 if (!nameMap
.lookup(operands
[0]).empty()) {
795 return emitError(unknownLoc
, "duplicate name found for result <id> ")
798 unsigned wordIndex
= 1;
799 StringRef name
= decodeStringLiteral(operands
, wordIndex
);
800 if (wordIndex
!= operands
.size()) {
801 return emitError(unknownLoc
,
802 "unexpected trailing words in OpName instruction");
804 nameMap
[operands
[0]] = name
;
808 //===----------------------------------------------------------------------===//
810 //===----------------------------------------------------------------------===//
812 LogicalResult
spirv::Deserializer::processType(spirv::Opcode opcode
,
813 ArrayRef
<uint32_t> operands
) {
814 if (operands
.empty()) {
815 return emitError(unknownLoc
, "type instruction with opcode ")
816 << spirv::stringifyOpcode(opcode
) << " needs at least one <id>";
819 /// TODO: Types might be forward declared in some instructions and need to be
820 /// handled appropriately.
821 if (typeMap
.count(operands
[0])) {
822 return emitError(unknownLoc
, "duplicate definition for result <id> ")
827 case spirv::Opcode::OpTypeVoid
:
828 if (operands
.size() != 1)
829 return emitError(unknownLoc
, "OpTypeVoid must have no parameters");
830 typeMap
[operands
[0]] = opBuilder
.getNoneType();
832 case spirv::Opcode::OpTypeBool
:
833 if (operands
.size() != 1)
834 return emitError(unknownLoc
, "OpTypeBool must have no parameters");
835 typeMap
[operands
[0]] = opBuilder
.getI1Type();
837 case spirv::Opcode::OpTypeInt
: {
838 if (operands
.size() != 3)
840 unknownLoc
, "OpTypeInt must have bitwidth and signedness parameters");
842 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
843 // to preserve or validate.
844 // 0 indicates unsigned, or no signedness semantics
845 // 1 indicates signed semantics."
847 // So we cannot differentiate signless and unsigned integers; always use
848 // signless semantics for such cases.
849 auto sign
= operands
[2] == 1 ? IntegerType::SignednessSemantics::Signed
850 : IntegerType::SignednessSemantics::Signless
;
851 typeMap
[operands
[0]] = IntegerType::get(context
, operands
[1], sign
);
853 case spirv::Opcode::OpTypeFloat
: {
854 if (operands
.size() != 2)
855 return emitError(unknownLoc
, "OpTypeFloat must have bitwidth parameter");
858 switch (operands
[1]) {
860 floatTy
= opBuilder
.getF16Type();
863 floatTy
= opBuilder
.getF32Type();
866 floatTy
= opBuilder
.getF64Type();
869 return emitError(unknownLoc
, "unsupported OpTypeFloat bitwidth: ")
872 typeMap
[operands
[0]] = floatTy
;
874 case spirv::Opcode::OpTypeVector
: {
875 if (operands
.size() != 3) {
878 "OpTypeVector must have element type and count parameters");
880 Type elementTy
= getType(operands
[1]);
882 return emitError(unknownLoc
, "OpTypeVector references undefined <id> ")
885 typeMap
[operands
[0]] = VectorType::get({operands
[2]}, elementTy
);
887 case spirv::Opcode::OpTypePointer
: {
888 return processOpTypePointer(operands
);
890 case spirv::Opcode::OpTypeArray
:
891 return processArrayType(operands
);
892 case spirv::Opcode::OpTypeCooperativeMatrixKHR
:
893 return processCooperativeMatrixTypeKHR(operands
);
894 case spirv::Opcode::OpTypeFunction
:
895 return processFunctionType(operands
);
896 case spirv::Opcode::OpTypeImage
:
897 return processImageType(operands
);
898 case spirv::Opcode::OpTypeSampledImage
:
899 return processSampledImageType(operands
);
900 case spirv::Opcode::OpTypeRuntimeArray
:
901 return processRuntimeArrayType(operands
);
902 case spirv::Opcode::OpTypeStruct
:
903 return processStructType(operands
);
904 case spirv::Opcode::OpTypeMatrix
:
905 return processMatrixType(operands
);
907 return emitError(unknownLoc
, "unhandled type instruction");
913 spirv::Deserializer::processOpTypePointer(ArrayRef
<uint32_t> operands
) {
914 if (operands
.size() != 3)
915 return emitError(unknownLoc
, "OpTypePointer must have two parameters");
917 auto pointeeType
= getType(operands
[2]);
919 return emitError(unknownLoc
, "unknown OpTypePointer pointee type <id> ")
922 uint32_t typePointerID
= operands
[0];
923 auto storageClass
= static_cast<spirv::StorageClass
>(operands
[1]);
924 typeMap
[typePointerID
] = spirv::PointerType::get(pointeeType
, storageClass
);
926 for (auto *deferredStructIt
= std::begin(deferredStructTypesInfos
);
927 deferredStructIt
!= std::end(deferredStructTypesInfos
);) {
928 for (auto *unresolvedMemberIt
=
929 std::begin(deferredStructIt
->unresolvedMemberTypes
);
930 unresolvedMemberIt
!=
931 std::end(deferredStructIt
->unresolvedMemberTypes
);) {
932 if (unresolvedMemberIt
->first
== typePointerID
) {
933 // The newly constructed pointer type can resolve one of the
934 // deferred struct type members; update the memberTypes list and
935 // clean the unresolvedMemberTypes list accordingly.
936 deferredStructIt
->memberTypes
[unresolvedMemberIt
->second
] =
937 typeMap
[typePointerID
];
939 deferredStructIt
->unresolvedMemberTypes
.erase(unresolvedMemberIt
);
941 ++unresolvedMemberIt
;
945 if (deferredStructIt
->unresolvedMemberTypes
.empty()) {
946 // All deferred struct type members are now resolved, set the struct body.
947 auto structType
= deferredStructIt
->deferredStructType
;
949 assert(structType
&& "expected a spirv::StructType");
950 assert(structType
.isIdentified() && "expected an indentified struct");
952 if (failed(structType
.trySetBody(
953 deferredStructIt
->memberTypes
, deferredStructIt
->offsetInfo
,
954 deferredStructIt
->memberDecorationsInfo
)))
957 deferredStructIt
= deferredStructTypesInfos
.erase(deferredStructIt
);
967 spirv::Deserializer::processArrayType(ArrayRef
<uint32_t> operands
) {
968 if (operands
.size() != 3) {
969 return emitError(unknownLoc
,
970 "OpTypeArray must have element type and count parameters");
973 Type elementTy
= getType(operands
[1]);
975 return emitError(unknownLoc
, "OpTypeArray references undefined <id> ")
980 // TODO: The count can also come frome a specialization constant.
981 auto countInfo
= getConstant(operands
[2]);
983 return emitError(unknownLoc
, "OpTypeArray count <id> ")
984 << operands
[2] << "can only come from normal constant right now";
987 if (auto intVal
= dyn_cast
<IntegerAttr
>(countInfo
->first
)) {
988 count
= intVal
.getValue().getZExtValue();
990 return emitError(unknownLoc
, "OpTypeArray count must come from a "
991 "scalar integer constant instruction");
994 typeMap
[operands
[0]] = spirv::ArrayType::get(
995 elementTy
, count
, typeDecorations
.lookup(operands
[0]));
1000 spirv::Deserializer::processFunctionType(ArrayRef
<uint32_t> operands
) {
1001 assert(!operands
.empty() && "No operands for processing function type");
1002 if (operands
.size() == 1) {
1003 return emitError(unknownLoc
, "missing return type for OpTypeFunction");
1005 auto returnType
= getType(operands
[1]);
1007 return emitError(unknownLoc
, "unknown return type in OpTypeFunction");
1009 SmallVector
<Type
, 1> argTypes
;
1010 for (size_t i
= 2, e
= operands
.size(); i
< e
; ++i
) {
1011 auto ty
= getType(operands
[i
]);
1013 return emitError(unknownLoc
, "unknown argument type in OpTypeFunction");
1015 argTypes
.push_back(ty
);
1017 ArrayRef
<Type
> returnTypes
;
1018 if (!isVoidType(returnType
)) {
1019 returnTypes
= llvm::ArrayRef(returnType
);
1021 typeMap
[operands
[0]] = FunctionType::get(context
, argTypes
, returnTypes
);
1025 LogicalResult
spirv::Deserializer::processCooperativeMatrixTypeKHR(
1026 ArrayRef
<uint32_t> operands
) {
1027 if (operands
.size() != 6) {
1028 return emitError(unknownLoc
,
1029 "OpTypeCooperativeMatrixKHR must have element type, "
1030 "scope, row and column parameters, and use");
1033 Type elementTy
= getType(operands
[1]);
1035 return emitError(unknownLoc
,
1036 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1040 std::optional
<spirv::Scope
> scope
=
1041 spirv::symbolizeScope(getConstantInt(operands
[2]).getInt());
1045 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1049 unsigned rows
= getConstantInt(operands
[3]).getInt();
1050 unsigned columns
= getConstantInt(operands
[4]).getInt();
1052 std::optional
<spirv::CooperativeMatrixUseKHR
> use
=
1053 spirv::symbolizeCooperativeMatrixUseKHR(
1054 getConstantInt(operands
[5]).getInt());
1058 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1062 typeMap
[operands
[0]] =
1063 spirv::CooperativeMatrixType::get(elementTy
, rows
, columns
, *scope
, *use
);
1068 spirv::Deserializer::processRuntimeArrayType(ArrayRef
<uint32_t> operands
) {
1069 if (operands
.size() != 2) {
1070 return emitError(unknownLoc
, "OpTypeRuntimeArray must have two operands");
1072 Type memberType
= getType(operands
[1]);
1074 return emitError(unknownLoc
,
1075 "OpTypeRuntimeArray references undefined <id> ")
1078 typeMap
[operands
[0]] = spirv::RuntimeArrayType::get(
1079 memberType
, typeDecorations
.lookup(operands
[0]));
1084 spirv::Deserializer::processStructType(ArrayRef
<uint32_t> operands
) {
1085 // TODO: Find a way to handle identified structs when debug info is stripped.
1087 if (operands
.empty()) {
1088 return emitError(unknownLoc
, "OpTypeStruct must have at least result <id>");
1091 if (operands
.size() == 1) {
1092 // Handle empty struct.
1093 typeMap
[operands
[0]] =
1094 spirv::StructType::getEmpty(context
, nameMap
.lookup(operands
[0]).str());
1098 // First element is operand ID, second element is member index in the struct.
1099 SmallVector
<std::pair
<uint32_t, unsigned>, 0> unresolvedMemberTypes
;
1100 SmallVector
<Type
, 4> memberTypes
;
1102 for (auto op
: llvm::drop_begin(operands
, 1)) {
1103 Type memberType
= getType(op
);
1104 bool typeForwardPtr
= (typeForwardPointerIDs
.count(op
) != 0);
1106 if (!memberType
&& !typeForwardPtr
)
1107 return emitError(unknownLoc
, "OpTypeStruct references undefined <id> ")
1111 unresolvedMemberTypes
.emplace_back(op
, memberTypes
.size());
1113 memberTypes
.push_back(memberType
);
1116 SmallVector
<spirv::StructType::OffsetInfo
, 0> offsetInfo
;
1117 SmallVector
<spirv::StructType::MemberDecorationInfo
, 0> memberDecorationsInfo
;
1118 if (memberDecorationMap
.count(operands
[0])) {
1119 auto &allMemberDecorations
= memberDecorationMap
[operands
[0]];
1120 for (auto memberIndex
: llvm::seq
<uint32_t>(0, memberTypes
.size())) {
1121 if (allMemberDecorations
.count(memberIndex
)) {
1122 for (auto &memberDecoration
: allMemberDecorations
[memberIndex
]) {
1123 // Check for offset.
1124 if (memberDecoration
.first
== spirv::Decoration::Offset
) {
1125 // If offset info is empty, resize to the number of members;
1126 if (offsetInfo
.empty()) {
1127 offsetInfo
.resize(memberTypes
.size());
1129 offsetInfo
[memberIndex
] = memberDecoration
.second
[0];
1131 if (!memberDecoration
.second
.empty()) {
1132 memberDecorationsInfo
.emplace_back(memberIndex
, /*hasValue=*/1,
1133 memberDecoration
.first
,
1134 memberDecoration
.second
[0]);
1136 memberDecorationsInfo
.emplace_back(memberIndex
, /*hasValue=*/0,
1137 memberDecoration
.first
, 0);
1145 uint32_t structID
= operands
[0];
1146 std::string structIdentifier
= nameMap
.lookup(structID
).str();
1148 if (structIdentifier
.empty()) {
1149 assert(unresolvedMemberTypes
.empty() &&
1150 "didn't expect unresolved member types");
1152 spirv::StructType::get(memberTypes
, offsetInfo
, memberDecorationsInfo
);
1154 auto structTy
= spirv::StructType::getIdentified(context
, structIdentifier
);
1155 typeMap
[structID
] = structTy
;
1157 if (!unresolvedMemberTypes
.empty())
1158 deferredStructTypesInfos
.push_back({structTy
, unresolvedMemberTypes
,
1159 memberTypes
, offsetInfo
,
1160 memberDecorationsInfo
});
1161 else if (failed(structTy
.trySetBody(memberTypes
, offsetInfo
,
1162 memberDecorationsInfo
)))
1166 // TODO: Update StructType to have member name as attribute as
1172 spirv::Deserializer::processMatrixType(ArrayRef
<uint32_t> operands
) {
1173 if (operands
.size() != 3) {
1174 // Three operands are needed: result_id, column_type, and column_count
1175 return emitError(unknownLoc
, "OpTypeMatrix must have 3 operands"
1176 " (result_id, column_type, and column_count)");
1178 // Matrix columns must be of vector type
1179 Type elementTy
= getType(operands
[1]);
1181 return emitError(unknownLoc
,
1182 "OpTypeMatrix references undefined column type.")
1186 uint32_t colsCount
= operands
[2];
1187 typeMap
[operands
[0]] = spirv::MatrixType::get(elementTy
, colsCount
);
1192 spirv::Deserializer::processTypeForwardPointer(ArrayRef
<uint32_t> operands
) {
1193 if (operands
.size() != 2)
1194 return emitError(unknownLoc
,
1195 "OpTypeForwardPointer instruction must have two operands");
1197 typeForwardPointerIDs
.insert(operands
[0]);
1198 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1199 // instruction that defines the actual type.
1205 spirv::Deserializer::processImageType(ArrayRef
<uint32_t> operands
) {
1206 // TODO: Add support for Access Qualifier.
1207 if (operands
.size() != 8)
1210 "OpTypeImage with non-eight operands are not supported yet");
1212 Type elementTy
= getType(operands
[1]);
1214 return emitError(unknownLoc
, "OpTypeImage references undefined <id>: ")
1217 auto dim
= spirv::symbolizeDim(operands
[2]);
1219 return emitError(unknownLoc
, "unknown Dim for OpTypeImage: ")
1222 auto depthInfo
= spirv::symbolizeImageDepthInfo(operands
[3]);
1224 return emitError(unknownLoc
, "unknown Depth for OpTypeImage: ")
1227 auto arrayedInfo
= spirv::symbolizeImageArrayedInfo(operands
[4]);
1229 return emitError(unknownLoc
, "unknown Arrayed for OpTypeImage: ")
1232 auto samplingInfo
= spirv::symbolizeImageSamplingInfo(operands
[5]);
1234 return emitError(unknownLoc
, "unknown MS for OpTypeImage: ") << operands
[5];
1236 auto samplerUseInfo
= spirv::symbolizeImageSamplerUseInfo(operands
[6]);
1237 if (!samplerUseInfo
)
1238 return emitError(unknownLoc
, "unknown Sampled for OpTypeImage: ")
1241 auto format
= spirv::symbolizeImageFormat(operands
[7]);
1243 return emitError(unknownLoc
, "unknown Format for OpTypeImage: ")
1246 typeMap
[operands
[0]] = spirv::ImageType::get(
1247 elementTy
, dim
.value(), depthInfo
.value(), arrayedInfo
.value(),
1248 samplingInfo
.value(), samplerUseInfo
.value(), format
.value());
1253 spirv::Deserializer::processSampledImageType(ArrayRef
<uint32_t> operands
) {
1254 if (operands
.size() != 2)
1255 return emitError(unknownLoc
, "OpTypeSampledImage must have two operands");
1257 Type elementTy
= getType(operands
[1]);
1259 return emitError(unknownLoc
,
1260 "OpTypeSampledImage references undefined <id>: ")
1263 typeMap
[operands
[0]] = spirv::SampledImageType::get(elementTy
);
1267 //===----------------------------------------------------------------------===//
1269 //===----------------------------------------------------------------------===//
1271 LogicalResult
spirv::Deserializer::processConstant(ArrayRef
<uint32_t> operands
,
1273 StringRef opname
= isSpec
? "OpSpecConstant" : "OpConstant";
1275 if (operands
.size() < 2) {
1276 return emitError(unknownLoc
)
1277 << opname
<< " must have type <id> and result <id>";
1279 if (operands
.size() < 3) {
1280 return emitError(unknownLoc
)
1281 << opname
<< " must have at least 1 more parameter";
1284 Type resultType
= getType(operands
[0]);
1286 return emitError(unknownLoc
, "undefined result type from <id> ")
1290 auto checkOperandSizeForBitwidth
= [&](unsigned bitwidth
) -> LogicalResult
{
1291 if (bitwidth
== 64) {
1292 if (operands
.size() == 4) {
1295 return emitError(unknownLoc
)
1296 << opname
<< " should have 2 parameters for 64-bit values";
1298 if (bitwidth
<= 32) {
1299 if (operands
.size() == 3) {
1303 return emitError(unknownLoc
)
1305 << " should have 1 parameter for values with no more than 32 bits";
1307 return emitError(unknownLoc
, "unsupported OpConstant bitwidth: ")
1311 auto resultID
= operands
[1];
1313 if (auto intType
= dyn_cast
<IntegerType
>(resultType
)) {
1314 auto bitwidth
= intType
.getWidth();
1315 if (failed(checkOperandSizeForBitwidth(bitwidth
))) {
1320 if (bitwidth
== 64) {
1321 // 64-bit integers are represented with two SPIR-V words. According to
1322 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1323 // literal’s low-order words appear first."
1327 } words
= {operands
[2], operands
[3]};
1328 value
= APInt(64, llvm::bit_cast
<uint64_t>(words
), /*isSigned=*/true);
1329 } else if (bitwidth
<= 32) {
1330 value
= APInt(bitwidth
, operands
[2], /*isSigned=*/true,
1331 /*implicitTrunc=*/true);
1334 auto attr
= opBuilder
.getIntegerAttr(intType
, value
);
1337 createSpecConstant(unknownLoc
, resultID
, attr
);
1339 // For normal constants, we just record the attribute (and its type) for
1340 // later materialization at use sites.
1341 constantMap
.try_emplace(resultID
, attr
, intType
);
1347 if (auto floatType
= dyn_cast
<FloatType
>(resultType
)) {
1348 auto bitwidth
= floatType
.getWidth();
1349 if (failed(checkOperandSizeForBitwidth(bitwidth
))) {
1354 if (floatType
.isF64()) {
1355 // Double values are represented with two SPIR-V words. According to
1356 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1357 // literal’s low-order words appear first."
1361 } words
= {operands
[2], operands
[3]};
1362 value
= APFloat(llvm::bit_cast
<double>(words
));
1363 } else if (floatType
.isF32()) {
1364 value
= APFloat(llvm::bit_cast
<float>(operands
[2]));
1365 } else if (floatType
.isF16()) {
1366 APInt
data(16, operands
[2]);
1367 value
= APFloat(APFloat::IEEEhalf(), data
);
1370 auto attr
= opBuilder
.getFloatAttr(floatType
, value
);
1372 createSpecConstant(unknownLoc
, resultID
, attr
);
1374 // For normal constants, we just record the attribute (and its type) for
1375 // later materialization at use sites.
1376 constantMap
.try_emplace(resultID
, attr
, floatType
);
1382 return emitError(unknownLoc
, "OpConstant can only generate values of "
1383 "scalar integer or floating-point type");
1386 LogicalResult
spirv::Deserializer::processConstantBool(
1387 bool isTrue
, ArrayRef
<uint32_t> operands
, bool isSpec
) {
1388 if (operands
.size() != 2) {
1389 return emitError(unknownLoc
, "Op")
1390 << (isSpec
? "Spec" : "") << "Constant"
1391 << (isTrue
? "True" : "False")
1392 << " must have type <id> and result <id>";
1395 auto attr
= opBuilder
.getBoolAttr(isTrue
);
1396 auto resultID
= operands
[1];
1398 createSpecConstant(unknownLoc
, resultID
, attr
);
1400 // For normal constants, we just record the attribute (and its type) for
1401 // later materialization at use sites.
1402 constantMap
.try_emplace(resultID
, attr
, opBuilder
.getI1Type());
1409 spirv::Deserializer::processConstantComposite(ArrayRef
<uint32_t> operands
) {
1410 if (operands
.size() < 2) {
1411 return emitError(unknownLoc
,
1412 "OpConstantComposite must have type <id> and result <id>");
1414 if (operands
.size() < 3) {
1415 return emitError(unknownLoc
,
1416 "OpConstantComposite must have at least 1 parameter");
1419 Type resultType
= getType(operands
[0]);
1421 return emitError(unknownLoc
, "undefined result type from <id> ")
1425 SmallVector
<Attribute
, 4> elements
;
1426 elements
.reserve(operands
.size() - 2);
1427 for (unsigned i
= 2, e
= operands
.size(); i
< e
; ++i
) {
1428 auto elementInfo
= getConstant(operands
[i
]);
1430 return emitError(unknownLoc
, "OpConstantComposite component <id> ")
1431 << operands
[i
] << " must come from a normal constant";
1433 elements
.push_back(elementInfo
->first
);
1436 auto resultID
= operands
[1];
1437 if (auto vectorType
= dyn_cast
<VectorType
>(resultType
)) {
1438 auto attr
= DenseElementsAttr::get(vectorType
, elements
);
1439 // For normal constants, we just record the attribute (and its type) for
1440 // later materialization at use sites.
1441 constantMap
.try_emplace(resultID
, attr
, resultType
);
1442 } else if (auto arrayType
= dyn_cast
<spirv::ArrayType
>(resultType
)) {
1443 auto attr
= opBuilder
.getArrayAttr(elements
);
1444 constantMap
.try_emplace(resultID
, attr
, resultType
);
1446 return emitError(unknownLoc
, "unsupported OpConstantComposite type: ")
1454 spirv::Deserializer::processSpecConstantComposite(ArrayRef
<uint32_t> operands
) {
1455 if (operands
.size() < 2) {
1456 return emitError(unknownLoc
,
1457 "OpConstantComposite must have type <id> and result <id>");
1459 if (operands
.size() < 3) {
1460 return emitError(unknownLoc
,
1461 "OpConstantComposite must have at least 1 parameter");
1464 Type resultType
= getType(operands
[0]);
1466 return emitError(unknownLoc
, "undefined result type from <id> ")
1470 auto resultID
= operands
[1];
1471 auto symName
= opBuilder
.getStringAttr(getSpecConstantSymbol(resultID
));
1473 SmallVector
<Attribute
, 4> elements
;
1474 elements
.reserve(operands
.size() - 2);
1475 for (unsigned i
= 2, e
= operands
.size(); i
< e
; ++i
) {
1476 auto elementInfo
= getSpecConstant(operands
[i
]);
1477 elements
.push_back(SymbolRefAttr::get(elementInfo
));
1480 auto op
= opBuilder
.create
<spirv::SpecConstantCompositeOp
>(
1481 unknownLoc
, TypeAttr::get(resultType
), symName
,
1482 opBuilder
.getArrayAttr(elements
));
1483 specConstCompositeMap
[resultID
] = op
;
1489 spirv::Deserializer::processSpecConstantOperation(ArrayRef
<uint32_t> operands
) {
1490 if (operands
.size() < 3)
1491 return emitError(unknownLoc
, "OpConstantOperation must have type <id>, "
1492 "result <id>, and operand opcode");
1494 uint32_t resultTypeID
= operands
[0];
1496 if (!getType(resultTypeID
))
1497 return emitError(unknownLoc
, "undefined result type from <id> ")
1500 uint32_t resultID
= operands
[1];
1501 spirv::Opcode enclosedOpcode
= static_cast<spirv::Opcode
>(operands
[2]);
1502 auto emplaceResult
= specConstOperationMap
.try_emplace(
1504 SpecConstOperationMaterializationInfo
{
1505 enclosedOpcode
, resultTypeID
,
1506 SmallVector
<uint32_t>{operands
.begin() + 3, operands
.end()}});
1508 if (!emplaceResult
.second
)
1509 return emitError(unknownLoc
, "value with <id>: ")
1510 << resultID
<< " is probably defined before.";
1515 Value
spirv::Deserializer::materializeSpecConstantOperation(
1516 uint32_t resultID
, spirv::Opcode enclosedOpcode
, uint32_t resultTypeID
,
1517 ArrayRef
<uint32_t> enclosedOpOperands
) {
1519 Type resultType
= getType(resultTypeID
);
1521 // Instructions wrapped by OpSpecConstantOp need an ID for their
1522 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1523 // dialect wrapped op. For that purpose, a new value map is created and "fake"
1524 // ID in that map is assigned to the result of the enclosed instruction. Note
1525 // that there is no need to update this fake ID since we only need to
1526 // reference the created Value for the enclosed op from the spv::YieldOp
1527 // created later in this method (both of which are the only values in their
1528 // region: the SpecConstantOperation's region). If we encounter another
1529 // SpecConstantOperation in the module, we simply re-use the fake ID since the
1530 // previous Value assigned to it isn't visible in the current scope anyway.
1531 DenseMap
<uint32_t, Value
> newValueMap
;
1532 llvm::SaveAndRestore
valueMapGuard(valueMap
, newValueMap
);
1533 constexpr uint32_t fakeID
= static_cast<uint32_t>(-3);
1535 SmallVector
<uint32_t, 4> enclosedOpResultTypeAndOperands
;
1536 enclosedOpResultTypeAndOperands
.push_back(resultTypeID
);
1537 enclosedOpResultTypeAndOperands
.push_back(fakeID
);
1538 enclosedOpResultTypeAndOperands
.append(enclosedOpOperands
.begin(),
1539 enclosedOpOperands
.end());
1541 // Process enclosed instruction before creating the enclosing
1542 // specConstantOperation (and its region). This way, references to constants,
1543 // global variables, and spec constants will be materialized outside the new
1544 // op's region. For more info, see Deserializer::getValue's implementation.
1546 processInstruction(enclosedOpcode
, enclosedOpResultTypeAndOperands
)))
1549 // Since the enclosed op is emitted in the current block, split it in a
1550 // separate new block.
1551 Block
*enclosedBlock
= curBlock
->splitBlock(&curBlock
->back());
1553 auto loc
= createFileLineColLoc(opBuilder
);
1554 auto specConstOperationOp
=
1555 opBuilder
.create
<spirv::SpecConstantOperationOp
>(loc
, resultType
);
1557 Region
&body
= specConstOperationOp
.getBody();
1558 // Move the new block into SpecConstantOperation's body.
1559 body
.getBlocks().splice(body
.end(), curBlock
->getParent()->getBlocks(),
1560 Region::iterator(enclosedBlock
));
1561 Block
&block
= body
.back();
1563 // RAII guard to reset the insertion point to the module's region after
1564 // deserializing the body of the specConstantOperation.
1565 OpBuilder::InsertionGuard
moduleInsertionGuard(opBuilder
);
1566 opBuilder
.setInsertionPointToEnd(&block
);
1568 opBuilder
.create
<spirv::YieldOp
>(loc
, block
.front().getResult(0));
1569 return specConstOperationOp
.getResult();
1573 spirv::Deserializer::processConstantNull(ArrayRef
<uint32_t> operands
) {
1574 if (operands
.size() != 2) {
1575 return emitError(unknownLoc
,
1576 "OpConstantNull must have type <id> and result <id>");
1579 Type resultType
= getType(operands
[0]);
1581 return emitError(unknownLoc
, "undefined result type from <id> ")
1585 auto resultID
= operands
[1];
1586 if (resultType
.isIntOrFloat() || isa
<VectorType
>(resultType
)) {
1587 auto attr
= opBuilder
.getZeroAttr(resultType
);
1588 // For normal constants, we just record the attribute (and its type) for
1589 // later materialization at use sites.
1590 constantMap
.try_emplace(resultID
, attr
, resultType
);
1594 return emitError(unknownLoc
, "unsupported OpConstantNull type: ")
1598 //===----------------------------------------------------------------------===//
1600 //===----------------------------------------------------------------------===//
1602 Block
*spirv::Deserializer::getOrCreateBlock(uint32_t id
) {
1603 if (auto *block
= getBlock(id
)) {
1604 LLVM_DEBUG(logger
.startLine() << "[block] got exiting block for id = " << id
1605 << " @ " << block
<< "\n");
1609 // We don't know where this block will be placed finally (in a
1610 // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1611 // function for now and sort out the proper place later.
1612 auto *block
= curFunction
->addBlock();
1613 LLVM_DEBUG(logger
.startLine() << "[block] created block for id = " << id
1614 << " @ " << block
<< "\n");
1615 return blockMap
[id
] = block
;
1618 LogicalResult
spirv::Deserializer::processBranch(ArrayRef
<uint32_t> operands
) {
1620 return emitError(unknownLoc
, "OpBranch must appear inside a block");
1623 if (operands
.size() != 1) {
1624 return emitError(unknownLoc
, "OpBranch must take exactly one target label");
1627 auto *target
= getOrCreateBlock(operands
[0]);
1628 auto loc
= createFileLineColLoc(opBuilder
);
1629 // The preceding instruction for the OpBranch instruction could be an
1630 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1631 // the same OpLine information.
1632 opBuilder
.create
<spirv::BranchOp
>(loc
, target
);
1639 spirv::Deserializer::processBranchConditional(ArrayRef
<uint32_t> operands
) {
1641 return emitError(unknownLoc
,
1642 "OpBranchConditional must appear inside a block");
1645 if (operands
.size() != 3 && operands
.size() != 5) {
1646 return emitError(unknownLoc
,
1647 "OpBranchConditional must have condition, true label, "
1648 "false label, and optionally two branch weights");
1651 auto condition
= getValue(operands
[0]);
1652 auto *trueBlock
= getOrCreateBlock(operands
[1]);
1653 auto *falseBlock
= getOrCreateBlock(operands
[2]);
1655 std::optional
<std::pair
<uint32_t, uint32_t>> weights
;
1656 if (operands
.size() == 5) {
1657 weights
= std::make_pair(operands
[3], operands
[4]);
1659 // The preceding instruction for the OpBranchConditional instruction could be
1660 // an OpSelectionMerge instruction, in this case they will have the same
1661 // OpLine information.
1662 auto loc
= createFileLineColLoc(opBuilder
);
1663 opBuilder
.create
<spirv::BranchConditionalOp
>(
1664 loc
, condition
, trueBlock
,
1665 /*trueArguments=*/ArrayRef
<Value
>(), falseBlock
,
1666 /*falseArguments=*/ArrayRef
<Value
>(), weights
);
1672 LogicalResult
spirv::Deserializer::processLabel(ArrayRef
<uint32_t> operands
) {
1674 return emitError(unknownLoc
, "OpLabel must appear inside a function");
1677 if (operands
.size() != 1) {
1678 return emitError(unknownLoc
, "OpLabel should only have result <id>");
1681 auto labelID
= operands
[0];
1682 // We may have forward declared this block.
1683 auto *block
= getOrCreateBlock(labelID
);
1684 LLVM_DEBUG(logger
.startLine()
1685 << "[block] populating block " << block
<< "\n");
1686 // If we have seen this block, make sure it was just a forward declaration.
1687 assert(block
->empty() && "re-deserialize the same block!");
1689 opBuilder
.setInsertionPointToStart(block
);
1690 blockMap
[labelID
] = curBlock
= block
;
1696 spirv::Deserializer::processSelectionMerge(ArrayRef
<uint32_t> operands
) {
1698 return emitError(unknownLoc
, "OpSelectionMerge must appear in a block");
1701 if (operands
.size() < 2) {
1704 "OpSelectionMerge must specify merge target and selection control");
1707 auto *mergeBlock
= getOrCreateBlock(operands
[0]);
1708 auto loc
= createFileLineColLoc(opBuilder
);
1709 auto selectionControl
= operands
[1];
1711 if (!blockMergeInfo
.try_emplace(curBlock
, loc
, selectionControl
, mergeBlock
)
1715 "a block cannot have more than one OpSelectionMerge instruction");
1722 spirv::Deserializer::processLoopMerge(ArrayRef
<uint32_t> operands
) {
1724 return emitError(unknownLoc
, "OpLoopMerge must appear in a block");
1727 if (operands
.size() < 3) {
1728 return emitError(unknownLoc
, "OpLoopMerge must specify merge target, "
1729 "continue target and loop control");
1732 auto *mergeBlock
= getOrCreateBlock(operands
[0]);
1733 auto *continueBlock
= getOrCreateBlock(operands
[1]);
1734 auto loc
= createFileLineColLoc(opBuilder
);
1735 uint32_t loopControl
= operands
[2];
1738 .try_emplace(curBlock
, loc
, loopControl
, mergeBlock
, continueBlock
)
1742 "a block cannot have more than one OpLoopMerge instruction");
1748 LogicalResult
spirv::Deserializer::processPhi(ArrayRef
<uint32_t> operands
) {
1750 return emitError(unknownLoc
, "OpPhi must appear in a block");
1753 if (operands
.size() < 4) {
1754 return emitError(unknownLoc
, "OpPhi must specify result type, result <id>, "
1755 "and variable-parent pairs");
1758 // Create a block argument for this OpPhi instruction.
1759 Type blockArgType
= getType(operands
[0]);
1760 BlockArgument blockArg
= curBlock
->addArgument(blockArgType
, unknownLoc
);
1761 valueMap
[operands
[1]] = blockArg
;
1762 LLVM_DEBUG(logger
.startLine()
1763 << "[phi] created block argument " << blockArg
1764 << " id = " << operands
[1] << " of type " << blockArgType
<< "\n");
1766 // For each (value, predecessor) pair, insert the value to the predecessor's
1767 // blockPhiInfo entry so later we can fix the block argument there.
1768 for (unsigned i
= 2, e
= operands
.size(); i
< e
; i
+= 2) {
1769 uint32_t value
= operands
[i
];
1770 Block
*predecessor
= getOrCreateBlock(operands
[i
+ 1]);
1771 std::pair
<Block
*, Block
*> predecessorTargetPair
{predecessor
, curBlock
};
1772 blockPhiInfo
[predecessorTargetPair
].push_back(value
);
1773 LLVM_DEBUG(logger
.startLine() << "[phi] predecessor @ " << predecessor
1774 << " with arg id = " << value
<< "\n");
1781 /// A class for putting all blocks in a structured selection/loop in a
1782 /// spirv.mlir.selection/spirv.mlir.loop op.
1783 class ControlFlowStructurizer
{
1786 ControlFlowStructurizer(Location loc
, uint32_t control
,
1787 spirv::BlockMergeInfoMap
&mergeInfo
, Block
*header
,
1788 Block
*merge
, Block
*cont
,
1789 llvm::ScopedPrinter
&logger
)
1790 : location(loc
), control(control
), blockMergeInfo(mergeInfo
),
1791 headerBlock(header
), mergeBlock(merge
), continueBlock(cont
),
1794 ControlFlowStructurizer(Location loc
, uint32_t control
,
1795 spirv::BlockMergeInfoMap
&mergeInfo
, Block
*header
,
1796 Block
*merge
, Block
*cont
)
1797 : location(loc
), control(control
), blockMergeInfo(mergeInfo
),
1798 headerBlock(header
), mergeBlock(merge
), continueBlock(cont
) {}
1801 /// Structurizes the loop at the given `headerBlock`.
1803 /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1804 /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1805 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1806 /// method will also update `mergeInfo` by remapping all blocks inside to the
1807 /// newly cloned ones inside structured control flow op's regions.
1808 LogicalResult
structurize();
1811 /// Creates a new spirv.mlir.selection op at the beginning of the
1813 spirv::SelectionOp
createSelectionOp(uint32_t selectionControl
);
1815 /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1816 spirv::LoopOp
createLoopOp(uint32_t loopControl
);
1818 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1819 void collectBlocksInConstruct();
1824 spirv::BlockMergeInfoMap
&blockMergeInfo
;
1828 Block
*continueBlock
; // nullptr for spirv.mlir.selection
1830 SetVector
<Block
*> constructBlocks
;
1833 /// A logger used to emit information during the deserialzation process.
1834 llvm::ScopedPrinter
&logger
;
1840 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl
) {
1841 // Create a builder and set the insertion point to the beginning of the
1842 // merge block so that the newly created SelectionOp will be inserted there.
1843 OpBuilder
builder(&mergeBlock
->front());
1845 auto control
= static_cast<spirv::SelectionControl
>(selectionControl
);
1846 auto selectionOp
= builder
.create
<spirv::SelectionOp
>(location
, control
);
1847 selectionOp
.addMergeBlock(builder
);
1852 spirv::LoopOp
ControlFlowStructurizer::createLoopOp(uint32_t loopControl
) {
1853 // Create a builder and set the insertion point to the beginning of the
1854 // merge block so that the newly created LoopOp will be inserted there.
1855 OpBuilder
builder(&mergeBlock
->front());
1857 auto control
= static_cast<spirv::LoopControl
>(loopControl
);
1858 auto loopOp
= builder
.create
<spirv::LoopOp
>(location
, control
);
1859 loopOp
.addEntryAndMergeBlock(builder
);
1864 void ControlFlowStructurizer::collectBlocksInConstruct() {
1865 assert(constructBlocks
.empty() && "expected empty constructBlocks");
1867 // Put the header block in the work list first.
1868 constructBlocks
.insert(headerBlock
);
1870 // For each item in the work list, add its successors excluding the merge
1872 for (unsigned i
= 0; i
< constructBlocks
.size(); ++i
) {
1873 for (auto *successor
: constructBlocks
[i
]->getSuccessors())
1874 if (successor
!= mergeBlock
)
1875 constructBlocks
.insert(successor
);
1879 LogicalResult
ControlFlowStructurizer::structurize() {
1880 Operation
*op
= nullptr;
1881 bool isLoop
= continueBlock
!= nullptr;
1883 if (auto loopOp
= createLoopOp(control
))
1884 op
= loopOp
.getOperation();
1886 if (auto selectionOp
= createSelectionOp(control
))
1887 op
= selectionOp
.getOperation();
1891 Region
&body
= op
->getRegion(0);
1894 // All references to the old merge block should be directed to the
1895 // selection/loop merge block in the SelectionOp/LoopOp's region.
1896 mapper
.map(mergeBlock
, &body
.back());
1898 collectBlocksInConstruct();
1900 // We've identified all blocks belonging to the selection/loop's region. Now
1901 // need to "move" them into the selection/loop. Instead of really moving the
1902 // blocks, in the following we copy them and remap all values and branches.
1904 // * Inserting a block into a region requires the block not in any region
1905 // before. But selections/loops can nest so we can create selection/loop ops
1906 // in a nested manner, which means some blocks may already be in a
1907 // selection/loop region when to be moved again.
1908 // * It's much trickier to fix up the branches into and out of the loop's
1909 // region: we need to treat not-moved blocks and moved blocks differently:
1910 // Not-moved blocks jumping to the loop header block need to jump to the
1911 // merge point containing the new loop op but not the loop continue block's
1912 // back edge. Moved blocks jumping out of the loop need to jump to the
1913 // merge block inside the loop region but not other not-moved blocks.
1914 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1917 // Create a corresponding block in the SelectionOp/LoopOp's region for each
1918 // block in this loop construct.
1919 OpBuilder
builder(body
);
1920 for (auto *block
: constructBlocks
) {
1921 // Create a block and insert it before the selection/loop merge block in the
1922 // SelectionOp/LoopOp's region.
1923 auto *newBlock
= builder
.createBlock(&body
.back());
1924 mapper
.map(block
, newBlock
);
1925 LLVM_DEBUG(logger
.startLine() << "[cf] cloned block " << newBlock
1926 << " from block " << block
<< "\n");
1927 if (!isFnEntryBlock(block
)) {
1928 for (BlockArgument blockArg
: block
->getArguments()) {
1930 newBlock
->addArgument(blockArg
.getType(), blockArg
.getLoc());
1931 mapper
.map(blockArg
, newArg
);
1932 LLVM_DEBUG(logger
.startLine() << "[cf] remapped block argument "
1933 << blockArg
<< " to " << newArg
<< "\n");
1936 LLVM_DEBUG(logger
.startLine()
1937 << "[cf] block " << block
<< " is a function entry block\n");
1940 for (auto &op
: *block
)
1941 newBlock
->push_back(op
.clone(mapper
));
1944 // Go through all ops and remap the operands.
1945 auto remapOperands
= [&](Operation
*op
) {
1946 for (auto &operand
: op
->getOpOperands())
1947 if (Value mappedOp
= mapper
.lookupOrNull(operand
.get()))
1948 operand
.set(mappedOp
);
1949 for (auto &succOp
: op
->getBlockOperands())
1950 if (Block
*mappedOp
= mapper
.lookupOrNull(succOp
.get()))
1951 succOp
.set(mappedOp
);
1953 for (auto &block
: body
)
1954 block
.walk(remapOperands
);
1956 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1957 // the selection/loop construct into its region. Next we need to fix the
1958 // connections between this new SelectionOp/LoopOp with existing blocks.
1960 // All existing incoming branches should go to the merge block, where the
1961 // SelectionOp/LoopOp resides right now.
1962 headerBlock
->replaceAllUsesWith(mergeBlock
);
1965 logger
.startLine() << "[cf] after cloning and fixing references:\n";
1966 headerBlock
->getParentOp()->print(logger
.getOStream());
1967 logger
.startLine() << "\n";
1971 if (!mergeBlock
->args_empty()) {
1972 return mergeBlock
->getParentOp()->emitError(
1973 "OpPhi in loop merge block unsupported");
1976 // The loop header block may have block arguments. Since now we place the
1977 // loop op inside the old merge block, we need to make sure the old merge
1978 // block has the same block argument list.
1979 for (BlockArgument blockArg
: headerBlock
->getArguments())
1980 mergeBlock
->addArgument(blockArg
.getType(), blockArg
.getLoc());
1982 // If the loop header block has block arguments, make sure the spirv.Branch
1984 SmallVector
<Value
, 4> blockArgs
;
1985 if (!headerBlock
->args_empty())
1986 blockArgs
= {mergeBlock
->args_begin(), mergeBlock
->args_end()};
1988 // The loop entry block should have a unconditional branch jumping to the
1989 // loop header block.
1990 builder
.setInsertionPointToEnd(&body
.front());
1991 builder
.create
<spirv::BranchOp
>(location
, mapper
.lookupOrNull(headerBlock
),
1992 ArrayRef
<Value
>(blockArgs
));
1995 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1997 LLVM_DEBUG(logger
.startLine() << "[cf] cleaning up blocks after clone\n");
1998 // First we need to drop all operands' references inside all blocks. This is
1999 // needed because we can have blocks referencing SSA values from one another.
2000 for (auto *block
: constructBlocks
)
2001 block
->dropAllReferences();
2003 // Check that whether some op in the to-be-erased blocks still has uses. Those
2004 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2005 // region. We cannot handle such cases given that once a value is sinked into
2006 // the SelectionOp/LoopOp's region, there is no escape for it:
2007 // SelectionOp/LooOp does not support yield values right now.
2008 for (auto *block
: constructBlocks
) {
2009 for (Operation
&op
: *block
)
2010 if (!op
.use_empty())
2011 return op
.emitOpError(
2012 "failed control flow structurization: it has uses outside of the "
2013 "enclosing selection/loop construct");
2016 // Then erase all old blocks.
2017 for (auto *block
: constructBlocks
) {
2018 // We've cloned all blocks belonging to this construct into the structured
2019 // control flow op's region. Among these blocks, some may compose another
2020 // selection/loop. If so, they will be recorded within blockMergeInfo.
2021 // We need to update the pointers there to the newly remapped ones so we can
2022 // continue structurizing them later.
2023 // TODO: The asserts in the following assumes input SPIR-V blob forms
2024 // correctly nested selection/loop constructs. We should relax this and
2025 // support error cases better.
2026 auto it
= blockMergeInfo
.find(block
);
2027 if (it
!= blockMergeInfo
.end()) {
2028 // Use the original location for nested selection/loop ops.
2029 Location loc
= it
->second
.loc
;
2031 Block
*newHeader
= mapper
.lookupOrNull(block
);
2033 return emitError(loc
, "failed control flow structurization: nested "
2034 "loop header block should be remapped!");
2036 Block
*newContinue
= it
->second
.continueBlock
;
2038 newContinue
= mapper
.lookupOrNull(newContinue
);
2040 return emitError(loc
, "failed control flow structurization: nested "
2041 "loop continue block should be remapped!");
2044 Block
*newMerge
= it
->second
.mergeBlock
;
2045 if (Block
*mappedTo
= mapper
.lookupOrNull(newMerge
))
2046 newMerge
= mappedTo
;
2048 // The iterator should be erased before adding a new entry into
2049 // blockMergeInfo to avoid iterator invalidation.
2050 blockMergeInfo
.erase(it
);
2051 blockMergeInfo
.try_emplace(newHeader
, loc
, it
->second
.control
, newMerge
,
2055 // The structured selection/loop's entry block does not have arguments.
2056 // If the function's header block is also part of the structured control
2057 // flow, we cannot just simply erase it because it may contain arguments
2058 // matching the function signature and used by the cloned blocks.
2059 if (isFnEntryBlock(block
)) {
2060 LLVM_DEBUG(logger
.startLine() << "[cf] changing entry block " << block
2061 << " to only contain a spirv.Branch op\n");
2062 // Still keep the function entry block for the potential block arguments,
2063 // but replace all ops inside with a branch to the merge block.
2065 builder
.setInsertionPointToEnd(block
);
2066 builder
.create
<spirv::BranchOp
>(location
, mergeBlock
);
2068 LLVM_DEBUG(logger
.startLine() << "[cf] erasing block " << block
<< "\n");
2073 LLVM_DEBUG(logger
.startLine()
2074 << "[cf] after structurizing construct with header block "
2075 << headerBlock
<< ":\n"
2081 LogicalResult
spirv::Deserializer::wireUpBlockArgument() {
2084 << "//----- [phi] start wiring up block arguments -----//\n";
2088 OpBuilder::InsertionGuard
guard(opBuilder
);
2090 for (const auto &info
: blockPhiInfo
) {
2091 Block
*block
= info
.first
.first
;
2092 Block
*target
= info
.first
.second
;
2093 const BlockPhiInfo
&phiInfo
= info
.second
;
2095 logger
.startLine() << "[phi] block " << block
<< "\n";
2096 logger
.startLine() << "[phi] before creating block argument:\n";
2097 block
->getParentOp()->print(logger
.getOStream());
2098 logger
.startLine() << "\n";
2101 // Set insertion point to before this block's terminator early because we
2102 // may materialize ops via getValue() call.
2103 auto *op
= block
->getTerminator();
2104 opBuilder
.setInsertionPoint(op
);
2106 SmallVector
<Value
, 4> blockArgs
;
2107 blockArgs
.reserve(phiInfo
.size());
2108 for (uint32_t valueId
: phiInfo
) {
2109 if (Value value
= getValue(valueId
)) {
2110 blockArgs
.push_back(value
);
2111 LLVM_DEBUG(logger
.startLine() << "[phi] block argument " << value
2112 << " id = " << valueId
<< "\n");
2114 return emitError(unknownLoc
, "OpPhi references undefined value!");
2118 if (auto branchOp
= dyn_cast
<spirv::BranchOp
>(op
)) {
2119 // Replace the previous branch op with a new one with block arguments.
2120 opBuilder
.create
<spirv::BranchOp
>(branchOp
.getLoc(), branchOp
.getTarget(),
2123 } else if (auto branchCondOp
= dyn_cast
<spirv::BranchConditionalOp
>(op
)) {
2124 assert((branchCondOp
.getTrueBlock() == target
||
2125 branchCondOp
.getFalseBlock() == target
) &&
2126 "expected target to be either the true or false target");
2127 if (target
== branchCondOp
.getTrueTarget())
2128 opBuilder
.create
<spirv::BranchConditionalOp
>(
2129 branchCondOp
.getLoc(), branchCondOp
.getCondition(), blockArgs
,
2130 branchCondOp
.getFalseBlockArguments(),
2131 branchCondOp
.getBranchWeightsAttr(), branchCondOp
.getTrueTarget(),
2132 branchCondOp
.getFalseTarget());
2134 opBuilder
.create
<spirv::BranchConditionalOp
>(
2135 branchCondOp
.getLoc(), branchCondOp
.getCondition(),
2136 branchCondOp
.getTrueBlockArguments(), blockArgs
,
2137 branchCondOp
.getBranchWeightsAttr(), branchCondOp
.getTrueBlock(),
2138 branchCondOp
.getFalseBlock());
2140 branchCondOp
.erase();
2142 return emitError(unknownLoc
, "unimplemented terminator for Phi creation");
2146 logger
.startLine() << "[phi] after creating block argument:\n";
2147 block
->getParentOp()->print(logger
.getOStream());
2148 logger
.startLine() << "\n";
2151 blockPhiInfo
.clear();
2156 << "//--- [phi] completed wiring up block arguments ---//\n";
2161 LogicalResult
spirv::Deserializer::structurizeControlFlow() {
2164 << "//----- [cf] start structurizing control flow -----//\n";
2168 while (!blockMergeInfo
.empty()) {
2169 Block
*headerBlock
= blockMergeInfo
.begin()->first
;
2170 BlockMergeInfo mergeInfo
= blockMergeInfo
.begin()->second
;
2173 logger
.startLine() << "[cf] header block " << headerBlock
<< ":\n";
2174 headerBlock
->print(logger
.getOStream());
2175 logger
.startLine() << "\n";
2178 auto *mergeBlock
= mergeInfo
.mergeBlock
;
2179 assert(mergeBlock
&& "merge block cannot be nullptr");
2180 if (!mergeBlock
->args_empty())
2181 return emitError(unknownLoc
, "OpPhi in loop merge block unimplemented");
2183 logger
.startLine() << "[cf] merge block " << mergeBlock
<< ":\n";
2184 mergeBlock
->print(logger
.getOStream());
2185 logger
.startLine() << "\n";
2188 auto *continueBlock
= mergeInfo
.continueBlock
;
2189 LLVM_DEBUG(if (continueBlock
) {
2190 logger
.startLine() << "[cf] continue block " << continueBlock
<< ":\n";
2191 continueBlock
->print(logger
.getOStream());
2192 logger
.startLine() << "\n";
2194 // Erase this case before calling into structurizer, who will update
2196 blockMergeInfo
.erase(blockMergeInfo
.begin());
2197 ControlFlowStructurizer
structurizer(mergeInfo
.loc
, mergeInfo
.control
,
2198 blockMergeInfo
, headerBlock
,
2199 mergeBlock
, continueBlock
2205 if (failed(structurizer
.structurize()))
2212 << "//--- [cf] completed structurizing control flow ---//\n";
2217 //===----------------------------------------------------------------------===//
2219 //===----------------------------------------------------------------------===//
2221 Location
spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder
) {
2225 auto fileName
= debugInfoMap
.lookup(debugLine
->fileID
).str();
2226 if (fileName
.empty())
2227 fileName
= "<unknown>";
2228 return FileLineColLoc::get(opBuilder
.getStringAttr(fileName
), debugLine
->line
,
2233 spirv::Deserializer::processDebugLine(ArrayRef
<uint32_t> operands
) {
2234 // According to SPIR-V spec:
2235 // "This location information applies to the instructions physically
2236 // following this instruction, up to the first occurrence of any of the
2237 // following: the next end of block, the next OpLine instruction, or the next
2238 // OpNoLine instruction."
2239 if (operands
.size() != 3)
2240 return emitError(unknownLoc
, "OpLine must have 3 operands");
2241 debugLine
= DebugLine
{operands
[0], operands
[1], operands
[2]};
2245 void spirv::Deserializer::clearDebugLine() { debugLine
= std::nullopt
; }
2248 spirv::Deserializer::processDebugString(ArrayRef
<uint32_t> operands
) {
2249 if (operands
.size() < 2)
2250 return emitError(unknownLoc
, "OpString needs at least 2 operands");
2252 if (!debugInfoMap
.lookup(operands
[0]).empty())
2253 return emitError(unknownLoc
,
2254 "duplicate debug string found for result <id> ")
2257 unsigned wordIndex
= 1;
2258 StringRef debugString
= decodeStringLiteral(operands
, wordIndex
);
2259 if (wordIndex
!= operands
.size())
2260 return emitError(unknownLoc
,
2261 "unexpected trailing words in OpString instruction");
2263 debugInfoMap
[operands
[0]] = debugString
;