[libc++][NFC] Simplify the implementation of string and string_views operator== ...
[llvm-project.git] / mlir / lib / Target / SPIRV / Deserialization / Deserializer.cpp
blob04469f1933819bfe5ce9264f18988bb474a59ab5
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
11 //===----------------------------------------------------------------------===//
13 #include "Deserializer.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <optional>
33 using namespace mlir;
35 #define DEBUG_TYPE "spirv-deserialization"
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
41 /// Returns true if the given `block` is a function entry block.
42 static inline bool isFnEntryBlock(Block *block) {
43 return block->isEntryBlock() &&
44 isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
51 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52 MLIRContext *context)
53 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
54 module(createModuleOp()), opBuilder(module->getRegion())
55 #ifndef NDEBUG
57 logger(llvm::dbgs())
58 #endif
62 LogicalResult spirv::Deserializer::deserialize() {
63 LLVM_DEBUG({
64 logger.resetIndent();
65 logger.startLine()
66 << "//+++---------- start deserialization ----------+++//\n";
67 });
69 if (failed(processHeader()))
70 return failure();
72 spirv::Opcode opcode = spirv::Opcode::OpNop;
73 ArrayRef<uint32_t> operands;
74 auto binarySize = binary.size();
75 while (curOffset < binarySize) {
76 // Slice the next instruction out and populate `opcode` and `operands`.
77 // Internally this also updates `curOffset`.
78 if (failed(sliceInstruction(opcode, operands)))
79 return failure();
81 if (failed(processInstruction(opcode, operands)))
82 return failure();
85 assert(curOffset == binarySize &&
86 "deserializer should never index beyond the binary end");
88 for (auto &deferred : deferredInstructions) {
89 if (failed(processInstruction(deferred.first, deferred.second, false))) {
90 return failure();
94 attachVCETriple();
96 LLVM_DEBUG(logger.startLine()
97 << "//+++-------- completed deserialization --------+++//\n");
98 return success();
101 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
102 return std::move(module);
105 //===----------------------------------------------------------------------===//
106 // Module structure
107 //===----------------------------------------------------------------------===//
109 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
110 OpBuilder builder(context);
111 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
112 spirv::ModuleOp::build(builder, state);
113 return cast<spirv::ModuleOp>(Operation::create(state));
116 LogicalResult spirv::Deserializer::processHeader() {
117 if (binary.size() < spirv::kHeaderWordCount)
118 return emitError(unknownLoc,
119 "SPIR-V binary module must have a 5-word header");
121 if (binary[0] != spirv::kMagicNumber)
122 return emitError(unknownLoc, "incorrect magic number");
124 // Version number bytes: 0 | major number | minor number | 0
125 uint32_t majorVersion = (binary[1] << 8) >> 24;
126 uint32_t minorVersion = (binary[1] << 16) >> 24;
127 if (majorVersion == 1) {
128 switch (minorVersion) {
129 #define MIN_VERSION_CASE(v) \
130 case v: \
131 version = spirv::Version::V_1_##v; \
132 break
134 MIN_VERSION_CASE(0);
135 MIN_VERSION_CASE(1);
136 MIN_VERSION_CASE(2);
137 MIN_VERSION_CASE(3);
138 MIN_VERSION_CASE(4);
139 MIN_VERSION_CASE(5);
140 #undef MIN_VERSION_CASE
141 default:
142 return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
143 << minorVersion;
145 } else {
146 return emitError(unknownLoc, "unsupported SPIR-V major version: ")
147 << majorVersion;
150 // TODO: generator number, bound, schema
151 curOffset = spirv::kHeaderWordCount;
152 return success();
155 LogicalResult
156 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
157 if (operands.size() != 1)
158 return emitError(unknownLoc, "OpMemoryModel must have one parameter");
160 auto cap = spirv::symbolizeCapability(operands[0]);
161 if (!cap)
162 return emitError(unknownLoc, "unknown capability: ") << operands[0];
164 capabilities.insert(*cap);
165 return success();
168 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
169 if (words.empty()) {
170 return emitError(
171 unknownLoc,
172 "OpExtension must have a literal string for the extension name");
175 unsigned wordIndex = 0;
176 StringRef extName = decodeStringLiteral(words, wordIndex);
177 if (wordIndex != words.size())
178 return emitError(unknownLoc,
179 "unexpected trailing words in OpExtension instruction");
180 auto ext = spirv::symbolizeExtension(extName);
181 if (!ext)
182 return emitError(unknownLoc, "unknown extension: ") << extName;
184 extensions.insert(*ext);
185 return success();
188 LogicalResult
189 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
190 if (words.size() < 2) {
191 return emitError(unknownLoc,
192 "OpExtInstImport must have a result <id> and a literal "
193 "string for the extended instruction set name");
196 unsigned wordIndex = 1;
197 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
198 if (wordIndex != words.size()) {
199 return emitError(unknownLoc,
200 "unexpected trailing words in OpExtInstImport");
202 return success();
205 void spirv::Deserializer::attachVCETriple() {
206 (*module)->setAttr(
207 spirv::ModuleOp::getVCETripleAttrName(),
208 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
209 extensions.getArrayRef(), context));
212 LogicalResult
213 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
214 if (operands.size() != 2)
215 return emitError(unknownLoc, "OpMemoryModel must have two operands");
217 (*module)->setAttr(
218 module->getAddressingModelAttrName(),
219 opBuilder.getAttr<spirv::AddressingModelAttr>(
220 static_cast<spirv::AddressingModel>(operands.front())));
222 (*module)->setAttr(module->getMemoryModelAttrName(),
223 opBuilder.getAttr<spirv::MemoryModelAttr>(
224 static_cast<spirv::MemoryModel>(operands.back())));
226 return success();
229 template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
230 LogicalResult deserializeCacheControlDecoration(
231 Location loc, OpBuilder &opBuilder,
232 DenseMap<uint32_t, NamedAttrList> &decorations, ArrayRef<uint32_t> words,
233 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
234 if (words.size() != 4) {
235 return emitError(loc, "OpDecoration with ")
236 << decorationName << "needs a cache control integer literal and a "
237 << cacheControlKind << " cache control literal";
239 unsigned cacheLevel = words[2];
240 auto cacheControlAttr = static_cast<EnumTy>(words[3]);
241 auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
242 SmallVector<Attribute> attrs;
243 if (auto attrList =
244 llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
245 llvm::append_range(attrs, attrList);
246 attrs.push_back(value);
247 decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
248 return success();
251 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
252 // TODO: This function should also be auto-generated. For now, since only a
253 // few decorations are processed/handled in a meaningful manner, going with a
254 // manual implementation.
255 if (words.size() < 2) {
256 return emitError(
257 unknownLoc, "OpDecorate must have at least result <id> and Decoration");
259 auto decorationName =
260 stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
261 if (decorationName.empty()) {
262 return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
264 auto symbol = getSymbolDecoration(decorationName);
265 switch (static_cast<spirv::Decoration>(words[1])) {
266 case spirv::Decoration::FPFastMathMode:
267 if (words.size() != 3) {
268 return emitError(unknownLoc, "OpDecorate with ")
269 << decorationName << " needs a single integer literal";
271 decorations[words[0]].set(
272 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
273 static_cast<FPFastMathMode>(words[2])));
274 break;
275 case spirv::Decoration::FPRoundingMode:
276 if (words.size() != 3) {
277 return emitError(unknownLoc, "OpDecorate with ")
278 << decorationName << " needs a single integer literal";
280 decorations[words[0]].set(
281 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
282 static_cast<FPRoundingMode>(words[2])));
283 break;
284 case spirv::Decoration::DescriptorSet:
285 case spirv::Decoration::Binding:
286 if (words.size() != 3) {
287 return emitError(unknownLoc, "OpDecorate with ")
288 << decorationName << " needs a single integer literal";
290 decorations[words[0]].set(
291 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
292 break;
293 case spirv::Decoration::BuiltIn:
294 if (words.size() != 3) {
295 return emitError(unknownLoc, "OpDecorate with ")
296 << decorationName << " needs a single integer literal";
298 decorations[words[0]].set(
299 symbol, opBuilder.getStringAttr(
300 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
301 break;
302 case spirv::Decoration::ArrayStride:
303 if (words.size() != 3) {
304 return emitError(unknownLoc, "OpDecorate with ")
305 << decorationName << " needs a single integer literal";
307 typeDecorations[words[0]] = words[2];
308 break;
309 case spirv::Decoration::LinkageAttributes: {
310 if (words.size() < 4) {
311 return emitError(unknownLoc, "OpDecorate with ")
312 << decorationName
313 << " needs at least 1 string and 1 integer literal";
315 // LinkageAttributes has two parameters ["linkageName", linkageType]
316 // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
317 // "linkageName" is a stringliteral encoded as uint32_t,
318 // hence the size of name is variable length which results in words.size()
319 // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
320 // 3 + ceildiv(strlen(name), 4).
321 unsigned wordIndex = 2;
322 auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
323 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
324 static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
325 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
326 StringAttr::get(context, linkageName), linkageTypeAttr);
327 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
328 break;
330 case spirv::Decoration::Aliased:
331 case spirv::Decoration::AliasedPointer:
332 case spirv::Decoration::Block:
333 case spirv::Decoration::BufferBlock:
334 case spirv::Decoration::Flat:
335 case spirv::Decoration::NonReadable:
336 case spirv::Decoration::NonWritable:
337 case spirv::Decoration::NoPerspective:
338 case spirv::Decoration::NoSignedWrap:
339 case spirv::Decoration::NoUnsignedWrap:
340 case spirv::Decoration::RelaxedPrecision:
341 case spirv::Decoration::Restrict:
342 case spirv::Decoration::RestrictPointer:
343 case spirv::Decoration::NoContraction:
344 case spirv::Decoration::Constant:
345 if (words.size() != 2) {
346 return emitError(unknownLoc, "OpDecoration with ")
347 << decorationName << "needs a single target <id>";
349 // Block decoration does not affect spirv.struct type, but is still stored
350 // for verification.
351 // TODO: Update StructType to contain this information since
352 // it is needed for many validation rules.
353 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
354 break;
355 case spirv::Decoration::Location:
356 case spirv::Decoration::SpecId:
357 if (words.size() != 3) {
358 return emitError(unknownLoc, "OpDecoration with ")
359 << decorationName << "needs a single integer literal";
361 decorations[words[0]].set(
362 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
363 break;
364 case spirv::Decoration::CacheControlLoadINTEL: {
365 LogicalResult res = deserializeCacheControlDecoration<
366 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
367 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
368 "load");
369 if (failed(res))
370 return res;
371 break;
373 case spirv::Decoration::CacheControlStoreINTEL: {
374 LogicalResult res = deserializeCacheControlDecoration<
375 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
376 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
377 "store");
378 if (failed(res))
379 return res;
380 break;
382 default:
383 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
385 return success();
388 LogicalResult
389 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
390 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
391 if (words.size() < 3) {
392 return emitError(unknownLoc,
393 "OpMemberDecorate must have at least 3 operands");
396 auto decoration = static_cast<spirv::Decoration>(words[2]);
397 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
398 return emitError(unknownLoc,
399 " missing offset specification in OpMemberDecorate with "
400 "Offset decoration");
402 ArrayRef<uint32_t> decorationOperands;
403 if (words.size() > 3) {
404 decorationOperands = words.slice(3);
406 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
407 return success();
410 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
411 if (words.size() < 3) {
412 return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
414 unsigned wordIndex = 2;
415 auto name = decodeStringLiteral(words, wordIndex);
416 if (wordIndex != words.size()) {
417 return emitError(unknownLoc,
418 "unexpected trailing words in OpMemberName instruction");
420 memberNameMap[words[0]][words[1]] = name;
421 return success();
424 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
425 uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
426 if (!decorations.contains(argID)) {
427 argAttrs[argIndex] = DictionaryAttr::get(context, {});
428 return success();
431 spirv::DecorationAttr foundDecorationAttr;
432 for (NamedAttribute decAttr : decorations[argID]) {
433 for (auto decoration :
434 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
435 spirv::Decoration::AliasedPointer,
436 spirv::Decoration::RestrictPointer}) {
438 if (decAttr.getName() !=
439 getSymbolDecoration(stringifyDecoration(decoration)))
440 continue;
442 if (foundDecorationAttr)
443 return emitError(unknownLoc,
444 "more than one Aliased/Restrict decorations for "
445 "function argument with result <id> ")
446 << argID;
448 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
449 break;
453 if (!foundDecorationAttr)
454 return emitError(unknownLoc, "unimplemented decoration support for "
455 "function argument with result <id> ")
456 << argID;
458 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
459 foundDecorationAttr);
460 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
461 return success();
464 LogicalResult
465 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
466 if (curFunction) {
467 return emitError(unknownLoc, "found function inside function");
470 // Get the result type
471 if (operands.size() != 4) {
472 return emitError(unknownLoc, "OpFunction must have 4 parameters");
474 Type resultType = getType(operands[0]);
475 if (!resultType) {
476 return emitError(unknownLoc, "undefined result type from <id> ")
477 << operands[0];
480 uint32_t fnID = operands[1];
481 if (funcMap.count(fnID)) {
482 return emitError(unknownLoc, "duplicate function definition/declaration");
485 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
486 if (!fnControl) {
487 return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
490 Type fnType = getType(operands[3]);
491 if (!fnType || !isa<FunctionType>(fnType)) {
492 return emitError(unknownLoc, "unknown function type from <id> ")
493 << operands[3];
495 auto functionType = cast<FunctionType>(fnType);
497 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
498 (functionType.getNumResults() == 1 &&
499 functionType.getResult(0) != resultType)) {
500 return emitError(unknownLoc, "mismatch in function type ")
501 << functionType << " and return type " << resultType << " specified";
504 std::string fnName = getFunctionSymbol(fnID);
505 auto funcOp = opBuilder.create<spirv::FuncOp>(
506 unknownLoc, fnName, functionType, fnControl.value());
507 // Processing other function attributes.
508 if (decorations.count(fnID)) {
509 for (auto attr : decorations[fnID].getAttrs()) {
510 funcOp->setAttr(attr.getName(), attr.getValue());
513 curFunction = funcMap[fnID] = funcOp;
514 auto *entryBlock = funcOp.addEntryBlock();
515 LLVM_DEBUG({
516 logger.startLine()
517 << "//===-------------------------------------------===//\n";
518 logger.startLine() << "[fn] name: " << fnName << "\n";
519 logger.startLine() << "[fn] type: " << fnType << "\n";
520 logger.startLine() << "[fn] ID: " << fnID << "\n";
521 logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
522 logger.indent();
525 SmallVector<Attribute> argAttrs;
526 argAttrs.resize(functionType.getNumInputs());
528 // Parse the op argument instructions
529 if (functionType.getNumInputs()) {
530 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
531 auto argType = functionType.getInput(i);
532 spirv::Opcode opcode = spirv::Opcode::OpNop;
533 ArrayRef<uint32_t> operands;
534 if (failed(sliceInstruction(opcode, operands,
535 spirv::Opcode::OpFunctionParameter))) {
536 return failure();
538 if (opcode != spirv::Opcode::OpFunctionParameter) {
539 return emitError(
540 unknownLoc,
541 "missing OpFunctionParameter instruction for argument ")
542 << i;
544 if (operands.size() != 2) {
545 return emitError(
546 unknownLoc,
547 "expected result type and result <id> for OpFunctionParameter");
549 auto argDefinedType = getType(operands[0]);
550 if (!argDefinedType || argDefinedType != argType) {
551 return emitError(unknownLoc,
552 "mismatch in argument type between function type "
553 "definition ")
554 << functionType << " and argument type definition "
555 << argDefinedType << " at argument " << i;
557 if (getValue(operands[1])) {
558 return emitError(unknownLoc, "duplicate definition of result <id> ")
559 << operands[1];
561 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
562 return failure();
565 auto argValue = funcOp.getArgument(i);
566 valueMap[operands[1]] = argValue;
570 if (llvm::any_of(argAttrs, [](Attribute attr) {
571 auto argAttr = cast<DictionaryAttr>(attr);
572 return !argAttr.empty();
574 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
576 // entryBlock is needed to access the arguments, Once that is done, we can
577 // erase the block for functions with 'Import' LinkageAttributes, since these
578 // are essentially function declarations, so they have no body.
579 auto linkageAttr = funcOp.getLinkageAttributes();
580 auto hasImportLinkage =
581 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
582 spirv::LinkageType::Import);
583 if (hasImportLinkage)
584 funcOp.eraseBody();
586 // RAII guard to reset the insertion point to the module's region after
587 // deserializing the body of this function.
588 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
590 spirv::Opcode opcode = spirv::Opcode::OpNop;
591 ArrayRef<uint32_t> instOperands;
593 // Special handling for the entry block. We need to make sure it starts with
594 // an OpLabel instruction. The entry block takes the same parameters as the
595 // function. All other blocks do not take any parameter. We have already
596 // created the entry block, here we need to register it to the correct label
597 // <id>.
598 if (failed(sliceInstruction(opcode, instOperands,
599 spirv::Opcode::OpFunctionEnd))) {
600 return failure();
602 if (opcode == spirv::Opcode::OpFunctionEnd) {
603 return processFunctionEnd(instOperands);
605 if (opcode != spirv::Opcode::OpLabel) {
606 return emitError(unknownLoc, "a basic block must start with OpLabel");
608 if (instOperands.size() != 1) {
609 return emitError(unknownLoc, "OpLabel should only have result <id>");
611 blockMap[instOperands[0]] = entryBlock;
612 if (failed(processLabel(instOperands))) {
613 return failure();
616 // Then process all the other instructions in the function until we hit
617 // OpFunctionEnd.
618 while (succeeded(sliceInstruction(opcode, instOperands,
619 spirv::Opcode::OpFunctionEnd)) &&
620 opcode != spirv::Opcode::OpFunctionEnd) {
621 if (failed(processInstruction(opcode, instOperands))) {
622 return failure();
625 if (opcode != spirv::Opcode::OpFunctionEnd) {
626 return failure();
629 return processFunctionEnd(instOperands);
632 LogicalResult
633 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
634 // Process OpFunctionEnd.
635 if (!operands.empty()) {
636 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
639 // Wire up block arguments from OpPhi instructions.
640 // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
641 // ops.
642 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
643 return failure();
646 curBlock = nullptr;
647 curFunction = std::nullopt;
649 LLVM_DEBUG({
650 logger.unindent();
651 logger.startLine()
652 << "//===-------------------------------------------===//\n";
654 return success();
657 std::optional<std::pair<Attribute, Type>>
658 spirv::Deserializer::getConstant(uint32_t id) {
659 auto constIt = constantMap.find(id);
660 if (constIt == constantMap.end())
661 return std::nullopt;
662 return constIt->getSecond();
665 std::optional<spirv::SpecConstOperationMaterializationInfo>
666 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
667 auto constIt = specConstOperationMap.find(id);
668 if (constIt == specConstOperationMap.end())
669 return std::nullopt;
670 return constIt->getSecond();
673 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
674 auto funcName = nameMap.lookup(id).str();
675 if (funcName.empty()) {
676 funcName = "spirv_fn_" + std::to_string(id);
678 return funcName;
681 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
682 auto constName = nameMap.lookup(id).str();
683 if (constName.empty()) {
684 constName = "spirv_spec_const_" + std::to_string(id);
686 return constName;
689 spirv::SpecConstantOp
690 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
691 TypedAttr defaultValue) {
692 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
693 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
694 defaultValue);
695 if (decorations.count(resultID)) {
696 for (auto attr : decorations[resultID].getAttrs())
697 op->setAttr(attr.getName(), attr.getValue());
699 specConstMap[resultID] = op;
700 return op;
703 LogicalResult
704 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
705 unsigned wordIndex = 0;
706 if (operands.size() < 3) {
707 return emitError(
708 unknownLoc,
709 "OpVariable needs at least 3 operands, type, <id> and storage class");
712 // Result Type.
713 auto type = getType(operands[wordIndex]);
714 if (!type) {
715 return emitError(unknownLoc, "unknown result type <id> : ")
716 << operands[wordIndex];
718 auto ptrType = dyn_cast<spirv::PointerType>(type);
719 if (!ptrType) {
720 return emitError(unknownLoc,
721 "expected a result type <id> to be a spirv.ptr, found : ")
722 << type;
724 wordIndex++;
726 // Result <id>.
727 auto variableID = operands[wordIndex];
728 auto variableName = nameMap.lookup(variableID).str();
729 if (variableName.empty()) {
730 variableName = "spirv_var_" + std::to_string(variableID);
732 wordIndex++;
734 // Storage class.
735 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
736 if (ptrType.getStorageClass() != storageClass) {
737 return emitError(unknownLoc, "mismatch in storage class of pointer type ")
738 << type << " and that specified in OpVariable instruction : "
739 << stringifyStorageClass(storageClass);
741 wordIndex++;
743 // Initializer.
744 FlatSymbolRefAttr initializer = nullptr;
746 if (wordIndex < operands.size()) {
747 Operation *op = nullptr;
749 if (auto initOp = getGlobalVariable(operands[wordIndex]))
750 op = initOp;
751 else if (auto initOp = getSpecConstant(operands[wordIndex]))
752 op = initOp;
753 else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
754 op = initOp;
755 else
756 return emitError(unknownLoc, "unknown <id> ")
757 << operands[wordIndex] << "used as initializer";
759 initializer = SymbolRefAttr::get(op);
760 wordIndex++;
762 if (wordIndex != operands.size()) {
763 return emitError(unknownLoc,
764 "found more operands than expected when deserializing "
765 "OpVariable instruction, only ")
766 << wordIndex << " of " << operands.size() << " processed";
768 auto loc = createFileLineColLoc(opBuilder);
769 auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
770 loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
771 initializer);
773 // Decorations.
774 if (decorations.count(variableID)) {
775 for (auto attr : decorations[variableID].getAttrs())
776 varOp->setAttr(attr.getName(), attr.getValue());
778 globalVariableMap[variableID] = varOp;
779 return success();
782 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
783 auto constInfo = getConstant(id);
784 if (!constInfo) {
785 return nullptr;
787 return dyn_cast<IntegerAttr>(constInfo->first);
790 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
791 if (operands.size() < 2) {
792 return emitError(unknownLoc, "OpName needs at least 2 operands");
794 if (!nameMap.lookup(operands[0]).empty()) {
795 return emitError(unknownLoc, "duplicate name found for result <id> ")
796 << operands[0];
798 unsigned wordIndex = 1;
799 StringRef name = decodeStringLiteral(operands, wordIndex);
800 if (wordIndex != operands.size()) {
801 return emitError(unknownLoc,
802 "unexpected trailing words in OpName instruction");
804 nameMap[operands[0]] = name;
805 return success();
808 //===----------------------------------------------------------------------===//
809 // Type
810 //===----------------------------------------------------------------------===//
812 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
813 ArrayRef<uint32_t> operands) {
814 if (operands.empty()) {
815 return emitError(unknownLoc, "type instruction with opcode ")
816 << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
819 /// TODO: Types might be forward declared in some instructions and need to be
820 /// handled appropriately.
821 if (typeMap.count(operands[0])) {
822 return emitError(unknownLoc, "duplicate definition for result <id> ")
823 << operands[0];
826 switch (opcode) {
827 case spirv::Opcode::OpTypeVoid:
828 if (operands.size() != 1)
829 return emitError(unknownLoc, "OpTypeVoid must have no parameters");
830 typeMap[operands[0]] = opBuilder.getNoneType();
831 break;
832 case spirv::Opcode::OpTypeBool:
833 if (operands.size() != 1)
834 return emitError(unknownLoc, "OpTypeBool must have no parameters");
835 typeMap[operands[0]] = opBuilder.getI1Type();
836 break;
837 case spirv::Opcode::OpTypeInt: {
838 if (operands.size() != 3)
839 return emitError(
840 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
842 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
843 // to preserve or validate.
844 // 0 indicates unsigned, or no signedness semantics
845 // 1 indicates signed semantics."
847 // So we cannot differentiate signless and unsigned integers; always use
848 // signless semantics for such cases.
849 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
850 : IntegerType::SignednessSemantics::Signless;
851 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
852 } break;
853 case spirv::Opcode::OpTypeFloat: {
854 if (operands.size() != 2)
855 return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
857 Type floatTy;
858 switch (operands[1]) {
859 case 16:
860 floatTy = opBuilder.getF16Type();
861 break;
862 case 32:
863 floatTy = opBuilder.getF32Type();
864 break;
865 case 64:
866 floatTy = opBuilder.getF64Type();
867 break;
868 default:
869 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
870 << operands[1];
872 typeMap[operands[0]] = floatTy;
873 } break;
874 case spirv::Opcode::OpTypeVector: {
875 if (operands.size() != 3) {
876 return emitError(
877 unknownLoc,
878 "OpTypeVector must have element type and count parameters");
880 Type elementTy = getType(operands[1]);
881 if (!elementTy) {
882 return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
883 << operands[1];
885 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
886 } break;
887 case spirv::Opcode::OpTypePointer: {
888 return processOpTypePointer(operands);
889 } break;
890 case spirv::Opcode::OpTypeArray:
891 return processArrayType(operands);
892 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
893 return processCooperativeMatrixTypeKHR(operands);
894 case spirv::Opcode::OpTypeFunction:
895 return processFunctionType(operands);
896 case spirv::Opcode::OpTypeImage:
897 return processImageType(operands);
898 case spirv::Opcode::OpTypeSampledImage:
899 return processSampledImageType(operands);
900 case spirv::Opcode::OpTypeRuntimeArray:
901 return processRuntimeArrayType(operands);
902 case spirv::Opcode::OpTypeStruct:
903 return processStructType(operands);
904 case spirv::Opcode::OpTypeMatrix:
905 return processMatrixType(operands);
906 default:
907 return emitError(unknownLoc, "unhandled type instruction");
909 return success();
912 LogicalResult
913 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
914 if (operands.size() != 3)
915 return emitError(unknownLoc, "OpTypePointer must have two parameters");
917 auto pointeeType = getType(operands[2]);
918 if (!pointeeType)
919 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
920 << operands[2];
922 uint32_t typePointerID = operands[0];
923 auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
924 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
926 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
927 deferredStructIt != std::end(deferredStructTypesInfos);) {
928 for (auto *unresolvedMemberIt =
929 std::begin(deferredStructIt->unresolvedMemberTypes);
930 unresolvedMemberIt !=
931 std::end(deferredStructIt->unresolvedMemberTypes);) {
932 if (unresolvedMemberIt->first == typePointerID) {
933 // The newly constructed pointer type can resolve one of the
934 // deferred struct type members; update the memberTypes list and
935 // clean the unresolvedMemberTypes list accordingly.
936 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
937 typeMap[typePointerID];
938 unresolvedMemberIt =
939 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
940 } else {
941 ++unresolvedMemberIt;
945 if (deferredStructIt->unresolvedMemberTypes.empty()) {
946 // All deferred struct type members are now resolved, set the struct body.
947 auto structType = deferredStructIt->deferredStructType;
949 assert(structType && "expected a spirv::StructType");
950 assert(structType.isIdentified() && "expected an indentified struct");
952 if (failed(structType.trySetBody(
953 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
954 deferredStructIt->memberDecorationsInfo)))
955 return failure();
957 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
958 } else {
959 ++deferredStructIt;
963 return success();
966 LogicalResult
967 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
968 if (operands.size() != 3) {
969 return emitError(unknownLoc,
970 "OpTypeArray must have element type and count parameters");
973 Type elementTy = getType(operands[1]);
974 if (!elementTy) {
975 return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
976 << operands[1];
979 unsigned count = 0;
980 // TODO: The count can also come frome a specialization constant.
981 auto countInfo = getConstant(operands[2]);
982 if (!countInfo) {
983 return emitError(unknownLoc, "OpTypeArray count <id> ")
984 << operands[2] << "can only come from normal constant right now";
987 if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
988 count = intVal.getValue().getZExtValue();
989 } else {
990 return emitError(unknownLoc, "OpTypeArray count must come from a "
991 "scalar integer constant instruction");
994 typeMap[operands[0]] = spirv::ArrayType::get(
995 elementTy, count, typeDecorations.lookup(operands[0]));
996 return success();
999 LogicalResult
1000 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
1001 assert(!operands.empty() && "No operands for processing function type");
1002 if (operands.size() == 1) {
1003 return emitError(unknownLoc, "missing return type for OpTypeFunction");
1005 auto returnType = getType(operands[1]);
1006 if (!returnType) {
1007 return emitError(unknownLoc, "unknown return type in OpTypeFunction");
1009 SmallVector<Type, 1> argTypes;
1010 for (size_t i = 2, e = operands.size(); i < e; ++i) {
1011 auto ty = getType(operands[i]);
1012 if (!ty) {
1013 return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
1015 argTypes.push_back(ty);
1017 ArrayRef<Type> returnTypes;
1018 if (!isVoidType(returnType)) {
1019 returnTypes = llvm::ArrayRef(returnType);
1021 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1022 return success();
1025 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1026 ArrayRef<uint32_t> operands) {
1027 if (operands.size() != 6) {
1028 return emitError(unknownLoc,
1029 "OpTypeCooperativeMatrixKHR must have element type, "
1030 "scope, row and column parameters, and use");
1033 Type elementTy = getType(operands[1]);
1034 if (!elementTy) {
1035 return emitError(unknownLoc,
1036 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1037 << operands[1];
1040 std::optional<spirv::Scope> scope =
1041 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1042 if (!scope) {
1043 return emitError(
1044 unknownLoc,
1045 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1046 << operands[2];
1049 unsigned rows = getConstantInt(operands[3]).getInt();
1050 unsigned columns = getConstantInt(operands[4]).getInt();
1052 std::optional<spirv::CooperativeMatrixUseKHR> use =
1053 spirv::symbolizeCooperativeMatrixUseKHR(
1054 getConstantInt(operands[5]).getInt());
1055 if (!use) {
1056 return emitError(
1057 unknownLoc,
1058 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1059 << operands[5];
1062 typeMap[operands[0]] =
1063 spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1064 return success();
1067 LogicalResult
1068 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1069 if (operands.size() != 2) {
1070 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1072 Type memberType = getType(operands[1]);
1073 if (!memberType) {
1074 return emitError(unknownLoc,
1075 "OpTypeRuntimeArray references undefined <id> ")
1076 << operands[1];
1078 typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1079 memberType, typeDecorations.lookup(operands[0]));
1080 return success();
1083 LogicalResult
1084 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1085 // TODO: Find a way to handle identified structs when debug info is stripped.
1087 if (operands.empty()) {
1088 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1091 if (operands.size() == 1) {
1092 // Handle empty struct.
1093 typeMap[operands[0]] =
1094 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1095 return success();
1098 // First element is operand ID, second element is member index in the struct.
1099 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1100 SmallVector<Type, 4> memberTypes;
1102 for (auto op : llvm::drop_begin(operands, 1)) {
1103 Type memberType = getType(op);
1104 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1106 if (!memberType && !typeForwardPtr)
1107 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1108 << op;
1110 if (!memberType)
1111 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1113 memberTypes.push_back(memberType);
1116 SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
1117 SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
1118 if (memberDecorationMap.count(operands[0])) {
1119 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1120 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1121 if (allMemberDecorations.count(memberIndex)) {
1122 for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1123 // Check for offset.
1124 if (memberDecoration.first == spirv::Decoration::Offset) {
1125 // If offset info is empty, resize to the number of members;
1126 if (offsetInfo.empty()) {
1127 offsetInfo.resize(memberTypes.size());
1129 offsetInfo[memberIndex] = memberDecoration.second[0];
1130 } else {
1131 if (!memberDecoration.second.empty()) {
1132 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
1133 memberDecoration.first,
1134 memberDecoration.second[0]);
1135 } else {
1136 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
1137 memberDecoration.first, 0);
1145 uint32_t structID = operands[0];
1146 std::string structIdentifier = nameMap.lookup(structID).str();
1148 if (structIdentifier.empty()) {
1149 assert(unresolvedMemberTypes.empty() &&
1150 "didn't expect unresolved member types");
1151 typeMap[structID] =
1152 spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
1153 } else {
1154 auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1155 typeMap[structID] = structTy;
1157 if (!unresolvedMemberTypes.empty())
1158 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1159 memberTypes, offsetInfo,
1160 memberDecorationsInfo});
1161 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1162 memberDecorationsInfo)))
1163 return failure();
1166 // TODO: Update StructType to have member name as attribute as
1167 // well.
1168 return success();
1171 LogicalResult
1172 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1173 if (operands.size() != 3) {
1174 // Three operands are needed: result_id, column_type, and column_count
1175 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1176 " (result_id, column_type, and column_count)");
1178 // Matrix columns must be of vector type
1179 Type elementTy = getType(operands[1]);
1180 if (!elementTy) {
1181 return emitError(unknownLoc,
1182 "OpTypeMatrix references undefined column type.")
1183 << operands[1];
1186 uint32_t colsCount = operands[2];
1187 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1188 return success();
1191 LogicalResult
1192 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1193 if (operands.size() != 2)
1194 return emitError(unknownLoc,
1195 "OpTypeForwardPointer instruction must have two operands");
1197 typeForwardPointerIDs.insert(operands[0]);
1198 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1199 // instruction that defines the actual type.
1201 return success();
1204 LogicalResult
1205 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1206 // TODO: Add support for Access Qualifier.
1207 if (operands.size() != 8)
1208 return emitError(
1209 unknownLoc,
1210 "OpTypeImage with non-eight operands are not supported yet");
1212 Type elementTy = getType(operands[1]);
1213 if (!elementTy)
1214 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1215 << operands[1];
1217 auto dim = spirv::symbolizeDim(operands[2]);
1218 if (!dim)
1219 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1220 << operands[2];
1222 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1223 if (!depthInfo)
1224 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1225 << operands[3];
1227 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1228 if (!arrayedInfo)
1229 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1230 << operands[4];
1232 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1233 if (!samplingInfo)
1234 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1236 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1237 if (!samplerUseInfo)
1238 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1239 << operands[6];
1241 auto format = spirv::symbolizeImageFormat(operands[7]);
1242 if (!format)
1243 return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1244 << operands[7];
1246 typeMap[operands[0]] = spirv::ImageType::get(
1247 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1248 samplingInfo.value(), samplerUseInfo.value(), format.value());
1249 return success();
1252 LogicalResult
1253 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1254 if (operands.size() != 2)
1255 return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1257 Type elementTy = getType(operands[1]);
1258 if (!elementTy)
1259 return emitError(unknownLoc,
1260 "OpTypeSampledImage references undefined <id>: ")
1261 << operands[1];
1263 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1264 return success();
1267 //===----------------------------------------------------------------------===//
1268 // Constant
1269 //===----------------------------------------------------------------------===//
1271 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1272 bool isSpec) {
1273 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1275 if (operands.size() < 2) {
1276 return emitError(unknownLoc)
1277 << opname << " must have type <id> and result <id>";
1279 if (operands.size() < 3) {
1280 return emitError(unknownLoc)
1281 << opname << " must have at least 1 more parameter";
1284 Type resultType = getType(operands[0]);
1285 if (!resultType) {
1286 return emitError(unknownLoc, "undefined result type from <id> ")
1287 << operands[0];
1290 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1291 if (bitwidth == 64) {
1292 if (operands.size() == 4) {
1293 return success();
1295 return emitError(unknownLoc)
1296 << opname << " should have 2 parameters for 64-bit values";
1298 if (bitwidth <= 32) {
1299 if (operands.size() == 3) {
1300 return success();
1303 return emitError(unknownLoc)
1304 << opname
1305 << " should have 1 parameter for values with no more than 32 bits";
1307 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1308 << bitwidth;
1311 auto resultID = operands[1];
1313 if (auto intType = dyn_cast<IntegerType>(resultType)) {
1314 auto bitwidth = intType.getWidth();
1315 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1316 return failure();
1319 APInt value;
1320 if (bitwidth == 64) {
1321 // 64-bit integers are represented with two SPIR-V words. According to
1322 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1323 // literal’s low-order words appear first."
1324 struct DoubleWord {
1325 uint32_t word1;
1326 uint32_t word2;
1327 } words = {operands[2], operands[3]};
1328 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1329 } else if (bitwidth <= 32) {
1330 value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1331 /*implicitTrunc=*/true);
1334 auto attr = opBuilder.getIntegerAttr(intType, value);
1336 if (isSpec) {
1337 createSpecConstant(unknownLoc, resultID, attr);
1338 } else {
1339 // For normal constants, we just record the attribute (and its type) for
1340 // later materialization at use sites.
1341 constantMap.try_emplace(resultID, attr, intType);
1344 return success();
1347 if (auto floatType = dyn_cast<FloatType>(resultType)) {
1348 auto bitwidth = floatType.getWidth();
1349 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1350 return failure();
1353 APFloat value(0.f);
1354 if (floatType.isF64()) {
1355 // Double values are represented with two SPIR-V words. According to
1356 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1357 // literal’s low-order words appear first."
1358 struct DoubleWord {
1359 uint32_t word1;
1360 uint32_t word2;
1361 } words = {operands[2], operands[3]};
1362 value = APFloat(llvm::bit_cast<double>(words));
1363 } else if (floatType.isF32()) {
1364 value = APFloat(llvm::bit_cast<float>(operands[2]));
1365 } else if (floatType.isF16()) {
1366 APInt data(16, operands[2]);
1367 value = APFloat(APFloat::IEEEhalf(), data);
1370 auto attr = opBuilder.getFloatAttr(floatType, value);
1371 if (isSpec) {
1372 createSpecConstant(unknownLoc, resultID, attr);
1373 } else {
1374 // For normal constants, we just record the attribute (and its type) for
1375 // later materialization at use sites.
1376 constantMap.try_emplace(resultID, attr, floatType);
1379 return success();
1382 return emitError(unknownLoc, "OpConstant can only generate values of "
1383 "scalar integer or floating-point type");
1386 LogicalResult spirv::Deserializer::processConstantBool(
1387 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1388 if (operands.size() != 2) {
1389 return emitError(unknownLoc, "Op")
1390 << (isSpec ? "Spec" : "") << "Constant"
1391 << (isTrue ? "True" : "False")
1392 << " must have type <id> and result <id>";
1395 auto attr = opBuilder.getBoolAttr(isTrue);
1396 auto resultID = operands[1];
1397 if (isSpec) {
1398 createSpecConstant(unknownLoc, resultID, attr);
1399 } else {
1400 // For normal constants, we just record the attribute (and its type) for
1401 // later materialization at use sites.
1402 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1405 return success();
1408 LogicalResult
1409 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1410 if (operands.size() < 2) {
1411 return emitError(unknownLoc,
1412 "OpConstantComposite must have type <id> and result <id>");
1414 if (operands.size() < 3) {
1415 return emitError(unknownLoc,
1416 "OpConstantComposite must have at least 1 parameter");
1419 Type resultType = getType(operands[0]);
1420 if (!resultType) {
1421 return emitError(unknownLoc, "undefined result type from <id> ")
1422 << operands[0];
1425 SmallVector<Attribute, 4> elements;
1426 elements.reserve(operands.size() - 2);
1427 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1428 auto elementInfo = getConstant(operands[i]);
1429 if (!elementInfo) {
1430 return emitError(unknownLoc, "OpConstantComposite component <id> ")
1431 << operands[i] << " must come from a normal constant";
1433 elements.push_back(elementInfo->first);
1436 auto resultID = operands[1];
1437 if (auto vectorType = dyn_cast<VectorType>(resultType)) {
1438 auto attr = DenseElementsAttr::get(vectorType, elements);
1439 // For normal constants, we just record the attribute (and its type) for
1440 // later materialization at use sites.
1441 constantMap.try_emplace(resultID, attr, resultType);
1442 } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1443 auto attr = opBuilder.getArrayAttr(elements);
1444 constantMap.try_emplace(resultID, attr, resultType);
1445 } else {
1446 return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1447 << resultType;
1450 return success();
1453 LogicalResult
1454 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1455 if (operands.size() < 2) {
1456 return emitError(unknownLoc,
1457 "OpConstantComposite must have type <id> and result <id>");
1459 if (operands.size() < 3) {
1460 return emitError(unknownLoc,
1461 "OpConstantComposite must have at least 1 parameter");
1464 Type resultType = getType(operands[0]);
1465 if (!resultType) {
1466 return emitError(unknownLoc, "undefined result type from <id> ")
1467 << operands[0];
1470 auto resultID = operands[1];
1471 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1473 SmallVector<Attribute, 4> elements;
1474 elements.reserve(operands.size() - 2);
1475 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1476 auto elementInfo = getSpecConstant(operands[i]);
1477 elements.push_back(SymbolRefAttr::get(elementInfo));
1480 auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1481 unknownLoc, TypeAttr::get(resultType), symName,
1482 opBuilder.getArrayAttr(elements));
1483 specConstCompositeMap[resultID] = op;
1485 return success();
1488 LogicalResult
1489 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1490 if (operands.size() < 3)
1491 return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1492 "result <id>, and operand opcode");
1494 uint32_t resultTypeID = operands[0];
1496 if (!getType(resultTypeID))
1497 return emitError(unknownLoc, "undefined result type from <id> ")
1498 << resultTypeID;
1500 uint32_t resultID = operands[1];
1501 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1502 auto emplaceResult = specConstOperationMap.try_emplace(
1503 resultID,
1504 SpecConstOperationMaterializationInfo{
1505 enclosedOpcode, resultTypeID,
1506 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1508 if (!emplaceResult.second)
1509 return emitError(unknownLoc, "value with <id>: ")
1510 << resultID << " is probably defined before.";
1512 return success();
1515 Value spirv::Deserializer::materializeSpecConstantOperation(
1516 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1517 ArrayRef<uint32_t> enclosedOpOperands) {
1519 Type resultType = getType(resultTypeID);
1521 // Instructions wrapped by OpSpecConstantOp need an ID for their
1522 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1523 // dialect wrapped op. For that purpose, a new value map is created and "fake"
1524 // ID in that map is assigned to the result of the enclosed instruction. Note
1525 // that there is no need to update this fake ID since we only need to
1526 // reference the created Value for the enclosed op from the spv::YieldOp
1527 // created later in this method (both of which are the only values in their
1528 // region: the SpecConstantOperation's region). If we encounter another
1529 // SpecConstantOperation in the module, we simply re-use the fake ID since the
1530 // previous Value assigned to it isn't visible in the current scope anyway.
1531 DenseMap<uint32_t, Value> newValueMap;
1532 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1533 constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1535 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1536 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1537 enclosedOpResultTypeAndOperands.push_back(fakeID);
1538 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1539 enclosedOpOperands.end());
1541 // Process enclosed instruction before creating the enclosing
1542 // specConstantOperation (and its region). This way, references to constants,
1543 // global variables, and spec constants will be materialized outside the new
1544 // op's region. For more info, see Deserializer::getValue's implementation.
1545 if (failed(
1546 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1547 return Value();
1549 // Since the enclosed op is emitted in the current block, split it in a
1550 // separate new block.
1551 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1553 auto loc = createFileLineColLoc(opBuilder);
1554 auto specConstOperationOp =
1555 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1557 Region &body = specConstOperationOp.getBody();
1558 // Move the new block into SpecConstantOperation's body.
1559 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1560 Region::iterator(enclosedBlock));
1561 Block &block = body.back();
1563 // RAII guard to reset the insertion point to the module's region after
1564 // deserializing the body of the specConstantOperation.
1565 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1566 opBuilder.setInsertionPointToEnd(&block);
1568 opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1569 return specConstOperationOp.getResult();
1572 LogicalResult
1573 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1574 if (operands.size() != 2) {
1575 return emitError(unknownLoc,
1576 "OpConstantNull must have type <id> and result <id>");
1579 Type resultType = getType(operands[0]);
1580 if (!resultType) {
1581 return emitError(unknownLoc, "undefined result type from <id> ")
1582 << operands[0];
1585 auto resultID = operands[1];
1586 if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
1587 auto attr = opBuilder.getZeroAttr(resultType);
1588 // For normal constants, we just record the attribute (and its type) for
1589 // later materialization at use sites.
1590 constantMap.try_emplace(resultID, attr, resultType);
1591 return success();
1594 return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1595 << resultType;
1598 //===----------------------------------------------------------------------===//
1599 // Control flow
1600 //===----------------------------------------------------------------------===//
1602 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1603 if (auto *block = getBlock(id)) {
1604 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1605 << " @ " << block << "\n");
1606 return block;
1609 // We don't know where this block will be placed finally (in a
1610 // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1611 // function for now and sort out the proper place later.
1612 auto *block = curFunction->addBlock();
1613 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1614 << " @ " << block << "\n");
1615 return blockMap[id] = block;
1618 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1619 if (!curBlock) {
1620 return emitError(unknownLoc, "OpBranch must appear inside a block");
1623 if (operands.size() != 1) {
1624 return emitError(unknownLoc, "OpBranch must take exactly one target label");
1627 auto *target = getOrCreateBlock(operands[0]);
1628 auto loc = createFileLineColLoc(opBuilder);
1629 // The preceding instruction for the OpBranch instruction could be an
1630 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1631 // the same OpLine information.
1632 opBuilder.create<spirv::BranchOp>(loc, target);
1634 clearDebugLine();
1635 return success();
1638 LogicalResult
1639 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1640 if (!curBlock) {
1641 return emitError(unknownLoc,
1642 "OpBranchConditional must appear inside a block");
1645 if (operands.size() != 3 && operands.size() != 5) {
1646 return emitError(unknownLoc,
1647 "OpBranchConditional must have condition, true label, "
1648 "false label, and optionally two branch weights");
1651 auto condition = getValue(operands[0]);
1652 auto *trueBlock = getOrCreateBlock(operands[1]);
1653 auto *falseBlock = getOrCreateBlock(operands[2]);
1655 std::optional<std::pair<uint32_t, uint32_t>> weights;
1656 if (operands.size() == 5) {
1657 weights = std::make_pair(operands[3], operands[4]);
1659 // The preceding instruction for the OpBranchConditional instruction could be
1660 // an OpSelectionMerge instruction, in this case they will have the same
1661 // OpLine information.
1662 auto loc = createFileLineColLoc(opBuilder);
1663 opBuilder.create<spirv::BranchConditionalOp>(
1664 loc, condition, trueBlock,
1665 /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1666 /*falseArguments=*/ArrayRef<Value>(), weights);
1668 clearDebugLine();
1669 return success();
1672 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1673 if (!curFunction) {
1674 return emitError(unknownLoc, "OpLabel must appear inside a function");
1677 if (operands.size() != 1) {
1678 return emitError(unknownLoc, "OpLabel should only have result <id>");
1681 auto labelID = operands[0];
1682 // We may have forward declared this block.
1683 auto *block = getOrCreateBlock(labelID);
1684 LLVM_DEBUG(logger.startLine()
1685 << "[block] populating block " << block << "\n");
1686 // If we have seen this block, make sure it was just a forward declaration.
1687 assert(block->empty() && "re-deserialize the same block!");
1689 opBuilder.setInsertionPointToStart(block);
1690 blockMap[labelID] = curBlock = block;
1692 return success();
1695 LogicalResult
1696 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1697 if (!curBlock) {
1698 return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1701 if (operands.size() < 2) {
1702 return emitError(
1703 unknownLoc,
1704 "OpSelectionMerge must specify merge target and selection control");
1707 auto *mergeBlock = getOrCreateBlock(operands[0]);
1708 auto loc = createFileLineColLoc(opBuilder);
1709 auto selectionControl = operands[1];
1711 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1712 .second) {
1713 return emitError(
1714 unknownLoc,
1715 "a block cannot have more than one OpSelectionMerge instruction");
1718 return success();
1721 LogicalResult
1722 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1723 if (!curBlock) {
1724 return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1727 if (operands.size() < 3) {
1728 return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1729 "continue target and loop control");
1732 auto *mergeBlock = getOrCreateBlock(operands[0]);
1733 auto *continueBlock = getOrCreateBlock(operands[1]);
1734 auto loc = createFileLineColLoc(opBuilder);
1735 uint32_t loopControl = operands[2];
1737 if (!blockMergeInfo
1738 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1739 .second) {
1740 return emitError(
1741 unknownLoc,
1742 "a block cannot have more than one OpLoopMerge instruction");
1745 return success();
1748 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1749 if (!curBlock) {
1750 return emitError(unknownLoc, "OpPhi must appear in a block");
1753 if (operands.size() < 4) {
1754 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1755 "and variable-parent pairs");
1758 // Create a block argument for this OpPhi instruction.
1759 Type blockArgType = getType(operands[0]);
1760 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1761 valueMap[operands[1]] = blockArg;
1762 LLVM_DEBUG(logger.startLine()
1763 << "[phi] created block argument " << blockArg
1764 << " id = " << operands[1] << " of type " << blockArgType << "\n");
1766 // For each (value, predecessor) pair, insert the value to the predecessor's
1767 // blockPhiInfo entry so later we can fix the block argument there.
1768 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1769 uint32_t value = operands[i];
1770 Block *predecessor = getOrCreateBlock(operands[i + 1]);
1771 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1772 blockPhiInfo[predecessorTargetPair].push_back(value);
1773 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1774 << " with arg id = " << value << "\n");
1777 return success();
1780 namespace {
1781 /// A class for putting all blocks in a structured selection/loop in a
1782 /// spirv.mlir.selection/spirv.mlir.loop op.
1783 class ControlFlowStructurizer {
1784 public:
1785 #ifndef NDEBUG
1786 ControlFlowStructurizer(Location loc, uint32_t control,
1787 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1788 Block *merge, Block *cont,
1789 llvm::ScopedPrinter &logger)
1790 : location(loc), control(control), blockMergeInfo(mergeInfo),
1791 headerBlock(header), mergeBlock(merge), continueBlock(cont),
1792 logger(logger) {}
1793 #else
1794 ControlFlowStructurizer(Location loc, uint32_t control,
1795 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1796 Block *merge, Block *cont)
1797 : location(loc), control(control), blockMergeInfo(mergeInfo),
1798 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1799 #endif
1801 /// Structurizes the loop at the given `headerBlock`.
1803 /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1804 /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1805 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1806 /// method will also update `mergeInfo` by remapping all blocks inside to the
1807 /// newly cloned ones inside structured control flow op's regions.
1808 LogicalResult structurize();
1810 private:
1811 /// Creates a new spirv.mlir.selection op at the beginning of the
1812 /// `mergeBlock`.
1813 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1815 /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1816 spirv::LoopOp createLoopOp(uint32_t loopControl);
1818 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1819 void collectBlocksInConstruct();
1821 Location location;
1822 uint32_t control;
1824 spirv::BlockMergeInfoMap &blockMergeInfo;
1826 Block *headerBlock;
1827 Block *mergeBlock;
1828 Block *continueBlock; // nullptr for spirv.mlir.selection
1830 SetVector<Block *> constructBlocks;
1832 #ifndef NDEBUG
1833 /// A logger used to emit information during the deserialzation process.
1834 llvm::ScopedPrinter &logger;
1835 #endif
1837 } // namespace
1839 spirv::SelectionOp
1840 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1841 // Create a builder and set the insertion point to the beginning of the
1842 // merge block so that the newly created SelectionOp will be inserted there.
1843 OpBuilder builder(&mergeBlock->front());
1845 auto control = static_cast<spirv::SelectionControl>(selectionControl);
1846 auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1847 selectionOp.addMergeBlock(builder);
1849 return selectionOp;
1852 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1853 // Create a builder and set the insertion point to the beginning of the
1854 // merge block so that the newly created LoopOp will be inserted there.
1855 OpBuilder builder(&mergeBlock->front());
1857 auto control = static_cast<spirv::LoopControl>(loopControl);
1858 auto loopOp = builder.create<spirv::LoopOp>(location, control);
1859 loopOp.addEntryAndMergeBlock(builder);
1861 return loopOp;
1864 void ControlFlowStructurizer::collectBlocksInConstruct() {
1865 assert(constructBlocks.empty() && "expected empty constructBlocks");
1867 // Put the header block in the work list first.
1868 constructBlocks.insert(headerBlock);
1870 // For each item in the work list, add its successors excluding the merge
1871 // block.
1872 for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1873 for (auto *successor : constructBlocks[i]->getSuccessors())
1874 if (successor != mergeBlock)
1875 constructBlocks.insert(successor);
1879 LogicalResult ControlFlowStructurizer::structurize() {
1880 Operation *op = nullptr;
1881 bool isLoop = continueBlock != nullptr;
1882 if (isLoop) {
1883 if (auto loopOp = createLoopOp(control))
1884 op = loopOp.getOperation();
1885 } else {
1886 if (auto selectionOp = createSelectionOp(control))
1887 op = selectionOp.getOperation();
1889 if (!op)
1890 return failure();
1891 Region &body = op->getRegion(0);
1893 IRMapping mapper;
1894 // All references to the old merge block should be directed to the
1895 // selection/loop merge block in the SelectionOp/LoopOp's region.
1896 mapper.map(mergeBlock, &body.back());
1898 collectBlocksInConstruct();
1900 // We've identified all blocks belonging to the selection/loop's region. Now
1901 // need to "move" them into the selection/loop. Instead of really moving the
1902 // blocks, in the following we copy them and remap all values and branches.
1903 // This is because:
1904 // * Inserting a block into a region requires the block not in any region
1905 // before. But selections/loops can nest so we can create selection/loop ops
1906 // in a nested manner, which means some blocks may already be in a
1907 // selection/loop region when to be moved again.
1908 // * It's much trickier to fix up the branches into and out of the loop's
1909 // region: we need to treat not-moved blocks and moved blocks differently:
1910 // Not-moved blocks jumping to the loop header block need to jump to the
1911 // merge point containing the new loop op but not the loop continue block's
1912 // back edge. Moved blocks jumping out of the loop need to jump to the
1913 // merge block inside the loop region but not other not-moved blocks.
1914 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1915 // logic.
1917 // Create a corresponding block in the SelectionOp/LoopOp's region for each
1918 // block in this loop construct.
1919 OpBuilder builder(body);
1920 for (auto *block : constructBlocks) {
1921 // Create a block and insert it before the selection/loop merge block in the
1922 // SelectionOp/LoopOp's region.
1923 auto *newBlock = builder.createBlock(&body.back());
1924 mapper.map(block, newBlock);
1925 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1926 << " from block " << block << "\n");
1927 if (!isFnEntryBlock(block)) {
1928 for (BlockArgument blockArg : block->getArguments()) {
1929 auto newArg =
1930 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1931 mapper.map(blockArg, newArg);
1932 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1933 << blockArg << " to " << newArg << "\n");
1935 } else {
1936 LLVM_DEBUG(logger.startLine()
1937 << "[cf] block " << block << " is a function entry block\n");
1940 for (auto &op : *block)
1941 newBlock->push_back(op.clone(mapper));
1944 // Go through all ops and remap the operands.
1945 auto remapOperands = [&](Operation *op) {
1946 for (auto &operand : op->getOpOperands())
1947 if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1948 operand.set(mappedOp);
1949 for (auto &succOp : op->getBlockOperands())
1950 if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1951 succOp.set(mappedOp);
1953 for (auto &block : body)
1954 block.walk(remapOperands);
1956 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1957 // the selection/loop construct into its region. Next we need to fix the
1958 // connections between this new SelectionOp/LoopOp with existing blocks.
1960 // All existing incoming branches should go to the merge block, where the
1961 // SelectionOp/LoopOp resides right now.
1962 headerBlock->replaceAllUsesWith(mergeBlock);
1964 LLVM_DEBUG({
1965 logger.startLine() << "[cf] after cloning and fixing references:\n";
1966 headerBlock->getParentOp()->print(logger.getOStream());
1967 logger.startLine() << "\n";
1970 if (isLoop) {
1971 if (!mergeBlock->args_empty()) {
1972 return mergeBlock->getParentOp()->emitError(
1973 "OpPhi in loop merge block unsupported");
1976 // The loop header block may have block arguments. Since now we place the
1977 // loop op inside the old merge block, we need to make sure the old merge
1978 // block has the same block argument list.
1979 for (BlockArgument blockArg : headerBlock->getArguments())
1980 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1982 // If the loop header block has block arguments, make sure the spirv.Branch
1983 // op matches.
1984 SmallVector<Value, 4> blockArgs;
1985 if (!headerBlock->args_empty())
1986 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1988 // The loop entry block should have a unconditional branch jumping to the
1989 // loop header block.
1990 builder.setInsertionPointToEnd(&body.front());
1991 builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1992 ArrayRef<Value>(blockArgs));
1995 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1996 // cleaned up.
1997 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
1998 // First we need to drop all operands' references inside all blocks. This is
1999 // needed because we can have blocks referencing SSA values from one another.
2000 for (auto *block : constructBlocks)
2001 block->dropAllReferences();
2003 // Check that whether some op in the to-be-erased blocks still has uses. Those
2004 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2005 // region. We cannot handle such cases given that once a value is sinked into
2006 // the SelectionOp/LoopOp's region, there is no escape for it:
2007 // SelectionOp/LooOp does not support yield values right now.
2008 for (auto *block : constructBlocks) {
2009 for (Operation &op : *block)
2010 if (!op.use_empty())
2011 return op.emitOpError(
2012 "failed control flow structurization: it has uses outside of the "
2013 "enclosing selection/loop construct");
2016 // Then erase all old blocks.
2017 for (auto *block : constructBlocks) {
2018 // We've cloned all blocks belonging to this construct into the structured
2019 // control flow op's region. Among these blocks, some may compose another
2020 // selection/loop. If so, they will be recorded within blockMergeInfo.
2021 // We need to update the pointers there to the newly remapped ones so we can
2022 // continue structurizing them later.
2023 // TODO: The asserts in the following assumes input SPIR-V blob forms
2024 // correctly nested selection/loop constructs. We should relax this and
2025 // support error cases better.
2026 auto it = blockMergeInfo.find(block);
2027 if (it != blockMergeInfo.end()) {
2028 // Use the original location for nested selection/loop ops.
2029 Location loc = it->second.loc;
2031 Block *newHeader = mapper.lookupOrNull(block);
2032 if (!newHeader)
2033 return emitError(loc, "failed control flow structurization: nested "
2034 "loop header block should be remapped!");
2036 Block *newContinue = it->second.continueBlock;
2037 if (newContinue) {
2038 newContinue = mapper.lookupOrNull(newContinue);
2039 if (!newContinue)
2040 return emitError(loc, "failed control flow structurization: nested "
2041 "loop continue block should be remapped!");
2044 Block *newMerge = it->second.mergeBlock;
2045 if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2046 newMerge = mappedTo;
2048 // The iterator should be erased before adding a new entry into
2049 // blockMergeInfo to avoid iterator invalidation.
2050 blockMergeInfo.erase(it);
2051 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2052 newContinue);
2055 // The structured selection/loop's entry block does not have arguments.
2056 // If the function's header block is also part of the structured control
2057 // flow, we cannot just simply erase it because it may contain arguments
2058 // matching the function signature and used by the cloned blocks.
2059 if (isFnEntryBlock(block)) {
2060 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2061 << " to only contain a spirv.Branch op\n");
2062 // Still keep the function entry block for the potential block arguments,
2063 // but replace all ops inside with a branch to the merge block.
2064 block->clear();
2065 builder.setInsertionPointToEnd(block);
2066 builder.create<spirv::BranchOp>(location, mergeBlock);
2067 } else {
2068 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2069 block->erase();
2073 LLVM_DEBUG(logger.startLine()
2074 << "[cf] after structurizing construct with header block "
2075 << headerBlock << ":\n"
2076 << *op << "\n");
2078 return success();
2081 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2082 LLVM_DEBUG({
2083 logger.startLine()
2084 << "//----- [phi] start wiring up block arguments -----//\n";
2085 logger.indent();
2088 OpBuilder::InsertionGuard guard(opBuilder);
2090 for (const auto &info : blockPhiInfo) {
2091 Block *block = info.first.first;
2092 Block *target = info.first.second;
2093 const BlockPhiInfo &phiInfo = info.second;
2094 LLVM_DEBUG({
2095 logger.startLine() << "[phi] block " << block << "\n";
2096 logger.startLine() << "[phi] before creating block argument:\n";
2097 block->getParentOp()->print(logger.getOStream());
2098 logger.startLine() << "\n";
2101 // Set insertion point to before this block's terminator early because we
2102 // may materialize ops via getValue() call.
2103 auto *op = block->getTerminator();
2104 opBuilder.setInsertionPoint(op);
2106 SmallVector<Value, 4> blockArgs;
2107 blockArgs.reserve(phiInfo.size());
2108 for (uint32_t valueId : phiInfo) {
2109 if (Value value = getValue(valueId)) {
2110 blockArgs.push_back(value);
2111 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2112 << " id = " << valueId << "\n");
2113 } else {
2114 return emitError(unknownLoc, "OpPhi references undefined value!");
2118 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2119 // Replace the previous branch op with a new one with block arguments.
2120 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2121 blockArgs);
2122 branchOp.erase();
2123 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2124 assert((branchCondOp.getTrueBlock() == target ||
2125 branchCondOp.getFalseBlock() == target) &&
2126 "expected target to be either the true or false target");
2127 if (target == branchCondOp.getTrueTarget())
2128 opBuilder.create<spirv::BranchConditionalOp>(
2129 branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2130 branchCondOp.getFalseBlockArguments(),
2131 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2132 branchCondOp.getFalseTarget());
2133 else
2134 opBuilder.create<spirv::BranchConditionalOp>(
2135 branchCondOp.getLoc(), branchCondOp.getCondition(),
2136 branchCondOp.getTrueBlockArguments(), blockArgs,
2137 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2138 branchCondOp.getFalseBlock());
2140 branchCondOp.erase();
2141 } else {
2142 return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2145 LLVM_DEBUG({
2146 logger.startLine() << "[phi] after creating block argument:\n";
2147 block->getParentOp()->print(logger.getOStream());
2148 logger.startLine() << "\n";
2151 blockPhiInfo.clear();
2153 LLVM_DEBUG({
2154 logger.unindent();
2155 logger.startLine()
2156 << "//--- [phi] completed wiring up block arguments ---//\n";
2158 return success();
2161 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2162 LLVM_DEBUG({
2163 logger.startLine()
2164 << "//----- [cf] start structurizing control flow -----//\n";
2165 logger.indent();
2168 while (!blockMergeInfo.empty()) {
2169 Block *headerBlock = blockMergeInfo.begin()->first;
2170 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2172 LLVM_DEBUG({
2173 logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2174 headerBlock->print(logger.getOStream());
2175 logger.startLine() << "\n";
2178 auto *mergeBlock = mergeInfo.mergeBlock;
2179 assert(mergeBlock && "merge block cannot be nullptr");
2180 if (!mergeBlock->args_empty())
2181 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2182 LLVM_DEBUG({
2183 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2184 mergeBlock->print(logger.getOStream());
2185 logger.startLine() << "\n";
2188 auto *continueBlock = mergeInfo.continueBlock;
2189 LLVM_DEBUG(if (continueBlock) {
2190 logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2191 continueBlock->print(logger.getOStream());
2192 logger.startLine() << "\n";
2194 // Erase this case before calling into structurizer, who will update
2195 // blockMergeInfo.
2196 blockMergeInfo.erase(blockMergeInfo.begin());
2197 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2198 blockMergeInfo, headerBlock,
2199 mergeBlock, continueBlock
2200 #ifndef NDEBUG
2202 logger
2203 #endif
2205 if (failed(structurizer.structurize()))
2206 return failure();
2209 LLVM_DEBUG({
2210 logger.unindent();
2211 logger.startLine()
2212 << "//--- [cf] completed structurizing control flow ---//\n";
2214 return success();
2217 //===----------------------------------------------------------------------===//
2218 // Debug
2219 //===----------------------------------------------------------------------===//
2221 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2222 if (!debugLine)
2223 return unknownLoc;
2225 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2226 if (fileName.empty())
2227 fileName = "<unknown>";
2228 return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2229 debugLine->column);
2232 LogicalResult
2233 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2234 // According to SPIR-V spec:
2235 // "This location information applies to the instructions physically
2236 // following this instruction, up to the first occurrence of any of the
2237 // following: the next end of block, the next OpLine instruction, or the next
2238 // OpNoLine instruction."
2239 if (operands.size() != 3)
2240 return emitError(unknownLoc, "OpLine must have 3 operands");
2241 debugLine = DebugLine{operands[0], operands[1], operands[2]};
2242 return success();
2245 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2247 LogicalResult
2248 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2249 if (operands.size() < 2)
2250 return emitError(unknownLoc, "OpString needs at least 2 operands");
2252 if (!debugInfoMap.lookup(operands[0]).empty())
2253 return emitError(unknownLoc,
2254 "duplicate debug string found for result <id> ")
2255 << operands[0];
2257 unsigned wordIndex = 1;
2258 StringRef debugString = decodeStringLiteral(operands, wordIndex);
2259 if (wordIndex != operands.size())
2260 return emitError(unknownLoc,
2261 "unexpected trailing words in OpString instruction");
2263 debugInfoMap[operands[0]] = debugString;
2264 return success();