[Workflow] Roll back some settings since they caused more issues
[llvm-project.git] / mlir / lib / Dialect / Utils / IndexingUtils.cpp
blobf4e29539214b4b6f74acd59de5b8dd653f1d3f5e
1 //===- IndexingUtils.cpp - Helpers related to index computations ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Utils/IndexingUtils.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "llvm/ADT/STLExtras.h"
17 #include <numeric>
18 #include <optional>
20 using namespace mlir;
22 template <typename ExprType>
23 SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes,
24 ExprType unit) {
25 if (sizes.empty())
26 return {};
27 SmallVector<ExprType> strides(sizes.size(), unit);
28 for (int64_t r = strides.size() - 2; r >= 0; --r)
29 strides[r] = strides[r + 1] * sizes[r + 1];
30 return strides;
33 template <typename ExprType>
34 SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1,
35 ArrayRef<ExprType> v2) {
36 // Early exit if both are empty, let zip_equal fail if only 1 is empty.
37 if (v1.empty() && v2.empty())
38 return {};
39 SmallVector<ExprType> result;
40 for (auto it : llvm::zip_equal(v1, v2))
41 result.push_back(std::get<0>(it) * std::get<1>(it));
42 return result;
45 template <typename ExprType>
46 ExprType linearizeImpl(ArrayRef<ExprType> offsets, ArrayRef<ExprType> basis,
47 ExprType zero) {
48 assert(offsets.size() == basis.size());
49 ExprType linearIndex = zero;
50 for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
51 linearIndex = linearIndex + offsets[idx] * basis[idx];
52 return linearIndex;
55 template <typename ExprType, typename DivOpTy>
56 SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
57 ArrayRef<ExprType> strides,
58 DivOpTy divOp) {
59 int64_t rank = strides.size();
60 SmallVector<ExprType> offsets(rank);
61 for (int64_t r = 0; r < rank; ++r) {
62 offsets[r] = divOp(linearIndex, strides[r]);
63 linearIndex = linearIndex % strides[r];
65 return offsets;
68 //===----------------------------------------------------------------------===//
69 // Utils that operate on static integer values.
70 //===----------------------------------------------------------------------===//
72 SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
73 assert(llvm::all_of(sizes, [](int64_t s) { return s > 0; }) &&
74 "sizes must be nonnegative");
75 int64_t unit = 1;
76 return ::computeSuffixProductImpl(sizes, unit);
79 SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
80 ArrayRef<int64_t> v2) {
81 return computeElementwiseMulImpl(v1, v2);
84 int64_t mlir::computeSum(ArrayRef<int64_t> basis) {
85 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
86 "basis must be nonnegative");
87 if (basis.empty())
88 return 0;
89 return std::accumulate(basis.begin(), basis.end(), 1, std::plus<int64_t>());
92 int64_t mlir::computeProduct(ArrayRef<int64_t> basis) {
93 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
94 "basis must be nonnegative");
95 if (basis.empty())
96 return 0;
97 return std::accumulate(basis.begin(), basis.end(), 1,
98 std::multiplies<int64_t>());
101 int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
102 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
103 "basis must be nonnegative");
104 int64_t zero = 0;
105 return linearizeImpl(offsets, basis, zero);
108 SmallVector<int64_t> mlir::delinearize(int64_t linearIndex,
109 ArrayRef<int64_t> strides) {
110 assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) &&
111 "strides must be nonnegative");
112 return delinearizeImpl(linearIndex, strides,
113 [](int64_t e1, int64_t e2) { return e1 / e2; });
116 std::optional<SmallVector<int64_t>>
117 mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) {
118 if (shape.size() < subShape.size())
119 return std::nullopt;
120 assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) &&
121 "shape must be nonnegative");
122 assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) &&
123 "subShape must be nonnegative");
125 // Starting from the end, compute the integer divisors.
126 std::vector<int64_t> result;
127 result.reserve(shape.size());
128 for (auto [size, subSize] :
129 llvm::zip(llvm::reverse(shape), llvm::reverse(subShape))) {
130 // If integral division does not occur, return and let the caller decide.
131 if (size % subSize != 0)
132 return std::nullopt;
133 result.push_back(size / subSize);
135 // At this point we computed the ratio (in reverse) for the common size.
136 // Fill with the remaining entries from the shape (still in reverse).
137 int commonSize = subShape.size();
138 std::copy(shape.rbegin() + commonSize, shape.rend(),
139 std::back_inserter(result));
140 // Reverse again to get it back in the proper order and return.
141 return SmallVector<int64_t>{result.rbegin(), result.rend()};
144 //===----------------------------------------------------------------------===//
145 // Utils that operate on AffineExpr.
146 //===----------------------------------------------------------------------===//
148 SmallVector<AffineExpr> mlir::computeSuffixProduct(ArrayRef<AffineExpr> sizes) {
149 if (sizes.empty())
150 return {};
151 AffineExpr unit = getAffineConstantExpr(1, sizes.front().getContext());
152 return ::computeSuffixProductImpl(sizes, unit);
155 SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1,
156 ArrayRef<AffineExpr> v2) {
157 return computeElementwiseMulImpl(v1, v2);
160 AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
161 if (basis.empty())
162 return getAffineConstantExpr(0, ctx);
163 return std::accumulate(basis.begin(), basis.end(),
164 getAffineConstantExpr(0, ctx),
165 std::plus<AffineExpr>());
168 AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
169 if (basis.empty())
170 return getAffineConstantExpr(1, ctx);
171 return std::accumulate(basis.begin(), basis.end(),
172 getAffineConstantExpr(1, ctx),
173 std::multiplies<AffineExpr>());
176 AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
177 ArrayRef<AffineExpr> basis) {
178 AffineExpr zero = getAffineConstantExpr(0, ctx);
179 return linearizeImpl(offsets, basis, zero);
182 AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
183 ArrayRef<int64_t> basis) {
185 return linearize(ctx, offsets, getAffineConstantExprs(basis, ctx));
188 SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
189 ArrayRef<AffineExpr> strides) {
190 return delinearizeImpl(
191 linearIndex, strides,
192 [](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(e2); });
195 SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
196 ArrayRef<int64_t> strides) {
197 MLIRContext *ctx = linearIndex.getContext();
198 return delinearize(linearIndex, getAffineConstantExprs(strides, ctx));
201 //===----------------------------------------------------------------------===//
202 // Permutation utils.
203 //===----------------------------------------------------------------------===//
205 SmallVector<int64_t>
206 mlir::invertPermutationVector(ArrayRef<int64_t> permutation) {
207 assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) &&
208 "permutation must be non-negative");
209 SmallVector<int64_t> inversion(permutation.size());
210 for (const auto &pos : llvm::enumerate(permutation)) {
211 inversion[pos.value()] = pos.index();
213 return inversion;
216 bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
217 assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) &&
218 "permutation must be non-negative");
219 llvm::SmallDenseSet<int64_t, 4> seenVals;
220 for (auto val : interchange) {
221 if (seenVals.count(val))
222 return false;
223 seenVals.insert(val);
225 return seenVals.size() == interchange.size();
228 SmallVector<int64_t>
229 mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
230 ArrayRef<int64_t> desiredPositions) {
231 SmallVector<int64_t> res(permSize, -1);
232 DenseSet<int64_t> seen;
233 for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) {
234 res[desiredPos] = pos;
235 seen.insert(pos);
237 int64_t nextPos = 0;
238 for (int64_t &entry : res) {
239 if (entry != -1)
240 continue;
241 while (seen.contains(nextPos))
242 ++nextPos;
243 entry = nextPos;
244 ++nextPos;
246 return res;
249 SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
250 unsigned dropFront,
251 unsigned dropBack) {
252 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
253 auto range = arrayAttr.getAsRange<IntegerAttr>();
254 SmallVector<int64_t> res;
255 res.reserve(arrayAttr.size() - dropFront - dropBack);
256 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
257 it != eit; ++it)
258 res.push_back((*it).getValue().getSExtValue());
259 return res;
262 // TODO: do we have any common utily for this?
263 static MLIRContext *getContext(OpFoldResult val) {
264 assert(val && "Invalid value");
265 if (auto attr = dyn_cast<Attribute>(val)) {
266 return attr.getContext();
267 } else {
268 return cast<Value>(val).getContext();
272 std::pair<AffineExpr, SmallVector<OpFoldResult>>
273 mlir::computeLinearIndex(OpFoldResult sourceOffset,
274 ArrayRef<OpFoldResult> strides,
275 ArrayRef<OpFoldResult> indices) {
276 assert(strides.size() == indices.size());
277 auto sourceRank = static_cast<unsigned>(strides.size());
279 // Hold the affine symbols and values for the computation of the offset.
280 SmallVector<OpFoldResult> values(2 * sourceRank + 1);
281 SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
283 bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols});
284 AffineExpr expr = symbols.front();
285 values[0] = sourceOffset;
287 for (unsigned i = 0; i < sourceRank; ++i) {
288 // Compute the stride.
289 OpFoldResult origStride = strides[i];
291 // Build up the computation of the offset.
292 unsigned baseIdxForDim = 1 + 2 * i;
293 unsigned subOffsetForDim = baseIdxForDim;
294 unsigned origStrideForDim = baseIdxForDim + 1;
295 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
296 values[subOffsetForDim] = indices[i];
297 values[origStrideForDim] = origStride;
300 return {expr, values};
303 //===----------------------------------------------------------------------===//
304 // TileOffsetRange
305 //===----------------------------------------------------------------------===//
307 /// Apply left-padding by 1 to the tile shape if required.
308 static SmallVector<int64_t> padTileShapeToSize(ArrayRef<int64_t> tileShape,
309 unsigned paddedSize) {
310 assert(tileShape.size() <= paddedSize &&
311 "expected tileShape to <= paddedSize");
312 if (tileShape.size() == paddedSize)
313 return to_vector(tileShape);
314 SmallVector<int64_t> result(paddedSize - tileShape.size(), 1);
315 llvm::append_range(result, tileShape);
316 return result;
319 mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl(
320 ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
321 ArrayRef<int64_t> loopOrder)
322 : tileShape(padTileShapeToSize(tileShape, shape.size())),
323 inverseLoopOrder(invertPermutationVector(loopOrder)),
324 sliceStrides(shape.size()) {
325 // Divide the shape by the tile shape.
326 std::optional<SmallVector<int64_t>> shapeRatio =
327 mlir::computeShapeRatio(shape, tileShape);
328 assert(shapeRatio && shapeRatio->size() == shape.size() &&
329 "target shape does not evenly divide the original shape");
330 assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() &&
331 "expected loop order to be a permutation of rank equal to outer "
332 "shape");
334 maxLinearIndex = mlir::computeMaxLinearIndex(*shapeRatio);
335 mlir::applyPermutationToVector(*shapeRatio, loopOrder);
336 sliceStrides = mlir::computeStrides(*shapeRatio);
339 SmallVector<int64_t> mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets(
340 int64_t linearIndex) const {
341 SmallVector<int64_t> tileCoords = applyPermutation(
342 delinearize(linearIndex, sliceStrides), inverseLoopOrder);
343 return computeElementwiseMul(tileCoords, tileShape);
346 SmallVector<AffineExpr>
347 mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets(
348 AffineExpr linearIndex) const {
349 MLIRContext *ctx = linearIndex.getContext();
350 SmallVector<AffineExpr> tileCoords = applyPermutation(
351 delinearize(linearIndex, sliceStrides), inverseLoopOrder);
352 return mlir::computeElementwiseMul(tileCoords,
353 getAffineConstantExprs(tileShape, ctx));