1 //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains corner case tests for the SPIR-V serializer that are not
10 // covered by normal serialization and deserialization roundtripping.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Target/SPIRV/Serialization.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "gmock/gmock.h"
33 //===----------------------------------------------------------------------===//
35 //===----------------------------------------------------------------------===//
37 class SerializationTest
: public ::testing::Test
{
40 context
.getOrLoadDialect
<mlir::spirv::SPIRVDialect
>();
44 /// Initializes an empty SPIR-V module op.
46 OpBuilder
builder(&context
);
47 OperationState
state(UnknownLoc::get(&context
),
48 spirv::ModuleOp::getOperationName());
49 state
.addAttribute("addressing_model",
50 builder
.getAttr
<spirv::AddressingModelAttr
>(
51 spirv::AddressingModel::Logical
));
52 state
.addAttribute("memory_model", builder
.getAttr
<spirv::MemoryModelAttr
>(
53 spirv::MemoryModel::GLSL450
));
54 state
.addAttribute("vce_triple",
55 spirv::VerCapExtAttr::get(
56 spirv::Version::V_1_0
, ArrayRef
<spirv::Capability
>(),
57 ArrayRef
<spirv::Extension
>(), &context
));
58 spirv::ModuleOp::build(builder
, state
);
59 module
= cast
<spirv::ModuleOp
>(Operation::create(state
));
62 /// Gets the `struct { float }` type.
63 spirv::StructType
getFloatStructType() {
64 OpBuilder
builder(module
->getRegion());
65 llvm::SmallVector
<Type
, 1> elementTypes
{builder
.getF32Type()};
66 llvm::SmallVector
<spirv::StructType::OffsetInfo
, 1> offsetInfo
{0};
67 return spirv::StructType::get(elementTypes
, offsetInfo
);
70 /// Inserts a global variable of the given `type` and `name`.
71 spirv::GlobalVariableOp
addGlobalVar(Type type
, llvm::StringRef name
) {
72 OpBuilder
builder(module
->getRegion());
73 auto ptrType
= spirv::PointerType::get(type
, spirv::StorageClass::Uniform
);
74 return builder
.create
<spirv::GlobalVariableOp
>(
75 UnknownLoc::get(&context
), TypeAttr::get(ptrType
),
76 builder
.getStringAttr(name
), nullptr);
79 // Inserts an Integer or a Vector of Integers constant of value 'val'.
80 spirv::ConstantOp
AddConstInt(Type type
, APInt val
) {
81 OpBuilder
builder(module
->getRegion());
82 auto loc
= UnknownLoc::get(&context
);
84 if (auto intType
= dyn_cast
<IntegerType
>(type
)) {
85 return builder
.create
<spirv::ConstantOp
>(
86 loc
, type
, builder
.getIntegerAttr(type
, val
));
88 if (auto vectorType
= dyn_cast
<VectorType
>(type
)) {
89 Type elemType
= vectorType
.getElementType();
90 if (auto intType
= dyn_cast
<IntegerType
>(elemType
)) {
91 return builder
.create
<spirv::ConstantOp
>(
93 DenseElementsAttr::get(vectorType
,
94 IntegerAttr::get(elemType
, val
).getValue()));
97 llvm_unreachable("unimplemented types for AddConstInt()");
100 /// Handles a SPIR-V instruction with the given `opcode` and `operand`.
101 /// Returns true to interrupt.
102 using HandleFn
= llvm::function_ref
<bool(spirv::Opcode opcode
,
103 ArrayRef
<uint32_t> operands
)>;
105 /// Returns true if we can find a matching instruction in the SPIR-V blob.
106 bool scanInstruction(HandleFn handleFn
) {
107 auto binarySize
= binary
.size();
108 auto *begin
= binary
.begin();
109 auto currOffset
= spirv::kHeaderWordCount
;
111 while (currOffset
< binarySize
) {
112 auto wordCount
= binary
[currOffset
] >> 16;
113 if (!wordCount
|| (currOffset
+ wordCount
> binarySize
))
116 spirv::Opcode opcode
=
117 static_cast<spirv::Opcode
>(binary
[currOffset
] & 0xffff);
118 llvm::ArrayRef
<uint32_t> operands(begin
+ currOffset
+ 1,
119 begin
+ currOffset
+ wordCount
);
120 if (handleFn(opcode
, operands
))
123 currOffset
+= wordCount
;
130 OwningOpRef
<spirv::ModuleOp
> module
;
131 SmallVector
<uint32_t, 0> binary
;
134 //===----------------------------------------------------------------------===//
136 //===----------------------------------------------------------------------===//
138 TEST_F(SerializationTest
, ContainsBlockDecoration
) {
139 auto structType
= getFloatStructType();
140 addGlobalVar(structType
, "var0");
142 ASSERT_TRUE(succeeded(spirv::serialize(module
.get(), binary
)));
144 auto hasBlockDecoration
= [](spirv::Opcode opcode
,
145 ArrayRef
<uint32_t> operands
) {
146 return opcode
== spirv::Opcode::OpDecorate
&& operands
.size() == 2 &&
147 operands
[1] == static_cast<uint32_t>(spirv::Decoration::Block
);
149 EXPECT_TRUE(scanInstruction(hasBlockDecoration
));
152 TEST_F(SerializationTest
, ContainsNoDuplicatedBlockDecoration
) {
153 auto structType
= getFloatStructType();
154 // Two global variables using the same type should not decorate the type with
155 // duplicated `Block` decorations.
156 addGlobalVar(structType
, "var0");
157 addGlobalVar(structType
, "var1");
159 ASSERT_TRUE(succeeded(spirv::serialize(module
.get(), binary
)));
162 auto countBlockDecoration
= [&count
](spirv::Opcode opcode
,
163 ArrayRef
<uint32_t> operands
) {
164 if (opcode
== spirv::Opcode::OpDecorate
&& operands
.size() == 2 &&
165 operands
[1] == static_cast<uint32_t>(spirv::Decoration::Block
))
169 ASSERT_FALSE(scanInstruction(countBlockDecoration
));
170 EXPECT_EQ(count
, 1u);
173 TEST_F(SerializationTest
, SignlessVsSignedIntegerConstantBitExtension
) {
175 auto signlessInt16Type
=
176 IntegerType::get(&context
, 16, IntegerType::Signless
);
177 auto signedInt16Type
= IntegerType::get(&context
, 16, IntegerType::Signed
);
178 // Check the bit extension of same value under different signedness semantics.
179 APInt
signlessIntConstVal(signlessInt16Type
.getWidth(), -1,
180 signlessInt16Type
.getSignedness());
181 APInt
signedIntConstVal(signedInt16Type
.getWidth(), -1,
182 signedInt16Type
.getSignedness());
184 AddConstInt(signlessInt16Type
, signlessIntConstVal
);
185 AddConstInt(signedInt16Type
, signedIntConstVal
);
186 ASSERT_TRUE(succeeded(spirv::serialize(module
.get(), binary
)));
188 auto hasSignlessVal
= [&](spirv::Opcode opcode
, ArrayRef
<uint32_t> operands
) {
189 return opcode
== spirv::Opcode::OpConstant
&& operands
.size() == 3 &&
190 operands
[2] == 65535;
192 EXPECT_TRUE(scanInstruction(hasSignlessVal
));
194 auto hasSignedVal
= [&](spirv::Opcode opcode
, ArrayRef
<uint32_t> operands
) {
195 return opcode
== spirv::Opcode::OpConstant
&& operands
.size() == 3 &&
196 operands
[2] == 4294967295;
198 EXPECT_TRUE(scanInstruction(hasSignedVal
));
201 TEST_F(SerializationTest
, ContainsSymbolName
) {
202 auto structType
= getFloatStructType();
203 addGlobalVar(structType
, "var0");
205 spirv::SerializationOptions options
;
206 options
.emitSymbolName
= true;
207 ASSERT_TRUE(succeeded(spirv::serialize(module
.get(), binary
, options
)));
209 auto hasVarName
= [](spirv::Opcode opcode
, ArrayRef
<uint32_t> operands
) {
210 unsigned index
= 1; // Skip the result <id>
211 return opcode
== spirv::Opcode::OpName
&&
212 spirv::decodeStringLiteral(operands
, index
) == "var0";
214 EXPECT_TRUE(scanInstruction(hasVarName
));
217 TEST_F(SerializationTest
, DoesNotContainSymbolName
) {
218 auto structType
= getFloatStructType();
219 addGlobalVar(structType
, "var0");
221 spirv::SerializationOptions options
;
222 options
.emitSymbolName
= false;
223 ASSERT_TRUE(succeeded(spirv::serialize(module
.get(), binary
, options
)));
225 auto hasVarName
= [](spirv::Opcode opcode
, ArrayRef
<uint32_t> operands
) {
226 unsigned index
= 1; // Skip the result <id>
227 return opcode
== spirv::Opcode::OpName
&&
228 spirv::decodeStringLiteral(operands
, index
) == "var0";
230 EXPECT_FALSE(scanInstruction(hasVarName
));