1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/DialectResourceBlobManager.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/SymbolTable.h"
19 #include "mlir/IR/Types.h"
20 #include "llvm/ADT/APSInt.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/Endian.h"
27 #define DEBUG_TYPE "builtinattributes"
30 using namespace mlir::detail
;
32 //===----------------------------------------------------------------------===//
33 /// Tablegen Attribute Definitions
34 //===----------------------------------------------------------------------===//
36 #define GET_ATTRDEF_CLASSES
37 #include "mlir/IR/BuiltinAttributes.cpp.inc"
39 //===----------------------------------------------------------------------===//
41 //===----------------------------------------------------------------------===//
43 void BuiltinDialect::registerAttributes() {
45 #define GET_ATTRDEF_LIST
46 #include "mlir/IR/BuiltinAttributes.cpp.inc"
48 addAttributes
<DistinctAttr
>();
51 //===----------------------------------------------------------------------===//
53 //===----------------------------------------------------------------------===//
55 /// Helper function that does either an in place sort or sorts from source array
56 /// into destination. If inPlace then storage is both the source and the
57 /// destination, else value is the source and storage destination. Returns
58 /// whether source was sorted.
59 template <bool inPlace
>
60 static bool dictionaryAttrSort(ArrayRef
<NamedAttribute
> value
,
61 SmallVectorImpl
<NamedAttribute
> &storage
) {
62 // Specialize for the common case.
63 switch (value
.size()) {
65 // Zero already sorted.
70 // One already sorted but may need to be copied.
72 storage
.assign({value
[0]});
75 bool isSorted
= value
[0] < value
[1];
78 std::swap(storage
[0], storage
[1]);
79 } else if (isSorted
) {
80 storage
.assign({value
[0], value
[1]});
82 storage
.assign({value
[1], value
[0]});
88 storage
.assign(value
.begin(), value
.end());
89 // Check to see they are sorted already.
90 bool isSorted
= llvm::is_sorted(value
);
91 // If not, do a general sort.
93 llvm::array_pod_sort(storage
.begin(), storage
.end());
99 /// Returns an entry with a duplicate name from the given sorted array of named
100 /// attributes. Returns std::nullopt if all elements have unique names.
101 static std::optional
<NamedAttribute
>
102 findDuplicateElement(ArrayRef
<NamedAttribute
> value
) {
103 const std::optional
<NamedAttribute
> none
{std::nullopt
};
104 if (value
.size() < 2)
107 if (value
.size() == 2)
108 return value
[0].getName() == value
[1].getName() ? value
[0] : none
;
110 const auto *it
= std::adjacent_find(value
.begin(), value
.end(),
111 [](NamedAttribute l
, NamedAttribute r
) {
112 return l
.getName() == r
.getName();
114 return it
!= value
.end() ? *it
: none
;
117 bool DictionaryAttr::sort(ArrayRef
<NamedAttribute
> value
,
118 SmallVectorImpl
<NamedAttribute
> &storage
) {
119 bool isSorted
= dictionaryAttrSort
</*inPlace=*/false>(value
, storage
);
120 assert(!findDuplicateElement(storage
) &&
121 "DictionaryAttr element names must be unique");
125 bool DictionaryAttr::sortInPlace(SmallVectorImpl
<NamedAttribute
> &array
) {
126 bool isSorted
= dictionaryAttrSort
</*inPlace=*/true>(array
, array
);
127 assert(!findDuplicateElement(array
) &&
128 "DictionaryAttr element names must be unique");
132 std::optional
<NamedAttribute
>
133 DictionaryAttr::findDuplicate(SmallVectorImpl
<NamedAttribute
> &array
,
136 dictionaryAttrSort
</*inPlace=*/true>(array
, array
);
137 return findDuplicateElement(array
);
140 DictionaryAttr
DictionaryAttr::get(MLIRContext
*context
,
141 ArrayRef
<NamedAttribute
> value
) {
143 return DictionaryAttr::getEmpty(context
);
145 // We need to sort the element list to canonicalize it.
146 SmallVector
<NamedAttribute
, 8> storage
;
147 if (dictionaryAttrSort
</*inPlace=*/false>(value
, storage
))
149 assert(!findDuplicateElement(value
) &&
150 "DictionaryAttr element names must be unique");
151 return Base::get(context
, value
);
153 /// Construct a dictionary with an array of values that is known to already be
154 /// sorted by name and uniqued.
155 DictionaryAttr
DictionaryAttr::getWithSorted(MLIRContext
*context
,
156 ArrayRef
<NamedAttribute
> value
) {
158 return DictionaryAttr::getEmpty(context
);
159 // Ensure that the attribute elements are unique and sorted.
160 assert(llvm::is_sorted(
161 value
, [](NamedAttribute l
, NamedAttribute r
) { return l
< r
; }) &&
162 "expected attribute values to be sorted");
163 assert(!findDuplicateElement(value
) &&
164 "DictionaryAttr element names must be unique");
165 return Base::get(context
, value
);
168 /// Return the specified attribute if present, null otherwise.
169 Attribute
DictionaryAttr::get(StringRef name
) const {
170 auto it
= impl::findAttrSorted(begin(), end(), name
);
171 return it
.second
? it
.first
->getValue() : Attribute();
173 Attribute
DictionaryAttr::get(StringAttr name
) const {
174 auto it
= impl::findAttrSorted(begin(), end(), name
);
175 return it
.second
? it
.first
->getValue() : Attribute();
178 /// Return the specified named attribute if present, std::nullopt otherwise.
179 std::optional
<NamedAttribute
> DictionaryAttr::getNamed(StringRef name
) const {
180 auto it
= impl::findAttrSorted(begin(), end(), name
);
181 return it
.second
? *it
.first
: std::optional
<NamedAttribute
>();
183 std::optional
<NamedAttribute
> DictionaryAttr::getNamed(StringAttr name
) const {
184 auto it
= impl::findAttrSorted(begin(), end(), name
);
185 return it
.second
? *it
.first
: std::optional
<NamedAttribute
>();
188 /// Return whether the specified attribute is present.
189 bool DictionaryAttr::contains(StringRef name
) const {
190 return impl::findAttrSorted(begin(), end(), name
).second
;
192 bool DictionaryAttr::contains(StringAttr name
) const {
193 return impl::findAttrSorted(begin(), end(), name
).second
;
196 DictionaryAttr::iterator
DictionaryAttr::begin() const {
197 return getValue().begin();
199 DictionaryAttr::iterator
DictionaryAttr::end() const {
200 return getValue().end();
202 size_t DictionaryAttr::size() const { return getValue().size(); }
204 DictionaryAttr
DictionaryAttr::getEmptyUnchecked(MLIRContext
*context
) {
205 return Base::get(context
, ArrayRef
<NamedAttribute
>());
208 //===----------------------------------------------------------------------===//
210 //===----------------------------------------------------------------------===//
212 /// Prints a strided layout attribute.
213 void StridedLayoutAttr::print(llvm::raw_ostream
&os
) const {
214 auto printIntOrQuestion
= [&](int64_t value
) {
215 if (ShapedType::isDynamic(value
))
222 llvm::interleaveComma(getStrides(), os
, printIntOrQuestion
);
225 if (getOffset() != 0) {
227 printIntOrQuestion(getOffset());
232 /// Returns true if this layout is static, i.e. the strides and offset all have
233 /// a known value > 0.
234 bool StridedLayoutAttr::hasStaticLayout() const {
235 return !ShapedType::isDynamic(getOffset()) &&
236 !ShapedType::isDynamicShape(getStrides());
239 /// Returns the strided layout as an affine map.
240 AffineMap
StridedLayoutAttr::getAffineMap() const {
241 return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext());
244 /// Checks that the type-agnostic strided layout invariants are satisfied.
246 StridedLayoutAttr::verify(function_ref
<InFlightDiagnostic()> emitError
,
247 int64_t offset
, ArrayRef
<int64_t> strides
) {
248 if (llvm::is_contained(strides
, 0))
249 return emitError() << "strides must not be zero";
254 /// Checks that the type-specific strided layout invariants are satisfied.
255 LogicalResult
StridedLayoutAttr::verifyLayout(
256 ArrayRef
<int64_t> shape
,
257 function_ref
<InFlightDiagnostic()> emitError
) const {
258 if (shape
.size() != getStrides().size())
259 return emitError() << "expected the number of strides to match the rank";
264 //===----------------------------------------------------------------------===//
266 //===----------------------------------------------------------------------===//
268 StringAttr
StringAttr::getEmptyStringAttrUnchecked(MLIRContext
*context
) {
269 return Base::get(context
, "", NoneType::get(context
));
272 /// Twine support for StringAttr.
273 StringAttr
StringAttr::get(MLIRContext
*context
, const Twine
&twine
) {
274 // Fast-path empty twine.
275 if (twine
.isTriviallyEmpty())
277 SmallVector
<char, 32> tempStr
;
278 return Base::get(context
, twine
.toStringRef(tempStr
), NoneType::get(context
));
281 /// Twine support for StringAttr.
282 StringAttr
StringAttr::get(const Twine
&twine
, Type type
) {
283 SmallVector
<char, 32> tempStr
;
284 return Base::get(type
.getContext(), twine
.toStringRef(tempStr
), type
);
287 StringRef
StringAttr::getValue() const { return getImpl()->value
; }
289 Type
StringAttr::getType() const { return getImpl()->type
; }
291 Dialect
*StringAttr::getReferencedDialect() const {
292 return getImpl()->referencedDialect
;
295 //===----------------------------------------------------------------------===//
297 //===----------------------------------------------------------------------===//
299 double FloatAttr::getValueAsDouble() const {
300 return getValueAsDouble(getValue());
302 double FloatAttr::getValueAsDouble(APFloat value
) {
303 if (&value
.getSemantics() != &APFloat::IEEEdouble()) {
304 bool losesInfo
= false;
305 value
.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven
,
308 return value
.convertToDouble();
311 LogicalResult
FloatAttr::verify(function_ref
<InFlightDiagnostic()> emitError
,
312 Type type
, APFloat value
) {
313 // Verify that the type is correct.
314 if (!llvm::isa
<FloatType
>(type
))
315 return emitError() << "expected floating point type";
317 // Verify that the type semantics match that of the value.
318 if (&llvm::cast
<FloatType
>(type
).getFloatSemantics() !=
319 &value
.getSemantics()) {
321 << "FloatAttr type doesn't match the type implied by its value";
326 //===----------------------------------------------------------------------===//
328 //===----------------------------------------------------------------------===//
330 SymbolRefAttr
SymbolRefAttr::get(MLIRContext
*ctx
, StringRef value
,
331 ArrayRef
<FlatSymbolRefAttr
> nestedRefs
) {
332 return get(StringAttr::get(ctx
, value
), nestedRefs
);
335 FlatSymbolRefAttr
SymbolRefAttr::get(MLIRContext
*ctx
, StringRef value
) {
336 return llvm::cast
<FlatSymbolRefAttr
>(get(ctx
, value
, {}));
339 FlatSymbolRefAttr
SymbolRefAttr::get(StringAttr value
) {
340 return llvm::cast
<FlatSymbolRefAttr
>(get(value
, {}));
343 FlatSymbolRefAttr
SymbolRefAttr::get(Operation
*symbol
) {
345 symbol
->getAttrOfType
<StringAttr
>(SymbolTable::getSymbolAttrName());
346 assert(symName
&& "value does not have a valid symbol name");
347 return SymbolRefAttr::get(symName
);
350 StringAttr
SymbolRefAttr::getLeafReference() const {
351 ArrayRef
<FlatSymbolRefAttr
> nestedRefs
= getNestedReferences();
352 return nestedRefs
.empty() ? getRootReference() : nestedRefs
.back().getAttr();
355 //===----------------------------------------------------------------------===//
357 //===----------------------------------------------------------------------===//
359 int64_t IntegerAttr::getInt() const {
360 assert((getType().isIndex() || getType().isSignlessInteger()) &&
361 "must be signless integer");
362 return getValue().getSExtValue();
365 int64_t IntegerAttr::getSInt() const {
366 assert(getType().isSignedInteger() && "must be signed integer");
367 return getValue().getSExtValue();
370 uint64_t IntegerAttr::getUInt() const {
371 assert(getType().isUnsignedInteger() && "must be unsigned integer");
372 return getValue().getZExtValue();
375 /// Return the value as an APSInt which carries the signed from the type of
376 /// the attribute. This traps on signless integers types!
377 APSInt
IntegerAttr::getAPSInt() const {
378 assert(!getType().isSignlessInteger() &&
379 "Signless integers don't carry a sign for APSInt");
380 return APSInt(getValue(), getType().isUnsignedInteger());
383 LogicalResult
IntegerAttr::verify(function_ref
<InFlightDiagnostic()> emitError
,
384 Type type
, APInt value
) {
385 if (IntegerType integerType
= llvm::dyn_cast
<IntegerType
>(type
)) {
386 if (integerType
.getWidth() != value
.getBitWidth())
387 return emitError() << "integer type bit width (" << integerType
.getWidth()
388 << ") doesn't match value bit width ("
389 << value
.getBitWidth() << ")";
392 if (llvm::isa
<IndexType
>(type
)) {
393 if (value
.getBitWidth() != IndexType::kInternalStorageBitWidth
)
395 << "value bit width (" << value
.getBitWidth()
396 << ") doesn't match index type internal storage bit width ("
397 << IndexType::kInternalStorageBitWidth
<< ")";
400 return emitError() << "expected integer or index type";
403 BoolAttr
IntegerAttr::getBoolAttrUnchecked(IntegerType type
, bool value
) {
404 auto attr
= Base::get(type
.getContext(), type
, APInt(/*numBits=*/1, value
));
405 return llvm::cast
<BoolAttr
>(attr
);
408 //===----------------------------------------------------------------------===//
410 //===----------------------------------------------------------------------===//
412 bool BoolAttr::getValue() const {
413 auto *storage
= reinterpret_cast<IntegerAttrStorage
*>(impl
);
414 return storage
->value
.getBoolValue();
417 bool BoolAttr::classof(Attribute attr
) {
418 IntegerAttr intAttr
= llvm::dyn_cast
<IntegerAttr
>(attr
);
419 return intAttr
&& intAttr
.getType().isSignlessInteger(1);
422 //===----------------------------------------------------------------------===//
424 //===----------------------------------------------------------------------===//
426 LogicalResult
OpaqueAttr::verify(function_ref
<InFlightDiagnostic()> emitError
,
427 StringAttr dialect
, StringRef attrData
,
429 if (!Dialect::isValidNamespace(dialect
.strref()))
430 return emitError() << "invalid dialect namespace '" << dialect
<< "'";
432 // Check that the dialect is actually registered.
433 MLIRContext
*context
= dialect
.getContext();
434 if (!context
->allowsUnregisteredDialects() &&
435 !context
->getLoadedDialect(dialect
.strref())) {
437 << "#" << dialect
<< "<\"" << attrData
<< "\"> : " << type
438 << " attribute created with unregistered dialect. If this is "
439 "intended, please call allowUnregisteredDialects() on the "
440 "MLIRContext, or use -allow-unregistered-dialect with "
441 "the MLIR opt tool used";
447 //===----------------------------------------------------------------------===//
448 // DenseElementsAttr Utilities
449 //===----------------------------------------------------------------------===//
451 const char DenseIntOrFPElementsAttrStorage::kSplatTrue
= ~0;
452 const char DenseIntOrFPElementsAttrStorage::kSplatFalse
= 0;
454 /// Get the bitwidth of a dense element type within the buffer.
455 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
456 static size_t getDenseElementStorageWidth(size_t origWidth
) {
457 return origWidth
== 1 ? origWidth
: llvm::alignTo
<8>(origWidth
);
459 static size_t getDenseElementStorageWidth(Type elementType
) {
460 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType
));
463 /// Set a bit to a specific value.
464 static void setBit(char *rawData
, size_t bitPos
, bool value
) {
466 rawData
[bitPos
/ CHAR_BIT
] |= (1 << (bitPos
% CHAR_BIT
));
468 rawData
[bitPos
/ CHAR_BIT
] &= ~(1 << (bitPos
% CHAR_BIT
));
471 /// Return the value of the specified bit.
472 static bool getBit(const char *rawData
, size_t bitPos
) {
473 return (rawData
[bitPos
/ CHAR_BIT
] & (1 << (bitPos
% CHAR_BIT
))) != 0;
476 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
478 static void copyAPIntToArrayForBEmachine(APInt value
, size_t numBytes
,
480 assert(llvm::endianness::native
== llvm::endianness::big
);
481 assert(value
.getNumWords() * APInt::APINT_WORD_SIZE
>= numBytes
);
483 // Copy the words filled with data.
484 // For example, when `value` has 2 words, the first word is filled with data.
485 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
486 size_t numFilledWords
= (value
.getNumWords() - 1) * APInt::APINT_WORD_SIZE
;
487 std::copy_n(reinterpret_cast<const char *>(value
.getRawData()),
488 numFilledWords
, result
);
489 // Convert last word of APInt to LE format and store it in char
491 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------|
492 size_t lastWordPos
= numFilledWords
;
493 SmallVector
<char, 8> valueLE(APInt::APINT_WORD_SIZE
);
494 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
495 reinterpret_cast<const char *>(value
.getRawData()) + lastWordPos
,
496 valueLE
.begin(), APInt::APINT_BITS_PER_WORD
, 1);
497 // Extract actual APInt data from `valueLE`, convert endianness to BE format,
498 // and store it in `result`.
499 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij|
500 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
501 valueLE
.begin(), result
+ lastWordPos
,
502 (numBytes
- lastWordPos
) * CHAR_BIT
, 1);
505 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
507 static void copyArrayToAPIntForBEmachine(const char *inArray
, size_t numBytes
,
509 assert(llvm::endianness::native
== llvm::endianness::big
);
510 assert(result
.getNumWords() * APInt::APINT_WORD_SIZE
>= numBytes
);
512 // Copy the data that fills the word of `result` from `inArray`.
513 // For example, when `result` has 2 words, the first word will be filled with
514 // data. So, the first 8 bytes are copied from `inArray` here.
515 // `inArray` (10 bytes, BE): |abcdefgh|ij|
516 // ==> `result` (2 words, BE): |abcdefgh|--------|
517 size_t numFilledWords
= (result
.getNumWords() - 1) * APInt::APINT_WORD_SIZE
;
519 inArray
, numFilledWords
,
520 const_cast<char *>(reinterpret_cast<const char *>(result
.getRawData())));
522 // Convert array data which will be last word of `result` to LE format, and
523 // store it in char array(`inArrayLE`).
524 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------|
525 size_t lastWordPos
= numFilledWords
;
526 SmallVector
<char, 8> inArrayLE(APInt::APINT_WORD_SIZE
);
527 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
528 inArray
+ lastWordPos
, inArrayLE
.begin(),
529 (numBytes
- lastWordPos
) * CHAR_BIT
, 1);
531 // Convert `inArrayLE` to BE format, and store it in last word of `result`.
532 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij|
533 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
535 const_cast<char *>(reinterpret_cast<const char *>(result
.getRawData())) +
537 APInt::APINT_BITS_PER_WORD
, 1);
540 /// Writes value to the bit position `bitPos` in array `rawData`.
541 static void writeBits(char *rawData
, size_t bitPos
, APInt value
) {
542 size_t bitWidth
= value
.getBitWidth();
544 // If the bitwidth is 1 we just toggle the specific bit.
546 return setBit(rawData
, bitPos
, value
.isOne());
548 // Otherwise, the bit position is guaranteed to be byte aligned.
549 assert((bitPos
% CHAR_BIT
) == 0 && "expected bitPos to be 8-bit aligned");
550 if (llvm::endianness::native
== llvm::endianness::big
) {
551 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
552 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
553 // work correctly in BE format.
554 // ex. `value` (2 words including 10 bytes)
555 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------|
556 copyAPIntToArrayForBEmachine(value
, llvm::divideCeil(bitWidth
, CHAR_BIT
),
557 rawData
+ (bitPos
/ CHAR_BIT
));
559 std::copy_n(reinterpret_cast<const char *>(value
.getRawData()),
560 llvm::divideCeil(bitWidth
, CHAR_BIT
),
561 rawData
+ (bitPos
/ CHAR_BIT
));
565 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
567 static APInt
readBits(const char *rawData
, size_t bitPos
, size_t bitWidth
) {
568 // Handle a boolean bit position.
570 return APInt(1, getBit(rawData
, bitPos
) ? 1 : 0);
572 // Otherwise, the bit position must be 8-bit aligned.
573 assert((bitPos
% CHAR_BIT
) == 0 && "expected bitPos to be 8-bit aligned");
574 APInt
result(bitWidth
, 0);
575 if (llvm::endianness::native
== llvm::endianness::big
) {
576 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
577 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
578 // work correctly in BE format.
579 // ex. `result` (2 words including 10 bytes)
580 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function
581 copyArrayToAPIntForBEmachine(rawData
+ (bitPos
/ CHAR_BIT
),
582 llvm::divideCeil(bitWidth
, CHAR_BIT
), result
);
584 std::copy_n(rawData
+ (bitPos
/ CHAR_BIT
),
585 llvm::divideCeil(bitWidth
, CHAR_BIT
),
587 reinterpret_cast<const char *>(result
.getRawData())));
592 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
593 /// the same element count as 'type'.
594 template <typename Values
>
595 static bool hasSameElementsOrSplat(ShapedType type
, const Values
&values
) {
596 return (values
.size() == 1) ||
597 (type
.getNumElements() == static_cast<int64_t>(values
.size()));
600 //===----------------------------------------------------------------------===//
601 // DenseElementsAttr Iterators
602 //===----------------------------------------------------------------------===//
604 //===----------------------------------------------------------------------===//
605 // AttributeElementIterator
607 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
608 DenseElementsAttr attr
, size_t index
)
609 : llvm::indexed_accessor_iterator
<AttributeElementIterator
, const void *,
610 Attribute
, Attribute
, Attribute
>(
611 attr
.getAsOpaquePointer(), index
) {}
613 Attribute
DenseElementsAttr::AttributeElementIterator::operator*() const {
614 auto owner
= llvm::cast
<DenseElementsAttr
>(getFromOpaquePointer(base
));
615 Type eltTy
= owner
.getElementType();
616 if (llvm::dyn_cast
<IntegerType
>(eltTy
))
617 return IntegerAttr::get(eltTy
, *IntElementIterator(owner
, index
));
618 if (llvm::isa
<IndexType
>(eltTy
))
619 return IntegerAttr::get(eltTy
, *IntElementIterator(owner
, index
));
620 if (auto floatEltTy
= llvm::dyn_cast
<FloatType
>(eltTy
)) {
621 IntElementIterator
intIt(owner
, index
);
622 FloatElementIterator
floatIt(floatEltTy
.getFloatSemantics(), intIt
);
623 return FloatAttr::get(eltTy
, *floatIt
);
625 if (auto complexTy
= llvm::dyn_cast
<ComplexType
>(eltTy
)) {
626 auto complexEltTy
= complexTy
.getElementType();
627 ComplexIntElementIterator
complexIntIt(owner
, index
);
628 if (llvm::isa
<IntegerType
>(complexEltTy
)) {
629 auto value
= *complexIntIt
;
630 auto real
= IntegerAttr::get(complexEltTy
, value
.real());
631 auto imag
= IntegerAttr::get(complexEltTy
, value
.imag());
632 return ArrayAttr::get(complexTy
.getContext(),
633 ArrayRef
<Attribute
>{real
, imag
});
636 ComplexFloatElementIterator
complexFloatIt(
637 llvm::cast
<FloatType
>(complexEltTy
).getFloatSemantics(), complexIntIt
);
638 auto value
= *complexFloatIt
;
639 auto real
= FloatAttr::get(complexEltTy
, value
.real());
640 auto imag
= FloatAttr::get(complexEltTy
, value
.imag());
641 return ArrayAttr::get(complexTy
.getContext(),
642 ArrayRef
<Attribute
>{real
, imag
});
644 if (llvm::isa
<DenseStringElementsAttr
>(owner
)) {
645 ArrayRef
<StringRef
> vals
= owner
.getRawStringData();
646 return StringAttr::get(owner
.isSplat() ? vals
.front() : vals
[index
], eltTy
);
648 llvm_unreachable("unexpected element type");
651 //===----------------------------------------------------------------------===//
652 // BoolElementIterator
654 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
655 DenseElementsAttr attr
, size_t dataIndex
)
656 : DenseElementIndexedIteratorImpl
<BoolElementIterator
, bool, bool, bool>(
657 attr
.getRawData().data(), attr
.isSplat(), dataIndex
) {}
659 bool DenseElementsAttr::BoolElementIterator::operator*() const {
660 return getBit(getData(), getDataIndex());
663 //===----------------------------------------------------------------------===//
664 // IntElementIterator
666 DenseElementsAttr::IntElementIterator::IntElementIterator(
667 DenseElementsAttr attr
, size_t dataIndex
)
668 : DenseElementIndexedIteratorImpl
<IntElementIterator
, APInt
, APInt
, APInt
>(
669 attr
.getRawData().data(), attr
.isSplat(), dataIndex
),
670 bitWidth(getDenseElementBitWidth(attr
.getElementType())) {}
672 APInt
DenseElementsAttr::IntElementIterator::operator*() const {
673 return readBits(getData(),
674 getDataIndex() * getDenseElementStorageWidth(bitWidth
),
678 //===----------------------------------------------------------------------===//
679 // ComplexIntElementIterator
681 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
682 DenseElementsAttr attr
, size_t dataIndex
)
683 : DenseElementIndexedIteratorImpl
<ComplexIntElementIterator
,
684 std::complex<APInt
>, std::complex<APInt
>,
685 std::complex<APInt
>>(
686 attr
.getRawData().data(), attr
.isSplat(), dataIndex
) {
687 auto complexType
= llvm::cast
<ComplexType
>(attr
.getElementType());
688 bitWidth
= getDenseElementBitWidth(complexType
.getElementType());
692 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
693 size_t storageWidth
= getDenseElementStorageWidth(bitWidth
);
694 size_t offset
= getDataIndex() * storageWidth
* 2;
695 return {readBits(getData(), offset
, bitWidth
),
696 readBits(getData(), offset
+ storageWidth
, bitWidth
)};
699 //===----------------------------------------------------------------------===//
701 //===----------------------------------------------------------------------===//
704 DenseArrayAttr::verify(function_ref
<InFlightDiagnostic()> emitError
,
705 Type elementType
, int64_t size
, ArrayRef
<char> rawData
) {
706 if (!elementType
.isIntOrIndexOrFloat())
707 return emitError() << "expected integer or floating point element type";
708 int64_t dataSize
= rawData
.size();
709 int64_t elementSize
=
710 llvm::divideCeil(elementType
.getIntOrFloatBitWidth(), CHAR_BIT
);
711 if (size
* elementSize
!= dataSize
) {
712 return emitError() << "expected data size (" << size
<< " elements, "
714 << " bytes each) does not match: " << dataSize
721 /// Instantiations of this class provide utilities for interacting with native
722 /// data types in the context of DenseArrayAttr.
723 template <size_t width
,
724 IntegerType::SignednessSemantics signedness
= IntegerType::Signless
>
725 struct DenseArrayAttrIntUtil
{
726 static bool checkElementType(Type eltType
) {
727 auto type
= llvm::dyn_cast
<IntegerType
>(eltType
);
728 if (!type
|| type
.getWidth() != width
)
730 return type
.getSignedness() == signedness
;
733 static Type
getElementType(MLIRContext
*ctx
) {
734 return IntegerType::get(ctx
, width
, signedness
);
737 template <typename T
>
738 static void printElement(raw_ostream
&os
, T value
) {
742 template <typename T
>
743 static ParseResult
parseElement(AsmParser
&parser
, T
&value
) {
744 return parser
.parseInteger(value
);
747 template <typename T
>
748 struct DenseArrayAttrUtil
;
750 /// Specialization for boolean elements to print 'true' and 'false' literals for
753 struct DenseArrayAttrUtil
<bool> : public DenseArrayAttrIntUtil
<1> {
754 static void printElement(raw_ostream
&os
, bool value
) {
755 os
<< (value
? "true" : "false");
759 /// Specialization for 8-bit integers to ensure values are printed as integers
760 /// and not characters.
762 struct DenseArrayAttrUtil
<int8_t> : public DenseArrayAttrIntUtil
<8> {
763 static void printElement(raw_ostream
&os
, int8_t value
) {
764 os
<< static_cast<int>(value
);
768 struct DenseArrayAttrUtil
<int16_t> : public DenseArrayAttrIntUtil
<16> {};
770 struct DenseArrayAttrUtil
<int32_t> : public DenseArrayAttrIntUtil
<32> {};
772 struct DenseArrayAttrUtil
<int64_t> : public DenseArrayAttrIntUtil
<64> {};
774 /// Specialization for 32-bit floats.
776 struct DenseArrayAttrUtil
<float> {
777 static bool checkElementType(Type eltType
) { return eltType
.isF32(); }
778 static Type
getElementType(MLIRContext
*ctx
) { return Float32Type::get(ctx
); }
779 static void printElement(raw_ostream
&os
, float value
) { os
<< value
; }
781 /// Parse a double and cast it to a float.
782 static ParseResult
parseElement(AsmParser
&parser
, float &value
) {
784 if (parser
.parseFloat(doubleVal
))
791 /// Specialization for 64-bit floats.
793 struct DenseArrayAttrUtil
<double> {
794 static bool checkElementType(Type eltType
) { return eltType
.isF64(); }
795 static Type
getElementType(MLIRContext
*ctx
) { return Float64Type::get(ctx
); }
796 static void printElement(raw_ostream
&os
, float value
) { os
<< value
; }
797 static ParseResult
parseElement(AsmParser
&parser
, double &value
) {
798 return parser
.parseFloat(value
);
803 template <typename T
>
804 void DenseArrayAttrImpl
<T
>::print(AsmPrinter
&printer
) const {
805 print(printer
.getStream());
808 template <typename T
>
809 void DenseArrayAttrImpl
<T
>::printWithoutBraces(raw_ostream
&os
) const {
810 llvm::interleaveComma(asArrayRef(), os
, [&](T value
) {
811 DenseArrayAttrUtil
<T
>::printElement(os
, value
);
815 template <typename T
>
816 void DenseArrayAttrImpl
<T
>::print(raw_ostream
&os
) const {
818 printWithoutBraces(os
);
822 /// Parse a DenseArrayAttr without the braces: `1, 2, 3`
823 template <typename T
>
824 Attribute DenseArrayAttrImpl
<T
>::parseWithoutBraces(AsmParser
&parser
,
827 if (failed(parser
.parseCommaSeparatedList([&]() {
829 if (DenseArrayAttrUtil
<T
>::parseElement(parser
, value
))
831 data
.push_back(value
);
835 return get(parser
.getContext(), data
);
838 /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
839 template <typename T
>
840 Attribute DenseArrayAttrImpl
<T
>::parse(AsmParser
&parser
, Type odsType
) {
841 if (parser
.parseLSquare())
843 // Handle empty list case.
844 if (succeeded(parser
.parseOptionalRSquare()))
845 return get(parser
.getContext(), {});
846 Attribute result
= parseWithoutBraces(parser
, odsType
);
847 if (parser
.parseRSquare())
852 /// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
853 template <typename T
>
854 DenseArrayAttrImpl
<T
>::operator ArrayRef
<T
>() const {
855 ArrayRef
<char> raw
= getRawData();
856 assert((raw
.size() % sizeof(T
)) == 0);
857 return ArrayRef
<T
>(reinterpret_cast<const T
*>(raw
.data()),
858 raw
.size() / sizeof(T
));
861 /// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
862 template <typename T
>
863 DenseArrayAttrImpl
<T
> DenseArrayAttrImpl
<T
>::get(MLIRContext
*context
,
864 ArrayRef
<T
> content
) {
865 Type elementType
= DenseArrayAttrUtil
<T
>::getElementType(context
);
866 auto rawArray
= ArrayRef
<char>(reinterpret_cast<const char *>(content
.data()),
867 content
.size() * sizeof(T
));
868 return llvm::cast
<DenseArrayAttrImpl
<T
>>(
869 Base::get(context
, elementType
, content
.size(), rawArray
));
872 template <typename T
>
873 bool DenseArrayAttrImpl
<T
>::classof(Attribute attr
) {
874 if (auto denseArray
= llvm::dyn_cast
<DenseArrayAttr
>(attr
))
875 return DenseArrayAttrUtil
<T
>::checkElementType(denseArray
.getElementType());
881 // Explicit instantiation for all the supported DenseArrayAttr.
882 template class DenseArrayAttrImpl
<bool>;
883 template class DenseArrayAttrImpl
<int8_t>;
884 template class DenseArrayAttrImpl
<int16_t>;
885 template class DenseArrayAttrImpl
<int32_t>;
886 template class DenseArrayAttrImpl
<int64_t>;
887 template class DenseArrayAttrImpl
<float>;
888 template class DenseArrayAttrImpl
<double>;
889 } // namespace detail
892 //===----------------------------------------------------------------------===//
894 //===----------------------------------------------------------------------===//
896 /// Method for support type inquiry through isa, cast and dyn_cast.
897 bool DenseElementsAttr::classof(Attribute attr
) {
898 return llvm::isa
<DenseIntOrFPElementsAttr
, DenseStringElementsAttr
>(attr
);
901 DenseElementsAttr
DenseElementsAttr::get(ShapedType type
,
902 ArrayRef
<Attribute
> values
) {
903 assert(hasSameElementsOrSplat(type
, values
));
905 Type eltType
= type
.getElementType();
907 // Take care complex type case first.
908 if (auto complexType
= llvm::dyn_cast
<ComplexType
>(eltType
)) {
909 if (complexType
.getElementType().isIntOrIndex()) {
910 SmallVector
<std::complex<APInt
>> complexValues
;
911 complexValues
.reserve(values
.size());
912 for (Attribute attr
: values
) {
913 assert(llvm::isa
<ArrayAttr
>(attr
) && "expected ArrayAttr for complex");
914 auto arrayAttr
= llvm::cast
<ArrayAttr
>(attr
);
915 assert(arrayAttr
.size() == 2 && "expected 2 element for complex");
916 auto attr0
= arrayAttr
[0];
917 auto attr1
= arrayAttr
[1];
918 complexValues
.push_back(
919 std::complex<APInt
>(llvm::cast
<IntegerAttr
>(attr0
).getValue(),
920 llvm::cast
<IntegerAttr
>(attr1
).getValue()));
922 return DenseElementsAttr::get(type
, complexValues
);
925 SmallVector
<std::complex<APFloat
>> complexValues
;
926 complexValues
.reserve(values
.size());
927 for (Attribute attr
: values
) {
928 assert(llvm::isa
<ArrayAttr
>(attr
) && "expected ArrayAttr for complex");
929 auto arrayAttr
= llvm::cast
<ArrayAttr
>(attr
);
930 assert(arrayAttr
.size() == 2 && "expected 2 element for complex");
931 auto attr0
= arrayAttr
[0];
932 auto attr1
= arrayAttr
[1];
933 complexValues
.push_back(
934 std::complex<APFloat
>(llvm::cast
<FloatAttr
>(attr0
).getValue(),
935 llvm::cast
<FloatAttr
>(attr1
).getValue()));
937 return DenseElementsAttr::get(type
, complexValues
);
940 // If the element type is not based on int/float/index, assume it is a string
942 if (!eltType
.isIntOrIndexOrFloat()) {
943 SmallVector
<StringRef
, 8> stringValues
;
944 stringValues
.reserve(values
.size());
945 for (Attribute attr
: values
) {
946 assert(llvm::isa
<StringAttr
>(attr
) &&
947 "expected string value for non integer/index/float element");
948 stringValues
.push_back(llvm::cast
<StringAttr
>(attr
).getValue());
950 return get(type
, stringValues
);
953 // Otherwise, get the raw storage width to use for the allocation.
954 size_t bitWidth
= getDenseElementBitWidth(eltType
);
955 size_t storageBitWidth
= getDenseElementStorageWidth(bitWidth
);
957 // Compress the attribute values into a character buffer.
958 SmallVector
<char, 8> data(
959 llvm::divideCeil(storageBitWidth
* values
.size(), CHAR_BIT
));
961 for (unsigned i
= 0, e
= values
.size(); i
< e
; ++i
) {
962 if (auto floatAttr
= llvm::dyn_cast
<FloatAttr
>(values
[i
])) {
963 assert(floatAttr
.getType() == eltType
&&
964 "expected float attribute type to equal element type");
965 intVal
= floatAttr
.getValue().bitcastToAPInt();
967 auto intAttr
= llvm::cast
<IntegerAttr
>(values
[i
]);
968 assert(intAttr
.getType() == eltType
&&
969 "expected integer attribute type to equal element type");
970 intVal
= intAttr
.getValue();
973 assert(intVal
.getBitWidth() == bitWidth
&&
974 "expected value to have same bitwidth as element type");
975 writeBits(data
.data(), i
* storageBitWidth
, intVal
);
978 // Handle the special encoding of splat of bool.
979 if (values
.size() == 1 && eltType
.isInteger(1))
980 data
[0] = data
[0] ? -1 : 0;
982 return DenseIntOrFPElementsAttr::getRaw(type
, data
);
985 DenseElementsAttr
DenseElementsAttr::get(ShapedType type
,
986 ArrayRef
<bool> values
) {
987 assert(hasSameElementsOrSplat(type
, values
));
988 assert(type
.getElementType().isInteger(1));
990 std::vector
<char> buff(llvm::divideCeil(values
.size(), CHAR_BIT
));
992 if (!values
.empty()) {
994 bool firstValue
= values
[0];
995 for (int i
= 0, e
= values
.size(); i
!= e
; ++i
) {
996 isSplat
&= values
[i
] == firstValue
;
997 setBit(buff
.data(), i
, values
[i
]);
1000 // Splat of bool is encoded as a byte with all-ones in it.
1003 buff
[0] = values
[0] ? -1 : 0;
1007 return DenseIntOrFPElementsAttr::getRaw(type
, buff
);
1010 DenseElementsAttr
DenseElementsAttr::get(ShapedType type
,
1011 ArrayRef
<StringRef
> values
) {
1012 assert(!type
.getElementType().isIntOrFloat());
1013 return DenseStringElementsAttr::get(type
, values
);
1016 /// Constructs a dense integer elements attribute from an array of APInt
1017 /// values. Each APInt value is expected to have the same bitwidth as the
1018 /// element type of 'type'.
1019 DenseElementsAttr
DenseElementsAttr::get(ShapedType type
,
1020 ArrayRef
<APInt
> values
) {
1021 assert(type
.getElementType().isIntOrIndex());
1022 assert(hasSameElementsOrSplat(type
, values
));
1023 size_t storageBitWidth
= getDenseElementStorageWidth(type
.getElementType());
1024 return DenseIntOrFPElementsAttr::getRaw(type
, storageBitWidth
, values
);
1026 DenseElementsAttr
DenseElementsAttr::get(ShapedType type
,
1027 ArrayRef
<std::complex<APInt
>> values
) {
1028 ComplexType
complex = llvm::cast
<ComplexType
>(type
.getElementType());
1029 assert(llvm::isa
<IntegerType
>(complex.getElementType()));
1030 assert(hasSameElementsOrSplat(type
, values
));
1031 size_t storageBitWidth
= getDenseElementStorageWidth(complex) / 2;
1032 ArrayRef
<APInt
> intVals(reinterpret_cast<const APInt
*>(values
.data()),
1034 return DenseIntOrFPElementsAttr::getRaw(type
, storageBitWidth
, intVals
);
1037 // Constructs a dense float elements attribute from an array of APFloat
1038 // values. Each APFloat value is expected to have the same bitwidth as the
1039 // element type of 'type'.
1040 DenseElementsAttr
DenseElementsAttr::get(ShapedType type
,
1041 ArrayRef
<APFloat
> values
) {
1042 assert(llvm::isa
<FloatType
>(type
.getElementType()));
1043 assert(hasSameElementsOrSplat(type
, values
));
1044 size_t storageBitWidth
= getDenseElementStorageWidth(type
.getElementType());
1045 return DenseIntOrFPElementsAttr::getRaw(type
, storageBitWidth
, values
);
1048 DenseElementsAttr::get(ShapedType type
,
1049 ArrayRef
<std::complex<APFloat
>> values
) {
1050 ComplexType
complex = llvm::cast
<ComplexType
>(type
.getElementType());
1051 assert(llvm::isa
<FloatType
>(complex.getElementType()));
1052 assert(hasSameElementsOrSplat(type
, values
));
1053 ArrayRef
<APFloat
> apVals(reinterpret_cast<const APFloat
*>(values
.data()),
1055 size_t storageBitWidth
= getDenseElementStorageWidth(complex) / 2;
1056 return DenseIntOrFPElementsAttr::getRaw(type
, storageBitWidth
, apVals
);
1059 /// Construct a dense elements attribute from a raw buffer representing the
1060 /// data for this attribute. Users should generally not use this methods as
1061 /// the expected buffer format may not be a form the user expects.
1063 DenseElementsAttr::getFromRawBuffer(ShapedType type
, ArrayRef
<char> rawBuffer
) {
1064 return DenseIntOrFPElementsAttr::getRaw(type
, rawBuffer
);
1067 /// Returns true if the given buffer is a valid raw buffer for the given type.
1068 bool DenseElementsAttr::isValidRawBuffer(ShapedType type
,
1069 ArrayRef
<char> rawBuffer
,
1070 bool &detectedSplat
) {
1071 size_t storageWidth
= getDenseElementStorageWidth(type
.getElementType());
1072 size_t rawBufferWidth
= rawBuffer
.size() * CHAR_BIT
;
1073 int64_t numElements
= type
.getNumElements();
1075 // The initializer is always a splat if the result type has a single element.
1076 detectedSplat
= numElements
== 1;
1078 // Storage width of 1 is special as it is packed by the bit.
1079 if (storageWidth
== 1) {
1080 // Check for a splat, or a buffer equal to the number of elements which
1081 // consists of either all 0's or all 1's.
1082 if (rawBuffer
.size() == 1) {
1083 auto rawByte
= static_cast<uint8_t>(rawBuffer
[0]);
1084 if (rawByte
== 0 || rawByte
== 0xff) {
1085 detectedSplat
= true;
1090 // This is a valid non-splat buffer if it has the right size.
1091 return rawBufferWidth
== llvm::alignTo
<8>(numElements
);
1094 // All other types are 8-bit aligned, so we can just check the buffer width
1095 // to know if only a single initializer element was passed in.
1096 if (rawBufferWidth
== storageWidth
) {
1097 detectedSplat
= true;
1101 // The raw buffer is valid if it has the right size.
1102 return rawBufferWidth
== storageWidth
* numElements
;
1105 /// Check the information for a C++ data type, check if this type is valid for
1106 /// the current attribute. This method is used to verify specific type
1107 /// invariants that the templatized 'getValues' method cannot.
1108 static bool isValidIntOrFloat(Type type
, int64_t dataEltSize
, bool isInt
,
1110 // Make sure that the data element size is the same as the type element width.
1111 auto denseEltBitWidth
= getDenseElementBitWidth(type
);
1112 auto dataSize
= static_cast<size_t>(dataEltSize
* CHAR_BIT
);
1113 if (denseEltBitWidth
!= dataSize
) {
1114 LLVM_DEBUG(llvm::dbgs() << "expected dense element bit width "
1115 << denseEltBitWidth
<< " to match data size "
1116 << dataSize
<< " for type " << type
<< "\n");
1120 // Check that the element type is either float or integer or index.
1122 bool valid
= llvm::isa
<FloatType
>(type
);
1124 LLVM_DEBUG(llvm::dbgs()
1125 << "expected float type when isInt is false, but found "
1132 auto intType
= llvm::dyn_cast
<IntegerType
>(type
);
1134 LLVM_DEBUG(llvm::dbgs()
1135 << "expected integer type when isInt is true, but found " << type
1140 // Make sure signedness semantics is consistent.
1141 if (intType
.isSignless())
1144 bool valid
= intType
.isSigned() == isSigned
;
1146 LLVM_DEBUG(llvm::dbgs() << "expected signedness " << isSigned
1147 << " to match type " << type
<< "\n");
1151 /// Defaults down the subclass implementation.
1152 DenseElementsAttr
DenseElementsAttr::getRawComplex(ShapedType type
,
1153 ArrayRef
<char> data
,
1154 int64_t dataEltSize
,
1155 bool isInt
, bool isSigned
) {
1156 return DenseIntOrFPElementsAttr::getRawComplex(type
, data
, dataEltSize
, isInt
,
1159 DenseElementsAttr
DenseElementsAttr::getRawIntOrFloat(ShapedType type
,
1160 ArrayRef
<char> data
,
1161 int64_t dataEltSize
,
1164 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type
, data
, dataEltSize
,
1168 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize
, bool isInt
,
1169 bool isSigned
) const {
1170 return ::isValidIntOrFloat(getElementType(), dataEltSize
, isInt
, isSigned
);
1172 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize
, bool isInt
,
1173 bool isSigned
) const {
1174 return ::isValidIntOrFloat(
1175 llvm::cast
<ComplexType
>(getElementType()).getElementType(),
1176 dataEltSize
/ 2, isInt
, isSigned
);
1179 /// Returns true if this attribute corresponds to a splat, i.e. if all element
1180 /// values are the same.
1181 bool DenseElementsAttr::isSplat() const {
1182 return static_cast<DenseElementsAttributeStorage
*>(impl
)->isSplat
;
1185 /// Return if the given complex type has an integer element type.
1186 static bool isComplexOfIntType(Type type
) {
1187 return llvm::isa
<IntegerType
>(llvm::cast
<ComplexType
>(type
).getElementType());
1190 auto DenseElementsAttr::tryGetComplexIntValues() const
1191 -> FailureOr
<iterator_range_impl
<ComplexIntElementIterator
>> {
1192 if (!isComplexOfIntType(getElementType()))
1194 return iterator_range_impl
<ComplexIntElementIterator
>(
1195 getType(), ComplexIntElementIterator(*this, 0),
1196 ComplexIntElementIterator(*this, getNumElements()));
1199 auto DenseElementsAttr::tryGetFloatValues() const
1200 -> FailureOr
<iterator_range_impl
<FloatElementIterator
>> {
1201 auto eltTy
= llvm::dyn_cast
<FloatType
>(getElementType());
1204 const auto &elementSemantics
= eltTy
.getFloatSemantics();
1205 return iterator_range_impl
<FloatElementIterator
>(
1206 getType(), FloatElementIterator(elementSemantics
, raw_int_begin()),
1207 FloatElementIterator(elementSemantics
, raw_int_end()));
1210 auto DenseElementsAttr::tryGetComplexFloatValues() const
1211 -> FailureOr
<iterator_range_impl
<ComplexFloatElementIterator
>> {
1212 auto complexTy
= llvm::dyn_cast
<ComplexType
>(getElementType());
1215 auto eltTy
= llvm::dyn_cast
<FloatType
>(complexTy
.getElementType());
1218 const auto &semantics
= eltTy
.getFloatSemantics();
1219 return iterator_range_impl
<ComplexFloatElementIterator
>(
1220 getType(), {semantics
, {*this, 0}},
1221 {semantics
, {*this, static_cast<size_t>(getNumElements())}});
1224 /// Return the raw storage data held by this attribute.
1225 ArrayRef
<char> DenseElementsAttr::getRawData() const {
1226 return static_cast<DenseIntOrFPElementsAttrStorage
*>(impl
)->data
;
1229 ArrayRef
<StringRef
> DenseElementsAttr::getRawStringData() const {
1230 return static_cast<DenseStringElementsAttrStorage
*>(impl
)->data
;
1233 /// Return a new DenseElementsAttr that has the same data as the current
1234 /// attribute, but has been reshaped to 'newType'. The new type must have the
1235 /// same total number of elements as well as element type.
1236 DenseElementsAttr
DenseElementsAttr::reshape(ShapedType newType
) {
1237 ShapedType curType
= getType();
1238 if (curType
== newType
)
1241 assert(newType
.getElementType() == curType
.getElementType() &&
1242 "expected the same element type");
1243 assert(newType
.getNumElements() == curType
.getNumElements() &&
1244 "expected the same number of elements");
1245 return DenseIntOrFPElementsAttr::getRaw(newType
, getRawData());
1248 DenseElementsAttr
DenseElementsAttr::resizeSplat(ShapedType newType
) {
1249 assert(isSplat() && "expected a splat type");
1251 ShapedType curType
= getType();
1252 if (curType
== newType
)
1255 assert(newType
.getElementType() == curType
.getElementType() &&
1256 "expected the same element type");
1257 return DenseIntOrFPElementsAttr::getRaw(newType
, getRawData());
1260 /// Return a new DenseElementsAttr that has the same data as the current
1261 /// attribute, but has bitcast elements such that it is now 'newType'. The new
1262 /// type must have the same shape and element types of the same bitwidth as the
1264 DenseElementsAttr
DenseElementsAttr::bitcast(Type newElType
) {
1265 ShapedType curType
= getType();
1266 Type curElType
= curType
.getElementType();
1267 if (curElType
== newElType
)
1270 assert(getDenseElementBitWidth(newElType
) ==
1271 getDenseElementBitWidth(curElType
) &&
1272 "expected element types with the same bitwidth");
1273 return DenseIntOrFPElementsAttr::getRaw(curType
.clone(newElType
),
1278 DenseElementsAttr::mapValues(Type newElementType
,
1279 function_ref
<APInt(const APInt
&)> mapping
) const {
1280 return llvm::cast
<DenseIntElementsAttr
>(*this).mapValues(newElementType
,
1284 DenseElementsAttr
DenseElementsAttr::mapValues(
1285 Type newElementType
, function_ref
<APInt(const APFloat
&)> mapping
) const {
1286 return llvm::cast
<DenseFPElementsAttr
>(*this).mapValues(newElementType
,
1290 ShapedType
DenseElementsAttr::getType() const {
1291 return static_cast<const DenseElementsAttributeStorage
*>(impl
)->type
;
1294 Type
DenseElementsAttr::getElementType() const {
1295 return getType().getElementType();
1298 int64_t DenseElementsAttr::getNumElements() const {
1299 return getType().getNumElements();
1302 //===----------------------------------------------------------------------===//
1303 // DenseIntOrFPElementsAttr
1304 //===----------------------------------------------------------------------===//
1306 /// Utility method to write a range of APInt values to a buffer.
1307 template <typename APRangeT
>
1308 static void writeAPIntsToBuffer(size_t storageWidth
, std::vector
<char> &data
,
1309 APRangeT
&&values
) {
1310 size_t numValues
= llvm::size(values
);
1311 data
.resize(llvm::divideCeil(storageWidth
* numValues
, CHAR_BIT
));
1313 for (auto it
= values
.begin(), e
= values
.end(); it
!= e
;
1314 ++it
, offset
+= storageWidth
) {
1315 assert((*it
).getBitWidth() <= storageWidth
);
1316 writeBits(data
.data(), offset
, *it
);
1319 // Handle the special encoding of splat of a boolean.
1320 if (numValues
== 1 && (*values
.begin()).getBitWidth() == 1)
1321 data
[0] = data
[0] ? -1 : 0;
1324 /// Constructs a dense elements attribute from an array of raw APFloat values.
1325 /// Each APFloat value is expected to have the same bitwidth as the element
1326 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1327 DenseElementsAttr
DenseIntOrFPElementsAttr::getRaw(ShapedType type
,
1328 size_t storageWidth
,
1329 ArrayRef
<APFloat
> values
) {
1330 std::vector
<char> data
;
1331 auto unwrapFloat
= [](const APFloat
&val
) { return val
.bitcastToAPInt(); };
1332 writeAPIntsToBuffer(storageWidth
, data
, llvm::map_range(values
, unwrapFloat
));
1333 return DenseIntOrFPElementsAttr::getRaw(type
, data
);
1336 /// Constructs a dense elements attribute from an array of raw APInt values.
1337 /// Each APInt value is expected to have the same bitwidth as the element type
1339 DenseElementsAttr
DenseIntOrFPElementsAttr::getRaw(ShapedType type
,
1340 size_t storageWidth
,
1341 ArrayRef
<APInt
> values
) {
1342 std::vector
<char> data
;
1343 writeAPIntsToBuffer(storageWidth
, data
, values
);
1344 return DenseIntOrFPElementsAttr::getRaw(type
, data
);
1347 DenseElementsAttr
DenseIntOrFPElementsAttr::getRaw(ShapedType type
,
1348 ArrayRef
<char> data
) {
1349 assert(type
.hasStaticShape() && "type must have static shape");
1350 bool isSplat
= false;
1351 bool isValid
= isValidRawBuffer(type
, data
, isSplat
);
1354 return Base::get(type
.getContext(), type
, data
, isSplat
);
1357 /// Overload of the raw 'get' method that asserts that the given type is of
1358 /// complex type. This method is used to verify type invariants that the
1359 /// templatized 'get' method cannot.
1360 DenseElementsAttr
DenseIntOrFPElementsAttr::getRawComplex(ShapedType type
,
1361 ArrayRef
<char> data
,
1362 int64_t dataEltSize
,
1365 assert(::isValidIntOrFloat(
1366 llvm::cast
<ComplexType
>(type
.getElementType()).getElementType(),
1367 dataEltSize
/ 2, isInt
, isSigned
) &&
1368 "Try re-running with -debug-only=builtinattributes");
1370 int64_t numElements
= data
.size() / dataEltSize
;
1372 assert(numElements
== 1 || numElements
== type
.getNumElements());
1373 return getRaw(type
, data
);
1376 /// Overload of the 'getRaw' method that asserts that the given type is of
1377 /// integer type. This method is used to verify type invariants that the
1378 /// templatized 'get' method cannot.
1380 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type
, ArrayRef
<char> data
,
1381 int64_t dataEltSize
, bool isInt
,
1383 assert(::isValidIntOrFloat(type
.getElementType(), dataEltSize
, isInt
,
1385 "Try re-running with -debug-only=builtinattributes");
1387 int64_t numElements
= data
.size() / dataEltSize
;
1388 assert(numElements
== 1 || numElements
== type
.getNumElements());
1390 return getRaw(type
, data
);
1393 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1394 const char *inRawData
, char *outRawData
, size_t elementBitWidth
,
1395 size_t numElements
) {
1396 using llvm::support::ulittle16_t
;
1397 using llvm::support::ulittle32_t
;
1398 using llvm::support::ulittle64_t
;
1400 assert(llvm::endianness::native
== llvm::endianness::big
);
1401 // NOLINT to avoid warning message about replacing by static_assert()
1403 // Following std::copy_n always converts endianness on BE machine.
1404 switch (elementBitWidth
) {
1406 const ulittle16_t
*inRawDataPos
=
1407 reinterpret_cast<const ulittle16_t
*>(inRawData
);
1408 uint16_t *outDataPos
= reinterpret_cast<uint16_t *>(outRawData
);
1409 std::copy_n(inRawDataPos
, numElements
, outDataPos
);
1413 const ulittle32_t
*inRawDataPos
=
1414 reinterpret_cast<const ulittle32_t
*>(inRawData
);
1415 uint32_t *outDataPos
= reinterpret_cast<uint32_t *>(outRawData
);
1416 std::copy_n(inRawDataPos
, numElements
, outDataPos
);
1420 const ulittle64_t
*inRawDataPos
=
1421 reinterpret_cast<const ulittle64_t
*>(inRawData
);
1422 uint64_t *outDataPos
= reinterpret_cast<uint64_t *>(outRawData
);
1423 std::copy_n(inRawDataPos
, numElements
, outDataPos
);
1427 size_t nBytes
= elementBitWidth
/ CHAR_BIT
;
1428 for (size_t i
= 0; i
< nBytes
; i
++)
1429 std::copy_n(inRawData
+ (nBytes
- 1 - i
), 1, outRawData
+ i
);
1435 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1436 ArrayRef
<char> inRawData
, MutableArrayRef
<char> outRawData
,
1438 size_t numElements
= type
.getNumElements();
1439 Type elementType
= type
.getElementType();
1440 if (ComplexType complexTy
= llvm::dyn_cast
<ComplexType
>(elementType
)) {
1441 elementType
= complexTy
.getElementType();
1442 numElements
= numElements
* 2;
1444 size_t elementBitWidth
= getDenseElementStorageWidth(elementType
);
1445 assert(numElements
* elementBitWidth
== inRawData
.size() * CHAR_BIT
&&
1446 inRawData
.size() <= outRawData
.size());
1447 if (elementBitWidth
<= CHAR_BIT
)
1448 std::memcpy(outRawData
.begin(), inRawData
.begin(), inRawData
.size());
1450 convertEndianOfCharForBEmachine(inRawData
.begin(), outRawData
.begin(),
1451 elementBitWidth
, numElements
);
1454 //===----------------------------------------------------------------------===//
1455 // DenseFPElementsAttr
1456 //===----------------------------------------------------------------------===//
1458 template <typename Fn
, typename Attr
>
1459 static ShapedType
mappingHelper(Fn mapping
, Attr
&attr
, ShapedType inType
,
1460 Type newElementType
,
1461 llvm::SmallVectorImpl
<char> &data
) {
1462 size_t bitWidth
= getDenseElementBitWidth(newElementType
);
1463 size_t storageBitWidth
= getDenseElementStorageWidth(bitWidth
);
1465 ShapedType newArrayType
= inType
.cloneWith(inType
.getShape(), newElementType
);
1467 size_t numRawElements
= attr
.isSplat() ? 1 : newArrayType
.getNumElements();
1468 data
.resize(llvm::divideCeil(storageBitWidth
* numRawElements
, CHAR_BIT
));
1470 // Functor used to process a single element value of the attribute.
1471 auto processElt
= [&](decltype(*attr
.begin()) value
, size_t index
) {
1472 auto newInt
= mapping(value
);
1473 assert(newInt
.getBitWidth() == bitWidth
);
1474 writeBits(data
.data(), index
* storageBitWidth
, newInt
);
1477 // Check for the splat case.
1478 if (attr
.isSplat()) {
1479 if (bitWidth
== 1) {
1480 // Handle the special encoding of splat of bool.
1481 data
[0] = mapping(*attr
.begin()).isZero() ? 0 : -1;
1483 processElt(*attr
.begin(), /*index=*/0);
1485 return newArrayType
;
1488 // Otherwise, process all of the element values.
1489 uint64_t elementIdx
= 0;
1490 for (auto value
: attr
)
1491 processElt(value
, elementIdx
++);
1492 return newArrayType
;
1495 DenseElementsAttr
DenseFPElementsAttr::mapValues(
1496 Type newElementType
, function_ref
<APInt(const APFloat
&)> mapping
) const {
1497 llvm::SmallVector
<char, 8> elementData
;
1499 mappingHelper(mapping
, *this, getType(), newElementType
, elementData
);
1501 return getRaw(newArrayType
, elementData
);
1504 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1505 bool DenseFPElementsAttr::classof(Attribute attr
) {
1506 if (auto denseAttr
= llvm::dyn_cast
<DenseElementsAttr
>(attr
))
1507 return llvm::isa
<FloatType
>(denseAttr
.getType().getElementType());
1511 //===----------------------------------------------------------------------===//
1512 // DenseIntElementsAttr
1513 //===----------------------------------------------------------------------===//
1515 DenseElementsAttr
DenseIntElementsAttr::mapValues(
1516 Type newElementType
, function_ref
<APInt(const APInt
&)> mapping
) const {
1517 llvm::SmallVector
<char, 8> elementData
;
1519 mappingHelper(mapping
, *this, getType(), newElementType
, elementData
);
1520 return getRaw(newArrayType
, elementData
);
1523 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1524 bool DenseIntElementsAttr::classof(Attribute attr
) {
1525 if (auto denseAttr
= llvm::dyn_cast
<DenseElementsAttr
>(attr
))
1526 return denseAttr
.getType().getElementType().isIntOrIndex();
1530 //===----------------------------------------------------------------------===//
1531 // DenseResourceElementsAttr
1532 //===----------------------------------------------------------------------===//
1534 DenseResourceElementsAttr
1535 DenseResourceElementsAttr::get(ShapedType type
,
1536 DenseResourceElementsHandle handle
) {
1537 return Base::get(type
.getContext(), type
, handle
);
1540 DenseResourceElementsAttr
DenseResourceElementsAttr::get(ShapedType type
,
1542 AsmResourceBlob blob
) {
1543 // Extract the builtin dialect resource manager from context and construct a
1544 // handle by inserting a new resource using the provided blob.
1546 DenseResourceElementsHandle::getManagerInterface(type
.getContext());
1547 return get(type
, manager
.insert(blobName
, std::move(blob
)));
1550 //===----------------------------------------------------------------------===//
1551 // DenseResourceElementsAttrBase
1554 /// Instantiations of this class provide utilities for interacting with native
1555 /// data types in the context of DenseResourceElementsAttr.
1556 template <typename T
>
1557 struct DenseResourceAttrUtil
;
1558 template <size_t width
, bool isSigned
>
1559 struct DenseResourceElementsAttrIntUtil
{
1560 static bool checkElementType(Type eltType
) {
1561 IntegerType type
= llvm::dyn_cast
<IntegerType
>(eltType
);
1562 if (!type
|| type
.getWidth() != width
)
1564 return isSigned
? !type
.isUnsigned() : !type
.isSigned();
1568 struct DenseResourceAttrUtil
<bool> {
1569 static bool checkElementType(Type eltType
) {
1570 return eltType
.isSignlessInteger(1);
1574 struct DenseResourceAttrUtil
<int8_t>
1575 : public DenseResourceElementsAttrIntUtil
<8, true> {};
1577 struct DenseResourceAttrUtil
<uint8_t>
1578 : public DenseResourceElementsAttrIntUtil
<8, false> {};
1580 struct DenseResourceAttrUtil
<int16_t>
1581 : public DenseResourceElementsAttrIntUtil
<16, true> {};
1583 struct DenseResourceAttrUtil
<uint16_t>
1584 : public DenseResourceElementsAttrIntUtil
<16, false> {};
1586 struct DenseResourceAttrUtil
<int32_t>
1587 : public DenseResourceElementsAttrIntUtil
<32, true> {};
1589 struct DenseResourceAttrUtil
<uint32_t>
1590 : public DenseResourceElementsAttrIntUtil
<32, false> {};
1592 struct DenseResourceAttrUtil
<int64_t>
1593 : public DenseResourceElementsAttrIntUtil
<64, true> {};
1595 struct DenseResourceAttrUtil
<uint64_t>
1596 : public DenseResourceElementsAttrIntUtil
<64, false> {};
1598 struct DenseResourceAttrUtil
<float> {
1599 static bool checkElementType(Type eltType
) { return eltType
.isF32(); }
1602 struct DenseResourceAttrUtil
<double> {
1603 static bool checkElementType(Type eltType
) { return eltType
.isF64(); }
1607 template <typename T
>
1608 DenseResourceElementsAttrBase
<T
>
1609 DenseResourceElementsAttrBase
<T
>::get(ShapedType type
, StringRef blobName
,
1610 AsmResourceBlob blob
) {
1611 // Check that the blob is in the form we were expecting.
1612 assert(blob
.getDataAlignment() == alignof(T
) &&
1613 "alignment mismatch between expected alignment and blob alignment");
1614 assert(((blob
.getData().size() % sizeof(T
)) == 0) &&
1615 "size mismatch between expected element width and blob size");
1616 assert(DenseResourceAttrUtil
<T
>::checkElementType(type
.getElementType()) &&
1617 "invalid shape element type for provided type `T`");
1618 return llvm::cast
<DenseResourceElementsAttrBase
<T
>>(
1619 DenseResourceElementsAttr::get(type
, blobName
, std::move(blob
)));
1622 template <typename T
>
1623 std::optional
<ArrayRef
<T
>>
1624 DenseResourceElementsAttrBase
<T
>::tryGetAsArrayRef() const {
1625 if (AsmResourceBlob
*blob
= this->getRawHandle().getBlob())
1626 return blob
->template getDataAs
<T
>();
1627 return std::nullopt
;
1630 template <typename T
>
1631 bool DenseResourceElementsAttrBase
<T
>::classof(Attribute attr
) {
1632 auto resourceAttr
= llvm::dyn_cast
<DenseResourceElementsAttr
>(attr
);
1633 return resourceAttr
&& DenseResourceAttrUtil
<T
>::checkElementType(
1634 resourceAttr
.getElementType());
1639 // Explicit instantiation for all the supported DenseResourceElementsAttr.
1640 template class DenseResourceElementsAttrBase
<bool>;
1641 template class DenseResourceElementsAttrBase
<int8_t>;
1642 template class DenseResourceElementsAttrBase
<int16_t>;
1643 template class DenseResourceElementsAttrBase
<int32_t>;
1644 template class DenseResourceElementsAttrBase
<int64_t>;
1645 template class DenseResourceElementsAttrBase
<uint8_t>;
1646 template class DenseResourceElementsAttrBase
<uint16_t>;
1647 template class DenseResourceElementsAttrBase
<uint32_t>;
1648 template class DenseResourceElementsAttrBase
<uint64_t>;
1649 template class DenseResourceElementsAttrBase
<float>;
1650 template class DenseResourceElementsAttrBase
<double>;
1651 } // namespace detail
1654 //===----------------------------------------------------------------------===//
1655 // SparseElementsAttr
1656 //===----------------------------------------------------------------------===//
1658 /// Get a zero APFloat for the given sparse attribute.
1659 APFloat
SparseElementsAttr::getZeroAPFloat() const {
1660 auto eltType
= llvm::cast
<FloatType
>(getElementType());
1661 return APFloat(eltType
.getFloatSemantics());
1664 /// Get a zero APInt for the given sparse attribute.
1665 APInt
SparseElementsAttr::getZeroAPInt() const {
1666 auto eltType
= llvm::cast
<IntegerType
>(getElementType());
1667 return APInt::getZero(eltType
.getWidth());
1670 /// Get a zero attribute for the given attribute type.
1671 Attribute
SparseElementsAttr::getZeroAttr() const {
1672 auto eltType
= getElementType();
1674 // Handle floating point elements.
1675 if (llvm::isa
<FloatType
>(eltType
))
1676 return FloatAttr::get(eltType
, 0);
1678 // Handle complex elements.
1679 if (auto complexTy
= llvm::dyn_cast
<ComplexType
>(eltType
)) {
1680 auto eltType
= complexTy
.getElementType();
1682 if (llvm::isa
<FloatType
>(eltType
))
1683 zero
= FloatAttr::get(eltType
, 0);
1684 else // must be integer
1685 zero
= IntegerAttr::get(eltType
, 0);
1686 return ArrayAttr::get(complexTy
.getContext(),
1687 ArrayRef
<Attribute
>{zero
, zero
});
1690 // Handle string type.
1691 if (llvm::isa
<DenseStringElementsAttr
>(getValues()))
1692 return StringAttr::get("", eltType
);
1694 // Otherwise, this is an integer.
1695 return IntegerAttr::get(eltType
, 0);
1698 /// Flatten, and return, all of the sparse indices in this attribute in
1699 /// row-major order.
1700 std::vector
<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1701 std::vector
<ptrdiff_t> flatSparseIndices
;
1703 // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1704 // as a 1-D index array.
1705 auto sparseIndices
= getIndices();
1706 auto sparseIndexValues
= sparseIndices
.getValues
<uint64_t>();
1707 if (sparseIndices
.isSplat()) {
1708 SmallVector
<uint64_t, 8> indices(getType().getRank(),
1709 *sparseIndexValues
.begin());
1710 flatSparseIndices
.push_back(getFlattenedIndex(indices
));
1711 return flatSparseIndices
;
1714 // Otherwise, reinterpret each index as an ArrayRef when flattening.
1715 auto numSparseIndices
= sparseIndices
.getType().getDimSize(0);
1716 size_t rank
= getType().getRank();
1717 for (size_t i
= 0, e
= numSparseIndices
; i
!= e
; ++i
)
1718 flatSparseIndices
.push_back(getFlattenedIndex(
1719 {&*std::next(sparseIndexValues
.begin(), i
* rank
), rank
}));
1720 return flatSparseIndices
;
1724 SparseElementsAttr::verify(function_ref
<InFlightDiagnostic()> emitError
,
1725 ShapedType type
, DenseIntElementsAttr sparseIndices
,
1726 DenseElementsAttr values
) {
1727 ShapedType valuesType
= values
.getType();
1728 if (valuesType
.getRank() != 1)
1729 return emitError() << "expected 1-d tensor for sparse element values";
1731 // Verify the indices and values shape.
1732 ShapedType indicesType
= sparseIndices
.getType();
1733 auto emitShapeError
= [&]() {
1734 return emitError() << "expected shape ([" << type
.getShape()
1735 << "]); inferred shape of indices literal (["
1736 << indicesType
.getShape()
1737 << "]); inferred shape of values literal (["
1738 << valuesType
.getShape() << "])";
1740 // Verify indices shape.
1741 size_t rank
= type
.getRank(), indicesRank
= indicesType
.getRank();
1742 if (indicesRank
== 2) {
1743 if (indicesType
.getDimSize(1) != static_cast<int64_t>(rank
))
1744 return emitShapeError();
1745 } else if (indicesRank
!= 1 || rank
!= 1) {
1746 return emitShapeError();
1748 // Verify the values shape.
1749 int64_t numSparseIndices
= indicesType
.getDimSize(0);
1750 if (numSparseIndices
!= valuesType
.getDimSize(0))
1751 return emitShapeError();
1753 // Verify that the sparse indices are within the value shape.
1754 auto emitIndexError
= [&](unsigned indexNum
, ArrayRef
<uint64_t> index
) {
1756 << "sparse index #" << indexNum
1757 << " is not contained within the value shape, with index=[" << index
1758 << "], and type=" << type
;
1761 // Handle the case where the index values are a splat.
1762 auto sparseIndexValues
= sparseIndices
.getValues
<uint64_t>();
1763 if (sparseIndices
.isSplat()) {
1764 SmallVector
<uint64_t> indices(rank
, *sparseIndexValues
.begin());
1765 if (!ElementsAttr::isValidIndex(type
, indices
))
1766 return emitIndexError(0, indices
);
1770 // Otherwise, reinterpret each index as an ArrayRef.
1771 for (size_t i
= 0, e
= numSparseIndices
; i
!= e
; ++i
) {
1772 ArrayRef
<uint64_t> index(&*std::next(sparseIndexValues
.begin(), i
* rank
),
1774 if (!ElementsAttr::isValidIndex(type
, index
))
1775 return emitIndexError(i
, index
);
1781 //===----------------------------------------------------------------------===//
1783 //===----------------------------------------------------------------------===//
1785 DistinctAttr
DistinctAttr::create(Attribute referencedAttr
) {
1786 return Base::get(referencedAttr
.getContext(), referencedAttr
);
1789 Attribute
DistinctAttr::getReferencedAttr() const {
1790 return getImpl()->referencedAttr
;
1793 //===----------------------------------------------------------------------===//
1794 // Attribute Utilities
1795 //===----------------------------------------------------------------------===//
1797 AffineMap
mlir::makeStridedLinearLayoutMap(ArrayRef
<int64_t> strides
,
1799 MLIRContext
*context
) {
1801 unsigned nSymbols
= 0;
1803 // AffineExpr for offset.
1805 if (!ShapedType::isDynamic(offset
)) {
1806 auto cst
= getAffineConstantExpr(offset
, context
);
1809 // Dynamic case, new symbol for the offset.
1810 auto sym
= getAffineSymbolExpr(nSymbols
++, context
);
1814 // AffineExpr for strides.
1815 for (const auto &en
: llvm::enumerate(strides
)) {
1816 auto dim
= en
.index();
1817 auto stride
= en
.value();
1818 assert(stride
!= 0 && "Invalid stride specification");
1819 auto d
= getAffineDimExpr(dim
, context
);
1822 if (!ShapedType::isDynamic(stride
))
1823 mult
= getAffineConstantExpr(stride
, context
);
1825 // Dynamic case, new symbol for each new stride.
1826 mult
= getAffineSymbolExpr(nSymbols
++, context
);
1827 expr
= expr
+ d
* mult
;
1830 return AffineMap::get(strides
.size(), nSymbols
, expr
);