1 //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a translation from SPIR-V binary module to MLIR SPIR-V
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/Verifier.h"
20 #include "mlir/Parser/Parser.h"
21 #include "mlir/Support/FileUtilities.h"
22 #include "mlir/Target/SPIRV/Deserialization.h"
23 #include "mlir/Target/SPIRV/Serialization.h"
24 #include "mlir/Tools/mlir-translate/Translation.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/MemoryBuffer.h"
27 #include "llvm/Support/SMLoc.h"
28 #include "llvm/Support/SourceMgr.h"
29 #include "llvm/Support/ToolOutputFile.h"
33 //===----------------------------------------------------------------------===//
34 // Deserialization registration
35 //===----------------------------------------------------------------------===//
37 // Deserializes the SPIR-V binary module stored in the file named as
38 // `inputFilename` and returns a module containing the SPIR-V module.
39 static OwningOpRef
<Operation
*>
40 deserializeModule(const llvm::MemoryBuffer
*input
, MLIRContext
*context
) {
41 context
->loadDialect
<spirv::SPIRVDialect
>();
43 // Make sure the input stream can be treated as a stream of SPIR-V words
44 auto *start
= input
->getBufferStart();
45 auto size
= input
->getBufferSize();
46 if (size
% sizeof(uint32_t) != 0) {
47 emitError(UnknownLoc::get(context
))
48 << "SPIR-V binary module must contain integral number of 32-bit words";
52 auto binary
= llvm::ArrayRef(reinterpret_cast<const uint32_t *>(start
),
53 size
/ sizeof(uint32_t));
54 return spirv::deserialize(binary
, context
);
58 void registerFromSPIRVTranslation() {
59 TranslateToMLIRRegistration
fromBinary(
60 "deserialize-spirv", "deserializes the SPIR-V module",
61 [](llvm::SourceMgr
&sourceMgr
, MLIRContext
*context
) {
62 assert(sourceMgr
.getNumBuffers() == 1 && "expected one buffer");
63 return deserializeModule(
64 sourceMgr
.getMemoryBuffer(sourceMgr
.getMainFileID()), context
);
69 //===----------------------------------------------------------------------===//
70 // Serialization registration
71 //===----------------------------------------------------------------------===//
73 static LogicalResult
serializeModule(spirv::ModuleOp module
,
74 raw_ostream
&output
) {
75 SmallVector
<uint32_t, 0> binary
;
76 if (failed(spirv::serialize(module
, binary
)))
79 output
.write(reinterpret_cast<char *>(binary
.data()),
80 binary
.size() * sizeof(uint32_t));
82 return mlir::success();
86 void registerToSPIRVTranslation() {
87 TranslateFromMLIRRegistration
toBinary(
88 "serialize-spirv", "serialize SPIR-V dialect",
89 [](spirv::ModuleOp module
, raw_ostream
&output
) {
90 return serializeModule(module
, output
);
92 [](DialectRegistry
®istry
) {
93 registry
.insert
<spirv::SPIRVDialect
>();
98 //===----------------------------------------------------------------------===//
99 // Round-trip registration
100 //===----------------------------------------------------------------------===//
102 static LogicalResult
roundTripModule(spirv::ModuleOp module
, bool emitDebugInfo
,
103 raw_ostream
&output
) {
104 SmallVector
<uint32_t, 0> binary
;
105 MLIRContext
*context
= module
->getContext();
107 spirv::SerializationOptions options
;
108 options
.emitDebugInfo
= emitDebugInfo
;
109 if (failed(spirv::serialize(module
, binary
, options
)))
112 MLIRContext
deserializationContext(context
->getDialectRegistry());
113 // TODO: we should only load the required dialects instead of all dialects.
114 deserializationContext
.loadAllAvailableDialects();
115 // Then deserialize to get back a SPIR-V module.
116 OwningOpRef
<spirv::ModuleOp
> spirvModule
=
117 spirv::deserialize(binary
, &deserializationContext
);
120 spirvModule
->print(output
);
122 return mlir::success();
126 void registerTestRoundtripSPIRV() {
127 TranslateFromMLIRRegistration
roundtrip(
128 "test-spirv-roundtrip", "test roundtrip in SPIR-V dialect",
129 [](spirv::ModuleOp module
, raw_ostream
&output
) {
130 return roundTripModule(module
, /*emitDebugInfo=*/false, output
);
132 [](DialectRegistry
®istry
) {
133 registry
.insert
<spirv::SPIRVDialect
>();
137 void registerTestRoundtripDebugSPIRV() {
138 TranslateFromMLIRRegistration
roundtrip(
139 "test-spirv-roundtrip-debug", "test roundtrip debug in SPIR-V",
140 [](spirv::ModuleOp module
, raw_ostream
&output
) {
141 return roundTripModule(module
, /*emitDebugInfo=*/true, output
);
143 [](DialectRegistry
®istry
) {
144 registry
.insert
<spirv::SPIRVDialect
>();