Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Interfaces / TilingInterface / TestTilingInterfaceTransformOps.cpp
blob5e903e378daf8266915a3123cf1406c4ba7aab0f
1 //===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines transform dialect operations used for testing
10 // TilingInterface
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Index/IR/IndexDialect.h"
16 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
17 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
18 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
19 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/IR/Dominance.h"
22 #include "mlir/IR/OpImplementation.h"
23 #include "mlir/Interfaces/TilingInterface.h"
25 #define GET_OP_CLASSES
26 #include "TestTilingInterfaceTransformOps.h.inc"
28 using namespace mlir;
29 using namespace mlir::transform;
31 //===----------------------------------------------------------------------===//
32 // TestFuseAndYieldOp
33 //===----------------------------------------------------------------------===//
35 static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
36 SmallVector<Operation *> worklist;
37 llvm::SmallDenseSet<Operation *> producers;
38 worklist.push_back(op);
39 producers.insert(op);
40 while (!worklist.empty()) {
41 Operation *current = worklist.pop_back_val();
42 for (OpOperand &operand : current->getOpOperands()) {
43 Operation *producer = operand.get().getDefiningOp();
44 if (!producer || !isa<TilingInterface>(producer) ||
45 producers.contains(producer))
46 continue;
47 worklist.push_back(producer);
48 producers.insert(producer);
51 return producers;
54 /// Apply a tile and fuse transformation to all payload ops and store both the
55 /// tiled operation as well as the created tile loops.
56 template <typename Range>
57 static LogicalResult
58 applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
59 Range &&payloadOps, unsigned numLoops,
60 ArrayRef<OpFoldResult> tileSizes,
61 ArrayRef<int64_t> interchange, bool useForall,
62 TransformResults &transformResults) {
63 SmallVector<Operation *> tiledOps;
64 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
66 for (Operation *target : payloadOps) {
67 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
68 if (!tilingInterfaceOp)
69 return transformOp->emitError("only TilingInterface ops are supported");
70 DominanceInfo dominanceInfo(tilingInterfaceOp);
72 llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
73 collectTiledAndFusedOps(tilingInterfaceOp);
74 llvm::DenseSet<Operation *> yieldReplacementsFor;
75 for (auto op : tiledAndFusedOps) {
76 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
77 return dominanceInfo.properlyDominates(tilingInterfaceOp, user);
78 })) {
79 yieldReplacementsFor.insert(op);
83 scf::SCFTilingOptions tilingOptions;
84 tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
85 if (useForall) {
86 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
89 scf::SCFTileAndFuseOptions tileAndFuseOptions;
90 tileAndFuseOptions.setTilingOptions(tilingOptions);
92 scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
93 [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
94 bool isDestinationOperand)
95 -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
96 Operation *owner = originalProducer.getOwner();
97 bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
98 return scf::SCFTileAndFuseOptions::ControlFnResult{
99 yieldProducerReplacement};
101 tileAndFuseOptions.setFusionControlFn(controlFn);
103 rewriter.setInsertionPoint(target);
104 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
105 scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
106 tileAndFuseOptions);
107 if (failed(tiledResults))
108 return failure();
110 // Perform the replacement of tiled and fused values.
111 SmallVector<Operation *> opsToReplace{target};
112 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
113 for (Operation *toReplace : opsToReplace) {
114 for (OpResult res : toReplace->getResults())
115 if (auto replacement = tiledResults->replacements.lookup(res)) {
116 Operation *replacementOp = replacement.getDefiningOp();
117 rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
118 Operation *user = use.getOwner();
119 return dominanceInfo.properlyDominates(replacementOp, user) &&
120 user->getParentOp() == replacementOp->getParentOp();
124 if (toReplace->use_empty()) {
125 rewriter.eraseOp(toReplace);
129 // Report back the relevant handles to the transform op.
130 tiledOps.push_back(tiledResults->tiledAndFusedOps.front());
131 assert(tiledResults->loops.size() == numLoops &&
132 "Mismatched number of loops, tile and fuse transform should have "
133 "failed");
134 for (unsigned int i = 0; i < numLoops; ++i)
135 loopOps[i].push_back(tiledResults->loops[i]);
138 transformResults.set(transformOp->getOpResult(0), tiledOps);
139 for (unsigned int i = 0; i < numLoops; ++i)
140 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
142 return success();
145 DiagnosedSilenceableFailure
146 transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
147 TransformResults &transformResults,
148 TransformState &state) {
149 SmallVector<int64_t> tileSizes =
150 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
151 SmallVector<int64_t> tileInterchange =
152 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
154 SmallVector<OpFoldResult> tileSizesOfr =
155 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
157 LogicalResult result = applyTileAndFuseToAll(
158 rewriter, getOperation(), state.getPayloadOps(getTarget()),
159 tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
160 tileInterchange, getUseForall(), transformResults);
161 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
162 : DiagnosedSilenceableFailure::success();
165 //===----------------------------------------------------------------------===//
166 // TestFuseConsumerOp
167 //===----------------------------------------------------------------------===//
169 /// Apply fusing of consumer transformation to all payload ops and store both
170 /// the original consumer operation as well as the fused consumer operation.
171 template <typename Range>
172 static LogicalResult
173 applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
174 Range &&payloadOps, uint32_t numConsumerToFuse,
175 TransformResults &transformResults) {
176 SmallVector<Operation *> originalConsumerOps;
177 SmallVector<Operation *> fusedConsumerOps;
179 for (Operation *target : payloadOps) {
180 rewriter.setInsertionPoint(target);
182 while (numConsumerToFuse--) {
183 FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
184 scf::tileAndFuseConsumerOfSlice(rewriter, target);
186 if (failed(fuseConsumerResults))
187 return failure();
189 // Report back the relevant handles to the transform op.
190 originalConsumerOps.push_back(
191 fuseConsumerResults->origConsumerOperand->getOwner());
192 fusedConsumerOps.push_back(
193 fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
197 transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
198 transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
199 return success();
202 DiagnosedSilenceableFailure
203 transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
204 TransformResults &transformResults,
205 TransformState &state) {
206 LogicalResult result = applyFuseConsumer(
207 rewriter, getOperation(), state.getPayloadOps(getTarget()),
208 getNumConsumerToFuse(), transformResults);
209 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
210 : DiagnosedSilenceableFailure::success();
213 void transform::TestFuseConsumerOp::getEffects(
214 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
215 consumesHandle(getTargetMutable(), effects);
216 producesHandle(getOperation()->getOpResults(), effects);
217 modifiesPayload(effects);
220 //===----------------------------------------------------------------------===//
221 // TestTileUsingForallOp
222 //===----------------------------------------------------------------------===//
224 /// Apply a tiling transformation to all payload ops and store both the
225 /// tiled operation as well as the created tile loops.
226 template <typename Range>
227 static LogicalResult
228 applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
229 Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes,
230 ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping,
231 TransformResults &transformResults) {
232 SmallVector<Operation *> tiledOps;
233 SmallVector<Operation *> loopOps;
235 for (Operation *target : payloadOps) {
236 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
237 if (!tilingInterfaceOp)
238 return transformOp->emitError("only TilingInterface ops are supported");
239 scf::SCFTilingOptions tilingOptions;
240 tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
241 if (mapping) {
242 tilingOptions.setMapping(mapping.value().getValue());
244 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
246 rewriter.setInsertionPoint(target);
247 FailureOr<scf::SCFTilingResult> tiledResults =
248 scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions);
249 if (failed(tiledResults))
250 return failure();
252 // Perform the replacement of tiled and fused values.
253 rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements);
255 // Report back the relevant handles to the transform op.
256 tiledOps.push_back(tiledResults->tiledOps.front());
257 for (Operation *loop : tiledResults->loops)
258 loopOps.push_back(loop);
261 transformResults.set(transformOp->getOpResult(0), tiledOps);
262 for (auto [index, loop] : llvm::enumerate(loopOps))
263 transformResults.set(transformOp->getOpResult(index + 1), {loop});
265 return success();
268 DiagnosedSilenceableFailure
269 transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter,
270 TransformResults &transformResults,
271 TransformState &state) {
272 SmallVector<int64_t> tileSizes =
273 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
274 SmallVector<int64_t> interchange =
275 extractFromIntegerArrayAttr<int64_t>(getInterchange());
276 SmallVector<OpFoldResult> tileSizesOfr =
277 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
279 LogicalResult result =
280 applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()),
281 tileSizesOfr, interchange, getMapping(), transformResults);
282 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
283 : DiagnosedSilenceableFailure::success();
286 void transform::TestTileUsingForallOp::getEffects(
287 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
288 consumesHandle(getTargetMutable(), effects);
289 producesHandle(getOperation()->getOpResults(), effects);
290 modifiesPayload(effects);
293 //===----------------------------------------------------------------------===//
294 // TestFuseUsingForallOp
295 //===----------------------------------------------------------------------===//
297 /// Apply a tiling transformation to all payload ops and store both the
298 /// tiled operation as well as the created tile loops.
299 template <typename Range>
300 static LogicalResult applyTilingToAll(
301 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
302 unsigned numLoops, TransformResults &transformResults,
303 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
304 applyFn) {
305 SmallVector<Operation *> tiledLinalgOps;
306 SmallVector<SmallVector<Operation *>> loopOps(1);
308 for (Operation *target : payloadOps) {
309 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
310 if (!tilingInterfaceOp)
311 return transformOp->emitError("only TilingInterface ops are supported");
313 rewriter.setInsertionPoint(target);
314 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
315 applyFn(tilingInterfaceOp);
316 if (failed(tiledResults))
317 return failure();
319 // Perform the replacement of tiled and fused values.
320 SmallVector<Operation *> opsToReplace{target};
321 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
322 for (Operation *toReplace : opsToReplace) {
323 for (OpResult res : toReplace->getResults())
324 if (auto replacement = tiledResults->replacements.lookup(res))
325 rewriter.replaceAllUsesWith(res, replacement);
326 if (toReplace->use_empty())
327 rewriter.eraseOp(toReplace);
330 // Report back the relevant handles to the transform op.
331 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
332 assert(tiledResults->loops.size() == 1 &&
333 cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
334 "Mismatched number of loops, tile and fuse transform should have "
335 "failed");
336 loopOps[0] = {tiledResults->loops[0]};
339 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
340 if (!loopOps.empty())
341 transformResults.set(transformOp->getOpResult(1), loopOps[0]);
343 return success();
346 DiagnosedSilenceableFailure
347 transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter,
348 TransformResults &transformResults,
349 TransformState &state) {
350 SmallVector<int64_t> tileSizes =
351 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
352 SmallVector<int64_t> tileInterchange =
353 extractFromIntegerArrayAttr<int64_t>(getInterchange());
355 scf::SCFTilingOptions tilingOptions;
356 tilingOptions.interchangeVector = tileInterchange;
357 SmallVector<OpFoldResult> tileSizesOfr =
358 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
359 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
360 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
361 scf::SCFTileAndFuseOptions tileAndFuseOptions;
362 tileAndFuseOptions.tilingOptions = tilingOptions;
363 LogicalResult result = applyTilingToAll(
364 rewriter, getOperation(), state.getPayloadOps(getRootOp()),
365 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
366 [&](TilingInterface tilingInterfaceOp)
367 -> FailureOr<scf::SCFTileAndFuseResult> {
368 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
369 tileAndFuseOptions);
371 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
372 : DiagnosedSilenceableFailure::success();
375 void transform::TestFuseUsingForallOp::getEffects(
376 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
377 consumesHandle(getRootOpMutable(), effects);
378 producesHandle(getOperation()->getOpResults(), effects);
379 modifiesPayload(effects);
382 #define GET_OP_CLASSES
383 #include "TestTilingInterfaceTransformOps.cpp.inc"
385 namespace {
386 class TestTilingInterfaceDialectExtension
387 : public transform::TransformDialectExtension<
388 TestTilingInterfaceDialectExtension> {
389 public:
390 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
391 TestTilingInterfaceDialectExtension)
393 using Base::Base;
395 void init() {
396 declareDependentDialect<affine::AffineDialect>();
397 declareDependentDialect<index::IndexDialect>();
398 declareDependentDialect<scf::SCFDialect>();
399 declareDependentDialect<tensor::TensorDialect>();
401 registerTransformOps<
402 #define GET_OP_LIST
403 #include "TestTilingInterfaceTransformOps.cpp.inc"
404 >();
407 } // namespace
409 namespace test {
410 void registerTestTilingInterfaceTransformDialectExtension(
411 DialectRegistry &registry) {
412 registry.addExtensions<TestTilingInterfaceDialectExtension>();
414 } // namespace test