[Clang][SME2] Fix PSEL builtin predicates (#77097)
[llvm-project.git] / mlir / unittests / Dialect / SPIRV / DeserializationTest.cpp
blob13c83c00d1523fd01a27509767952297ec4abd80
1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The purpose of this file is to provide negative deserialization tests.
10 // For positive deserialization tests, please use serialization and
11 // deserialization for roundtripping.
13 //===----------------------------------------------------------------------===//
15 #include "mlir/Target/SPIRV/Deserialization.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/MLIRContext.h"
20 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
21 #include "gmock/gmock.h"
23 #include <memory>
25 using namespace mlir;
27 using ::testing::StrEq;
29 //===----------------------------------------------------------------------===//
30 // Test Fixture
31 //===----------------------------------------------------------------------===//
33 /// A deserialization test fixture providing minimal SPIR-V building and
34 /// diagnostic checking utilities.
35 class DeserializationTest : public ::testing::Test {
36 protected:
37 DeserializationTest() {
38 context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
39 // Register a diagnostic handler to capture the diagnostic so that we can
40 // check it later.
41 context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
42 diagnostic = std::make_unique<Diagnostic>(std::move(diag));
43 });
46 /// Performs deserialization and returns the constructed spirv.module op.
47 OwningOpRef<spirv::ModuleOp> deserialize() {
48 return spirv::deserialize(binary, &context);
51 /// Checks there is a diagnostic generated with the given `errorMessage`.
52 void expectDiagnostic(StringRef errorMessage) {
53 ASSERT_NE(nullptr, diagnostic.get());
55 // TODO: check error location too.
56 EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage)));
59 //===--------------------------------------------------------------------===//
60 // SPIR-V builder methods
61 //===--------------------------------------------------------------------===//
63 /// Adds the SPIR-V module header to `binary`.
64 void addHeader() {
65 spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0);
68 /// Adds the SPIR-V instruction into `binary`.
69 void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
70 uint32_t wordCount = 1 + operands.size();
71 binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
72 binary.append(operands.begin(), operands.end());
75 uint32_t addVoidType() {
76 auto id = nextID++;
77 addInstruction(spirv::Opcode::OpTypeVoid, {id});
78 return id;
81 uint32_t addIntType(uint32_t bitwidth) {
82 auto id = nextID++;
83 addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
84 return id;
87 uint32_t addStructType(ArrayRef<uint32_t> memberTypes) {
88 auto id = nextID++;
89 SmallVector<uint32_t, 2> words;
90 words.push_back(id);
91 words.append(memberTypes.begin(), memberTypes.end());
92 addInstruction(spirv::Opcode::OpTypeStruct, words);
93 return id;
96 uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
97 auto id = nextID++;
98 SmallVector<uint32_t, 4> operands;
99 operands.push_back(id);
100 operands.push_back(retType);
101 operands.append(paramTypes.begin(), paramTypes.end());
102 addInstruction(spirv::Opcode::OpTypeFunction, operands);
103 return id;
106 uint32_t addFunction(uint32_t retType, uint32_t fnType) {
107 auto id = nextID++;
108 addInstruction(spirv::Opcode::OpFunction,
109 {retType, id,
110 static_cast<uint32_t>(spirv::FunctionControl::None),
111 fnType});
112 return id;
115 void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }
117 void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }
119 protected:
120 SmallVector<uint32_t, 5> binary;
121 uint32_t nextID = 1;
122 MLIRContext context;
123 std::unique_ptr<Diagnostic> diagnostic;
126 //===----------------------------------------------------------------------===//
127 // Basics
128 //===----------------------------------------------------------------------===//
130 TEST_F(DeserializationTest, EmptyModuleFailure) {
131 ASSERT_FALSE(deserialize());
132 expectDiagnostic("SPIR-V binary module must have a 5-word header");
135 TEST_F(DeserializationTest, WrongMagicNumberFailure) {
136 addHeader();
137 binary.front() = 0xdeadbeef; // Change to a wrong magic number
138 ASSERT_FALSE(deserialize());
139 expectDiagnostic("incorrect magic number");
142 TEST_F(DeserializationTest, OnlyHeaderSuccess) {
143 addHeader();
144 EXPECT_TRUE(deserialize());
147 TEST_F(DeserializationTest, ZeroWordCountFailure) {
148 addHeader();
149 binary.push_back(0); // OpNop with zero word count
151 ASSERT_FALSE(deserialize());
152 expectDiagnostic("word count cannot be zero");
155 TEST_F(DeserializationTest, InsufficientWordFailure) {
156 addHeader();
157 binary.push_back((2u << 16) |
158 static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
159 // Missing word for type <id>.
161 ASSERT_FALSE(deserialize());
162 expectDiagnostic("insufficient words for the last instruction");
165 //===----------------------------------------------------------------------===//
166 // Types
167 //===----------------------------------------------------------------------===//
169 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
170 addHeader();
171 addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
173 ASSERT_FALSE(deserialize());
174 expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
177 //===----------------------------------------------------------------------===//
178 // StructType
179 //===----------------------------------------------------------------------===//
181 TEST_F(DeserializationTest, OpMemberNameSuccess) {
182 addHeader();
183 SmallVector<uint32_t, 5> typeDecl;
184 std::swap(typeDecl, binary);
186 auto int32Type = addIntType(32);
187 auto structType = addStructType({int32Type, int32Type});
188 std::swap(typeDecl, binary);
190 SmallVector<uint32_t, 5> operands1 = {structType, 0};
191 (void)spirv::encodeStringLiteralInto(operands1, "i1");
192 addInstruction(spirv::Opcode::OpMemberName, operands1);
194 SmallVector<uint32_t, 5> operands2 = {structType, 1};
195 (void)spirv::encodeStringLiteralInto(operands2, "i2");
196 addInstruction(spirv::Opcode::OpMemberName, operands2);
198 binary.append(typeDecl.begin(), typeDecl.end());
199 EXPECT_TRUE(deserialize());
202 TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
203 addHeader();
204 SmallVector<uint32_t, 5> typeDecl;
205 std::swap(typeDecl, binary);
207 auto int32Type = addIntType(32);
208 auto int64Type = addIntType(64);
209 auto structType = addStructType({int32Type, int64Type});
210 std::swap(typeDecl, binary);
212 SmallVector<uint32_t, 5> operands1 = {structType};
213 addInstruction(spirv::Opcode::OpMemberName, operands1);
215 binary.append(typeDecl.begin(), typeDecl.end());
216 ASSERT_FALSE(deserialize());
217 expectDiagnostic("OpMemberName must have at least 3 operands");
220 TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
221 addHeader();
222 SmallVector<uint32_t, 5> typeDecl;
223 std::swap(typeDecl, binary);
225 auto int32Type = addIntType(32);
226 auto structType = addStructType({int32Type});
227 std::swap(typeDecl, binary);
229 SmallVector<uint32_t, 5> operands = {structType, 0};
230 (void)spirv::encodeStringLiteralInto(operands, "int32");
231 operands.push_back(42);
232 addInstruction(spirv::Opcode::OpMemberName, operands);
234 binary.append(typeDecl.begin(), typeDecl.end());
235 ASSERT_FALSE(deserialize());
236 expectDiagnostic("unexpected trailing words in OpMemberName instruction");
239 //===----------------------------------------------------------------------===//
240 // Functions
241 //===----------------------------------------------------------------------===//
243 TEST_F(DeserializationTest, FunctionMissingEndFailure) {
244 addHeader();
245 auto voidType = addVoidType();
246 auto fnType = addFunctionType(voidType, {});
247 addFunction(voidType, fnType);
248 // Missing OpFunctionEnd.
250 ASSERT_FALSE(deserialize());
251 expectDiagnostic("expected OpFunctionEnd instruction");
254 TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
255 addHeader();
256 auto voidType = addVoidType();
257 auto i32Type = addIntType(32);
258 auto fnType = addFunctionType(voidType, {i32Type});
259 addFunction(voidType, fnType);
260 // Missing OpFunctionParameter.
262 ASSERT_FALSE(deserialize());
263 expectDiagnostic("expected OpFunctionParameter instruction");
266 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
267 addHeader();
268 auto voidType = addVoidType();
269 auto fnType = addFunctionType(voidType, {});
270 addFunction(voidType, fnType);
271 // Missing OpLabel.
272 addReturn();
273 addFunctionEnd();
275 ASSERT_FALSE(deserialize());
276 expectDiagnostic("a basic block must start with OpLabel");
279 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
280 addHeader();
281 auto voidType = addVoidType();
282 auto fnType = addFunctionType(voidType, {});
283 addFunction(voidType, fnType);
284 addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
285 addReturn();
286 addFunctionEnd();
288 ASSERT_FALSE(deserialize());
289 expectDiagnostic("OpLabel should only have result <id>");