1 //===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/AffineMap.h"
10 #include "mlir/IR/BuiltinAttributes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/Support/LLVM.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "gtest/gtest.h"
20 using namespace mlir::detail
;
23 TEST(ShapedTypeTest
, CloneMemref
) {
26 Type i32
= IntegerType::get(&context
, 32);
27 Type f32
= FloatType::getF32(&context
);
28 Attribute memSpace
= IntegerAttr::get(IntegerType::get(&context
, 64), 7);
29 Type memrefOriginalType
= i32
;
30 llvm::SmallVector
<int64_t> memrefOriginalShape({10, 20});
31 AffineMap map
= makeStridedLinearLayoutMap({2, 3}, 5, &context
);
33 ShapedType memrefType
=
34 (ShapedType
)MemRefType::Builder(memrefOriginalShape
, memrefOriginalType
)
35 .setMemorySpace(memSpace
)
36 .setLayout(AffineMapAttr::get(map
));
38 llvm::SmallVector
<int64_t> memrefNewShape({30, 40});
39 ASSERT_NE(memrefOriginalShape
, memrefNewShape
);
40 ASSERT_EQ(memrefType
.clone(memrefNewShape
),
41 (ShapedType
)MemRefType::Builder(memrefNewShape
, memrefOriginalType
)
42 .setMemorySpace(memSpace
)
43 .setLayout(AffineMapAttr::get(map
)));
45 Type memrefNewType
= f32
;
46 ASSERT_NE(memrefOriginalType
, memrefNewType
);
47 ASSERT_EQ(memrefType
.clone(memrefNewType
),
48 (MemRefType
)MemRefType::Builder(memrefOriginalShape
, memrefNewType
)
49 .setMemorySpace(memSpace
)
50 .setLayout(AffineMapAttr::get(map
)));
52 ASSERT_EQ(memrefType
.clone(memrefNewShape
, memrefNewType
),
53 (MemRefType
)MemRefType::Builder(memrefNewShape
, memrefNewType
)
54 .setMemorySpace(memSpace
)
55 .setLayout(AffineMapAttr::get(map
)));
57 // Test unranked memref cloning.
58 ShapedType unrankedTensorType
=
59 UnrankedMemRefType::get(memrefOriginalType
, memSpace
);
60 ASSERT_EQ(unrankedTensorType
.clone(memrefNewShape
),
61 (MemRefType
)MemRefType::Builder(memrefNewShape
, memrefOriginalType
)
62 .setMemorySpace(memSpace
));
63 ASSERT_EQ(unrankedTensorType
.clone(memrefNewType
),
64 UnrankedMemRefType::get(memrefNewType
, memSpace
));
65 ASSERT_EQ(unrankedTensorType
.clone(memrefNewShape
, memrefNewType
),
66 (MemRefType
)MemRefType::Builder(memrefNewShape
, memrefNewType
)
67 .setMemorySpace(memSpace
));
70 TEST(ShapedTypeTest
, CloneTensor
) {
73 Type i32
= IntegerType::get(&context
, 32);
74 Type f32
= FloatType::getF32(&context
);
76 Type tensorOriginalType
= i32
;
77 llvm::SmallVector
<int64_t> tensorOriginalShape({10, 20});
79 // Test ranked tensor cloning.
80 ShapedType tensorType
=
81 RankedTensorType::get(tensorOriginalShape
, tensorOriginalType
);
83 llvm::SmallVector
<int64_t> tensorNewShape({30, 40});
84 ASSERT_NE(tensorOriginalShape
, tensorNewShape
);
86 tensorType
.clone(tensorNewShape
),
87 (ShapedType
)RankedTensorType::get(tensorNewShape
, tensorOriginalType
));
89 Type tensorNewType
= f32
;
90 ASSERT_NE(tensorOriginalType
, tensorNewType
);
92 tensorType
.clone(tensorNewType
),
93 (ShapedType
)RankedTensorType::get(tensorOriginalShape
, tensorNewType
));
95 ASSERT_EQ(tensorType
.clone(tensorNewShape
, tensorNewType
),
96 (ShapedType
)RankedTensorType::get(tensorNewShape
, tensorNewType
));
98 // Test unranked tensor cloning.
99 ShapedType unrankedTensorType
= UnrankedTensorType::get(tensorOriginalType
);
101 unrankedTensorType
.clone(tensorNewShape
),
102 (ShapedType
)RankedTensorType::get(tensorNewShape
, tensorOriginalType
));
103 ASSERT_EQ(unrankedTensorType
.clone(tensorNewType
),
104 (ShapedType
)UnrankedTensorType::get(tensorNewType
));
106 unrankedTensorType
.clone(tensorNewShape
),
107 (ShapedType
)RankedTensorType::get(tensorNewShape
, tensorOriginalType
));
110 TEST(ShapedTypeTest
, CloneVector
) {
113 Type i32
= IntegerType::get(&context
, 32);
114 Type f32
= FloatType::getF32(&context
);
116 Type vectorOriginalType
= i32
;
117 llvm::SmallVector
<int64_t> vectorOriginalShape({10, 20});
118 ShapedType vectorType
=
119 VectorType::get(vectorOriginalShape
, vectorOriginalType
);
121 llvm::SmallVector
<int64_t> vectorNewShape({30, 40});
122 ASSERT_NE(vectorOriginalShape
, vectorNewShape
);
123 ASSERT_EQ(vectorType
.clone(vectorNewShape
),
124 VectorType::get(vectorNewShape
, vectorOriginalType
));
126 Type vectorNewType
= f32
;
127 ASSERT_NE(vectorOriginalType
, vectorNewType
);
128 ASSERT_EQ(vectorType
.clone(vectorNewType
),
129 VectorType::get(vectorOriginalShape
, vectorNewType
));
131 ASSERT_EQ(vectorType
.clone(vectorNewShape
, vectorNewType
),
132 VectorType::get(vectorNewShape
, vectorNewType
));
135 TEST(ShapedTypeTest
, VectorTypeBuilder
) {
137 Type f32
= FloatType::getF32(&context
);
139 SmallVector
<int64_t> shape
{2, 4, 8, 9, 1};
140 SmallVector
<bool> scalableDims
{true, false, true, false, false};
141 VectorType vectorType
= VectorType::get(shape
, f32
, scalableDims
);
145 VectorType dropFrontTwoDims
=
146 VectorType::Builder(vectorType
).dropDim(0).dropDim(0);
147 ASSERT_EQ(vectorType
.getElementType(), dropFrontTwoDims
.getElementType());
148 ASSERT_EQ(vectorType
.getShape().drop_front(2), dropFrontTwoDims
.getShape());
149 ASSERT_EQ(vectorType
.getScalableDims().drop_front(2),
150 dropFrontTwoDims
.getScalableDims());
155 VectorType setTwoDims
=
156 VectorType::Builder(vectorType
).setDim(0, 10).setDim(3, 12);
157 ASSERT_EQ(setTwoDims
.getShape(), ArrayRef
<int64_t>({10, 4, 8, 12, 1}));
158 ASSERT_EQ(vectorType
.getElementType(), setTwoDims
.getElementType());
159 ASSERT_EQ(vectorType
.getScalableDims(), setTwoDims
.getScalableDims());
163 // Test for bug from:
164 // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
165 // Constructs a temporary builder, modifies it, copies it to `builder`.
166 // This used to lead to a use-after-free. Running under sanitizers will
168 VectorType::Builder builder
= VectorType::Builder(vectorType
).setDim(0, 16);
169 VectorType newVectorType
= VectorType(builder
);
170 ASSERT_EQ(newVectorType
.getDimSize(0), 16);
174 // Make builder from scratch (without scalable dims) -- this use to lead to
175 // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
176 // Running under sanitizers will catch any issues.
177 SmallVector
<int64_t> shape
{1, 2, 3, 4};
178 VectorType::Builder
builder(shape
, f32
);
179 ASSERT_EQ(VectorType(builder
).getShape(), ArrayRef(shape
));
183 // Set vector shape (without scalable dims) -- this use to lead to
184 // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
185 // Running under sanitizers will catch any issues.
186 VectorType::Builder
builder(vectorType
);
187 SmallVector
<int64_t> newShape
{2, 2};
188 builder
.setShape(newShape
);
189 ASSERT_EQ(VectorType(builder
).getShape(), ArrayRef(newShape
));
193 TEST(ShapedTypeTest
, RankedTensorTypeBuilder
) {
195 Type f32
= FloatType::getF32(&context
);
197 SmallVector
<int64_t> shape
{2, 4, 8, 16, 32};
198 RankedTensorType tensorType
= RankedTensorType::get(shape
, f32
);
202 RankedTensorType dropFrontTwoDims
=
203 RankedTensorType::Builder(tensorType
).dropDim(0).dropDim(1).dropDim(0);
204 ASSERT_EQ(tensorType
.getElementType(), dropFrontTwoDims
.getElementType());
205 ASSERT_EQ(dropFrontTwoDims
.getShape(), ArrayRef
<int64_t>({16, 32}));
210 RankedTensorType insertTwoDims
=
211 RankedTensorType::Builder(tensorType
).insertDim(7, 2).insertDim(9, 3);
212 ASSERT_EQ(tensorType
.getElementType(), insertTwoDims
.getElementType());
213 ASSERT_EQ(insertTwoDims
.getShape(),
214 ArrayRef
<int64_t>({2, 4, 7, 9, 8, 16, 32}));
218 // Test for bug from:
219 // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
220 // Constructs a temporary builder, modifies it, copies it to `builder`.
221 // This used to lead to a use-after-free. Running under sanitizers will
223 RankedTensorType::Builder builder
=
224 RankedTensorType::Builder(tensorType
).dropDim(0);
225 RankedTensorType newTensorType
= RankedTensorType(builder
);
226 ASSERT_EQ(tensorType
.getShape().drop_front(), newTensorType
.getShape());
230 /// Simple wrapper class to enable "isa querying" and simple accessing of
232 class TensorWithString
: public RankedTensorType
{
234 using RankedTensorType::RankedTensorType
;
236 static TensorWithString
get(ArrayRef
<int64_t> shape
, Type elementType
,
238 return mlir::cast
<TensorWithString
>(RankedTensorType::get(
239 shape
, elementType
, StringAttr::get(elementType
.getContext(), name
)));
242 StringRef
getName() const {
243 if (Attribute enc
= getEncoding())
244 return mlir::cast
<StringAttr
>(enc
).getValue();
248 static bool classof(Type type
) {
249 if (auto rt
= mlir::dyn_cast_or_null
<RankedTensorType
>(type
))
250 return mlir::isa_and_present
<StringAttr
>(rt
.getEncoding());
255 TEST(ShapedTypeTest
, RankedTensorTypeView
) {
257 Type f32
= FloatType::getF32(&context
);
259 Type noEncodingRankedTensorType
= RankedTensorType::get({10, 20}, f32
);
261 UnitAttr unitAttr
= UnitAttr::get(&context
);
262 Type unitEncodingRankedTensorType
=
263 RankedTensorType::get({10, 20}, f32
, unitAttr
);
265 StringAttr stringAttr
= StringAttr::get(&context
, "app");
266 Type stringEncodingRankedTensorType
=
267 RankedTensorType::get({10, 20}, f32
, stringAttr
);
269 EXPECT_FALSE(mlir::isa
<TensorWithString
>(noEncodingRankedTensorType
));
270 EXPECT_FALSE(mlir::isa
<TensorWithString
>(unitEncodingRankedTensorType
));
271 ASSERT_TRUE(mlir::isa
<TensorWithString
>(stringEncodingRankedTensorType
));
273 // Cast to TensorWithString view.
274 auto view
= mlir::cast
<TensorWithString
>(stringEncodingRankedTensorType
);
275 ASSERT_TRUE(mlir::isa
<TensorWithString
>(view
));
276 EXPECT_EQ(view
.getName(), "app");
277 // Verify one could cast view type back to base type.
278 ASSERT_TRUE(mlir::isa
<RankedTensorType
>(view
));
280 Type viewCreated
= TensorWithString::get({10, 20}, f32
, "bob");
281 ASSERT_TRUE(mlir::isa
<TensorWithString
>(viewCreated
));
282 ASSERT_TRUE(mlir::isa
<RankedTensorType
>(viewCreated
));
283 view
= mlir::cast
<TensorWithString
>(viewCreated
);
284 EXPECT_EQ(view
.getName(), "bob");