[OpenACC] Create AST nodes for 'data' constructs
[llvm-project.git] / mlir / lib / Target / SPIRV / Deserialization / DeserializeOps.cpp
blobb30da773d48967308a82acd8b3583ca9bbb7f6f5
1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
11 //===----------------------------------------------------------------------===//
13 #include "Deserializer.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Location.h"
19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Debug.h"
23 #include <optional>
25 using namespace mlir;
27 #define DEBUG_TYPE "spirv-deserialization"
29 //===----------------------------------------------------------------------===//
30 // Utility Functions
31 //===----------------------------------------------------------------------===//
33 /// Extracts the opcode from the given first word of a SPIR-V instruction.
34 static inline spirv::Opcode extractOpcode(uint32_t word) {
35 return static_cast<spirv::Opcode>(word & 0xffff);
38 //===----------------------------------------------------------------------===//
39 // Instruction
40 //===----------------------------------------------------------------------===//
42 Value spirv::Deserializer::getValue(uint32_t id) {
43 if (auto constInfo = getConstant(id)) {
44 // Materialize a `spirv.Constant` op at every use site.
45 return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
46 constInfo->first);
48 if (auto varOp = getGlobalVariable(id)) {
49 auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
50 unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
51 return addressOfOp.getPointer();
53 if (auto constOp = getSpecConstant(id)) {
54 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
55 unknownLoc, constOp.getDefaultValue().getType(),
56 SymbolRefAttr::get(constOp.getOperation()));
57 return referenceOfOp.getReference();
59 if (auto constCompositeOp = getSpecConstantComposite(id)) {
60 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
61 unknownLoc, constCompositeOp.getType(),
62 SymbolRefAttr::get(constCompositeOp.getOperation()));
63 return referenceOfOp.getReference();
65 if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
66 return materializeSpecConstantOperation(
67 id, specConstOperationInfo->enclodesOpcode,
68 specConstOperationInfo->resultTypeID,
69 specConstOperationInfo->enclosedOpOperands);
71 if (auto undef = getUndefType(id)) {
72 return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
74 return valueMap.lookup(id);
77 LogicalResult spirv::Deserializer::sliceInstruction(
78 spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
79 std::optional<spirv::Opcode> expectedOpcode) {
80 auto binarySize = binary.size();
81 if (curOffset >= binarySize) {
82 return emitError(unknownLoc, "expected ")
83 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
84 : "more")
85 << " instruction";
88 // For each instruction, get its word count from the first word to slice it
89 // from the stream properly, and then dispatch to the instruction handler.
91 uint32_t wordCount = binary[curOffset] >> 16;
93 if (wordCount == 0)
94 return emitError(unknownLoc, "word count cannot be zero");
96 uint32_t nextOffset = curOffset + wordCount;
97 if (nextOffset > binarySize)
98 return emitError(unknownLoc, "insufficient words for the last instruction");
100 opcode = extractOpcode(binary[curOffset]);
101 operands = binary.slice(curOffset + 1, wordCount - 1);
102 curOffset = nextOffset;
103 return success();
106 LogicalResult spirv::Deserializer::processInstruction(
107 spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
108 LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
109 << spirv::stringifyOpcode(opcode) << "\n");
111 // First dispatch all the instructions whose opcode does not correspond to
112 // those that have a direct mirror in the SPIR-V dialect
113 switch (opcode) {
114 case spirv::Opcode::OpCapability:
115 return processCapability(operands);
116 case spirv::Opcode::OpExtension:
117 return processExtension(operands);
118 case spirv::Opcode::OpExtInst:
119 return processExtInst(operands);
120 case spirv::Opcode::OpExtInstImport:
121 return processExtInstImport(operands);
122 case spirv::Opcode::OpMemberName:
123 return processMemberName(operands);
124 case spirv::Opcode::OpMemoryModel:
125 return processMemoryModel(operands);
126 case spirv::Opcode::OpEntryPoint:
127 case spirv::Opcode::OpExecutionMode:
128 if (deferInstructions) {
129 deferredInstructions.emplace_back(opcode, operands);
130 return success();
132 break;
133 case spirv::Opcode::OpVariable:
134 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
135 return processGlobalVariable(operands);
137 break;
138 case spirv::Opcode::OpLine:
139 return processDebugLine(operands);
140 case spirv::Opcode::OpNoLine:
141 clearDebugLine();
142 return success();
143 case spirv::Opcode::OpName:
144 return processName(operands);
145 case spirv::Opcode::OpString:
146 return processDebugString(operands);
147 case spirv::Opcode::OpModuleProcessed:
148 case spirv::Opcode::OpSource:
149 case spirv::Opcode::OpSourceContinued:
150 case spirv::Opcode::OpSourceExtension:
151 // TODO: This is debug information embedded in the binary which should be
152 // translated into the spirv.module.
153 return success();
154 case spirv::Opcode::OpTypeVoid:
155 case spirv::Opcode::OpTypeBool:
156 case spirv::Opcode::OpTypeInt:
157 case spirv::Opcode::OpTypeFloat:
158 case spirv::Opcode::OpTypeVector:
159 case spirv::Opcode::OpTypeMatrix:
160 case spirv::Opcode::OpTypeArray:
161 case spirv::Opcode::OpTypeFunction:
162 case spirv::Opcode::OpTypeImage:
163 case spirv::Opcode::OpTypeSampledImage:
164 case spirv::Opcode::OpTypeRuntimeArray:
165 case spirv::Opcode::OpTypeStruct:
166 case spirv::Opcode::OpTypePointer:
167 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
168 return processType(opcode, operands);
169 case spirv::Opcode::OpTypeForwardPointer:
170 return processTypeForwardPointer(operands);
171 case spirv::Opcode::OpConstant:
172 return processConstant(operands, /*isSpec=*/false);
173 case spirv::Opcode::OpSpecConstant:
174 return processConstant(operands, /*isSpec=*/true);
175 case spirv::Opcode::OpConstantComposite:
176 return processConstantComposite(operands);
177 case spirv::Opcode::OpSpecConstantComposite:
178 return processSpecConstantComposite(operands);
179 case spirv::Opcode::OpSpecConstantOp:
180 return processSpecConstantOperation(operands);
181 case spirv::Opcode::OpConstantTrue:
182 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
183 case spirv::Opcode::OpSpecConstantTrue:
184 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
185 case spirv::Opcode::OpConstantFalse:
186 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
187 case spirv::Opcode::OpSpecConstantFalse:
188 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
189 case spirv::Opcode::OpConstantNull:
190 return processConstantNull(operands);
191 case spirv::Opcode::OpDecorate:
192 return processDecoration(operands);
193 case spirv::Opcode::OpMemberDecorate:
194 return processMemberDecoration(operands);
195 case spirv::Opcode::OpFunction:
196 return processFunction(operands);
197 case spirv::Opcode::OpLabel:
198 return processLabel(operands);
199 case spirv::Opcode::OpBranch:
200 return processBranch(operands);
201 case spirv::Opcode::OpBranchConditional:
202 return processBranchConditional(operands);
203 case spirv::Opcode::OpSelectionMerge:
204 return processSelectionMerge(operands);
205 case spirv::Opcode::OpLoopMerge:
206 return processLoopMerge(operands);
207 case spirv::Opcode::OpPhi:
208 return processPhi(operands);
209 case spirv::Opcode::OpUndef:
210 return processUndef(operands);
211 default:
212 break;
214 return dispatchToAutogenDeserialization(opcode, operands);
217 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
218 ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
219 unsigned numOperands) {
220 SmallVector<Type, 1> resultTypes;
221 uint32_t valueID = 0;
223 size_t wordIndex = 0;
224 if (hasResult) {
225 if (wordIndex >= words.size())
226 return emitError(unknownLoc,
227 "expected result type <id> while deserializing for ")
228 << opName;
230 // Decode the type <id>
231 auto type = getType(words[wordIndex]);
232 if (!type)
233 return emitError(unknownLoc, "unknown type result <id>: ")
234 << words[wordIndex];
235 resultTypes.push_back(type);
236 ++wordIndex;
238 // Decode the result <id>
239 if (wordIndex >= words.size())
240 return emitError(unknownLoc,
241 "expected result <id> while deserializing for ")
242 << opName;
243 valueID = words[wordIndex];
244 ++wordIndex;
247 SmallVector<Value, 4> operands;
248 SmallVector<NamedAttribute, 4> attributes;
250 // Decode operands
251 size_t operandIndex = 0;
252 for (; operandIndex < numOperands && wordIndex < words.size();
253 ++operandIndex, ++wordIndex) {
254 auto arg = getValue(words[wordIndex]);
255 if (!arg)
256 return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
257 operands.push_back(arg);
259 if (operandIndex != numOperands) {
260 return emitError(
261 unknownLoc,
262 "found less operands than expected when deserializing for ")
263 << opName << "; only " << operandIndex << " of " << numOperands
264 << " processed";
266 if (wordIndex != words.size()) {
267 return emitError(
268 unknownLoc,
269 "found more operands than expected when deserializing for ")
270 << opName << "; only " << wordIndex << " of " << words.size()
271 << " processed";
274 // Attach attributes from decorations
275 if (decorations.count(valueID)) {
276 auto attrs = decorations[valueID].getAttrs();
277 attributes.append(attrs.begin(), attrs.end());
280 // Create the op and update bookkeeping maps
281 Location loc = createFileLineColLoc(opBuilder);
282 OperationState opState(loc, opName);
283 opState.addOperands(operands);
284 if (hasResult)
285 opState.addTypes(resultTypes);
286 opState.addAttributes(attributes);
287 Operation *op = opBuilder.create(opState);
288 if (hasResult)
289 valueMap[valueID] = op->getResult(0);
291 if (op->hasTrait<OpTrait::IsTerminator>())
292 clearDebugLine();
294 return success();
297 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
298 if (operands.size() != 2) {
299 return emitError(unknownLoc, "OpUndef instruction must have two operands");
301 auto type = getType(operands[0]);
302 if (!type) {
303 return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
305 undefMap[operands[1]] = type;
306 return success();
309 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
310 if (operands.size() < 4) {
311 return emitError(unknownLoc,
312 "OpExtInst must have at least 4 operands, result type "
313 "<id>, result <id>, set <id> and instruction opcode");
315 if (!extendedInstSets.count(operands[2])) {
316 return emitError(unknownLoc, "undefined set <id> in OpExtInst");
318 SmallVector<uint32_t, 4> slicedOperands;
319 slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
320 slicedOperands.append(std::next(operands.begin(), 4), operands.end());
321 return dispatchToExtensionSetAutogenDeserialization(
322 extendedInstSets[operands[2]], operands[3], slicedOperands);
325 namespace mlir {
326 namespace spirv {
328 template <>
329 LogicalResult
330 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
331 unsigned wordIndex = 0;
332 if (wordIndex >= words.size()) {
333 return emitError(unknownLoc,
334 "missing Execution Model specification in OpEntryPoint");
336 auto execModel = spirv::ExecutionModelAttr::get(
337 context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
338 if (wordIndex >= words.size()) {
339 return emitError(unknownLoc, "missing <id> in OpEntryPoint");
341 // Get the function <id>
342 auto fnID = words[wordIndex++];
343 // Get the function name
344 auto fnName = decodeStringLiteral(words, wordIndex);
345 // Verify that the function <id> matches the fnName
346 auto parsedFunc = getFunction(fnID);
347 if (!parsedFunc) {
348 return emitError(unknownLoc, "no function matching <id> ") << fnID;
350 if (parsedFunc.getName() != fnName) {
351 // The deserializer uses "spirv_fn_<id>" as the function name if the input
352 // SPIR-V blob does not contain a name for it. We should use a more clear
353 // indication for such case rather than relying on naming details.
354 if (!parsedFunc.getName().starts_with("spirv_fn_"))
355 return emitError(unknownLoc,
356 "function name mismatch between OpEntryPoint "
357 "and OpFunction with <id> ")
358 << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
359 parsedFunc.setName(fnName);
361 SmallVector<Attribute, 4> interface;
362 while (wordIndex < words.size()) {
363 auto arg = getGlobalVariable(words[wordIndex]);
364 if (!arg) {
365 return emitError(unknownLoc, "undefined result <id> ")
366 << words[wordIndex] << " while decoding OpEntryPoint";
368 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
369 wordIndex++;
371 opBuilder.create<spirv::EntryPointOp>(
372 unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
373 opBuilder.getArrayAttr(interface));
374 return success();
377 template <>
378 LogicalResult
379 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
380 unsigned wordIndex = 0;
381 if (wordIndex >= words.size()) {
382 return emitError(unknownLoc,
383 "missing function result <id> in OpExecutionMode");
385 // Get the function <id> to get the name of the function
386 auto fnID = words[wordIndex++];
387 auto fn = getFunction(fnID);
388 if (!fn) {
389 return emitError(unknownLoc, "no function matching <id> ") << fnID;
391 // Get the Execution mode
392 if (wordIndex >= words.size()) {
393 return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
395 auto execMode = spirv::ExecutionModeAttr::get(
396 context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
398 // Get the values
399 SmallVector<Attribute, 4> attrListElems;
400 while (wordIndex < words.size()) {
401 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
403 auto values = opBuilder.getArrayAttr(attrListElems);
404 opBuilder.create<spirv::ExecutionModeOp>(
405 unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
406 execMode, values);
407 return success();
410 template <>
411 LogicalResult
412 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
413 if (operands.size() < 3) {
414 return emitError(unknownLoc,
415 "OpFunctionCall must have at least 3 operands");
418 Type resultType = getType(operands[0]);
419 if (!resultType) {
420 return emitError(unknownLoc, "undefined result type from <id> ")
421 << operands[0];
424 // Use null type to mean no result type.
425 if (isVoidType(resultType))
426 resultType = nullptr;
428 auto resultID = operands[1];
429 auto functionID = operands[2];
431 auto functionName = getFunctionSymbol(functionID);
433 SmallVector<Value, 4> arguments;
434 for (auto operand : llvm::drop_begin(operands, 3)) {
435 auto value = getValue(operand);
436 if (!value) {
437 return emitError(unknownLoc, "unknown <id> ")
438 << operand << " used by OpFunctionCall";
440 arguments.push_back(value);
443 auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
444 unknownLoc, resultType,
445 SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
447 if (resultType)
448 valueMap[resultID] = opFunctionCall.getResult(0);
449 return success();
452 template <>
453 LogicalResult
454 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
455 SmallVector<Type, 1> resultTypes;
456 size_t wordIndex = 0;
457 SmallVector<Value, 4> operands;
458 SmallVector<NamedAttribute, 4> attributes;
460 if (wordIndex < words.size()) {
461 auto arg = getValue(words[wordIndex]);
463 if (!arg) {
464 return emitError(unknownLoc, "unknown result <id> : ")
465 << words[wordIndex];
468 operands.push_back(arg);
469 wordIndex++;
472 if (wordIndex < words.size()) {
473 auto arg = getValue(words[wordIndex]);
475 if (!arg) {
476 return emitError(unknownLoc, "unknown result <id> : ")
477 << words[wordIndex];
480 operands.push_back(arg);
481 wordIndex++;
484 bool isAlignedAttr = false;
486 if (wordIndex < words.size()) {
487 auto attrValue = words[wordIndex++];
488 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
489 static_cast<spirv::MemoryAccess>(attrValue));
490 attributes.push_back(
491 opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
492 isAlignedAttr = (attrValue == 2);
495 if (isAlignedAttr && wordIndex < words.size()) {
496 attributes.push_back(opBuilder.getNamedAttr(
497 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
500 if (wordIndex < words.size()) {
501 auto attrValue = words[wordIndex++];
502 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
503 static_cast<spirv::MemoryAccess>(attrValue));
504 attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr));
507 if (wordIndex < words.size()) {
508 attributes.push_back(opBuilder.getNamedAttr(
509 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
512 if (wordIndex != words.size()) {
513 return emitError(unknownLoc,
514 "found more operands than expected when deserializing "
515 "spirv::CopyMemoryOp, only ")
516 << wordIndex << " of " << words.size() << " processed";
519 Location loc = createFileLineColLoc(opBuilder);
520 opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
522 return success();
525 template <>
526 LogicalResult Deserializer::processOp<spirv::GenericCastToPtrExplicitOp>(
527 ArrayRef<uint32_t> words) {
528 if (words.size() != 4) {
529 return emitError(unknownLoc,
530 "expected 4 words in GenericCastToPtrExplicitOp"
531 " but got : ")
532 << words.size();
534 SmallVector<Type, 1> resultTypes;
535 SmallVector<Value, 4> operands;
536 uint32_t valueID = 0;
537 auto type = getType(words[0]);
539 if (!type)
540 return emitError(unknownLoc, "unknown type result <id> : ") << words[0];
541 resultTypes.push_back(type);
543 valueID = words[1];
545 auto arg = getValue(words[2]);
546 if (!arg)
547 return emitError(unknownLoc, "unknown result <id> : ") << words[2];
548 operands.push_back(arg);
550 Location loc = createFileLineColLoc(opBuilder);
551 Operation *op = opBuilder.create<spirv::GenericCastToPtrExplicitOp>(
552 loc, resultTypes, operands);
553 valueMap[valueID] = op->getResult(0);
554 return success();
557 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
558 // various Deserializer::processOp<...>() specializations.
559 #define GET_DESERIALIZATION_FNS
560 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
562 } // namespace spirv
563 } // namespace mlir