1 //===- IndexingUtils.cpp - Helpers related to index computations ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Utils/IndexingUtils.h"
10 #include "mlir/Dialect/Utils/StaticValueUtils.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "llvm/ADT/STLExtras.h"
21 template <typename ExprType
>
22 SmallVector
<ExprType
> computeSuffixProductImpl(ArrayRef
<ExprType
> sizes
,
26 SmallVector
<ExprType
> strides(sizes
.size(), unit
);
27 for (int64_t r
= strides
.size() - 2; r
>= 0; --r
)
28 strides
[r
] = strides
[r
+ 1] * sizes
[r
+ 1];
32 template <typename ExprType
>
33 SmallVector
<ExprType
> computeElementwiseMulImpl(ArrayRef
<ExprType
> v1
,
34 ArrayRef
<ExprType
> v2
) {
35 // Early exit if both are empty, let zip_equal fail if only 1 is empty.
36 if (v1
.empty() && v2
.empty())
38 SmallVector
<ExprType
> result
;
39 for (auto it
: llvm::zip_equal(v1
, v2
))
40 result
.push_back(std::get
<0>(it
) * std::get
<1>(it
));
44 template <typename ExprType
>
45 ExprType
linearizeImpl(ArrayRef
<ExprType
> offsets
, ArrayRef
<ExprType
> basis
,
47 assert(offsets
.size() == basis
.size());
48 ExprType linearIndex
= zero
;
49 for (unsigned idx
= 0, e
= basis
.size(); idx
< e
; ++idx
)
50 linearIndex
= linearIndex
+ offsets
[idx
] * basis
[idx
];
54 template <typename ExprType
, typename DivOpTy
>
55 SmallVector
<ExprType
> delinearizeImpl(ExprType linearIndex
,
56 ArrayRef
<ExprType
> strides
,
58 int64_t rank
= strides
.size();
59 SmallVector
<ExprType
> offsets(rank
);
60 for (int64_t r
= 0; r
< rank
; ++r
) {
61 offsets
[r
] = divOp(linearIndex
, strides
[r
]);
62 linearIndex
= linearIndex
% strides
[r
];
67 //===----------------------------------------------------------------------===//
68 // Utils that operate on static integer values.
69 //===----------------------------------------------------------------------===//
71 SmallVector
<int64_t> mlir::computeSuffixProduct(ArrayRef
<int64_t> sizes
) {
72 assert(llvm::all_of(sizes
, [](int64_t s
) { return s
>= 0; }) &&
73 "sizes must be nonnegative");
75 return ::computeSuffixProductImpl(sizes
, unit
);
78 SmallVector
<int64_t> mlir::computeElementwiseMul(ArrayRef
<int64_t> v1
,
79 ArrayRef
<int64_t> v2
) {
80 return computeElementwiseMulImpl(v1
, v2
);
83 int64_t mlir::computeSum(ArrayRef
<int64_t> basis
) {
84 assert(llvm::all_of(basis
, [](int64_t s
) { return s
> 0; }) &&
85 "basis must be nonnegative");
88 return std::accumulate(basis
.begin(), basis
.end(), 1, std::plus
<int64_t>());
91 int64_t mlir::computeProduct(ArrayRef
<int64_t> basis
) {
92 assert(llvm::all_of(basis
, [](int64_t s
) { return s
> 0; }) &&
93 "basis must be nonnegative");
96 return std::accumulate(basis
.begin(), basis
.end(), 1,
97 std::multiplies
<int64_t>());
100 int64_t mlir::linearize(ArrayRef
<int64_t> offsets
, ArrayRef
<int64_t> basis
) {
101 assert(llvm::all_of(basis
, [](int64_t s
) { return s
> 0; }) &&
102 "basis must be nonnegative");
104 return linearizeImpl(offsets
, basis
, zero
);
107 SmallVector
<int64_t> mlir::delinearize(int64_t linearIndex
,
108 ArrayRef
<int64_t> strides
) {
109 assert(llvm::all_of(strides
, [](int64_t s
) { return s
> 0; }) &&
110 "strides must be nonnegative");
111 return delinearizeImpl(linearIndex
, strides
,
112 [](int64_t e1
, int64_t e2
) { return e1
/ e2
; });
115 std::optional
<SmallVector
<int64_t>>
116 mlir::computeShapeRatio(ArrayRef
<int64_t> shape
, ArrayRef
<int64_t> subShape
) {
117 if (shape
.size() < subShape
.size())
119 assert(llvm::all_of(shape
, [](int64_t s
) { return s
> 0; }) &&
120 "shape must be nonnegative");
121 assert(llvm::all_of(subShape
, [](int64_t s
) { return s
> 0; }) &&
122 "subShape must be nonnegative");
124 // Starting from the end, compute the integer divisors.
125 std::vector
<int64_t> result
;
126 result
.reserve(shape
.size());
127 for (auto [size
, subSize
] :
128 llvm::zip(llvm::reverse(shape
), llvm::reverse(subShape
))) {
129 // If integral division does not occur, return and let the caller decide.
130 if (size
% subSize
!= 0)
132 result
.push_back(size
/ subSize
);
134 // At this point we computed the ratio (in reverse) for the common size.
135 // Fill with the remaining entries from the shape (still in reverse).
136 int commonSize
= subShape
.size();
137 std::copy(shape
.rbegin() + commonSize
, shape
.rend(),
138 std::back_inserter(result
));
139 // Reverse again to get it back in the proper order and return.
140 return SmallVector
<int64_t>{result
.rbegin(), result
.rend()};
143 //===----------------------------------------------------------------------===//
144 // Utils that operate on AffineExpr.
145 //===----------------------------------------------------------------------===//
147 SmallVector
<AffineExpr
> mlir::computeSuffixProduct(ArrayRef
<AffineExpr
> sizes
) {
150 AffineExpr unit
= getAffineConstantExpr(1, sizes
.front().getContext());
151 return ::computeSuffixProductImpl(sizes
, unit
);
154 SmallVector
<AffineExpr
> mlir::computeElementwiseMul(ArrayRef
<AffineExpr
> v1
,
155 ArrayRef
<AffineExpr
> v2
) {
156 return computeElementwiseMulImpl(v1
, v2
);
159 AffineExpr
mlir::computeSum(MLIRContext
*ctx
, ArrayRef
<AffineExpr
> basis
) {
161 return getAffineConstantExpr(0, ctx
);
162 return std::accumulate(basis
.begin(), basis
.end(),
163 getAffineConstantExpr(0, ctx
),
164 std::plus
<AffineExpr
>());
167 AffineExpr
mlir::computeProduct(MLIRContext
*ctx
, ArrayRef
<AffineExpr
> basis
) {
169 return getAffineConstantExpr(1, ctx
);
170 return std::accumulate(basis
.begin(), basis
.end(),
171 getAffineConstantExpr(1, ctx
),
172 std::multiplies
<AffineExpr
>());
175 AffineExpr
mlir::linearize(MLIRContext
*ctx
, ArrayRef
<AffineExpr
> offsets
,
176 ArrayRef
<AffineExpr
> basis
) {
177 AffineExpr zero
= getAffineConstantExpr(0, ctx
);
178 return linearizeImpl(offsets
, basis
, zero
);
181 AffineExpr
mlir::linearize(MLIRContext
*ctx
, ArrayRef
<AffineExpr
> offsets
,
182 ArrayRef
<int64_t> basis
) {
184 return linearize(ctx
, offsets
, getAffineConstantExprs(basis
, ctx
));
187 SmallVector
<AffineExpr
> mlir::delinearize(AffineExpr linearIndex
,
188 ArrayRef
<AffineExpr
> strides
) {
189 return delinearizeImpl(
190 linearIndex
, strides
,
191 [](AffineExpr e1
, AffineExpr e2
) { return e1
.floorDiv(e2
); });
194 SmallVector
<AffineExpr
> mlir::delinearize(AffineExpr linearIndex
,
195 ArrayRef
<int64_t> strides
) {
196 MLIRContext
*ctx
= linearIndex
.getContext();
197 return delinearize(linearIndex
, getAffineConstantExprs(strides
, ctx
));
200 //===----------------------------------------------------------------------===//
201 // Permutation utils.
202 //===----------------------------------------------------------------------===//
205 mlir::invertPermutationVector(ArrayRef
<int64_t> permutation
) {
206 assert(llvm::all_of(permutation
, [](int64_t s
) { return s
>= 0; }) &&
207 "permutation must be non-negative");
208 SmallVector
<int64_t> inversion(permutation
.size());
209 for (const auto &pos
: llvm::enumerate(permutation
)) {
210 inversion
[pos
.value()] = pos
.index();
215 bool mlir::isIdentityPermutation(ArrayRef
<int64_t> permutation
) {
216 for (auto i
: llvm::seq
<int64_t>(0, permutation
.size()))
217 if (permutation
[i
] != i
)
222 bool mlir::isPermutationVector(ArrayRef
<int64_t> interchange
) {
223 assert(llvm::all_of(interchange
, [](int64_t s
) { return s
>= 0; }) &&
224 "permutation must be non-negative");
225 llvm::SmallDenseSet
<int64_t, 4> seenVals
;
226 for (auto val
: interchange
) {
227 if (seenVals
.count(val
))
229 seenVals
.insert(val
);
231 return seenVals
.size() == interchange
.size();
235 mlir::computePermutationVector(int64_t permSize
, ArrayRef
<int64_t> positions
,
236 ArrayRef
<int64_t> desiredPositions
) {
237 SmallVector
<int64_t> res(permSize
, -1);
238 DenseSet
<int64_t> seen
;
239 for (auto [pos
, desiredPos
] : llvm::zip_equal(positions
, desiredPositions
)) {
240 res
[desiredPos
] = pos
;
244 for (int64_t &entry
: res
) {
247 while (seen
.contains(nextPos
))
255 SmallVector
<int64_t> mlir::dropDims(ArrayRef
<int64_t> inputPerm
,
256 ArrayRef
<int64_t> dropPositions
) {
257 assert(inputPerm
.size() >= dropPositions
.size() &&
258 "expect inputPerm size large than position to drop");
259 SmallVector
<int64_t> res
;
260 unsigned permSize
= inputPerm
.size();
261 for (unsigned inputIndex
= 0; inputIndex
< permSize
; ++inputIndex
) {
262 int64_t targetIndex
= inputPerm
[inputIndex
];
263 bool shouldDrop
= false;
264 unsigned dropSize
= dropPositions
.size();
265 for (unsigned dropIndex
= 0; dropIndex
< dropSize
; dropIndex
++) {
266 if (dropPositions
[dropIndex
] == inputPerm
[inputIndex
]) {
270 if (dropPositions
[dropIndex
] < inputPerm
[inputIndex
]) {
275 res
.push_back(targetIndex
);
281 SmallVector
<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr
,
284 assert(arrayAttr
.size() > dropFront
+ dropBack
&& "Out of bounds");
285 auto range
= arrayAttr
.getAsRange
<IntegerAttr
>();
286 SmallVector
<int64_t> res
;
287 res
.reserve(arrayAttr
.size() - dropFront
- dropBack
);
288 for (auto it
= range
.begin() + dropFront
, eit
= range
.end() - dropBack
;
290 res
.push_back((*it
).getValue().getSExtValue());
294 // TODO: do we have any common utily for this?
295 static MLIRContext
*getContext(OpFoldResult val
) {
296 assert(val
&& "Invalid value");
297 if (auto attr
= dyn_cast
<Attribute
>(val
)) {
298 return attr
.getContext();
300 return cast
<Value
>(val
).getContext();
303 std::pair
<AffineExpr
, SmallVector
<OpFoldResult
>>
304 mlir::computeLinearIndex(OpFoldResult sourceOffset
,
305 ArrayRef
<OpFoldResult
> strides
,
306 ArrayRef
<OpFoldResult
> indices
) {
307 assert(strides
.size() == indices
.size());
308 auto sourceRank
= static_cast<unsigned>(strides
.size());
310 // Hold the affine symbols and values for the computation of the offset.
311 SmallVector
<OpFoldResult
> values(2 * sourceRank
+ 1);
312 SmallVector
<AffineExpr
> symbols(2 * sourceRank
+ 1);
314 bindSymbolsList(getContext(sourceOffset
), MutableArrayRef
{symbols
});
315 AffineExpr expr
= symbols
.front();
316 values
[0] = sourceOffset
;
318 for (unsigned i
= 0; i
< sourceRank
; ++i
) {
319 // Compute the stride.
320 OpFoldResult origStride
= strides
[i
];
322 // Build up the computation of the offset.
323 unsigned baseIdxForDim
= 1 + 2 * i
;
324 unsigned subOffsetForDim
= baseIdxForDim
;
325 unsigned origStrideForDim
= baseIdxForDim
+ 1;
326 expr
= expr
+ symbols
[subOffsetForDim
] * symbols
[origStrideForDim
];
327 values
[subOffsetForDim
] = indices
[i
];
328 values
[origStrideForDim
] = origStride
;
331 return {expr
, values
};
334 std::pair
<AffineExpr
, SmallVector
<OpFoldResult
>>
335 mlir::computeLinearIndex(OpFoldResult sourceOffset
, ArrayRef
<int64_t> strides
,
336 ArrayRef
<Value
> indices
) {
337 return computeLinearIndex(
338 sourceOffset
, getAsIndexOpFoldResult(sourceOffset
.getContext(), strides
),
339 getAsOpFoldResult(ValueRange(indices
)));
342 //===----------------------------------------------------------------------===//
344 //===----------------------------------------------------------------------===//
346 /// Apply left-padding by 1 to the tile shape if required.
347 static SmallVector
<int64_t> padTileShapeToSize(ArrayRef
<int64_t> tileShape
,
348 unsigned paddedSize
) {
349 assert(tileShape
.size() <= paddedSize
&&
350 "expected tileShape to <= paddedSize");
351 if (tileShape
.size() == paddedSize
)
352 return to_vector(tileShape
);
353 SmallVector
<int64_t> result(paddedSize
- tileShape
.size(), 1);
354 llvm::append_range(result
, tileShape
);
358 mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl(
359 ArrayRef
<int64_t> shape
, ArrayRef
<int64_t> tileShape
,
360 ArrayRef
<int64_t> loopOrder
)
361 : tileShape(padTileShapeToSize(tileShape
, shape
.size())),
362 inverseLoopOrder(invertPermutationVector(loopOrder
)),
363 sliceStrides(shape
.size()) {
364 // Divide the shape by the tile shape.
365 std::optional
<SmallVector
<int64_t>> shapeRatio
=
366 mlir::computeShapeRatio(shape
, tileShape
);
367 assert(shapeRatio
&& shapeRatio
->size() == shape
.size() &&
368 "target shape does not evenly divide the original shape");
369 assert(isPermutationVector(loopOrder
) && loopOrder
.size() == shape
.size() &&
370 "expected loop order to be a permutation of rank equal to outer "
373 maxLinearIndex
= mlir::computeMaxLinearIndex(*shapeRatio
);
374 mlir::applyPermutationToVector(*shapeRatio
, loopOrder
);
375 sliceStrides
= mlir::computeStrides(*shapeRatio
);
378 SmallVector
<int64_t> mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets(
379 int64_t linearIndex
) const {
380 SmallVector
<int64_t> tileCoords
= applyPermutation(
381 delinearize(linearIndex
, sliceStrides
), inverseLoopOrder
);
382 return computeElementwiseMul(tileCoords
, tileShape
);
385 SmallVector
<AffineExpr
>
386 mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets(
387 AffineExpr linearIndex
) const {
388 MLIRContext
*ctx
= linearIndex
.getContext();
389 SmallVector
<AffineExpr
> tileCoords
= applyPermutation(
390 delinearize(linearIndex
, sliceStrides
), inverseLoopOrder
);
391 return mlir::computeElementwiseMul(tileCoords
,
392 getAffineConstantExprs(tileShape
, ctx
));