[OpenACC] Implement 'device_type' for 'data' construct
[llvm-project.git] / mlir / lib / Target / SPIRV / Serialization / SerializeOps.cpp
blob4c15523a05fa821d33d519fa490067605296f37b
1 //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the serialization methods for MLIR SPIR-V module ops.
11 //===----------------------------------------------------------------------===//
13 #include "Serializer.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/IR/RegionGraphTraits.h"
18 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
19 #include "llvm/ADT/DepthFirstIterator.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "spirv-serialization"
25 using namespace mlir;
27 /// A pre-order depth-first visitor function for processing basic blocks.
28 ///
29 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
30 /// depth-first manner and calls `blockHandler` on each block. Skips handling
31 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
32 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
33 /// successors.
34 ///
35 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
36 /// of blocks in a function must satisfy the rule that blocks appear before
37 /// all blocks they dominate." This can be achieved by a pre-order CFG
38 /// traversal algorithm. To make the serialization output more logical and
39 /// readable to human, we perform depth-first CFG traversal and delay the
40 /// serialization of the merge block and the continue block, if exists, until
41 /// after all other blocks have been processed.
42 static LogicalResult
43 visitInPrettyBlockOrder(Block *headerBlock,
44 function_ref<LogicalResult(Block *)> blockHandler,
45 bool skipHeader = false, BlockRange skipBlocks = {}) {
46 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
47 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
49 for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
50 if (skipHeader && block == headerBlock)
51 continue;
52 if (failed(blockHandler(block)))
53 return failure();
55 return success();
58 namespace mlir {
59 namespace spirv {
60 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
61 if (auto resultID =
62 prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
63 valueIDMap[op.getResult()] = resultID;
64 return success();
66 return failure();
69 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
70 if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
71 /*isSpec=*/true)) {
72 // Emit the OpDecorate instruction for SpecId.
73 if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
74 auto val = static_cast<uint32_t>(specID.getInt());
75 if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
76 return failure();
79 specConstIDMap[op.getSymName()] = resultID;
80 return processName(resultID, op.getSymName());
82 return failure();
85 LogicalResult
86 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
87 uint32_t typeID = 0;
88 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
89 return failure();
92 auto resultID = getNextID();
94 SmallVector<uint32_t, 8> operands;
95 operands.push_back(typeID);
96 operands.push_back(resultID);
98 auto constituents = op.getConstituents();
100 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
101 auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
103 auto constituentName = constituent.getValue();
104 auto constituentID = getSpecConstID(constituentName);
106 if (!constituentID) {
107 return op.emitError("unknown result <id> for specialization constant ")
108 << constituentName;
111 operands.push_back(constituentID);
114 encodeInstructionInto(typesGlobalValues,
115 spirv::Opcode::OpSpecConstantComposite, operands);
116 specConstIDMap[op.getSymName()] = resultID;
118 return processName(resultID, op.getSymName());
121 LogicalResult
122 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
123 uint32_t typeID = 0;
124 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
125 return failure();
128 auto resultID = getNextID();
130 SmallVector<uint32_t, 8> operands;
131 operands.push_back(typeID);
132 operands.push_back(resultID);
134 Block &block = op.getRegion().getBlocks().front();
135 Operation &enclosedOp = block.getOperations().front();
137 std::string enclosedOpName;
138 llvm::raw_string_ostream rss(enclosedOpName);
139 rss << "Op" << enclosedOp.getName().stripDialect();
140 auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName);
142 if (!enclosedOpcode) {
143 op.emitError("Couldn't find op code for op ")
144 << enclosedOp.getName().getStringRef();
145 return failure();
148 operands.push_back(static_cast<uint32_t>(*enclosedOpcode));
150 // Append operands to the enclosed op to the list of operands.
151 for (Value operand : enclosedOp.getOperands()) {
152 uint32_t id = getValueID(operand);
153 assert(id && "use before def!");
154 operands.push_back(id);
157 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
158 operands);
159 valueIDMap[op.getResult()] = resultID;
161 return success();
164 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
165 auto undefType = op.getType();
166 auto &id = undefValIDMap[undefType];
167 if (!id) {
168 id = getNextID();
169 uint32_t typeID = 0;
170 if (failed(processType(op.getLoc(), undefType, typeID)))
171 return failure();
172 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
173 {typeID, id});
175 valueIDMap[op.getResult()] = id;
176 return success();
179 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
180 for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
181 uint32_t argTypeID = 0;
182 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
183 return failure();
185 auto argValueID = getNextID();
187 // Process decoration attributes of arguments.
188 auto funcOp = cast<FunctionOpInterface>(*op);
189 for (auto argAttr : funcOp.getArgAttrs(idx)) {
190 if (argAttr.getName() != DecorationAttr::name)
191 continue;
193 if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
194 if (failed(processDecorationAttr(op->getLoc(), argValueID,
195 decAttr.getValue(), decAttr)))
196 return failure();
200 valueIDMap[arg] = argValueID;
201 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
202 {argTypeID, argValueID});
204 return success();
207 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
208 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
209 assert(functionHeader.empty() && functionBody.empty());
211 uint32_t fnTypeID = 0;
212 // Generate type of the function.
213 if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
214 return failure();
216 // Add the function definition.
217 SmallVector<uint32_t, 4> operands;
218 uint32_t resTypeID = 0;
219 auto resultTypes = op.getFunctionType().getResults();
220 if (resultTypes.size() > 1) {
221 return op.emitError("cannot serialize function with multiple return types");
223 if (failed(processType(op.getLoc(),
224 (resultTypes.empty() ? getVoidType() : resultTypes[0]),
225 resTypeID))) {
226 return failure();
228 operands.push_back(resTypeID);
229 auto funcID = getOrCreateFunctionID(op.getName());
230 operands.push_back(funcID);
231 operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
232 operands.push_back(fnTypeID);
233 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
235 // Add function name.
236 if (failed(processName(funcID, op.getName()))) {
237 return failure();
239 // Handle external functions with linkage_attributes(LinkageAttributes)
240 // differently.
241 auto linkageAttr = op.getLinkageAttributes();
242 auto hasImportLinkage =
243 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
244 spirv::LinkageType::Import);
245 if (op.isExternal() && !hasImportLinkage) {
246 return op.emitError(
247 "'spirv.module' cannot contain external functions "
248 "without 'Import' linkage_attributes (LinkageAttributes)");
250 if (op.isExternal() && hasImportLinkage) {
251 // Add an entry block to set up the block arguments
252 // to match the signature of the function.
253 // This is to generate OpFunctionParameter for functions with
254 // LinkageAttributes.
255 // WARNING: This operation has side-effect, it essentially adds a body
256 // to the func. Hence, making it not external anymore (isExternal()
257 // is going to return false for this function from now on)
258 // Hence, we'll remove the body once we are done with the serialization.
259 op.addEntryBlock();
260 if (failed(processFuncParameter(op)))
261 return failure();
262 // Don't need to process the added block, there is nothing to process,
263 // the fake body was added just to get the arguments, remove the body,
264 // since it's use is done.
265 op.eraseBody();
266 } else {
267 if (failed(processFuncParameter(op)))
268 return failure();
270 // Some instructions (e.g., OpVariable) in a function must be in the first
271 // block in the function. These instructions will be put in
272 // functionHeader. Thus, we put the label in functionHeader first, and
273 // omit it from the first block. OpLabel only needs to be added for
274 // functions with body (including empty body). Since, we added a fake body
275 // for functions with 'Import' Linkage attributes, these functions are
276 // essentially function delcaration, so they should not have OpLabel and a
277 // terminating instruction. That's why we skipped it for those functions.
278 encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
279 {getOrCreateBlockID(&op.front())});
280 if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
281 return failure();
282 if (failed(visitInPrettyBlockOrder(
283 &op.front(), [&](Block *block) { return processBlock(block); },
284 /*skipHeader=*/true))) {
285 return failure();
288 // There might be OpPhi instructions who have value references needing to
289 // fix.
290 for (const auto &deferredValue : deferredPhiValues) {
291 Value value = deferredValue.first;
292 uint32_t id = getValueID(value);
293 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
294 << " to id = " << id << '\n');
295 assert(id && "OpPhi references undefined value!");
296 for (size_t offset : deferredValue.second)
297 functionBody[offset] = id;
299 deferredPhiValues.clear();
301 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
302 << "' --\n");
303 // Insert Decorations based on Function Attributes.
304 // Only attributes we should be considering for decoration are the
305 // ::mlir::spirv::Decoration attributes.
307 for (auto attr : op->getAttrs()) {
308 // Only generate OpDecorate op for spirv::Decoration attributes.
309 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
310 llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
311 /*capitalizeFirst=*/true));
312 if (isValidDecoration != std::nullopt) {
313 if (failed(processDecoration(op.getLoc(), funcID, attr))) {
314 return failure();
318 // Insert OpFunctionEnd.
319 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
321 functions.append(functionHeader.begin(), functionHeader.end());
322 functions.append(functionBody.begin(), functionBody.end());
323 functionHeader.clear();
324 functionBody.clear();
326 return success();
329 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
330 SmallVector<uint32_t, 4> operands;
331 SmallVector<StringRef, 2> elidedAttrs;
332 uint32_t resultID = 0;
333 uint32_t resultTypeID = 0;
334 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
335 return failure();
337 operands.push_back(resultTypeID);
338 resultID = getNextID();
339 valueIDMap[op.getResult()] = resultID;
340 operands.push_back(resultID);
341 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
342 if (attr) {
343 operands.push_back(
344 static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue()));
346 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
347 for (auto arg : op.getODSOperands(0)) {
348 auto argID = getValueID(arg);
349 if (!argID) {
350 return emitError(op.getLoc(), "operand 0 has a use before def");
352 operands.push_back(argID);
354 if (failed(emitDebugLine(functionHeader, op.getLoc())))
355 return failure();
356 encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
357 for (auto attr : op->getAttrs()) {
358 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
359 return attr.getName() == elided;
360 })) {
361 continue;
363 if (failed(processDecoration(op.getLoc(), resultID, attr))) {
364 return failure();
367 return success();
370 LogicalResult
371 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
372 // Get TypeID.
373 uint32_t resultTypeID = 0;
374 SmallVector<StringRef, 4> elidedAttrs;
375 if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
376 return failure();
379 elidedAttrs.push_back("type");
380 SmallVector<uint32_t, 4> operands;
381 operands.push_back(resultTypeID);
382 auto resultID = getNextID();
384 // Encode the name.
385 auto varName = varOp.getSymName();
386 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
387 if (failed(processName(resultID, varName))) {
388 return failure();
390 globalVarIDMap[varName] = resultID;
391 operands.push_back(resultID);
393 // Encode StorageClass.
394 operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
396 // Encode initialization.
397 StringRef initAttrName = varOp.getInitializerAttrName().getValue();
398 if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
399 uint32_t initializerID = 0;
400 auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
401 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
402 varOp->getParentOp(), initRef.getAttr());
404 // Check if initializer is GlobalVariable or SpecConstant* cases.
405 if (isa<spirv::GlobalVariableOp>(initOp))
406 initializerID = getVariableID(*initSymbolName);
407 else
408 initializerID = getSpecConstID(*initSymbolName);
410 if (!initializerID)
411 return emitError(varOp.getLoc(),
412 "invalid usage of undefined variable as initializer");
414 operands.push_back(initializerID);
415 elidedAttrs.push_back(initAttrName);
418 if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
419 return failure();
420 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
421 elidedAttrs.push_back(initAttrName);
423 // Encode decorations.
424 for (auto attr : varOp->getAttrs()) {
425 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
426 return attr.getName() == elided;
427 })) {
428 continue;
430 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
431 return failure();
434 return success();
437 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
438 // Assign <id>s to all blocks so that branches inside the SelectionOp can
439 // resolve properly.
440 auto &body = selectionOp.getBody();
441 for (Block &block : body)
442 getOrCreateBlockID(&block);
444 auto *headerBlock = selectionOp.getHeaderBlock();
445 auto *mergeBlock = selectionOp.getMergeBlock();
446 auto headerID = getBlockID(headerBlock);
447 auto mergeID = getBlockID(mergeBlock);
448 auto loc = selectionOp.getLoc();
450 // This SelectionOp is in some MLIR block with preceding and following ops. In
451 // the binary format, it should reside in separate SPIR-V blocks from its
452 // preceding and following ops. So we need to emit unconditional branches to
453 // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
454 // flow afterwards.
455 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
457 // Emit the selection header block, which dominates all other blocks, first.
458 // We need to emit an OpSelectionMerge instruction before the selection header
459 // block's terminator.
460 auto emitSelectionMerge = [&]() {
461 if (failed(emitDebugLine(functionBody, loc)))
462 return failure();
463 lastProcessedWasMergeInst = true;
464 encodeInstructionInto(
465 functionBody, spirv::Opcode::OpSelectionMerge,
466 {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
467 return success();
469 if (failed(
470 processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge)))
471 return failure();
473 // Process all blocks with a depth-first visitor starting from the header
474 // block. The selection header block and merge block are skipped by this
475 // visitor.
476 if (failed(visitInPrettyBlockOrder(
477 headerBlock, [&](Block *block) { return processBlock(block); },
478 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
479 return failure();
481 // There is nothing to do for the merge block in the selection, which just
482 // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel
483 // instruction to start a new SPIR-V block for ops following this SelectionOp.
484 // The block should use the <id> for the merge block.
485 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
486 LLVM_DEBUG(llvm::dbgs() << "done merge ");
487 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
488 LLVM_DEBUG(llvm::dbgs() << "\n");
489 return success();
492 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
493 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
494 // properly. We don't need to assign for the entry block, which is just for
495 // satisfying MLIR region's structural requirement.
496 auto &body = loopOp.getBody();
497 for (Block &block : llvm::drop_begin(body))
498 getOrCreateBlockID(&block);
500 auto *headerBlock = loopOp.getHeaderBlock();
501 auto *continueBlock = loopOp.getContinueBlock();
502 auto *mergeBlock = loopOp.getMergeBlock();
503 auto headerID = getBlockID(headerBlock);
504 auto continueID = getBlockID(continueBlock);
505 auto mergeID = getBlockID(mergeBlock);
506 auto loc = loopOp.getLoc();
508 // This LoopOp is in some MLIR block with preceding and following ops. In the
509 // binary format, it should reside in separate SPIR-V blocks from its
510 // preceding and following ops. So we need to emit unconditional branches to
511 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
512 // afterwards.
513 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
515 // LoopOp's entry block is just there for satisfying MLIR's structural
516 // requirements so we omit it and start serialization from the loop header
517 // block.
519 // Emit the loop header block, which dominates all other blocks, first. We
520 // need to emit an OpLoopMerge instruction before the loop header block's
521 // terminator.
522 auto emitLoopMerge = [&]() {
523 if (failed(emitDebugLine(functionBody, loc)))
524 return failure();
525 lastProcessedWasMergeInst = true;
526 encodeInstructionInto(
527 functionBody, spirv::Opcode::OpLoopMerge,
528 {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
529 return success();
531 if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
532 return failure();
534 // Process all blocks with a depth-first visitor starting from the header
535 // block. The loop header block, loop continue block, and loop merge block are
536 // skipped by this visitor and handled later in this function.
537 if (failed(visitInPrettyBlockOrder(
538 headerBlock, [&](Block *block) { return processBlock(block); },
539 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
540 return failure();
542 // We have handled all other blocks. Now get to the loop continue block.
543 if (failed(processBlock(continueBlock)))
544 return failure();
546 // There is nothing to do for the merge block in the loop, which just contains
547 // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction
548 // to start a new SPIR-V block for ops following this LoopOp. The block should
549 // use the <id> for the merge block.
550 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
551 LLVM_DEBUG(llvm::dbgs() << "done merge ");
552 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
553 LLVM_DEBUG(llvm::dbgs() << "\n");
554 return success();
557 LogicalResult Serializer::processBranchConditionalOp(
558 spirv::BranchConditionalOp condBranchOp) {
559 auto conditionID = getValueID(condBranchOp.getCondition());
560 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
561 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
562 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
564 if (auto weights = condBranchOp.getBranchWeights()) {
565 for (auto val : weights->getValue())
566 arguments.push_back(cast<IntegerAttr>(val).getInt());
569 if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
570 return failure();
571 encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
572 arguments);
573 return success();
576 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
577 if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
578 return failure();
579 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
580 {getOrCreateBlockID(branchOp.getTarget())});
581 return success();
584 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
585 auto varName = addressOfOp.getVariable();
586 auto variableID = getVariableID(varName);
587 if (!variableID) {
588 return addressOfOp.emitError("unknown result <id> for variable ")
589 << varName;
591 valueIDMap[addressOfOp.getPointer()] = variableID;
592 return success();
595 LogicalResult
596 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
597 auto constName = referenceOfOp.getSpecConst();
598 auto constID = getSpecConstID(constName);
599 if (!constID) {
600 return referenceOfOp.emitError(
601 "unknown result <id> for specialization constant ")
602 << constName;
604 valueIDMap[referenceOfOp.getReference()] = constID;
605 return success();
608 template <>
609 LogicalResult
610 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
611 SmallVector<uint32_t, 4> operands;
612 // Add the ExecutionModel.
613 operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
614 // Add the function <id>.
615 auto funcID = getFunctionID(op.getFn());
616 if (!funcID) {
617 return op.emitError("missing <id> for function ")
618 << op.getFn()
619 << "; function needs to be defined before spirv.EntryPoint is "
620 "serialized";
622 operands.push_back(funcID);
623 // Add the name of the function.
624 spirv::encodeStringLiteralInto(operands, op.getFn());
626 // Add the interface values.
627 if (auto interface = op.getInterface()) {
628 for (auto var : interface.getValue()) {
629 auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
630 if (!id) {
631 return op.emitError(
632 "referencing undefined global variable."
633 "spirv.EntryPoint is at the end of spirv.module. All "
634 "referenced variables should already be defined");
636 operands.push_back(id);
639 encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
640 return success();
643 template <>
644 LogicalResult
645 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
646 SmallVector<uint32_t, 4> operands;
647 // Add the function <id>.
648 auto funcID = getFunctionID(op.getFn());
649 if (!funcID) {
650 return op.emitError("missing <id> for function ")
651 << op.getFn()
652 << "; function needs to be serialized before ExecutionModeOp is "
653 "serialized";
655 operands.push_back(funcID);
656 // Add the ExecutionMode.
657 operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
659 // Serialize values if any.
660 auto values = op.getValues();
661 if (values) {
662 for (auto &intVal : values.getValue()) {
663 operands.push_back(static_cast<uint32_t>(
664 llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
667 encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
668 operands);
669 return success();
672 template <>
673 LogicalResult
674 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
675 auto funcName = op.getCallee();
676 uint32_t resTypeID = 0;
678 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
679 if (failed(processType(op.getLoc(), resultTy, resTypeID)))
680 return failure();
682 auto funcID = getOrCreateFunctionID(funcName);
683 auto funcCallID = getNextID();
684 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
686 for (auto value : op.getArguments()) {
687 auto valueID = getValueID(value);
688 assert(valueID && "cannot find a value for spirv.FunctionCall");
689 operands.push_back(valueID);
692 if (!isa<NoneType>(resultTy))
693 valueIDMap[op.getResult(0)] = funcCallID;
695 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
696 return success();
699 template <>
700 LogicalResult
701 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
702 SmallVector<uint32_t, 4> operands;
703 SmallVector<StringRef, 2> elidedAttrs;
705 for (Value operand : op->getOperands()) {
706 auto id = getValueID(operand);
707 assert(id && "use before def!");
708 operands.push_back(id);
711 StringAttr memoryAccess = op.getMemoryAccessAttrName();
712 if (auto attr = op->getAttr(memoryAccess)) {
713 operands.push_back(
714 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
717 elidedAttrs.push_back(memoryAccess.strref());
719 StringAttr alignment = op.getAlignmentAttrName();
720 if (auto attr = op->getAttr(alignment)) {
721 operands.push_back(static_cast<uint32_t>(
722 cast<IntegerAttr>(attr).getValue().getZExtValue()));
725 elidedAttrs.push_back(alignment.strref());
727 StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
728 if (auto attr = op->getAttr(sourceMemoryAccess)) {
729 operands.push_back(
730 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
733 elidedAttrs.push_back(sourceMemoryAccess.strref());
735 StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
736 if (auto attr = op->getAttr(sourceAlignment)) {
737 operands.push_back(static_cast<uint32_t>(
738 cast<IntegerAttr>(attr).getValue().getZExtValue()));
741 elidedAttrs.push_back(sourceAlignment.strref());
742 if (failed(emitDebugLine(functionBody, op.getLoc())))
743 return failure();
744 encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
746 return success();
748 template <>
749 LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
750 spirv::GenericCastToPtrExplicitOp op) {
751 SmallVector<uint32_t, 4> operands;
752 Type resultTy;
753 Location loc = op->getLoc();
754 uint32_t resultTypeID = 0;
755 uint32_t resultID = 0;
756 resultTy = op->getResult(0).getType();
757 if (failed(processType(loc, resultTy, resultTypeID)))
758 return failure();
759 operands.push_back(resultTypeID);
761 resultID = getNextID();
762 operands.push_back(resultID);
763 valueIDMap[op->getResult(0)] = resultID;
765 for (Value operand : op->getOperands())
766 operands.push_back(getValueID(operand));
767 spirv::StorageClass resultStorage =
768 cast<spirv::PointerType>(resultTy).getStorageClass();
769 operands.push_back(static_cast<uint32_t>(resultStorage));
770 encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit,
771 operands);
772 return success();
775 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
776 // various Serializer::processOp<...>() specializations.
777 #define GET_SERIALIZATION_FNS
778 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
780 } // namespace spirv
781 } // namespace mlir