[libc++abi] Build cxxabi with sanitizers (#119612)
[llvm-project.git] / mlir / test / lib / Dialect / Test / TestPatterns.cpp
blob8a0bc597c56bebbdfedf53b07ac2920d853528ca
1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include <cstdint>
27 using namespace mlir;
28 using namespace test;
30 // Native function for testing NativeCodeCall
31 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
32 return choice.getValue() ? input1 : input2;
35 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
36 rewriter.create<OpI>(loc, input);
39 static void handleNoResultOp(PatternRewriter &rewriter,
40 OpSymbolBindingNoResult op) {
41 // Turn the no result op to a one-result op.
42 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
43 op.getOperand());
46 static bool getFirstI32Result(Operation *op, Value &value) {
47 if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
48 return false;
49 value = op->getResult(0);
50 return true;
53 static Value bindNativeCodeCallResult(Value value) { return value; }
55 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
56 Value input2) {
57 return SmallVector<Value, 2>({input2, input1});
60 // Test that natives calls are only called once during rewrites.
61 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
62 // This let us check the number of times OpM_Test was called by inspecting
63 // the returned value in the MLIR output.
64 static int64_t opMIncreasingValue = 314159265;
65 static Attribute opMTest(PatternRewriter &rewriter, Value val) {
66 int64_t i = opMIncreasingValue++;
67 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
70 namespace {
71 #include "TestPatterns.inc"
72 } // namespace
74 //===----------------------------------------------------------------------===//
75 // Test Reduce Pattern Interface
76 //===----------------------------------------------------------------------===//
78 void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
79 populateWithGenerated(patterns);
82 //===----------------------------------------------------------------------===//
83 // Canonicalizer Driver.
84 //===----------------------------------------------------------------------===//
86 namespace {
87 struct FoldingPattern : public RewritePattern {
88 public:
89 FoldingPattern(MLIRContext *context)
90 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
91 /*benefit=*/1, context) {}
93 LogicalResult matchAndRewrite(Operation *op,
94 PatternRewriter &rewriter) const override {
95 // Exercise createOrFold API for a single-result operation that is folded
96 // upon construction. The operation being created has an in-place folder,
97 // and it should be still present in the output. Furthermore, the folder
98 // should not crash when attempting to recover the (unchanged) operation
99 // result.
100 Value result = rewriter.createOrFold<TestOpInPlaceFold>(
101 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
102 assert(result);
103 rewriter.replaceOp(op, result);
104 return success();
108 /// This pattern creates a foldable operation at the entry point of the block.
109 /// This tests the situation where the operation folder will need to replace an
110 /// operation with a previously created constant that does not initially
111 /// dominate the operation to replace.
112 struct FolderInsertBeforePreviouslyFoldedConstantPattern
113 : public OpRewritePattern<TestCastOp> {
114 public:
115 using OpRewritePattern<TestCastOp>::OpRewritePattern;
117 LogicalResult matchAndRewrite(TestCastOp op,
118 PatternRewriter &rewriter) const override {
119 if (!op->hasAttr("test_fold_before_previously_folded_op"))
120 return failure();
121 rewriter.setInsertionPointToStart(op->getBlock());
123 auto constOp = rewriter.create<arith::ConstantOp>(
124 op.getLoc(), rewriter.getBoolAttr(true));
125 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
126 Value(constOp));
127 return success();
131 /// This pattern matches test.op_commutative2 with the first operand being
132 /// another test.op_commutative2 with a constant on the right side and fold it
133 /// away by propagating it as its result. This is intend to check that patterns
134 /// are applied after the commutative property moves constant to the right.
135 struct FolderCommutativeOp2WithConstant
136 : public OpRewritePattern<TestCommutative2Op> {
137 public:
138 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
140 LogicalResult matchAndRewrite(TestCommutative2Op op,
141 PatternRewriter &rewriter) const override {
142 auto operand =
143 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
144 if (!operand)
145 return failure();
146 Attribute constInput;
147 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput)))
148 return failure();
149 rewriter.replaceOp(op, operand->getOperand(1));
150 return success();
154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
155 /// attribute with value smaller than MaxVal, it increments the value by 1.
156 template <int MaxVal>
157 struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
158 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
160 LogicalResult matchAndRewrite(AnyAttrOfOp op,
161 PatternRewriter &rewriter) const override {
162 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
163 if (!intAttr)
164 return failure();
165 int64_t val = intAttr.getInt();
166 if (val >= MaxVal)
167 return failure();
168 rewriter.modifyOpInPlace(
169 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
170 return success();
174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175 struct MakeOpEligible : public RewritePattern {
176 MakeOpEligible(MLIRContext *context)
177 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
179 LogicalResult matchAndRewrite(Operation *op,
180 PatternRewriter &rewriter) const override {
181 if (op->hasAttr("eligible"))
182 return failure();
183 rewriter.modifyOpInPlace(
184 op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
185 return success();
189 /// This pattern hoists eligible ops out of a "test.one_region_op".
190 struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
191 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
193 LogicalResult matchAndRewrite(test::OneRegionOp op,
194 PatternRewriter &rewriter) const override {
195 Operation *terminator = op.getRegion().front().getTerminator();
196 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
197 if (toBeHoisted->getParentOp() != op)
198 return failure();
199 if (!toBeHoisted->hasAttr("eligible"))
200 return failure();
201 rewriter.moveOpBefore(toBeHoisted, op);
202 return success();
206 /// This pattern moves "test.move_before_parent_op" before the parent op.
207 struct MoveBeforeParentOp : public RewritePattern {
208 MoveBeforeParentOp(MLIRContext *context)
209 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
211 LogicalResult matchAndRewrite(Operation *op,
212 PatternRewriter &rewriter) const override {
213 // Do not hoist past functions.
214 if (isa<FunctionOpInterface>(op->getParentOp()))
215 return failure();
216 rewriter.moveOpBefore(op, op->getParentOp());
217 return success();
221 /// This pattern moves "test.move_after_parent_op" after the parent op.
222 struct MoveAfterParentOp : public RewritePattern {
223 MoveAfterParentOp(MLIRContext *context)
224 : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
226 LogicalResult matchAndRewrite(Operation *op,
227 PatternRewriter &rewriter) const override {
228 // Do not hoist past functions.
229 if (isa<FunctionOpInterface>(op->getParentOp()))
230 return failure();
232 int64_t moveForwardBy = 0;
233 if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
234 moveForwardBy = advanceBy.getInt();
236 Operation *moveAfter = op->getParentOp();
237 for (int64_t i = 0; i < moveForwardBy; ++i)
238 moveAfter = moveAfter->getNextNode();
240 rewriter.moveOpAfter(op, moveAfter);
241 return success();
245 /// This pattern inlines blocks that are nested in
246 /// "test.inline_blocks_into_parent" into the parent block.
247 struct InlineBlocksIntoParent : public RewritePattern {
248 InlineBlocksIntoParent(MLIRContext *context)
249 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
250 context) {}
252 LogicalResult matchAndRewrite(Operation *op,
253 PatternRewriter &rewriter) const override {
254 bool changed = false;
255 for (Region &r : op->getRegions()) {
256 while (!r.empty()) {
257 rewriter.inlineBlockBefore(&r.front(), op);
258 changed = true;
261 return success(changed);
265 /// This pattern splits blocks at "test.split_block_here" and replaces the op
266 /// with a new op (to prevent an infinite loop of block splitting).
267 struct SplitBlockHere : public RewritePattern {
268 SplitBlockHere(MLIRContext *context)
269 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
271 LogicalResult matchAndRewrite(Operation *op,
272 PatternRewriter &rewriter) const override {
273 rewriter.splitBlock(op->getBlock(), op->getIterator());
274 Operation *newOp = rewriter.create(
275 op->getLoc(),
276 OperationName("test.new_op", op->getContext()).getIdentifier(),
277 op->getOperands(), op->getResultTypes());
278 rewriter.replaceOp(op, newOp);
279 return success();
283 /// This pattern clones "test.clone_me" ops.
284 struct CloneOp : public RewritePattern {
285 CloneOp(MLIRContext *context)
286 : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
288 LogicalResult matchAndRewrite(Operation *op,
289 PatternRewriter &rewriter) const override {
290 // Do not clone already cloned ops to avoid going into an infinite loop.
291 if (op->hasAttr("was_cloned"))
292 return failure();
293 Operation *cloned = rewriter.clone(*op);
294 cloned->setAttr("was_cloned", rewriter.getUnitAttr());
295 return success();
299 /// This pattern clones regions of "test.clone_region_before" ops before the
300 /// parent block.
301 struct CloneRegionBeforeOp : public RewritePattern {
302 CloneRegionBeforeOp(MLIRContext *context)
303 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
305 LogicalResult matchAndRewrite(Operation *op,
306 PatternRewriter &rewriter) const override {
307 // Do not clone already cloned ops to avoid going into an infinite loop.
308 if (op->hasAttr("was_cloned"))
309 return failure();
310 for (Region &r : op->getRegions())
311 rewriter.cloneRegionBefore(r, op->getBlock());
312 op->setAttr("was_cloned", rewriter.getUnitAttr());
313 return success();
317 /// Replace an operation may introduce the re-visiting of its users.
318 class ReplaceWithNewOp : public RewritePattern {
319 public:
320 ReplaceWithNewOp(MLIRContext *context)
321 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
323 LogicalResult matchAndRewrite(Operation *op,
324 PatternRewriter &rewriter) const override {
325 Operation *newOp;
326 if (op->hasAttr("create_erase_op")) {
327 newOp = rewriter.create(
328 op->getLoc(),
329 OperationName("test.erase_op", op->getContext()).getIdentifier(),
330 ValueRange(), TypeRange());
331 } else {
332 newOp = rewriter.create(
333 op->getLoc(),
334 OperationName("test.new_op", op->getContext()).getIdentifier(),
335 op->getOperands(), op->getResultTypes());
337 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
338 // A "notifyOperationReplaced" callback is triggered in either case.
339 rewriter.replaceAllOpUsesWith(op, newOp->getResults());
340 rewriter.eraseOp(op);
341 return success();
345 /// Erases the first child block of the matched "test.erase_first_block"
346 /// operation.
347 class EraseFirstBlock : public RewritePattern {
348 public:
349 EraseFirstBlock(MLIRContext *context)
350 : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
352 LogicalResult matchAndRewrite(Operation *op,
353 PatternRewriter &rewriter) const override {
354 for (Region &r : op->getRegions()) {
355 for (Block &b : r.getBlocks()) {
356 rewriter.eraseBlock(&b);
357 return success();
361 return failure();
365 struct TestGreedyPatternDriver
366 : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
367 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
369 TestGreedyPatternDriver() = default;
370 TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
371 : PassWrapper(other) {}
373 StringRef getArgument() const final { return "test-greedy-patterns"; }
374 StringRef getDescription() const final { return "Run test dialect patterns"; }
375 void runOnOperation() override {
376 mlir::RewritePatternSet patterns(&getContext());
377 populateWithGenerated(patterns);
379 // Verify named pattern is generated with expected name.
380 patterns.add<FoldingPattern, TestNamedPatternRule,
381 FolderInsertBeforePreviouslyFoldedConstantPattern,
382 FolderCommutativeOp2WithConstant, HoistEligibleOps,
383 MakeOpEligible>(&getContext());
385 // Additional patterns for testing the GreedyPatternRewriteDriver.
386 patterns.insert<IncrementIntAttribute<3>>(&getContext());
388 GreedyRewriteConfig config;
389 config.useTopDownTraversal = this->useTopDownTraversal;
390 config.maxIterations = this->maxIterations;
391 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
392 config);
395 Option<bool> useTopDownTraversal{
396 *this, "top-down",
397 llvm::cl::desc("Seed the worklist in general top-down order"),
398 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
399 Option<int> maxIterations{
400 *this, "max-iterations",
401 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
402 llvm::cl::init(GreedyRewriteConfig().maxIterations)};
405 struct DumpNotifications : public RewriterBase::Listener {
406 void notifyBlockInserted(Block *block, Region *previous,
407 Region::iterator previousIt) override {
408 llvm::outs() << "notifyBlockInserted";
409 if (block->getParentOp()) {
410 llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
411 } else {
412 llvm::outs() << " into unknown op: ";
414 if (previous == nullptr) {
415 llvm::outs() << "was unlinked\n";
416 } else {
417 llvm::outs() << "was linked\n";
420 void notifyOperationInserted(Operation *op,
421 OpBuilder::InsertPoint previous) override {
422 llvm::outs() << "notifyOperationInserted: " << op->getName();
423 if (!previous.isSet()) {
424 llvm::outs() << ", was unlinked\n";
425 } else {
426 if (!previous.getPoint().getNodePtr()) {
427 llvm::outs() << ", was linked, exact position unknown\n";
428 } else if (previous.getPoint() == previous.getBlock()->end()) {
429 llvm::outs() << ", was last in block\n";
430 } else {
431 llvm::outs() << ", previous = " << previous.getPoint()->getName()
432 << "\n";
436 void notifyBlockErased(Block *block) override {
437 llvm::outs() << "notifyBlockErased\n";
439 void notifyOperationErased(Operation *op) override {
440 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
442 void notifyOperationModified(Operation *op) override {
443 llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
445 void notifyOperationReplaced(Operation *op, ValueRange values) override {
446 llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
450 struct TestStrictPatternDriver
451 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
452 public:
453 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
455 TestStrictPatternDriver() = default;
456 TestStrictPatternDriver(const TestStrictPatternDriver &other)
457 : PassWrapper(other) {
458 strictMode = other.strictMode;
461 StringRef getArgument() const final { return "test-strict-pattern-driver"; }
462 StringRef getDescription() const final {
463 return "Test strict mode of pattern driver";
466 void runOnOperation() override {
467 MLIRContext *ctx = &getContext();
468 mlir::RewritePatternSet patterns(ctx);
469 patterns.add<
470 // clang-format off
471 ChangeBlockOp,
472 CloneOp,
473 CloneRegionBeforeOp,
474 EraseOp,
475 ImplicitChangeOp,
476 InlineBlocksIntoParent,
477 InsertSameOp,
478 MoveBeforeParentOp,
479 ReplaceWithNewOp,
480 SplitBlockHere
481 // clang-format on
482 >(ctx);
483 SmallVector<Operation *> ops;
484 getOperation()->walk([&](Operation *op) {
485 StringRef opName = op->getName().getStringRef();
486 if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
487 opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
488 opName == "test.move_before_parent_op" ||
489 opName == "test.inline_blocks_into_parent" ||
490 opName == "test.split_block_here" || opName == "test.clone_me" ||
491 opName == "test.clone_region_before") {
492 ops.push_back(op);
496 DumpNotifications dumpNotifications;
497 GreedyRewriteConfig config;
498 config.listener = &dumpNotifications;
499 if (strictMode == "AnyOp") {
500 config.strictMode = GreedyRewriteStrictness::AnyOp;
501 } else if (strictMode == "ExistingAndNewOps") {
502 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
503 } else if (strictMode == "ExistingOps") {
504 config.strictMode = GreedyRewriteStrictness::ExistingOps;
505 } else {
506 llvm_unreachable("invalid strictness option");
509 // Check if these transformations introduce visiting of operations that
510 // are not in the `ops` set (The new created ops are valid). An invalid
511 // operation will trigger the assertion while processing.
512 bool changed = false;
513 bool allErased = false;
514 (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config,
515 &changed, &allErased);
516 Builder b(ctx);
517 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
518 getOperation()->setAttr("pattern_driver_all_erased",
519 b.getBoolAttr(allErased));
522 Option<std::string> strictMode{
523 *this, "strictness",
524 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
525 llvm::cl::init("AnyOp")};
527 private:
528 // New inserted operation is valid for further transformation.
529 class InsertSameOp : public RewritePattern {
530 public:
531 InsertSameOp(MLIRContext *context)
532 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
534 LogicalResult matchAndRewrite(Operation *op,
535 PatternRewriter &rewriter) const override {
536 if (op->hasAttr("skip"))
537 return failure();
539 Operation *newOp =
540 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
541 op->getOperands(), op->getResultTypes());
542 rewriter.modifyOpInPlace(
543 op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
544 newOp->setAttr("skip", rewriter.getBoolAttr(true));
546 return success();
550 // Remove an operation may introduce the re-visiting of its operands.
551 class EraseOp : public RewritePattern {
552 public:
553 EraseOp(MLIRContext *context)
554 : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
555 LogicalResult matchAndRewrite(Operation *op,
556 PatternRewriter &rewriter) const override {
557 rewriter.eraseOp(op);
558 return success();
562 // The following two patterns test RewriterBase::replaceAllUsesWith.
564 // That function replaces all usages of a Block (or a Value) with another one
565 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
566 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
567 // worklist: when an op is modified, it is added to the worklist. The two
568 // patterns below make the tracking observable: ChangeBlockOp replaces all
569 // usages of a block and that pattern is applied because the corresponding ops
570 // are put on the initial worklist (see above). ImplicitChangeOp does an
571 // unrelated change but ops of the corresponding type are *not* on the initial
572 // worklist, so the effect of the second pattern is only visible if the
573 // tracking and subsequent adding to the worklist actually works.
575 // Replace all usages of the first successor with the second successor.
576 class ChangeBlockOp : public RewritePattern {
577 public:
578 ChangeBlockOp(MLIRContext *context)
579 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
580 LogicalResult matchAndRewrite(Operation *op,
581 PatternRewriter &rewriter) const override {
582 if (op->getNumSuccessors() < 2)
583 return failure();
584 Block *firstSuccessor = op->getSuccessor(0);
585 Block *secondSuccessor = op->getSuccessor(1);
586 if (firstSuccessor == secondSuccessor)
587 return failure();
588 // This is the function being tested:
589 rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
590 // Using the following line instead would make the test fail:
591 // firstSuccessor->replaceAllUsesWith(secondSuccessor);
592 return success();
596 // Changes the successor to the parent block.
597 class ImplicitChangeOp : public RewritePattern {
598 public:
599 ImplicitChangeOp(MLIRContext *context)
600 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
601 LogicalResult matchAndRewrite(Operation *op,
602 PatternRewriter &rewriter) const override {
603 if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
604 return failure();
605 rewriter.modifyOpInPlace(op,
606 [&]() { op->setSuccessor(op->getBlock(), 0); });
607 return success();
612 struct TestWalkPatternDriver final
613 : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
614 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
616 TestWalkPatternDriver() = default;
617 TestWalkPatternDriver(const TestWalkPatternDriver &other)
618 : PassWrapper(other) {}
620 StringRef getArgument() const override {
621 return "test-walk-pattern-rewrite-driver";
623 StringRef getDescription() const override {
624 return "Run test walk pattern rewrite driver";
626 void runOnOperation() override {
627 mlir::RewritePatternSet patterns(&getContext());
629 // Patterns for testing the WalkPatternRewriteDriver.
630 patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
631 MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
632 &getContext());
634 DumpNotifications dumpListener;
635 walkAndApplyPatterns(getOperation(), std::move(patterns),
636 dumpNotifications ? &dumpListener : nullptr);
639 Option<bool> dumpNotifications{
640 *this, "dump-notifications",
641 llvm::cl::desc("Print rewrite listener notifications"),
642 llvm::cl::init(false)};
645 } // namespace
647 //===----------------------------------------------------------------------===//
648 // ReturnType Driver.
649 //===----------------------------------------------------------------------===//
651 namespace {
652 // Generate ops for each instance where the type can be successfully inferred.
653 template <typename OpTy>
654 static void invokeCreateWithInferredReturnType(Operation *op) {
655 auto *context = op->getContext();
656 auto fop = op->getParentOfType<func::FuncOp>();
657 auto location = UnknownLoc::get(context);
658 OpBuilder b(op);
659 b.setInsertionPointAfter(op);
661 // Use permutations of 2 args as operands.
662 assert(fop.getNumArguments() >= 2);
663 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
664 for (int j = 0; j < e; ++j) {
665 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
666 SmallVector<Type, 2> inferredReturnTypes;
667 if (succeeded(OpTy::inferReturnTypes(
668 context, std::nullopt, values, op->getDiscardableAttrDictionary(),
669 op->getPropertiesStorage(), op->getRegions(),
670 inferredReturnTypes))) {
671 OperationState state(location, OpTy::getOperationName());
672 // TODO: Expand to regions.
673 OpTy::build(b, state, values, op->getAttrs());
674 (void)b.create(state);
680 static void reifyReturnShape(Operation *op) {
681 OpBuilder b(op);
683 // Use permutations of 2 args as operands.
684 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
685 SmallVector<Value, 2> shapes;
686 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
687 !llvm::hasSingleElement(shapes))
688 return;
689 for (const auto &it : llvm::enumerate(shapes)) {
690 op->emitRemark() << "value " << it.index() << ": "
691 << it.value().getDefiningOp();
695 struct TestReturnTypeDriver
696 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
697 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
699 void getDependentDialects(DialectRegistry &registry) const override {
700 registry.insert<tensor::TensorDialect>();
702 StringRef getArgument() const final { return "test-return-type"; }
703 StringRef getDescription() const final { return "Run return type functions"; }
705 void runOnOperation() override {
706 if (getOperation().getName() == "testCreateFunctions") {
707 std::vector<Operation *> ops;
708 // Collect ops to avoid triggering on inserted ops.
709 for (auto &op : getOperation().getBody().front())
710 ops.push_back(&op);
711 // Generate test patterns for each, but skip terminator.
712 for (auto *op : llvm::ArrayRef(ops).drop_back()) {
713 // Test create method of each of the Op classes below. The resultant
714 // output would be in reverse order underneath `op` from which
715 // the attributes and regions are used.
716 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
717 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
718 op);
719 invokeCreateWithInferredReturnType<
720 OpWithShapedTypeInferTypeInterfaceOp>(op);
722 return;
724 if (getOperation().getName() == "testReifyFunctions") {
725 std::vector<Operation *> ops;
726 // Collect ops to avoid triggering on inserted ops.
727 for (auto &op : getOperation().getBody().front())
728 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
729 ops.push_back(&op);
730 // Generate test patterns for each, but skip terminator.
731 for (auto *op : ops)
732 reifyReturnShape(op);
736 } // namespace
738 namespace {
739 struct TestDerivedAttributeDriver
740 : public PassWrapper<TestDerivedAttributeDriver,
741 OperationPass<func::FuncOp>> {
742 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
744 StringRef getArgument() const final { return "test-derived-attr"; }
745 StringRef getDescription() const final {
746 return "Run test derived attributes";
748 void runOnOperation() override;
750 } // namespace
752 void TestDerivedAttributeDriver::runOnOperation() {
753 getOperation().walk([](DerivedAttributeOpInterface dOp) {
754 auto dAttr = dOp.materializeDerivedAttributes();
755 if (!dAttr)
756 return;
757 for (auto d : dAttr)
758 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
762 //===----------------------------------------------------------------------===//
763 // Legalization Driver.
764 //===----------------------------------------------------------------------===//
766 namespace {
767 //===----------------------------------------------------------------------===//
768 // Region-Block Rewrite Testing
770 /// This pattern applies a signature conversion to a block inside a detached
771 /// region.
772 struct TestDetachedSignatureConversion : public ConversionPattern {
773 TestDetachedSignatureConversion(MLIRContext *ctx)
774 : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
775 ctx) {}
777 LogicalResult
778 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
779 ConversionPatternRewriter &rewriter) const final {
780 if (op->getNumRegions() != 1)
781 return failure();
782 OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
783 op->getResultTypes(), {}, BlockRange());
784 Region *newRegion = state.addRegion();
785 rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
786 newRegion->begin());
787 TypeConverter::SignatureConversion result(newRegion->getNumArguments());
788 for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i)
789 result.addInputs(i, rewriter.getF64Type());
790 rewriter.applySignatureConversion(&newRegion->front(), result);
791 Operation *newOp = rewriter.create(state);
792 rewriter.replaceOp(op, newOp->getResults());
793 return success();
797 /// This pattern is a simple pattern that inlines the first region of a given
798 /// operation into the parent region.
799 struct TestRegionRewriteBlockMovement : public ConversionPattern {
800 TestRegionRewriteBlockMovement(MLIRContext *ctx)
801 : ConversionPattern("test.region", 1, ctx) {}
803 LogicalResult
804 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
805 ConversionPatternRewriter &rewriter) const final {
806 // Inline this region into the parent region.
807 auto &parentRegion = *op->getParentRegion();
808 auto &opRegion = op->getRegion(0);
809 if (op->getDiscardableAttr("legalizer.should_clone"))
810 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end());
811 else
812 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end());
814 if (op->getDiscardableAttr("legalizer.erase_old_blocks")) {
815 while (!opRegion.empty())
816 rewriter.eraseBlock(&opRegion.front());
819 // Drop this operation.
820 rewriter.eraseOp(op);
821 return success();
824 /// This pattern is a simple pattern that generates a region containing an
825 /// illegal operation.
826 struct TestRegionRewriteUndo : public RewritePattern {
827 TestRegionRewriteUndo(MLIRContext *ctx)
828 : RewritePattern("test.region_builder", 1, ctx) {}
830 LogicalResult matchAndRewrite(Operation *op,
831 PatternRewriter &rewriter) const final {
832 // Create the region operation with an entry block containing arguments.
833 OperationState newRegion(op->getLoc(), "test.region");
834 newRegion.addRegion();
835 auto *regionOp = rewriter.create(newRegion);
836 auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
837 entryBlock->addArgument(rewriter.getIntegerType(64),
838 rewriter.getUnknownLoc());
840 // Add an explicitly illegal operation to ensure the conversion fails.
841 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
842 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
844 // Drop this operation.
845 rewriter.eraseOp(op);
846 return success();
849 /// A simple pattern that creates a block at the end of the parent region of the
850 /// matched operation.
851 struct TestCreateBlock : public RewritePattern {
852 TestCreateBlock(MLIRContext *ctx)
853 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
855 LogicalResult matchAndRewrite(Operation *op,
856 PatternRewriter &rewriter) const final {
857 Region &region = *op->getParentRegion();
858 Type i32Type = rewriter.getIntegerType(32);
859 Location loc = op->getLoc();
860 rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
861 rewriter.create<TerminatorOp>(loc);
862 rewriter.eraseOp(op);
863 return success();
867 /// A simple pattern that creates a block containing an invalid operation in
868 /// order to trigger the block creation undo mechanism.
869 struct TestCreateIllegalBlock : public RewritePattern {
870 TestCreateIllegalBlock(MLIRContext *ctx)
871 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
873 LogicalResult matchAndRewrite(Operation *op,
874 PatternRewriter &rewriter) const final {
875 Region &region = *op->getParentRegion();
876 Type i32Type = rewriter.getIntegerType(32);
877 Location loc = op->getLoc();
878 rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
879 // Create an illegal op to ensure the conversion fails.
880 rewriter.create<ILLegalOpF>(loc, i32Type);
881 rewriter.create<TerminatorOp>(loc);
882 rewriter.eraseOp(op);
883 return success();
887 /// A simple pattern that tests the undo mechanism when replacing the uses of a
888 /// block argument.
889 struct TestUndoBlockArgReplace : public ConversionPattern {
890 TestUndoBlockArgReplace(MLIRContext *ctx)
891 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
893 LogicalResult
894 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
895 ConversionPatternRewriter &rewriter) const final {
896 auto illegalOp =
897 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
898 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
899 illegalOp->getResult(0));
900 rewriter.modifyOpInPlace(op, [] {});
901 return success();
905 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
906 /// This is to test the rollback logic.
907 struct TestUndoMoveOpBefore : public ConversionPattern {
908 TestUndoMoveOpBefore(MLIRContext *ctx)
909 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
911 LogicalResult
912 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
913 ConversionPatternRewriter &rewriter) const override {
914 rewriter.moveOpBefore(op, op->getParentOp());
915 // Replace with an illegal op to ensure the conversion fails.
916 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
917 return success();
921 /// A rewrite pattern that tests the undo mechanism when erasing a block.
922 struct TestUndoBlockErase : public ConversionPattern {
923 TestUndoBlockErase(MLIRContext *ctx)
924 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
926 LogicalResult
927 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
928 ConversionPatternRewriter &rewriter) const final {
929 Block *secondBlock = &*std::next(op->getRegion(0).begin());
930 rewriter.setInsertionPointToStart(secondBlock);
931 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
932 rewriter.eraseBlock(secondBlock);
933 rewriter.modifyOpInPlace(op, [] {});
934 return success();
938 /// A pattern that modifies a property in-place, but keeps the op illegal.
939 struct TestUndoPropertiesModification : public ConversionPattern {
940 TestUndoPropertiesModification(MLIRContext *ctx)
941 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
942 LogicalResult
943 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
944 ConversionPatternRewriter &rewriter) const final {
945 if (!op->hasAttr("modify_inplace"))
946 return failure();
947 rewriter.modifyOpInPlace(
948 op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
949 return success();
953 //===----------------------------------------------------------------------===//
954 // Type-Conversion Rewrite Testing
956 /// This patterns erases a region operation that has had a type conversion.
957 struct TestDropOpSignatureConversion : public ConversionPattern {
958 TestDropOpSignatureConversion(MLIRContext *ctx,
959 const TypeConverter &converter)
960 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
961 LogicalResult
962 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
963 ConversionPatternRewriter &rewriter) const override {
964 Region &region = op->getRegion(0);
965 Block *entry = &region.front();
967 // Convert the original entry arguments.
968 const TypeConverter &converter = *getTypeConverter();
969 TypeConverter::SignatureConversion result(entry->getNumArguments());
970 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
971 result)) ||
972 failed(rewriter.convertRegionTypes(&region, converter, &result)))
973 return failure();
975 // Convert the region signature and just drop the operation.
976 rewriter.eraseOp(op);
977 return success();
980 /// This pattern simply updates the operands of the given operation.
981 struct TestPassthroughInvalidOp : public ConversionPattern {
982 TestPassthroughInvalidOp(MLIRContext *ctx)
983 : ConversionPattern("test.invalid", 1, ctx) {}
984 LogicalResult
985 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
986 ConversionPatternRewriter &rewriter) const final {
987 SmallVector<Value> flattened;
988 for (auto it : llvm::enumerate(operands)) {
989 ValueRange range = it.value();
990 if (range.size() == 1) {
991 flattened.push_back(range.front());
992 continue;
995 // This is a 1:N replacement. Insert a test.cast op. (That's what the
996 // argument materialization used to do.)
997 flattened.push_back(
998 rewriter
999 .create<TestCastOp>(op->getLoc(),
1000 op->getOperand(it.index()).getType(), range)
1001 .getResult());
1003 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened,
1004 std::nullopt);
1005 return success();
1008 /// Replace with valid op, but simply drop the operands. This is used in a
1009 /// regression where we used to generate circular unrealized_conversion_cast
1010 /// ops.
1011 struct TestDropAndReplaceInvalidOp : public ConversionPattern {
1012 TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
1013 : ConversionPattern(converter,
1014 "test.drop_operands_and_replace_with_valid", 1, ctx) {
1016 LogicalResult
1017 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1018 ConversionPatternRewriter &rewriter) const final {
1019 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(),
1020 std::nullopt);
1021 return success();
1024 /// This pattern handles the case of a split return value.
1025 struct TestSplitReturnType : public ConversionPattern {
1026 TestSplitReturnType(MLIRContext *ctx)
1027 : ConversionPattern("test.return", 1, ctx) {}
1028 LogicalResult
1029 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1030 ConversionPatternRewriter &rewriter) const final {
1031 // Check for a return of F32.
1032 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
1033 return failure();
1034 rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]);
1035 return success();
1039 //===----------------------------------------------------------------------===//
1040 // Multi-Level Type-Conversion Rewrite Testing
1041 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
1042 TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
1043 : ConversionPattern("test.type_producer", 1, ctx) {}
1044 LogicalResult
1045 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1046 ConversionPatternRewriter &rewriter) const final {
1047 // If the type is I32, change the type to F32.
1048 if (!Type(*op->result_type_begin()).isSignlessInteger(32))
1049 return failure();
1050 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
1051 return success();
1054 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
1055 TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
1056 : ConversionPattern("test.type_producer", 1, ctx) {}
1057 LogicalResult
1058 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1059 ConversionPatternRewriter &rewriter) const final {
1060 // If the type is F32, change the type to F64.
1061 if (!Type(*op->result_type_begin()).isF32())
1062 return rewriter.notifyMatchFailure(op, "expected single f32 operand");
1063 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
1064 return success();
1067 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
1068 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
1069 : ConversionPattern("test.type_producer", 10, ctx) {}
1070 LogicalResult
1071 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1072 ConversionPatternRewriter &rewriter) const final {
1073 // Always convert to B16, even though it is not a legal type. This tests
1074 // that values are unmapped correctly.
1075 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
1076 return success();
1079 struct TestUpdateConsumerType : public ConversionPattern {
1080 TestUpdateConsumerType(MLIRContext *ctx)
1081 : ConversionPattern("test.type_consumer", 1, ctx) {}
1082 LogicalResult
1083 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1084 ConversionPatternRewriter &rewriter) const final {
1085 // Verify that the incoming operand has been successfully remapped to F64.
1086 if (!operands[0].getType().isF64())
1087 return failure();
1088 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
1089 return success();
1093 //===----------------------------------------------------------------------===//
1094 // Non-Root Replacement Rewrite Testing
1095 /// This pattern generates an invalid operation, but replaces it before the
1096 /// pattern is finished. This checks that we don't need to legalize the
1097 /// temporary op.
1098 struct TestNonRootReplacement : public RewritePattern {
1099 TestNonRootReplacement(MLIRContext *ctx)
1100 : RewritePattern("test.replace_non_root", 1, ctx) {}
1102 LogicalResult matchAndRewrite(Operation *op,
1103 PatternRewriter &rewriter) const final {
1104 auto resultType = *op->result_type_begin();
1105 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
1106 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
1108 rewriter.replaceOp(illegalOp, legalOp);
1109 rewriter.replaceOp(op, illegalOp);
1110 return success();
1114 //===----------------------------------------------------------------------===//
1115 // Recursive Rewrite Testing
1116 /// This pattern is applied to the same operation multiple times, but has a
1117 /// bounded recursion.
1118 struct TestBoundedRecursiveRewrite
1119 : public OpRewritePattern<TestRecursiveRewriteOp> {
1120 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
1122 void initialize() {
1123 // The conversion target handles bounding the recursion of this pattern.
1124 setHasBoundedRewriteRecursion();
1127 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
1128 PatternRewriter &rewriter) const final {
1129 // Decrement the depth of the op in-place.
1130 rewriter.modifyOpInPlace(op, [&] {
1131 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
1133 return success();
1137 struct TestNestedOpCreationUndoRewrite
1138 : public OpRewritePattern<IllegalOpWithRegionAnchor> {
1139 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
1141 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
1142 PatternRewriter &rewriter) const final {
1143 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1144 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1145 return success();
1149 // This pattern matches `test.blackhole` and delete this op and its producer.
1150 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1151 using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1153 LogicalResult matchAndRewrite(BlackHoleOp op,
1154 PatternRewriter &rewriter) const final {
1155 Operation *producer = op.getOperand().getDefiningOp();
1156 // Always erase the user before the producer, the framework should handle
1157 // this correctly.
1158 rewriter.eraseOp(op);
1159 rewriter.eraseOp(producer);
1160 return success();
1164 // This pattern replaces explicitly illegal op with explicitly legal op,
1165 // but in addition creates unregistered operation.
1166 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1167 using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1169 LogicalResult matchAndRewrite(ILLegalOpG op,
1170 PatternRewriter &rewriter) const final {
1171 IntegerAttr attr = rewriter.getI32IntegerAttr(0);
1172 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
1173 rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
1174 return success();
1178 class TestEraseOp : public ConversionPattern {
1179 public:
1180 TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {}
1181 LogicalResult
1182 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1183 ConversionPatternRewriter &rewriter) const final {
1184 // Erase op without replacements.
1185 rewriter.eraseOp(op);
1186 return success();
1190 /// This pattern matches a test.duplicate_block_args op and duplicates all
1191 /// block arguments.
1192 class TestDuplicateBlockArgs
1193 : public OpConversionPattern<DuplicateBlockArgsOp> {
1194 using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
1196 LogicalResult
1197 matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
1198 ConversionPatternRewriter &rewriter) const override {
1199 if (op.getIsLegal())
1200 return failure();
1201 rewriter.startOpModification(op);
1202 Block *body = &op.getBody().front();
1203 TypeConverter::SignatureConversion result(body->getNumArguments());
1204 for (auto it : llvm::enumerate(body->getArgumentTypes()))
1205 result.addInputs(it.index(), {it.value(), it.value()});
1206 rewriter.applySignatureConversion(body, result, getTypeConverter());
1207 op.setIsLegal(true);
1208 rewriter.finalizeOpModification(op);
1209 return success();
1213 /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
1214 /// op. The pattern supports 1:N replacements and forwards the replacement
1215 /// values of the single operand as test.valid operands.
1216 class TestRepetitive1ToNConsumer : public ConversionPattern {
1217 public:
1218 TestRepetitive1ToNConsumer(MLIRContext *ctx)
1219 : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {}
1220 LogicalResult
1221 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1222 ConversionPatternRewriter &rewriter) const final {
1223 // A single operand is expected.
1224 if (op->getNumOperands() != 1)
1225 return failure();
1226 rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front());
1227 return success();
1231 } // namespace
1233 namespace {
1234 struct TestTypeConverter : public TypeConverter {
1235 using TypeConverter::TypeConverter;
1236 TestTypeConverter() {
1237 addConversion(convertType);
1238 addArgumentMaterialization(materializeCast);
1239 addSourceMaterialization(materializeCast);
1242 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1243 // Drop I16 types.
1244 if (t.isSignlessInteger(16))
1245 return success();
1247 // Convert I64 to F64.
1248 if (t.isSignlessInteger(64)) {
1249 results.push_back(FloatType::getF64(t.getContext()));
1250 return success();
1253 // Convert I42 to I43.
1254 if (t.isInteger(42)) {
1255 results.push_back(IntegerType::get(t.getContext(), 43));
1256 return success();
1259 // Split F32 into F16,F16.
1260 if (t.isF32()) {
1261 results.assign(2, FloatType::getF16(t.getContext()));
1262 return success();
1265 // Drop I24 types.
1266 if (t.isInteger(24)) {
1267 return success();
1270 // Otherwise, convert the type directly.
1271 results.push_back(t);
1272 return success();
1275 /// Hook for materializing a conversion. This is necessary because we generate
1276 /// 1->N type mappings.
1277 static Value materializeCast(OpBuilder &builder, Type resultType,
1278 ValueRange inputs, Location loc) {
1279 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1283 struct TestLegalizePatternDriver
1284 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
1285 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
1287 StringRef getArgument() const final { return "test-legalize-patterns"; }
1288 StringRef getDescription() const final {
1289 return "Run test dialect legalization patterns";
1291 /// The mode of conversion to use with the driver.
1292 enum class ConversionMode { Analysis, Full, Partial };
1294 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1296 void getDependentDialects(DialectRegistry &registry) const override {
1297 registry.insert<func::FuncDialect, test::TestDialect>();
1300 void runOnOperation() override {
1301 TestTypeConverter converter;
1302 mlir::RewritePatternSet patterns(&getContext());
1303 populateWithGenerated(patterns);
1304 patterns.add<
1305 TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1306 TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1307 TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
1308 TestSplitReturnType, TestChangeProducerTypeI32ToF32,
1309 TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
1310 TestUpdateConsumerType, TestNonRootReplacement,
1311 TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1312 TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1313 TestUndoPropertiesModification, TestEraseOp,
1314 TestRepetitive1ToNConsumer>(&getContext());
1315 patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
1316 &getContext(), converter);
1317 patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
1318 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1319 converter);
1320 mlir::populateCallOpTypeConversionPattern(patterns, converter);
1322 // Define the conversion target used for the test.
1323 ConversionTarget target(getContext());
1324 target.addLegalOp<ModuleOp>();
1325 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1326 TerminatorOp, OneRegionOp>();
1327 target.addLegalOp(
1328 OperationName("test.legal_op_with_region", &getContext()));
1329 target
1330 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1331 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
1332 // Don't allow F32 operands.
1333 return llvm::none_of(op.getOperandTypes(),
1334 [](Type type) { return type.isF32(); });
1336 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1337 return converter.isSignatureLegal(op.getFunctionType()) &&
1338 converter.isLegal(&op.getBody());
1340 target.addDynamicallyLegalOp<func::CallOp>(
1341 [&](func::CallOp op) { return converter.isLegal(op); });
1343 // TestCreateUnregisteredOp creates `arith.constant` operation,
1344 // which was not added to target intentionally to test
1345 // correct error code from conversion driver.
1346 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
1348 // Expect the type_producer/type_consumer operations to only operate on f64.
1349 target.addDynamicallyLegalOp<TestTypeProducerOp>(
1350 [](TestTypeProducerOp op) { return op.getType().isF64(); });
1351 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1352 return op.getOperand().getType().isF64();
1355 // Check support for marking certain operations as recursively legal.
1356 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
1357 return static_cast<bool>(
1358 op->getAttrOfType<UnitAttr>("test.recursively_legal"));
1361 // Mark the bound recursion operation as dynamically legal.
1362 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
1363 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1365 // Create a dynamically legal rule that can only be legalized by folding it.
1366 target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
1367 [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
1369 target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
1370 [](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
1372 // Handle a partial conversion.
1373 if (mode == ConversionMode::Partial) {
1374 DenseSet<Operation *> unlegalizedOps;
1375 ConversionConfig config;
1376 DumpNotifications dumpNotifications;
1377 config.listener = &dumpNotifications;
1378 config.unlegalizedOps = &unlegalizedOps;
1379 if (failed(applyPartialConversion(getOperation(), target,
1380 std::move(patterns), config))) {
1381 getOperation()->emitRemark() << "applyPartialConversion failed";
1383 // Emit remarks for each legalizable operation.
1384 for (auto *op : unlegalizedOps)
1385 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1386 return;
1389 // Handle a full conversion.
1390 if (mode == ConversionMode::Full) {
1391 // Check support for marking unknown operations as dynamically legal.
1392 target.markUnknownOpDynamicallyLegal([](Operation *op) {
1393 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
1396 ConversionConfig config;
1397 DumpNotifications dumpNotifications;
1398 config.listener = &dumpNotifications;
1399 if (failed(applyFullConversion(getOperation(), target,
1400 std::move(patterns), config))) {
1401 getOperation()->emitRemark() << "applyFullConversion failed";
1403 return;
1406 // Otherwise, handle an analysis conversion.
1407 assert(mode == ConversionMode::Analysis);
1409 // Analyze the convertible operations.
1410 DenseSet<Operation *> legalizedOps;
1411 ConversionConfig config;
1412 config.legalizableOps = &legalizedOps;
1413 if (failed(applyAnalysisConversion(getOperation(), target,
1414 std::move(patterns), config)))
1415 return signalPassFailure();
1417 // Emit remarks for each legalizable operation.
1418 for (auto *op : legalizedOps)
1419 op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1422 /// The mode of conversion to use.
1423 ConversionMode mode;
1425 } // namespace
1427 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1428 legalizerConversionMode(
1429 "test-legalize-mode",
1430 llvm::cl::desc("The legalization mode to use with the test driver"),
1431 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
1432 llvm::cl::values(
1433 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1434 "analysis", "Perform an analysis conversion"),
1435 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1436 "Perform a full conversion"),
1437 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1438 "partial", "Perform a partial conversion")));
1440 //===----------------------------------------------------------------------===//
1441 // ConversionPatternRewriter::getRemappedValue testing. This method is used
1442 // to get the remapped value of an original value that was replaced using
1443 // ConversionPatternRewriter.
1444 namespace {
1445 struct TestRemapValueTypeConverter : public TypeConverter {
1446 using TypeConverter::TypeConverter;
1448 TestRemapValueTypeConverter() {
1449 addConversion(
1450 [](Float32Type type) { return Float64Type::get(type.getContext()); });
1451 addConversion([](Type type) { return type; });
1455 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1456 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1457 /// operand twice.
1459 /// Example:
1460 /// %1 = test.one_variadic_out_one_variadic_in1"(%0)
1461 /// is replaced with:
1462 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1463 struct OneVResOneVOperandOp1Converter
1464 : public OpConversionPattern<OneVResOneVOperandOp1> {
1465 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1467 LogicalResult
1468 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1469 ConversionPatternRewriter &rewriter) const override {
1470 auto origOps = op.getOperands();
1471 assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1472 "One operand expected");
1473 Value origOp = *origOps.begin();
1474 SmallVector<Value, 2> remappedOperands;
1475 // Replicate the remapped original operand twice. Note that we don't used
1476 // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1477 remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1478 remappedOperands.push_back(rewriter.getRemappedValue(origOp));
1480 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
1481 remappedOperands);
1482 return success();
1486 /// A rewriter pattern that tests that blocks can be merged.
1487 struct TestRemapValueInRegion
1488 : public OpConversionPattern<TestRemappedValueRegionOp> {
1489 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1491 LogicalResult
1492 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1493 ConversionPatternRewriter &rewriter) const final {
1494 Block &block = op.getBody().front();
1495 Operation *terminator = block.getTerminator();
1497 // Merge the block into the parent region.
1498 Block *parentBlock = op->getBlock();
1499 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
1500 rewriter.mergeBlocks(&block, parentBlock, ValueRange());
1501 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
1503 // Replace the results of this operation with the remapped terminator
1504 // values.
1505 SmallVector<Value> terminatorOperands;
1506 if (failed(rewriter.getRemappedValues(terminator->getOperands(),
1507 terminatorOperands)))
1508 return failure();
1510 rewriter.eraseOp(terminator);
1511 rewriter.replaceOp(op, terminatorOperands);
1512 return success();
1516 struct TestRemappedValue
1517 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
1518 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
1520 StringRef getArgument() const final { return "test-remapped-value"; }
1521 StringRef getDescription() const final {
1522 return "Test public remapped value mechanism in ConversionPatternRewriter";
1524 void runOnOperation() override {
1525 TestRemapValueTypeConverter typeConverter;
1527 mlir::RewritePatternSet patterns(&getContext());
1528 patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
1529 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1530 &getContext());
1531 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
1533 mlir::ConversionTarget target(getContext());
1534 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1536 // Expect the type_producer/type_consumer operations to only operate on f64.
1537 target.addDynamicallyLegalOp<TestTypeProducerOp>(
1538 [](TestTypeProducerOp op) { return op.getType().isF64(); });
1539 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1540 return op.getOperand().getType().isF64();
1543 // We make OneVResOneVOperandOp1 legal only when it has more that one
1544 // operand. This will trigger the conversion that will replace one-operand
1545 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1546 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1547 [](Operation *op) { return op->getNumOperands() > 1; });
1549 if (failed(mlir::applyFullConversion(getOperation(), target,
1550 std::move(patterns)))) {
1551 signalPassFailure();
1555 } // namespace
1557 //===----------------------------------------------------------------------===//
1558 // Test patterns without a specific root operation kind
1559 //===----------------------------------------------------------------------===//
1561 namespace {
1562 /// This pattern matches and removes any operation in the test dialect.
1563 struct RemoveTestDialectOps : public RewritePattern {
1564 RemoveTestDialectOps(MLIRContext *context)
1565 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1567 LogicalResult matchAndRewrite(Operation *op,
1568 PatternRewriter &rewriter) const override {
1569 if (!isa<TestDialect>(op->getDialect()))
1570 return failure();
1571 rewriter.eraseOp(op);
1572 return success();
1576 struct TestUnknownRootOpDriver
1577 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
1578 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
1580 StringRef getArgument() const final {
1581 return "test-legalize-unknown-root-patterns";
1583 StringRef getDescription() const final {
1584 return "Test public remapped value mechanism in ConversionPatternRewriter";
1586 void runOnOperation() override {
1587 mlir::RewritePatternSet patterns(&getContext());
1588 patterns.add<RemoveTestDialectOps>(&getContext());
1590 mlir::ConversionTarget target(getContext());
1591 target.addIllegalDialect<TestDialect>();
1592 if (failed(applyPartialConversion(getOperation(), target,
1593 std::move(patterns))))
1594 signalPassFailure();
1597 } // namespace
1599 //===----------------------------------------------------------------------===//
1600 // Test patterns that uses operations and types defined at runtime
1601 //===----------------------------------------------------------------------===//
1603 namespace {
1604 /// This pattern matches dynamic operations 'test.one_operand_two_results' and
1605 /// replace them with dynamic operations 'test.generic_dynamic_op'.
1606 struct RewriteDynamicOp : public RewritePattern {
1607 RewriteDynamicOp(MLIRContext *context)
1608 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1609 context) {}
1611 LogicalResult matchAndRewrite(Operation *op,
1612 PatternRewriter &rewriter) const override {
1613 assert(op->getName().getStringRef() ==
1614 "test.dynamic_one_operand_two_results" &&
1615 "rewrite pattern should only match operations with the right name");
1617 OperationState state(op->getLoc(), "test.dynamic_generic",
1618 op->getOperands(), op->getResultTypes(),
1619 op->getAttrs());
1620 auto *newOp = rewriter.create(state);
1621 rewriter.replaceOp(op, newOp->getResults());
1622 return success();
1626 struct TestRewriteDynamicOpDriver
1627 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
1628 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
1630 void getDependentDialects(DialectRegistry &registry) const override {
1631 registry.insert<TestDialect>();
1633 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
1634 StringRef getDescription() const final {
1635 return "Test rewritting on dynamic operations";
1637 void runOnOperation() override {
1638 RewritePatternSet patterns(&getContext());
1639 patterns.add<RewriteDynamicOp>(&getContext());
1641 ConversionTarget target(getContext());
1642 target.addIllegalOp(
1643 OperationName("test.dynamic_one_operand_two_results", &getContext()));
1644 target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
1645 if (failed(applyPartialConversion(getOperation(), target,
1646 std::move(patterns))))
1647 signalPassFailure();
1650 } // end anonymous namespace
1652 //===----------------------------------------------------------------------===//
1653 // Test type conversions
1654 //===----------------------------------------------------------------------===//
1656 namespace {
1657 struct TestTypeConversionProducer
1658 : public OpConversionPattern<TestTypeProducerOp> {
1659 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1660 LogicalResult
1661 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1662 ConversionPatternRewriter &rewriter) const final {
1663 Type resultType = op.getType();
1664 Type convertedType = getTypeConverter()
1665 ? getTypeConverter()->convertType(resultType)
1666 : resultType;
1667 if (isa<FloatType>(resultType))
1668 resultType = rewriter.getF64Type();
1669 else if (resultType.isInteger(16))
1670 resultType = rewriter.getIntegerType(64);
1671 else if (isa<test::TestRecursiveType>(resultType) &&
1672 convertedType != resultType)
1673 resultType = convertedType;
1674 else
1675 return failure();
1677 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
1678 return success();
1682 /// Call signature conversion and then fail the rewrite to trigger the undo
1683 /// mechanism.
1684 struct TestSignatureConversionUndo
1685 : public OpConversionPattern<TestSignatureConversionUndoOp> {
1686 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1688 LogicalResult
1689 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1690 ConversionPatternRewriter &rewriter) const final {
1691 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter());
1692 return failure();
1696 /// Call signature conversion without providing a type converter to handle
1697 /// materializations.
1698 struct TestTestSignatureConversionNoConverter
1699 : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1700 TestTestSignatureConversionNoConverter(const TypeConverter &converter,
1701 MLIRContext *context)
1702 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1703 converter(converter) {}
1705 LogicalResult
1706 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1707 ConversionPatternRewriter &rewriter) const final {
1708 Region &region = op->getRegion(0);
1709 Block *entry = &region.front();
1711 // Convert the original entry arguments.
1712 TypeConverter::SignatureConversion result(entry->getNumArguments());
1713 if (failed(
1714 converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
1715 return failure();
1716 rewriter.modifyOpInPlace(op, [&] {
1717 rewriter.applySignatureConversion(&region.front(), result);
1719 return success();
1722 const TypeConverter &converter;
1725 /// Just forward the operands to the root op. This is essentially a no-op
1726 /// pattern that is used to trigger target materialization.
1727 struct TestTypeConsumerForward
1728 : public OpConversionPattern<TestTypeConsumerOp> {
1729 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1731 LogicalResult
1732 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1733 ConversionPatternRewriter &rewriter) const final {
1734 rewriter.modifyOpInPlace(op,
1735 [&] { op->setOperands(adaptor.getOperands()); });
1736 return success();
1740 struct TestTypeConversionAnotherProducer
1741 : public OpRewritePattern<TestAnotherTypeProducerOp> {
1742 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1744 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1745 PatternRewriter &rewriter) const final {
1746 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1747 return success();
1751 struct TestReplaceWithLegalOp : public ConversionPattern {
1752 TestReplaceWithLegalOp(MLIRContext *ctx)
1753 : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
1754 LogicalResult
1755 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1756 ConversionPatternRewriter &rewriter) const final {
1757 rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
1758 return success();
1762 struct TestTypeConversionDriver
1763 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
1764 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1766 void getDependentDialects(DialectRegistry &registry) const override {
1767 registry.insert<TestDialect>();
1769 StringRef getArgument() const final {
1770 return "test-legalize-type-conversion";
1772 StringRef getDescription() const final {
1773 return "Test various type conversion functionalities in DialectConversion";
1776 void runOnOperation() override {
1777 // Initialize the type converter.
1778 SmallVector<Type, 2> conversionCallStack;
1779 TypeConverter converter;
1781 /// Add the legal set of type conversions.
1782 converter.addConversion([](Type type) -> Type {
1783 // Treat F64 as legal.
1784 if (type.isF64())
1785 return type;
1786 // Allow converting BF16/F16/F32 to F64.
1787 if (type.isBF16() || type.isF16() || type.isF32())
1788 return FloatType::getF64(type.getContext());
1789 // Otherwise, the type is illegal.
1790 return nullptr;
1792 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
1793 // Drop all integer types.
1794 return success();
1796 converter.addConversion(
1797 // Convert a recursive self-referring type into a non-self-referring
1798 // type named "outer_converted_type" that contains a SimpleAType.
1799 [&](test::TestRecursiveType type,
1800 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1801 // If the type is already converted, return it to indicate that it is
1802 // legal.
1803 if (type.getName() == "outer_converted_type") {
1804 results.push_back(type);
1805 return success();
1808 conversionCallStack.push_back(type);
1809 auto popConversionCallStack = llvm::make_scope_exit(
1810 [&conversionCallStack]() { conversionCallStack.pop_back(); });
1812 // If the type is on the call stack more than once (it is there at
1813 // least once because of the _current_ call, which is always the last
1814 // element on the stack), we've hit the recursive case. Just return
1815 // SimpleAType here to create a non-recursive type as a result.
1816 if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
1817 type)) {
1818 results.push_back(test::SimpleAType::get(type.getContext()));
1819 return success();
1822 // Convert the body recursively.
1823 auto result = test::TestRecursiveType::get(type.getContext(),
1824 "outer_converted_type");
1825 if (failed(result.setBody(converter.convertType(type.getBody()))))
1826 return failure();
1827 results.push_back(result);
1828 return success();
1831 /// Add the legal set of type materializations.
1832 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
1833 ValueRange inputs,
1834 Location loc) -> Value {
1835 // Allow casting from F64 back to F32.
1836 if (!resultType.isF16() && inputs.size() == 1 &&
1837 inputs[0].getType().isF64())
1838 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1839 // Allow producing an i32 or i64 from nothing.
1840 if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1841 inputs.empty())
1842 return builder.create<TestTypeProducerOp>(loc, resultType);
1843 // Allow producing an i64 from an integer.
1844 if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
1845 isa<IntegerType>(inputs[0].getType()))
1846 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1847 // Otherwise, fail.
1848 return nullptr;
1851 // Initialize the conversion target.
1852 mlir::ConversionTarget target(getContext());
1853 target.addLegalOp<LegalOpD>();
1854 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1855 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
1856 return op.getType().isF64() || op.getType().isInteger(64) ||
1857 (recursiveType &&
1858 recursiveType.getName() == "outer_converted_type");
1860 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1861 return converter.isSignatureLegal(op.getFunctionType()) &&
1862 converter.isLegal(&op.getBody());
1864 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1865 // Allow casts from F64 to F32.
1866 return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1868 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1869 [&](TestSignatureConversionNoConverterOp op) {
1870 return converter.isLegal(op.getRegion().front().getArgumentTypes());
1873 // Initialize the set of rewrite patterns.
1874 RewritePatternSet patterns(&getContext());
1875 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1876 TestSignatureConversionUndo,
1877 TestTestSignatureConversionNoConverter>(converter,
1878 &getContext());
1879 patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
1880 &getContext());
1881 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1882 converter);
1884 if (failed(applyPartialConversion(getOperation(), target,
1885 std::move(patterns))))
1886 signalPassFailure();
1889 } // namespace
1891 //===----------------------------------------------------------------------===//
1892 // Test Target Materialization With No Uses
1893 //===----------------------------------------------------------------------===//
1895 namespace {
1896 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1897 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1899 LogicalResult
1900 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1901 ConversionPatternRewriter &rewriter) const final {
1902 rewriter.replaceOp(op, adaptor.getOperands());
1903 return success();
1907 struct TestTargetMaterializationWithNoUses
1908 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
1909 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1910 TestTargetMaterializationWithNoUses)
1912 StringRef getArgument() const final {
1913 return "test-target-materialization-with-no-uses";
1915 StringRef getDescription() const final {
1916 return "Test a special case of target materialization in DialectConversion";
1919 void runOnOperation() override {
1920 TypeConverter converter;
1921 converter.addConversion([](Type t) { return t; });
1922 converter.addConversion([](IntegerType intTy) -> Type {
1923 if (intTy.getWidth() == 16)
1924 return IntegerType::get(intTy.getContext(), 64);
1925 return intTy;
1927 converter.addTargetMaterialization(
1928 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1929 return builder.create<TestCastOp>(loc, type, inputs).getResult();
1932 ConversionTarget target(getContext());
1933 target.addIllegalOp<TestTypeChangerOp>();
1935 RewritePatternSet patterns(&getContext());
1936 patterns.add<ForwardOperandPattern>(converter, &getContext());
1938 if (failed(applyPartialConversion(getOperation(), target,
1939 std::move(patterns))))
1940 signalPassFailure();
1943 } // namespace
1945 //===----------------------------------------------------------------------===//
1946 // Test Block Merging
1947 //===----------------------------------------------------------------------===//
1949 namespace {
1950 /// A rewriter pattern that tests that blocks can be merged.
1951 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1952 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1954 LogicalResult
1955 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1956 ConversionPatternRewriter &rewriter) const final {
1957 Block &firstBlock = op.getBody().front();
1958 Operation *branchOp = firstBlock.getTerminator();
1959 Block *secondBlock = &*(std::next(op.getBody().begin()));
1960 auto succOperands = branchOp->getOperands();
1961 SmallVector<Value, 2> replacements(succOperands);
1962 rewriter.eraseOp(branchOp);
1963 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1964 rewriter.modifyOpInPlace(op, [] {});
1965 return success();
1969 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1970 struct TestUndoBlocksMerge : public ConversionPattern {
1971 TestUndoBlocksMerge(MLIRContext *ctx)
1972 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1973 LogicalResult
1974 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1975 ConversionPatternRewriter &rewriter) const final {
1976 Block &firstBlock = op->getRegion(0).front();
1977 Operation *branchOp = firstBlock.getTerminator();
1978 Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
1979 rewriter.setInsertionPointToStart(secondBlock);
1980 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1981 auto succOperands = branchOp->getOperands();
1982 SmallVector<Value, 2> replacements(succOperands);
1983 rewriter.eraseOp(branchOp);
1984 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
1985 rewriter.modifyOpInPlace(op, [] {});
1986 return success();
1990 /// A rewrite mechanism to inline the body of the op into its parent, when both
1991 /// ops can have a single block.
1992 struct TestMergeSingleBlockOps
1993 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1994 using OpConversionPattern<
1995 SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1997 LogicalResult
1998 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1999 ConversionPatternRewriter &rewriter) const final {
2000 SingleBlockImplicitTerminatorOp parentOp =
2001 op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2002 if (!parentOp)
2003 return failure();
2004 Block &innerBlock = op.getRegion().front();
2005 TerminatorOp innerTerminator =
2006 cast<TerminatorOp>(innerBlock.getTerminator());
2007 rewriter.inlineBlockBefore(&innerBlock, op);
2008 rewriter.eraseOp(innerTerminator);
2009 rewriter.eraseOp(op);
2010 return success();
2014 struct TestMergeBlocksPatternDriver
2015 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
2016 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
2018 StringRef getArgument() const final { return "test-merge-blocks"; }
2019 StringRef getDescription() const final {
2020 return "Test Merging operation in ConversionPatternRewriter";
2022 void runOnOperation() override {
2023 MLIRContext *context = &getContext();
2024 mlir::RewritePatternSet patterns(context);
2025 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
2026 context);
2027 ConversionTarget target(*context);
2028 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
2029 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
2030 target.addIllegalOp<ILLegalOpF>();
2032 /// Expect the op to have a single block after legalization.
2033 target.addDynamicallyLegalOp<TestMergeBlocksOp>(
2034 [&](TestMergeBlocksOp op) -> bool {
2035 return llvm::hasSingleElement(op.getBody());
2038 /// Only allow `test.br` within test.merge_blocks op.
2039 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
2040 return op->getParentOfType<TestMergeBlocksOp>();
2043 /// Expect that all nested test.SingleBlockImplicitTerminator ops are
2044 /// inlined.
2045 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
2046 [&](SingleBlockImplicitTerminatorOp op) -> bool {
2047 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2050 DenseSet<Operation *> unlegalizedOps;
2051 ConversionConfig config;
2052 config.unlegalizedOps = &unlegalizedOps;
2053 (void)applyPartialConversion(getOperation(), target, std::move(patterns),
2054 config);
2055 for (auto *op : unlegalizedOps)
2056 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
2059 } // namespace
2061 //===----------------------------------------------------------------------===//
2062 // Test Selective Replacement
2063 //===----------------------------------------------------------------------===//
2065 namespace {
2066 /// A rewrite mechanism to inline the body of the op into its parent, when both
2067 /// ops can have a single block.
2068 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
2069 using OpRewritePattern<TestCastOp>::OpRewritePattern;
2071 LogicalResult matchAndRewrite(TestCastOp op,
2072 PatternRewriter &rewriter) const final {
2073 if (op.getNumOperands() != 2)
2074 return failure();
2075 OperandRange operands = op.getOperands();
2077 // Replace non-terminator uses with the first operand.
2078 rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
2079 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
2081 // Replace everything else with the second operand if the operation isn't
2082 // dead.
2083 rewriter.replaceOp(op, op.getOperand(1));
2084 return success();
2088 struct TestSelectiveReplacementPatternDriver
2089 : public PassWrapper<TestSelectiveReplacementPatternDriver,
2090 OperationPass<>> {
2091 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2092 TestSelectiveReplacementPatternDriver)
2094 StringRef getArgument() const final {
2095 return "test-pattern-selective-replacement";
2097 StringRef getDescription() const final {
2098 return "Test selective replacement in the PatternRewriter";
2100 void runOnOperation() override {
2101 MLIRContext *context = &getContext();
2102 mlir::RewritePatternSet patterns(context);
2103 patterns.add<TestSelectiveOpReplacementPattern>(context);
2104 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
2107 } // namespace
2109 //===----------------------------------------------------------------------===//
2110 // PassRegistration
2111 //===----------------------------------------------------------------------===//
2113 namespace mlir {
2114 namespace test {
2115 void registerPatternsTestPass() {
2116 PassRegistration<TestReturnTypeDriver>();
2118 PassRegistration<TestDerivedAttributeDriver>();
2120 PassRegistration<TestGreedyPatternDriver>();
2121 PassRegistration<TestStrictPatternDriver>();
2122 PassRegistration<TestWalkPatternDriver>();
2124 PassRegistration<TestLegalizePatternDriver>([] {
2125 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
2128 PassRegistration<TestRemappedValue>();
2130 PassRegistration<TestUnknownRootOpDriver>();
2132 PassRegistration<TestTypeConversionDriver>();
2133 PassRegistration<TestTargetMaterializationWithNoUses>();
2135 PassRegistration<TestRewriteDynamicOpDriver>();
2137 PassRegistration<TestMergeBlocksPatternDriver>();
2138 PassRegistration<TestSelectiveReplacementPatternDriver>();
2140 } // namespace test
2141 } // namespace mlir