1 //===- IndexingUtils.cpp - Helpers related to index computations ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Utils/IndexingUtils.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "llvm/ADT/STLExtras.h"
22 template <typename ExprType
>
23 SmallVector
<ExprType
> computeSuffixProductImpl(ArrayRef
<ExprType
> sizes
,
27 SmallVector
<ExprType
> strides(sizes
.size(), unit
);
28 for (int64_t r
= strides
.size() - 2; r
>= 0; --r
)
29 strides
[r
] = strides
[r
+ 1] * sizes
[r
+ 1];
33 template <typename ExprType
>
34 SmallVector
<ExprType
> computeElementwiseMulImpl(ArrayRef
<ExprType
> v1
,
35 ArrayRef
<ExprType
> v2
) {
36 // Early exit if both are empty, let zip_equal fail if only 1 is empty.
37 if (v1
.empty() && v2
.empty())
39 SmallVector
<ExprType
> result
;
40 for (auto it
: llvm::zip_equal(v1
, v2
))
41 result
.push_back(std::get
<0>(it
) * std::get
<1>(it
));
45 template <typename ExprType
>
46 ExprType
linearizeImpl(ArrayRef
<ExprType
> offsets
, ArrayRef
<ExprType
> basis
,
48 assert(offsets
.size() == basis
.size());
49 ExprType linearIndex
= zero
;
50 for (unsigned idx
= 0, e
= basis
.size(); idx
< e
; ++idx
)
51 linearIndex
= linearIndex
+ offsets
[idx
] * basis
[idx
];
55 template <typename ExprType
, typename DivOpTy
>
56 SmallVector
<ExprType
> delinearizeImpl(ExprType linearIndex
,
57 ArrayRef
<ExprType
> strides
,
59 int64_t rank
= strides
.size();
60 SmallVector
<ExprType
> offsets(rank
);
61 for (int64_t r
= 0; r
< rank
; ++r
) {
62 offsets
[r
] = divOp(linearIndex
, strides
[r
]);
63 linearIndex
= linearIndex
% strides
[r
];
68 //===----------------------------------------------------------------------===//
69 // Utils that operate on static integer values.
70 //===----------------------------------------------------------------------===//
72 SmallVector
<int64_t> mlir::computeSuffixProduct(ArrayRef
<int64_t> sizes
) {
73 assert(llvm::all_of(sizes
, [](int64_t s
) { return s
> 0; }) &&
74 "sizes must be nonnegative");
76 return ::computeSuffixProductImpl(sizes
, unit
);
79 SmallVector
<int64_t> mlir::computeElementwiseMul(ArrayRef
<int64_t> v1
,
80 ArrayRef
<int64_t> v2
) {
81 return computeElementwiseMulImpl(v1
, v2
);
84 int64_t mlir::computeSum(ArrayRef
<int64_t> basis
) {
85 assert(llvm::all_of(basis
, [](int64_t s
) { return s
> 0; }) &&
86 "basis must be nonnegative");
89 return std::accumulate(basis
.begin(), basis
.end(), 1, std::plus
<int64_t>());
92 int64_t mlir::computeProduct(ArrayRef
<int64_t> basis
) {
93 assert(llvm::all_of(basis
, [](int64_t s
) { return s
> 0; }) &&
94 "basis must be nonnegative");
97 return std::accumulate(basis
.begin(), basis
.end(), 1,
98 std::multiplies
<int64_t>());
101 int64_t mlir::linearize(ArrayRef
<int64_t> offsets
, ArrayRef
<int64_t> basis
) {
102 assert(llvm::all_of(basis
, [](int64_t s
) { return s
> 0; }) &&
103 "basis must be nonnegative");
105 return linearizeImpl(offsets
, basis
, zero
);
108 SmallVector
<int64_t> mlir::delinearize(int64_t linearIndex
,
109 ArrayRef
<int64_t> strides
) {
110 assert(llvm::all_of(strides
, [](int64_t s
) { return s
> 0; }) &&
111 "strides must be nonnegative");
112 return delinearizeImpl(linearIndex
, strides
,
113 [](int64_t e1
, int64_t e2
) { return e1
/ e2
; });
116 std::optional
<SmallVector
<int64_t>>
117 mlir::computeShapeRatio(ArrayRef
<int64_t> shape
, ArrayRef
<int64_t> subShape
) {
118 if (shape
.size() < subShape
.size())
120 assert(llvm::all_of(shape
, [](int64_t s
) { return s
> 0; }) &&
121 "shape must be nonnegative");
122 assert(llvm::all_of(subShape
, [](int64_t s
) { return s
> 0; }) &&
123 "subShape must be nonnegative");
125 // Starting from the end, compute the integer divisors.
126 std::vector
<int64_t> result
;
127 result
.reserve(shape
.size());
128 for (auto [size
, subSize
] :
129 llvm::zip(llvm::reverse(shape
), llvm::reverse(subShape
))) {
130 // If integral division does not occur, return and let the caller decide.
131 if (size
% subSize
!= 0)
133 result
.push_back(size
/ subSize
);
135 // At this point we computed the ratio (in reverse) for the common size.
136 // Fill with the remaining entries from the shape (still in reverse).
137 int commonSize
= subShape
.size();
138 std::copy(shape
.rbegin() + commonSize
, shape
.rend(),
139 std::back_inserter(result
));
140 // Reverse again to get it back in the proper order and return.
141 return SmallVector
<int64_t>{result
.rbegin(), result
.rend()};
144 //===----------------------------------------------------------------------===//
145 // Utils that operate on AffineExpr.
146 //===----------------------------------------------------------------------===//
148 SmallVector
<AffineExpr
> mlir::computeSuffixProduct(ArrayRef
<AffineExpr
> sizes
) {
151 AffineExpr unit
= getAffineConstantExpr(1, sizes
.front().getContext());
152 return ::computeSuffixProductImpl(sizes
, unit
);
155 SmallVector
<AffineExpr
> mlir::computeElementwiseMul(ArrayRef
<AffineExpr
> v1
,
156 ArrayRef
<AffineExpr
> v2
) {
157 return computeElementwiseMulImpl(v1
, v2
);
160 AffineExpr
mlir::computeSum(MLIRContext
*ctx
, ArrayRef
<AffineExpr
> basis
) {
162 return getAffineConstantExpr(0, ctx
);
163 return std::accumulate(basis
.begin(), basis
.end(),
164 getAffineConstantExpr(0, ctx
),
165 std::plus
<AffineExpr
>());
168 AffineExpr
mlir::computeProduct(MLIRContext
*ctx
, ArrayRef
<AffineExpr
> basis
) {
170 return getAffineConstantExpr(1, ctx
);
171 return std::accumulate(basis
.begin(), basis
.end(),
172 getAffineConstantExpr(1, ctx
),
173 std::multiplies
<AffineExpr
>());
176 AffineExpr
mlir::linearize(MLIRContext
*ctx
, ArrayRef
<AffineExpr
> offsets
,
177 ArrayRef
<AffineExpr
> basis
) {
178 AffineExpr zero
= getAffineConstantExpr(0, ctx
);
179 return linearizeImpl(offsets
, basis
, zero
);
182 AffineExpr
mlir::linearize(MLIRContext
*ctx
, ArrayRef
<AffineExpr
> offsets
,
183 ArrayRef
<int64_t> basis
) {
185 return linearize(ctx
, offsets
, getAffineConstantExprs(basis
, ctx
));
188 SmallVector
<AffineExpr
> mlir::delinearize(AffineExpr linearIndex
,
189 ArrayRef
<AffineExpr
> strides
) {
190 return delinearizeImpl(
191 linearIndex
, strides
,
192 [](AffineExpr e1
, AffineExpr e2
) { return e1
.floorDiv(e2
); });
195 SmallVector
<AffineExpr
> mlir::delinearize(AffineExpr linearIndex
,
196 ArrayRef
<int64_t> strides
) {
197 MLIRContext
*ctx
= linearIndex
.getContext();
198 return delinearize(linearIndex
, getAffineConstantExprs(strides
, ctx
));
201 //===----------------------------------------------------------------------===//
202 // Permutation utils.
203 //===----------------------------------------------------------------------===//
206 mlir::invertPermutationVector(ArrayRef
<int64_t> permutation
) {
207 assert(llvm::all_of(permutation
, [](int64_t s
) { return s
>= 0; }) &&
208 "permutation must be non-negative");
209 SmallVector
<int64_t> inversion(permutation
.size());
210 for (const auto &pos
: llvm::enumerate(permutation
)) {
211 inversion
[pos
.value()] = pos
.index();
216 bool mlir::isPermutationVector(ArrayRef
<int64_t> interchange
) {
217 assert(llvm::all_of(interchange
, [](int64_t s
) { return s
>= 0; }) &&
218 "permutation must be non-negative");
219 llvm::SmallDenseSet
<int64_t, 4> seenVals
;
220 for (auto val
: interchange
) {
221 if (seenVals
.count(val
))
223 seenVals
.insert(val
);
225 return seenVals
.size() == interchange
.size();
229 mlir::computePermutationVector(int64_t permSize
, ArrayRef
<int64_t> positions
,
230 ArrayRef
<int64_t> desiredPositions
) {
231 SmallVector
<int64_t> res(permSize
, -1);
232 DenseSet
<int64_t> seen
;
233 for (auto [pos
, desiredPos
] : llvm::zip_equal(positions
, desiredPositions
)) {
234 res
[desiredPos
] = pos
;
238 for (int64_t &entry
: res
) {
241 while (seen
.contains(nextPos
))
249 SmallVector
<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr
,
252 assert(arrayAttr
.size() > dropFront
+ dropBack
&& "Out of bounds");
253 auto range
= arrayAttr
.getAsRange
<IntegerAttr
>();
254 SmallVector
<int64_t> res
;
255 res
.reserve(arrayAttr
.size() - dropFront
- dropBack
);
256 for (auto it
= range
.begin() + dropFront
, eit
= range
.end() - dropBack
;
258 res
.push_back((*it
).getValue().getSExtValue());
262 // TODO: do we have any common utily for this?
263 static MLIRContext
*getContext(OpFoldResult val
) {
264 assert(val
&& "Invalid value");
265 if (auto attr
= dyn_cast
<Attribute
>(val
)) {
266 return attr
.getContext();
268 return cast
<Value
>(val
).getContext();
272 std::pair
<AffineExpr
, SmallVector
<OpFoldResult
>>
273 mlir::computeLinearIndex(OpFoldResult sourceOffset
,
274 ArrayRef
<OpFoldResult
> strides
,
275 ArrayRef
<OpFoldResult
> indices
) {
276 assert(strides
.size() == indices
.size());
277 auto sourceRank
= static_cast<unsigned>(strides
.size());
279 // Hold the affine symbols and values for the computation of the offset.
280 SmallVector
<OpFoldResult
> values(2 * sourceRank
+ 1);
281 SmallVector
<AffineExpr
> symbols(2 * sourceRank
+ 1);
283 bindSymbolsList(getContext(sourceOffset
), MutableArrayRef
{symbols
});
284 AffineExpr expr
= symbols
.front();
285 values
[0] = sourceOffset
;
287 for (unsigned i
= 0; i
< sourceRank
; ++i
) {
288 // Compute the stride.
289 OpFoldResult origStride
= strides
[i
];
291 // Build up the computation of the offset.
292 unsigned baseIdxForDim
= 1 + 2 * i
;
293 unsigned subOffsetForDim
= baseIdxForDim
;
294 unsigned origStrideForDim
= baseIdxForDim
+ 1;
295 expr
= expr
+ symbols
[subOffsetForDim
] * symbols
[origStrideForDim
];
296 values
[subOffsetForDim
] = indices
[i
];
297 values
[origStrideForDim
] = origStride
;
300 return {expr
, values
};
303 //===----------------------------------------------------------------------===//
305 //===----------------------------------------------------------------------===//
307 /// Apply left-padding by 1 to the tile shape if required.
308 static SmallVector
<int64_t> padTileShapeToSize(ArrayRef
<int64_t> tileShape
,
309 unsigned paddedSize
) {
310 assert(tileShape
.size() <= paddedSize
&&
311 "expected tileShape to <= paddedSize");
312 if (tileShape
.size() == paddedSize
)
313 return to_vector(tileShape
);
314 SmallVector
<int64_t> result(paddedSize
- tileShape
.size(), 1);
315 llvm::append_range(result
, tileShape
);
319 mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl(
320 ArrayRef
<int64_t> shape
, ArrayRef
<int64_t> tileShape
,
321 ArrayRef
<int64_t> loopOrder
)
322 : tileShape(padTileShapeToSize(tileShape
, shape
.size())),
323 inverseLoopOrder(invertPermutationVector(loopOrder
)),
324 sliceStrides(shape
.size()) {
325 // Divide the shape by the tile shape.
326 std::optional
<SmallVector
<int64_t>> shapeRatio
=
327 mlir::computeShapeRatio(shape
, tileShape
);
328 assert(shapeRatio
&& shapeRatio
->size() == shape
.size() &&
329 "target shape does not evenly divide the original shape");
330 assert(isPermutationVector(loopOrder
) && loopOrder
.size() == shape
.size() &&
331 "expected loop order to be a permutation of rank equal to outer "
334 maxLinearIndex
= mlir::computeMaxLinearIndex(*shapeRatio
);
335 mlir::applyPermutationToVector(*shapeRatio
, loopOrder
);
336 sliceStrides
= mlir::computeStrides(*shapeRatio
);
339 SmallVector
<int64_t> mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets(
340 int64_t linearIndex
) const {
341 SmallVector
<int64_t> tileCoords
= applyPermutation(
342 delinearize(linearIndex
, sliceStrides
), inverseLoopOrder
);
343 return computeElementwiseMul(tileCoords
, tileShape
);
346 SmallVector
<AffineExpr
>
347 mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets(
348 AffineExpr linearIndex
) const {
349 MLIRContext
*ctx
= linearIndex
.getContext();
350 SmallVector
<AffineExpr
> tileCoords
= applyPermutation(
351 delinearize(linearIndex
, sliceStrides
), inverseLoopOrder
);
352 return mlir::computeElementwiseMul(tileCoords
,
353 getAffineConstantExprs(tileShape
, ctx
));