[hwasan] Omit tag check for null pointers (#122206)
[llvm-project.git] / mlir / test / lib / Interfaces / TilingInterface / TestTilingInterfaceTransformOps.cpp
blob7380b766935ffea799fa2a38741c4c3e6c79a795
1 //===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines transform dialect operations used for testing
10 // TilingInterface
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Index/IR/IndexDialect.h"
16 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
17 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
18 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
19 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/IR/Dominance.h"
22 #include "mlir/IR/OpImplementation.h"
23 #include "mlir/Interfaces/TilingInterface.h"
25 #define GET_OP_CLASSES
26 #include "TestTilingInterfaceTransformOps.h.inc"
28 using namespace mlir;
29 using namespace mlir::transform;
31 //===----------------------------------------------------------------------===//
32 // TestFuseAndYieldOp
33 //===----------------------------------------------------------------------===//
35 static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
36 SmallVector<Operation *> worklist;
37 llvm::SmallDenseSet<Operation *> producers;
38 worklist.push_back(op);
39 producers.insert(op);
40 while (!worklist.empty()) {
41 Operation *current = worklist.pop_back_val();
42 for (OpOperand &operand : current->getOpOperands()) {
43 Operation *producer = operand.get().getDefiningOp();
44 if (!producer || !isa<TilingInterface>(producer) ||
45 producers.contains(producer))
46 continue;
47 worklist.push_back(producer);
48 producers.insert(producer);
51 return producers;
54 /// Apply a tile and fuse transformation to all payload ops and store both the
55 /// tiled operation as well as the created tile loops.
56 template <typename Range>
57 static LogicalResult
58 applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
59 Range &&payloadOps, unsigned numLoops,
60 ArrayRef<OpFoldResult> tileSizes,
61 ArrayRef<int64_t> interchange, bool useForall,
62 TransformResults &transformResults) {
63 SmallVector<Operation *> tiledOps;
64 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
66 for (Operation *target : payloadOps) {
67 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
68 if (!tilingInterfaceOp)
69 return transformOp->emitError("only TilingInterface ops are supported");
70 DominanceInfo dominanceInfo(tilingInterfaceOp);
72 llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
73 collectTiledAndFusedOps(tilingInterfaceOp);
74 llvm::DenseSet<Operation *> yieldReplacementsFor;
75 for (auto op : tiledAndFusedOps) {
76 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
77 return dominanceInfo.properlyDominates(tilingInterfaceOp, user);
78 })) {
79 yieldReplacementsFor.insert(op);
83 scf::SCFTilingOptions tilingOptions;
84 tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
85 if (useForall) {
86 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
89 scf::SCFTileAndFuseOptions tileAndFuseOptions;
90 tileAndFuseOptions.setTilingOptions(tilingOptions);
92 scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
93 [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
94 bool isDestinationOperand)
95 -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
96 Operation *owner = originalProducer.getOwner();
97 bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
98 return scf::SCFTileAndFuseOptions::ControlFnResult{
99 yieldProducerReplacement};
101 tileAndFuseOptions.setFusionControlFn(controlFn);
103 rewriter.setInsertionPoint(target);
104 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
105 scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
106 tileAndFuseOptions);
107 if (failed(tiledResults))
108 return failure();
110 // Perform the replacement of tiled and fused values.
111 SmallVector<Operation *> opsToReplace{target};
112 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
113 for (Operation *toReplace : opsToReplace) {
114 for (OpResult res : toReplace->getResults())
115 if (auto replacement = tiledResults->replacements.lookup(res)) {
116 Operation *replacementOp = replacement.getDefiningOp();
117 rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
118 Operation *user = use.getOwner();
119 return dominanceInfo.properlyDominates(replacementOp, user) &&
120 user->getParentOp() == replacementOp->getParentOp();
124 if (toReplace->use_empty()) {
125 rewriter.eraseOp(toReplace);
129 // Report back the relevant handles to the transform op.
130 tiledOps.push_back(tiledResults->tiledAndFusedOps.front());
131 assert(tiledResults->loops.size() == numLoops &&
132 "Mismatched number of loops, tile and fuse transform should have "
133 "failed");
134 for (unsigned int i = 0; i < numLoops; ++i)
135 loopOps[i].push_back(tiledResults->loops[i]);
138 transformResults.set(transformOp->getOpResult(0), tiledOps);
139 for (unsigned int i = 0; i < numLoops; ++i)
140 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
142 return success();
145 DiagnosedSilenceableFailure
146 transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
147 TransformResults &transformResults,
148 TransformState &state) {
149 SmallVector<int64_t> tileSizes =
150 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
151 SmallVector<int64_t> tileInterchange =
152 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
154 SmallVector<OpFoldResult> tileSizesOfr =
155 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
157 LogicalResult result = applyTileAndFuseToAll(
158 rewriter, getOperation(), state.getPayloadOps(getTarget()),
159 tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
160 tileInterchange, getUseForall(), transformResults);
161 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
162 : DiagnosedSilenceableFailure::success();
165 //===----------------------------------------------------------------------===//
166 // TestFuseConsumerOp
167 //===----------------------------------------------------------------------===//
169 /// Apply fusing of consumer transformation to all payload ops and store both
170 /// the original consumer operation as well as the fused consumer operation.
171 template <typename Range>
172 static LogicalResult
173 applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
174 Range &&payloadOps, uint32_t numConsumerToFuse,
175 TransformResults &transformResults) {
176 SmallVector<Operation *> originalConsumerOps;
177 SmallVector<Operation *> fusedConsumerOps;
179 for (Operation *target : payloadOps) {
180 rewriter.setInsertionPoint(target);
182 while (numConsumerToFuse--) {
183 FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
184 scf::tileAndFuseConsumerOfSlice(rewriter, target);
186 if (failed(fuseConsumerResults))
187 return failure();
189 // Report back the relevant handles to the transform op.
190 originalConsumerOps.push_back(
191 fuseConsumerResults->origConsumerOperand->getOwner());
192 fusedConsumerOps.push_back(
193 fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
197 transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
198 transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
199 return success();
202 DiagnosedSilenceableFailure
203 transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
204 TransformResults &transformResults,
205 TransformState &state) {
206 LogicalResult result = applyFuseConsumer(
207 rewriter, getOperation(), state.getPayloadOps(getTarget()),
208 getNumConsumerToFuse(), transformResults);
209 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
210 : DiagnosedSilenceableFailure::success();
213 void transform::TestFuseConsumerOp::getEffects(
214 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
215 consumesHandle(getTargetMutable(), effects);
216 producesHandle(getOperation()->getOpResults(), effects);
217 modifiesPayload(effects);
220 //===----------------------------------------------------------------------===//
221 // TestTileUsingForallOp
222 //===----------------------------------------------------------------------===//
224 /// Apply a tiling transformation to all payload ops and store both the
225 /// tiled operation as well as the created tile loops.
226 template <typename Range>
227 static LogicalResult
228 applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
229 Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes,
230 ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping,
231 TransformResults &transformResults) {
232 SmallVector<Operation *> tiledOps;
233 SmallVector<Operation *> loopOps;
235 for (Operation *target : payloadOps) {
236 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
237 if (!tilingInterfaceOp)
238 return transformOp->emitError("only TilingInterface ops are supported");
239 scf::SCFTilingOptions tilingOptions;
240 tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
241 if (mapping) {
242 tilingOptions.setMapping(mapping.value().getValue());
244 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
246 rewriter.setInsertionPoint(target);
247 FailureOr<scf::SCFTilingResult> tiledResults =
248 scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions);
249 if (failed(tiledResults))
250 return failure();
252 // Perform the replacement of tiled and fused values.
253 rewriter.replaceOp(tilingInterfaceOp,
254 tiledResults->mergeResult.replacements);
256 // Report back the relevant handles to the transform op.
257 tiledOps.push_back(tiledResults->tiledOps.front());
258 for (Operation *loop : tiledResults->loops)
259 loopOps.push_back(loop);
262 transformResults.set(transformOp->getOpResult(0), tiledOps);
263 for (auto [index, loop] : llvm::enumerate(loopOps))
264 transformResults.set(transformOp->getOpResult(index + 1), {loop});
266 return success();
269 DiagnosedSilenceableFailure
270 transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter,
271 TransformResults &transformResults,
272 TransformState &state) {
273 SmallVector<int64_t> tileSizes =
274 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
275 SmallVector<int64_t> interchange =
276 extractFromIntegerArrayAttr<int64_t>(getInterchange());
277 SmallVector<OpFoldResult> tileSizesOfr =
278 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
280 LogicalResult result =
281 applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()),
282 tileSizesOfr, interchange, getMapping(), transformResults);
283 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
284 : DiagnosedSilenceableFailure::success();
287 void transform::TestTileUsingForallOp::getEffects(
288 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
289 consumesHandle(getTargetMutable(), effects);
290 producesHandle(getOperation()->getOpResults(), effects);
291 modifiesPayload(effects);
294 //===----------------------------------------------------------------------===//
295 // TestFuseUsingForallOp
296 //===----------------------------------------------------------------------===//
298 /// Apply a tiling transformation to all payload ops and store both the
299 /// tiled operation as well as the created tile loops.
300 template <typename Range>
301 static LogicalResult applyTilingToAll(
302 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
303 unsigned numLoops, TransformResults &transformResults,
304 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
305 applyFn) {
306 SmallVector<Operation *> tiledLinalgOps;
307 SmallVector<SmallVector<Operation *>> loopOps(1);
309 for (Operation *target : payloadOps) {
310 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
311 if (!tilingInterfaceOp)
312 return transformOp->emitError("only TilingInterface ops are supported");
314 rewriter.setInsertionPoint(target);
315 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
316 applyFn(tilingInterfaceOp);
317 if (failed(tiledResults))
318 return failure();
320 // Perform the replacement of tiled and fused values.
321 SmallVector<Operation *> opsToReplace{target};
322 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
323 for (Operation *toReplace : opsToReplace) {
324 for (OpResult res : toReplace->getResults())
325 if (auto replacement = tiledResults->replacements.lookup(res))
326 rewriter.replaceAllUsesWith(res, replacement);
327 if (toReplace->use_empty())
328 rewriter.eraseOp(toReplace);
331 // Report back the relevant handles to the transform op.
332 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
333 assert(tiledResults->loops.size() == 1 &&
334 cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
335 "Mismatched number of loops, tile and fuse transform should have "
336 "failed");
337 loopOps[0] = {tiledResults->loops[0]};
340 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
341 if (!loopOps.empty())
342 transformResults.set(transformOp->getOpResult(1), loopOps[0]);
344 return success();
347 DiagnosedSilenceableFailure
348 transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter,
349 TransformResults &transformResults,
350 TransformState &state) {
351 SmallVector<int64_t> tileSizes =
352 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
353 SmallVector<int64_t> tileInterchange =
354 extractFromIntegerArrayAttr<int64_t>(getInterchange());
356 scf::SCFTilingOptions tilingOptions;
357 tilingOptions.interchangeVector = tileInterchange;
358 SmallVector<OpFoldResult> tileSizesOfr =
359 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
360 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
361 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
362 scf::SCFTileAndFuseOptions tileAndFuseOptions;
363 tileAndFuseOptions.tilingOptions = tilingOptions;
364 LogicalResult result = applyTilingToAll(
365 rewriter, getOperation(), state.getPayloadOps(getRootOp()),
366 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
367 [&](TilingInterface tilingInterfaceOp)
368 -> FailureOr<scf::SCFTileAndFuseResult> {
369 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
370 tileAndFuseOptions);
372 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
373 : DiagnosedSilenceableFailure::success();
376 void transform::TestFuseUsingForallOp::getEffects(
377 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
378 consumesHandle(getRootOpMutable(), effects);
379 producesHandle(getOperation()->getOpResults(), effects);
380 modifiesPayload(effects);
383 #define GET_OP_CLASSES
384 #include "TestTilingInterfaceTransformOps.cpp.inc"
386 namespace {
387 class TestTilingInterfaceDialectExtension
388 : public transform::TransformDialectExtension<
389 TestTilingInterfaceDialectExtension> {
390 public:
391 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
392 TestTilingInterfaceDialectExtension)
394 using Base::Base;
396 void init() {
397 declareDependentDialect<affine::AffineDialect>();
398 declareDependentDialect<index::IndexDialect>();
399 declareDependentDialect<scf::SCFDialect>();
400 declareDependentDialect<tensor::TensorDialect>();
402 registerTransformOps<
403 #define GET_OP_LIST
404 #include "TestTilingInterfaceTransformOps.cpp.inc"
405 >();
408 } // namespace
410 namespace test {
411 void registerTestTilingInterfaceTransformDialectExtension(
412 DialectRegistry &registry) {
413 registry.addExtensions<TestTilingInterfaceDialectExtension>();
415 } // namespace test