1 //===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines transform dialect operations used for testing
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Index/IR/IndexDialect.h"
16 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
17 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
18 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
19 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/IR/Dominance.h"
22 #include "mlir/IR/OpImplementation.h"
23 #include "mlir/Interfaces/TilingInterface.h"
25 #define GET_OP_CLASSES
26 #include "TestTilingInterfaceTransformOps.h.inc"
29 using namespace mlir::transform
;
31 //===----------------------------------------------------------------------===//
33 //===----------------------------------------------------------------------===//
35 static llvm::SmallDenseSet
<Operation
*> collectTiledAndFusedOps(Operation
*op
) {
36 SmallVector
<Operation
*> worklist
;
37 llvm::SmallDenseSet
<Operation
*> producers
;
38 worklist
.push_back(op
);
40 while (!worklist
.empty()) {
41 Operation
*current
= worklist
.pop_back_val();
42 for (OpOperand
&operand
: current
->getOpOperands()) {
43 Operation
*producer
= operand
.get().getDefiningOp();
44 if (!producer
|| !isa
<TilingInterface
>(producer
) ||
45 producers
.contains(producer
))
47 worklist
.push_back(producer
);
48 producers
.insert(producer
);
54 /// Apply a tile and fuse transformation to all payload ops and store both the
55 /// tiled operation as well as the created tile loops.
56 template <typename Range
>
58 applyTileAndFuseToAll(RewriterBase
&rewriter
, Operation
*transformOp
,
59 Range
&&payloadOps
, unsigned numLoops
,
60 ArrayRef
<OpFoldResult
> tileSizes
,
61 ArrayRef
<int64_t> interchange
, bool useForall
,
62 TransformResults
&transformResults
) {
63 SmallVector
<Operation
*> tiledOps
;
64 SmallVector
<SmallVector
<Operation
*>> loopOps(numLoops
);
66 for (Operation
*target
: payloadOps
) {
67 auto tilingInterfaceOp
= dyn_cast
<TilingInterface
>(target
);
68 if (!tilingInterfaceOp
)
69 return transformOp
->emitError("only TilingInterface ops are supported");
70 DominanceInfo
dominanceInfo(tilingInterfaceOp
);
72 llvm::SmallDenseSet
<Operation
*> tiledAndFusedOps
=
73 collectTiledAndFusedOps(tilingInterfaceOp
);
74 llvm::DenseSet
<Operation
*> yieldReplacementsFor
;
75 for (auto op
: tiledAndFusedOps
) {
76 if (llvm::any_of(op
->getUsers(), [&](Operation
*user
) {
77 return dominanceInfo
.properlyDominates(tilingInterfaceOp
, user
);
79 yieldReplacementsFor
.insert(op
);
83 scf::SCFTilingOptions tilingOptions
;
84 tilingOptions
.setTileSizes(tileSizes
).setInterchange(interchange
);
86 tilingOptions
.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp
);
89 scf::SCFTileAndFuseOptions tileAndFuseOptions
;
90 tileAndFuseOptions
.setTilingOptions(tilingOptions
);
92 scf::SCFTileAndFuseOptions::ControlFnTy controlFn
=
93 [&](tensor::ExtractSliceOp candidateSliceOp
, OpResult originalProducer
,
94 bool isDestinationOperand
)
95 -> std::optional
<scf::SCFTileAndFuseOptions::ControlFnResult
> {
96 Operation
*owner
= originalProducer
.getOwner();
97 bool yieldProducerReplacement
= yieldReplacementsFor
.contains(owner
);
98 return scf::SCFTileAndFuseOptions::ControlFnResult
{
99 yieldProducerReplacement
};
101 tileAndFuseOptions
.setFusionControlFn(controlFn
);
103 rewriter
.setInsertionPoint(target
);
104 FailureOr
<scf::SCFTileAndFuseResult
> tiledResults
=
105 scf::tileConsumerAndFuseProducersUsingSCF(rewriter
, tilingInterfaceOp
,
107 if (failed(tiledResults
))
110 // Perform the replacement of tiled and fused values.
111 SmallVector
<Operation
*> opsToReplace
{target
};
112 llvm::append_range(opsToReplace
, tiledResults
->fusedProducers
);
113 for (Operation
*toReplace
: opsToReplace
) {
114 for (OpResult res
: toReplace
->getResults())
115 if (auto replacement
= tiledResults
->replacements
.lookup(res
)) {
116 Operation
*replacementOp
= replacement
.getDefiningOp();
117 rewriter
.replaceUsesWithIf(res
, replacement
, [&](OpOperand
&use
) {
118 Operation
*user
= use
.getOwner();
119 return dominanceInfo
.properlyDominates(replacementOp
, user
) &&
120 user
->getParentOp() == replacementOp
->getParentOp();
124 if (toReplace
->use_empty()) {
125 rewriter
.eraseOp(toReplace
);
129 // Report back the relevant handles to the transform op.
130 tiledOps
.push_back(tiledResults
->tiledAndFusedOps
.front());
131 assert(tiledResults
->loops
.size() == numLoops
&&
132 "Mismatched number of loops, tile and fuse transform should have "
134 for (unsigned int i
= 0; i
< numLoops
; ++i
)
135 loopOps
[i
].push_back(tiledResults
->loops
[i
]);
138 transformResults
.set(transformOp
->getOpResult(0), tiledOps
);
139 for (unsigned int i
= 0; i
< numLoops
; ++i
)
140 transformResults
.set(transformOp
->getOpResult(i
+ 1), loopOps
[i
]);
145 DiagnosedSilenceableFailure
146 transform::TestFuseAndYieldOp::apply(TransformRewriter
&rewriter
,
147 TransformResults
&transformResults
,
148 TransformState
&state
) {
149 SmallVector
<int64_t> tileSizes
=
150 extractFromIntegerArrayAttr
<int64_t>(getTileSizes());
151 SmallVector
<int64_t> tileInterchange
=
152 extractFromIntegerArrayAttr
<int64_t>(getTileInterchange());
154 SmallVector
<OpFoldResult
> tileSizesOfr
=
155 getAsIndexOpFoldResult(rewriter
.getContext(), tileSizes
);
157 LogicalResult result
= applyTileAndFuseToAll(
158 rewriter
, getOperation(), state
.getPayloadOps(getTarget()),
159 tileSizes
.size() - llvm::count(tileSizes
, 0), tileSizesOfr
,
160 tileInterchange
, getUseForall(), transformResults
);
161 return failed(result
) ? DiagnosedSilenceableFailure::definiteFailure()
162 : DiagnosedSilenceableFailure::success();
165 //===----------------------------------------------------------------------===//
166 // TestFuseConsumerOp
167 //===----------------------------------------------------------------------===//
169 /// Apply fusing of consumer transformation to all payload ops and store both
170 /// the original consumer operation as well as the fused consumer operation.
171 template <typename Range
>
173 applyFuseConsumer(RewriterBase
&rewriter
, Operation
*transformOp
,
174 Range
&&payloadOps
, uint32_t numConsumerToFuse
,
175 TransformResults
&transformResults
) {
176 SmallVector
<Operation
*> originalConsumerOps
;
177 SmallVector
<Operation
*> fusedConsumerOps
;
179 for (Operation
*target
: payloadOps
) {
180 rewriter
.setInsertionPoint(target
);
182 while (numConsumerToFuse
--) {
183 FailureOr
<scf::SCFFuseConsumerOfSliceResult
> fuseConsumerResults
=
184 scf::tileAndFuseConsumerOfSlice(rewriter
, target
);
186 if (failed(fuseConsumerResults
))
189 // Report back the relevant handles to the transform op.
190 originalConsumerOps
.push_back(
191 fuseConsumerResults
->origConsumerOperand
->getOwner());
192 fusedConsumerOps
.push_back(
193 fuseConsumerResults
->tiledAndFusedConsumerOperand
->getOwner());
197 transformResults
.set(transformOp
->getOpResult(0), originalConsumerOps
);
198 transformResults
.set(transformOp
->getOpResult(1), fusedConsumerOps
);
202 DiagnosedSilenceableFailure
203 transform::TestFuseConsumerOp::apply(TransformRewriter
&rewriter
,
204 TransformResults
&transformResults
,
205 TransformState
&state
) {
206 LogicalResult result
= applyFuseConsumer(
207 rewriter
, getOperation(), state
.getPayloadOps(getTarget()),
208 getNumConsumerToFuse(), transformResults
);
209 return failed(result
) ? DiagnosedSilenceableFailure::definiteFailure()
210 : DiagnosedSilenceableFailure::success();
213 void transform::TestFuseConsumerOp::getEffects(
214 SmallVectorImpl
<MemoryEffects::EffectInstance
> &effects
) {
215 consumesHandle(getTargetMutable(), effects
);
216 producesHandle(getOperation()->getOpResults(), effects
);
217 modifiesPayload(effects
);
220 //===----------------------------------------------------------------------===//
221 // TestTileUsingForallOp
222 //===----------------------------------------------------------------------===//
224 /// Apply a tiling transformation to all payload ops and store both the
225 /// tiled operation as well as the created tile loops.
226 template <typename Range
>
228 applyTileToAll(RewriterBase
&rewriter
, Operation
*transformOp
,
229 Range
&&payloadOps
, ArrayRef
<OpFoldResult
> tileSizes
,
230 ArrayRef
<int64_t> interchange
, std::optional
<ArrayAttr
> mapping
,
231 TransformResults
&transformResults
) {
232 SmallVector
<Operation
*> tiledOps
;
233 SmallVector
<Operation
*> loopOps
;
235 for (Operation
*target
: payloadOps
) {
236 auto tilingInterfaceOp
= dyn_cast
<TilingInterface
>(target
);
237 if (!tilingInterfaceOp
)
238 return transformOp
->emitError("only TilingInterface ops are supported");
239 scf::SCFTilingOptions tilingOptions
;
240 tilingOptions
.setTileSizes(tileSizes
).setInterchange(interchange
);
242 tilingOptions
.setMapping(mapping
.value().getValue());
244 tilingOptions
.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp
);
246 rewriter
.setInsertionPoint(target
);
247 FailureOr
<scf::SCFTilingResult
> tiledResults
=
248 scf::tileUsingSCF(rewriter
, tilingInterfaceOp
, tilingOptions
);
249 if (failed(tiledResults
))
252 // Perform the replacement of tiled and fused values.
253 rewriter
.replaceOp(tilingInterfaceOp
, tiledResults
->replacements
);
255 // Report back the relevant handles to the transform op.
256 tiledOps
.push_back(tiledResults
->tiledOps
.front());
257 for (Operation
*loop
: tiledResults
->loops
)
258 loopOps
.push_back(loop
);
261 transformResults
.set(transformOp
->getOpResult(0), tiledOps
);
262 for (auto [index
, loop
] : llvm::enumerate(loopOps
))
263 transformResults
.set(transformOp
->getOpResult(index
+ 1), {loop
});
268 DiagnosedSilenceableFailure
269 transform::TestTileUsingForallOp::apply(TransformRewriter
&rewriter
,
270 TransformResults
&transformResults
,
271 TransformState
&state
) {
272 SmallVector
<int64_t> tileSizes
=
273 extractFromIntegerArrayAttr
<int64_t>(getTileSizes());
274 SmallVector
<int64_t> interchange
=
275 extractFromIntegerArrayAttr
<int64_t>(getInterchange());
276 SmallVector
<OpFoldResult
> tileSizesOfr
=
277 getAsIndexOpFoldResult(rewriter
.getContext(), tileSizes
);
279 LogicalResult result
=
280 applyTileToAll(rewriter
, getOperation(), state
.getPayloadOps(getTarget()),
281 tileSizesOfr
, interchange
, getMapping(), transformResults
);
282 return failed(result
) ? DiagnosedSilenceableFailure::definiteFailure()
283 : DiagnosedSilenceableFailure::success();
286 void transform::TestTileUsingForallOp::getEffects(
287 SmallVectorImpl
<MemoryEffects::EffectInstance
> &effects
) {
288 consumesHandle(getTargetMutable(), effects
);
289 producesHandle(getOperation()->getOpResults(), effects
);
290 modifiesPayload(effects
);
293 //===----------------------------------------------------------------------===//
294 // TestFuseUsingForallOp
295 //===----------------------------------------------------------------------===//
297 /// Apply a tiling transformation to all payload ops and store both the
298 /// tiled operation as well as the created tile loops.
299 template <typename Range
>
300 static LogicalResult
applyTilingToAll(
301 RewriterBase
&rewriter
, Operation
*transformOp
, Range
&&payloadOps
,
302 unsigned numLoops
, TransformResults
&transformResults
,
303 function_ref
<FailureOr
<scf::SCFTileAndFuseResult
>(TilingInterface
)>
305 SmallVector
<Operation
*> tiledLinalgOps
;
306 SmallVector
<SmallVector
<Operation
*>> loopOps(1);
308 for (Operation
*target
: payloadOps
) {
309 auto tilingInterfaceOp
= dyn_cast
<TilingInterface
>(target
);
310 if (!tilingInterfaceOp
)
311 return transformOp
->emitError("only TilingInterface ops are supported");
313 rewriter
.setInsertionPoint(target
);
314 FailureOr
<scf::SCFTileAndFuseResult
> tiledResults
=
315 applyFn(tilingInterfaceOp
);
316 if (failed(tiledResults
))
319 // Perform the replacement of tiled and fused values.
320 SmallVector
<Operation
*> opsToReplace
{target
};
321 llvm::append_range(opsToReplace
, tiledResults
->fusedProducers
);
322 for (Operation
*toReplace
: opsToReplace
) {
323 for (OpResult res
: toReplace
->getResults())
324 if (auto replacement
= tiledResults
->replacements
.lookup(res
))
325 rewriter
.replaceAllUsesWith(res
, replacement
);
326 if (toReplace
->use_empty())
327 rewriter
.eraseOp(toReplace
);
330 // Report back the relevant handles to the transform op.
331 tiledLinalgOps
.push_back(tiledResults
->tiledAndFusedOps
.front());
332 assert(tiledResults
->loops
.size() == 1 &&
333 cast
<scf::ForallOp
>(tiledResults
->loops
[0]).getRank() == numLoops
&&
334 "Mismatched number of loops, tile and fuse transform should have "
336 loopOps
[0] = {tiledResults
->loops
[0]};
339 transformResults
.set(transformOp
->getOpResult(0), tiledLinalgOps
);
340 if (!loopOps
.empty())
341 transformResults
.set(transformOp
->getOpResult(1), loopOps
[0]);
346 DiagnosedSilenceableFailure
347 transform::TestFuseUsingForallOp::apply(TransformRewriter
&rewriter
,
348 TransformResults
&transformResults
,
349 TransformState
&state
) {
350 SmallVector
<int64_t> tileSizes
=
351 extractFromIntegerArrayAttr
<int64_t>(getTileSizes());
352 SmallVector
<int64_t> tileInterchange
=
353 extractFromIntegerArrayAttr
<int64_t>(getInterchange());
355 scf::SCFTilingOptions tilingOptions
;
356 tilingOptions
.interchangeVector
= tileInterchange
;
357 SmallVector
<OpFoldResult
> tileSizesOfr
=
358 getAsIndexOpFoldResult(rewriter
.getContext(), tileSizes
);
359 tilingOptions
= tilingOptions
.setTileSizes(tileSizesOfr
);
360 tilingOptions
.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp
);
361 scf::SCFTileAndFuseOptions tileAndFuseOptions
;
362 tileAndFuseOptions
.tilingOptions
= tilingOptions
;
363 LogicalResult result
= applyTilingToAll(
364 rewriter
, getOperation(), state
.getPayloadOps(getRootOp()),
365 tileSizes
.size() - llvm::count(tileSizes
, 0), transformResults
,
366 [&](TilingInterface tilingInterfaceOp
)
367 -> FailureOr
<scf::SCFTileAndFuseResult
> {
368 return tileConsumerAndFuseProducersUsingSCF(rewriter
, tilingInterfaceOp
,
371 return failed(result
) ? DiagnosedSilenceableFailure::definiteFailure()
372 : DiagnosedSilenceableFailure::success();
375 void transform::TestFuseUsingForallOp::getEffects(
376 SmallVectorImpl
<MemoryEffects::EffectInstance
> &effects
) {
377 consumesHandle(getRootOpMutable(), effects
);
378 producesHandle(getOperation()->getOpResults(), effects
);
379 modifiesPayload(effects
);
382 #define GET_OP_CLASSES
383 #include "TestTilingInterfaceTransformOps.cpp.inc"
386 class TestTilingInterfaceDialectExtension
387 : public transform::TransformDialectExtension
<
388 TestTilingInterfaceDialectExtension
> {
390 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
391 TestTilingInterfaceDialectExtension
)
396 declareDependentDialect
<affine::AffineDialect
>();
397 declareDependentDialect
<index::IndexDialect
>();
398 declareDependentDialect
<scf::SCFDialect
>();
399 declareDependentDialect
<tensor::TensorDialect
>();
401 registerTransformOps
<
403 #include "TestTilingInterfaceTransformOps.cpp.inc"
410 void registerTestTilingInterfaceTransformDialectExtension(
411 DialectRegistry
®istry
) {
412 registry
.addExtensions
<TestTilingInterfaceDialectExtension
>();