1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // The purpose of this file is to provide negative deserialization tests.
10 // For positive deserialization tests, please use serialization and
11 // deserialization for roundtripping.
13 //===----------------------------------------------------------------------===//
15 #include "mlir/Target/SPIRV/Deserialization.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/MLIRContext.h"
20 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
21 #include "gmock/gmock.h"
27 using ::testing::StrEq
;
29 //===----------------------------------------------------------------------===//
31 //===----------------------------------------------------------------------===//
33 /// A deserialization test fixture providing minimal SPIR-V building and
34 /// diagnostic checking utilities.
35 class DeserializationTest
: public ::testing::Test
{
37 DeserializationTest() {
38 context
.getOrLoadDialect
<mlir::spirv::SPIRVDialect
>();
39 // Register a diagnostic handler to capture the diagnostic so that we can
41 context
.getDiagEngine().registerHandler([&](Diagnostic
&diag
) {
42 diagnostic
= std::make_unique
<Diagnostic
>(std::move(diag
));
46 /// Performs deserialization and returns the constructed spirv.module op.
47 OwningOpRef
<spirv::ModuleOp
> deserialize() {
48 return spirv::deserialize(binary
, &context
);
51 /// Checks there is a diagnostic generated with the given `errorMessage`.
52 void expectDiagnostic(StringRef errorMessage
) {
53 ASSERT_NE(nullptr, diagnostic
.get());
55 // TODO: check error location too.
56 EXPECT_THAT(diagnostic
->str(), StrEq(std::string(errorMessage
)));
59 //===--------------------------------------------------------------------===//
60 // SPIR-V builder methods
61 //===--------------------------------------------------------------------===//
63 /// Adds the SPIR-V module header to `binary`.
65 spirv::appendModuleHeader(binary
, spirv::Version::V_1_0
, /*idBound=*/0);
68 /// Adds the SPIR-V instruction into `binary`.
69 void addInstruction(spirv::Opcode op
, ArrayRef
<uint32_t> operands
) {
70 uint32_t wordCount
= 1 + operands
.size();
71 binary
.push_back(spirv::getPrefixedOpcode(wordCount
, op
));
72 binary
.append(operands
.begin(), operands
.end());
75 uint32_t addVoidType() {
77 addInstruction(spirv::Opcode::OpTypeVoid
, {id
});
81 uint32_t addIntType(uint32_t bitwidth
) {
83 addInstruction(spirv::Opcode::OpTypeInt
, {id
, bitwidth
, /*signedness=*/1});
87 uint32_t addStructType(ArrayRef
<uint32_t> memberTypes
) {
89 SmallVector
<uint32_t, 2> words
;
91 words
.append(memberTypes
.begin(), memberTypes
.end());
92 addInstruction(spirv::Opcode::OpTypeStruct
, words
);
96 uint32_t addFunctionType(uint32_t retType
, ArrayRef
<uint32_t> paramTypes
) {
98 SmallVector
<uint32_t, 4> operands
;
99 operands
.push_back(id
);
100 operands
.push_back(retType
);
101 operands
.append(paramTypes
.begin(), paramTypes
.end());
102 addInstruction(spirv::Opcode::OpTypeFunction
, operands
);
106 uint32_t addFunction(uint32_t retType
, uint32_t fnType
) {
108 addInstruction(spirv::Opcode::OpFunction
,
110 static_cast<uint32_t>(spirv::FunctionControl::None
),
115 void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd
, {}); }
117 void addReturn() { addInstruction(spirv::Opcode::OpReturn
, {}); }
120 SmallVector
<uint32_t, 5> binary
;
123 std::unique_ptr
<Diagnostic
> diagnostic
;
126 //===----------------------------------------------------------------------===//
128 //===----------------------------------------------------------------------===//
130 TEST_F(DeserializationTest
, EmptyModuleFailure
) {
131 ASSERT_FALSE(deserialize());
132 expectDiagnostic("SPIR-V binary module must have a 5-word header");
135 TEST_F(DeserializationTest
, WrongMagicNumberFailure
) {
137 binary
.front() = 0xdeadbeef; // Change to a wrong magic number
138 ASSERT_FALSE(deserialize());
139 expectDiagnostic("incorrect magic number");
142 TEST_F(DeserializationTest
, OnlyHeaderSuccess
) {
144 EXPECT_TRUE(deserialize());
147 TEST_F(DeserializationTest
, ZeroWordCountFailure
) {
149 binary
.push_back(0); // OpNop with zero word count
151 ASSERT_FALSE(deserialize());
152 expectDiagnostic("word count cannot be zero");
155 TEST_F(DeserializationTest
, InsufficientWordFailure
) {
157 binary
.push_back((2u << 16) |
158 static_cast<uint32_t>(spirv::Opcode::OpTypeVoid
));
159 // Missing word for type <id>.
161 ASSERT_FALSE(deserialize());
162 expectDiagnostic("insufficient words for the last instruction");
165 //===----------------------------------------------------------------------===//
167 //===----------------------------------------------------------------------===//
169 TEST_F(DeserializationTest
, IntTypeMissingSignednessFailure
) {
171 addInstruction(spirv::Opcode::OpTypeInt
, {nextID
++, 32});
173 ASSERT_FALSE(deserialize());
174 expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
177 //===----------------------------------------------------------------------===//
179 //===----------------------------------------------------------------------===//
181 TEST_F(DeserializationTest
, OpMemberNameSuccess
) {
183 SmallVector
<uint32_t, 5> typeDecl
;
184 std::swap(typeDecl
, binary
);
186 auto int32Type
= addIntType(32);
187 auto structType
= addStructType({int32Type
, int32Type
});
188 std::swap(typeDecl
, binary
);
190 SmallVector
<uint32_t, 5> operands1
= {structType
, 0};
191 (void)spirv::encodeStringLiteralInto(operands1
, "i1");
192 addInstruction(spirv::Opcode::OpMemberName
, operands1
);
194 SmallVector
<uint32_t, 5> operands2
= {structType
, 1};
195 (void)spirv::encodeStringLiteralInto(operands2
, "i2");
196 addInstruction(spirv::Opcode::OpMemberName
, operands2
);
198 binary
.append(typeDecl
.begin(), typeDecl
.end());
199 EXPECT_TRUE(deserialize());
202 TEST_F(DeserializationTest
, OpMemberNameMissingOperands
) {
204 SmallVector
<uint32_t, 5> typeDecl
;
205 std::swap(typeDecl
, binary
);
207 auto int32Type
= addIntType(32);
208 auto int64Type
= addIntType(64);
209 auto structType
= addStructType({int32Type
, int64Type
});
210 std::swap(typeDecl
, binary
);
212 SmallVector
<uint32_t, 5> operands1
= {structType
};
213 addInstruction(spirv::Opcode::OpMemberName
, operands1
);
215 binary
.append(typeDecl
.begin(), typeDecl
.end());
216 ASSERT_FALSE(deserialize());
217 expectDiagnostic("OpMemberName must have at least 3 operands");
220 TEST_F(DeserializationTest
, OpMemberNameExcessOperands
) {
222 SmallVector
<uint32_t, 5> typeDecl
;
223 std::swap(typeDecl
, binary
);
225 auto int32Type
= addIntType(32);
226 auto structType
= addStructType({int32Type
});
227 std::swap(typeDecl
, binary
);
229 SmallVector
<uint32_t, 5> operands
= {structType
, 0};
230 (void)spirv::encodeStringLiteralInto(operands
, "int32");
231 operands
.push_back(42);
232 addInstruction(spirv::Opcode::OpMemberName
, operands
);
234 binary
.append(typeDecl
.begin(), typeDecl
.end());
235 ASSERT_FALSE(deserialize());
236 expectDiagnostic("unexpected trailing words in OpMemberName instruction");
239 //===----------------------------------------------------------------------===//
241 //===----------------------------------------------------------------------===//
243 TEST_F(DeserializationTest
, FunctionMissingEndFailure
) {
245 auto voidType
= addVoidType();
246 auto fnType
= addFunctionType(voidType
, {});
247 addFunction(voidType
, fnType
);
248 // Missing OpFunctionEnd.
250 ASSERT_FALSE(deserialize());
251 expectDiagnostic("expected OpFunctionEnd instruction");
254 TEST_F(DeserializationTest
, FunctionMissingParameterFailure
) {
256 auto voidType
= addVoidType();
257 auto i32Type
= addIntType(32);
258 auto fnType
= addFunctionType(voidType
, {i32Type
});
259 addFunction(voidType
, fnType
);
260 // Missing OpFunctionParameter.
262 ASSERT_FALSE(deserialize());
263 expectDiagnostic("expected OpFunctionParameter instruction");
266 TEST_F(DeserializationTest
, FunctionMissingLabelForFirstBlockFailure
) {
268 auto voidType
= addVoidType();
269 auto fnType
= addFunctionType(voidType
, {});
270 addFunction(voidType
, fnType
);
275 ASSERT_FALSE(deserialize());
276 expectDiagnostic("a basic block must start with OpLabel");
279 TEST_F(DeserializationTest
, FunctionMalformedLabelFailure
) {
281 auto voidType
= addVoidType();
282 auto fnType
= addFunctionType(voidType
, {});
283 addFunction(voidType
, fnType
);
284 addInstruction(spirv::Opcode::OpLabel
, {}); // Malformed OpLabel
288 ASSERT_FALSE(deserialize());
289 expectDiagnostic("OpLabel should only have result <id>");