1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/AsmState.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "gtest/gtest.h"
16 #include "../../test/lib/Dialect/Test/TestDialect.h"
19 using namespace mlir::detail
;
21 //===----------------------------------------------------------------------===//
23 //===----------------------------------------------------------------------===//
25 template <typename EltTy
>
26 static void testSplat(Type eltType
, const EltTy
&splatElt
) {
27 RankedTensorType shape
= RankedTensorType::get({2, 1}, eltType
);
29 // Check that the generated splat is the same for 1 element and N elements.
30 DenseElementsAttr splat
= DenseElementsAttr::get(shape
, splatElt
);
31 EXPECT_TRUE(splat
.isSplat());
34 DenseElementsAttr::get(shape
, llvm::ArrayRef({splatElt
, splatElt
}));
35 EXPECT_EQ(detectedSplat
, splat
);
37 for (auto newValue
: detectedSplat
.template getValues
<EltTy
>())
38 EXPECT_TRUE(newValue
== splatElt
);
42 TEST(DenseSplatTest
, BoolSplat
) {
44 IntegerType boolTy
= IntegerType::get(&context
, 1);
45 RankedTensorType shape
= RankedTensorType::get({2, 2}, boolTy
);
47 // Check that splat is automatically detected for boolean values.
49 DenseElementsAttr trueSplat
= DenseElementsAttr::get(shape
, true);
50 EXPECT_TRUE(trueSplat
.isSplat());
52 DenseElementsAttr falseSplat
= DenseElementsAttr::get(shape
, false);
53 EXPECT_TRUE(falseSplat
.isSplat());
54 EXPECT_NE(falseSplat
, trueSplat
);
56 /// Detect and handle splat within 8 elements (bool values are bit-packed).
58 auto detectedSplat
= DenseElementsAttr::get(shape
, {true, true, true, true});
59 EXPECT_EQ(detectedSplat
, trueSplat
);
61 detectedSplat
= DenseElementsAttr::get(shape
, {false, false, false, false});
62 EXPECT_EQ(detectedSplat
, falseSplat
);
64 TEST(DenseSplatTest
, BoolSplatRawRoundtrip
) {
66 IntegerType boolTy
= IntegerType::get(&context
, 1);
67 RankedTensorType shape
= RankedTensorType::get({2, 2}, boolTy
);
69 // Check that splat booleans properly round trip via the raw API.
70 DenseElementsAttr trueSplat
= DenseElementsAttr::get(shape
, true);
71 EXPECT_TRUE(trueSplat
.isSplat());
72 DenseElementsAttr trueSplatFromRaw
=
73 DenseElementsAttr::getFromRawBuffer(shape
, trueSplat
.getRawData());
74 EXPECT_TRUE(trueSplatFromRaw
.isSplat());
76 EXPECT_EQ(trueSplat
, trueSplatFromRaw
);
79 TEST(DenseSplatTest
, BoolSplatSmall
) {
81 Builder
builder(&context
);
83 // Check that splats that don't fill entire byte are handled properly.
84 auto tensorType
= RankedTensorType::get({4}, builder
.getI1Type());
85 std::vector
<char> data
{0b00001111};
86 auto trueSplatFromRaw
=
87 DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType
, data
);
88 EXPECT_TRUE(trueSplatFromRaw
.isSplat());
89 DenseElementsAttr trueSplat
= DenseElementsAttr::get(tensorType
, true);
90 EXPECT_EQ(trueSplat
, trueSplatFromRaw
);
93 TEST(DenseSplatTest
, LargeBoolSplat
) {
94 constexpr int64_t boolCount
= 56;
97 IntegerType boolTy
= IntegerType::get(&context
, 1);
98 RankedTensorType shape
= RankedTensorType::get({boolCount
}, boolTy
);
100 // Check that splat is automatically detected for boolean values.
102 DenseElementsAttr trueSplat
= DenseElementsAttr::get(shape
, true);
103 DenseElementsAttr falseSplat
= DenseElementsAttr::get(shape
, false);
104 EXPECT_TRUE(trueSplat
.isSplat());
105 EXPECT_TRUE(falseSplat
.isSplat());
107 /// Detect that the large boolean arrays are properly splatted.
109 SmallVector
<bool, 64> trueValues(boolCount
, true);
110 auto detectedSplat
= DenseElementsAttr::get(shape
, trueValues
);
111 EXPECT_EQ(detectedSplat
, trueSplat
);
113 SmallVector
<bool, 64> falseValues(boolCount
, false);
114 detectedSplat
= DenseElementsAttr::get(shape
, falseValues
);
115 EXPECT_EQ(detectedSplat
, falseSplat
);
118 TEST(DenseSplatTest
, BoolNonSplat
) {
120 IntegerType boolTy
= IntegerType::get(&context
, 1);
121 RankedTensorType shape
= RankedTensorType::get({6}, boolTy
);
123 // Check that we properly handle non-splat values.
124 DenseElementsAttr nonSplat
=
125 DenseElementsAttr::get(shape
, {false, false, true, false, false, true});
126 EXPECT_FALSE(nonSplat
.isSplat());
129 TEST(DenseSplatTest
, OddIntSplat
) {
130 // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
132 constexpr size_t intWidth
= 19;
133 IntegerType intTy
= IntegerType::get(&context
, intWidth
);
134 APInt
value(intWidth
, 10);
136 testSplat(intTy
, value
);
139 TEST(DenseSplatTest
, Int32Splat
) {
141 IntegerType intTy
= IntegerType::get(&context
, 32);
144 testSplat(intTy
, value
);
147 TEST(DenseSplatTest
, IntAttrSplat
) {
149 IntegerType intTy
= IntegerType::get(&context
, 85);
150 Attribute value
= IntegerAttr::get(intTy
, 109);
152 testSplat(intTy
, value
);
155 TEST(DenseSplatTest
, F32Splat
) {
157 FloatType floatTy
= FloatType::getF32(&context
);
160 testSplat(floatTy
, value
);
163 TEST(DenseSplatTest
, F64Splat
) {
165 FloatType floatTy
= FloatType::getF64(&context
);
168 testSplat(floatTy
, APFloat(value
));
171 TEST(DenseSplatTest
, FloatAttrSplat
) {
173 FloatType floatTy
= FloatType::getF32(&context
);
174 Attribute value
= FloatAttr::get(floatTy
, 10.0);
176 testSplat(floatTy
, value
);
179 TEST(DenseSplatTest
, BF16Splat
) {
181 FloatType floatTy
= FloatType::getBF16(&context
);
182 Attribute value
= FloatAttr::get(floatTy
, 10.0);
184 testSplat(floatTy
, value
);
187 TEST(DenseSplatTest
, StringSplat
) {
189 context
.allowUnregisteredDialects();
191 OpaqueType::get(StringAttr::get(&context
, "test"), "string");
192 StringRef value
= "test-string";
193 testSplat(stringType
, value
);
196 TEST(DenseSplatTest
, StringAttrSplat
) {
198 context
.allowUnregisteredDialects();
200 OpaqueType::get(StringAttr::get(&context
, "test"), "string");
201 Attribute stringAttr
= StringAttr::get("test-string", stringType
);
202 testSplat(stringType
, stringAttr
);
205 TEST(DenseComplexTest
, ComplexFloatSplat
) {
207 ComplexType complexType
= ComplexType::get(FloatType::getF32(&context
));
208 std::complex<float> value(10.0, 15.0);
209 testSplat(complexType
, value
);
212 TEST(DenseComplexTest
, ComplexIntSplat
) {
214 ComplexType complexType
= ComplexType::get(IntegerType::get(&context
, 64));
215 std::complex<int64_t> value(10, 15);
216 testSplat(complexType
, value
);
219 TEST(DenseComplexTest
, ComplexAPFloatSplat
) {
221 ComplexType complexType
= ComplexType::get(FloatType::getF32(&context
));
222 std::complex<APFloat
> value(APFloat(10.0f
), APFloat(15.0f
));
223 testSplat(complexType
, value
);
226 TEST(DenseComplexTest
, ComplexAPIntSplat
) {
228 ComplexType complexType
= ComplexType::get(IntegerType::get(&context
, 64));
229 std::complex<APInt
> value(APInt(64, 10), APInt(64, 15));
230 testSplat(complexType
, value
);
233 TEST(DenseScalarTest
, ExtractZeroRankElement
) {
235 const int elementValue
= 12;
236 IntegerType intTy
= IntegerType::get(&context
, 32);
237 Attribute value
= IntegerAttr::get(intTy
, elementValue
);
238 RankedTensorType shape
= RankedTensorType::get({}, intTy
);
240 auto attr
= DenseElementsAttr::get(shape
, llvm::ArrayRef({elementValue
}));
241 EXPECT_TRUE(attr
.getValues
<Attribute
>()[0] == value
);
244 TEST(DenseSplatMapValuesTest
, I32ToTrue
) {
246 const int elementValue
= 12;
247 IntegerType boolTy
= IntegerType::get(&context
, 1);
248 IntegerType intTy
= IntegerType::get(&context
, 32);
249 RankedTensorType shape
= RankedTensorType::get({4}, intTy
);
252 DenseElementsAttr::get(shape
, llvm::ArrayRef({elementValue
}))
253 .mapValues(boolTy
, [](const APInt
&x
) {
254 return x
.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
256 EXPECT_EQ(attr
.getNumElements(), 4);
257 EXPECT_TRUE(attr
.isSplat());
258 EXPECT_TRUE(attr
.getSplatValue
<BoolAttr
>().getValue());
261 TEST(DenseSplatMapValuesTest
, I32ToFalse
) {
263 const int elementValue
= 0;
264 IntegerType boolTy
= IntegerType::get(&context
, 1);
265 IntegerType intTy
= IntegerType::get(&context
, 32);
266 RankedTensorType shape
= RankedTensorType::get({4}, intTy
);
269 DenseElementsAttr::get(shape
, llvm::ArrayRef({elementValue
}))
270 .mapValues(boolTy
, [](const APInt
&x
) {
271 return x
.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
273 EXPECT_EQ(attr
.getNumElements(), 4);
274 EXPECT_TRUE(attr
.isSplat());
275 EXPECT_FALSE(attr
.getSplatValue
<BoolAttr
>().getValue());
279 //===----------------------------------------------------------------------===//
280 // DenseResourceElementsAttr
281 //===----------------------------------------------------------------------===//
283 template <typename AttrT
, typename T
>
284 static void checkNativeAccess(MLIRContext
*ctx
, ArrayRef
<T
> data
,
286 auto type
= RankedTensorType::get(data
.size(), elementType
);
287 auto attr
= AttrT::get(type
, "resource",
288 UnmanagedAsmResourceBlob::allocateInferAlign(data
));
290 // Check that we can access and iterate the data properly.
291 std::optional
<ArrayRef
<T
>> attrData
= attr
.tryGetAsArrayRef();
292 EXPECT_TRUE(attrData
.has_value());
293 EXPECT_EQ(*attrData
, data
);
295 // Check that we cast to this attribute when possible.
296 Attribute genericAttr
= attr
;
297 EXPECT_TRUE(isa
<AttrT
>(genericAttr
));
299 template <typename AttrT
, typename T
>
300 static void checkNativeIntAccess(Builder
&builder
, size_t intWidth
) {
301 T data
[] = {0, 1, 2};
302 checkNativeAccess
<AttrT
, T
>(builder
.getContext(), llvm::ArrayRef(data
),
303 builder
.getIntegerType(intWidth
));
307 TEST(DenseResourceElementsAttrTest
, CheckNativeAccess
) {
309 Builder
builder(&context
);
312 bool boolData
[] = {true, false, true};
313 checkNativeAccess
<DenseBoolResourceElementsAttr
>(
314 &context
, llvm::ArrayRef(boolData
), builder
.getI1Type());
317 checkNativeIntAccess
<DenseUI8ResourceElementsAttr
, uint8_t>(builder
, 8);
318 checkNativeIntAccess
<DenseUI16ResourceElementsAttr
, uint16_t>(builder
, 16);
319 checkNativeIntAccess
<DenseUI32ResourceElementsAttr
, uint32_t>(builder
, 32);
320 checkNativeIntAccess
<DenseUI64ResourceElementsAttr
, uint64_t>(builder
, 64);
323 checkNativeIntAccess
<DenseI8ResourceElementsAttr
, int8_t>(builder
, 8);
324 checkNativeIntAccess
<DenseI16ResourceElementsAttr
, int16_t>(builder
, 16);
325 checkNativeIntAccess
<DenseI32ResourceElementsAttr
, int32_t>(builder
, 32);
326 checkNativeIntAccess
<DenseI64ResourceElementsAttr
, int64_t>(builder
, 64);
329 float floatData
[] = {0, 1, 2};
330 checkNativeAccess
<DenseF32ResourceElementsAttr
>(
331 &context
, llvm::ArrayRef(floatData
), builder
.getF32Type());
334 double doubleData
[] = {0, 1, 2};
335 checkNativeAccess
<DenseF64ResourceElementsAttr
>(
336 &context
, llvm::ArrayRef(doubleData
), builder
.getF64Type());
339 TEST(DenseResourceElementsAttrTest
, CheckNoCast
) {
341 Builder
builder(&context
);
343 // Create a i32 attribute.
344 ArrayRef
<uint32_t> data
;
345 auto type
= RankedTensorType::get(data
.size(), builder
.getI32Type());
346 Attribute i32ResourceAttr
= DenseI32ResourceElementsAttr::get(
347 type
, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data
));
349 EXPECT_TRUE(isa
<DenseI32ResourceElementsAttr
>(i32ResourceAttr
));
350 EXPECT_FALSE(isa
<DenseF32ResourceElementsAttr
>(i32ResourceAttr
));
351 EXPECT_FALSE(isa
<DenseBoolResourceElementsAttr
>(i32ResourceAttr
));
354 TEST(DenseResourceElementsAttrTest
, CheckNotMutableAllocateAndCopy
) {
356 Builder
builder(&context
);
358 // Create a i32 attribute.
359 std::vector
<int32_t> data
= {10, 20, 30};
360 auto type
= RankedTensorType::get(data
.size(), builder
.getI32Type());
361 Attribute i32ResourceAttr
= DenseI32ResourceElementsAttr::get(
363 HeapAsmResourceBlob::allocateAndCopyInferAlign
<int32_t>(
364 data
, /*is_mutable=*/false));
366 EXPECT_TRUE(isa
<DenseI32ResourceElementsAttr
>(i32ResourceAttr
));
369 TEST(DenseResourceElementsAttrTest
, CheckInvalidData
) {
371 Builder
builder(&context
);
373 // Create a bool attribute with data of the incorrect type.
374 ArrayRef
<uint32_t> data
;
375 auto type
= RankedTensorType::get(data
.size(), builder
.getI32Type());
378 DenseBoolResourceElementsAttr::get(
380 UnmanagedAsmResourceBlob::allocateInferAlign(data
));
382 "alignment mismatch between expected alignment and blob alignment");
385 TEST(DenseResourceElementsAttrTest
, CheckInvalidType
) {
387 Builder
builder(&context
);
389 // Create a bool attribute with incorrect type.
391 auto type
= RankedTensorType::get(data
.size(), builder
.getI32Type());
394 DenseBoolResourceElementsAttr::get(
396 UnmanagedAsmResourceBlob::allocateInferAlign(data
));
398 "invalid shape element type for provided type `T`");
402 //===----------------------------------------------------------------------===//
403 // SparseElementsAttr
404 //===----------------------------------------------------------------------===//
407 TEST(SparseElementsAttrTest
, GetZero
) {
409 context
.allowUnregisteredDialects();
411 IntegerType intTy
= IntegerType::get(&context
, 32);
412 FloatType floatTy
= FloatType::getF32(&context
);
413 Type stringTy
= OpaqueType::get(StringAttr::get(&context
, "test"), "string");
415 ShapedType tensorI32
= RankedTensorType::get({2, 2}, intTy
);
416 ShapedType tensorF32
= RankedTensorType::get({2, 2}, floatTy
);
417 ShapedType tensorString
= RankedTensorType::get({2, 2}, stringTy
);
420 RankedTensorType::get({1, 2}, IntegerType::get(&context
, 64));
422 DenseIntElementsAttr::get(indicesType
, {APInt(64, 0), APInt(64, 0)});
424 RankedTensorType intValueTy
= RankedTensorType::get({1}, intTy
);
425 auto intValue
= DenseIntElementsAttr::get(intValueTy
, {1});
427 RankedTensorType floatValueTy
= RankedTensorType::get({1}, floatTy
);
428 auto floatValue
= DenseFPElementsAttr::get(floatValueTy
, {1.0f
});
430 RankedTensorType stringValueTy
= RankedTensorType::get({1}, stringTy
);
431 auto stringValue
= DenseElementsAttr::get(stringValueTy
, {StringRef("foo")});
433 auto sparseInt
= SparseElementsAttr::get(tensorI32
, indices
, intValue
);
434 auto sparseFloat
= SparseElementsAttr::get(tensorF32
, indices
, floatValue
);
436 SparseElementsAttr::get(tensorString
, indices
, stringValue
);
438 // Only index (0, 0) contains an element, others are supposed to return
439 // the zero/empty value.
441 cast
<IntegerAttr
>(sparseInt
.getValues
<Attribute
>()[{1, 1}]);
442 EXPECT_EQ(zeroIntValue
.getInt(), 0);
443 EXPECT_TRUE(zeroIntValue
.getType() == intTy
);
445 auto zeroFloatValue
=
446 cast
<FloatAttr
>(sparseFloat
.getValues
<Attribute
>()[{1, 1}]);
447 EXPECT_EQ(zeroFloatValue
.getValueAsDouble(), 0.0f
);
448 EXPECT_TRUE(zeroFloatValue
.getType() == floatTy
);
450 auto zeroStringValue
=
451 cast
<StringAttr
>(sparseString
.getValues
<Attribute
>()[{1, 1}]);
452 EXPECT_TRUE(zeroStringValue
.empty());
453 EXPECT_TRUE(zeroStringValue
.getType() == stringTy
);
456 //===----------------------------------------------------------------------===//
458 //===----------------------------------------------------------------------===//
460 TEST(SubElementTest
, Nested
) {
462 Builder
builder(&context
);
464 BoolAttr trueAttr
= builder
.getBoolAttr(true);
465 BoolAttr falseAttr
= builder
.getBoolAttr(false);
466 ArrayAttr boolArrayAttr
=
467 builder
.getArrayAttr({trueAttr
, falseAttr
, trueAttr
});
468 StringAttr strAttr
= builder
.getStringAttr("array");
469 DictionaryAttr dictAttr
=
470 builder
.getDictionaryAttr(builder
.getNamedAttr(strAttr
, boolArrayAttr
));
472 SmallVector
<Attribute
> subAttrs
;
473 dictAttr
.walk([&](Attribute attr
) { subAttrs
.push_back(attr
); });
474 // Note that trueAttr appears only once, identical subattributes are skipped.
475 EXPECT_EQ(llvm::ArrayRef(subAttrs
),
477 {strAttr
, trueAttr
, falseAttr
, boolArrayAttr
, dictAttr
}));
480 // Test how many times we call copy-ctor when building an attribute.
481 TEST(CopyCountAttr
, CopyCount
) {
483 context
.loadDialect
<test::TestDialect
>();
485 test::CopyCount::counter
= 0;
486 test::CopyCount
copyCount("hello");
487 test::TestCopyCountAttr::get(&context
, std::move(copyCount
));
488 int counter1
= test::CopyCount::counter
;
489 test::CopyCount::counter
= 0;
490 test::TestCopyCountAttr::get(&context
, std::move(copyCount
));
492 // One verification enabled only in assert-mode requires a copy.
493 EXPECT_EQ(counter1
, 1);
494 EXPECT_EQ(test::CopyCount::counter
, 1);
496 EXPECT_EQ(counter1
, 0);
497 EXPECT_EQ(test::CopyCount::counter
, 0);
501 // Test stripped printing using test dialect attribute.
502 TEST(CopyCountAttr
, PrintStripped
) {
504 context
.loadDialect
<test::TestDialect
>();
505 // Doesn't matter which dialect attribute is used, just chose TestCopyCount
507 test::CopyCount::counter
= 0;
508 test::CopyCount
copyCount("hello");
509 Attribute res
= test::TestCopyCountAttr::get(&context
, std::move(copyCount
));
512 llvm::raw_string_ostream
os(str
);
513 os
<< "|" << res
<< "|";
514 res
.printStripped(os
<< "[");
516 EXPECT_EQ(str
, "|#test.copy_count<hello>|[copy_count<hello>]");