1 //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Affine/Utils.h"
12 #include "mlir/Dialect/Arith/Utils/Utils.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Utils/Utils.h"
18 #include "mlir/Dialect/Utils/IndexingUtils.h"
19 #include "mlir/Interfaces/InferTypeOpInterface.h"
20 #include "mlir/Interfaces/TilingInterface.h"
21 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
24 using namespace mlir::tensor
;
28 struct PadOpTiling
: public TilingInterface::ExternalModel
<PadOpTiling
, PadOp
> {
30 SmallVector
<utils::IteratorType
> getLoopIteratorTypes(Operation
*op
) const {
31 auto padOp
= cast
<PadOp
>(op
);
32 SmallVector
<utils::IteratorType
> iteratorTypes(
33 padOp
.getResultType().getRank(), utils::IteratorType::parallel
);
37 SmallVector
<Range
> getIterationDomain(Operation
*op
, OpBuilder
&b
) const {
38 ReifiedRankedShapedTypeDims reifiedShapes
;
39 (void)reifyResultShapes(b
, op
, reifiedShapes
);
40 OpFoldResult zero
= b
.getIndexAttr(0);
41 OpFoldResult one
= b
.getIndexAttr(1);
42 // Initialize all the ranges to {zero, one, one}. All the `ub`s are
44 SmallVector
<Range
> loopRanges(reifiedShapes
[0].size(), {zero
, one
, one
});
45 for (const auto &ub
: enumerate(reifiedShapes
[0]))
46 loopRanges
[ub
.index()].size
= ub
.value();
50 FailureOr
<TilingResult
>
51 getTiledImplementation(Operation
*op
, OpBuilder
&b
,
52 ArrayRef
<OpFoldResult
> offsets
,
53 ArrayRef
<OpFoldResult
> sizes
) const {
54 FailureOr
<TilingResult
> result
=
55 tensor::bubbleUpPadSlice(b
, cast
<PadOp
>(op
), offsets
, sizes
);
58 return result
.value();
62 getResultTilePosition(Operation
*op
, OpBuilder
&b
, unsigned resultNumber
,
63 ArrayRef
<OpFoldResult
> offsets
,
64 ArrayRef
<OpFoldResult
> sizes
,
65 SmallVector
<OpFoldResult
> &resultOffsets
,
66 SmallVector
<OpFoldResult
> &resultSizes
) const {
67 resultOffsets
.assign(offsets
.begin(), offsets
.end());
68 resultSizes
.assign(sizes
.begin(), sizes
.end());
72 LogicalResult
getIterationDomainTileFromResultTile(
73 Operation
*op
, OpBuilder
&b
, unsigned resultNumber
,
74 ArrayRef
<OpFoldResult
> offsets
, ArrayRef
<OpFoldResult
> sizes
,
75 SmallVectorImpl
<OpFoldResult
> &iterDomainOffsets
,
76 SmallVectorImpl
<OpFoldResult
> &iterDomainSizes
) const {
77 iterDomainOffsets
.assign(offsets
.begin(), offsets
.end());
78 iterDomainSizes
.assign(sizes
.begin(), sizes
.end());
82 FailureOr
<TilingResult
>
83 generateResultTileValue(Operation
*op
, OpBuilder
&b
, unsigned resultNumber
,
84 ArrayRef
<OpFoldResult
> offsets
,
85 ArrayRef
<OpFoldResult
> sizes
) const {
86 return getTiledImplementation(op
, b
, offsets
, sizes
);
90 template <typename OpTy
>
91 static SmallVector
<Range
> getPackUnPackIterationDomain(OpTy op
,
93 static_assert(llvm::is_one_of
<OpTy
, PackOp
, UnPackOp
>::value
,
94 "applies to only pack or unpack operations");
95 OpBuilder::InsertionGuard
g(builder
);
96 int64_t rank
= (std::is_same
<OpTy
, PackOp
>::value
) ? op
.getSourceRank()
98 OpFoldResult zero
= builder
.getIndexAttr(0);
99 OpFoldResult one
= builder
.getIndexAttr(1);
100 ReifiedRankedShapedTypeDims resultShape
;
101 (void)reifyResultShapes(builder
, op
, resultShape
);
102 SmallVector
<Range
> loopBounds(rank
);
103 for (auto dim
: llvm::seq
<int64_t>(0, rank
)) {
104 loopBounds
[dim
].offset
= zero
;
105 loopBounds
[dim
].stride
= one
;
106 loopBounds
[dim
].size
= resultShape
[0][dim
];
111 static void applyPermToRange(SmallVector
<OpFoldResult
> &offsets
,
112 SmallVector
<OpFoldResult
> &sizes
,
113 ArrayRef
<int64_t> permutation
) {
114 if (permutation
.empty())
116 applyPermutationToVector
<OpFoldResult
>(offsets
, permutation
);
117 applyPermutationToVector
<OpFoldResult
>(sizes
, permutation
);
121 : public TilingInterface::ExternalModel
<PackOpTiling
, PackOp
> {
123 SmallVector
<utils::IteratorType
> getLoopIteratorTypes(Operation
*op
) const {
124 // Note that here we only consider untiled dimensions and outer tiled data
125 // dimensions, the inner tiled data dimensions are materialized when
126 // building the body of the operation.
127 auto packOp
= cast
<PackOp
>(op
);
128 SmallVector
<utils::IteratorType
> iteratorTypes(
129 packOp
.getSourceRank(), utils::IteratorType::parallel
);
130 return iteratorTypes
;
133 SmallVector
<Range
> getIterationDomain(Operation
*op
, OpBuilder
&b
) const {
134 return getPackUnPackIterationDomain
<PackOp
>(cast
<PackOp
>(op
), b
);
137 FailureOr
<TilingResult
>
138 getTiledImplementation(Operation
*op
, OpBuilder
&b
,
139 ArrayRef
<OpFoldResult
> offsets
,
140 ArrayRef
<OpFoldResult
> sizes
) const {
141 auto packOp
= cast
<PackOp
>(op
);
142 Location loc
= packOp
.getLoc();
144 // The tiling is applied on interchanged dimensions. We have to undo the
145 // interchange to map sizes and offsets to the original input.
146 int64_t inputRank
= packOp
.getSourceRank();
147 SmallVector
<OpFoldResult
> origOffsets(offsets
);
148 SmallVector
<OpFoldResult
> origSizes(sizes
);
149 applyPermToRange(origOffsets
, origSizes
,
150 invertPermutationVector(packOp
.getOuterDimsPerm()));
152 DenseMap
<int64_t, OpFoldResult
> dimAndTileMapping
=
153 packOp
.getDimAndTileMapping();
154 SmallVector
<OpFoldResult
> srcDimValues
=
155 tensor::getMixedSizes(b
, loc
, packOp
.getSource());
156 SmallVector
<OpFoldResult
> inputIndices
, inputSizes
;
157 for (auto dim
: llvm::seq
<int64_t>(0, inputRank
)) {
158 using AV
= affine::AffineValueExpr
;
159 affine::AffineBuilder
ab(b
, loc
);
160 AffineExpr dim0
, dim1
, sym
;
161 bindDims(b
.getContext(), dim0
, dim1
);
162 bindSymbols(b
.getContext(), sym
);
163 if (dimAndTileMapping
.count(dim
)) {
164 // If the data dimension is tiled, the i-th index is the product of
165 // offset_i and tile_i, and the i-th size is the product of sizes_i and
167 auto avOffset
= AV(dim0
).bind(origOffsets
[dim
]);
168 auto avSize
= AV(dim0
).bind(origSizes
[dim
]);
169 auto avTileSize
= AV(sym
).bind(dimAndTileMapping
[dim
]);
170 inputIndices
.push_back(ab
.mul(avOffset
, avTileSize
));
171 inputSizes
.push_back(ab
.mul(avSize
, avTileSize
));
173 inputIndices
.push_back(origOffsets
[dim
]);
174 inputSizes
.push_back(origSizes
[dim
]);
177 // Limit the size of the input operand for incomplete tiles.
178 if (packOp
.getPaddingValue()) {
179 OpFoldResult dimSize
= srcDimValues
[dim
];
180 auto avDimSize
= AV(dim0
).bind(dimSize
);
181 auto avInputIdx
= AV(dim1
).bind(inputIndices
.back());
183 ab
.min({inputSizes
.back(), ab
.sub(avDimSize
, avInputIdx
)});
187 auto oneAttr
= b
.getI64IntegerAttr(1);
188 SmallVector
<OpFoldResult
> strides(inputRank
, oneAttr
);
190 SmallVector
<Value
> tiledOperands
;
191 auto sourceSlice
= b
.create
<ExtractSliceOp
>(
192 loc
, packOp
.getSource(), inputIndices
, inputSizes
, strides
);
193 tiledOperands
.push_back(sourceSlice
);
195 SmallVector
<OpFoldResult
> outputOffsets
, outputSizes
;
196 if (failed(getResultTilePosition(op
, b
, 0, offsets
, sizes
, outputOffsets
,
200 strides
.append(packOp
.getDestRank() - inputRank
, oneAttr
);
201 auto outSlice
= b
.create
<ExtractSliceOp
>(
202 loc
, packOp
.getDest(), outputOffsets
, outputSizes
, strides
);
203 tiledOperands
.push_back(outSlice
);
205 if (auto val
= packOp
.getPaddingValue())
206 tiledOperands
.push_back(val
);
207 for (auto tile
: packOp
.getInnerTiles())
208 tiledOperands
.push_back(tile
);
210 Operation
*tiledPackOp
= b
.create
<PackOp
>(
211 loc
, TypeRange
{outSlice
.getType()}, tiledOperands
, op
->getAttrs());
215 SmallVector
<Value
>(tiledPackOp
->getResults()),
216 llvm::to_vector(ArrayRef
<Operation
*>{sourceSlice
, outSlice
})};
220 getResultTilePosition(Operation
*op
, OpBuilder
&b
, unsigned resultNumber
,
221 ArrayRef
<OpFoldResult
> offsets
,
222 ArrayRef
<OpFoldResult
> sizes
,
223 SmallVector
<OpFoldResult
> &resultOffsets
,
224 SmallVector
<OpFoldResult
> &resultSizes
) const {
225 // The iteration domain is over outer dimensions of packed layout. In this
226 // context, the outer dimensions of `resultOffsets` are `offsets`. The
227 // inner dimensions of `resultOffsets` are zeros because tiling is not
229 auto packOp
= cast
<PackOp
>(op
);
230 int64_t inputRank
= packOp
.getSourceRank();
231 int64_t outputRank
= packOp
.getDestRank();
232 auto zeroAttr
= b
.getI64IntegerAttr(0);
233 resultOffsets
.assign(offsets
.begin(), offsets
.end());
234 resultOffsets
.append(outputRank
- inputRank
, zeroAttr
);
236 ReifiedRankedShapedTypeDims outputShape
;
237 (void)reifyResultShapes(b
, packOp
, outputShape
);
238 resultSizes
.assign(sizes
.begin(), sizes
.end());
239 for (auto dataTileDim
: llvm::seq
<unsigned>(inputRank
, outputRank
))
240 resultSizes
.push_back(outputShape
[0][dataTileDim
]);
245 FailureOr
<TilingResult
>
246 generateResultTileValue(Operation
*op
, OpBuilder
&b
, unsigned resultNumber
,
247 ArrayRef
<OpFoldResult
> offsets
,
248 ArrayRef
<OpFoldResult
> sizes
) const {
249 auto packOp
= cast
<PackOp
>(op
);
250 int64_t numTiles
= packOp
.getInnerDimsPos().size();
252 // tensor.pack op is fusible (as a producer) only if full inner tiles are
253 // iterated or inner dims are not tiled. Otherwise, it will generate a
254 // sequence of non-trivial ops (for partial tiles).
255 for (auto offset
: offsets
.take_back(numTiles
))
256 if (!isConstantIntValue(offset
, 0))
260 llvm::zip_equal(packOp
.getMixedTiles(), sizes
.take_back(numTiles
)))
261 if (!isEqualConstantIntOrValue(std::get
<0>(iter
), std::get
<1>(iter
)))
264 FailureOr
<TilingResult
> tilingResult
= getTiledImplementation(
265 op
, b
, offsets
.drop_back(numTiles
), sizes
.drop_back(numTiles
));
266 if (failed(tilingResult
))
268 return tilingResult
.value();
271 /// Method to return the position of iteration domain tile computed by the
272 /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
273 /// `resultSizes` only cover outer dimensions.
274 LogicalResult
getIterationDomainTileFromOperandTile(
275 Operation
*op
, OpBuilder
&b
, unsigned operandNumber
,
276 ArrayRef
<OpFoldResult
> offsets
, ArrayRef
<OpFoldResult
> sizes
,
277 SmallVectorImpl
<OpFoldResult
> &resultOffsets
,
278 SmallVectorImpl
<OpFoldResult
> &resultSizes
) const {
279 if (operandNumber
!= 0)
282 auto packOp
= cast
<PackOp
>(op
);
283 // It is not trivial to infer dest tile from source tile if `packOp` has
285 if (packOp
.getPaddingValue())
288 Location loc
= packOp
.getLoc();
290 SmallVector
<OpFoldResult
> outerDimOffsets
, outerDimSizes
;
291 DenseMap
<int64_t, OpFoldResult
> dimAndTileMapping
=
292 packOp
.getDimAndTileMapping();
293 for (auto dim
: llvm::seq
<int64_t>(packOp
.getSourceRank())) {
294 if (dimAndTileMapping
.count(dim
)) {
295 FailureOr
<int64_t> cstSize
=
296 ValueBoundsConstraintSet::computeConstantBound(
297 presburger::BoundType::UB
, sizes
[dim
],
298 /*stopCondition=*/nullptr, /*closedUB=*/true);
299 std::optional
<int64_t> cstInnerSize
=
300 getConstantIntValue(dimAndTileMapping
[dim
]);
301 // Currently fusing `packOp` as consumer only expects perfect tiling
302 // scenario because even if without padding semantic, the `packOp` may
303 // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
304 // where the `tileSize` from operand of `packOp` is 5, which is not
305 // exactly divided by `innerTile`(=6) of `packOp`. As the result:
306 // 1. the first slice is extracted from (0) to (4) and inserted into
307 // (0,0)~(0,4) at first row.
308 // 2. the second slice is extracted from (5) to (9) and SHOULD BE
309 // respectively inserted into two rows with different length, including
310 // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
311 // them, thus adding below constraint to bypass them temporarily. In
312 // another word, we can only support tiling with consumer if the tile
313 // size for the producer is a multiple of the inner tile size for the
314 // packed dimensions at this moment.
315 if (failed(cstSize
) || !cstInnerSize
|| *cstSize
% *cstInnerSize
!= 0) {
319 using AV
= affine::AffineValueExpr
;
320 affine::AffineBuilder
ab(b
, loc
);
321 AffineExpr dim0
, sym
;
322 bindDims(b
.getContext(), dim0
);
323 bindSymbols(b
.getContext(), sym
);
324 auto avOffset
= AV(dim0
).bind(offsets
[dim
]);
325 auto avSize
= AV(dim0
).bind(sizes
[dim
]);
326 auto avTileSize
= AV(sym
).bind(dimAndTileMapping
[dim
]);
327 outerDimOffsets
.push_back(ab
.floor(avOffset
, avTileSize
));
328 outerDimSizes
.push_back(ab
.ceil(avSize
, avTileSize
));
330 outerDimOffsets
.push_back(offsets
[dim
]);
331 outerDimSizes
.push_back(sizes
[dim
]);
334 applyPermToRange(outerDimOffsets
, outerDimSizes
, packOp
.getOuterDimsPerm());
335 resultOffsets
= outerDimOffsets
;
336 resultSizes
= outerDimSizes
;
340 /// Method to return the tiled implementation of tensor.pack as a consumer.
341 FailureOr
<TilingResult
> getTiledImplementationFromOperandTile(
342 Operation
*op
, OpBuilder
&b
, unsigned operandNumber
,
343 ArrayRef
<OpFoldResult
> offsets
, ArrayRef
<OpFoldResult
> sizes
) const {
344 if (operandNumber
!= 0)
347 auto packOp
= cast
<PackOp
>(op
);
348 Location loc
= packOp
.getLoc();
350 int64_t inputRank
= packOp
.getSourceRank();
351 auto oneAttr
= b
.getI64IntegerAttr(1);
352 SmallVector
<OpFoldResult
> strides(inputRank
, oneAttr
);
354 SmallVector
<Value
> tiledOperands
;
355 auto sourceSlice
= b
.create
<ExtractSliceOp
>(loc
, packOp
.getSource(),
356 offsets
, sizes
, strides
);
357 tiledOperands
.push_back(sourceSlice
);
359 SmallVector
<OpFoldResult
> outerDimOffsets
, outerDimSizes
;
360 if (failed(getIterationDomainTileFromOperandTile(
361 op
, b
, /*operandNumber=*/0, offsets
, sizes
, outerDimOffsets
,
365 SmallVector
<OpFoldResult
> outputOffsets
, outputSizes
;
366 if (failed(getResultTilePosition(op
, b
, 0, outerDimOffsets
, outerDimSizes
,
367 outputOffsets
, outputSizes
)))
370 strides
.append(packOp
.getDestRank() - inputRank
, oneAttr
);
371 auto outSlice
= b
.create
<ExtractSliceOp
>(
372 loc
, packOp
.getDest(), outputOffsets
, outputSizes
, strides
);
373 tiledOperands
.push_back(outSlice
);
375 assert(!packOp
.getPaddingValue() && "Expect no padding semantic");
376 for (auto tile
: packOp
.getInnerTiles())
377 tiledOperands
.push_back(tile
);
379 Operation
*tiledPackOp
= b
.create
<PackOp
>(
380 loc
, TypeRange
{outSlice
.getType()}, tiledOperands
, op
->getAttrs());
384 SmallVector
<Value
>(tiledPackOp
->getResults()),
385 llvm::to_vector(ArrayRef
<Operation
*>{sourceSlice
, outSlice
})};
389 struct UnpackTileDimInfo
{
390 bool isAlignedToInnerTileSize
;
391 OpFoldResult sourceOffset
;
392 OpFoldResult sourceSize
;
393 OpFoldResult resultOffset
;
394 OpFoldResult destExpandedSize
;
397 /// Returns the needed information for tiling unpack op on `tileDim` with given
398 /// `tileOffset` and `tileSize`. For more details, see the comment of the
399 /// `getTiledImplementation`.
400 static UnpackTileDimInfo
getUnpackTileDimInfo(OpBuilder
&b
, UnPackOp unpackOp
,
402 OpFoldResult tileOffset
,
403 OpFoldResult tileSize
) {
404 UnpackTileDimInfo info
;
405 Attribute zeroAttr
= b
.getIndexAttr(0);
406 Attribute oneAttr
= b
.getIndexAttr(1);
407 DenseMap
<int64_t, OpFoldResult
> dimAndTileMapping
=
408 unpackOp
.getDimAndTileMapping();
409 // The dimension is not one of packed data dimension.
410 if (!dimAndTileMapping
.count(tileDim
)) {
411 info
.isAlignedToInnerTileSize
= true;
412 info
.sourceOffset
= tileOffset
;
413 info
.sourceSize
= tileSize
;
414 info
.resultOffset
= zeroAttr
;
415 info
.destExpandedSize
= tileSize
;
419 Location loc
= unpackOp
.getLoc();
420 using AV
= affine::AffineValueExpr
;
421 affine::AffineBuilder
ab(b
, loc
);
422 AffineExpr dim0
, dim1
, sym0
;
423 bindDims(b
.getContext(), dim0
, dim1
);
424 bindSymbols(b
.getContext(), sym0
);
426 OpFoldResult innerTileSize
= dimAndTileMapping
[tileDim
];
428 info
.isAlignedToInnerTileSize
= false;
429 FailureOr
<int64_t> cstSize
= ValueBoundsConstraintSet::computeConstantBound(
430 presburger::BoundType::UB
, tileSize
,
431 /*stopCondition=*/nullptr, /*closedUB=*/true);
432 std::optional
<int64_t> cstInnerSize
= getConstantIntValue(innerTileSize
);
433 if (!failed(cstSize
) && cstInnerSize
) {
434 if (*cstSize
% *cstInnerSize
== 0)
435 info
.isAlignedToInnerTileSize
= true;
437 // If the tiling size equals to the inner tiling size, the outer dims are
439 if (*cstInnerSize
== *cstSize
) {
440 auto lhs
= AV(dim0
).bind(tileOffset
);
441 auto rhs
= AV(dim1
).bind(innerTileSize
);
442 info
.sourceOffset
= ab
.floor(lhs
, rhs
);
443 info
.sourceSize
= oneAttr
;
444 info
.resultOffset
= zeroAttr
;
445 info
.destExpandedSize
= tileSize
;
450 if (info
.isAlignedToInnerTileSize
) {
452 ab
.floor(AV(dim0
).bind(tileOffset
), AV(dim1
).bind(innerTileSize
));
453 info
.resultOffset
= zeroAttr
;
454 info
.destExpandedSize
= tileSize
;
456 // The ceilDiv is needed here because there could be incomplete tile even
457 // it is perfect tiling cases. E.g.,
458 // %0 = unpack tensor<33x2xf32> into tensor<64xf32>
459 // If the tiling size is 32, there will be 3 tiles. Two of them have
460 // size=32; one of them have size=2. The size is represented using
461 // affine_min op; we need ceilDiv.
463 ab
.ceil(AV(dim0
).bind(tileSize
), AV(dim1
).bind(innerTileSize
));
467 affine::DivModValue firstCoord
= affine::getDivMod(
468 b
, loc
, getValueOrCreateConstantIndexOp(b
, loc
, tileOffset
),
469 getValueOrCreateConstantIndexOp(b
, loc
, innerTileSize
));
470 OpFoldResult tileExclusiveBound
=
471 ab
.add(AV(dim0
).bind(tileOffset
), AV(dim1
).bind(tileSize
));
472 affine::DivModValue lastCoord
= affine::getDivMod(
474 getValueOrCreateConstantIndexOp(
476 ab
.sub(AV(dim0
).bind(tileExclusiveBound
), AV(dim1
).bind(oneAttr
))),
477 getValueOrCreateConstantIndexOp(b
, loc
, innerTileSize
));
479 OpFoldResult lengthMinusOne
= ab
.sub(AV(dim0
).bind(lastCoord
.quotient
),
480 AV(dim1
).bind(firstCoord
.quotient
));
482 ab
.add(AV(dim0
).bind(lengthMinusOne
), AV(dim1
).bind(oneAttr
));
483 info
.sourceOffset
= firstCoord
.quotient
;
484 info
.resultOffset
= firstCoord
.remainder
;
485 // Do not create an Affine ops for expanded size because the affine op is too
486 // complicated which would trigger an issue in affine ops simplification.
487 info
.destExpandedSize
= b
.createOrFold
<arith::MulIOp
>(
488 loc
, getValueOrCreateConstantIndexOp(b
, loc
, info
.sourceSize
),
489 getValueOrCreateConstantIndexOp(b
, loc
, innerTileSize
));
493 struct UnPackOpTiling
494 : public TilingInterface::ExternalModel
<UnPackOpTiling
, UnPackOp
> {
496 SmallVector
<utils::IteratorType
> getLoopIteratorTypes(Operation
*op
) const {
497 auto unpackOp
= cast
<UnPackOp
>(op
);
498 SmallVector
<utils::IteratorType
> iteratorTypes(
499 unpackOp
.getDestRank(), utils::IteratorType::parallel
);
500 return iteratorTypes
;
503 SmallVector
<Range
> getIterationDomain(Operation
*op
, OpBuilder
&b
) const {
504 return getPackUnPackIterationDomain
<UnPackOp
>(cast
<UnPackOp
>(op
), b
);
507 /// There are two cases in tiling unpack ops. If the tiling size is aligned to
508 /// the inner tile size, the corresponding tiles of source are all complete.
509 /// Otherwise, there are in-complete tiles. We will need to expand the slice
510 /// of source for getting complete tiles. The tiled unpack op unpacks more
511 /// data from source, so We'll need an extract_slice op to shift and truncate
513 /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
514 /// coordinates of second tile (i.e., result[15..31]) are
515 /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
516 /// row are incomplete tiles. To represent the unpack op, we have to complete
517 /// the rows. I.e., the input coordinates would start with (1, 0); end with
518 /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
519 /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
520 /// can get the actual result.
521 FailureOr
<TilingResult
>
522 getTiledImplementation(Operation
*op
, OpBuilder
&b
,
523 ArrayRef
<OpFoldResult
> offsets
,
524 ArrayRef
<OpFoldResult
> sizes
) const {
525 auto unpackOp
= cast
<UnPackOp
>(op
);
526 int64_t srcRank
= unpackOp
.getSourceRank();
527 int64_t destRank
= unpackOp
.getDestRank();
528 int64_t numInnerTiles
= srcRank
- destRank
;
529 Location loc
= unpackOp
.getLoc();
531 // The perfect tiling case indicates that the tiling sizes are multiple of
532 // inner_tile_size. In this context, no extra data is needed when
533 // representing the tiled unpack op.
534 bool isPerfectTilingCase
= true;
535 Attribute oneAttr
= b
.getIndexAttr(1);
536 SmallVector
<OpFoldResult
> sliceSrcStrides(destRank
, oneAttr
);
537 SmallVector
<OpFoldResult
> sliceSrcIndices
, sliceSrcSizes
;
538 SmallVector
<OpFoldResult
> destExpandedSizes
, resultOffsetsFromDest
;
539 for (auto dim
: llvm::seq
<int64_t>(0, destRank
)) {
540 UnpackTileDimInfo info
=
541 getUnpackTileDimInfo(b
, unpackOp
, dim
, offsets
[dim
], sizes
[dim
]);
542 if (!info
.isAlignedToInnerTileSize
)
543 isPerfectTilingCase
= false;
544 sliceSrcIndices
.push_back(info
.sourceOffset
);
545 sliceSrcSizes
.push_back(info
.sourceSize
);
546 destExpandedSizes
.push_back(info
.destExpandedSize
);
547 resultOffsetsFromDest
.push_back(info
.resultOffset
);
550 // The tiling is applied on destination dimensions. We have to apply the
551 // interchange on source dimensions if outer_dims_perm is set.
552 applyPermToRange(sliceSrcIndices
, sliceSrcSizes
,
553 unpackOp
.getOuterDimsPerm());
554 Attribute zeroAttr
= b
.getIndexAttr(0);
555 sliceSrcIndices
.append(numInnerTiles
, zeroAttr
);
556 sliceSrcSizes
.append(unpackOp
.getMixedTiles());
557 sliceSrcStrides
.append(numInnerTiles
, oneAttr
);
558 SmallVector
<Operation
*> generatedSlices
;
559 ExtractSliceOp sliceSource
=
560 b
.create
<ExtractSliceOp
>(loc
, unpackOp
.getSource(), sliceSrcIndices
,
561 sliceSrcSizes
, sliceSrcStrides
);
562 generatedSlices
.push_back(sliceSource
);
564 SmallVector
<OpFoldResult
> destStrides(destRank
, oneAttr
);
566 if (isPerfectTilingCase
) {
567 auto destSliceOp
= b
.create
<ExtractSliceOp
>(loc
, unpackOp
.getDest(),
568 offsets
, sizes
, destStrides
);
569 sliceDest
= destSliceOp
;
570 generatedSlices
.push_back(destSliceOp
);
572 sliceDest
= b
.create
<EmptyOp
>(loc
, destExpandedSizes
,
573 unpackOp
.getDestType().getElementType());
576 SmallVector
<Value
> tiledOperands
= {sliceSource
.getResult(), sliceDest
};
577 for (auto tile
: unpackOp
.getInnerTiles())
578 tiledOperands
.push_back(tile
);
580 Operation
*tiledUnpackOp
= b
.create
<UnPackOp
>(
581 loc
, TypeRange
{sliceDest
.getType()}, tiledOperands
, op
->getAttrs());
583 if (isPerfectTilingCase
)
584 return TilingResult
{{tiledUnpackOp
},
585 SmallVector
<Value
>(tiledUnpackOp
->getResults()),
589 b
.create
<ExtractSliceOp
>(loc
, tiledUnpackOp
->getResult(0),
590 resultOffsetsFromDest
, sizes
, destStrides
);
592 {tiledUnpackOp
}, {extractSlice
.getResult()}, generatedSlices
};
596 getResultTilePosition(Operation
*op
, OpBuilder
&b
, unsigned resultNumber
,
597 ArrayRef
<OpFoldResult
> offsets
,
598 ArrayRef
<OpFoldResult
> sizes
,
599 SmallVector
<OpFoldResult
> &resultOffsets
,
600 SmallVector
<OpFoldResult
> &resultSizes
) const {
601 resultOffsets
= llvm::to_vector(offsets
);
602 resultSizes
= llvm::to_vector(sizes
);
606 FailureOr
<TilingResult
>
607 generateResultTileValue(Operation
*op
, OpBuilder
&b
, unsigned resultNumber
,
608 ArrayRef
<OpFoldResult
> offsets
,
609 ArrayRef
<OpFoldResult
> sizes
) const {
610 FailureOr
<TilingResult
> tilingResult
=
611 getTiledImplementation(op
, b
, offsets
, sizes
);
612 if (failed(tilingResult
))
614 return tilingResult
.value();
617 /// Method to return the position of iteration domain tile computed by the
619 LogicalResult
getIterationDomainTileFromOperandTile(
620 Operation
*op
, OpBuilder
&b
, unsigned operandNumber
,
621 ArrayRef
<OpFoldResult
> offsets
, ArrayRef
<OpFoldResult
> sizes
,
622 SmallVectorImpl
<OpFoldResult
> &resultOffsets
,
623 SmallVectorImpl
<OpFoldResult
> &resultSizes
) const {
624 auto unPackOp
= cast
<UnPackOp
>(op
);
625 // If the operand tile is the dest, then no adjustment is needed.
626 if (operandNumber
== unPackOp
.getDestMutable().getOperandNumber()) {
627 resultOffsets
= llvm::to_vector(offsets
);
628 resultSizes
= llvm::to_vector(sizes
);
631 Location loc
= unPackOp
.getLoc();
633 int64_t numTiles
= unPackOp
.getInnerDimsPos().size();
634 auto destOffsets
= offsets
.drop_back(numTiles
);
635 auto destSizes
= sizes
.drop_back(numTiles
);
636 // The tiling is applied on interchanged dimensions. We have to undo the
637 // interchange to map sizes and offsets to the original input.
638 int64_t outputRank
= unPackOp
.getDestRank();
639 ReifiedRankedShapedTypeDims reifiedReturnShapes
;
640 if (failed(reifyResultShapes(b
, unPackOp
, reifiedReturnShapes
)))
642 SmallVector
<OpFoldResult
> outputMixedSizes
= reifiedReturnShapes
.front();
643 SmallVector
<OpFoldResult
> origOffsets(destOffsets
);
644 SmallVector
<OpFoldResult
> origSizes(destSizes
);
645 applyPermToRange(origOffsets
, origSizes
,
646 invertPermutationVector(unPackOp
.getOuterDimsPerm()));
648 DenseMap
<int64_t, OpFoldResult
> dimAndTileMapping
=
649 unPackOp
.getDimAndTileMapping();
651 for (auto dim
: llvm::seq
<int64_t>(0, outputRank
)) {
652 using AV
= affine::AffineValueExpr
;
653 affine::AffineBuilder
ab(b
, loc
);
654 AffineExpr dim0
, dim1
, sym0
;
655 bindDims(b
.getContext(), dim0
, dim1
);
656 bindSymbols(b
.getContext(), sym0
);
657 if (dimAndTileMapping
.count(dim
)) {
658 // If the data dimension is tiled, the i-th index is the product of
659 // offset_i and tile_i, and the i-th size is the product of sizes_i and
660 // tile_i. The sizes must be clamped to the sizes of the unpack result.
661 auto avOffset
= AV(dim0
).bind(origOffsets
[dim
]);
662 auto avSize
= AV(dim0
).bind(origSizes
[dim
]);
663 auto avTileSize
= AV(sym0
).bind(dimAndTileMapping
[dim
]);
664 auto avResultSize
= AV(dim0
).bind(outputMixedSizes
[dim
]);
665 resultOffsets
.push_back(ab
.mul(avOffset
, avTileSize
));
666 auto avResultOffset
= AV(dim1
).bind(resultOffsets
.back());
667 resultSizes
.push_back(ab
.min({ab
.mul(avSize
, avTileSize
),
668 ab
.sub(avResultSize
, avResultOffset
)}));
670 resultOffsets
.push_back(origOffsets
[dim
]);
671 resultSizes
.push_back(origSizes
[dim
]);
677 /// Method to return the tiled implementation of tensor.unpack as a consumer.
678 FailureOr
<TilingResult
> getTiledImplementationFromOperandTile(
679 Operation
*op
, OpBuilder
&b
, unsigned operandNumber
,
680 ArrayRef
<OpFoldResult
> offsets
, ArrayRef
<OpFoldResult
> sizes
) const {
681 auto unPackOp
= cast
<UnPackOp
>(op
);
682 // tensor.unpack op is fusible (as a consumer) only if inner dims are not
684 int64_t numTiles
= unPackOp
.getInnerDimsPos().size();
686 llvm::zip_equal(unPackOp
.getMixedTiles(), sizes
.take_back(numTiles
))) {
687 if (!isEqualConstantIntOrValue(std::get
<0>(iter
), std::get
<1>(iter
)))
691 Location loc
= unPackOp
.getLoc();
693 // Fetch offset/size for creating the slice of the dest operand of
695 SmallVector
<OpFoldResult
> outputOffsets
, outputSizes
;
696 if (failed(getIterationDomainTileFromOperandTile(
697 op
, b
, /*operandNumber=*/0, offsets
, sizes
, outputOffsets
,
701 auto oneAttr
= b
.getI64IntegerAttr(1);
702 int64_t outputRank
= unPackOp
.getDestRank();
703 SmallVector
<OpFoldResult
> strides(outputRank
, oneAttr
);
705 SmallVector
<Value
> tiledOperands
;
706 // Create slice of the dest operand.
707 auto extractDestSlice
= b
.create
<ExtractSliceOp
>(
708 loc
, unPackOp
.getDest(), outputOffsets
, outputSizes
, strides
);
709 tiledOperands
.push_back(extractDestSlice
);
711 SmallVector
<OpFoldResult
> inputOffsets
, inputSizes
;
712 strides
.append(unPackOp
.getSourceRank() - outputRank
, oneAttr
);
713 // Create slice of the source operand.
714 auto extractSourceSlice
= b
.create
<ExtractSliceOp
>(
715 loc
, unPackOp
.getSource(), offsets
, sizes
, strides
);
716 tiledOperands
.insert(tiledOperands
.begin(), extractSourceSlice
);
717 for (auto tile
: unPackOp
.getInnerTiles())
718 tiledOperands
.push_back(tile
);
720 // Create tiled unpack op.
721 Operation
*tiledUnPackOp
=
722 b
.create
<UnPackOp
>(loc
, TypeRange
{extractDestSlice
.getType()},
723 tiledOperands
, op
->getAttrs());
725 return TilingResult
{{tiledUnPackOp
},
726 SmallVector
<Value
>(tiledUnPackOp
->getResults()),
727 llvm::to_vector(ArrayRef
<Operation
*>{
728 extractSourceSlice
, extractDestSlice
})};
734 FailureOr
<TilingResult
> tensor::bubbleUpPadSlice(OpBuilder
&b
,
736 ArrayRef
<OpFoldResult
> offsets
,
737 ArrayRef
<OpFoldResult
> sizes
,
738 bool generateZeroSliceGuard
) {
739 // Only constant padding value supported.
740 Value padValue
= padOp
.getConstantPaddingValue();
744 // Helper variables and functions for various arithmetic operations. These
745 // are used extensively for computing new offset/length and padding values.
746 Location loc
= padOp
->getLoc();
747 AffineExpr dim0
, dim1
;
748 bindDims(b
.getContext(), dim0
, dim1
);
750 auto addMap
= AffineMap::get(2, 0, {dim0
+ dim1
});
751 auto add
= [&](OpFoldResult v1
, OpFoldResult v2
) {
752 return affine::makeComposedFoldedAffineApply(b
, loc
, addMap
, {v1
, v2
});
754 // Subtract two integers.
755 auto subMap
= AffineMap::get(2, 0, {dim0
- dim1
});
756 auto sub
= [&](OpFoldResult v1
, OpFoldResult v2
) {
757 return affine::makeComposedFoldedAffineApply(b
, loc
, subMap
, {v1
, v2
});
759 // Take the minimum of two integers.
760 auto idMap
= AffineMap::getMultiDimIdentityMap(2, b
.getContext());
761 auto min
= [&](OpFoldResult v1
, OpFoldResult v2
) {
762 return affine::makeComposedFoldedAffineMin(b
, loc
, idMap
, {v1
, v2
});
764 // Take the maximum of two integers.
765 auto max
= [&](OpFoldResult v1
, OpFoldResult v2
) {
766 return affine::makeComposedFoldedAffineMax(b
, loc
, idMap
, {v1
, v2
});
768 // Zero index-typed integer.
769 OpFoldResult zero
= b
.getIndexAttr(0);
771 // Compute new offsets, lengths, low padding, high padding.
772 SmallVector
<OpFoldResult
> newOffsets
, newLengths
, newStrides
;
773 SmallVector
<OpFoldResult
> newLows
, newHighs
;
774 // Set to true if the original data source is not read at all.
775 bool hasZeroLen
= false;
776 // Same as hasZeroLen, but for dynamic dimension sizes. This condition
777 // is true if the original data source turns out to be unused at runtime.
778 Value dynHasZeroLenCond
;
780 int64_t rank
= padOp
.getSourceType().getRank();
781 for (unsigned dim
= 0; dim
< rank
; ++dim
) {
782 auto low
= padOp
.getMixedLowPad()[dim
];
783 bool hasLowPad
= !isConstantIntValue(low
, 0);
784 auto high
= padOp
.getMixedHighPad()[dim
];
785 bool hasHighPad
= !isConstantIntValue(high
, 0);
786 auto offset
= offsets
[dim
];
787 auto length
= sizes
[dim
];
788 auto srcSize
= tensor::getMixedSize(b
, loc
, padOp
.getSource(), dim
);
790 // The new amount of low padding is `low - offset`. Except for the case
791 // where none of the low padding is read. In that case, the new amount of
792 // low padding is zero.
794 // Optimization: If low = 0, then newLow = 0.
795 OpFoldResult newLow
= hasLowPad
? max(zero
, sub(low
, offset
)) : zero
;
796 newLows
.push_back(newLow
);
798 // Start reading the data from position `offset - low`. Since the original
799 // read may have started in the low padding zone, this value could be
800 // negative. Therefore, start reading from:
802 // max(offset - low, 0)
804 // The original read could also have started in the high padding zone.
805 // In that case, set the offset to the end of source tensor. The new
806 // ExtractSliceOp length will be zero in that case. (Effectively reading
807 // no data from the source.)
809 // Optimization: If low = 0, then the formula can be simplified.
810 OpFoldResult newOffset
= hasLowPad
811 ? min(max(sub(offset
, low
), zero
), srcSize
)
812 : min(offset
, srcSize
);
813 newOffsets
.push_back(newOffset
);
815 // The original ExtractSliceOp was reading until position `offset +
816 // length`. Therefore, the corresponding position within the source tensor
819 // offset + length - low
821 // In case the original ExtractSliceOp stopped reading within the low
822 // padding zone, this value can be negative. In that case, the end
823 // position of the read should be zero. (Similar to newOffset.)
825 // The original read could also have stopped in the high padding zone.
826 // In that case, set the end positition of the read should be the end of
827 // the source tensor. (Similar to newOffset.)
829 // endLoc = min(max(offset - low + length, 0), srcSize)
831 // The new ExtractSliceOp length is `endLoc - newOffset`.
833 // Optimization: If low = 0, then the formula can be simplified.
834 OpFoldResult endLoc
=
835 hasLowPad
? min(max(add(sub(offset
, low
), length
), zero
), srcSize
)
836 : min(add(offset
, length
), srcSize
);
837 OpFoldResult newLength
= sub(endLoc
, newOffset
);
838 newLengths
.push_back(newLength
);
840 // Check if newLength is zero. In that case, no SubTensorOp should be
842 if (isConstantIntValue(newLength
, 0)) {
844 } else if (!hasZeroLen
) {
845 Value check
= b
.create
<arith::CmpIOp
>(
846 loc
, arith::CmpIPredicate::eq
,
847 getValueOrCreateConstantIndexOp(b
, loc
, newLength
),
848 getValueOrCreateConstantIndexOp(b
, loc
, zero
));
851 ? b
.create
<arith::OrIOp
>(loc
, check
, dynHasZeroLenCond
)
855 // The amount of high padding is simply the number of elements remaining,
856 // so that the result has the same length as the original ExtractSliceOp.
857 // As an optimization, if the original high padding is zero, then the new
858 // high padding must also be zero.
859 OpFoldResult newHigh
=
860 hasHighPad
? sub(sub(length
, newLength
), newLow
) : zero
;
861 newHighs
.push_back(newHigh
);
863 // Only unit stride supported.
864 newStrides
.push_back(b
.getIndexAttr(1));
867 // The shape of the result can be obtained from the sizes passed in.
868 SmallVector
<Value
> dynDims
;
869 SmallVector
<int64_t> shape
;
870 dispatchIndexOpFoldResults(sizes
, dynDims
, shape
);
871 RankedTensorType resultType
=
872 RankedTensorType::get(shape
, padOp
.getResultType().getElementType());
874 // Insert cast to ensure that types match. (May be folded away.)
875 auto castResult
= [&](Value val
) -> Value
{
876 if (resultType
== val
.getType())
878 return b
.create
<tensor::CastOp
>(loc
, resultType
, val
);
881 // In cases where the original data source is unused: Emit a GenerateOp and
882 // do not generate a SliceOp. (The result shape of the SliceOp would
883 // have a dimension of size 0, the semantics of which is unclear.)
884 auto createGenerateOp
= [&]() {
885 // Create GenerateOp.
886 auto generateOp
= b
.create
<tensor::GenerateOp
>(
887 loc
, resultType
, dynDims
,
888 [&](OpBuilder
&builder
, Location gLoc
, ValueRange indices
) {
889 builder
.create
<tensor::YieldOp
>(gLoc
, padValue
);
894 // Emit a SliceOp and a PadOp. Should not be used in cases where
895 // the result shape of the new SliceOp has a zero dimension.
896 auto createPadOfExtractSlice
= [&]() {
897 // Create pad(extract_slice(x)).
898 auto newSliceOp
= b
.create
<tensor::ExtractSliceOp
>(
899 loc
, padOp
.getSource(), newOffsets
, newLengths
, newStrides
);
900 auto newPadOp
= b
.create
<PadOp
>(
901 loc
, Type(), newSliceOp
, newLows
, newHighs
,
902 /*nofold=*/padOp
.getNofold(),
903 getPrunedAttributeList(padOp
, PadOp::getAttributeNames()));
905 // Copy region to new PadOp.
907 padOp
.getRegion().cloneInto(&newPadOp
.getRegion(), bvm
);
909 // Cast result and return.
910 return std::make_tuple(newPadOp
, newSliceOp
);
913 // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
914 // the original data source x is not used.
916 Operation
*generateOp
= createGenerateOp();
917 return TilingResult
{{generateOp
},
918 {castResult(generateOp
->getResult(0))},
919 /*generatedSlices=*/{}};
922 // If there are dynamic dimensions: Generate an scf.if check to avoid
923 // creating SliceOps with result dimensions of size 0 at runtime.
924 if (generateZeroSliceGuard
&& dynHasZeroLenCond
) {
928 auto result
= b
.create
<scf::IfOp
>(
929 loc
, dynHasZeroLenCond
,
931 [&](OpBuilder
&b
, Location loc
) {
932 thenOp
= createGenerateOp();
933 b
.create
<scf::YieldOp
>(loc
, castResult(thenOp
->getResult(0)));
936 [&](OpBuilder
&b
, Location loc
) {
937 std::tie(elseOp
, sliceOp
) = createPadOfExtractSlice();
938 b
.create
<scf::YieldOp
>(loc
, castResult(elseOp
->getResult(0)));
941 {elseOp
}, SmallVector
<Value
>(result
->getResults()), {sliceOp
}};
944 auto [newPadOp
, sliceOp
] = createPadOfExtractSlice();
946 {newPadOp
}, {castResult(newPadOp
->getResult(0))}, {sliceOp
}};
949 void mlir::tensor::registerTilingInterfaceExternalModels(
950 DialectRegistry
®istry
) {
951 registry
.addExtension(+[](MLIRContext
*ctx
, TensorDialect
*dialect
) {
952 tensor::PadOp::attachInterface
<PadOpTiling
>(*ctx
);
953 tensor::PackOp::attachInterface
<PackOpTiling
>(*ctx
);
954 tensor::UnPackOp::attachInterface
<UnPackOpTiling
>(*ctx
);
958 void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps(
959 DialectRegistry
®istry
) {
960 registry
.addExtension(+[](MLIRContext
*ctx
, TensorDialect
*dialect
) {
961 tensor::PackOp::attachInterface
<PackOpTiling
>(*ctx
);
962 tensor::UnPackOp::attachInterface
<UnPackOpTiling
>(*ctx
);