Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Dialect / Linalg / Utils / Utils.cpp
blob38e427af1c4846713cc35e427fa2a8d716150d51
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/Arith/Utils/Utils.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Linalg/IR/Linalg.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h"
27 #include "mlir/Dialect/Tensor/Utils/Utils.h"
28 #include "mlir/Dialect/Utils/IndexingUtils.h"
29 #include "mlir/Dialect/Utils/StaticValueUtils.h"
30 #include "mlir/IR/AffineExpr.h"
31 #include "mlir/IR/AffineExprVisitor.h"
32 #include "mlir/IR/AffineMap.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/Pass/Pass.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/Debug.h"
38 #include <optional>
40 #define DEBUG_TYPE "linalg-utils"
42 using namespace mlir;
43 using namespace presburger;
44 using namespace mlir::affine;
45 using namespace mlir::linalg;
46 using namespace mlir::scf;
48 namespace {
50 // Helper visitor to determine whether an AffineExpr is tiled.
51 // This is achieved by traversing every AffineDimExpr with position `pos` and
52 // checking whether the corresponding `tileSizes[pos]` is non-zero.
53 // This also enforces only positive coefficients occur in multiplications.
55 // Example:
56 // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
58 struct TileCheck : public AffineExprVisitor<TileCheck> {
59 TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
61 void visitDimExpr(AffineDimExpr expr) {
62 isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
64 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
65 visit(expr.getLHS());
66 visit(expr.getRHS());
67 if (expr.getKind() == mlir::AffineExprKind::Mul)
68 assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 &&
69 "nonpositive multiplying coefficient");
71 bool isTiled = false;
72 ArrayRef<OpFoldResult> tileSizes;
75 } // namespace
77 static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
78 if (!expr)
79 return false;
80 TileCheck t(tileSizes);
81 t.visit(expr);
82 return t.isTiled;
85 // Checks whether the `map varies with respect to a non-zero `tileSize`.
86 static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
87 if (!map)
88 return false;
89 for (unsigned r = 0; r < map.getNumResults(); ++r)
90 if (isTiled(map.getResult(r), tileSizes))
91 return true;
92 return false;
95 std::optional<RegionMatcher::BinaryOpKind>
96 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
97 auto &region = op.getRegion();
98 if (!llvm::hasSingleElement(region))
99 return std::nullopt;
101 Block &block = region.front();
102 if (block.getNumArguments() != 2 ||
103 !block.getArgument(0).getType().isSignlessIntOrFloat() ||
104 !block.getArgument(1).getType().isSignlessIntOrFloat())
105 return std::nullopt;
107 auto &ops = block.getOperations();
108 if (!llvm::hasSingleElement(block.without_terminator()))
109 return std::nullopt;
111 using mlir::matchers::m_Val;
112 auto a = m_Val(block.getArgument(0));
113 auto b = m_Val(block.getArgument(1));
115 auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b));
116 if (addPattern.match(&ops.back()))
117 return BinaryOpKind::IAdd;
119 return std::nullopt;
122 /// Explicit instantiation of loop nest generator for different loop types.
123 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
124 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
125 template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
127 /// Given a list of subview ranges, extract individual values for lower, upper
128 /// bounds and steps and put them into the corresponding vectors.
129 static void unpackRanges(OpBuilder &builder, Location loc,
130 ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
131 SmallVectorImpl<Value> &ubs,
132 SmallVectorImpl<Value> &steps) {
133 for (Range range : ranges) {
134 lbs.emplace_back(
135 getValueOrCreateConstantIndexOp(builder, loc, range.offset));
136 ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size));
137 steps.emplace_back(
138 getValueOrCreateConstantIndexOp(builder, loc, range.stride));
142 //===----------------------------------------------------------------------===//
143 // General utilities
144 //===----------------------------------------------------------------------===//
146 namespace mlir {
147 namespace linalg {
149 bool allIndexingsAreProjectedPermutation(LinalgOp op) {
150 return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
151 return m.isProjectedPermutation(/*allowZeroInResults=*/true);
155 bool hasOnlyScalarElementwiseOp(Region &r) {
156 if (!llvm::hasSingleElement(r))
157 return false;
158 for (Operation &op : r.front()) {
159 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
160 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
161 OpTrait::hasElementwiseMappableTraits(&op)) ||
162 llvm::any_of(op.getResultTypes(),
163 [](Type type) { return !type.isIntOrIndexOrFloat(); }))
164 return false;
166 return true;
169 bool isElementwise(LinalgOp op) {
170 if (op.getNumLoops() != op.getNumParallelLoops())
171 return false;
173 if (!allIndexingsAreProjectedPermutation(op))
174 return false;
176 // TODO: relax the restrictions on indexing map.
177 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
178 if (!op.getMatchingIndexingMap(&opOperand).isPermutation())
179 return false;
181 return hasOnlyScalarElementwiseOp(op->getRegion(0));
184 bool isParallelIterator(utils::IteratorType iteratorType) {
185 return iteratorType == utils::IteratorType::parallel;
188 bool isReductionIterator(utils::IteratorType iteratorType) {
189 return iteratorType == utils::IteratorType::reduction;
192 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
193 Value source, Value pad, bool nofold) {
194 // Exit if `source` is not defined by an ExtractSliceOp.
195 auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
196 if (!sliceOp)
197 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
199 // Search the `source` use-def chain for padded LinalgOps.
200 Value current = sliceOp.getSource();
201 while (current) {
202 auto linalgOp = current.getDefiningOp<LinalgOp>();
203 if (!linalgOp)
204 break;
205 OpResult opResult = cast<OpResult>(current);
206 current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
208 auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
210 // Exit if the search fails to match a tensor::PadOp at the end of the matched
211 // LinalgOp sequence.
212 if (!padOp)
213 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
215 // Exit if the padded result type does not match.
216 if (sliceOp.getSource().getType() != type)
217 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
219 // Exit if the LinalgOps are not high padded.
220 if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
221 return getConstantIntValue(ofr) != static_cast<int64_t>(0);
223 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
225 // Exit if `padOpSliceOp`, which defines the slice used by
226 // `padOp`, is rank-reducing.
227 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
228 if (!padOpSliceOp ||
229 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
230 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
232 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
233 // of the slice padded by `padOp`.
234 if (llvm::any_of(
235 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
236 [](std::tuple<OpFoldResult, OpFoldResult> it) {
237 return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
239 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
241 // Exit if the padding values do not match.
242 Attribute padOpPadAttr, padAttr;
243 Value padOpPad = padOp.getConstantPaddingValue();
244 if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) ||
245 !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr)
246 return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
248 // Return the padded result if the padding values and sizes match.
249 return sliceOp.getSource();
252 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
253 auto memrefTypeTo = cast<MemRefType>(to.getType());
254 #ifndef NDEBUG
255 auto memrefTypeFrom = cast<MemRefType>(from.getType());
256 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
257 "`from` and `to` memref must have the same rank");
258 #endif // NDEBUG
260 AffineMap id =
261 AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
262 SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
263 utils::IteratorType::parallel);
264 return b.create<linalg::GenericOp>(
265 loc,
266 /*inputs=*/from,
267 /*outputs=*/to,
268 /*indexingMaps=*/llvm::ArrayRef({id, id}),
269 /*iteratorTypes=*/iteratorTypes,
270 [](OpBuilder &b, Location loc, ValueRange args) {
271 b.create<linalg::YieldOp>(loc, args.front());
275 /// Specialization to build an scf "for" nest.
276 template <>
277 void GenerateLoopNest<scf::ForOp>::doit(
278 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
279 ArrayRef<utils::IteratorType> iteratorTypes,
280 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
281 ValueRange)>
282 bodyBuilderFn,
283 ArrayRef<linalg::ProcInfo> procInfo) {
284 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
285 "expected as many entries for proc info as number of loops, even if "
286 "they are null entries");
287 SmallVector<Value> iterArgInitValues;
288 if (!linalgOp.hasPureBufferSemantics())
289 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
290 SmallVector<Value, 4> lbs, ubs, steps;
291 unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
292 LoopNest loopNest = mlir::scf::buildLoopNest(
293 b, loc, lbs, ubs, steps, iterArgInitValues,
294 [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
295 assert(iterArgs.size() == iterArgInitValues.size() &&
296 "expect the number of output tensors and iter args to match");
297 SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
298 if (!iterArgs.empty()) {
299 operandValuesToUse = linalgOp.getDpsInputs();
300 operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
302 return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
305 if (loopNest.loops.empty() || procInfo.empty())
306 return;
308 // Filter out scf.for loops that were created out of parallel dimensions.
309 for (const auto &loop : llvm::enumerate(loopNest.loops)) {
310 if (procInfo[loop.index()].distributionMethod ==
311 DistributionMethod::Cyclic) {
312 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
313 procInfo[loop.index()].nprocs);
318 /// Specialization to build affine "for" nest.
319 template <>
320 void GenerateLoopNest<AffineForOp>::doit(
321 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
322 ArrayRef<utils::IteratorType> iteratorTypes,
323 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
324 ValueRange)>
325 bodyBuilderFn,
326 ArrayRef<linalg::ProcInfo> /*procInfo*/) {
327 SmallVector<Value> iterArgInitValues;
328 if (!linalgOp.hasPureBufferSemantics())
329 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
330 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
331 SmallVector<Value, 4> lbs, ubs, steps;
332 unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
334 // Affine loops require constant steps.
335 SmallVector<int64_t, 4> constantSteps;
336 constantSteps.reserve(steps.size());
337 for (Value v : steps) {
338 auto constVal = getConstantIntValue(v);
339 assert(constVal.has_value() && "Affine loops require constant steps");
340 constantSteps.push_back(constVal.value());
343 affine::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
344 [&](OpBuilder &b, Location loc, ValueRange ivs) {
345 bodyBuilderFn(b, loc, ivs,
346 linalgOp->getOperands());
350 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
351 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
352 Value nprocs, Value &lb, Value &ub,
353 Value &step) {
354 AffineExpr d0, d1;
355 bindDims(b.getContext(), d0, d1);
356 AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
357 lb =
358 affine::makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
359 step = affine::makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
362 /// Generates a loop nest consisting of scf.parallel and scf.for, depending
363 /// on the `iteratorTypes.` Consecutive parallel loops create a single
364 /// scf.parallel operation; each sequential loop creates a new scf.for
365 /// operation. The body of the innermost loop is populated by
366 /// `bodyBuilderFn` that accepts a range of induction variables for all
367 /// loops. `ivStorage` is used to store the partial list of induction
368 /// variables.
369 // TODO: this function can be made iterative instead. However, it
370 // will have at most as many recursive calls as nested loops, which rarely
371 // exceeds 10.
372 static void generateParallelLoopNest(
373 OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
374 ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
375 ArrayRef<linalg::ProcInfo> procInfo,
376 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
377 SmallVectorImpl<Value> &ivStorage) {
378 assert(lbs.size() == ubs.size());
379 assert(lbs.size() == steps.size());
380 assert(lbs.size() == iteratorTypes.size());
381 assert(procInfo.empty() || (lbs.size() == procInfo.size()));
383 // If there are no (more) loops to be generated, generate the body and be
384 // done with it.
385 if (iteratorTypes.empty()) {
386 bodyBuilderFn(b, loc, ivStorage);
387 return;
390 // If there are no outer parallel loops, generate one sequential loop and
391 // recurse.
392 if (!isParallelIterator(iteratorTypes.front())) {
393 LoopNest singleLoop = buildLoopNest(
394 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
395 [&](OpBuilder &b, Location loc, ValueRange ivs) {
396 ivStorage.append(ivs.begin(), ivs.end());
397 generateParallelLoopNest(
398 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
399 iteratorTypes.drop_front(),
400 procInfo.empty() ? procInfo : procInfo.drop_front(),
401 bodyBuilderFn, ivStorage);
403 return;
406 unsigned nLoops = iteratorTypes.size();
407 unsigned numProcessed = 0;
408 DistributionMethod distributionMethod = DistributionMethod::None;
409 if (procInfo.empty()) {
410 numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size();
411 } else {
412 distributionMethod = procInfo.front().distributionMethod;
413 numProcessed =
414 nLoops - procInfo
415 .drop_while([&](linalg::ProcInfo p) {
416 return p.distributionMethod == distributionMethod;
418 .size();
421 auto remainderProcInfo =
422 procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
423 switch (distributionMethod) {
424 case DistributionMethod::None: {
425 // Generate a single parallel loop-nest operation for all outermost
426 // parallel loops and recurse.
427 b.create<scf::ParallelOp>(
428 loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
429 steps.take_front(numProcessed),
430 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
431 ivStorage.append(localIvs.begin(), localIvs.end());
432 generateParallelLoopNest(
433 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
434 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
435 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
436 bodyBuilderFn, ivStorage);
438 return;
440 case DistributionMethod::Cyclic: {
441 // Generate a single parallel loop-nest operation for all outermost
442 // parallel loops and recurse.
443 b.create<scf::ParallelOp>(
444 loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
445 steps.take_front(numProcessed),
446 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
447 ivStorage.append(localIvs.begin(), localIvs.end());
448 generateParallelLoopNest(
449 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
450 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
451 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
452 bodyBuilderFn, ivStorage);
454 return;
456 case DistributionMethod::CyclicNumProcsGeNumIters: {
457 // Check (for the processed loops) that the iteration is in-bounds.
458 ArithBuilder ab(b, loc);
459 Value cond = ab.slt(lbs[0], ubs[0]);
460 for (unsigned i = 1; i < numProcessed; ++i)
461 cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
462 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
463 b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) {
464 generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed),
465 ubs.drop_front(numProcessed),
466 steps.drop_front(numProcessed),
467 iteratorTypes.drop_front(numProcessed),
468 remainderProcInfo, bodyBuilderFn, ivStorage);
469 b.create<scf::YieldOp>(loc, ValueRange{});
471 return;
473 case DistributionMethod::CyclicNumProcsEqNumIters:
474 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
475 // with inner loop generation.
476 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
477 generateParallelLoopNest(
478 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
479 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
480 remainderProcInfo, bodyBuilderFn, ivStorage);
481 return;
485 /// Specialization for generating a mix of parallel and sequential scf loops.
486 template <>
487 void GenerateLoopNest<scf::ParallelOp>::doit(
488 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
489 ArrayRef<utils::IteratorType> iteratorTypes,
490 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
491 ValueRange)>
492 bodyBuilderFn,
493 ArrayRef<linalg::ProcInfo> procInfo) {
494 SmallVector<Value> iterArgInitValues;
495 if (!linalgOp.hasPureBufferSemantics())
496 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
497 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
498 // This function may be passed more iterator types than ranges.
499 assert(iteratorTypes.size() >= loopRanges.size() &&
500 "expected iterator type for all ranges");
501 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
502 "expected proc information for all loops when present");
503 iteratorTypes = iteratorTypes.take_front(loopRanges.size());
504 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
505 unsigned numLoops = iteratorTypes.size();
506 ivs.reserve(numLoops);
507 lbsStorage.reserve(numLoops);
508 ubsStorage.reserve(numLoops);
509 stepsStorage.reserve(numLoops);
511 // Get the loop lb, ub, and step.
512 unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
514 // Modify the lb, ub, and step based on the distribution options.
515 for (const auto &it : llvm::enumerate(procInfo)) {
516 if (it.value().distributionMethod != linalg::DistributionMethod::None) {
517 updateBoundsForCyclicDistribution(
518 b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
519 ubsStorage[it.index()], stepsStorage[it.index()]);
522 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
523 generateParallelLoopNest(
524 b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
525 [&](OpBuilder &b, Location loc, ValueRange ivs) {
526 bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
528 ivs);
530 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
533 static Operation *materializeTiledShape(OpBuilder &builder, Location loc,
534 Value valueToTile,
535 const SliceParameters &sliceParams) {
536 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
537 auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
538 .Case([&](MemRefType) {
539 return builder.create<memref::SubViewOp>(
540 loc, valueToTile, sliceParams.offsets,
541 sliceParams.sizes, sliceParams.strides);
543 .Case([&](RankedTensorType) {
544 return builder.create<tensor::ExtractSliceOp>(
545 loc, valueToTile, sliceParams.offsets,
546 sliceParams.sizes, sliceParams.strides);
548 .Default([](ShapedType) -> Operation * {
549 llvm_unreachable("Unexpected shaped type");
551 return sliceOp;
554 Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
555 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
556 ArrayRef<OpFoldResult> lbs,
557 ArrayRef<OpFoldResult> ubs,
558 ArrayRef<OpFoldResult> subShapeSizes,
559 bool omitPartialTileCheck) {
560 SliceParameters sliceParams =
561 computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
562 ubs, subShapeSizes, omitPartialTileCheck);
563 return materializeTiledShape(builder, loc, valueToTile, sliceParams);
566 SliceParameters
567 computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
568 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
569 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
570 ArrayRef<OpFoldResult> subShapeSizes,
571 bool omitPartialTileCheck) {
572 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
573 assert(shapedType && "only shaped types can be tiled");
574 ArrayRef<int64_t> shape = shapedType.getShape();
575 int64_t rank = shapedType.getRank();
577 // Compute offsets/sizes/strides for the tile.
578 SliceParameters sliceParams;
579 sliceParams.offsets.reserve(rank);
580 sliceParams.sizes.reserve(rank);
581 sliceParams.strides.reserve(rank);
582 for (unsigned r = 0; r < rank; ++r) {
583 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
584 if (!isTiled(map.getSubMap({r}), tileSizes)) {
585 sliceParams.offsets.push_back(builder.getIndexAttr(0));
586 OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
587 sliceParams.sizes.push_back(dim);
588 sliceParams.strides.push_back(builder.getIndexAttr(1));
589 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
590 continue;
592 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
594 // Tiling creates a new slice at the proper index, the slice step is 1
595 // (i.e. the op does not subsample, stepping occurs in the loop).
596 auto m = map.getSubMap({r});
597 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
598 IRRewriter rewriter(builder);
599 OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs);
600 sliceParams.offsets.push_back(offset);
601 OpFoldResult closedIntSize =
602 makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
603 // Resulting size needs to be made half open interval again.
604 AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
605 OpFoldResult size =
606 makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize);
607 LLVM_DEBUG(llvm::dbgs()
608 << "computeSliceParameters: raw size: " << size << "\n");
609 LLVM_DEBUG(llvm::dbgs()
610 << "computeSliceParameters: new offset: " << offset << "\n");
611 sliceParams.strides.push_back(builder.getIndexAttr(1));
613 if (omitPartialTileCheck) {
614 // We statically know that the partial/boundary tile condition is
615 // unnecessary.
616 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
617 sliceParams.sizes.push_back(size);
618 continue;
621 // The size of the subview / extract_slice should be trimmed to avoid
622 // out-of-bounds accesses, unless:
623 // a. We statically know the subshape size divides the shape size evenly.
624 // b. The subshape size is 1. According to the way the loops are set up,
625 // tensors with "0" dimensions would never be constructed.
626 int64_t shapeSize = shape[r];
627 std::optional<int64_t> sizeCst = getConstantIntValue(size);
628 auto hasTileSizeOne = sizeCst && *sizeCst == 1;
629 auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
630 ((shapeSize % *sizeCst) == 0);
631 if (!hasTileSizeOne && !dividesEvenly) {
632 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
633 << ", size: " << size
634 << ": make sure in bound with affine.min\n");
636 AffineExpr dim0, dim1, dim2;
637 MLIRContext *context = builder.getContext();
638 bindDims(context, dim0, dim1, dim2);
640 // Get the dimension size for this dimension. We need to first calculate
641 // the max index and then plus one. This is important because for
642 // convolution ops, we have its input window dimension's affine map of the
643 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
644 // dimension and `s0` is stride. Directly use the dimension size of
645 // output/filer window dimensions will cause incorrect calculation.
646 AffineMap minusOneMap = AffineMap::inferFromExprList(
647 {ArrayRef<AffineExpr>{dim0 - 1}}, context)
648 .front();
649 AffineMap plusOneMap = AffineMap::inferFromExprList(
650 {ArrayRef<AffineExpr>{dim0 + 1}}, context)
651 .front();
652 SmallVector<OpFoldResult> maxIndices =
653 llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) {
654 return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap,
655 {ub});
656 }));
657 OpFoldResult maxIndex =
658 makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices);
659 OpFoldResult d =
660 makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex});
662 // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
663 AffineMap minMap = AffineMap::inferFromExprList(
664 {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context)
665 .front();
666 size =
667 makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});
669 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
670 sliceParams.sizes.push_back(size);
672 return sliceParams;
675 SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
676 ArrayRef<OpFoldResult> ivs,
677 ArrayRef<OpFoldResult> tileSizes) {
678 SmallVector<OpFoldResult> offsets;
679 for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
680 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
681 bool isTiled = !isZeroIndex(tileSizes[idx]);
682 offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
683 LLVM_DEBUG(llvm::dbgs()
684 << "computeTileOffsets: " << offsets.back() << "\n");
686 return offsets;
689 SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
690 ArrayRef<OpFoldResult> tileSizes,
691 ArrayRef<OpFoldResult> sizeBounds) {
692 SmallVector<OpFoldResult> sizes;
693 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
694 bool isTiled = !isZeroIndex(tileSizes[idx]);
695 // Before composing, we need to make range a closed interval.
696 OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
697 AffineExpr d0 = getAffineDimExpr(0, b.getContext());
698 IRRewriter rewriter(b);
699 sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size));
700 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
702 return sizes;
705 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
706 if (op.hasPureBufferSemantics())
707 return {};
708 return llvm::to_vector(
709 llvm::map_range(op.getDpsInitsMutable(), [&](OpOperand &opOperand) {
710 return operands[opOperand.getOperandNumber()].getType();
711 }));
714 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
715 LinalgOp op, ValueRange operands,
716 ValueRange results) {
717 if (op.hasPureBufferSemantics())
718 return {};
719 SmallVector<Value> tensorResults;
720 tensorResults.reserve(results.size());
721 // Insert a insert_slice for each output tensor.
722 unsigned resultIdx = 0;
723 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
724 // TODO: use an interface/adaptor to avoid leaking position in
725 // `tiledOperands`.
726 Value outputTensor = operands[opOperand.getOperandNumber()];
727 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
728 Value inserted = builder.create<tensor::InsertSliceOp>(
729 loc, sliceOp.getSource().getType(), results[resultIdx],
730 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
731 sliceOp.getStrides(), sliceOp.getStaticOffsets(),
732 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
733 tensorResults.push_back(inserted);
734 } else {
735 tensorResults.push_back(results[resultIdx]);
737 ++resultIdx;
739 return tensorResults;
742 SmallVector<std::optional<SliceParameters>>
743 computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
744 ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
745 ArrayRef<OpFoldResult> tileSizes,
746 ArrayRef<OpFoldResult> sizeBounds,
747 bool omitPartialTileCheck) {
748 assert(ivs.size() == static_cast<size_t>(llvm::count_if(
749 llvm::make_range(tileSizes.begin(), tileSizes.end()),
750 [](OpFoldResult v) { return !isZeroIndex(v); })) &&
751 "expected as many ivs as non-zero sizes");
753 // Construct (potentially temporary) mins and maxes on which to apply maps
754 // that define tile subshapes.
755 SmallVector<OpFoldResult> lbs =
756 computeTileOffsets(builder, loc, ivs, tileSizes);
757 SmallVector<OpFoldResult> subShapeSizes =
758 computeTileSizes(builder, loc, tileSizes, sizeBounds);
760 assert(static_cast<int64_t>(valuesToTile.size()) <=
761 linalgOp->getNumOperands() &&
762 "more value to tile than operands.");
763 SmallVector<std::optional<SliceParameters>> allSliceParams;
764 allSliceParams.reserve(valuesToTile.size());
765 for (auto [opOperand, val] :
766 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
767 Value shapedOp = val;
768 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
769 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
770 // Use `opOperand` as is if it is not tiled and not an output tensor. Having
771 // an extract/insert slice pair for all output tensors simplifies follow up
772 // transformations such as padding and bufferization since the
773 // extract/insert slice pairs make the accessed iteration argument
774 // subdomains explicit.
776 Type operandType = opOperand.get().getType();
777 if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
778 linalgOp.isDpsInit(&opOperand))) {
779 allSliceParams.push_back(std::nullopt);
780 LLVM_DEBUG(llvm::dbgs()
781 << ": not tiled: use shape: " << operandType << "\n");
782 continue;
784 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
786 allSliceParams.push_back(computeSliceParameters(
787 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
788 omitPartialTileCheck));
791 return allSliceParams;
794 SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
795 LinalgOp linalgOp, ValueRange valuesToTile,
796 ArrayRef<OpFoldResult> ivs,
797 ArrayRef<OpFoldResult> tileSizes,
798 ArrayRef<OpFoldResult> sizeBounds,
799 bool omitPartialTileCheck) {
800 SmallVector<std::optional<SliceParameters>> allSliceParameter =
801 computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
802 tileSizes, sizeBounds, omitPartialTileCheck);
803 SmallVector<Value> tiledShapes;
804 for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
805 Value valueToTile = std::get<0>(item);
806 std::optional<SliceParameters> sliceParams = std::get<1>(item);
807 tiledShapes.push_back(
808 sliceParams.has_value()
809 ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
810 ->getResult(0)
811 : valueToTile);
813 return tiledShapes;
816 void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
817 ArrayRef<OpFoldResult> offsets) {
818 IRRewriter rewriter(b);
819 offsetIndices(rewriter, linalgOp, offsets);
822 void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
823 ArrayRef<OpFoldResult> offsets) {
824 if (!linalgOp.hasIndexSemantics())
825 return;
827 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
828 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
829 continue;
830 OpBuilder::InsertionGuard guard(b);
831 b.setInsertionPointAfter(indexOp);
832 AffineExpr index, offset;
833 bindDims(b.getContext(), index, offset);
834 OpFoldResult applied = makeComposedFoldedAffineApply(
835 b, indexOp.getLoc(), index + offset,
836 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
837 Value materialized =
838 getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
839 b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
840 return use.getOwner() != materialized.getDefiningOp();
845 /// Get the reassociation maps to fold the result of a extract_slice (or source
846 /// of a insert_slice) operation with given offsets, and sizes to its
847 /// rank-reduced version. This is only done for the cases where the size is 1
848 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
849 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
850 /// potentially cannot be handled).
851 std::optional<SmallVector<ReassociationIndices>>
852 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
853 SmallVector<ReassociationIndices> reassociation;
854 ReassociationIndices curr;
855 for (const auto &it : llvm::enumerate(mixedSizes)) {
856 auto dim = it.index();
857 auto size = it.value();
858 curr.push_back(dim);
859 auto attr = llvm::dyn_cast_if_present<Attribute>(size);
860 if (attr && cast<IntegerAttr>(attr).getInt() == 1)
861 continue;
862 reassociation.emplace_back(ReassociationIndices{});
863 std::swap(reassociation.back(), curr);
865 // When the reassociations are not empty, then fold the remaining
866 // unit-dimensions into the last dimension. If the reassociations so far is
867 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
868 if (!curr.empty() && !reassociation.empty())
869 reassociation.back().append(curr.begin(), curr.end());
870 return reassociation;
873 } // namespace linalg
874 } // namespace mlir