1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements utilities for the Linalg dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/Arith/Utils/Utils.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Linalg/IR/Linalg.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h"
27 #include "mlir/Dialect/Tensor/Utils/Utils.h"
28 #include "mlir/Dialect/Utils/IndexingUtils.h"
29 #include "mlir/Dialect/Utils/StaticValueUtils.h"
30 #include "mlir/IR/AffineExpr.h"
31 #include "mlir/IR/AffineExprVisitor.h"
32 #include "mlir/IR/AffineMap.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/Pass/Pass.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/Debug.h"
40 #define DEBUG_TYPE "linalg-utils"
43 using namespace presburger
;
44 using namespace mlir::affine
;
45 using namespace mlir::linalg
;
46 using namespace mlir::scf
;
50 // Helper visitor to determine whether an AffineExpr is tiled.
51 // This is achieved by traversing every AffineDimExpr with position `pos` and
52 // checking whether the corresponding `tileSizes[pos]` is non-zero.
53 // This also enforces only positive coefficients occur in multiplications.
56 // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
58 struct TileCheck
: public AffineExprVisitor
<TileCheck
> {
59 TileCheck(ArrayRef
<OpFoldResult
> tileSizes
) : tileSizes(tileSizes
) {}
61 void visitDimExpr(AffineDimExpr expr
) {
62 isTiled
|= !isZeroIndex(tileSizes
[expr
.getPosition()]);
64 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr
) {
67 if (expr
.getKind() == mlir::AffineExprKind::Mul
)
68 assert(cast
<AffineConstantExpr
>(expr
.getRHS()).getValue() > 0 &&
69 "nonpositive multiplying coefficient");
72 ArrayRef
<OpFoldResult
> tileSizes
;
77 static bool isTiled(AffineExpr expr
, ArrayRef
<OpFoldResult
> tileSizes
) {
80 TileCheck
t(tileSizes
);
85 // Checks whether the `map varies with respect to a non-zero `tileSize`.
86 static bool isTiled(AffineMap map
, ArrayRef
<OpFoldResult
> tileSizes
) {
89 for (unsigned r
= 0; r
< map
.getNumResults(); ++r
)
90 if (isTiled(map
.getResult(r
), tileSizes
))
95 std::optional
<RegionMatcher::BinaryOpKind
>
96 RegionMatcher::matchAsScalarBinaryOp(GenericOp op
) {
97 auto ®ion
= op
.getRegion();
98 if (!llvm::hasSingleElement(region
))
101 Block
&block
= region
.front();
102 if (block
.getNumArguments() != 2 ||
103 !block
.getArgument(0).getType().isSignlessIntOrFloat() ||
104 !block
.getArgument(1).getType().isSignlessIntOrFloat())
107 auto &ops
= block
.getOperations();
108 if (!llvm::hasSingleElement(block
.without_terminator()))
111 using mlir::matchers::m_Val
;
112 auto a
= m_Val(block
.getArgument(0));
113 auto b
= m_Val(block
.getArgument(1));
115 auto addPattern
= m_Op
<linalg::YieldOp
>(m_Op
<arith::AddIOp
>(a
, b
));
116 if (addPattern
.match(&ops
.back()))
117 return BinaryOpKind::IAdd
;
122 /// Explicit instantiation of loop nest generator for different loop types.
123 template struct mlir::linalg::GenerateLoopNest
<scf::ForOp
>;
124 template struct mlir::linalg::GenerateLoopNest
<scf::ParallelOp
>;
125 template struct mlir::linalg::GenerateLoopNest
<AffineForOp
>;
127 /// Given a list of subview ranges, extract individual values for lower, upper
128 /// bounds and steps and put them into the corresponding vectors.
129 static void unpackRanges(OpBuilder
&builder
, Location loc
,
130 ArrayRef
<Range
> ranges
, SmallVectorImpl
<Value
> &lbs
,
131 SmallVectorImpl
<Value
> &ubs
,
132 SmallVectorImpl
<Value
> &steps
) {
133 for (Range range
: ranges
) {
135 getValueOrCreateConstantIndexOp(builder
, loc
, range
.offset
));
136 ubs
.emplace_back(getValueOrCreateConstantIndexOp(builder
, loc
, range
.size
));
138 getValueOrCreateConstantIndexOp(builder
, loc
, range
.stride
));
142 //===----------------------------------------------------------------------===//
144 //===----------------------------------------------------------------------===//
149 bool allIndexingsAreProjectedPermutation(LinalgOp op
) {
150 return llvm::all_of(op
.getIndexingMapsArray(), [](AffineMap m
) {
151 return m
.isProjectedPermutation(/*allowZeroInResults=*/true);
155 bool hasOnlyScalarElementwiseOp(Region
&r
) {
156 if (!llvm::hasSingleElement(r
))
158 for (Operation
&op
: r
.front()) {
159 if (!(isa
<arith::ConstantOp
, func::ConstantOp
, tensor::ExtractOp
,
160 linalg::YieldOp
, linalg::IndexOp
, AffineApplyOp
>(op
) ||
161 OpTrait::hasElementwiseMappableTraits(&op
)) ||
162 llvm::any_of(op
.getResultTypes(),
163 [](Type type
) { return !type
.isIntOrIndexOrFloat(); }))
169 bool isElementwise(LinalgOp op
) {
170 if (op
.getNumLoops() != op
.getNumParallelLoops())
173 if (!allIndexingsAreProjectedPermutation(op
))
176 // TODO: relax the restrictions on indexing map.
177 for (OpOperand
&opOperand
: op
.getDpsInitsMutable()) {
178 if (!op
.getMatchingIndexingMap(&opOperand
).isPermutation())
181 return hasOnlyScalarElementwiseOp(op
->getRegion(0));
184 bool isParallelIterator(utils::IteratorType iteratorType
) {
185 return iteratorType
== utils::IteratorType::parallel
;
188 bool isReductionIterator(utils::IteratorType iteratorType
) {
189 return iteratorType
== utils::IteratorType::reduction
;
192 Value
makeComposedPadHighOp(OpBuilder
&b
, Location loc
, RankedTensorType type
,
193 Value source
, Value pad
, bool nofold
) {
194 // Exit if `source` is not defined by an ExtractSliceOp.
195 auto sliceOp
= source
.getDefiningOp
<tensor::ExtractSliceOp
>();
197 return tensor::createPadHighOp(type
, source
, pad
, nofold
, loc
, b
);
199 // Search the `source` use-def chain for padded LinalgOps.
200 Value current
= sliceOp
.getSource();
202 auto linalgOp
= current
.getDefiningOp
<LinalgOp
>();
205 OpResult opResult
= cast
<OpResult
>(current
);
206 current
= linalgOp
.getDpsInitOperand(opResult
.getResultNumber())->get();
208 auto padOp
= current
? current
.getDefiningOp
<tensor::PadOp
>() : nullptr;
210 // Exit if the search fails to match a tensor::PadOp at the end of the matched
211 // LinalgOp sequence.
213 return tensor::createPadHighOp(type
, source
, pad
, nofold
, loc
, b
);
215 // Exit if the padded result type does not match.
216 if (sliceOp
.getSource().getType() != type
)
217 return tensor::createPadHighOp(type
, source
, pad
, nofold
, loc
, b
);
219 // Exit if the LinalgOps are not high padded.
220 if (llvm::any_of(padOp
.getMixedLowPad(), [](OpFoldResult ofr
) {
221 return getConstantIntValue(ofr
) != static_cast<int64_t>(0);
223 return tensor::createPadHighOp(type
, source
, pad
, nofold
, loc
, b
);
225 // Exit if `padOpSliceOp`, which defines the slice used by
226 // `padOp`, is rank-reducing.
227 auto padOpSliceOp
= padOp
.getSource().getDefiningOp
<tensor::ExtractSliceOp
>();
229 sliceOp
.getMixedSizes().size() != padOpSliceOp
.getMixedSizes().size())
230 return tensor::createPadHighOp(type
, source
, pad
, nofold
, loc
, b
);
232 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
233 // of the slice padded by `padOp`.
235 llvm::zip(sliceOp
.getMixedSizes(), padOpSliceOp
.getMixedSizes()),
236 [](std::tuple
<OpFoldResult
, OpFoldResult
> it
) {
237 return !isEqualConstantIntOrValue(std::get
<0>(it
), std::get
<1>(it
));
239 return tensor::createPadHighOp(type
, source
, pad
, nofold
, loc
, b
);
241 // Exit if the padding values do not match.
242 Attribute padOpPadAttr
, padAttr
;
243 Value padOpPad
= padOp
.getConstantPaddingValue();
244 if (!padOpPad
|| !matchPattern(padOpPad
, m_Constant(&padOpPadAttr
)) ||
245 !matchPattern(pad
, m_Constant(&padAttr
)) || padOpPadAttr
!= padAttr
)
246 return tensor::createPadHighOp(type
, source
, pad
, nofold
, loc
, b
);
248 // Return the padded result if the padding values and sizes match.
249 return sliceOp
.getSource();
252 GenericOp
makeMemRefCopyOp(OpBuilder
&b
, Location loc
, Value from
, Value to
) {
253 auto memrefTypeTo
= cast
<MemRefType
>(to
.getType());
255 auto memrefTypeFrom
= cast
<MemRefType
>(from
.getType());
256 assert(memrefTypeFrom
.getRank() == memrefTypeTo
.getRank() &&
257 "`from` and `to` memref must have the same rank");
261 AffineMap::getMultiDimIdentityMap(memrefTypeTo
.getRank(), b
.getContext());
262 SmallVector
<utils::IteratorType
> iteratorTypes(memrefTypeTo
.getRank(),
263 utils::IteratorType::parallel
);
264 return b
.create
<linalg::GenericOp
>(
268 /*indexingMaps=*/llvm::ArrayRef({id
, id
}),
269 /*iteratorTypes=*/iteratorTypes
,
270 [](OpBuilder
&b
, Location loc
, ValueRange args
) {
271 b
.create
<linalg::YieldOp
>(loc
, args
.front());
275 /// Specialization to build an scf "for" nest.
277 void GenerateLoopNest
<scf::ForOp
>::doit(
278 OpBuilder
&b
, Location loc
, ArrayRef
<Range
> loopRanges
, LinalgOp linalgOp
,
279 ArrayRef
<utils::IteratorType
> iteratorTypes
,
280 function_ref
<scf::ValueVector(OpBuilder
&, Location
, ValueRange
,
283 ArrayRef
<linalg::ProcInfo
> procInfo
) {
284 assert((procInfo
.empty() || (procInfo
.size() == loopRanges
.size())) &&
285 "expected as many entries for proc info as number of loops, even if "
286 "they are null entries");
287 SmallVector
<Value
> iterArgInitValues
;
288 if (!linalgOp
.hasPureBufferSemantics())
289 llvm::append_range(iterArgInitValues
, linalgOp
.getDpsInits());
290 SmallVector
<Value
, 4> lbs
, ubs
, steps
;
291 unpackRanges(b
, loc
, loopRanges
, lbs
, ubs
, steps
);
292 LoopNest loopNest
= mlir::scf::buildLoopNest(
293 b
, loc
, lbs
, ubs
, steps
, iterArgInitValues
,
294 [&](OpBuilder
&b
, Location loc
, ValueRange ivs
, ValueRange iterArgs
) {
295 assert(iterArgs
.size() == iterArgInitValues
.size() &&
296 "expect the number of output tensors and iter args to match");
297 SmallVector
<Value
> operandValuesToUse
= linalgOp
->getOperands();
298 if (!iterArgs
.empty()) {
299 operandValuesToUse
= linalgOp
.getDpsInputs();
300 operandValuesToUse
.append(iterArgs
.begin(), iterArgs
.end());
302 return bodyBuilderFn(b
, loc
, ivs
, operandValuesToUse
);
305 if (loopNest
.loops
.empty() || procInfo
.empty())
308 // Filter out scf.for loops that were created out of parallel dimensions.
309 for (const auto &loop
: llvm::enumerate(loopNest
.loops
)) {
310 if (procInfo
[loop
.index()].distributionMethod
==
311 DistributionMethod::Cyclic
) {
312 mapLoopToProcessorIds(loop
.value(), procInfo
[loop
.index()].procId
,
313 procInfo
[loop
.index()].nprocs
);
318 /// Specialization to build affine "for" nest.
320 void GenerateLoopNest
<AffineForOp
>::doit(
321 OpBuilder
&b
, Location loc
, ArrayRef
<Range
> loopRanges
, LinalgOp linalgOp
,
322 ArrayRef
<utils::IteratorType
> iteratorTypes
,
323 function_ref
<scf::ValueVector(OpBuilder
&, Location
, ValueRange
,
326 ArrayRef
<linalg::ProcInfo
> /*procInfo*/) {
327 SmallVector
<Value
> iterArgInitValues
;
328 if (!linalgOp
.hasPureBufferSemantics())
329 llvm::append_range(iterArgInitValues
, linalgOp
.getDpsInits());
330 assert(iterArgInitValues
.empty() && "unexpected AffineForOp init values");
331 SmallVector
<Value
, 4> lbs
, ubs
, steps
;
332 unpackRanges(b
, loc
, loopRanges
, lbs
, ubs
, steps
);
334 // Affine loops require constant steps.
335 SmallVector
<int64_t, 4> constantSteps
;
336 constantSteps
.reserve(steps
.size());
337 for (Value v
: steps
) {
338 auto constVal
= getConstantIntValue(v
);
339 assert(constVal
.has_value() && "Affine loops require constant steps");
340 constantSteps
.push_back(constVal
.value());
343 affine::buildAffineLoopNest(b
, loc
, lbs
, ubs
, constantSteps
,
344 [&](OpBuilder
&b
, Location loc
, ValueRange ivs
) {
345 bodyBuilderFn(b
, loc
, ivs
,
346 linalgOp
->getOperands());
350 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
351 void updateBoundsForCyclicDistribution(OpBuilder
&b
, Location loc
, Value procId
,
352 Value nprocs
, Value
&lb
, Value
&ub
,
355 bindDims(b
.getContext(), d0
, d1
);
356 AffineExpr s0
= getAffineSymbolExpr(0, b
.getContext());
358 affine::makeComposedAffineApply(b
, loc
, d0
+ d1
* s0
, {lb
, procId
, step
});
359 step
= affine::makeComposedAffineApply(b
, loc
, d0
* s0
, {nprocs
, step
});
362 /// Generates a loop nest consisting of scf.parallel and scf.for, depending
363 /// on the `iteratorTypes.` Consecutive parallel loops create a single
364 /// scf.parallel operation; each sequential loop creates a new scf.for
365 /// operation. The body of the innermost loop is populated by
366 /// `bodyBuilderFn` that accepts a range of induction variables for all
367 /// loops. `ivStorage` is used to store the partial list of induction
369 // TODO: this function can be made iterative instead. However, it
370 // will have at most as many recursive calls as nested loops, which rarely
372 static void generateParallelLoopNest(
373 OpBuilder
&b
, Location loc
, ValueRange lbs
, ValueRange ubs
,
374 ValueRange steps
, ArrayRef
<utils::IteratorType
> iteratorTypes
,
375 ArrayRef
<linalg::ProcInfo
> procInfo
,
376 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuilderFn
,
377 SmallVectorImpl
<Value
> &ivStorage
) {
378 assert(lbs
.size() == ubs
.size());
379 assert(lbs
.size() == steps
.size());
380 assert(lbs
.size() == iteratorTypes
.size());
381 assert(procInfo
.empty() || (lbs
.size() == procInfo
.size()));
383 // If there are no (more) loops to be generated, generate the body and be
385 if (iteratorTypes
.empty()) {
386 bodyBuilderFn(b
, loc
, ivStorage
);
390 // If there are no outer parallel loops, generate one sequential loop and
392 if (!isParallelIterator(iteratorTypes
.front())) {
393 LoopNest singleLoop
= buildLoopNest(
394 b
, loc
, lbs
.take_front(), ubs
.take_front(), steps
.take_front(),
395 [&](OpBuilder
&b
, Location loc
, ValueRange ivs
) {
396 ivStorage
.append(ivs
.begin(), ivs
.end());
397 generateParallelLoopNest(
398 b
, loc
, lbs
.drop_front(), ubs
.drop_front(), steps
.drop_front(),
399 iteratorTypes
.drop_front(),
400 procInfo
.empty() ? procInfo
: procInfo
.drop_front(),
401 bodyBuilderFn
, ivStorage
);
406 unsigned nLoops
= iteratorTypes
.size();
407 unsigned numProcessed
= 0;
408 DistributionMethod distributionMethod
= DistributionMethod::None
;
409 if (procInfo
.empty()) {
410 numProcessed
= nLoops
- iteratorTypes
.drop_while(isParallelIterator
).size();
412 distributionMethod
= procInfo
.front().distributionMethod
;
415 .drop_while([&](linalg::ProcInfo p
) {
416 return p
.distributionMethod
== distributionMethod
;
421 auto remainderProcInfo
=
422 procInfo
.empty() ? procInfo
: procInfo
.drop_front(numProcessed
);
423 switch (distributionMethod
) {
424 case DistributionMethod::None
: {
425 // Generate a single parallel loop-nest operation for all outermost
426 // parallel loops and recurse.
427 b
.create
<scf::ParallelOp
>(
428 loc
, lbs
.take_front(numProcessed
), ubs
.take_front(numProcessed
),
429 steps
.take_front(numProcessed
),
430 [&](OpBuilder
&nestedBuilder
, Location nestedLoc
, ValueRange localIvs
) {
431 ivStorage
.append(localIvs
.begin(), localIvs
.end());
432 generateParallelLoopNest(
433 nestedBuilder
, nestedLoc
, lbs
.drop_front(numProcessed
),
434 ubs
.drop_front(numProcessed
), steps
.drop_front(numProcessed
),
435 iteratorTypes
.drop_front(numProcessed
), remainderProcInfo
,
436 bodyBuilderFn
, ivStorage
);
440 case DistributionMethod::Cyclic
: {
441 // Generate a single parallel loop-nest operation for all outermost
442 // parallel loops and recurse.
443 b
.create
<scf::ParallelOp
>(
444 loc
, lbs
.take_front(numProcessed
), ubs
.take_front(numProcessed
),
445 steps
.take_front(numProcessed
),
446 [&](OpBuilder
&nestedBuilder
, Location nestedLoc
, ValueRange localIvs
) {
447 ivStorage
.append(localIvs
.begin(), localIvs
.end());
448 generateParallelLoopNest(
449 nestedBuilder
, nestedLoc
, lbs
.drop_front(numProcessed
),
450 ubs
.drop_front(numProcessed
), steps
.drop_front(numProcessed
),
451 iteratorTypes
.drop_front(numProcessed
), remainderProcInfo
,
452 bodyBuilderFn
, ivStorage
);
456 case DistributionMethod::CyclicNumProcsGeNumIters
: {
457 // Check (for the processed loops) that the iteration is in-bounds.
458 ArithBuilder
ab(b
, loc
);
459 Value cond
= ab
.slt(lbs
[0], ubs
[0]);
460 for (unsigned i
= 1; i
< numProcessed
; ++i
)
461 cond
= ab
._and(cond
, ab
.slt(lbs
[i
], ubs
[i
]));
462 ivStorage
.append(lbs
.begin(), std::next(lbs
.begin(), numProcessed
));
463 b
.create
<scf::IfOp
>(loc
, cond
, [&](OpBuilder
&b
, Location loc
) {
464 generateParallelLoopNest(b
, loc
, lbs
.drop_front(numProcessed
),
465 ubs
.drop_front(numProcessed
),
466 steps
.drop_front(numProcessed
),
467 iteratorTypes
.drop_front(numProcessed
),
468 remainderProcInfo
, bodyBuilderFn
, ivStorage
);
469 b
.create
<scf::YieldOp
>(loc
, ValueRange
{});
473 case DistributionMethod::CyclicNumProcsEqNumIters
:
474 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
475 // with inner loop generation.
476 ivStorage
.append(lbs
.begin(), std::next(lbs
.begin(), numProcessed
));
477 generateParallelLoopNest(
478 b
, loc
, lbs
.drop_front(numProcessed
), ubs
.drop_front(numProcessed
),
479 steps
.drop_front(numProcessed
), iteratorTypes
.drop_front(numProcessed
),
480 remainderProcInfo
, bodyBuilderFn
, ivStorage
);
485 /// Specialization for generating a mix of parallel and sequential scf loops.
487 void GenerateLoopNest
<scf::ParallelOp
>::doit(
488 OpBuilder
&b
, Location loc
, ArrayRef
<Range
> loopRanges
, LinalgOp linalgOp
,
489 ArrayRef
<utils::IteratorType
> iteratorTypes
,
490 function_ref
<scf::ValueVector(OpBuilder
&, Location
, ValueRange
,
493 ArrayRef
<linalg::ProcInfo
> procInfo
) {
494 SmallVector
<Value
> iterArgInitValues
;
495 if (!linalgOp
.hasPureBufferSemantics())
496 llvm::append_range(iterArgInitValues
, linalgOp
.getDpsInits());
497 assert(iterArgInitValues
.empty() && "unexpected ParallelOp init values");
498 // This function may be passed more iterator types than ranges.
499 assert(iteratorTypes
.size() >= loopRanges
.size() &&
500 "expected iterator type for all ranges");
501 assert((procInfo
.empty() || (procInfo
.size() == loopRanges
.size())) &&
502 "expected proc information for all loops when present");
503 iteratorTypes
= iteratorTypes
.take_front(loopRanges
.size());
504 SmallVector
<Value
, 8> lbsStorage
, ubsStorage
, stepsStorage
, ivs
;
505 unsigned numLoops
= iteratorTypes
.size();
506 ivs
.reserve(numLoops
);
507 lbsStorage
.reserve(numLoops
);
508 ubsStorage
.reserve(numLoops
);
509 stepsStorage
.reserve(numLoops
);
511 // Get the loop lb, ub, and step.
512 unpackRanges(b
, loc
, loopRanges
, lbsStorage
, ubsStorage
, stepsStorage
);
514 // Modify the lb, ub, and step based on the distribution options.
515 for (const auto &it
: llvm::enumerate(procInfo
)) {
516 if (it
.value().distributionMethod
!= linalg::DistributionMethod::None
) {
517 updateBoundsForCyclicDistribution(
518 b
, loc
, it
.value().procId
, it
.value().nprocs
, lbsStorage
[it
.index()],
519 ubsStorage
[it
.index()], stepsStorage
[it
.index()]);
522 ValueRange
lbs(lbsStorage
), ubs(ubsStorage
), steps(stepsStorage
);
523 generateParallelLoopNest(
524 b
, loc
, lbs
, ubs
, steps
, iteratorTypes
, procInfo
,
525 [&](OpBuilder
&b
, Location loc
, ValueRange ivs
) {
526 bodyBuilderFn(b
, loc
, ivs
, linalgOp
->getOperands());
530 assert(ivs
.size() == iteratorTypes
.size() && "did not generate enough loops");
533 static Operation
*materializeTiledShape(OpBuilder
&builder
, Location loc
,
535 const SliceParameters
&sliceParams
) {
536 auto shapedType
= dyn_cast
<ShapedType
>(valueToTile
.getType());
537 auto *sliceOp
= TypeSwitch
<ShapedType
, Operation
*>(shapedType
)
538 .Case([&](MemRefType
) {
539 return builder
.create
<memref::SubViewOp
>(
540 loc
, valueToTile
, sliceParams
.offsets
,
541 sliceParams
.sizes
, sliceParams
.strides
);
543 .Case([&](RankedTensorType
) {
544 return builder
.create
<tensor::ExtractSliceOp
>(
545 loc
, valueToTile
, sliceParams
.offsets
,
546 sliceParams
.sizes
, sliceParams
.strides
);
548 .Default([](ShapedType
) -> Operation
* {
549 llvm_unreachable("Unexpected shaped type");
554 Operation
*makeTiledShape(OpBuilder
&builder
, Location loc
, Value valueToTile
,
555 ArrayRef
<OpFoldResult
> tileSizes
, AffineMap map
,
556 ArrayRef
<OpFoldResult
> lbs
,
557 ArrayRef
<OpFoldResult
> ubs
,
558 ArrayRef
<OpFoldResult
> subShapeSizes
,
559 bool omitPartialTileCheck
) {
560 SliceParameters sliceParams
=
561 computeSliceParameters(builder
, loc
, valueToTile
, tileSizes
, map
, lbs
,
562 ubs
, subShapeSizes
, omitPartialTileCheck
);
563 return materializeTiledShape(builder
, loc
, valueToTile
, sliceParams
);
567 computeSliceParameters(OpBuilder
&builder
, Location loc
, Value valueToTile
,
568 ArrayRef
<OpFoldResult
> tileSizes
, AffineMap map
,
569 ArrayRef
<OpFoldResult
> lbs
, ArrayRef
<OpFoldResult
> ubs
,
570 ArrayRef
<OpFoldResult
> subShapeSizes
,
571 bool omitPartialTileCheck
) {
572 auto shapedType
= dyn_cast
<ShapedType
>(valueToTile
.getType());
573 assert(shapedType
&& "only shaped types can be tiled");
574 ArrayRef
<int64_t> shape
= shapedType
.getShape();
575 int64_t rank
= shapedType
.getRank();
577 // Compute offsets/sizes/strides for the tile.
578 SliceParameters sliceParams
;
579 sliceParams
.offsets
.reserve(rank
);
580 sliceParams
.sizes
.reserve(rank
);
581 sliceParams
.strides
.reserve(rank
);
582 for (unsigned r
= 0; r
< rank
; ++r
) {
583 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r
);
584 if (!isTiled(map
.getSubMap({r
}), tileSizes
)) {
585 sliceParams
.offsets
.push_back(builder
.getIndexAttr(0));
586 OpFoldResult dim
= createFoldedDimOp(builder
, loc
, valueToTile
, r
);
587 sliceParams
.sizes
.push_back(dim
);
588 sliceParams
.strides
.push_back(builder
.getIndexAttr(1));
589 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim
<< "\n");
592 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
594 // Tiling creates a new slice at the proper index, the slice step is 1
595 // (i.e. the op does not subsample, stepping occurs in the loop).
596 auto m
= map
.getSubMap({r
});
597 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m
<< "\n");
598 IRRewriter
rewriter(builder
);
599 OpFoldResult offset
= makeComposedFoldedAffineApply(rewriter
, loc
, m
, lbs
);
600 sliceParams
.offsets
.push_back(offset
);
601 OpFoldResult closedIntSize
=
602 makeComposedFoldedAffineApply(rewriter
, loc
, m
, subShapeSizes
);
603 // Resulting size needs to be made half open interval again.
604 AffineExpr s0
= getAffineSymbolExpr(0, builder
.getContext());
606 makeComposedFoldedAffineApply(rewriter
, loc
, s0
+ 1, closedIntSize
);
607 LLVM_DEBUG(llvm::dbgs()
608 << "computeSliceParameters: raw size: " << size
<< "\n");
609 LLVM_DEBUG(llvm::dbgs()
610 << "computeSliceParameters: new offset: " << offset
<< "\n");
611 sliceParams
.strides
.push_back(builder
.getIndexAttr(1));
613 if (omitPartialTileCheck
) {
614 // We statically know that the partial/boundary tile condition is
616 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size
<< "\n");
617 sliceParams
.sizes
.push_back(size
);
621 // The size of the subview / extract_slice should be trimmed to avoid
622 // out-of-bounds accesses, unless:
623 // a. We statically know the subshape size divides the shape size evenly.
624 // b. The subshape size is 1. According to the way the loops are set up,
625 // tensors with "0" dimensions would never be constructed.
626 int64_t shapeSize
= shape
[r
];
627 std::optional
<int64_t> sizeCst
= getConstantIntValue(size
);
628 auto hasTileSizeOne
= sizeCst
&& *sizeCst
== 1;
629 auto dividesEvenly
= sizeCst
&& !ShapedType::isDynamic(shapeSize
) &&
630 ((shapeSize
% *sizeCst
) == 0);
631 if (!hasTileSizeOne
&& !dividesEvenly
) {
632 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
633 << ", size: " << size
634 << ": make sure in bound with affine.min\n");
636 AffineExpr dim0
, dim1
, dim2
;
637 MLIRContext
*context
= builder
.getContext();
638 bindDims(context
, dim0
, dim1
, dim2
);
640 // Get the dimension size for this dimension. We need to first calculate
641 // the max index and then plus one. This is important because for
642 // convolution ops, we have its input window dimension's affine map of the
643 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
644 // dimension and `s0` is stride. Directly use the dimension size of
645 // output/filer window dimensions will cause incorrect calculation.
646 AffineMap minusOneMap
= AffineMap::inferFromExprList(
647 {ArrayRef
<AffineExpr
>{dim0
- 1}}, context
)
649 AffineMap plusOneMap
= AffineMap::inferFromExprList(
650 {ArrayRef
<AffineExpr
>{dim0
+ 1}}, context
)
652 SmallVector
<OpFoldResult
> maxIndices
=
653 llvm::to_vector(llvm::map_range(ubs
, [&](OpFoldResult ub
) {
654 return makeComposedFoldedAffineApply(rewriter
, loc
, minusOneMap
,
657 OpFoldResult maxIndex
=
658 makeComposedFoldedAffineApply(rewriter
, loc
, m
, maxIndices
);
660 makeComposedFoldedAffineApply(rewriter
, loc
, plusOneMap
, {maxIndex
});
662 // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
663 AffineMap minMap
= AffineMap::inferFromExprList(
664 {ArrayRef
<AffineExpr
>{dim1
- dim2
, dim0
}}, context
)
667 makeComposedFoldedAffineMin(rewriter
, loc
, minMap
, {size
, d
, offset
});
669 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size
<< "\n");
670 sliceParams
.sizes
.push_back(size
);
675 SmallVector
<OpFoldResult
> computeTileOffsets(OpBuilder
&b
, Location loc
,
676 ArrayRef
<OpFoldResult
> ivs
,
677 ArrayRef
<OpFoldResult
> tileSizes
) {
678 SmallVector
<OpFoldResult
> offsets
;
679 for (unsigned idx
= 0, idxIvs
= 0, e
= tileSizes
.size(); idx
< e
; ++idx
) {
680 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx
<< "\n");
681 bool isTiled
= !isZeroIndex(tileSizes
[idx
]);
682 offsets
.push_back(isTiled
? ivs
[idxIvs
++] : b
.getIndexAttr(0));
683 LLVM_DEBUG(llvm::dbgs()
684 << "computeTileOffsets: " << offsets
.back() << "\n");
689 SmallVector
<OpFoldResult
> computeTileSizes(OpBuilder
&b
, Location loc
,
690 ArrayRef
<OpFoldResult
> tileSizes
,
691 ArrayRef
<OpFoldResult
> sizeBounds
) {
692 SmallVector
<OpFoldResult
> sizes
;
693 for (unsigned idx
= 0, e
= tileSizes
.size(); idx
< e
; ++idx
) {
694 bool isTiled
= !isZeroIndex(tileSizes
[idx
]);
695 // Before composing, we need to make range a closed interval.
696 OpFoldResult size
= isTiled
? tileSizes
[idx
] : sizeBounds
[idx
];
697 AffineExpr d0
= getAffineDimExpr(0, b
.getContext());
698 IRRewriter
rewriter(b
);
699 sizes
.push_back(makeComposedFoldedAffineApply(rewriter
, loc
, d0
- 1, size
));
700 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes
.back() << "\n");
705 SmallVector
<Type
> getTensorOutputTypes(LinalgOp op
, ValueRange operands
) {
706 if (op
.hasPureBufferSemantics())
708 return llvm::to_vector(
709 llvm::map_range(op
.getDpsInitsMutable(), [&](OpOperand
&opOperand
) {
710 return operands
[opOperand
.getOperandNumber()].getType();
714 SmallVector
<Value
> insertSlicesBack(OpBuilder
&builder
, Location loc
,
715 LinalgOp op
, ValueRange operands
,
716 ValueRange results
) {
717 if (op
.hasPureBufferSemantics())
719 SmallVector
<Value
> tensorResults
;
720 tensorResults
.reserve(results
.size());
721 // Insert a insert_slice for each output tensor.
722 unsigned resultIdx
= 0;
723 for (OpOperand
&opOperand
: op
.getDpsInitsMutable()) {
724 // TODO: use an interface/adaptor to avoid leaking position in
726 Value outputTensor
= operands
[opOperand
.getOperandNumber()];
727 if (auto sliceOp
= outputTensor
.getDefiningOp
<tensor::ExtractSliceOp
>()) {
728 Value inserted
= builder
.create
<tensor::InsertSliceOp
>(
729 loc
, sliceOp
.getSource().getType(), results
[resultIdx
],
730 sliceOp
.getSource(), sliceOp
.getOffsets(), sliceOp
.getSizes(),
731 sliceOp
.getStrides(), sliceOp
.getStaticOffsets(),
732 sliceOp
.getStaticSizes(), sliceOp
.getStaticStrides());
733 tensorResults
.push_back(inserted
);
735 tensorResults
.push_back(results
[resultIdx
]);
739 return tensorResults
;
742 SmallVector
<std::optional
<SliceParameters
>>
743 computeAllSliceParameters(OpBuilder
&builder
, Location loc
, LinalgOp linalgOp
,
744 ValueRange valuesToTile
, ArrayRef
<OpFoldResult
> ivs
,
745 ArrayRef
<OpFoldResult
> tileSizes
,
746 ArrayRef
<OpFoldResult
> sizeBounds
,
747 bool omitPartialTileCheck
) {
748 assert(ivs
.size() == static_cast<size_t>(llvm::count_if(
749 llvm::make_range(tileSizes
.begin(), tileSizes
.end()),
750 [](OpFoldResult v
) { return !isZeroIndex(v
); })) &&
751 "expected as many ivs as non-zero sizes");
753 // Construct (potentially temporary) mins and maxes on which to apply maps
754 // that define tile subshapes.
755 SmallVector
<OpFoldResult
> lbs
=
756 computeTileOffsets(builder
, loc
, ivs
, tileSizes
);
757 SmallVector
<OpFoldResult
> subShapeSizes
=
758 computeTileSizes(builder
, loc
, tileSizes
, sizeBounds
);
760 assert(static_cast<int64_t>(valuesToTile
.size()) <=
761 linalgOp
->getNumOperands() &&
762 "more value to tile than operands.");
763 SmallVector
<std::optional
<SliceParameters
>> allSliceParams
;
764 allSliceParams
.reserve(valuesToTile
.size());
765 for (auto [opOperand
, val
] :
766 llvm::zip(linalgOp
->getOpOperands(), valuesToTile
)) {
767 Value shapedOp
= val
;
768 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp
);
769 AffineMap map
= linalgOp
.getMatchingIndexingMap(&opOperand
);
770 // Use `opOperand` as is if it is not tiled and not an output tensor. Having
771 // an extract/insert slice pair for all output tensors simplifies follow up
772 // transformations such as padding and bufferization since the
773 // extract/insert slice pairs make the accessed iteration argument
774 // subdomains explicit.
776 Type operandType
= opOperand
.get().getType();
777 if (!isTiled(map
, tileSizes
) && !(isa
<RankedTensorType
>(operandType
) &&
778 linalgOp
.isDpsInit(&opOperand
))) {
779 allSliceParams
.push_back(std::nullopt
);
780 LLVM_DEBUG(llvm::dbgs()
781 << ": not tiled: use shape: " << operandType
<< "\n");
784 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
786 allSliceParams
.push_back(computeSliceParameters(
787 builder
, loc
, shapedOp
, tileSizes
, map
, lbs
, sizeBounds
, subShapeSizes
,
788 omitPartialTileCheck
));
791 return allSliceParams
;
794 SmallVector
<Value
> makeTiledShapes(OpBuilder
&builder
, Location loc
,
795 LinalgOp linalgOp
, ValueRange valuesToTile
,
796 ArrayRef
<OpFoldResult
> ivs
,
797 ArrayRef
<OpFoldResult
> tileSizes
,
798 ArrayRef
<OpFoldResult
> sizeBounds
,
799 bool omitPartialTileCheck
) {
800 SmallVector
<std::optional
<SliceParameters
>> allSliceParameter
=
801 computeAllSliceParameters(builder
, loc
, linalgOp
, valuesToTile
, ivs
,
802 tileSizes
, sizeBounds
, omitPartialTileCheck
);
803 SmallVector
<Value
> tiledShapes
;
804 for (auto item
: llvm::zip(valuesToTile
, allSliceParameter
)) {
805 Value valueToTile
= std::get
<0>(item
);
806 std::optional
<SliceParameters
> sliceParams
= std::get
<1>(item
);
807 tiledShapes
.push_back(
808 sliceParams
.has_value()
809 ? materializeTiledShape(builder
, loc
, valueToTile
, *sliceParams
)
816 void offsetIndices(OpBuilder
&b
, LinalgOp linalgOp
,
817 ArrayRef
<OpFoldResult
> offsets
) {
818 IRRewriter
rewriter(b
);
819 offsetIndices(rewriter
, linalgOp
, offsets
);
822 void offsetIndices(RewriterBase
&b
, LinalgOp linalgOp
,
823 ArrayRef
<OpFoldResult
> offsets
) {
824 if (!linalgOp
.hasIndexSemantics())
827 for (IndexOp indexOp
: linalgOp
.getBlock()->getOps
<IndexOp
>()) {
828 if (indexOp
.getDim() >= offsets
.size() || !offsets
[indexOp
.getDim()])
830 OpBuilder::InsertionGuard
guard(b
);
831 b
.setInsertionPointAfter(indexOp
);
832 AffineExpr index
, offset
;
833 bindDims(b
.getContext(), index
, offset
);
834 OpFoldResult applied
= makeComposedFoldedAffineApply(
835 b
, indexOp
.getLoc(), index
+ offset
,
836 {getAsOpFoldResult(indexOp
.getResult()), offsets
[indexOp
.getDim()]});
838 getValueOrCreateConstantIndexOp(b
, indexOp
.getLoc(), applied
);
839 b
.replaceUsesWithIf(indexOp
, materialized
, [&](OpOperand
&use
) {
840 return use
.getOwner() != materialized
.getDefiningOp();
845 /// Get the reassociation maps to fold the result of a extract_slice (or source
846 /// of a insert_slice) operation with given offsets, and sizes to its
847 /// rank-reduced version. This is only done for the cases where the size is 1
848 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
849 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
850 /// potentially cannot be handled).
851 std::optional
<SmallVector
<ReassociationIndices
>>
852 getReassociationMapForFoldingUnitDims(ArrayRef
<OpFoldResult
> mixedSizes
) {
853 SmallVector
<ReassociationIndices
> reassociation
;
854 ReassociationIndices curr
;
855 for (const auto &it
: llvm::enumerate(mixedSizes
)) {
856 auto dim
= it
.index();
857 auto size
= it
.value();
859 auto attr
= llvm::dyn_cast_if_present
<Attribute
>(size
);
860 if (attr
&& cast
<IntegerAttr
>(attr
).getInt() == 1)
862 reassociation
.emplace_back(ReassociationIndices
{});
863 std::swap(reassociation
.back(), curr
);
865 // When the reassociations are not empty, then fold the remaining
866 // unit-dimensions into the last dimension. If the reassociations so far is
867 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
868 if (!curr
.empty() && !reassociation
.empty())
869 reassociation
.back().append(curr
.begin(), curr
.end());
870 return reassociation
;
873 } // namespace linalg