[Clang][SME2] Fix PSEL builtin predicates (#77097)
[llvm-project.git] / mlir / unittests / Dialect / SPIRV / SerializationTest.cpp
blob56a98cc205ab43f8ecf3b5db63cc6c5b15f99699
1 //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains corner case tests for the SPIR-V serializer that are not
10 // covered by normal serialization and deserialization roundtripping.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Target/SPIRV/Serialization.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "gmock/gmock.h"
31 using namespace mlir;
33 //===----------------------------------------------------------------------===//
34 // Test Fixture
35 //===----------------------------------------------------------------------===//
37 class SerializationTest : public ::testing::Test {
38 protected:
39 SerializationTest() {
40 context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
41 initModuleOp();
44 /// Initializes an empty SPIR-V module op.
45 void initModuleOp() {
46 OpBuilder builder(&context);
47 OperationState state(UnknownLoc::get(&context),
48 spirv::ModuleOp::getOperationName());
49 state.addAttribute("addressing_model",
50 builder.getAttr<spirv::AddressingModelAttr>(
51 spirv::AddressingModel::Logical));
52 state.addAttribute("memory_model", builder.getAttr<spirv::MemoryModelAttr>(
53 spirv::MemoryModel::GLSL450));
54 state.addAttribute("vce_triple",
55 spirv::VerCapExtAttr::get(
56 spirv::Version::V_1_0, ArrayRef<spirv::Capability>(),
57 ArrayRef<spirv::Extension>(), &context));
58 spirv::ModuleOp::build(builder, state);
59 module = cast<spirv::ModuleOp>(Operation::create(state));
62 /// Gets the `struct { float }` type.
63 spirv::StructType getFloatStructType() {
64 OpBuilder builder(module->getRegion());
65 llvm::SmallVector<Type, 1> elementTypes{builder.getF32Type()};
66 llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
67 return spirv::StructType::get(elementTypes, offsetInfo);
70 /// Inserts a global variable of the given `type` and `name`.
71 spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) {
72 OpBuilder builder(module->getRegion());
73 auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
74 return builder.create<spirv::GlobalVariableOp>(
75 UnknownLoc::get(&context), TypeAttr::get(ptrType),
76 builder.getStringAttr(name), nullptr);
79 // Inserts an Integer or a Vector of Integers constant of value 'val'.
80 spirv::ConstantOp AddConstInt(Type type, APInt val) {
81 OpBuilder builder(module->getRegion());
82 auto loc = UnknownLoc::get(&context);
84 if (auto intType = dyn_cast<IntegerType>(type)) {
85 return builder.create<spirv::ConstantOp>(
86 loc, type, builder.getIntegerAttr(type, val));
88 if (auto vectorType = dyn_cast<VectorType>(type)) {
89 Type elemType = vectorType.getElementType();
90 if (auto intType = dyn_cast<IntegerType>(elemType)) {
91 return builder.create<spirv::ConstantOp>(
92 loc, type,
93 DenseElementsAttr::get(vectorType,
94 IntegerAttr::get(elemType, val).getValue()));
97 llvm_unreachable("unimplemented types for AddConstInt()");
100 /// Handles a SPIR-V instruction with the given `opcode` and `operand`.
101 /// Returns true to interrupt.
102 using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode,
103 ArrayRef<uint32_t> operands)>;
105 /// Returns true if we can find a matching instruction in the SPIR-V blob.
106 bool scanInstruction(HandleFn handleFn) {
107 auto binarySize = binary.size();
108 auto *begin = binary.begin();
109 auto currOffset = spirv::kHeaderWordCount;
111 while (currOffset < binarySize) {
112 auto wordCount = binary[currOffset] >> 16;
113 if (!wordCount || (currOffset + wordCount > binarySize))
114 return false;
116 spirv::Opcode opcode =
117 static_cast<spirv::Opcode>(binary[currOffset] & 0xffff);
118 llvm::ArrayRef<uint32_t> operands(begin + currOffset + 1,
119 begin + currOffset + wordCount);
120 if (handleFn(opcode, operands))
121 return true;
123 currOffset += wordCount;
125 return false;
128 protected:
129 MLIRContext context;
130 OwningOpRef<spirv::ModuleOp> module;
131 SmallVector<uint32_t, 0> binary;
134 //===----------------------------------------------------------------------===//
135 // Block decoration
136 //===----------------------------------------------------------------------===//
138 TEST_F(SerializationTest, ContainsBlockDecoration) {
139 auto structType = getFloatStructType();
140 addGlobalVar(structType, "var0");
142 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
144 auto hasBlockDecoration = [](spirv::Opcode opcode,
145 ArrayRef<uint32_t> operands) {
146 return opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
147 operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
149 EXPECT_TRUE(scanInstruction(hasBlockDecoration));
152 TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) {
153 auto structType = getFloatStructType();
154 // Two global variables using the same type should not decorate the type with
155 // duplicated `Block` decorations.
156 addGlobalVar(structType, "var0");
157 addGlobalVar(structType, "var1");
159 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
161 unsigned count = 0;
162 auto countBlockDecoration = [&count](spirv::Opcode opcode,
163 ArrayRef<uint32_t> operands) {
164 if (opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
165 operands[1] == static_cast<uint32_t>(spirv::Decoration::Block))
166 ++count;
167 return false;
169 ASSERT_FALSE(scanInstruction(countBlockDecoration));
170 EXPECT_EQ(count, 1u);
173 TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
175 auto signlessInt16Type =
176 IntegerType::get(&context, 16, IntegerType::Signless);
177 auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed);
178 // Check the bit extension of same value under different signedness semantics.
179 APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1,
180 signlessInt16Type.getSignedness());
181 APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
182 signedInt16Type.getSignedness());
184 AddConstInt(signlessInt16Type, signlessIntConstVal);
185 AddConstInt(signedInt16Type, signedIntConstVal);
186 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
188 auto hasSignlessVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
189 return opcode == spirv::Opcode::OpConstant && operands.size() == 3 &&
190 operands[2] == 65535;
192 EXPECT_TRUE(scanInstruction(hasSignlessVal));
194 auto hasSignedVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
195 return opcode == spirv::Opcode::OpConstant && operands.size() == 3 &&
196 operands[2] == 4294967295;
198 EXPECT_TRUE(scanInstruction(hasSignedVal));
201 TEST_F(SerializationTest, ContainsSymbolName) {
202 auto structType = getFloatStructType();
203 addGlobalVar(structType, "var0");
205 spirv::SerializationOptions options;
206 options.emitSymbolName = true;
207 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
209 auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
210 unsigned index = 1; // Skip the result <id>
211 return opcode == spirv::Opcode::OpName &&
212 spirv::decodeStringLiteral(operands, index) == "var0";
214 EXPECT_TRUE(scanInstruction(hasVarName));
217 TEST_F(SerializationTest, DoesNotContainSymbolName) {
218 auto structType = getFloatStructType();
219 addGlobalVar(structType, "var0");
221 spirv::SerializationOptions options;
222 options.emitSymbolName = false;
223 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
225 auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
226 unsigned index = 1; // Skip the result <id>
227 return opcode == spirv::Opcode::OpName &&
228 spirv::decodeStringLiteral(operands, index) == "var0";
230 EXPECT_FALSE(scanInstruction(hasVarName));