Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / Test / TestPatterns.cpp
blob3df6cff3c0a60b394015e463b7abac7f9d4574ec
1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include <cstdint>
27 using namespace mlir;
28 using namespace test;
30 // Native function for testing NativeCodeCall
31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
32 return choice.getValue() ? input1 : input2;
35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
36 rewriter.create<OpI>(loc, input);
39 static void handleNoResultOp(PatternRewriter &rewriter,
40 OpSymbolBindingNoResult op) {
41 // Turn the no result op to a one-result op.
42 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
43 op.getOperand());
46 static bool getFirstI32Result(Operation *op, Value &value) {
47 if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
48 return false;
49 value = op->getResult(0);
50 return true;
53 static Value bindNativeCodeCallResult(Value value) { return value; }
55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
56 Value input2) {
57 return SmallVector<Value, 2>({input2, input1});
60 // Test that natives calls are only called once during rewrites.
61 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
62 // This let us check the number of times OpM_Test was called by inspecting
63 // the returned value in the MLIR output.
64 static int64_t opMIncreasingValue = 314159265;
65 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
66 int64_t i = opMIncreasingValue++;
67 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
70 namespace {
71 #include "TestPatterns.inc"
72 } // namespace
74 //===----------------------------------------------------------------------===//
75 // Test Reduce Pattern Interface
76 //===----------------------------------------------------------------------===//
78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
79 populateWithGenerated(patterns);
82 //===----------------------------------------------------------------------===//
83 // Canonicalizer Driver.
84 //===----------------------------------------------------------------------===//
86 namespace {
87 struct FoldingPattern : public RewritePattern {
88 public:
89 FoldingPattern(MLIRContext *context)
90 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
91 /*benefit=*/1, context) {}
93 LogicalResult matchAndRewrite(Operation *op,
94 PatternRewriter &rewriter) const override {
95 // Exercise createOrFold API for a single-result operation that is folded
96 // upon construction. The operation being created has an in-place folder,
97 // and it should be still present in the output. Furthermore, the folder
98 // should not crash when attempting to recover the (unchanged) operation
99 // result.
100 Value result = rewriter.createOrFold<TestOpInPlaceFold>(
101 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
102 assert(result);
103 rewriter.replaceOp(op, result);
104 return success();
108 /// This pattern creates a foldable operation at the entry point of the block.
109 /// This tests the situation where the operation folder will need to replace an
110 /// operation with a previously created constant that does not initially
111 /// dominate the operation to replace.
112 struct FolderInsertBeforePreviouslyFoldedConstantPattern
113 : public OpRewritePattern<TestCastOp> {
114 public:
115 using OpRewritePattern<TestCastOp>::OpRewritePattern;
117 LogicalResult matchAndRewrite(TestCastOp op,
118 PatternRewriter &rewriter) const override {
119 if (!op->hasAttr("test_fold_before_previously_folded_op"))
120 return failure();
121 rewriter.setInsertionPointToStart(op->getBlock());
123 auto constOp = rewriter.create<arith::ConstantOp>(
124 op.getLoc(), rewriter.getBoolAttr(true));
125 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
126 Value(constOp));
127 return success();
131 /// This pattern matches test.op_commutative2 with the first operand being
132 /// another test.op_commutative2 with a constant on the right side and fold it
133 /// away by propagating it as its result. This is intend to check that patterns
134 /// are applied after the commutative property moves constant to the right.
135 struct FolderCommutativeOp2WithConstant
136 : public OpRewritePattern<TestCommutative2Op> {
137 public:
138 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
140 LogicalResult matchAndRewrite(TestCommutative2Op op,
141 PatternRewriter &rewriter) const override {
142 auto operand =
143 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
144 if (!operand)
145 return failure();
146 Attribute constInput;
147 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
148 return failure();
149 rewriter.replaceOp(op, operand->getOperand(1));
150 return success();
154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
155 /// attribute with value smaller than MaxVal, it increments the value by 1.
156 template <int MaxVal>
157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
158 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
160 LogicalResult matchAndRewrite(AnyAttrOfOp op,
161 PatternRewriter &rewriter) const override {
162 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
163 if (!intAttr)
164 return failure();
165 int64_t val = intAttr.getInt();
166 if (val >= MaxVal)
167 return failure();
168 rewriter.modifyOpInPlace(
169 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
170 return success();
174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175 struct MakeOpEligible : public RewritePattern {
176 MakeOpEligible(MLIRContext *context)
177 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
179 LogicalResult matchAndRewrite(Operation *op,
180 PatternRewriter &rewriter) const override {
181 if (op->hasAttr("eligible"))
182 return failure();
183 rewriter.modifyOpInPlace(
184 op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
185 return success();
189 /// This pattern hoists eligible ops out of a "test.one_region_op".
190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
191 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
193 LogicalResult matchAndRewrite(test::OneRegionOp op,
194 PatternRewriter &rewriter) const override {
195 Operation *terminator = op.getRegion().front().getTerminator();
196 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
197 if (toBeHoisted->getParentOp() != op)
198 return failure();
199 if (!toBeHoisted->hasAttr("eligible"))
200 return failure();
201 rewriter.moveOpBefore(toBeHoisted, op);
202 return success();
206 /// This pattern moves "test.move_before_parent_op" before the parent op.
207 struct MoveBeforeParentOp : public RewritePattern {
208 MoveBeforeParentOp(MLIRContext *context)
209 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
211 LogicalResult matchAndRewrite(Operation *op,
212 PatternRewriter &rewriter) const override {
213 // Do not hoist past functions.
214 if (isa<FunctionOpInterface>(op->getParentOp()))
215 return failure();
216 rewriter.moveOpBefore(op, op->getParentOp());
217 return success();
221 /// This pattern moves "test.move_after_parent_op" after the parent op.
222 struct MoveAfterParentOp : public RewritePattern {
223 MoveAfterParentOp(MLIRContext *context)
224 : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
226 LogicalResult matchAndRewrite(Operation *op,
227 PatternRewriter &rewriter) const override {
228 // Do not hoist past functions.
229 if (isa<FunctionOpInterface>(op->getParentOp()))
230 return failure();
232 int64_t moveForwardBy = 0;
233 if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
234 moveForwardBy = advanceBy.getInt();
236 Operation *moveAfter = op->getParentOp();
237 for (int64_t i = 0; i < moveForwardBy; ++i)
238 moveAfter = moveAfter->getNextNode();
240 rewriter.moveOpAfter(op, moveAfter);
241 return success();
245 /// This pattern inlines blocks that are nested in
246 /// "test.inline_blocks_into_parent" into the parent block.
247 struct InlineBlocksIntoParent : public RewritePattern {
248 InlineBlocksIntoParent(MLIRContext *context)
249 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
250 context) {}
252 LogicalResult matchAndRewrite(Operation *op,
253 PatternRewriter &rewriter) const override {
254 bool changed = false;
255 for (Region &r : op->getRegions()) {
256 while (!r.empty()) {
257 rewriter.inlineBlockBefore(&r.front(), op);
258 changed = true;
261 return success(changed);
265 /// This pattern splits blocks at "test.split_block_here" and replaces the op
266 /// with a new op (to prevent an infinite loop of block splitting).
267 struct SplitBlockHere : public RewritePattern {
268 SplitBlockHere(MLIRContext *context)
269 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
271 LogicalResult matchAndRewrite(Operation *op,
272 PatternRewriter &rewriter) const override {
273 rewriter.splitBlock(op->getBlock(), op->getIterator());
274 Operation *newOp = rewriter.create(
275 op->getLoc(),
276 OperationName("test.new_op", op->getContext()).getIdentifier(),
277 op->getOperands(), op->getResultTypes());
278 rewriter.replaceOp(op, newOp);
279 return success();
283 /// This pattern clones "test.clone_me" ops.
284 struct CloneOp : public RewritePattern {
285 CloneOp(MLIRContext *context)
286 : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
288 LogicalResult matchAndRewrite(Operation *op,
289 PatternRewriter &rewriter) const override {
290 // Do not clone already cloned ops to avoid going into an infinite loop.
291 if (op->hasAttr("was_cloned"))
292 return failure();
293 Operation *cloned = rewriter.clone(*op);
294 cloned->setAttr("was_cloned", rewriter.getUnitAttr());
295 return success();
299 /// This pattern clones regions of "test.clone_region_before" ops before the
300 /// parent block.
301 struct CloneRegionBeforeOp : public RewritePattern {
302 CloneRegionBeforeOp(MLIRContext *context)
303 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
305 LogicalResult matchAndRewrite(Operation *op,
306 PatternRewriter &rewriter) const override {
307 // Do not clone already cloned ops to avoid going into an infinite loop.
308 if (op->hasAttr("was_cloned"))
309 return failure();
310 for (Region &r : op->getRegions())
311 rewriter.cloneRegionBefore(r, op->getBlock());
312 op->setAttr("was_cloned", rewriter.getUnitAttr());
313 return success();
317 /// Replace an operation may introduce the re-visiting of its users.
318 class ReplaceWithNewOp : public RewritePattern {
319 public:
320 ReplaceWithNewOp(MLIRContext *context)
321 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
323 LogicalResult matchAndRewrite(Operation *op,
324 PatternRewriter &rewriter) const override {
325 Operation *newOp;
326 if (op->hasAttr("create_erase_op")) {
327 newOp = rewriter.create(
328 op->getLoc(),
329 OperationName("test.erase_op", op->getContext()).getIdentifier(),
330 ValueRange(), TypeRange());
331 } else {
332 newOp = rewriter.create(
333 op->getLoc(),
334 OperationName("test.new_op", op->getContext()).getIdentifier(),
335 op->getOperands(), op->getResultTypes());
337 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
338 // A "notifyOperationReplaced" callback is triggered in either case.
339 rewriter.replaceAllOpUsesWith(op, newOp->getResults());
340 rewriter.eraseOp(op);
341 return success();
345 /// Erases the first child block of the matched "test.erase_first_block"
346 /// operation.
347 class EraseFirstBlock : public RewritePattern {
348 public:
349 EraseFirstBlock(MLIRContext *context)
350 : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
352 LogicalResult matchAndRewrite(Operation *op,
353 PatternRewriter &rewriter) const override {
354 for (Region &r : op->getRegions()) {
355 for (Block &b : r.getBlocks()) {
356 rewriter.eraseBlock(&b);
357 return success();
361 return failure();
365 struct TestGreedyPatternDriver
366 : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
367 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
369 TestGreedyPatternDriver() = default;
370 TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
371 : PassWrapper(other) {}
373 StringRef getArgument() const final { return "test-greedy-patterns"; }
374 StringRef getDescription() const final { return "Run test dialect patterns"; }
375 void runOnOperation() override {
376 mlir::RewritePatternSet patterns(&getContext());
377 populateWithGenerated(patterns);
379 // Verify named pattern is generated with expected name.
380 patterns.add<FoldingPattern, TestNamedPatternRule,
381 FolderInsertBeforePreviouslyFoldedConstantPattern,
382 FolderCommutativeOp2WithConstant, HoistEligibleOps,
383 MakeOpEligible>(&getContext());
385 // Additional patterns for testing the GreedyPatternRewriteDriver.
386 patterns.insert<IncrementIntAttribute<3>>(&getContext());
388 GreedyRewriteConfig config;
389 config.useTopDownTraversal = this->useTopDownTraversal;
390 config.maxIterations = this->maxIterations;
391 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
392 config);
395 Option<bool> useTopDownTraversal{
396 *this, "top-down",
397 llvm::cl::desc("Seed the worklist in general top-down order"),
398 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
399 Option<int> maxIterations{
400 *this, "max-iterations",
401 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
402 llvm::cl::init(GreedyRewriteConfig().maxIterations)};
405 struct DumpNotifications : public RewriterBase::Listener {
406 void notifyBlockInserted(Block *block, Region *previous,
407 Region::iterator previousIt) override {
408 llvm::outs() << "notifyBlockInserted";
409 if (block->getParentOp()) {
410 llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
411 } else {
412 llvm::outs() << " into unknown op: ";
414 if (previous == nullptr) {
415 llvm::outs() << "was unlinked\n";
416 } else {
417 llvm::outs() << "was linked\n";
420 void notifyOperationInserted(Operation *op,
421 OpBuilder::InsertPoint previous) override {
422 llvm::outs() << "notifyOperationInserted: " << op->getName();
423 if (!previous.isSet()) {
424 llvm::outs() << ", was unlinked\n";
425 } else {
426 if (!previous.getPoint().getNodePtr()) {
427 llvm::outs() << ", was linked, exact position unknown\n";
428 } else if (previous.getPoint() == previous.getBlock()->end()) {
429 llvm::outs() << ", was last in block\n";
430 } else {
431 llvm::outs() << ", previous = " << previous.getPoint()->getName()
432 << "\n";
436 void notifyBlockErased(Block *block) override {
437 llvm::outs() << "notifyBlockErased\n";
439 void notifyOperationErased(Operation *op) override {
440 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
442 void notifyOperationModified(Operation *op) override {
443 llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
445 void notifyOperationReplaced(Operation *op, ValueRange values) override {
446 llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
450 struct TestStrictPatternDriver
451 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
452 public:
453 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
455 TestStrictPatternDriver() = default;
456 TestStrictPatternDriver(const TestStrictPatternDriver &other)
457 : PassWrapper(other) {
458 strictMode = other.strictMode;
461 StringRef getArgument() const final { return "test-strict-pattern-driver"; }
462 StringRef getDescription() const final {
463 return "Test strict mode of pattern driver";
466 void runOnOperation() override {
467 MLIRContext *ctx = &getContext();
468 mlir::RewritePatternSet patterns(ctx);
469 patterns.add<
470 // clang-format off
471 ChangeBlockOp,
472 CloneOp,
473 CloneRegionBeforeOp,
474 EraseOp,
475 ImplicitChangeOp,
476 InlineBlocksIntoParent,
477 InsertSameOp,
478 MoveBeforeParentOp,
479 ReplaceWithNewOp,
480 SplitBlockHere
481 // clang-format on
482 >(ctx);
483 SmallVector<Operation *> ops;
484 getOperation()->walk([&](Operation *op) {
485 StringRef opName = op->getName().getStringRef();
486 if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
487 opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
488 opName == "test.move_before_parent_op" ||
489 opName == "test.inline_blocks_into_parent" ||
490 opName == "test.split_block_here" || opName == "test.clone_me" ||
491 opName == "test.clone_region_before") {
492 ops.push_back(op);
496 DumpNotifications dumpNotifications;
497 GreedyRewriteConfig config;
498 config.listener = &dumpNotifications;
499 if (strictMode == "AnyOp") {
500 config.strictMode = GreedyRewriteStrictness::AnyOp;
501 } else if (strictMode == "ExistingAndNewOps") {
502 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
503 } else if (strictMode == "ExistingOps") {
504 config.strictMode = GreedyRewriteStrictness::ExistingOps;
505 } else {
506 llvm_unreachable("invalid strictness option");
509 // Check if these transformations introduce visiting of operations that
510 // are not in the `ops` set (The new created ops are valid). An invalid
511 // operation will trigger the assertion while processing.
512 bool changed = false;
513 bool allErased = false;
514 (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config,
515 &changed, &allErased);
516 Builder b(ctx);
517 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
518 getOperation()->setAttr("pattern_driver_all_erased",
519 b.getBoolAttr(allErased));
522 Option<std::string> strictMode{
523 *this, "strictness",
524 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
525 llvm::cl::init("AnyOp")};
527 private:
528 // New inserted operation is valid for further transformation.
529 class InsertSameOp : public RewritePattern {
530 public:
531 InsertSameOp(MLIRContext *context)
532 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
534 LogicalResult matchAndRewrite(Operation *op,
535 PatternRewriter &rewriter) const override {
536 if (op->hasAttr("skip"))
537 return failure();
539 Operation *newOp =
540 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
541 op->getOperands(), op->getResultTypes());
542 rewriter.modifyOpInPlace(
543 op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
544 newOp->setAttr("skip", rewriter.getBoolAttr(true));
546 return success();
550 // Remove an operation may introduce the re-visiting of its operands.
551 class EraseOp : public RewritePattern {
552 public:
553 EraseOp(MLIRContext *context)
554 : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
555 LogicalResult matchAndRewrite(Operation *op,
556 PatternRewriter &rewriter) const override {
557 rewriter.eraseOp(op);
558 return success();
562 // The following two patterns test RewriterBase::replaceAllUsesWith.
564 // That function replaces all usages of a Block (or a Value) with another one
565 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
566 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
567 // worklist: when an op is modified, it is added to the worklist. The two
568 // patterns below make the tracking observable: ChangeBlockOp replaces all
569 // usages of a block and that pattern is applied because the corresponding ops
570 // are put on the initial worklist (see above). ImplicitChangeOp does an
571 // unrelated change but ops of the corresponding type are *not* on the initial
572 // worklist, so the effect of the second pattern is only visible if the
573 // tracking and subsequent adding to the worklist actually works.
575 // Replace all usages of the first successor with the second successor.
576 class ChangeBlockOp : public RewritePattern {
577 public:
578 ChangeBlockOp(MLIRContext *context)
579 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
580 LogicalResult matchAndRewrite(Operation *op,
581 PatternRewriter &rewriter) const override {
582 if (op->getNumSuccessors() < 2)
583 return failure();
584 Block *firstSuccessor = op->getSuccessor(0);
585 Block *secondSuccessor = op->getSuccessor(1);
586 if (firstSuccessor == secondSuccessor)
587 return failure();
588 // This is the function being tested:
589 rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
590 // Using the following line instead would make the test fail:
591 // firstSuccessor->replaceAllUsesWith(secondSuccessor);
592 return success();
596 // Changes the successor to the parent block.
597 class ImplicitChangeOp : public RewritePattern {
598 public:
599 ImplicitChangeOp(MLIRContext *context)
600 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
601 LogicalResult matchAndRewrite(Operation *op,
602 PatternRewriter &rewriter) const override {
603 if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
604 return failure();
605 rewriter.modifyOpInPlace(op,
606 [&]() { op->setSuccessor(op->getBlock(), 0); });
607 return success();
612 struct TestWalkPatternDriver final
613 : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
614 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
616 TestWalkPatternDriver() = default;
617 TestWalkPatternDriver(const TestWalkPatternDriver &other)
618 : PassWrapper(other) {}
620 StringRef getArgument() const override {
621 return "test-walk-pattern-rewrite-driver";
623 StringRef getDescription() const override {
624 return "Run test walk pattern rewrite driver";
626 void runOnOperation() override {
627 mlir::RewritePatternSet patterns(&getContext());
629 // Patterns for testing the WalkPatternRewriteDriver.
630 patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
631 MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
632 &getContext());
634 DumpNotifications dumpListener;
635 walkAndApplyPatterns(getOperation(), std::move(patterns),
636 dumpNotifications ? &dumpListener : nullptr);
639 Option<bool> dumpNotifications{
640 *this, "dump-notifications",
641 llvm::cl::desc("Print rewrite listener notifications"),
642 llvm::cl::init(false)};
645 } // namespace
647 //===----------------------------------------------------------------------===//
648 // ReturnType Driver.
649 //===----------------------------------------------------------------------===//
651 namespace {
652 // Generate ops for each instance where the type can be successfully inferred.
653 template <typename OpTy>
654 static void invokeCreateWithInferredReturnType(Operation *op) {
655 auto *context = op->getContext();
656 auto fop = op->getParentOfType<func::FuncOp>();
657 auto location = UnknownLoc::get(context);
658 OpBuilder b(op);
659 b.setInsertionPointAfter(op);
661 // Use permutations of 2 args as operands.
662 assert(fop.getNumArguments() >= 2);
663 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
664 for (int j = 0; j < e; ++j) {
665 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
666 SmallVector<Type, 2> inferredReturnTypes;
667 if (succeeded(OpTy::inferReturnTypes(
668 context, std::nullopt, values, op->getDiscardableAttrDictionary(),
669 op->getPropertiesStorage(), op->getRegions(),
670 inferredReturnTypes))) {
671 OperationState state(location, OpTy::getOperationName());
672 // TODO: Expand to regions.
673 OpTy::build(b, state, values, op->getAttrs());
674 (void)b.create(state);
680 static void reifyReturnShape(Operation *op) {
681 OpBuilder b(op);
683 // Use permutations of 2 args as operands.
684 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
685 SmallVector<Value, 2> shapes;
686 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
687 !llvm::hasSingleElement(shapes))
688 return;
689 for (const auto &it : llvm::enumerate(shapes)) {
690 op->emitRemark() << "value " << it.index() << ": "
691 << it.value().getDefiningOp();
695 struct TestReturnTypeDriver
696 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
697 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
699 void getDependentDialects(DialectRegistry &registry) const override {
700 registry.insert<tensor::TensorDialect>();
702 StringRef getArgument() const final { return "test-return-type"; }
703 StringRef getDescription() const final { return "Run return type functions"; }
705 void runOnOperation() override {
706 if (getOperation().getName() == "testCreateFunctions") {
707 std::vector<Operation *> ops;
708 // Collect ops to avoid triggering on inserted ops.
709 for (auto &op : getOperation().getBody().front())
710 ops.push_back(&op);
711 // Generate test patterns for each, but skip terminator.
712 for (auto *op : llvm::ArrayRef(ops).drop_back()) {
713 // Test create method of each of the Op classes below. The resultant
714 // output would be in reverse order underneath `op` from which
715 // the attributes and regions are used.
716 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
717 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
718 op);
719 invokeCreateWithInferredReturnType<
720 OpWithShapedTypeInferTypeInterfaceOp>(op);
722 return;
724 if (getOperation().getName() == "testReifyFunctions") {
725 std::vector<Operation *> ops;
726 // Collect ops to avoid triggering on inserted ops.
727 for (auto &op : getOperation().getBody().front())
728 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
729 ops.push_back(&op);
730 // Generate test patterns for each, but skip terminator.
731 for (auto *op : ops)
732 reifyReturnShape(op);
736 } // namespace
738 namespace {
739 struct TestDerivedAttributeDriver
740 : public PassWrapper<TestDerivedAttributeDriver,
741 OperationPass<func::FuncOp>> {
742 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
744 StringRef getArgument() const final { return "test-derived-attr"; }
745 StringRef getDescription() const final {
746 return "Run test derived attributes";
748 void runOnOperation() override;
750 } // namespace
752 void TestDerivedAttributeDriver::runOnOperation() {
753 getOperation().walk([](DerivedAttributeOpInterface dOp) {
754 auto dAttr = dOp.materializeDerivedAttributes();
755 if (!dAttr)
756 return;
757 for (auto d : dAttr)
758 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
762 //===----------------------------------------------------------------------===//
763 // Legalization Driver.
764 //===----------------------------------------------------------------------===//
766 namespace {
767 //===----------------------------------------------------------------------===//
768 // Region-Block Rewrite Testing
770 /// This pattern applies a signature conversion to a block inside a detached
771 /// region.
772 struct TestDetachedSignatureConversion : public ConversionPattern {
773 TestDetachedSignatureConversion(MLIRContext *ctx)
774 : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
775 ctx) {}
777 LogicalResult
778 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
779 ConversionPatternRewriter &rewriter) const final {
780 if (op->getNumRegions() != 1)
781 return failure();
782 OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
783 op->getResultTypes(), {}, BlockRange());
784 Region *newRegion = state.addRegion();
785 rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
786 newRegion->begin());
787 TypeConverter::SignatureConversion result(newRegion->getNumArguments());
788 for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i)
789 result.addInputs(i, rewriter.getF64Type());
790 rewriter.applySignatureConversion(&newRegion->front(), result);
791 Operation *newOp = rewriter.create(state);
792 rewriter.replaceOp(op, newOp->getResults());
793 return success();
797 /// This pattern is a simple pattern that inlines the first region of a given
798 /// operation into the parent region.
799 struct TestRegionRewriteBlockMovement : public ConversionPattern {
800 TestRegionRewriteBlockMovement(MLIRContext *ctx)
801 : ConversionPattern("test.region", 1, ctx) {}
803 LogicalResult
804 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
805 ConversionPatternRewriter &rewriter) const final {
806 // Inline this region into the parent region.
807 auto &parentRegion = *op->getParentRegion();
808 auto &opRegion = op->getRegion(0);
809 if (op->getDiscardableAttr("legalizer.should_clone"))
810 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
811 else
812 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
814 if (op->getDiscardableAttr("legalizer.erase_old_blocks")) {
815 while (!opRegion.empty())
816 rewriter.eraseBlock(&opRegion.front());
819 // Drop this operation.
820 rewriter.eraseOp(op);
821 return success();
824 /// This pattern is a simple pattern that generates a region containing an
825 /// illegal operation.
826 struct TestRegionRewriteUndo : public RewritePattern {
827 TestRegionRewriteUndo(MLIRContext *ctx)
828 : RewritePattern("test.region_builder", 1, ctx) {}
830 LogicalResult matchAndRewrite(Operation *op,
831 PatternRewriter &rewriter) const final {
832 // Create the region operation with an entry block containing arguments.
833 OperationState newRegion(op->getLoc(), "test.region");
834 newRegion.addRegion();
835 auto *regionOp = rewriter.create(newRegion);
836 auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
837 entryBlock->addArgument(rewriter.getIntegerType(64),
838 rewriter.getUnknownLoc());
840 // Add an explicitly illegal operation to ensure the conversion fails.
841 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
842 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
844 // Drop this operation.
845 rewriter.eraseOp(op);
846 return success();
849 /// A simple pattern that creates a block at the end of the parent region of the
850 /// matched operation.
851 struct TestCreateBlock : public RewritePattern {
852 TestCreateBlock(MLIRContext *ctx)
853 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
855 LogicalResult matchAndRewrite(Operation *op,
856 PatternRewriter &rewriter) const final {
857 Region &region = *op->getParentRegion();
858 Type i32Type = rewriter.getIntegerType(32);
859 Location loc = op->getLoc();
860 rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
861 rewriter.create<TerminatorOp>(loc);
862 rewriter.eraseOp(op);
863 return success();
867 /// A simple pattern that creates a block containing an invalid operation in
868 /// order to trigger the block creation undo mechanism.
869 struct TestCreateIllegalBlock : public RewritePattern {
870 TestCreateIllegalBlock(MLIRContext *ctx)
871 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
873 LogicalResult matchAndRewrite(Operation *op,
874 PatternRewriter &rewriter) const final {
875 Region &region = *op->getParentRegion();
876 Type i32Type = rewriter.getIntegerType(32);
877 Location loc = op->getLoc();
878 rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
879 // Create an illegal op to ensure the conversion fails.
880 rewriter.create<ILLegalOpF>(loc, i32Type);
881 rewriter.create<TerminatorOp>(loc);
882 rewriter.eraseOp(op);
883 return success();
887 /// A simple pattern that tests the undo mechanism when replacing the uses of a
888 /// block argument.
889 struct TestUndoBlockArgReplace : public ConversionPattern {
890 TestUndoBlockArgReplace(MLIRContext *ctx)
891 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
893 LogicalResult
894 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
895 ConversionPatternRewriter &rewriter) const final {
896 auto illegalOp =
897 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
898 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
899 illegalOp->getResult(0));
900 rewriter.modifyOpInPlace(op, [] {});
901 return success();
905 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
906 /// This is to test the rollback logic.
907 struct TestUndoMoveOpBefore : public ConversionPattern {
908 TestUndoMoveOpBefore(MLIRContext *ctx)
909 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
911 LogicalResult
912 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
913 ConversionPatternRewriter &rewriter) const override {
914 rewriter.moveOpBefore(op, op->getParentOp());
915 // Replace with an illegal op to ensure the conversion fails.
916 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
917 return success();
921 /// A rewrite pattern that tests the undo mechanism when erasing a block.
922 struct TestUndoBlockErase : public ConversionPattern {
923 TestUndoBlockErase(MLIRContext *ctx)
924 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
926 LogicalResult
927 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
928 ConversionPatternRewriter &rewriter) const final {
929 Block *secondBlock = &*std::next(op->getRegion(0).begin());
930 rewriter.setInsertionPointToStart(secondBlock);
931 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
932 rewriter.eraseBlock(secondBlock);
933 rewriter.modifyOpInPlace(op, [] {});
934 return success();
938 /// A pattern that modifies a property in-place, but keeps the op illegal.
939 struct TestUndoPropertiesModification : public ConversionPattern {
940 TestUndoPropertiesModification(MLIRContext *ctx)
941 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
942 LogicalResult
943 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
944 ConversionPatternRewriter &rewriter) const final {
945 if (!op->hasAttr("modify_inplace"))
946 return failure();
947 rewriter.modifyOpInPlace(
948 op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
949 return success();
953 //===----------------------------------------------------------------------===//
954 // Type-Conversion Rewrite Testing
956 /// This patterns erases a region operation that has had a type conversion.
957 struct TestDropOpSignatureConversion : public ConversionPattern {
958 TestDropOpSignatureConversion(MLIRContext *ctx,
959 const TypeConverter &converter)
960 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
961 LogicalResult
962 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
963 ConversionPatternRewriter &rewriter) const override {
964 Region &region = op->getRegion(0);
965 Block *entry = &region.front();
967 // Convert the original entry arguments.
968 const TypeConverter &converter = *getTypeConverter();
969 TypeConverter::SignatureConversion result(entry->getNumArguments());
970 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
971 result)) ||
972 failed(rewriter.convertRegionTypes(&region, converter, &result)))
973 return failure();
975 // Convert the region signature and just drop the operation.
976 rewriter.eraseOp(op);
977 return success();
980 /// This pattern simply updates the operands of the given operation.
981 struct TestPassthroughInvalidOp : public ConversionPattern {
982 TestPassthroughInvalidOp(MLIRContext *ctx)
983 : ConversionPattern("test.invalid", 1, ctx) {}
984 LogicalResult
985 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
986 ConversionPatternRewriter &rewriter) const final {
987 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands,
988 std::nullopt);
989 return success();
992 /// Replace with valid op, but simply drop the operands. This is used in a
993 /// regression where we used to generate circular unrealized_conversion_cast
994 /// ops.
995 struct TestDropAndReplaceInvalidOp : public ConversionPattern {
996 TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
997 : ConversionPattern(converter,
998 "test.drop_operands_and_replace_with_valid", 1, ctx) {
1000 LogicalResult
1001 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1002 ConversionPatternRewriter &rewriter) const final {
1003 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(),
1004 std::nullopt);
1005 return success();
1008 /// This pattern handles the case of a split return value.
1009 struct TestSplitReturnType : public ConversionPattern {
1010 TestSplitReturnType(MLIRContext *ctx)
1011 : ConversionPattern("test.return", 1, ctx) {}
1012 LogicalResult
1013 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1014 ConversionPatternRewriter &rewriter) const final {
1015 // Check for a return of F32.
1016 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
1017 return failure();
1019 // Check if the first operation is a cast operation, if it is we use the
1020 // results directly.
1021 auto *defOp = operands[0].getDefiningOp();
1022 if (auto packerOp =
1023 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
1024 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
1025 return success();
1028 // Otherwise, fail to match.
1029 return failure();
1033 //===----------------------------------------------------------------------===//
1034 // Multi-Level Type-Conversion Rewrite Testing
1035 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
1036 TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
1037 : ConversionPattern("test.type_producer", 1, ctx) {}
1038 LogicalResult
1039 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1040 ConversionPatternRewriter &rewriter) const final {
1041 // If the type is I32, change the type to F32.
1042 if (!Type(*op->result_type_begin()).isSignlessInteger(32))
1043 return failure();
1044 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
1045 return success();
1048 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
1049 TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
1050 : ConversionPattern("test.type_producer", 1, ctx) {}
1051 LogicalResult
1052 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1053 ConversionPatternRewriter &rewriter) const final {
1054 // If the type is F32, change the type to F64.
1055 if (!Type(*op->result_type_begin()).isF32())
1056 return rewriter.notifyMatchFailure(op, "expected single f32 operand");
1057 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
1058 return success();
1061 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
1062 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
1063 : ConversionPattern("test.type_producer", 10, ctx) {}
1064 LogicalResult
1065 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1066 ConversionPatternRewriter &rewriter) const final {
1067 // Always convert to B16, even though it is not a legal type. This tests
1068 // that values are unmapped correctly.
1069 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
1070 return success();
1073 struct TestUpdateConsumerType : public ConversionPattern {
1074 TestUpdateConsumerType(MLIRContext *ctx)
1075 : ConversionPattern("test.type_consumer", 1, ctx) {}
1076 LogicalResult
1077 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1078 ConversionPatternRewriter &rewriter) const final {
1079 // Verify that the incoming operand has been successfully remapped to F64.
1080 if (!operands[0].getType().isF64())
1081 return failure();
1082 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
1083 return success();
1087 //===----------------------------------------------------------------------===//
1088 // Non-Root Replacement Rewrite Testing
1089 /// This pattern generates an invalid operation, but replaces it before the
1090 /// pattern is finished. This checks that we don't need to legalize the
1091 /// temporary op.
1092 struct TestNonRootReplacement : public RewritePattern {
1093 TestNonRootReplacement(MLIRContext *ctx)
1094 : RewritePattern("test.replace_non_root", 1, ctx) {}
1096 LogicalResult matchAndRewrite(Operation *op,
1097 PatternRewriter &rewriter) const final {
1098 auto resultType = *op->result_type_begin();
1099 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
1100 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
1102 rewriter.replaceOp(illegalOp, legalOp);
1103 rewriter.replaceOp(op, illegalOp);
1104 return success();
1108 //===----------------------------------------------------------------------===//
1109 // Recursive Rewrite Testing
1110 /// This pattern is applied to the same operation multiple times, but has a
1111 /// bounded recursion.
1112 struct TestBoundedRecursiveRewrite
1113 : public OpRewritePattern<TestRecursiveRewriteOp> {
1114 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
1116 void initialize() {
1117 // The conversion target handles bounding the recursion of this pattern.
1118 setHasBoundedRewriteRecursion();
1121 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
1122 PatternRewriter &rewriter) const final {
1123 // Decrement the depth of the op in-place.
1124 rewriter.modifyOpInPlace(op, [&] {
1125 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
1127 return success();
1131 struct TestNestedOpCreationUndoRewrite
1132 : public OpRewritePattern<IllegalOpWithRegionAnchor> {
1133 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
1135 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
1136 PatternRewriter &rewriter) const final {
1137 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1138 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1139 return success();
1143 // This pattern matches `test.blackhole` and delete this op and its producer.
1144 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1145 using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1147 LogicalResult matchAndRewrite(BlackHoleOp op,
1148 PatternRewriter &rewriter) const final {
1149 Operation *producer = op.getOperand().getDefiningOp();
1150 // Always erase the user before the producer, the framework should handle
1151 // this correctly.
1152 rewriter.eraseOp(op);
1153 rewriter.eraseOp(producer);
1154 return success();
1158 // This pattern replaces explicitly illegal op with explicitly legal op,
1159 // but in addition creates unregistered operation.
1160 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1161 using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1163 LogicalResult matchAndRewrite(ILLegalOpG op,
1164 PatternRewriter &rewriter) const final {
1165 IntegerAttr attr = rewriter.getI32IntegerAttr(0);
1166 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
1167 rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
1168 return success();
1172 class TestEraseOp : public ConversionPattern {
1173 public:
1174 TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {}
1175 LogicalResult
1176 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1177 ConversionPatternRewriter &rewriter) const final {
1178 // Erase op without replacements.
1179 rewriter.eraseOp(op);
1180 return success();
1184 } // namespace
1186 namespace {
1187 struct TestTypeConverter : public TypeConverter {
1188 using TypeConverter::TypeConverter;
1189 TestTypeConverter() {
1190 addConversion(convertType);
1191 addArgumentMaterialization(materializeCast);
1192 addSourceMaterialization(materializeCast);
1195 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1196 // Drop I16 types.
1197 if (t.isSignlessInteger(16))
1198 return success();
1200 // Convert I64 to F64.
1201 if (t.isSignlessInteger(64)) {
1202 results.push_back(FloatType::getF64(t.getContext()));
1203 return success();
1206 // Convert I42 to I43.
1207 if (t.isInteger(42)) {
1208 results.push_back(IntegerType::get(t.getContext(), 43));
1209 return success();
1212 // Split F32 into F16,F16.
1213 if (t.isF32()) {
1214 results.assign(2, FloatType::getF16(t.getContext()));
1215 return success();
1218 // Otherwise, convert the type directly.
1219 results.push_back(t);
1220 return success();
1223 /// Hook for materializing a conversion. This is necessary because we generate
1224 /// 1->N type mappings.
1225 static Value materializeCast(OpBuilder &builder, Type resultType,
1226 ValueRange inputs, Location loc) {
1227 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1231 struct TestLegalizePatternDriver
1232 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
1233 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
1235 StringRef getArgument() const final { return "test-legalize-patterns"; }
1236 StringRef getDescription() const final {
1237 return "Run test dialect legalization patterns";
1239 /// The mode of conversion to use with the driver.
1240 enum class ConversionMode { Analysis, Full, Partial };
1242 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1244 void getDependentDialects(DialectRegistry &registry) const override {
1245 registry.insert<func::FuncDialect, test::TestDialect>();
1248 void runOnOperation() override {
1249 TestTypeConverter converter;
1250 mlir::RewritePatternSet patterns(&getContext());
1251 populateWithGenerated(patterns);
1252 patterns.add<
1253 TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1254 TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1255 TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
1256 TestSplitReturnType, TestChangeProducerTypeI32ToF32,
1257 TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
1258 TestUpdateConsumerType, TestNonRootReplacement,
1259 TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1260 TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1261 TestUndoPropertiesModification, TestEraseOp>(&getContext());
1262 patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
1263 &getContext(), converter);
1264 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1265 converter);
1266 mlir::populateCallOpTypeConversionPattern(patterns, converter);
1268 // Define the conversion target used for the test.
1269 ConversionTarget target(getContext());
1270 target.addLegalOp<ModuleOp>();
1271 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1272 TerminatorOp, OneRegionOp>();
1273 target.addLegalOp(
1274 OperationName("test.legal_op_with_region", &getContext()));
1275 target
1276 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1277 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
1278 // Don't allow F32 operands.
1279 return llvm::none_of(op.getOperandTypes(),
1280 [](Type type) { return type.isF32(); });
1282 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1283 return converter.isSignatureLegal(op.getFunctionType()) &&
1284 converter.isLegal(&op.getBody());
1286 target.addDynamicallyLegalOp<func::CallOp>(
1287 [&](func::CallOp op) { return converter.isLegal(op); });
1289 // TestCreateUnregisteredOp creates `arith.constant` operation,
1290 // which was not added to target intentionally to test
1291 // correct error code from conversion driver.
1292 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
1294 // Expect the type_producer/type_consumer operations to only operate on f64.
1295 target.addDynamicallyLegalOp<TestTypeProducerOp>(
1296 [](TestTypeProducerOp op) { return op.getType().isF64(); });
1297 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1298 return op.getOperand().getType().isF64();
1301 // Check support for marking certain operations as recursively legal.
1302 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
1303 return static_cast<bool>(
1304 op->getAttrOfType<UnitAttr>("test.recursively_legal"));
1307 // Mark the bound recursion operation as dynamically legal.
1308 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
1309 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1311 // Create a dynamically legal rule that can only be legalized by folding it.
1312 target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
1313 [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
1315 // Handle a partial conversion.
1316 if (mode == ConversionMode::Partial) {
1317 DenseSet<Operation *> unlegalizedOps;
1318 ConversionConfig config;
1319 DumpNotifications dumpNotifications;
1320 config.listener = &dumpNotifications;
1321 config.unlegalizedOps = &unlegalizedOps;
1322 if (failed(applyPartialConversion(getOperation(), target,
1323 std::move(patterns), config))) {
1324 getOperation()->emitRemark() << "applyPartialConversion failed";
1326 // Emit remarks for each legalizable operation.
1327 for (auto *op : unlegalizedOps)
1328 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1329 return;
1332 // Handle a full conversion.
1333 if (mode == ConversionMode::Full) {
1334 // Check support for marking unknown operations as dynamically legal.
1335 target.markUnknownOpDynamicallyLegal([](Operation *op) {
1336 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
1339 ConversionConfig config;
1340 DumpNotifications dumpNotifications;
1341 config.listener = &dumpNotifications;
1342 if (failed(applyFullConversion(getOperation(), target,
1343 std::move(patterns), config))) {
1344 getOperation()->emitRemark() << "applyFullConversion failed";
1346 return;
1349 // Otherwise, handle an analysis conversion.
1350 assert(mode == ConversionMode::Analysis);
1352 // Analyze the convertible operations.
1353 DenseSet<Operation *> legalizedOps;
1354 ConversionConfig config;
1355 config.legalizableOps = &legalizedOps;
1356 if (failed(applyAnalysisConversion(getOperation(), target,
1357 std::move(patterns), config)))
1358 return signalPassFailure();
1360 // Emit remarks for each legalizable operation.
1361 for (auto *op : legalizedOps)
1362 op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1365 /// The mode of conversion to use.
1366 ConversionMode mode;
1368 } // namespace
1370 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1371 legalizerConversionMode(
1372 "test-legalize-mode",
1373 llvm::cl::desc("The legalization mode to use with the test driver"),
1374 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
1375 llvm::cl::values(
1376 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1377 "analysis", "Perform an analysis conversion"),
1378 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1379 "Perform a full conversion"),
1380 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1381 "partial", "Perform a partial conversion")));
1383 //===----------------------------------------------------------------------===//
1384 // ConversionPatternRewriter::getRemappedValue testing. This method is used
1385 // to get the remapped value of an original value that was replaced using
1386 // ConversionPatternRewriter.
1387 namespace {
1388 struct TestRemapValueTypeConverter : public TypeConverter {
1389 using TypeConverter::TypeConverter;
1391 TestRemapValueTypeConverter() {
1392 addConversion(
1393 [](Float32Type type) { return Float64Type::get(type.getContext()); });
1394 addConversion([](Type type) { return type; });
1398 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1399 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1400 /// operand twice.
1402 /// Example:
1403 /// %1 = test.one_variadic_out_one_variadic_in1"(%0)
1404 /// is replaced with:
1405 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1406 struct OneVResOneVOperandOp1Converter
1407 : public OpConversionPattern<OneVResOneVOperandOp1> {
1408 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1410 LogicalResult
1411 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1412 ConversionPatternRewriter &rewriter) const override {
1413 auto origOps = op.getOperands();
1414 assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1415 "One operand expected");
1416 Value origOp = *origOps.begin();
1417 SmallVector<Value, 2> remappedOperands;
1418 // Replicate the remapped original operand twice. Note that we don't used
1419 // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1420 remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1421 remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1423 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
1424 remappedOperands);
1425 return success();
1429 /// A rewriter pattern that tests that blocks can be merged.
1430 struct TestRemapValueInRegion
1431 : public OpConversionPattern<TestRemappedValueRegionOp> {
1432 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1434 LogicalResult
1435 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1436 ConversionPatternRewriter &rewriter) const final {
1437 Block &block = op.getBody().front();
1438 Operation *terminator = block.getTerminator();
1440 // Merge the block into the parent region.
1441 Block *parentBlock = op->getBlock();
1442 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
1443 rewriter.mergeBlocks(&block, parentBlock, ValueRange());
1444 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
1446 // Replace the results of this operation with the remapped terminator
1447 // values.
1448 SmallVector<Value> terminatorOperands;
1449 if (failed(rewriter.getRemappedValues(terminator->getOperands(),
1450 terminatorOperands)))
1451 return failure();
1453 rewriter.eraseOp(terminator);
1454 rewriter.replaceOp(op, terminatorOperands);
1455 return success();
1459 struct TestRemappedValue
1460 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
1461 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
1463 StringRef getArgument() const final { return "test-remapped-value"; }
1464 StringRef getDescription() const final {
1465 return "Test public remapped value mechanism in ConversionPatternRewriter";
1467 void runOnOperation() override {
1468 TestRemapValueTypeConverter typeConverter;
1470 mlir::RewritePatternSet patterns(&getContext());
1471 patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
1472 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1473 &getContext());
1474 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
1476 mlir::ConversionTarget target(getContext());
1477 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1479 // Expect the type_producer/type_consumer operations to only operate on f64.
1480 target.addDynamicallyLegalOp<TestTypeProducerOp>(
1481 [](TestTypeProducerOp op) { return op.getType().isF64(); });
1482 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1483 return op.getOperand().getType().isF64();
1486 // We make OneVResOneVOperandOp1 legal only when it has more that one
1487 // operand. This will trigger the conversion that will replace one-operand
1488 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1489 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1490 [](Operation *op) { return op->getNumOperands() > 1; });
1492 if (failed(mlir::applyFullConversion(getOperation(), target,
1493 std::move(patterns)))) {
1494 signalPassFailure();
1498 } // namespace
1500 //===----------------------------------------------------------------------===//
1501 // Test patterns without a specific root operation kind
1502 //===----------------------------------------------------------------------===//
1504 namespace {
1505 /// This pattern matches and removes any operation in the test dialect.
1506 struct RemoveTestDialectOps : public RewritePattern {
1507 RemoveTestDialectOps(MLIRContext *context)
1508 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1510 LogicalResult matchAndRewrite(Operation *op,
1511 PatternRewriter &rewriter) const override {
1512 if (!isa<TestDialect>(op->getDialect()))
1513 return failure();
1514 rewriter.eraseOp(op);
1515 return success();
1519 struct TestUnknownRootOpDriver
1520 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
1521 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
1523 StringRef getArgument() const final {
1524 return "test-legalize-unknown-root-patterns";
1526 StringRef getDescription() const final {
1527 return "Test public remapped value mechanism in ConversionPatternRewriter";
1529 void runOnOperation() override {
1530 mlir::RewritePatternSet patterns(&getContext());
1531 patterns.add<RemoveTestDialectOps>(&getContext());
1533 mlir::ConversionTarget target(getContext());
1534 target.addIllegalDialect<TestDialect>();
1535 if (failed(applyPartialConversion(getOperation(), target,
1536 std::move(patterns))))
1537 signalPassFailure();
1540 } // namespace
1542 //===----------------------------------------------------------------------===//
1543 // Test patterns that uses operations and types defined at runtime
1544 //===----------------------------------------------------------------------===//
1546 namespace {
1547 /// This pattern matches dynamic operations 'test.one_operand_two_results' and
1548 /// replace them with dynamic operations 'test.generic_dynamic_op'.
1549 struct RewriteDynamicOp : public RewritePattern {
1550 RewriteDynamicOp(MLIRContext *context)
1551 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1552 context) {}
1554 LogicalResult matchAndRewrite(Operation *op,
1555 PatternRewriter &rewriter) const override {
1556 assert(op->getName().getStringRef() ==
1557 "test.dynamic_one_operand_two_results" &&
1558 "rewrite pattern should only match operations with the right name");
1560 OperationState state(op->getLoc(), "test.dynamic_generic",
1561 op->getOperands(), op->getResultTypes(),
1562 op->getAttrs());
1563 auto *newOp = rewriter.create(state);
1564 rewriter.replaceOp(op, newOp->getResults());
1565 return success();
1569 struct TestRewriteDynamicOpDriver
1570 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
1571 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
1573 void getDependentDialects(DialectRegistry &registry) const override {
1574 registry.insert<TestDialect>();
1576 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
1577 StringRef getDescription() const final {
1578 return "Test rewritting on dynamic operations";
1580 void runOnOperation() override {
1581 RewritePatternSet patterns(&getContext());
1582 patterns.add<RewriteDynamicOp>(&getContext());
1584 ConversionTarget target(getContext());
1585 target.addIllegalOp(
1586 OperationName("test.dynamic_one_operand_two_results", &getContext()));
1587 target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
1588 if (failed(applyPartialConversion(getOperation(), target,
1589 std::move(patterns))))
1590 signalPassFailure();
1593 } // end anonymous namespace
1595 //===----------------------------------------------------------------------===//
1596 // Test type conversions
1597 //===----------------------------------------------------------------------===//
1599 namespace {
1600 struct TestTypeConversionProducer
1601 : public OpConversionPattern<TestTypeProducerOp> {
1602 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1603 LogicalResult
1604 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1605 ConversionPatternRewriter &rewriter) const final {
1606 Type resultType = op.getType();
1607 Type convertedType = getTypeConverter()
1608 ? getTypeConverter()->convertType(resultType)
1609 : resultType;
1610 if (isa<FloatType>(resultType))
1611 resultType = rewriter.getF64Type();
1612 else if (resultType.isInteger(16))
1613 resultType = rewriter.getIntegerType(64);
1614 else if (isa<test::TestRecursiveType>(resultType) &&
1615 convertedType != resultType)
1616 resultType = convertedType;
1617 else
1618 return failure();
1620 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
1621 return success();
1625 /// Call signature conversion and then fail the rewrite to trigger the undo
1626 /// mechanism.
1627 struct TestSignatureConversionUndo
1628 : public OpConversionPattern<TestSignatureConversionUndoOp> {
1629 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1631 LogicalResult
1632 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1633 ConversionPatternRewriter &rewriter) const final {
1634 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
1635 return failure();
1639 /// Call signature conversion without providing a type converter to handle
1640 /// materializations.
1641 struct TestTestSignatureConversionNoConverter
1642 : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1643 TestTestSignatureConversionNoConverter(const TypeConverter &converter,
1644 MLIRContext *context)
1645 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1646 converter(converter) {}
1648 LogicalResult
1649 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1650 ConversionPatternRewriter &rewriter) const final {
1651 Region &region = op->getRegion(0);
1652 Block *entry = &region.front();
1654 // Convert the original entry arguments.
1655 TypeConverter::SignatureConversion result(entry->getNumArguments());
1656 if (failed(
1657 converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
1658 return failure();
1659 rewriter.modifyOpInPlace(op, [&] {
1660 rewriter.applySignatureConversion(&region.front(), result);
1662 return success();
1665 const TypeConverter &converter;
1668 /// Just forward the operands to the root op. This is essentially a no-op
1669 /// pattern that is used to trigger target materialization.
1670 struct TestTypeConsumerForward
1671 : public OpConversionPattern<TestTypeConsumerOp> {
1672 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1674 LogicalResult
1675 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1676 ConversionPatternRewriter &rewriter) const final {
1677 rewriter.modifyOpInPlace(op,
1678 [&] { op->setOperands(adaptor.getOperands()); });
1679 return success();
1683 struct TestTypeConversionAnotherProducer
1684 : public OpRewritePattern<TestAnotherTypeProducerOp> {
1685 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1687 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1688 PatternRewriter &rewriter) const final {
1689 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1690 return success();
1694 struct TestReplaceWithLegalOp : public ConversionPattern {
1695 TestReplaceWithLegalOp(MLIRContext *ctx)
1696 : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
1697 LogicalResult
1698 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1699 ConversionPatternRewriter &rewriter) const final {
1700 rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
1701 return success();
1705 struct TestTypeConversionDriver
1706 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
1707 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1709 void getDependentDialects(DialectRegistry &registry) const override {
1710 registry.insert<TestDialect>();
1712 StringRef getArgument() const final {
1713 return "test-legalize-type-conversion";
1715 StringRef getDescription() const final {
1716 return "Test various type conversion functionalities in DialectConversion";
1719 void runOnOperation() override {
1720 // Initialize the type converter.
1721 SmallVector<Type, 2> conversionCallStack;
1722 TypeConverter converter;
1724 /// Add the legal set of type conversions.
1725 converter.addConversion([](Type type) -> Type {
1726 // Treat F64 as legal.
1727 if (type.isF64())
1728 return type;
1729 // Allow converting BF16/F16/F32 to F64.
1730 if (type.isBF16() || type.isF16() || type.isF32())
1731 return FloatType::getF64(type.getContext());
1732 // Otherwise, the type is illegal.
1733 return nullptr;
1735 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1736 // Drop all integer types.
1737 return success();
1739 converter.addConversion(
1740 // Convert a recursive self-referring type into a non-self-referring
1741 // type named "outer_converted_type" that contains a SimpleAType.
1742 [&](test::TestRecursiveType type,
1743 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1744 // If the type is already converted, return it to indicate that it is
1745 // legal.
1746 if (type.getName() == "outer_converted_type") {
1747 results.push_back(type);
1748 return success();
1751 conversionCallStack.push_back(type);
1752 auto popConversionCallStack = llvm::make_scope_exit(
1753 [&conversionCallStack]() { conversionCallStack.pop_back(); });
1755 // If the type is on the call stack more than once (it is there at
1756 // least once because of the _current_ call, which is always the last
1757 // element on the stack), we've hit the recursive case. Just return
1758 // SimpleAType here to create a non-recursive type as a result.
1759 if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
1760 type)) {
1761 results.push_back(test::SimpleAType::get(type.getContext()));
1762 return success();
1765 // Convert the body recursively.
1766 auto result = test::TestRecursiveType::get(type.getContext(),
1767 "outer_converted_type");
1768 if (failed(result.setBody(converter.convertType(type.getBody()))))
1769 return failure();
1770 results.push_back(result);
1771 return success();
1774 /// Add the legal set of type materializations.
1775 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1776 ValueRange inputs,
1777 Location loc) -> Value {
1778 // Allow casting from F64 back to F32.
1779 if (!resultType.isF16() && inputs.size() == 1 &&
1780 inputs[0].getType().isF64())
1781 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1782 // Allow producing an i32 or i64 from nothing.
1783 if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1784 inputs.empty())
1785 return builder.create<TestTypeProducerOp>(loc, resultType);
1786 // Allow producing an i64 from an integer.
1787 if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
1788 isa<IntegerType>(inputs[0].getType()))
1789 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1790 // Otherwise, fail.
1791 return nullptr;
1794 // Initialize the conversion target.
1795 mlir::ConversionTarget target(getContext());
1796 target.addLegalOp<LegalOpD>();
1797 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1798 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
1799 return op.getType().isF64() || op.getType().isInteger(64) ||
1800 (recursiveType &&
1801 recursiveType.getName() == "outer_converted_type");
1803 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1804 return converter.isSignatureLegal(op.getFunctionType()) &&
1805 converter.isLegal(&op.getBody());
1807 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1808 // Allow casts from F64 to F32.
1809 return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1811 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1812 [&](TestSignatureConversionNoConverterOp op) {
1813 return converter.isLegal(op.getRegion().front().getArgumentTypes());
1816 // Initialize the set of rewrite patterns.
1817 RewritePatternSet patterns(&getContext());
1818 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1819 TestSignatureConversionUndo,
1820 TestTestSignatureConversionNoConverter>(converter,
1821 &getContext());
1822 patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
1823 &getContext());
1824 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1825 converter);
1827 if (failed(applyPartialConversion(getOperation(), target,
1828 std::move(patterns))))
1829 signalPassFailure();
1832 } // namespace
1834 //===----------------------------------------------------------------------===//
1835 // Test Target Materialization With No Uses
1836 //===----------------------------------------------------------------------===//
1838 namespace {
1839 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1840 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1842 LogicalResult
1843 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1844 ConversionPatternRewriter &rewriter) const final {
1845 rewriter.replaceOp(op, adaptor.getOperands());
1846 return success();
1850 struct TestTargetMaterializationWithNoUses
1851 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
1852 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1853 TestTargetMaterializationWithNoUses)
1855 StringRef getArgument() const final {
1856 return "test-target-materialization-with-no-uses";
1858 StringRef getDescription() const final {
1859 return "Test a special case of target materialization in DialectConversion";
1862 void runOnOperation() override {
1863 TypeConverter converter;
1864 converter.addConversion([](Type t) { return t; });
1865 converter.addConversion([](IntegerType intTy) -> Type {
1866 if (intTy.getWidth() == 16)
1867 return IntegerType::get(intTy.getContext(), 64);
1868 return intTy;
1870 converter.addTargetMaterialization(
1871 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1872 return builder.create<TestCastOp>(loc, type, inputs).getResult();
1875 ConversionTarget target(getContext());
1876 target.addIllegalOp<TestTypeChangerOp>();
1878 RewritePatternSet patterns(&getContext());
1879 patterns.add<ForwardOperandPattern>(converter, &getContext());
1881 if (failed(applyPartialConversion(getOperation(), target,
1882 std::move(patterns))))
1883 signalPassFailure();
1886 } // namespace
1888 //===----------------------------------------------------------------------===//
1889 // Test Block Merging
1890 //===----------------------------------------------------------------------===//
1892 namespace {
1893 /// A rewriter pattern that tests that blocks can be merged.
1894 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1895 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1897 LogicalResult
1898 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1899 ConversionPatternRewriter &rewriter) const final {
1900 Block &firstBlock = op.getBody().front();
1901 Operation *branchOp = firstBlock.getTerminator();
1902 Block *secondBlock = &*(std::next(op.getBody().begin()));
1903 auto succOperands = branchOp->getOperands();
1904 SmallVector<Value, 2> replacements(succOperands);
1905 rewriter.eraseOp(branchOp);
1906 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1907 rewriter.modifyOpInPlace(op, [] {});
1908 return success();
1912 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1913 struct TestUndoBlocksMerge : public ConversionPattern {
1914 TestUndoBlocksMerge(MLIRContext *ctx)
1915 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1916 LogicalResult
1917 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1918 ConversionPatternRewriter &rewriter) const final {
1919 Block &firstBlock = op->getRegion(0).front();
1920 Operation *branchOp = firstBlock.getTerminator();
1921 Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1922 rewriter.setInsertionPointToStart(secondBlock);
1923 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1924 auto succOperands = branchOp->getOperands();
1925 SmallVector<Value, 2> replacements(succOperands);
1926 rewriter.eraseOp(branchOp);
1927 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1928 rewriter.modifyOpInPlace(op, [] {});
1929 return success();
1933 /// A rewrite mechanism to inline the body of the op into its parent, when both
1934 /// ops can have a single block.
1935 struct TestMergeSingleBlockOps
1936 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1937 using OpConversionPattern<
1938 SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1940 LogicalResult
1941 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1942 ConversionPatternRewriter &rewriter) const final {
1943 SingleBlockImplicitTerminatorOp parentOp =
1944 op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1945 if (!parentOp)
1946 return failure();
1947 Block &innerBlock = op.getRegion().front();
1948 TerminatorOp innerTerminator =
1949 cast<TerminatorOp>(innerBlock.getTerminator());
1950 rewriter.inlineBlockBefore(&innerBlock, op);
1951 rewriter.eraseOp(innerTerminator);
1952 rewriter.eraseOp(op);
1953 return success();
1957 struct TestMergeBlocksPatternDriver
1958 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
1959 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
1961 StringRef getArgument() const final { return "test-merge-blocks"; }
1962 StringRef getDescription() const final {
1963 return "Test Merging operation in ConversionPatternRewriter";
1965 void runOnOperation() override {
1966 MLIRContext *context = &getContext();
1967 mlir::RewritePatternSet patterns(context);
1968 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1969 context);
1970 ConversionTarget target(*context);
1971 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1972 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1973 target.addIllegalOp<ILLegalOpF>();
1975 /// Expect the op to have a single block after legalization.
1976 target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1977 [&](TestMergeBlocksOp op) -> bool {
1978 return llvm::hasSingleElement(op.getBody());
1981 /// Only allow `test.br` within test.merge_blocks op.
1982 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
1983 return op->getParentOfType<TestMergeBlocksOp>();
1986 /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1987 /// inlined.
1988 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1989 [&](SingleBlockImplicitTerminatorOp op) -> bool {
1990 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1993 DenseSet<Operation *> unlegalizedOps;
1994 ConversionConfig config;
1995 config.unlegalizedOps = &unlegalizedOps;
1996 (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1997 config);
1998 for (auto *op : unlegalizedOps)
1999 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
2002 } // namespace
2004 //===----------------------------------------------------------------------===//
2005 // Test Selective Replacement
2006 //===----------------------------------------------------------------------===//
2008 namespace {
2009 /// A rewrite mechanism to inline the body of the op into its parent, when both
2010 /// ops can have a single block.
2011 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
2012 using OpRewritePattern<TestCastOp>::OpRewritePattern;
2014 LogicalResult matchAndRewrite(TestCastOp op,
2015 PatternRewriter &rewriter) const final {
2016 if (op.getNumOperands() != 2)
2017 return failure();
2018 OperandRange operands = op.getOperands();
2020 // Replace non-terminator uses with the first operand.
2021 rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
2022 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
2024 // Replace everything else with the second operand if the operation isn't
2025 // dead.
2026 rewriter.replaceOp(op, op.getOperand(1));
2027 return success();
2031 struct TestSelectiveReplacementPatternDriver
2032 : public PassWrapper<TestSelectiveReplacementPatternDriver,
2033 OperationPass<>> {
2034 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2035 TestSelectiveReplacementPatternDriver)
2037 StringRef getArgument() const final {
2038 return "test-pattern-selective-replacement";
2040 StringRef getDescription() const final {
2041 return "Test selective replacement in the PatternRewriter";
2043 void runOnOperation() override {
2044 MLIRContext *context = &getContext();
2045 mlir::RewritePatternSet patterns(context);
2046 patterns.add<TestSelectiveOpReplacementPattern>(context);
2047 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
2050 } // namespace
2052 //===----------------------------------------------------------------------===//
2053 // PassRegistration
2054 //===----------------------------------------------------------------------===//
2056 namespace mlir {
2057 namespace test {
2058 void registerPatternsTestPass() {
2059 PassRegistration<TestReturnTypeDriver>();
2061 PassRegistration<TestDerivedAttributeDriver>();
2063 PassRegistration<TestGreedyPatternDriver>();
2064 PassRegistration<TestStrictPatternDriver>();
2065 PassRegistration<TestWalkPatternDriver>();
2067 PassRegistration<TestLegalizePatternDriver>([] {
2068 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
2071 PassRegistration<TestRemappedValue>();
2073 PassRegistration<TestUnknownRootOpDriver>();
2075 PassRegistration<TestTypeConversionDriver>();
2076 PassRegistration<TestTargetMaterializationWithNoUses>();
2078 PassRegistration<TestRewriteDynamicOpDriver>();
2080 PassRegistration<TestMergeBlocksPatternDriver>();
2081 PassRegistration<TestSelectiveReplacementPatternDriver>();
2083 } // namespace test
2084 } // namespace mlir