[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Dialect / Tensor / IR / TensorTilingInterfaceImpl.cpp
blob68c3d1cabb11cb91cb0a997cece5e669a4704d18
1 //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Affine/Utils.h"
12 #include "mlir/Dialect/Arith/Utils/Utils.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Utils/Utils.h"
18 #include "mlir/Dialect/Utils/IndexingUtils.h"
19 #include "mlir/Interfaces/InferTypeOpInterface.h"
20 #include "mlir/Interfaces/TilingInterface.h"
21 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
23 using namespace mlir;
24 using namespace mlir::tensor;
26 namespace {
28 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
30 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
31 auto padOp = cast<PadOp>(op);
32 SmallVector<utils::IteratorType> iteratorTypes(
33 padOp.getResultType().getRank(), utils::IteratorType::parallel);
34 return iteratorTypes;
37 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
38 ReifiedRankedShapedTypeDims reifiedShapes;
39 (void)reifyResultShapes(b, op, reifiedShapes);
40 OpFoldResult zero = b.getIndexAttr(0);
41 OpFoldResult one = b.getIndexAttr(1);
42 // Initialize all the ranges to {zero, one, one}. All the `ub`s are
43 // overwritten.
44 SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one});
45 for (const auto &ub : enumerate(reifiedShapes[0]))
46 loopRanges[ub.index()].size = ub.value();
47 return loopRanges;
50 FailureOr<TilingResult>
51 getTiledImplementation(Operation *op, OpBuilder &b,
52 ArrayRef<OpFoldResult> offsets,
53 ArrayRef<OpFoldResult> sizes) const {
54 FailureOr<TilingResult> result =
55 tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
56 if (failed(result))
57 return failure();
58 return result.value();
61 LogicalResult
62 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
63 ArrayRef<OpFoldResult> offsets,
64 ArrayRef<OpFoldResult> sizes,
65 SmallVector<OpFoldResult> &resultOffsets,
66 SmallVector<OpFoldResult> &resultSizes) const {
67 resultOffsets.assign(offsets.begin(), offsets.end());
68 resultSizes.assign(sizes.begin(), sizes.end());
69 return success();
72 LogicalResult getIterationDomainTileFromResultTile(
73 Operation *op, OpBuilder &b, unsigned resultNumber,
74 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
75 SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
76 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
77 iterDomainOffsets.assign(offsets.begin(), offsets.end());
78 iterDomainSizes.assign(sizes.begin(), sizes.end());
79 return success();
82 FailureOr<TilingResult>
83 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
84 ArrayRef<OpFoldResult> offsets,
85 ArrayRef<OpFoldResult> sizes) const {
86 return getTiledImplementation(op, b, offsets, sizes);
90 template <typename OpTy>
91 static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
92 OpBuilder &builder) {
93 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
94 "applies to only pack or unpack operations");
95 OpBuilder::InsertionGuard g(builder);
96 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
97 : op.getDestRank();
98 OpFoldResult zero = builder.getIndexAttr(0);
99 OpFoldResult one = builder.getIndexAttr(1);
100 ReifiedRankedShapedTypeDims resultShape;
101 (void)reifyResultShapes(builder, op, resultShape);
102 SmallVector<Range> loopBounds(rank);
103 for (auto dim : llvm::seq<int64_t>(0, rank)) {
104 loopBounds[dim].offset = zero;
105 loopBounds[dim].stride = one;
106 loopBounds[dim].size = resultShape[0][dim];
108 return loopBounds;
111 static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
112 SmallVector<OpFoldResult> &sizes,
113 ArrayRef<int64_t> permutation) {
114 if (permutation.empty())
115 return;
116 applyPermutationToVector<OpFoldResult>(offsets, permutation);
117 applyPermutationToVector<OpFoldResult>(sizes, permutation);
120 struct PackOpTiling
121 : public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
123 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
124 // Note that here we only consider untiled dimensions and outer tiled data
125 // dimensions, the inner tiled data dimensions are materialized when
126 // building the body of the operation.
127 auto packOp = cast<PackOp>(op);
128 SmallVector<utils::IteratorType> iteratorTypes(
129 packOp.getSourceRank(), utils::IteratorType::parallel);
130 return iteratorTypes;
133 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
134 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
137 FailureOr<TilingResult>
138 getTiledImplementation(Operation *op, OpBuilder &b,
139 ArrayRef<OpFoldResult> offsets,
140 ArrayRef<OpFoldResult> sizes) const {
141 auto packOp = cast<PackOp>(op);
142 Location loc = packOp.getLoc();
144 // The tiling is applied on interchanged dimensions. We have to undo the
145 // interchange to map sizes and offsets to the original input.
146 int64_t inputRank = packOp.getSourceRank();
147 SmallVector<OpFoldResult> origOffsets(offsets);
148 SmallVector<OpFoldResult> origSizes(sizes);
149 applyPermToRange(origOffsets, origSizes,
150 invertPermutationVector(packOp.getOuterDimsPerm()));
152 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
153 packOp.getDimAndTileMapping();
154 SmallVector<OpFoldResult> srcDimValues =
155 tensor::getMixedSizes(b, loc, packOp.getSource());
156 SmallVector<OpFoldResult> inputIndices, inputSizes;
157 for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
158 using AV = affine::AffineValueExpr;
159 affine::AffineBuilder ab(b, loc);
160 AffineExpr dim0, dim1, sym;
161 bindDims(b.getContext(), dim0, dim1);
162 bindSymbols(b.getContext(), sym);
163 if (dimAndTileMapping.count(dim)) {
164 // If the data dimension is tiled, the i-th index is the product of
165 // offset_i and tile_i, and the i-th size is the product of sizes_i and
166 // tile_i.
167 auto avOffset = AV(dim0).bind(origOffsets[dim]);
168 auto avSize = AV(dim0).bind(origSizes[dim]);
169 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
170 inputIndices.push_back(ab.mul(avOffset, avTileSize));
171 inputSizes.push_back(ab.mul(avSize, avTileSize));
172 } else {
173 inputIndices.push_back(origOffsets[dim]);
174 inputSizes.push_back(origSizes[dim]);
177 // Limit the size of the input operand for incomplete tiles.
178 if (packOp.getPaddingValue()) {
179 OpFoldResult dimSize = srcDimValues[dim];
180 auto avDimSize = AV(dim0).bind(dimSize);
181 auto avInputIdx = AV(dim1).bind(inputIndices.back());
182 inputSizes.back() =
183 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
187 auto oneAttr = b.getI64IntegerAttr(1);
188 SmallVector<OpFoldResult> strides(inputRank, oneAttr);
190 SmallVector<Value> tiledOperands;
191 auto sourceSlice = b.create<ExtractSliceOp>(
192 loc, packOp.getSource(), inputIndices, inputSizes, strides);
193 tiledOperands.push_back(sourceSlice);
195 SmallVector<OpFoldResult> outputOffsets, outputSizes;
196 if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
197 outputSizes)))
198 return {};
200 strides.append(packOp.getDestRank() - inputRank, oneAttr);
201 auto outSlice = b.create<ExtractSliceOp>(
202 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
203 tiledOperands.push_back(outSlice);
205 if (auto val = packOp.getPaddingValue())
206 tiledOperands.push_back(val);
207 for (auto tile : packOp.getInnerTiles())
208 tiledOperands.push_back(tile);
210 Operation *tiledPackOp = b.create<PackOp>(
211 loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
213 return TilingResult{
214 {tiledPackOp},
215 SmallVector<Value>(tiledPackOp->getResults()),
216 llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
219 LogicalResult
220 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
221 ArrayRef<OpFoldResult> offsets,
222 ArrayRef<OpFoldResult> sizes,
223 SmallVector<OpFoldResult> &resultOffsets,
224 SmallVector<OpFoldResult> &resultSizes) const {
225 // The iteration domain is over outer dimensions of packed layout. In this
226 // context, the outer dimensions of `resultOffsets` are `offsets`. The
227 // inner dimensions of `resultOffsets` are zeros because tiling is not
228 // applied to them.
229 auto packOp = cast<PackOp>(op);
230 int64_t inputRank = packOp.getSourceRank();
231 int64_t outputRank = packOp.getDestRank();
232 auto zeroAttr = b.getI64IntegerAttr(0);
233 resultOffsets.assign(offsets.begin(), offsets.end());
234 resultOffsets.append(outputRank - inputRank, zeroAttr);
236 ReifiedRankedShapedTypeDims outputShape;
237 (void)reifyResultShapes(b, packOp, outputShape);
238 resultSizes.assign(sizes.begin(), sizes.end());
239 for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
240 resultSizes.push_back(outputShape[0][dataTileDim]);
242 return success();
245 FailureOr<TilingResult>
246 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
247 ArrayRef<OpFoldResult> offsets,
248 ArrayRef<OpFoldResult> sizes) const {
249 auto packOp = cast<PackOp>(op);
250 int64_t numTiles = packOp.getInnerDimsPos().size();
252 // tensor.pack op is fusible (as a producer) only if full inner tiles are
253 // iterated or inner dims are not tiled. Otherwise, it will generate a
254 // sequence of non-trivial ops (for partial tiles).
255 for (auto offset : offsets.take_back(numTiles))
256 if (!isConstantIntValue(offset, 0))
257 return failure();
259 for (auto iter :
260 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
261 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
262 return failure();
264 FailureOr<TilingResult> tilingResult = getTiledImplementation(
265 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
266 if (failed(tilingResult))
267 return failure();
268 return tilingResult.value();
271 /// Method to return the position of iteration domain tile computed by the
272 /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
273 /// `resultSizes` only cover outer dimensions.
274 LogicalResult getIterationDomainTileFromOperandTile(
275 Operation *op, OpBuilder &b, unsigned operandNumber,
276 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
277 SmallVectorImpl<OpFoldResult> &resultOffsets,
278 SmallVectorImpl<OpFoldResult> &resultSizes) const {
279 if (operandNumber != 0)
280 return failure();
282 auto packOp = cast<PackOp>(op);
283 // It is not trivial to infer dest tile from source tile if `packOp` has
284 // padding semantic.
285 if (packOp.getPaddingValue())
286 return failure();
288 Location loc = packOp.getLoc();
290 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
291 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
292 packOp.getDimAndTileMapping();
293 for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
294 if (dimAndTileMapping.count(dim)) {
295 FailureOr<int64_t> cstSize =
296 ValueBoundsConstraintSet::computeConstantBound(
297 presburger::BoundType::UB, sizes[dim],
298 /*stopCondition=*/nullptr, /*closedUB=*/true);
299 std::optional<int64_t> cstInnerSize =
300 getConstantIntValue(dimAndTileMapping[dim]);
301 // Currently fusing `packOp` as consumer only expects perfect tiling
302 // scenario because even if without padding semantic, the `packOp` may
303 // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
304 // where the `tileSize` from operand of `packOp` is 5, which is not
305 // exactly divided by `innerTile`(=6) of `packOp`. As the result:
306 // 1. the first slice is extracted from (0) to (4) and inserted into
307 // (0,0)~(0,4) at first row.
308 // 2. the second slice is extracted from (5) to (9) and SHOULD BE
309 // respectively inserted into two rows with different length, including
310 // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
311 // them, thus adding below constraint to bypass them temporarily. In
312 // another word, we can only support tiling with consumer if the tile
313 // size for the producer is a multiple of the inner tile size for the
314 // packed dimensions at this moment.
315 if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
316 return failure();
319 using AV = affine::AffineValueExpr;
320 affine::AffineBuilder ab(b, loc);
321 AffineExpr dim0, sym;
322 bindDims(b.getContext(), dim0);
323 bindSymbols(b.getContext(), sym);
324 auto avOffset = AV(dim0).bind(offsets[dim]);
325 auto avSize = AV(dim0).bind(sizes[dim]);
326 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
327 outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
328 outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
329 } else {
330 outerDimOffsets.push_back(offsets[dim]);
331 outerDimSizes.push_back(sizes[dim]);
334 applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm());
335 resultOffsets = outerDimOffsets;
336 resultSizes = outerDimSizes;
337 return success();
340 /// Method to return the tiled implementation of tensor.pack as a consumer.
341 FailureOr<TilingResult> getTiledImplementationFromOperandTile(
342 Operation *op, OpBuilder &b, unsigned operandNumber,
343 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
344 if (operandNumber != 0)
345 return failure();
347 auto packOp = cast<PackOp>(op);
348 Location loc = packOp.getLoc();
350 int64_t inputRank = packOp.getSourceRank();
351 auto oneAttr = b.getI64IntegerAttr(1);
352 SmallVector<OpFoldResult> strides(inputRank, oneAttr);
354 SmallVector<Value> tiledOperands;
355 auto sourceSlice = b.create<ExtractSliceOp>(loc, packOp.getSource(),
356 offsets, sizes, strides);
357 tiledOperands.push_back(sourceSlice);
359 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
360 if (failed(getIterationDomainTileFromOperandTile(
361 op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
362 outerDimSizes)))
363 return failure();
365 SmallVector<OpFoldResult> outputOffsets, outputSizes;
366 if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
367 outputOffsets, outputSizes)))
368 return failure();
370 strides.append(packOp.getDestRank() - inputRank, oneAttr);
371 auto outSlice = b.create<ExtractSliceOp>(
372 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
373 tiledOperands.push_back(outSlice);
375 assert(!packOp.getPaddingValue() && "Expect no padding semantic");
376 for (auto tile : packOp.getInnerTiles())
377 tiledOperands.push_back(tile);
379 Operation *tiledPackOp = b.create<PackOp>(
380 loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
382 return TilingResult{
383 {tiledPackOp},
384 SmallVector<Value>(tiledPackOp->getResults()),
385 llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
389 struct UnpackTileDimInfo {
390 bool isAlignedToInnerTileSize;
391 OpFoldResult sourceOffset;
392 OpFoldResult sourceSize;
393 OpFoldResult resultOffset;
394 OpFoldResult destExpandedSize;
397 /// Returns the needed information for tiling unpack op on `tileDim` with given
398 /// `tileOffset` and `tileSize`. For more details, see the comment of the
399 /// `getTiledImplementation`.
400 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
401 int64_t tileDim,
402 OpFoldResult tileOffset,
403 OpFoldResult tileSize) {
404 UnpackTileDimInfo info;
405 Attribute zeroAttr = b.getIndexAttr(0);
406 Attribute oneAttr = b.getIndexAttr(1);
407 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
408 unpackOp.getDimAndTileMapping();
409 // The dimension is not one of packed data dimension.
410 if (!dimAndTileMapping.count(tileDim)) {
411 info.isAlignedToInnerTileSize = true;
412 info.sourceOffset = tileOffset;
413 info.sourceSize = tileSize;
414 info.resultOffset = zeroAttr;
415 info.destExpandedSize = tileSize;
416 return info;
419 Location loc = unpackOp.getLoc();
420 using AV = affine::AffineValueExpr;
421 affine::AffineBuilder ab(b, loc);
422 AffineExpr dim0, dim1, sym0;
423 bindDims(b.getContext(), dim0, dim1);
424 bindSymbols(b.getContext(), sym0);
426 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
428 info.isAlignedToInnerTileSize = false;
429 FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
430 presburger::BoundType::UB, tileSize,
431 /*stopCondition=*/nullptr, /*closedUB=*/true);
432 std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
433 if (!failed(cstSize) && cstInnerSize) {
434 if (*cstSize % *cstInnerSize == 0)
435 info.isAlignedToInnerTileSize = true;
437 // If the tiling size equals to the inner tiling size, the outer dims are
438 // always 1.
439 if (*cstInnerSize == *cstSize) {
440 auto lhs = AV(dim0).bind(tileOffset);
441 auto rhs = AV(dim1).bind(innerTileSize);
442 info.sourceOffset = ab.floor(lhs, rhs);
443 info.sourceSize = oneAttr;
444 info.resultOffset = zeroAttr;
445 info.destExpandedSize = tileSize;
446 return info;
450 if (info.isAlignedToInnerTileSize) {
451 info.sourceOffset =
452 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
453 info.resultOffset = zeroAttr;
454 info.destExpandedSize = tileSize;
456 // The ceilDiv is needed here because there could be incomplete tile even
457 // it is perfect tiling cases. E.g.,
458 // %0 = unpack tensor<33x2xf32> into tensor<64xf32>
459 // If the tiling size is 32, there will be 3 tiles. Two of them have
460 // size=32; one of them have size=2. The size is represented using
461 // affine_min op; we need ceilDiv.
462 info.sourceSize =
463 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
464 return info;
467 affine::DivModValue firstCoord = affine::getDivMod(
468 b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset),
469 getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
470 OpFoldResult tileExclusiveBound =
471 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
472 affine::DivModValue lastCoord = affine::getDivMod(
473 b, loc,
474 getValueOrCreateConstantIndexOp(
475 b, loc,
476 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
477 getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
479 OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient),
480 AV(dim1).bind(firstCoord.quotient));
481 info.sourceSize =
482 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
483 info.sourceOffset = firstCoord.quotient;
484 info.resultOffset = firstCoord.remainder;
485 // Do not create an Affine ops for expanded size because the affine op is too
486 // complicated which would trigger an issue in affine ops simplification.
487 info.destExpandedSize = b.createOrFold<arith::MulIOp>(
488 loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize),
489 getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
490 return info;
493 struct UnPackOpTiling
494 : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
496 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
497 auto unpackOp = cast<UnPackOp>(op);
498 SmallVector<utils::IteratorType> iteratorTypes(
499 unpackOp.getDestRank(), utils::IteratorType::parallel);
500 return iteratorTypes;
503 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
504 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
507 /// There are two cases in tiling unpack ops. If the tiling size is aligned to
508 /// the inner tile size, the corresponding tiles of source are all complete.
509 /// Otherwise, there are in-complete tiles. We will need to expand the slice
510 /// of source for getting complete tiles. The tiled unpack op unpacks more
511 /// data from source, so We'll need an extract_slice op to shift and truncate
512 /// the output.
513 /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
514 /// coordinates of second tile (i.e., result[15..31]) are
515 /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
516 /// row are incomplete tiles. To represent the unpack op, we have to complete
517 /// the rows. I.e., the input coordinates would start with (1, 0); end with
518 /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
519 /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
520 /// can get the actual result.
521 FailureOr<TilingResult>
522 getTiledImplementation(Operation *op, OpBuilder &b,
523 ArrayRef<OpFoldResult> offsets,
524 ArrayRef<OpFoldResult> sizes) const {
525 auto unpackOp = cast<UnPackOp>(op);
526 int64_t srcRank = unpackOp.getSourceRank();
527 int64_t destRank = unpackOp.getDestRank();
528 int64_t numInnerTiles = srcRank - destRank;
529 Location loc = unpackOp.getLoc();
531 // The perfect tiling case indicates that the tiling sizes are multiple of
532 // inner_tile_size. In this context, no extra data is needed when
533 // representing the tiled unpack op.
534 bool isPerfectTilingCase = true;
535 Attribute oneAttr = b.getIndexAttr(1);
536 SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
537 SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes;
538 SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest;
539 for (auto dim : llvm::seq<int64_t>(0, destRank)) {
540 UnpackTileDimInfo info =
541 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
542 if (!info.isAlignedToInnerTileSize)
543 isPerfectTilingCase = false;
544 sliceSrcIndices.push_back(info.sourceOffset);
545 sliceSrcSizes.push_back(info.sourceSize);
546 destExpandedSizes.push_back(info.destExpandedSize);
547 resultOffsetsFromDest.push_back(info.resultOffset);
550 // The tiling is applied on destination dimensions. We have to apply the
551 // interchange on source dimensions if outer_dims_perm is set.
552 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
553 unpackOp.getOuterDimsPerm());
554 Attribute zeroAttr = b.getIndexAttr(0);
555 sliceSrcIndices.append(numInnerTiles, zeroAttr);
556 sliceSrcSizes.append(unpackOp.getMixedTiles());
557 sliceSrcStrides.append(numInnerTiles, oneAttr);
558 SmallVector<Operation *> generatedSlices;
559 ExtractSliceOp sliceSource =
560 b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
561 sliceSrcSizes, sliceSrcStrides);
562 generatedSlices.push_back(sliceSource);
564 SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
565 Value sliceDest;
566 if (isPerfectTilingCase) {
567 auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
568 offsets, sizes, destStrides);
569 sliceDest = destSliceOp;
570 generatedSlices.push_back(destSliceOp);
571 } else {
572 sliceDest = b.create<EmptyOp>(loc, destExpandedSizes,
573 unpackOp.getDestType().getElementType());
576 SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
577 for (auto tile : unpackOp.getInnerTiles())
578 tiledOperands.push_back(tile);
580 Operation *tiledUnpackOp = b.create<UnPackOp>(
581 loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
583 if (isPerfectTilingCase)
584 return TilingResult{{tiledUnpackOp},
585 SmallVector<Value>(tiledUnpackOp->getResults()),
586 generatedSlices};
588 auto extractSlice =
589 b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
590 resultOffsetsFromDest, sizes, destStrides);
591 return TilingResult{
592 {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
595 LogicalResult
596 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
597 ArrayRef<OpFoldResult> offsets,
598 ArrayRef<OpFoldResult> sizes,
599 SmallVector<OpFoldResult> &resultOffsets,
600 SmallVector<OpFoldResult> &resultSizes) const {
601 resultOffsets = llvm::to_vector(offsets);
602 resultSizes = llvm::to_vector(sizes);
603 return success();
606 FailureOr<TilingResult>
607 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
608 ArrayRef<OpFoldResult> offsets,
609 ArrayRef<OpFoldResult> sizes) const {
610 FailureOr<TilingResult> tilingResult =
611 getTiledImplementation(op, b, offsets, sizes);
612 if (failed(tilingResult))
613 return failure();
614 return tilingResult.value();
617 /// Method to return the position of iteration domain tile computed by the
618 /// tiled operation.
619 LogicalResult getIterationDomainTileFromOperandTile(
620 Operation *op, OpBuilder &b, unsigned operandNumber,
621 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
622 SmallVectorImpl<OpFoldResult> &resultOffsets,
623 SmallVectorImpl<OpFoldResult> &resultSizes) const {
624 auto unPackOp = cast<UnPackOp>(op);
625 // If the operand tile is the dest, then no adjustment is needed.
626 if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
627 resultOffsets = llvm::to_vector(offsets);
628 resultSizes = llvm::to_vector(sizes);
629 return success();
631 Location loc = unPackOp.getLoc();
633 int64_t numTiles = unPackOp.getInnerDimsPos().size();
634 auto destOffsets = offsets.drop_back(numTiles);
635 auto destSizes = sizes.drop_back(numTiles);
636 // The tiling is applied on interchanged dimensions. We have to undo the
637 // interchange to map sizes and offsets to the original input.
638 int64_t outputRank = unPackOp.getDestRank();
639 ReifiedRankedShapedTypeDims reifiedReturnShapes;
640 if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes)))
641 return failure();
642 SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front();
643 SmallVector<OpFoldResult> origOffsets(destOffsets);
644 SmallVector<OpFoldResult> origSizes(destSizes);
645 applyPermToRange(origOffsets, origSizes,
646 invertPermutationVector(unPackOp.getOuterDimsPerm()));
648 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
649 unPackOp.getDimAndTileMapping();
651 for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
652 using AV = affine::AffineValueExpr;
653 affine::AffineBuilder ab(b, loc);
654 AffineExpr dim0, dim1, sym0;
655 bindDims(b.getContext(), dim0, dim1);
656 bindSymbols(b.getContext(), sym0);
657 if (dimAndTileMapping.count(dim)) {
658 // If the data dimension is tiled, the i-th index is the product of
659 // offset_i and tile_i, and the i-th size is the product of sizes_i and
660 // tile_i. The sizes must be clamped to the sizes of the unpack result.
661 auto avOffset = AV(dim0).bind(origOffsets[dim]);
662 auto avSize = AV(dim0).bind(origSizes[dim]);
663 auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);
664 auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);
665 resultOffsets.push_back(ab.mul(avOffset, avTileSize));
666 auto avResultOffset = AV(dim1).bind(resultOffsets.back());
667 resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),
668 ab.sub(avResultSize, avResultOffset)}));
669 } else {
670 resultOffsets.push_back(origOffsets[dim]);
671 resultSizes.push_back(origSizes[dim]);
674 return success();
677 /// Method to return the tiled implementation of tensor.unpack as a consumer.
678 FailureOr<TilingResult> getTiledImplementationFromOperandTile(
679 Operation *op, OpBuilder &b, unsigned operandNumber,
680 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
681 auto unPackOp = cast<UnPackOp>(op);
682 // tensor.unpack op is fusible (as a consumer) only if inner dims are not
683 // tiled.
684 int64_t numTiles = unPackOp.getInnerDimsPos().size();
685 for (auto iter :
686 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
687 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
688 return failure();
691 Location loc = unPackOp.getLoc();
693 // Fetch offset/size for creating the slice of the dest operand of
694 // unpack op.
695 SmallVector<OpFoldResult> outputOffsets, outputSizes;
696 if (failed(getIterationDomainTileFromOperandTile(
697 op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
698 outputSizes)))
699 return failure();
701 auto oneAttr = b.getI64IntegerAttr(1);
702 int64_t outputRank = unPackOp.getDestRank();
703 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
705 SmallVector<Value> tiledOperands;
706 // Create slice of the dest operand.
707 auto extractDestSlice = b.create<ExtractSliceOp>(
708 loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
709 tiledOperands.push_back(extractDestSlice);
711 SmallVector<OpFoldResult> inputOffsets, inputSizes;
712 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
713 // Create slice of the source operand.
714 auto extractSourceSlice = b.create<ExtractSliceOp>(
715 loc, unPackOp.getSource(), offsets, sizes, strides);
716 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
717 for (auto tile : unPackOp.getInnerTiles())
718 tiledOperands.push_back(tile);
720 // Create tiled unpack op.
721 Operation *tiledUnPackOp =
722 b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()},
723 tiledOperands, op->getAttrs());
725 return TilingResult{{tiledUnPackOp},
726 SmallVector<Value>(tiledUnPackOp->getResults()),
727 llvm::to_vector(ArrayRef<Operation *>{
728 extractSourceSlice, extractDestSlice})};
732 } // namespace
734 FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
735 tensor::PadOp padOp,
736 ArrayRef<OpFoldResult> offsets,
737 ArrayRef<OpFoldResult> sizes,
738 bool generateZeroSliceGuard) {
739 // Only constant padding value supported.
740 Value padValue = padOp.getConstantPaddingValue();
741 if (!padValue)
742 return failure();
744 // Helper variables and functions for various arithmetic operations. These
745 // are used extensively for computing new offset/length and padding values.
746 Location loc = padOp->getLoc();
747 AffineExpr dim0, dim1;
748 bindDims(b.getContext(), dim0, dim1);
749 // Add two integers.
750 auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
751 auto add = [&](OpFoldResult v1, OpFoldResult v2) {
752 return affine::makeComposedFoldedAffineApply(b, loc, addMap, {v1, v2});
754 // Subtract two integers.
755 auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
756 auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
757 return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2});
759 // Take the minimum of two integers.
760 auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
761 auto min = [&](OpFoldResult v1, OpFoldResult v2) {
762 return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2});
764 // Take the maximum of two integers.
765 auto max = [&](OpFoldResult v1, OpFoldResult v2) {
766 return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2});
768 // Zero index-typed integer.
769 OpFoldResult zero = b.getIndexAttr(0);
771 // Compute new offsets, lengths, low padding, high padding.
772 SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
773 SmallVector<OpFoldResult> newLows, newHighs;
774 // Set to true if the original data source is not read at all.
775 bool hasZeroLen = false;
776 // Same as hasZeroLen, but for dynamic dimension sizes. This condition
777 // is true if the original data source turns out to be unused at runtime.
778 Value dynHasZeroLenCond;
780 int64_t rank = padOp.getSourceType().getRank();
781 for (unsigned dim = 0; dim < rank; ++dim) {
782 auto low = padOp.getMixedLowPad()[dim];
783 bool hasLowPad = !isConstantIntValue(low, 0);
784 auto high = padOp.getMixedHighPad()[dim];
785 bool hasHighPad = !isConstantIntValue(high, 0);
786 auto offset = offsets[dim];
787 auto length = sizes[dim];
788 auto srcSize = tensor::getMixedSize(b, loc, padOp.getSource(), dim);
790 // The new amount of low padding is `low - offset`. Except for the case
791 // where none of the low padding is read. In that case, the new amount of
792 // low padding is zero.
794 // Optimization: If low = 0, then newLow = 0.
795 OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
796 newLows.push_back(newLow);
798 // Start reading the data from position `offset - low`. Since the original
799 // read may have started in the low padding zone, this value could be
800 // negative. Therefore, start reading from:
802 // max(offset - low, 0)
804 // The original read could also have started in the high padding zone.
805 // In that case, set the offset to the end of source tensor. The new
806 // ExtractSliceOp length will be zero in that case. (Effectively reading
807 // no data from the source.)
809 // Optimization: If low = 0, then the formula can be simplified.
810 OpFoldResult newOffset = hasLowPad
811 ? min(max(sub(offset, low), zero), srcSize)
812 : min(offset, srcSize);
813 newOffsets.push_back(newOffset);
815 // The original ExtractSliceOp was reading until position `offset +
816 // length`. Therefore, the corresponding position within the source tensor
817 // is:
819 // offset + length - low
821 // In case the original ExtractSliceOp stopped reading within the low
822 // padding zone, this value can be negative. In that case, the end
823 // position of the read should be zero. (Similar to newOffset.)
825 // The original read could also have stopped in the high padding zone.
826 // In that case, set the end positition of the read should be the end of
827 // the source tensor. (Similar to newOffset.)
829 // endLoc = min(max(offset - low + length, 0), srcSize)
831 // The new ExtractSliceOp length is `endLoc - newOffset`.
833 // Optimization: If low = 0, then the formula can be simplified.
834 OpFoldResult endLoc =
835 hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize)
836 : min(add(offset, length), srcSize);
837 OpFoldResult newLength = sub(endLoc, newOffset);
838 newLengths.push_back(newLength);
840 // Check if newLength is zero. In that case, no SubTensorOp should be
841 // executed.
842 if (isConstantIntValue(newLength, 0)) {
843 hasZeroLen = true;
844 } else if (!hasZeroLen) {
845 Value check = b.create<arith::CmpIOp>(
846 loc, arith::CmpIPredicate::eq,
847 getValueOrCreateConstantIndexOp(b, loc, newLength),
848 getValueOrCreateConstantIndexOp(b, loc, zero));
849 dynHasZeroLenCond =
850 dynHasZeroLenCond
851 ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
852 : check;
855 // The amount of high padding is simply the number of elements remaining,
856 // so that the result has the same length as the original ExtractSliceOp.
857 // As an optimization, if the original high padding is zero, then the new
858 // high padding must also be zero.
859 OpFoldResult newHigh =
860 hasHighPad ? sub(sub(length, newLength), newLow) : zero;
861 newHighs.push_back(newHigh);
863 // Only unit stride supported.
864 newStrides.push_back(b.getIndexAttr(1));
867 // The shape of the result can be obtained from the sizes passed in.
868 SmallVector<Value> dynDims;
869 SmallVector<int64_t> shape;
870 dispatchIndexOpFoldResults(sizes, dynDims, shape);
871 RankedTensorType resultType =
872 RankedTensorType::get(shape, padOp.getResultType().getElementType());
874 // Insert cast to ensure that types match. (May be folded away.)
875 auto castResult = [&](Value val) -> Value {
876 if (resultType == val.getType())
877 return val;
878 return b.create<tensor::CastOp>(loc, resultType, val);
881 // In cases where the original data source is unused: Emit a GenerateOp and
882 // do not generate a SliceOp. (The result shape of the SliceOp would
883 // have a dimension of size 0, the semantics of which is unclear.)
884 auto createGenerateOp = [&]() {
885 // Create GenerateOp.
886 auto generateOp = b.create<tensor::GenerateOp>(
887 loc, resultType, dynDims,
888 [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
889 builder.create<tensor::YieldOp>(gLoc, padValue);
891 return generateOp;
894 // Emit a SliceOp and a PadOp. Should not be used in cases where
895 // the result shape of the new SliceOp has a zero dimension.
896 auto createPadOfExtractSlice = [&]() {
897 // Create pad(extract_slice(x)).
898 auto newSliceOp = b.create<tensor::ExtractSliceOp>(
899 loc, padOp.getSource(), newOffsets, newLengths, newStrides);
900 auto newPadOp = b.create<PadOp>(
901 loc, Type(), newSliceOp, newLows, newHighs,
902 /*nofold=*/padOp.getNofold(),
903 getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
905 // Copy region to new PadOp.
906 IRMapping bvm;
907 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
909 // Cast result and return.
910 return std::make_tuple(newPadOp, newSliceOp);
913 // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
914 // the original data source x is not used.
915 if (hasZeroLen) {
916 Operation *generateOp = createGenerateOp();
917 return TilingResult{{generateOp},
918 {castResult(generateOp->getResult(0))},
919 /*generatedSlices=*/{}};
922 // If there are dynamic dimensions: Generate an scf.if check to avoid
923 // creating SliceOps with result dimensions of size 0 at runtime.
924 if (generateZeroSliceGuard && dynHasZeroLenCond) {
925 Operation *thenOp;
926 Operation *elseOp;
927 Operation *sliceOp;
928 auto result = b.create<scf::IfOp>(
929 loc, dynHasZeroLenCond,
930 /*thenBuilder=*/
931 [&](OpBuilder &b, Location loc) {
932 thenOp = createGenerateOp();
933 b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0)));
935 /*elseBuilder=*/
936 [&](OpBuilder &b, Location loc) {
937 std::tie(elseOp, sliceOp) = createPadOfExtractSlice();
938 b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
940 return TilingResult{
941 {elseOp}, SmallVector<Value>(result->getResults()), {sliceOp}};
944 auto [newPadOp, sliceOp] = createPadOfExtractSlice();
945 return TilingResult{
946 {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}};
949 void mlir::tensor::registerTilingInterfaceExternalModels(
950 DialectRegistry &registry) {
951 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
952 tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
953 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
954 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
958 void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps(
959 DialectRegistry &registry) {
960 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
961 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
962 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);