1 //===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines transform dialect operations used for testing
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Index/IR/IndexDialect.h"
16 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
17 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
18 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
19 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/IR/Dominance.h"
22 #include "mlir/IR/OpImplementation.h"
23 #include "mlir/Interfaces/TilingInterface.h"
25 #define GET_OP_CLASSES
26 #include "TestTilingInterfaceTransformOps.h.inc"
29 using namespace mlir::transform
;
31 //===----------------------------------------------------------------------===//
33 //===----------------------------------------------------------------------===//
35 static llvm::SmallDenseSet
<Operation
*> collectTiledAndFusedOps(Operation
*op
) {
36 SmallVector
<Operation
*> worklist
;
37 llvm::SmallDenseSet
<Operation
*> producers
;
38 worklist
.push_back(op
);
40 while (!worklist
.empty()) {
41 Operation
*current
= worklist
.pop_back_val();
42 for (OpOperand
&operand
: current
->getOpOperands()) {
43 Operation
*producer
= operand
.get().getDefiningOp();
44 if (!producer
|| !isa
<TilingInterface
>(producer
) ||
45 producers
.contains(producer
))
47 worklist
.push_back(producer
);
48 producers
.insert(producer
);
54 /// Apply a tile and fuse transformation to all payload ops and store both the
55 /// tiled operation as well as the created tile loops.
56 template <typename Range
>
58 applyTileAndFuseToAll(RewriterBase
&rewriter
, Operation
*transformOp
,
59 Range
&&payloadOps
, unsigned numLoops
,
60 ArrayRef
<OpFoldResult
> tileSizes
,
61 ArrayRef
<int64_t> interchange
, bool useForall
,
62 TransformResults
&transformResults
) {
63 SmallVector
<Operation
*> tiledOps
;
64 SmallVector
<SmallVector
<Operation
*>> loopOps(numLoops
);
66 for (Operation
*target
: payloadOps
) {
67 auto tilingInterfaceOp
= dyn_cast
<TilingInterface
>(target
);
68 if (!tilingInterfaceOp
)
69 return transformOp
->emitError("only TilingInterface ops are supported");
70 DominanceInfo
dominanceInfo(tilingInterfaceOp
);
72 llvm::SmallDenseSet
<Operation
*> tiledAndFusedOps
=
73 collectTiledAndFusedOps(tilingInterfaceOp
);
74 llvm::DenseSet
<Operation
*> yieldReplacementsFor
;
75 for (auto op
: tiledAndFusedOps
) {
76 if (llvm::any_of(op
->getUsers(), [&](Operation
*user
) {
77 return dominanceInfo
.properlyDominates(tilingInterfaceOp
, user
);
79 yieldReplacementsFor
.insert(op
);
83 scf::SCFTilingOptions tilingOptions
;
84 tilingOptions
.setTileSizes(tileSizes
).setInterchange(interchange
);
86 tilingOptions
.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp
);
89 scf::SCFTileAndFuseOptions tileAndFuseOptions
;
90 tileAndFuseOptions
.setTilingOptions(tilingOptions
);
92 scf::SCFTileAndFuseOptions::ControlFnTy controlFn
=
93 [&](tensor::ExtractSliceOp candidateSliceOp
, OpResult originalProducer
,
94 bool isDestinationOperand
)
95 -> std::optional
<scf::SCFTileAndFuseOptions::ControlFnResult
> {
96 Operation
*owner
= originalProducer
.getOwner();
97 bool yieldProducerReplacement
= yieldReplacementsFor
.contains(owner
);
98 return scf::SCFTileAndFuseOptions::ControlFnResult
{
99 yieldProducerReplacement
};
101 tileAndFuseOptions
.setFusionControlFn(controlFn
);
103 rewriter
.setInsertionPoint(target
);
104 FailureOr
<scf::SCFTileAndFuseResult
> tiledResults
=
105 scf::tileConsumerAndFuseProducersUsingSCF(rewriter
, tilingInterfaceOp
,
107 if (failed(tiledResults
))
110 // Perform the replacement of tiled and fused values.
111 SmallVector
<Operation
*> opsToReplace
{target
};
112 llvm::append_range(opsToReplace
, tiledResults
->fusedProducers
);
113 for (Operation
*toReplace
: opsToReplace
) {
114 for (OpResult res
: toReplace
->getResults())
115 if (auto replacement
= tiledResults
->replacements
.lookup(res
)) {
116 Operation
*replacementOp
= replacement
.getDefiningOp();
117 rewriter
.replaceUsesWithIf(res
, replacement
, [&](OpOperand
&use
) {
118 Operation
*user
= use
.getOwner();
119 return dominanceInfo
.properlyDominates(replacementOp
, user
) &&
120 user
->getParentOp() == replacementOp
->getParentOp();
124 if (toReplace
->use_empty()) {
125 rewriter
.eraseOp(toReplace
);
129 // Report back the relevant handles to the transform op.
130 tiledOps
.push_back(tiledResults
->tiledAndFusedOps
.front());
131 assert(tiledResults
->loops
.size() == numLoops
&&
132 "Mismatched number of loops, tile and fuse transform should have "
134 for (unsigned int i
= 0; i
< numLoops
; ++i
)
135 loopOps
[i
].push_back(tiledResults
->loops
[i
]);
138 transformResults
.set(transformOp
->getOpResult(0), tiledOps
);
139 for (unsigned int i
= 0; i
< numLoops
; ++i
)
140 transformResults
.set(transformOp
->getOpResult(i
+ 1), loopOps
[i
]);
145 DiagnosedSilenceableFailure
146 transform::TestFuseAndYieldOp::apply(TransformRewriter
&rewriter
,
147 TransformResults
&transformResults
,
148 TransformState
&state
) {
149 SmallVector
<int64_t> tileSizes
=
150 extractFromIntegerArrayAttr
<int64_t>(getTileSizes());
151 SmallVector
<int64_t> tileInterchange
=
152 extractFromIntegerArrayAttr
<int64_t>(getTileInterchange());
154 SmallVector
<OpFoldResult
> tileSizesOfr
=
155 getAsIndexOpFoldResult(rewriter
.getContext(), tileSizes
);
157 LogicalResult result
= applyTileAndFuseToAll(
158 rewriter
, getOperation(), state
.getPayloadOps(getTarget()),
159 tileSizes
.size() - llvm::count(tileSizes
, 0), tileSizesOfr
,
160 tileInterchange
, getUseForall(), transformResults
);
161 return failed(result
) ? DiagnosedSilenceableFailure::definiteFailure()
162 : DiagnosedSilenceableFailure::success();
165 //===----------------------------------------------------------------------===//
166 // TestFuseConsumerOp
167 //===----------------------------------------------------------------------===//
169 /// Apply fusing of consumer transformation to all payload ops and store both
170 /// the original consumer operation as well as the fused consumer operation.
171 template <typename Range
>
173 applyFuseConsumer(RewriterBase
&rewriter
, Operation
*transformOp
,
174 Range
&&payloadOps
, uint32_t numConsumerToFuse
,
175 TransformResults
&transformResults
) {
176 SmallVector
<Operation
*> originalConsumerOps
;
177 SmallVector
<Operation
*> fusedConsumerOps
;
179 for (Operation
*target
: payloadOps
) {
180 rewriter
.setInsertionPoint(target
);
182 while (numConsumerToFuse
--) {
183 FailureOr
<scf::SCFFuseConsumerOfSliceResult
> fuseConsumerResults
=
184 scf::tileAndFuseConsumerOfSlice(rewriter
, target
);
186 if (failed(fuseConsumerResults
))
189 // Report back the relevant handles to the transform op.
190 originalConsumerOps
.push_back(
191 fuseConsumerResults
->origConsumerOperand
->getOwner());
192 fusedConsumerOps
.push_back(
193 fuseConsumerResults
->tiledAndFusedConsumerOperand
->getOwner());
197 transformResults
.set(transformOp
->getOpResult(0), originalConsumerOps
);
198 transformResults
.set(transformOp
->getOpResult(1), fusedConsumerOps
);
202 DiagnosedSilenceableFailure
203 transform::TestFuseConsumerOp::apply(TransformRewriter
&rewriter
,
204 TransformResults
&transformResults
,
205 TransformState
&state
) {
206 LogicalResult result
= applyFuseConsumer(
207 rewriter
, getOperation(), state
.getPayloadOps(getTarget()),
208 getNumConsumerToFuse(), transformResults
);
209 return failed(result
) ? DiagnosedSilenceableFailure::definiteFailure()
210 : DiagnosedSilenceableFailure::success();
213 void transform::TestFuseConsumerOp::getEffects(
214 SmallVectorImpl
<MemoryEffects::EffectInstance
> &effects
) {
215 consumesHandle(getTargetMutable(), effects
);
216 producesHandle(getOperation()->getOpResults(), effects
);
217 modifiesPayload(effects
);
220 //===----------------------------------------------------------------------===//
221 // TestTileUsingForallOp
222 //===----------------------------------------------------------------------===//
224 /// Apply a tiling transformation to all payload ops and store both the
225 /// tiled operation as well as the created tile loops.
226 template <typename Range
>
228 applyTileToAll(RewriterBase
&rewriter
, Operation
*transformOp
,
229 Range
&&payloadOps
, ArrayRef
<OpFoldResult
> tileSizes
,
230 ArrayRef
<int64_t> interchange
, std::optional
<ArrayAttr
> mapping
,
231 TransformResults
&transformResults
) {
232 SmallVector
<Operation
*> tiledOps
;
233 SmallVector
<Operation
*> loopOps
;
235 for (Operation
*target
: payloadOps
) {
236 auto tilingInterfaceOp
= dyn_cast
<TilingInterface
>(target
);
237 if (!tilingInterfaceOp
)
238 return transformOp
->emitError("only TilingInterface ops are supported");
239 scf::SCFTilingOptions tilingOptions
;
240 tilingOptions
.setTileSizes(tileSizes
).setInterchange(interchange
);
242 tilingOptions
.setMapping(mapping
.value().getValue());
244 tilingOptions
.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp
);
246 rewriter
.setInsertionPoint(target
);
247 FailureOr
<scf::SCFTilingResult
> tiledResults
=
248 scf::tileUsingSCF(rewriter
, tilingInterfaceOp
, tilingOptions
);
249 if (failed(tiledResults
))
252 // Perform the replacement of tiled and fused values.
253 rewriter
.replaceOp(tilingInterfaceOp
,
254 tiledResults
->mergeResult
.replacements
);
256 // Report back the relevant handles to the transform op.
257 tiledOps
.push_back(tiledResults
->tiledOps
.front());
258 for (Operation
*loop
: tiledResults
->loops
)
259 loopOps
.push_back(loop
);
262 transformResults
.set(transformOp
->getOpResult(0), tiledOps
);
263 for (auto [index
, loop
] : llvm::enumerate(loopOps
))
264 transformResults
.set(transformOp
->getOpResult(index
+ 1), {loop
});
269 DiagnosedSilenceableFailure
270 transform::TestTileUsingForallOp::apply(TransformRewriter
&rewriter
,
271 TransformResults
&transformResults
,
272 TransformState
&state
) {
273 SmallVector
<int64_t> tileSizes
=
274 extractFromIntegerArrayAttr
<int64_t>(getTileSizes());
275 SmallVector
<int64_t> interchange
=
276 extractFromIntegerArrayAttr
<int64_t>(getInterchange());
277 SmallVector
<OpFoldResult
> tileSizesOfr
=
278 getAsIndexOpFoldResult(rewriter
.getContext(), tileSizes
);
280 LogicalResult result
=
281 applyTileToAll(rewriter
, getOperation(), state
.getPayloadOps(getTarget()),
282 tileSizesOfr
, interchange
, getMapping(), transformResults
);
283 return failed(result
) ? DiagnosedSilenceableFailure::definiteFailure()
284 : DiagnosedSilenceableFailure::success();
287 void transform::TestTileUsingForallOp::getEffects(
288 SmallVectorImpl
<MemoryEffects::EffectInstance
> &effects
) {
289 consumesHandle(getTargetMutable(), effects
);
290 producesHandle(getOperation()->getOpResults(), effects
);
291 modifiesPayload(effects
);
294 //===----------------------------------------------------------------------===//
295 // TestFuseUsingForallOp
296 //===----------------------------------------------------------------------===//
298 /// Apply a tiling transformation to all payload ops and store both the
299 /// tiled operation as well as the created tile loops.
300 template <typename Range
>
301 static LogicalResult
applyTilingToAll(
302 RewriterBase
&rewriter
, Operation
*transformOp
, Range
&&payloadOps
,
303 unsigned numLoops
, TransformResults
&transformResults
,
304 function_ref
<FailureOr
<scf::SCFTileAndFuseResult
>(TilingInterface
)>
306 SmallVector
<Operation
*> tiledLinalgOps
;
307 SmallVector
<SmallVector
<Operation
*>> loopOps(1);
309 for (Operation
*target
: payloadOps
) {
310 auto tilingInterfaceOp
= dyn_cast
<TilingInterface
>(target
);
311 if (!tilingInterfaceOp
)
312 return transformOp
->emitError("only TilingInterface ops are supported");
314 rewriter
.setInsertionPoint(target
);
315 FailureOr
<scf::SCFTileAndFuseResult
> tiledResults
=
316 applyFn(tilingInterfaceOp
);
317 if (failed(tiledResults
))
320 // Perform the replacement of tiled and fused values.
321 SmallVector
<Operation
*> opsToReplace
{target
};
322 llvm::append_range(opsToReplace
, tiledResults
->fusedProducers
);
323 for (Operation
*toReplace
: opsToReplace
) {
324 for (OpResult res
: toReplace
->getResults())
325 if (auto replacement
= tiledResults
->replacements
.lookup(res
))
326 rewriter
.replaceAllUsesWith(res
, replacement
);
327 if (toReplace
->use_empty())
328 rewriter
.eraseOp(toReplace
);
331 // Report back the relevant handles to the transform op.
332 tiledLinalgOps
.push_back(tiledResults
->tiledAndFusedOps
.front());
333 assert(tiledResults
->loops
.size() == 1 &&
334 cast
<scf::ForallOp
>(tiledResults
->loops
[0]).getRank() == numLoops
&&
335 "Mismatched number of loops, tile and fuse transform should have "
337 loopOps
[0] = {tiledResults
->loops
[0]};
340 transformResults
.set(transformOp
->getOpResult(0), tiledLinalgOps
);
341 if (!loopOps
.empty())
342 transformResults
.set(transformOp
->getOpResult(1), loopOps
[0]);
347 DiagnosedSilenceableFailure
348 transform::TestFuseUsingForallOp::apply(TransformRewriter
&rewriter
,
349 TransformResults
&transformResults
,
350 TransformState
&state
) {
351 SmallVector
<int64_t> tileSizes
=
352 extractFromIntegerArrayAttr
<int64_t>(getTileSizes());
353 SmallVector
<int64_t> tileInterchange
=
354 extractFromIntegerArrayAttr
<int64_t>(getInterchange());
356 scf::SCFTilingOptions tilingOptions
;
357 tilingOptions
.interchangeVector
= tileInterchange
;
358 SmallVector
<OpFoldResult
> tileSizesOfr
=
359 getAsIndexOpFoldResult(rewriter
.getContext(), tileSizes
);
360 tilingOptions
= tilingOptions
.setTileSizes(tileSizesOfr
);
361 tilingOptions
.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp
);
362 scf::SCFTileAndFuseOptions tileAndFuseOptions
;
363 tileAndFuseOptions
.tilingOptions
= tilingOptions
;
364 LogicalResult result
= applyTilingToAll(
365 rewriter
, getOperation(), state
.getPayloadOps(getRootOp()),
366 tileSizes
.size() - llvm::count(tileSizes
, 0), transformResults
,
367 [&](TilingInterface tilingInterfaceOp
)
368 -> FailureOr
<scf::SCFTileAndFuseResult
> {
369 return tileConsumerAndFuseProducersUsingSCF(rewriter
, tilingInterfaceOp
,
372 return failed(result
) ? DiagnosedSilenceableFailure::definiteFailure()
373 : DiagnosedSilenceableFailure::success();
376 void transform::TestFuseUsingForallOp::getEffects(
377 SmallVectorImpl
<MemoryEffects::EffectInstance
> &effects
) {
378 consumesHandle(getRootOpMutable(), effects
);
379 producesHandle(getOperation()->getOpResults(), effects
);
380 modifiesPayload(effects
);
383 #define GET_OP_CLASSES
384 #include "TestTilingInterfaceTransformOps.cpp.inc"
387 class TestTilingInterfaceDialectExtension
388 : public transform::TransformDialectExtension
<
389 TestTilingInterfaceDialectExtension
> {
391 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
392 TestTilingInterfaceDialectExtension
)
397 declareDependentDialect
<affine::AffineDialect
>();
398 declareDependentDialect
<index::IndexDialect
>();
399 declareDependentDialect
<scf::SCFDialect
>();
400 declareDependentDialect
<tensor::TensorDialect
>();
402 registerTransformOps
<
404 #include "TestTilingInterfaceTransformOps.cpp.inc"
411 void registerTestTilingInterfaceTransformDialectExtension(
412 DialectRegistry
®istry
) {
413 registry
.addExtensions
<TestTilingInterfaceDialectExtension
>();