[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / mlir / lib / Target / SPIRV / Deserialization / Deserializer.cpp
blobfaaa42023a803ab4b512d5a8ca9bae3be8366365
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
11 //===----------------------------------------------------------------------===//
13 #include "Deserializer.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/Support/LogicalResult.h"
23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/bit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <optional>
34 using namespace mlir;
36 #define DEBUG_TYPE "spirv-deserialization"
38 //===----------------------------------------------------------------------===//
39 // Utility Functions
40 //===----------------------------------------------------------------------===//
42 /// Returns true if the given `block` is a function entry block.
43 static inline bool isFnEntryBlock(Block *block) {
44 return block->isEntryBlock() &&
45 isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
48 //===----------------------------------------------------------------------===//
49 // Deserializer Method Definitions
50 //===----------------------------------------------------------------------===//
52 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
53 MLIRContext *context)
54 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion())
56 #ifndef NDEBUG
58 logger(llvm::dbgs())
59 #endif
63 LogicalResult spirv::Deserializer::deserialize() {
64 LLVM_DEBUG({
65 logger.resetIndent();
66 logger.startLine()
67 << "//+++---------- start deserialization ----------+++//\n";
68 });
70 if (failed(processHeader()))
71 return failure();
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
77 // Slice the next instruction out and populate `opcode` and `operands`.
78 // Internally this also updates `curOffset`.
79 if (failed(sliceInstruction(opcode, operands)))
80 return failure();
82 if (failed(processInstruction(opcode, operands)))
83 return failure();
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (auto &deferred : deferredInstructions) {
90 if (failed(processInstruction(deferred.first, deferred.second, false))) {
91 return failure();
95 attachVCETriple();
97 LLVM_DEBUG(logger.startLine()
98 << "//+++-------- completed deserialization --------+++//\n");
99 return success();
102 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
106 //===----------------------------------------------------------------------===//
107 // Module structure
108 //===----------------------------------------------------------------------===//
110 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
114 return cast<spirv::ModuleOp>(Operation::create(state));
117 LogicalResult spirv::Deserializer::processHeader() {
118 if (binary.size() < spirv::kHeaderWordCount)
119 return emitError(unknownLoc,
120 "SPIR-V binary module must have a 5-word header");
122 if (binary[0] != spirv::kMagicNumber)
123 return emitError(unknownLoc, "incorrect magic number");
125 // Version number bytes: 0 | major number | minor number | 0
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
131 case v: \
132 version = spirv::Version::V_1_##v; \
133 break
135 MIN_VERSION_CASE(0);
136 MIN_VERSION_CASE(1);
137 MIN_VERSION_CASE(2);
138 MIN_VERSION_CASE(3);
139 MIN_VERSION_CASE(4);
140 MIN_VERSION_CASE(5);
141 #undef MIN_VERSION_CASE
142 default:
143 return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
144 << minorVersion;
146 } else {
147 return emitError(unknownLoc, "unsupported SPIR-V major version: ")
148 << majorVersion;
151 // TODO: generator number, bound, schema
152 curOffset = spirv::kHeaderWordCount;
153 return success();
156 LogicalResult
157 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
158 if (operands.size() != 1)
159 return emitError(unknownLoc, "OpMemoryModel must have one parameter");
161 auto cap = spirv::symbolizeCapability(operands[0]);
162 if (!cap)
163 return emitError(unknownLoc, "unknown capability: ") << operands[0];
165 capabilities.insert(*cap);
166 return success();
169 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
170 if (words.empty()) {
171 return emitError(
172 unknownLoc,
173 "OpExtension must have a literal string for the extension name");
176 unsigned wordIndex = 0;
177 StringRef extName = decodeStringLiteral(words, wordIndex);
178 if (wordIndex != words.size())
179 return emitError(unknownLoc,
180 "unexpected trailing words in OpExtension instruction");
181 auto ext = spirv::symbolizeExtension(extName);
182 if (!ext)
183 return emitError(unknownLoc, "unknown extension: ") << extName;
185 extensions.insert(*ext);
186 return success();
189 LogicalResult
190 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
191 if (words.size() < 2) {
192 return emitError(unknownLoc,
193 "OpExtInstImport must have a result <id> and a literal "
194 "string for the extended instruction set name");
197 unsigned wordIndex = 1;
198 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
199 if (wordIndex != words.size()) {
200 return emitError(unknownLoc,
201 "unexpected trailing words in OpExtInstImport");
203 return success();
206 void spirv::Deserializer::attachVCETriple() {
207 (*module)->setAttr(
208 spirv::ModuleOp::getVCETripleAttrName(),
209 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
210 extensions.getArrayRef(), context));
213 LogicalResult
214 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
215 if (operands.size() != 2)
216 return emitError(unknownLoc, "OpMemoryModel must have two operands");
218 (*module)->setAttr(
219 module->getAddressingModelAttrName(),
220 opBuilder.getAttr<spirv::AddressingModelAttr>(
221 static_cast<spirv::AddressingModel>(operands.front())));
223 (*module)->setAttr(module->getMemoryModelAttrName(),
224 opBuilder.getAttr<spirv::MemoryModelAttr>(
225 static_cast<spirv::MemoryModel>(operands.back())));
227 return success();
230 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
231 // TODO: This function should also be auto-generated. For now, since only a
232 // few decorations are processed/handled in a meaningful manner, going with a
233 // manual implementation.
234 if (words.size() < 2) {
235 return emitError(
236 unknownLoc, "OpDecorate must have at least result <id> and Decoration");
238 auto decorationName =
239 stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
240 if (decorationName.empty()) {
241 return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
243 auto symbol = getSymbolDecoration(decorationName);
244 switch (static_cast<spirv::Decoration>(words[1])) {
245 case spirv::Decoration::FPFastMathMode:
246 if (words.size() != 3) {
247 return emitError(unknownLoc, "OpDecorate with ")
248 << decorationName << " needs a single integer literal";
250 decorations[words[0]].set(
251 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
252 static_cast<FPFastMathMode>(words[2])));
253 break;
254 case spirv::Decoration::DescriptorSet:
255 case spirv::Decoration::Binding:
256 if (words.size() != 3) {
257 return emitError(unknownLoc, "OpDecorate with ")
258 << decorationName << " needs a single integer literal";
260 decorations[words[0]].set(
261 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
262 break;
263 case spirv::Decoration::BuiltIn:
264 if (words.size() != 3) {
265 return emitError(unknownLoc, "OpDecorate with ")
266 << decorationName << " needs a single integer literal";
268 decorations[words[0]].set(
269 symbol, opBuilder.getStringAttr(
270 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
271 break;
272 case spirv::Decoration::ArrayStride:
273 if (words.size() != 3) {
274 return emitError(unknownLoc, "OpDecorate with ")
275 << decorationName << " needs a single integer literal";
277 typeDecorations[words[0]] = words[2];
278 break;
279 case spirv::Decoration::LinkageAttributes: {
280 if (words.size() < 4) {
281 return emitError(unknownLoc, "OpDecorate with ")
282 << decorationName
283 << " needs at least 1 string and 1 integer literal";
285 // LinkageAttributes has two parameters ["linkageName", linkageType]
286 // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
287 // "linkageName" is a stringliteral encoded as uint32_t,
288 // hence the size of name is variable length which results in words.size()
289 // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
290 // 3 + ceildiv(strlen(name), 4).
291 unsigned wordIndex = 2;
292 auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
293 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
294 static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
295 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
296 StringAttr::get(context, linkageName), linkageTypeAttr);
297 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
298 break;
300 case spirv::Decoration::Aliased:
301 case spirv::Decoration::AliasedPointer:
302 case spirv::Decoration::Block:
303 case spirv::Decoration::BufferBlock:
304 case spirv::Decoration::Flat:
305 case spirv::Decoration::NonReadable:
306 case spirv::Decoration::NonWritable:
307 case spirv::Decoration::NoPerspective:
308 case spirv::Decoration::NoSignedWrap:
309 case spirv::Decoration::NoUnsignedWrap:
310 case spirv::Decoration::RelaxedPrecision:
311 case spirv::Decoration::Restrict:
312 case spirv::Decoration::RestrictPointer:
313 case spirv::Decoration::NoContraction:
314 if (words.size() != 2) {
315 return emitError(unknownLoc, "OpDecoration with ")
316 << decorationName << "needs a single target <id>";
318 // Block decoration does not affect spirv.struct type, but is still stored
319 // for verification.
320 // TODO: Update StructType to contain this information since
321 // it is needed for many validation rules.
322 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
323 break;
324 case spirv::Decoration::Location:
325 case spirv::Decoration::SpecId:
326 if (words.size() != 3) {
327 return emitError(unknownLoc, "OpDecoration with ")
328 << decorationName << "needs a single integer literal";
330 decorations[words[0]].set(
331 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
332 break;
333 default:
334 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
336 return success();
339 LogicalResult
340 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
341 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
342 if (words.size() < 3) {
343 return emitError(unknownLoc,
344 "OpMemberDecorate must have at least 3 operands");
347 auto decoration = static_cast<spirv::Decoration>(words[2]);
348 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
349 return emitError(unknownLoc,
350 " missing offset specification in OpMemberDecorate with "
351 "Offset decoration");
353 ArrayRef<uint32_t> decorationOperands;
354 if (words.size() > 3) {
355 decorationOperands = words.slice(3);
357 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
358 return success();
361 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
362 if (words.size() < 3) {
363 return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
365 unsigned wordIndex = 2;
366 auto name = decodeStringLiteral(words, wordIndex);
367 if (wordIndex != words.size()) {
368 return emitError(unknownLoc,
369 "unexpected trailing words in OpMemberName instruction");
371 memberNameMap[words[0]][words[1]] = name;
372 return success();
375 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
376 uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
377 if (!decorations.contains(argID)) {
378 argAttrs[argIndex] = DictionaryAttr::get(context, {});
379 return success();
382 spirv::DecorationAttr foundDecorationAttr;
383 for (NamedAttribute decAttr : decorations[argID]) {
384 for (auto decoration :
385 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
386 spirv::Decoration::AliasedPointer,
387 spirv::Decoration::RestrictPointer}) {
389 if (decAttr.getName() !=
390 getSymbolDecoration(stringifyDecoration(decoration)))
391 continue;
393 if (foundDecorationAttr)
394 return emitError(unknownLoc,
395 "more than one Aliased/Restrict decorations for "
396 "function argument with result <id> ")
397 << argID;
399 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
400 break;
404 if (!foundDecorationAttr)
405 return emitError(unknownLoc, "unimplemented decoration support for "
406 "function argument with result <id> ")
407 << argID;
409 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
410 foundDecorationAttr);
411 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
412 return success();
415 LogicalResult
416 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
417 if (curFunction) {
418 return emitError(unknownLoc, "found function inside function");
421 // Get the result type
422 if (operands.size() != 4) {
423 return emitError(unknownLoc, "OpFunction must have 4 parameters");
425 Type resultType = getType(operands[0]);
426 if (!resultType) {
427 return emitError(unknownLoc, "undefined result type from <id> ")
428 << operands[0];
431 uint32_t fnID = operands[1];
432 if (funcMap.count(fnID)) {
433 return emitError(unknownLoc, "duplicate function definition/declaration");
436 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
437 if (!fnControl) {
438 return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
441 Type fnType = getType(operands[3]);
442 if (!fnType || !isa<FunctionType>(fnType)) {
443 return emitError(unknownLoc, "unknown function type from <id> ")
444 << operands[3];
446 auto functionType = cast<FunctionType>(fnType);
448 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
449 (functionType.getNumResults() == 1 &&
450 functionType.getResult(0) != resultType)) {
451 return emitError(unknownLoc, "mismatch in function type ")
452 << functionType << " and return type " << resultType << " specified";
455 std::string fnName = getFunctionSymbol(fnID);
456 auto funcOp = opBuilder.create<spirv::FuncOp>(
457 unknownLoc, fnName, functionType, fnControl.value());
458 // Processing other function attributes.
459 if (decorations.count(fnID)) {
460 for (auto attr : decorations[fnID].getAttrs()) {
461 funcOp->setAttr(attr.getName(), attr.getValue());
464 curFunction = funcMap[fnID] = funcOp;
465 auto *entryBlock = funcOp.addEntryBlock();
466 LLVM_DEBUG({
467 logger.startLine()
468 << "//===-------------------------------------------===//\n";
469 logger.startLine() << "[fn] name: " << fnName << "\n";
470 logger.startLine() << "[fn] type: " << fnType << "\n";
471 logger.startLine() << "[fn] ID: " << fnID << "\n";
472 logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
473 logger.indent();
476 SmallVector<Attribute> argAttrs;
477 argAttrs.resize(functionType.getNumInputs());
479 // Parse the op argument instructions
480 if (functionType.getNumInputs()) {
481 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
482 auto argType = functionType.getInput(i);
483 spirv::Opcode opcode = spirv::Opcode::OpNop;
484 ArrayRef<uint32_t> operands;
485 if (failed(sliceInstruction(opcode, operands,
486 spirv::Opcode::OpFunctionParameter))) {
487 return failure();
489 if (opcode != spirv::Opcode::OpFunctionParameter) {
490 return emitError(
491 unknownLoc,
492 "missing OpFunctionParameter instruction for argument ")
493 << i;
495 if (operands.size() != 2) {
496 return emitError(
497 unknownLoc,
498 "expected result type and result <id> for OpFunctionParameter");
500 auto argDefinedType = getType(operands[0]);
501 if (!argDefinedType || argDefinedType != argType) {
502 return emitError(unknownLoc,
503 "mismatch in argument type between function type "
504 "definition ")
505 << functionType << " and argument type definition "
506 << argDefinedType << " at argument " << i;
508 if (getValue(operands[1])) {
509 return emitError(unknownLoc, "duplicate definition of result <id> ")
510 << operands[1];
512 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
513 return failure();
516 auto argValue = funcOp.getArgument(i);
517 valueMap[operands[1]] = argValue;
521 if (llvm::any_of(argAttrs, [](Attribute attr) {
522 auto argAttr = cast<DictionaryAttr>(attr);
523 return !argAttr.empty();
525 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
527 // entryBlock is needed to access the arguments, Once that is done, we can
528 // erase the block for functions with 'Import' LinkageAttributes, since these
529 // are essentially function declarations, so they have no body.
530 auto linkageAttr = funcOp.getLinkageAttributes();
531 auto hasImportLinkage =
532 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
533 spirv::LinkageType::Import);
534 if (hasImportLinkage)
535 funcOp.eraseBody();
537 // RAII guard to reset the insertion point to the module's region after
538 // deserializing the body of this function.
539 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
541 spirv::Opcode opcode = spirv::Opcode::OpNop;
542 ArrayRef<uint32_t> instOperands;
544 // Special handling for the entry block. We need to make sure it starts with
545 // an OpLabel instruction. The entry block takes the same parameters as the
546 // function. All other blocks do not take any parameter. We have already
547 // created the entry block, here we need to register it to the correct label
548 // <id>.
549 if (failed(sliceInstruction(opcode, instOperands,
550 spirv::Opcode::OpFunctionEnd))) {
551 return failure();
553 if (opcode == spirv::Opcode::OpFunctionEnd) {
554 return processFunctionEnd(instOperands);
556 if (opcode != spirv::Opcode::OpLabel) {
557 return emitError(unknownLoc, "a basic block must start with OpLabel");
559 if (instOperands.size() != 1) {
560 return emitError(unknownLoc, "OpLabel should only have result <id>");
562 blockMap[instOperands[0]] = entryBlock;
563 if (failed(processLabel(instOperands))) {
564 return failure();
567 // Then process all the other instructions in the function until we hit
568 // OpFunctionEnd.
569 while (succeeded(sliceInstruction(opcode, instOperands,
570 spirv::Opcode::OpFunctionEnd)) &&
571 opcode != spirv::Opcode::OpFunctionEnd) {
572 if (failed(processInstruction(opcode, instOperands))) {
573 return failure();
576 if (opcode != spirv::Opcode::OpFunctionEnd) {
577 return failure();
580 return processFunctionEnd(instOperands);
583 LogicalResult
584 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
585 // Process OpFunctionEnd.
586 if (!operands.empty()) {
587 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
590 // Wire up block arguments from OpPhi instructions.
591 // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
592 // ops.
593 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
594 return failure();
597 curBlock = nullptr;
598 curFunction = std::nullopt;
600 LLVM_DEBUG({
601 logger.unindent();
602 logger.startLine()
603 << "//===-------------------------------------------===//\n";
605 return success();
608 std::optional<std::pair<Attribute, Type>>
609 spirv::Deserializer::getConstant(uint32_t id) {
610 auto constIt = constantMap.find(id);
611 if (constIt == constantMap.end())
612 return std::nullopt;
613 return constIt->getSecond();
616 std::optional<spirv::SpecConstOperationMaterializationInfo>
617 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
618 auto constIt = specConstOperationMap.find(id);
619 if (constIt == specConstOperationMap.end())
620 return std::nullopt;
621 return constIt->getSecond();
624 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
625 auto funcName = nameMap.lookup(id).str();
626 if (funcName.empty()) {
627 funcName = "spirv_fn_" + std::to_string(id);
629 return funcName;
632 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
633 auto constName = nameMap.lookup(id).str();
634 if (constName.empty()) {
635 constName = "spirv_spec_const_" + std::to_string(id);
637 return constName;
640 spirv::SpecConstantOp
641 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
642 TypedAttr defaultValue) {
643 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
644 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
645 defaultValue);
646 if (decorations.count(resultID)) {
647 for (auto attr : decorations[resultID].getAttrs())
648 op->setAttr(attr.getName(), attr.getValue());
650 specConstMap[resultID] = op;
651 return op;
654 LogicalResult
655 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
656 unsigned wordIndex = 0;
657 if (operands.size() < 3) {
658 return emitError(
659 unknownLoc,
660 "OpVariable needs at least 3 operands, type, <id> and storage class");
663 // Result Type.
664 auto type = getType(operands[wordIndex]);
665 if (!type) {
666 return emitError(unknownLoc, "unknown result type <id> : ")
667 << operands[wordIndex];
669 auto ptrType = dyn_cast<spirv::PointerType>(type);
670 if (!ptrType) {
671 return emitError(unknownLoc,
672 "expected a result type <id> to be a spirv.ptr, found : ")
673 << type;
675 wordIndex++;
677 // Result <id>.
678 auto variableID = operands[wordIndex];
679 auto variableName = nameMap.lookup(variableID).str();
680 if (variableName.empty()) {
681 variableName = "spirv_var_" + std::to_string(variableID);
683 wordIndex++;
685 // Storage class.
686 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
687 if (ptrType.getStorageClass() != storageClass) {
688 return emitError(unknownLoc, "mismatch in storage class of pointer type ")
689 << type << " and that specified in OpVariable instruction : "
690 << stringifyStorageClass(storageClass);
692 wordIndex++;
694 // Initializer.
695 FlatSymbolRefAttr initializer = nullptr;
697 if (wordIndex < operands.size()) {
698 Operation *op = nullptr;
700 if (auto initOp = getGlobalVariable(operands[wordIndex]))
701 op = initOp;
702 else if (auto initOp = getSpecConstant(operands[wordIndex]))
703 op = initOp;
704 else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
705 op = initOp;
706 else
707 return emitError(unknownLoc, "unknown <id> ")
708 << operands[wordIndex] << "used as initializer";
710 initializer = SymbolRefAttr::get(op);
711 wordIndex++;
713 if (wordIndex != operands.size()) {
714 return emitError(unknownLoc,
715 "found more operands than expected when deserializing "
716 "OpVariable instruction, only ")
717 << wordIndex << " of " << operands.size() << " processed";
719 auto loc = createFileLineColLoc(opBuilder);
720 auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
721 loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
722 initializer);
724 // Decorations.
725 if (decorations.count(variableID)) {
726 for (auto attr : decorations[variableID].getAttrs())
727 varOp->setAttr(attr.getName(), attr.getValue());
729 globalVariableMap[variableID] = varOp;
730 return success();
733 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
734 auto constInfo = getConstant(id);
735 if (!constInfo) {
736 return nullptr;
738 return dyn_cast<IntegerAttr>(constInfo->first);
741 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
742 if (operands.size() < 2) {
743 return emitError(unknownLoc, "OpName needs at least 2 operands");
745 if (!nameMap.lookup(operands[0]).empty()) {
746 return emitError(unknownLoc, "duplicate name found for result <id> ")
747 << operands[0];
749 unsigned wordIndex = 1;
750 StringRef name = decodeStringLiteral(operands, wordIndex);
751 if (wordIndex != operands.size()) {
752 return emitError(unknownLoc,
753 "unexpected trailing words in OpName instruction");
755 nameMap[operands[0]] = name;
756 return success();
759 //===----------------------------------------------------------------------===//
760 // Type
761 //===----------------------------------------------------------------------===//
763 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
764 ArrayRef<uint32_t> operands) {
765 if (operands.empty()) {
766 return emitError(unknownLoc, "type instruction with opcode ")
767 << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
770 /// TODO: Types might be forward declared in some instructions and need to be
771 /// handled appropriately.
772 if (typeMap.count(operands[0])) {
773 return emitError(unknownLoc, "duplicate definition for result <id> ")
774 << operands[0];
777 switch (opcode) {
778 case spirv::Opcode::OpTypeVoid:
779 if (operands.size() != 1)
780 return emitError(unknownLoc, "OpTypeVoid must have no parameters");
781 typeMap[operands[0]] = opBuilder.getNoneType();
782 break;
783 case spirv::Opcode::OpTypeBool:
784 if (operands.size() != 1)
785 return emitError(unknownLoc, "OpTypeBool must have no parameters");
786 typeMap[operands[0]] = opBuilder.getI1Type();
787 break;
788 case spirv::Opcode::OpTypeInt: {
789 if (operands.size() != 3)
790 return emitError(
791 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
793 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
794 // to preserve or validate.
795 // 0 indicates unsigned, or no signedness semantics
796 // 1 indicates signed semantics."
798 // So we cannot differentiate signless and unsigned integers; always use
799 // signless semantics for such cases.
800 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
801 : IntegerType::SignednessSemantics::Signless;
802 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
803 } break;
804 case spirv::Opcode::OpTypeFloat: {
805 if (operands.size() != 2)
806 return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
808 Type floatTy;
809 switch (operands[1]) {
810 case 16:
811 floatTy = opBuilder.getF16Type();
812 break;
813 case 32:
814 floatTy = opBuilder.getF32Type();
815 break;
816 case 64:
817 floatTy = opBuilder.getF64Type();
818 break;
819 default:
820 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
821 << operands[1];
823 typeMap[operands[0]] = floatTy;
824 } break;
825 case spirv::Opcode::OpTypeVector: {
826 if (operands.size() != 3) {
827 return emitError(
828 unknownLoc,
829 "OpTypeVector must have element type and count parameters");
831 Type elementTy = getType(operands[1]);
832 if (!elementTy) {
833 return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
834 << operands[1];
836 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
837 } break;
838 case spirv::Opcode::OpTypePointer: {
839 return processOpTypePointer(operands);
840 } break;
841 case spirv::Opcode::OpTypeArray:
842 return processArrayType(operands);
843 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
844 return processCooperativeMatrixTypeKHR(operands);
845 case spirv::Opcode::OpTypeFunction:
846 return processFunctionType(operands);
847 case spirv::Opcode::OpTypeJointMatrixINTEL:
848 return processJointMatrixType(operands);
849 case spirv::Opcode::OpTypeImage:
850 return processImageType(operands);
851 case spirv::Opcode::OpTypeSampledImage:
852 return processSampledImageType(operands);
853 case spirv::Opcode::OpTypeRuntimeArray:
854 return processRuntimeArrayType(operands);
855 case spirv::Opcode::OpTypeStruct:
856 return processStructType(operands);
857 case spirv::Opcode::OpTypeMatrix:
858 return processMatrixType(operands);
859 default:
860 return emitError(unknownLoc, "unhandled type instruction");
862 return success();
865 LogicalResult
866 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
867 if (operands.size() != 3)
868 return emitError(unknownLoc, "OpTypePointer must have two parameters");
870 auto pointeeType = getType(operands[2]);
871 if (!pointeeType)
872 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
873 << operands[2];
875 uint32_t typePointerID = operands[0];
876 auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
877 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
879 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
880 deferredStructIt != std::end(deferredStructTypesInfos);) {
881 for (auto *unresolvedMemberIt =
882 std::begin(deferredStructIt->unresolvedMemberTypes);
883 unresolvedMemberIt !=
884 std::end(deferredStructIt->unresolvedMemberTypes);) {
885 if (unresolvedMemberIt->first == typePointerID) {
886 // The newly constructed pointer type can resolve one of the
887 // deferred struct type members; update the memberTypes list and
888 // clean the unresolvedMemberTypes list accordingly.
889 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
890 typeMap[typePointerID];
891 unresolvedMemberIt =
892 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
893 } else {
894 ++unresolvedMemberIt;
898 if (deferredStructIt->unresolvedMemberTypes.empty()) {
899 // All deferred struct type members are now resolved, set the struct body.
900 auto structType = deferredStructIt->deferredStructType;
902 assert(structType && "expected a spirv::StructType");
903 assert(structType.isIdentified() && "expected an indentified struct");
905 if (failed(structType.trySetBody(
906 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
907 deferredStructIt->memberDecorationsInfo)))
908 return failure();
910 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
911 } else {
912 ++deferredStructIt;
916 return success();
919 LogicalResult
920 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
921 if (operands.size() != 3) {
922 return emitError(unknownLoc,
923 "OpTypeArray must have element type and count parameters");
926 Type elementTy = getType(operands[1]);
927 if (!elementTy) {
928 return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
929 << operands[1];
932 unsigned count = 0;
933 // TODO: The count can also come frome a specialization constant.
934 auto countInfo = getConstant(operands[2]);
935 if (!countInfo) {
936 return emitError(unknownLoc, "OpTypeArray count <id> ")
937 << operands[2] << "can only come from normal constant right now";
940 if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
941 count = intVal.getValue().getZExtValue();
942 } else {
943 return emitError(unknownLoc, "OpTypeArray count must come from a "
944 "scalar integer constant instruction");
947 typeMap[operands[0]] = spirv::ArrayType::get(
948 elementTy, count, typeDecorations.lookup(operands[0]));
949 return success();
952 LogicalResult
953 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
954 assert(!operands.empty() && "No operands for processing function type");
955 if (operands.size() == 1) {
956 return emitError(unknownLoc, "missing return type for OpTypeFunction");
958 auto returnType = getType(operands[1]);
959 if (!returnType) {
960 return emitError(unknownLoc, "unknown return type in OpTypeFunction");
962 SmallVector<Type, 1> argTypes;
963 for (size_t i = 2, e = operands.size(); i < e; ++i) {
964 auto ty = getType(operands[i]);
965 if (!ty) {
966 return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
968 argTypes.push_back(ty);
970 ArrayRef<Type> returnTypes;
971 if (!isVoidType(returnType)) {
972 returnTypes = llvm::ArrayRef(returnType);
974 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
975 return success();
978 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
979 ArrayRef<uint32_t> operands) {
980 if (operands.size() != 6) {
981 return emitError(unknownLoc,
982 "OpTypeCooperativeMatrixKHR must have element type, "
983 "scope, row and column parameters, and use");
986 Type elementTy = getType(operands[1]);
987 if (!elementTy) {
988 return emitError(unknownLoc,
989 "OpTypeCooperativeMatrixKHR references undefined <id> ")
990 << operands[1];
993 std::optional<spirv::Scope> scope =
994 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
995 if (!scope) {
996 return emitError(
997 unknownLoc,
998 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
999 << operands[2];
1002 unsigned rows = getConstantInt(operands[3]).getInt();
1003 unsigned columns = getConstantInt(operands[4]).getInt();
1005 std::optional<spirv::CooperativeMatrixUseKHR> use =
1006 spirv::symbolizeCooperativeMatrixUseKHR(
1007 getConstantInt(operands[5]).getInt());
1008 if (!use) {
1009 return emitError(
1010 unknownLoc,
1011 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1012 << operands[5];
1015 typeMap[operands[0]] =
1016 spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1017 return success();
1020 LogicalResult
1021 spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
1022 if (operands.size() != 6) {
1023 return emitError(unknownLoc, "OpTypeJointMatrix must have element "
1024 "type and row x column parameters");
1027 Type elementTy = getType(operands[1]);
1028 if (!elementTy) {
1029 return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ")
1030 << operands[1];
1033 auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
1034 if (!scope) {
1035 return emitError(unknownLoc,
1036 "OpTypeJointMatrix references undefined scope <id> ")
1037 << operands[5];
1039 auto matrixLayout =
1040 spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
1041 if (!matrixLayout) {
1042 return emitError(unknownLoc,
1043 "OpTypeJointMatrix references undefined scope <id> ")
1044 << operands[4];
1046 unsigned rows = getConstantInt(operands[2]).getInt();
1047 unsigned columns = getConstantInt(operands[3]).getInt();
1049 typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
1050 elementTy, scope.value(), rows, columns, matrixLayout.value());
1051 return success();
1054 LogicalResult
1055 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1056 if (operands.size() != 2) {
1057 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1059 Type memberType = getType(operands[1]);
1060 if (!memberType) {
1061 return emitError(unknownLoc,
1062 "OpTypeRuntimeArray references undefined <id> ")
1063 << operands[1];
1065 typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1066 memberType, typeDecorations.lookup(operands[0]));
1067 return success();
1070 LogicalResult
1071 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1072 // TODO: Find a way to handle identified structs when debug info is stripped.
1074 if (operands.empty()) {
1075 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1078 if (operands.size() == 1) {
1079 // Handle empty struct.
1080 typeMap[operands[0]] =
1081 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1082 return success();
1085 // First element is operand ID, second element is member index in the struct.
1086 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1087 SmallVector<Type, 4> memberTypes;
1089 for (auto op : llvm::drop_begin(operands, 1)) {
1090 Type memberType = getType(op);
1091 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1093 if (!memberType && !typeForwardPtr)
1094 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1095 << op;
1097 if (!memberType)
1098 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1100 memberTypes.push_back(memberType);
1103 SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
1104 SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
1105 if (memberDecorationMap.count(operands[0])) {
1106 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1107 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1108 if (allMemberDecorations.count(memberIndex)) {
1109 for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1110 // Check for offset.
1111 if (memberDecoration.first == spirv::Decoration::Offset) {
1112 // If offset info is empty, resize to the number of members;
1113 if (offsetInfo.empty()) {
1114 offsetInfo.resize(memberTypes.size());
1116 offsetInfo[memberIndex] = memberDecoration.second[0];
1117 } else {
1118 if (!memberDecoration.second.empty()) {
1119 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
1120 memberDecoration.first,
1121 memberDecoration.second[0]);
1122 } else {
1123 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
1124 memberDecoration.first, 0);
1132 uint32_t structID = operands[0];
1133 std::string structIdentifier = nameMap.lookup(structID).str();
1135 if (structIdentifier.empty()) {
1136 assert(unresolvedMemberTypes.empty() &&
1137 "didn't expect unresolved member types");
1138 typeMap[structID] =
1139 spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
1140 } else {
1141 auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1142 typeMap[structID] = structTy;
1144 if (!unresolvedMemberTypes.empty())
1145 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1146 memberTypes, offsetInfo,
1147 memberDecorationsInfo});
1148 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1149 memberDecorationsInfo)))
1150 return failure();
1153 // TODO: Update StructType to have member name as attribute as
1154 // well.
1155 return success();
1158 LogicalResult
1159 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1160 if (operands.size() != 3) {
1161 // Three operands are needed: result_id, column_type, and column_count
1162 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1163 " (result_id, column_type, and column_count)");
1165 // Matrix columns must be of vector type
1166 Type elementTy = getType(operands[1]);
1167 if (!elementTy) {
1168 return emitError(unknownLoc,
1169 "OpTypeMatrix references undefined column type.")
1170 << operands[1];
1173 uint32_t colsCount = operands[2];
1174 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1175 return success();
1178 LogicalResult
1179 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1180 if (operands.size() != 2)
1181 return emitError(unknownLoc,
1182 "OpTypeForwardPointer instruction must have two operands");
1184 typeForwardPointerIDs.insert(operands[0]);
1185 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1186 // instruction that defines the actual type.
1188 return success();
1191 LogicalResult
1192 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1193 // TODO: Add support for Access Qualifier.
1194 if (operands.size() != 8)
1195 return emitError(
1196 unknownLoc,
1197 "OpTypeImage with non-eight operands are not supported yet");
1199 Type elementTy = getType(operands[1]);
1200 if (!elementTy)
1201 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1202 << operands[1];
1204 auto dim = spirv::symbolizeDim(operands[2]);
1205 if (!dim)
1206 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1207 << operands[2];
1209 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1210 if (!depthInfo)
1211 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1212 << operands[3];
1214 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1215 if (!arrayedInfo)
1216 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1217 << operands[4];
1219 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1220 if (!samplingInfo)
1221 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1223 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1224 if (!samplerUseInfo)
1225 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1226 << operands[6];
1228 auto format = spirv::symbolizeImageFormat(operands[7]);
1229 if (!format)
1230 return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1231 << operands[7];
1233 typeMap[operands[0]] = spirv::ImageType::get(
1234 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1235 samplingInfo.value(), samplerUseInfo.value(), format.value());
1236 return success();
1239 LogicalResult
1240 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1241 if (operands.size() != 2)
1242 return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1244 Type elementTy = getType(operands[1]);
1245 if (!elementTy)
1246 return emitError(unknownLoc,
1247 "OpTypeSampledImage references undefined <id>: ")
1248 << operands[1];
1250 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1251 return success();
1254 //===----------------------------------------------------------------------===//
1255 // Constant
1256 //===----------------------------------------------------------------------===//
1258 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1259 bool isSpec) {
1260 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1262 if (operands.size() < 2) {
1263 return emitError(unknownLoc)
1264 << opname << " must have type <id> and result <id>";
1266 if (operands.size() < 3) {
1267 return emitError(unknownLoc)
1268 << opname << " must have at least 1 more parameter";
1271 Type resultType = getType(operands[0]);
1272 if (!resultType) {
1273 return emitError(unknownLoc, "undefined result type from <id> ")
1274 << operands[0];
1277 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1278 if (bitwidth == 64) {
1279 if (operands.size() == 4) {
1280 return success();
1282 return emitError(unknownLoc)
1283 << opname << " should have 2 parameters for 64-bit values";
1285 if (bitwidth <= 32) {
1286 if (operands.size() == 3) {
1287 return success();
1290 return emitError(unknownLoc)
1291 << opname
1292 << " should have 1 parameter for values with no more than 32 bits";
1294 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1295 << bitwidth;
1298 auto resultID = operands[1];
1300 if (auto intType = dyn_cast<IntegerType>(resultType)) {
1301 auto bitwidth = intType.getWidth();
1302 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1303 return failure();
1306 APInt value;
1307 if (bitwidth == 64) {
1308 // 64-bit integers are represented with two SPIR-V words. According to
1309 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1310 // literal’s low-order words appear first."
1311 struct DoubleWord {
1312 uint32_t word1;
1313 uint32_t word2;
1314 } words = {operands[2], operands[3]};
1315 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1316 } else if (bitwidth <= 32) {
1317 value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1320 auto attr = opBuilder.getIntegerAttr(intType, value);
1322 if (isSpec) {
1323 createSpecConstant(unknownLoc, resultID, attr);
1324 } else {
1325 // For normal constants, we just record the attribute (and its type) for
1326 // later materialization at use sites.
1327 constantMap.try_emplace(resultID, attr, intType);
1330 return success();
1333 if (auto floatType = dyn_cast<FloatType>(resultType)) {
1334 auto bitwidth = floatType.getWidth();
1335 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1336 return failure();
1339 APFloat value(0.f);
1340 if (floatType.isF64()) {
1341 // Double values are represented with two SPIR-V words. According to
1342 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1343 // literal’s low-order words appear first."
1344 struct DoubleWord {
1345 uint32_t word1;
1346 uint32_t word2;
1347 } words = {operands[2], operands[3]};
1348 value = APFloat(llvm::bit_cast<double>(words));
1349 } else if (floatType.isF32()) {
1350 value = APFloat(llvm::bit_cast<float>(operands[2]));
1351 } else if (floatType.isF16()) {
1352 APInt data(16, operands[2]);
1353 value = APFloat(APFloat::IEEEhalf(), data);
1356 auto attr = opBuilder.getFloatAttr(floatType, value);
1357 if (isSpec) {
1358 createSpecConstant(unknownLoc, resultID, attr);
1359 } else {
1360 // For normal constants, we just record the attribute (and its type) for
1361 // later materialization at use sites.
1362 constantMap.try_emplace(resultID, attr, floatType);
1365 return success();
1368 return emitError(unknownLoc, "OpConstant can only generate values of "
1369 "scalar integer or floating-point type");
1372 LogicalResult spirv::Deserializer::processConstantBool(
1373 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1374 if (operands.size() != 2) {
1375 return emitError(unknownLoc, "Op")
1376 << (isSpec ? "Spec" : "") << "Constant"
1377 << (isTrue ? "True" : "False")
1378 << " must have type <id> and result <id>";
1381 auto attr = opBuilder.getBoolAttr(isTrue);
1382 auto resultID = operands[1];
1383 if (isSpec) {
1384 createSpecConstant(unknownLoc, resultID, attr);
1385 } else {
1386 // For normal constants, we just record the attribute (and its type) for
1387 // later materialization at use sites.
1388 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1391 return success();
1394 LogicalResult
1395 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1396 if (operands.size() < 2) {
1397 return emitError(unknownLoc,
1398 "OpConstantComposite must have type <id> and result <id>");
1400 if (operands.size() < 3) {
1401 return emitError(unknownLoc,
1402 "OpConstantComposite must have at least 1 parameter");
1405 Type resultType = getType(operands[0]);
1406 if (!resultType) {
1407 return emitError(unknownLoc, "undefined result type from <id> ")
1408 << operands[0];
1411 SmallVector<Attribute, 4> elements;
1412 elements.reserve(operands.size() - 2);
1413 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1414 auto elementInfo = getConstant(operands[i]);
1415 if (!elementInfo) {
1416 return emitError(unknownLoc, "OpConstantComposite component <id> ")
1417 << operands[i] << " must come from a normal constant";
1419 elements.push_back(elementInfo->first);
1422 auto resultID = operands[1];
1423 if (auto vectorType = dyn_cast<VectorType>(resultType)) {
1424 auto attr = DenseElementsAttr::get(vectorType, elements);
1425 // For normal constants, we just record the attribute (and its type) for
1426 // later materialization at use sites.
1427 constantMap.try_emplace(resultID, attr, resultType);
1428 } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1429 auto attr = opBuilder.getArrayAttr(elements);
1430 constantMap.try_emplace(resultID, attr, resultType);
1431 } else {
1432 return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1433 << resultType;
1436 return success();
1439 LogicalResult
1440 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1441 if (operands.size() < 2) {
1442 return emitError(unknownLoc,
1443 "OpConstantComposite must have type <id> and result <id>");
1445 if (operands.size() < 3) {
1446 return emitError(unknownLoc,
1447 "OpConstantComposite must have at least 1 parameter");
1450 Type resultType = getType(operands[0]);
1451 if (!resultType) {
1452 return emitError(unknownLoc, "undefined result type from <id> ")
1453 << operands[0];
1456 auto resultID = operands[1];
1457 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1459 SmallVector<Attribute, 4> elements;
1460 elements.reserve(operands.size() - 2);
1461 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1462 auto elementInfo = getSpecConstant(operands[i]);
1463 elements.push_back(SymbolRefAttr::get(elementInfo));
1466 auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1467 unknownLoc, TypeAttr::get(resultType), symName,
1468 opBuilder.getArrayAttr(elements));
1469 specConstCompositeMap[resultID] = op;
1471 return success();
1474 LogicalResult
1475 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1476 if (operands.size() < 3)
1477 return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1478 "result <id>, and operand opcode");
1480 uint32_t resultTypeID = operands[0];
1482 if (!getType(resultTypeID))
1483 return emitError(unknownLoc, "undefined result type from <id> ")
1484 << resultTypeID;
1486 uint32_t resultID = operands[1];
1487 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1488 auto emplaceResult = specConstOperationMap.try_emplace(
1489 resultID,
1490 SpecConstOperationMaterializationInfo{
1491 enclosedOpcode, resultTypeID,
1492 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1494 if (!emplaceResult.second)
1495 return emitError(unknownLoc, "value with <id>: ")
1496 << resultID << " is probably defined before.";
1498 return success();
1501 Value spirv::Deserializer::materializeSpecConstantOperation(
1502 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1503 ArrayRef<uint32_t> enclosedOpOperands) {
1505 Type resultType = getType(resultTypeID);
1507 // Instructions wrapped by OpSpecConstantOp need an ID for their
1508 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1509 // dialect wrapped op. For that purpose, a new value map is created and "fake"
1510 // ID in that map is assigned to the result of the enclosed instruction. Note
1511 // that there is no need to update this fake ID since we only need to
1512 // reference the created Value for the enclosed op from the spv::YieldOp
1513 // created later in this method (both of which are the only values in their
1514 // region: the SpecConstantOperation's region). If we encounter another
1515 // SpecConstantOperation in the module, we simply re-use the fake ID since the
1516 // previous Value assigned to it isn't visible in the current scope anyway.
1517 DenseMap<uint32_t, Value> newValueMap;
1518 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1519 constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1521 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1522 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1523 enclosedOpResultTypeAndOperands.push_back(fakeID);
1524 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1525 enclosedOpOperands.end());
1527 // Process enclosed instruction before creating the enclosing
1528 // specConstantOperation (and its region). This way, references to constants,
1529 // global variables, and spec constants will be materialized outside the new
1530 // op's region. For more info, see Deserializer::getValue's implementation.
1531 if (failed(
1532 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1533 return Value();
1535 // Since the enclosed op is emitted in the current block, split it in a
1536 // separate new block.
1537 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1539 auto loc = createFileLineColLoc(opBuilder);
1540 auto specConstOperationOp =
1541 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1543 Region &body = specConstOperationOp.getBody();
1544 // Move the new block into SpecConstantOperation's body.
1545 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1546 Region::iterator(enclosedBlock));
1547 Block &block = body.back();
1549 // RAII guard to reset the insertion point to the module's region after
1550 // deserializing the body of the specConstantOperation.
1551 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1552 opBuilder.setInsertionPointToEnd(&block);
1554 opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1555 return specConstOperationOp.getResult();
1558 LogicalResult
1559 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1560 if (operands.size() != 2) {
1561 return emitError(unknownLoc,
1562 "OpConstantNull must have type <id> and result <id>");
1565 Type resultType = getType(operands[0]);
1566 if (!resultType) {
1567 return emitError(unknownLoc, "undefined result type from <id> ")
1568 << operands[0];
1571 auto resultID = operands[1];
1572 if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
1573 auto attr = opBuilder.getZeroAttr(resultType);
1574 // For normal constants, we just record the attribute (and its type) for
1575 // later materialization at use sites.
1576 constantMap.try_emplace(resultID, attr, resultType);
1577 return success();
1580 return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1581 << resultType;
1584 //===----------------------------------------------------------------------===//
1585 // Control flow
1586 //===----------------------------------------------------------------------===//
1588 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1589 if (auto *block = getBlock(id)) {
1590 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1591 << " @ " << block << "\n");
1592 return block;
1595 // We don't know where this block will be placed finally (in a
1596 // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1597 // function for now and sort out the proper place later.
1598 auto *block = curFunction->addBlock();
1599 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1600 << " @ " << block << "\n");
1601 return blockMap[id] = block;
1604 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1605 if (!curBlock) {
1606 return emitError(unknownLoc, "OpBranch must appear inside a block");
1609 if (operands.size() != 1) {
1610 return emitError(unknownLoc, "OpBranch must take exactly one target label");
1613 auto *target = getOrCreateBlock(operands[0]);
1614 auto loc = createFileLineColLoc(opBuilder);
1615 // The preceding instruction for the OpBranch instruction could be an
1616 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1617 // the same OpLine information.
1618 opBuilder.create<spirv::BranchOp>(loc, target);
1620 clearDebugLine();
1621 return success();
1624 LogicalResult
1625 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1626 if (!curBlock) {
1627 return emitError(unknownLoc,
1628 "OpBranchConditional must appear inside a block");
1631 if (operands.size() != 3 && operands.size() != 5) {
1632 return emitError(unknownLoc,
1633 "OpBranchConditional must have condition, true label, "
1634 "false label, and optionally two branch weights");
1637 auto condition = getValue(operands[0]);
1638 auto *trueBlock = getOrCreateBlock(operands[1]);
1639 auto *falseBlock = getOrCreateBlock(operands[2]);
1641 std::optional<std::pair<uint32_t, uint32_t>> weights;
1642 if (operands.size() == 5) {
1643 weights = std::make_pair(operands[3], operands[4]);
1645 // The preceding instruction for the OpBranchConditional instruction could be
1646 // an OpSelectionMerge instruction, in this case they will have the same
1647 // OpLine information.
1648 auto loc = createFileLineColLoc(opBuilder);
1649 opBuilder.create<spirv::BranchConditionalOp>(
1650 loc, condition, trueBlock,
1651 /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1652 /*falseArguments=*/ArrayRef<Value>(), weights);
1654 clearDebugLine();
1655 return success();
1658 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1659 if (!curFunction) {
1660 return emitError(unknownLoc, "OpLabel must appear inside a function");
1663 if (operands.size() != 1) {
1664 return emitError(unknownLoc, "OpLabel should only have result <id>");
1667 auto labelID = operands[0];
1668 // We may have forward declared this block.
1669 auto *block = getOrCreateBlock(labelID);
1670 LLVM_DEBUG(logger.startLine()
1671 << "[block] populating block " << block << "\n");
1672 // If we have seen this block, make sure it was just a forward declaration.
1673 assert(block->empty() && "re-deserialize the same block!");
1675 opBuilder.setInsertionPointToStart(block);
1676 blockMap[labelID] = curBlock = block;
1678 return success();
1681 LogicalResult
1682 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1683 if (!curBlock) {
1684 return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1687 if (operands.size() < 2) {
1688 return emitError(
1689 unknownLoc,
1690 "OpSelectionMerge must specify merge target and selection control");
1693 auto *mergeBlock = getOrCreateBlock(operands[0]);
1694 auto loc = createFileLineColLoc(opBuilder);
1695 auto selectionControl = operands[1];
1697 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1698 .second) {
1699 return emitError(
1700 unknownLoc,
1701 "a block cannot have more than one OpSelectionMerge instruction");
1704 return success();
1707 LogicalResult
1708 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1709 if (!curBlock) {
1710 return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1713 if (operands.size() < 3) {
1714 return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1715 "continue target and loop control");
1718 auto *mergeBlock = getOrCreateBlock(operands[0]);
1719 auto *continueBlock = getOrCreateBlock(operands[1]);
1720 auto loc = createFileLineColLoc(opBuilder);
1721 uint32_t loopControl = operands[2];
1723 if (!blockMergeInfo
1724 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1725 .second) {
1726 return emitError(
1727 unknownLoc,
1728 "a block cannot have more than one OpLoopMerge instruction");
1731 return success();
1734 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1735 if (!curBlock) {
1736 return emitError(unknownLoc, "OpPhi must appear in a block");
1739 if (operands.size() < 4) {
1740 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1741 "and variable-parent pairs");
1744 // Create a block argument for this OpPhi instruction.
1745 Type blockArgType = getType(operands[0]);
1746 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1747 valueMap[operands[1]] = blockArg;
1748 LLVM_DEBUG(logger.startLine()
1749 << "[phi] created block argument " << blockArg
1750 << " id = " << operands[1] << " of type " << blockArgType << "\n");
1752 // For each (value, predecessor) pair, insert the value to the predecessor's
1753 // blockPhiInfo entry so later we can fix the block argument there.
1754 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1755 uint32_t value = operands[i];
1756 Block *predecessor = getOrCreateBlock(operands[i + 1]);
1757 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1758 blockPhiInfo[predecessorTargetPair].push_back(value);
1759 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1760 << " with arg id = " << value << "\n");
1763 return success();
1766 namespace {
1767 /// A class for putting all blocks in a structured selection/loop in a
1768 /// spirv.mlir.selection/spirv.mlir.loop op.
1769 class ControlFlowStructurizer {
1770 public:
1771 #ifndef NDEBUG
1772 ControlFlowStructurizer(Location loc, uint32_t control,
1773 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1774 Block *merge, Block *cont,
1775 llvm::ScopedPrinter &logger)
1776 : location(loc), control(control), blockMergeInfo(mergeInfo),
1777 headerBlock(header), mergeBlock(merge), continueBlock(cont),
1778 logger(logger) {}
1779 #else
1780 ControlFlowStructurizer(Location loc, uint32_t control,
1781 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1782 Block *merge, Block *cont)
1783 : location(loc), control(control), blockMergeInfo(mergeInfo),
1784 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1785 #endif
1787 /// Structurizes the loop at the given `headerBlock`.
1789 /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1790 /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1791 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1792 /// method will also update `mergeInfo` by remapping all blocks inside to the
1793 /// newly cloned ones inside structured control flow op's regions.
1794 LogicalResult structurize();
1796 private:
1797 /// Creates a new spirv.mlir.selection op at the beginning of the
1798 /// `mergeBlock`.
1799 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1801 /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1802 spirv::LoopOp createLoopOp(uint32_t loopControl);
1804 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1805 void collectBlocksInConstruct();
1807 Location location;
1808 uint32_t control;
1810 spirv::BlockMergeInfoMap &blockMergeInfo;
1812 Block *headerBlock;
1813 Block *mergeBlock;
1814 Block *continueBlock; // nullptr for spirv.mlir.selection
1816 SetVector<Block *> constructBlocks;
1818 #ifndef NDEBUG
1819 /// A logger used to emit information during the deserialzation process.
1820 llvm::ScopedPrinter &logger;
1821 #endif
1823 } // namespace
1825 spirv::SelectionOp
1826 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1827 // Create a builder and set the insertion point to the beginning of the
1828 // merge block so that the newly created SelectionOp will be inserted there.
1829 OpBuilder builder(&mergeBlock->front());
1831 auto control = static_cast<spirv::SelectionControl>(selectionControl);
1832 auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1833 selectionOp.addMergeBlock(builder);
1835 return selectionOp;
1838 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1839 // Create a builder and set the insertion point to the beginning of the
1840 // merge block so that the newly created LoopOp will be inserted there.
1841 OpBuilder builder(&mergeBlock->front());
1843 auto control = static_cast<spirv::LoopControl>(loopControl);
1844 auto loopOp = builder.create<spirv::LoopOp>(location, control);
1845 loopOp.addEntryAndMergeBlock(builder);
1847 return loopOp;
1850 void ControlFlowStructurizer::collectBlocksInConstruct() {
1851 assert(constructBlocks.empty() && "expected empty constructBlocks");
1853 // Put the header block in the work list first.
1854 constructBlocks.insert(headerBlock);
1856 // For each item in the work list, add its successors excluding the merge
1857 // block.
1858 for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1859 for (auto *successor : constructBlocks[i]->getSuccessors())
1860 if (successor != mergeBlock)
1861 constructBlocks.insert(successor);
1865 LogicalResult ControlFlowStructurizer::structurize() {
1866 Operation *op = nullptr;
1867 bool isLoop = continueBlock != nullptr;
1868 if (isLoop) {
1869 if (auto loopOp = createLoopOp(control))
1870 op = loopOp.getOperation();
1871 } else {
1872 if (auto selectionOp = createSelectionOp(control))
1873 op = selectionOp.getOperation();
1875 if (!op)
1876 return failure();
1877 Region &body = op->getRegion(0);
1879 IRMapping mapper;
1880 // All references to the old merge block should be directed to the
1881 // selection/loop merge block in the SelectionOp/LoopOp's region.
1882 mapper.map(mergeBlock, &body.back());
1884 collectBlocksInConstruct();
1886 // We've identified all blocks belonging to the selection/loop's region. Now
1887 // need to "move" them into the selection/loop. Instead of really moving the
1888 // blocks, in the following we copy them and remap all values and branches.
1889 // This is because:
1890 // * Inserting a block into a region requires the block not in any region
1891 // before. But selections/loops can nest so we can create selection/loop ops
1892 // in a nested manner, which means some blocks may already be in a
1893 // selection/loop region when to be moved again.
1894 // * It's much trickier to fix up the branches into and out of the loop's
1895 // region: we need to treat not-moved blocks and moved blocks differently:
1896 // Not-moved blocks jumping to the loop header block need to jump to the
1897 // merge point containing the new loop op but not the loop continue block's
1898 // back edge. Moved blocks jumping out of the loop need to jump to the
1899 // merge block inside the loop region but not other not-moved blocks.
1900 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1901 // logic.
1903 // Create a corresponding block in the SelectionOp/LoopOp's region for each
1904 // block in this loop construct.
1905 OpBuilder builder(body);
1906 for (auto *block : constructBlocks) {
1907 // Create a block and insert it before the selection/loop merge block in the
1908 // SelectionOp/LoopOp's region.
1909 auto *newBlock = builder.createBlock(&body.back());
1910 mapper.map(block, newBlock);
1911 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1912 << " from block " << block << "\n");
1913 if (!isFnEntryBlock(block)) {
1914 for (BlockArgument blockArg : block->getArguments()) {
1915 auto newArg =
1916 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1917 mapper.map(blockArg, newArg);
1918 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1919 << blockArg << " to " << newArg << "\n");
1921 } else {
1922 LLVM_DEBUG(logger.startLine()
1923 << "[cf] block " << block << " is a function entry block\n");
1926 for (auto &op : *block)
1927 newBlock->push_back(op.clone(mapper));
1930 // Go through all ops and remap the operands.
1931 auto remapOperands = [&](Operation *op) {
1932 for (auto &operand : op->getOpOperands())
1933 if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1934 operand.set(mappedOp);
1935 for (auto &succOp : op->getBlockOperands())
1936 if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1937 succOp.set(mappedOp);
1939 for (auto &block : body)
1940 block.walk(remapOperands);
1942 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1943 // the selection/loop construct into its region. Next we need to fix the
1944 // connections between this new SelectionOp/LoopOp with existing blocks.
1946 // All existing incoming branches should go to the merge block, where the
1947 // SelectionOp/LoopOp resides right now.
1948 headerBlock->replaceAllUsesWith(mergeBlock);
1950 LLVM_DEBUG({
1951 logger.startLine() << "[cf] after cloning and fixing references:\n";
1952 headerBlock->getParentOp()->print(logger.getOStream());
1953 logger.startLine() << "\n";
1956 if (isLoop) {
1957 if (!mergeBlock->args_empty()) {
1958 return mergeBlock->getParentOp()->emitError(
1959 "OpPhi in loop merge block unsupported");
1962 // The loop header block may have block arguments. Since now we place the
1963 // loop op inside the old merge block, we need to make sure the old merge
1964 // block has the same block argument list.
1965 for (BlockArgument blockArg : headerBlock->getArguments())
1966 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1968 // If the loop header block has block arguments, make sure the spirv.Branch
1969 // op matches.
1970 SmallVector<Value, 4> blockArgs;
1971 if (!headerBlock->args_empty())
1972 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1974 // The loop entry block should have a unconditional branch jumping to the
1975 // loop header block.
1976 builder.setInsertionPointToEnd(&body.front());
1977 builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1978 ArrayRef<Value>(blockArgs));
1981 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1982 // cleaned up.
1983 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
1984 // First we need to drop all operands' references inside all blocks. This is
1985 // needed because we can have blocks referencing SSA values from one another.
1986 for (auto *block : constructBlocks)
1987 block->dropAllReferences();
1989 // Check that whether some op in the to-be-erased blocks still has uses. Those
1990 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
1991 // region. We cannot handle such cases given that once a value is sinked into
1992 // the SelectionOp/LoopOp's region, there is no escape for it:
1993 // SelectionOp/LooOp does not support yield values right now.
1994 for (auto *block : constructBlocks) {
1995 for (Operation &op : *block)
1996 if (!op.use_empty())
1997 return op.emitOpError(
1998 "failed control flow structurization: it has uses outside of the "
1999 "enclosing selection/loop construct");
2002 // Then erase all old blocks.
2003 for (auto *block : constructBlocks) {
2004 // We've cloned all blocks belonging to this construct into the structured
2005 // control flow op's region. Among these blocks, some may compose another
2006 // selection/loop. If so, they will be recorded within blockMergeInfo.
2007 // We need to update the pointers there to the newly remapped ones so we can
2008 // continue structurizing them later.
2009 // TODO: The asserts in the following assumes input SPIR-V blob forms
2010 // correctly nested selection/loop constructs. We should relax this and
2011 // support error cases better.
2012 auto it = blockMergeInfo.find(block);
2013 if (it != blockMergeInfo.end()) {
2014 // Use the original location for nested selection/loop ops.
2015 Location loc = it->second.loc;
2017 Block *newHeader = mapper.lookupOrNull(block);
2018 if (!newHeader)
2019 return emitError(loc, "failed control flow structurization: nested "
2020 "loop header block should be remapped!");
2022 Block *newContinue = it->second.continueBlock;
2023 if (newContinue) {
2024 newContinue = mapper.lookupOrNull(newContinue);
2025 if (!newContinue)
2026 return emitError(loc, "failed control flow structurization: nested "
2027 "loop continue block should be remapped!");
2030 Block *newMerge = it->second.mergeBlock;
2031 if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2032 newMerge = mappedTo;
2034 // The iterator should be erased before adding a new entry into
2035 // blockMergeInfo to avoid iterator invalidation.
2036 blockMergeInfo.erase(it);
2037 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2038 newContinue);
2041 // The structured selection/loop's entry block does not have arguments.
2042 // If the function's header block is also part of the structured control
2043 // flow, we cannot just simply erase it because it may contain arguments
2044 // matching the function signature and used by the cloned blocks.
2045 if (isFnEntryBlock(block)) {
2046 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2047 << " to only contain a spirv.Branch op\n");
2048 // Still keep the function entry block for the potential block arguments,
2049 // but replace all ops inside with a branch to the merge block.
2050 block->clear();
2051 builder.setInsertionPointToEnd(block);
2052 builder.create<spirv::BranchOp>(location, mergeBlock);
2053 } else {
2054 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2055 block->erase();
2059 LLVM_DEBUG(logger.startLine()
2060 << "[cf] after structurizing construct with header block "
2061 << headerBlock << ":\n"
2062 << *op << "\n");
2064 return success();
2067 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2068 LLVM_DEBUG({
2069 logger.startLine()
2070 << "//----- [phi] start wiring up block arguments -----//\n";
2071 logger.indent();
2074 OpBuilder::InsertionGuard guard(opBuilder);
2076 for (const auto &info : blockPhiInfo) {
2077 Block *block = info.first.first;
2078 Block *target = info.first.second;
2079 const BlockPhiInfo &phiInfo = info.second;
2080 LLVM_DEBUG({
2081 logger.startLine() << "[phi] block " << block << "\n";
2082 logger.startLine() << "[phi] before creating block argument:\n";
2083 block->getParentOp()->print(logger.getOStream());
2084 logger.startLine() << "\n";
2087 // Set insertion point to before this block's terminator early because we
2088 // may materialize ops via getValue() call.
2089 auto *op = block->getTerminator();
2090 opBuilder.setInsertionPoint(op);
2092 SmallVector<Value, 4> blockArgs;
2093 blockArgs.reserve(phiInfo.size());
2094 for (uint32_t valueId : phiInfo) {
2095 if (Value value = getValue(valueId)) {
2096 blockArgs.push_back(value);
2097 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2098 << " id = " << valueId << "\n");
2099 } else {
2100 return emitError(unknownLoc, "OpPhi references undefined value!");
2104 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2105 // Replace the previous branch op with a new one with block arguments.
2106 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2107 blockArgs);
2108 branchOp.erase();
2109 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2110 assert((branchCondOp.getTrueBlock() == target ||
2111 branchCondOp.getFalseBlock() == target) &&
2112 "expected target to be either the true or false target");
2113 if (target == branchCondOp.getTrueTarget())
2114 opBuilder.create<spirv::BranchConditionalOp>(
2115 branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2116 branchCondOp.getFalseBlockArguments(),
2117 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2118 branchCondOp.getFalseTarget());
2119 else
2120 opBuilder.create<spirv::BranchConditionalOp>(
2121 branchCondOp.getLoc(), branchCondOp.getCondition(),
2122 branchCondOp.getTrueBlockArguments(), blockArgs,
2123 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2124 branchCondOp.getFalseBlock());
2126 branchCondOp.erase();
2127 } else {
2128 return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2131 LLVM_DEBUG({
2132 logger.startLine() << "[phi] after creating block argument:\n";
2133 block->getParentOp()->print(logger.getOStream());
2134 logger.startLine() << "\n";
2137 blockPhiInfo.clear();
2139 LLVM_DEBUG({
2140 logger.unindent();
2141 logger.startLine()
2142 << "//--- [phi] completed wiring up block arguments ---//\n";
2144 return success();
2147 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2148 LLVM_DEBUG({
2149 logger.startLine()
2150 << "//----- [cf] start structurizing control flow -----//\n";
2151 logger.indent();
2154 while (!blockMergeInfo.empty()) {
2155 Block *headerBlock = blockMergeInfo.begin()->first;
2156 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2158 LLVM_DEBUG({
2159 logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2160 headerBlock->print(logger.getOStream());
2161 logger.startLine() << "\n";
2164 auto *mergeBlock = mergeInfo.mergeBlock;
2165 assert(mergeBlock && "merge block cannot be nullptr");
2166 if (!mergeBlock->args_empty())
2167 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2168 LLVM_DEBUG({
2169 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2170 mergeBlock->print(logger.getOStream());
2171 logger.startLine() << "\n";
2174 auto *continueBlock = mergeInfo.continueBlock;
2175 LLVM_DEBUG(if (continueBlock) {
2176 logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2177 continueBlock->print(logger.getOStream());
2178 logger.startLine() << "\n";
2180 // Erase this case before calling into structurizer, who will update
2181 // blockMergeInfo.
2182 blockMergeInfo.erase(blockMergeInfo.begin());
2183 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2184 blockMergeInfo, headerBlock,
2185 mergeBlock, continueBlock
2186 #ifndef NDEBUG
2188 logger
2189 #endif
2191 if (failed(structurizer.structurize()))
2192 return failure();
2195 LLVM_DEBUG({
2196 logger.unindent();
2197 logger.startLine()
2198 << "//--- [cf] completed structurizing control flow ---//\n";
2200 return success();
2203 //===----------------------------------------------------------------------===//
2204 // Debug
2205 //===----------------------------------------------------------------------===//
2207 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2208 if (!debugLine)
2209 return unknownLoc;
2211 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2212 if (fileName.empty())
2213 fileName = "<unknown>";
2214 return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2215 debugLine->column);
2218 LogicalResult
2219 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2220 // According to SPIR-V spec:
2221 // "This location information applies to the instructions physically
2222 // following this instruction, up to the first occurrence of any of the
2223 // following: the next end of block, the next OpLine instruction, or the next
2224 // OpNoLine instruction."
2225 if (operands.size() != 3)
2226 return emitError(unknownLoc, "OpLine must have 3 operands");
2227 debugLine = DebugLine{operands[0], operands[1], operands[2]};
2228 return success();
2231 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2233 LogicalResult
2234 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2235 if (operands.size() < 2)
2236 return emitError(unknownLoc, "OpString needs at least 2 operands");
2238 if (!debugInfoMap.lookup(operands[0]).empty())
2239 return emitError(unknownLoc,
2240 "duplicate debug string found for result <id> ")
2241 << operands[0];
2243 unsigned wordIndex = 1;
2244 StringRef debugString = decodeStringLiteral(operands, wordIndex);
2245 if (wordIndex != operands.size())
2246 return emitError(unknownLoc,
2247 "unexpected trailing words in OpString instruction");
2249 debugInfoMap[operands[0]] = debugString;
2250 return success();