1 //===- BuiltinAttributeInterfaces.cpp -------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/BuiltinAttributeInterfaces.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "llvm/ADT/Sequence.h"
15 using namespace mlir::detail
;
17 //===----------------------------------------------------------------------===//
18 /// Tablegen Interface Definitions
19 //===----------------------------------------------------------------------===//
21 #include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
23 //===----------------------------------------------------------------------===//
25 //===----------------------------------------------------------------------===//
27 Type
ElementsAttr::getElementType(ElementsAttr elementsAttr
) {
28 return elementsAttr
.getShapedType().getElementType();
31 int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr
) {
32 return elementsAttr
.getShapedType().getNumElements();
35 bool ElementsAttr::isValidIndex(ShapedType type
, ArrayRef
<uint64_t> index
) {
36 // Verify that the rank of the indices matches the held type.
37 int64_t rank
= type
.getRank();
38 if (rank
== 0 && index
.size() == 1 && index
[0] == 0)
40 if (rank
!= static_cast<int64_t>(index
.size()))
43 // Verify that all of the indices are within the shape dimensions.
44 ArrayRef
<int64_t> shape
= type
.getShape();
45 return llvm::all_of(llvm::seq
<int>(0, rank
), [&](int i
) {
46 int64_t dim
= static_cast<int64_t>(index
[i
]);
47 return 0 <= dim
&& dim
< shape
[i
];
50 bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr
,
51 ArrayRef
<uint64_t> index
) {
52 return isValidIndex(elementsAttr
.getShapedType(), index
);
55 uint64_t ElementsAttr::getFlattenedIndex(Type type
, ArrayRef
<uint64_t> index
) {
56 ShapedType shapeType
= llvm::cast
<ShapedType
>(type
);
57 assert(isValidIndex(shapeType
, index
) &&
58 "expected valid multi-dimensional index");
60 // Reduce the provided multidimensional index into a flattended 1D row-major
62 auto rank
= shapeType
.getRank();
63 ArrayRef
<int64_t> shape
= shapeType
.getShape();
64 uint64_t valueIndex
= 0;
65 uint64_t dimMultiplier
= 1;
66 for (int i
= rank
- 1; i
>= 0; --i
) {
67 valueIndex
+= index
[i
] * dimMultiplier
;
68 dimMultiplier
*= shape
[i
];
73 //===----------------------------------------------------------------------===//
74 // MemRefLayoutAttrInterface
75 //===----------------------------------------------------------------------===//
77 LogicalResult
mlir::detail::verifyAffineMapAsLayout(
78 AffineMap m
, ArrayRef
<int64_t> shape
,
79 function_ref
<InFlightDiagnostic()> emitError
) {
80 if (m
.getNumDims() != shape
.size())
81 return emitError() << "memref layout mismatch between rank and affine map: "
82 << shape
.size() << " != " << m
.getNumDims();