1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "TestDialect.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
24 #include "llvm/ADT/ScopeExit.h"
30 // Native function for testing NativeCodeCall
31 static Value
chooseOperand(Value input1
, Value input2
, BoolAttr choice
) {
32 return choice
.getValue() ? input1
: input2
;
35 static void createOpI(PatternRewriter
&rewriter
, Location loc
, Value input
) {
36 rewriter
.create
<OpI
>(loc
, input
);
39 static void handleNoResultOp(PatternRewriter
&rewriter
,
40 OpSymbolBindingNoResult op
) {
41 // Turn the no result op to a one-result op.
42 rewriter
.create
<OpSymbolBindingB
>(op
.getLoc(), op
.getOperand().getType(),
46 static bool getFirstI32Result(Operation
*op
, Value
&value
) {
47 if (!Type(op
->getResult(0).getType()).isSignlessInteger(32))
49 value
= op
->getResult(0);
53 static Value
bindNativeCodeCallResult(Value value
) { return value
; }
55 static SmallVector
<Value
, 2> bindMultipleNativeCodeCallResult(Value input1
,
57 return SmallVector
<Value
, 2>({input2
, input1
});
60 // Test that natives calls are only called once during rewrites.
61 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
62 // This let us check the number of times OpM_Test was called by inspecting
63 // the returned value in the MLIR output.
64 static int64_t opMIncreasingValue
= 314159265;
65 static Attribute
opMTest(PatternRewriter
&rewriter
, Value val
) {
66 int64_t i
= opMIncreasingValue
++;
67 return rewriter
.getIntegerAttr(rewriter
.getIntegerType(32), i
);
71 #include "TestPatterns.inc"
74 //===----------------------------------------------------------------------===//
75 // Test Reduce Pattern Interface
76 //===----------------------------------------------------------------------===//
78 void test::populateTestReductionPatterns(RewritePatternSet
&patterns
) {
79 populateWithGenerated(patterns
);
82 //===----------------------------------------------------------------------===//
83 // Canonicalizer Driver.
84 //===----------------------------------------------------------------------===//
87 struct FoldingPattern
: public RewritePattern
{
89 FoldingPattern(MLIRContext
*context
)
90 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
91 /*benefit=*/1, context
) {}
93 LogicalResult
matchAndRewrite(Operation
*op
,
94 PatternRewriter
&rewriter
) const override
{
95 // Exercise createOrFold API for a single-result operation that is folded
96 // upon construction. The operation being created has an in-place folder,
97 // and it should be still present in the output. Furthermore, the folder
98 // should not crash when attempting to recover the (unchanged) operation
100 Value result
= rewriter
.createOrFold
<TestOpInPlaceFold
>(
101 op
->getLoc(), rewriter
.getIntegerType(32), op
->getOperand(0));
103 rewriter
.replaceOp(op
, result
);
108 /// This pattern creates a foldable operation at the entry point of the block.
109 /// This tests the situation where the operation folder will need to replace an
110 /// operation with a previously created constant that does not initially
111 /// dominate the operation to replace.
112 struct FolderInsertBeforePreviouslyFoldedConstantPattern
113 : public OpRewritePattern
<TestCastOp
> {
115 using OpRewritePattern
<TestCastOp
>::OpRewritePattern
;
117 LogicalResult
matchAndRewrite(TestCastOp op
,
118 PatternRewriter
&rewriter
) const override
{
119 if (!op
->hasAttr("test_fold_before_previously_folded_op"))
121 rewriter
.setInsertionPointToStart(op
->getBlock());
123 auto constOp
= rewriter
.create
<arith::ConstantOp
>(
124 op
.getLoc(), rewriter
.getBoolAttr(true));
125 rewriter
.replaceOpWithNewOp
<TestCastOp
>(op
, rewriter
.getI32Type(),
131 /// This pattern matches test.op_commutative2 with the first operand being
132 /// another test.op_commutative2 with a constant on the right side and fold it
133 /// away by propagating it as its result. This is intend to check that patterns
134 /// are applied after the commutative property moves constant to the right.
135 struct FolderCommutativeOp2WithConstant
136 : public OpRewritePattern
<TestCommutative2Op
> {
138 using OpRewritePattern
<TestCommutative2Op
>::OpRewritePattern
;
140 LogicalResult
matchAndRewrite(TestCommutative2Op op
,
141 PatternRewriter
&rewriter
) const override
{
143 dyn_cast_or_null
<TestCommutative2Op
>(op
->getOperand(0).getDefiningOp());
146 Attribute constInput
;
147 if (!matchPattern(operand
->getOperand(1), m_Constant(&constInput
)))
149 rewriter
.replaceOp(op
, operand
->getOperand(1));
154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
155 /// attribute with value smaller than MaxVal, it increments the value by 1.
156 template <int MaxVal
>
157 struct IncrementIntAttribute
: public OpRewritePattern
<AnyAttrOfOp
> {
158 using OpRewritePattern
<AnyAttrOfOp
>::OpRewritePattern
;
160 LogicalResult
matchAndRewrite(AnyAttrOfOp op
,
161 PatternRewriter
&rewriter
) const override
{
162 auto intAttr
= dyn_cast
<IntegerAttr
>(op
.getAttr());
165 int64_t val
= intAttr
.getInt();
168 rewriter
.modifyOpInPlace(
169 op
, [&]() { op
.setAttrAttr(rewriter
.getI32IntegerAttr(val
+ 1)); });
174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175 struct MakeOpEligible
: public RewritePattern
{
176 MakeOpEligible(MLIRContext
*context
)
177 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context
) {}
179 LogicalResult
matchAndRewrite(Operation
*op
,
180 PatternRewriter
&rewriter
) const override
{
181 if (op
->hasAttr("eligible"))
183 rewriter
.modifyOpInPlace(
184 op
, [&]() { op
->setAttr("eligible", rewriter
.getUnitAttr()); });
189 /// This pattern hoists eligible ops out of a "test.one_region_op".
190 struct HoistEligibleOps
: public OpRewritePattern
<test::OneRegionOp
> {
191 using OpRewritePattern
<test::OneRegionOp
>::OpRewritePattern
;
193 LogicalResult
matchAndRewrite(test::OneRegionOp op
,
194 PatternRewriter
&rewriter
) const override
{
195 Operation
*terminator
= op
.getRegion().front().getTerminator();
196 Operation
*toBeHoisted
= terminator
->getOperands()[0].getDefiningOp();
197 if (toBeHoisted
->getParentOp() != op
)
199 if (!toBeHoisted
->hasAttr("eligible"))
201 rewriter
.moveOpBefore(toBeHoisted
, op
);
206 /// This pattern moves "test.move_before_parent_op" before the parent op.
207 struct MoveBeforeParentOp
: public RewritePattern
{
208 MoveBeforeParentOp(MLIRContext
*context
)
209 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context
) {}
211 LogicalResult
matchAndRewrite(Operation
*op
,
212 PatternRewriter
&rewriter
) const override
{
213 // Do not hoist past functions.
214 if (isa
<FunctionOpInterface
>(op
->getParentOp()))
216 rewriter
.moveOpBefore(op
, op
->getParentOp());
221 /// This pattern moves "test.move_after_parent_op" after the parent op.
222 struct MoveAfterParentOp
: public RewritePattern
{
223 MoveAfterParentOp(MLIRContext
*context
)
224 : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context
) {}
226 LogicalResult
matchAndRewrite(Operation
*op
,
227 PatternRewriter
&rewriter
) const override
{
228 // Do not hoist past functions.
229 if (isa
<FunctionOpInterface
>(op
->getParentOp()))
232 int64_t moveForwardBy
= 0;
233 if (auto advanceBy
= op
->getAttrOfType
<IntegerAttr
>("advance"))
234 moveForwardBy
= advanceBy
.getInt();
236 Operation
*moveAfter
= op
->getParentOp();
237 for (int64_t i
= 0; i
< moveForwardBy
; ++i
)
238 moveAfter
= moveAfter
->getNextNode();
240 rewriter
.moveOpAfter(op
, moveAfter
);
245 /// This pattern inlines blocks that are nested in
246 /// "test.inline_blocks_into_parent" into the parent block.
247 struct InlineBlocksIntoParent
: public RewritePattern
{
248 InlineBlocksIntoParent(MLIRContext
*context
)
249 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
252 LogicalResult
matchAndRewrite(Operation
*op
,
253 PatternRewriter
&rewriter
) const override
{
254 bool changed
= false;
255 for (Region
&r
: op
->getRegions()) {
257 rewriter
.inlineBlockBefore(&r
.front(), op
);
261 return success(changed
);
265 /// This pattern splits blocks at "test.split_block_here" and replaces the op
266 /// with a new op (to prevent an infinite loop of block splitting).
267 struct SplitBlockHere
: public RewritePattern
{
268 SplitBlockHere(MLIRContext
*context
)
269 : RewritePattern("test.split_block_here", /*benefit=*/1, context
) {}
271 LogicalResult
matchAndRewrite(Operation
*op
,
272 PatternRewriter
&rewriter
) const override
{
273 rewriter
.splitBlock(op
->getBlock(), op
->getIterator());
274 Operation
*newOp
= rewriter
.create(
276 OperationName("test.new_op", op
->getContext()).getIdentifier(),
277 op
->getOperands(), op
->getResultTypes());
278 rewriter
.replaceOp(op
, newOp
);
283 /// This pattern clones "test.clone_me" ops.
284 struct CloneOp
: public RewritePattern
{
285 CloneOp(MLIRContext
*context
)
286 : RewritePattern("test.clone_me", /*benefit=*/1, context
) {}
288 LogicalResult
matchAndRewrite(Operation
*op
,
289 PatternRewriter
&rewriter
) const override
{
290 // Do not clone already cloned ops to avoid going into an infinite loop.
291 if (op
->hasAttr("was_cloned"))
293 Operation
*cloned
= rewriter
.clone(*op
);
294 cloned
->setAttr("was_cloned", rewriter
.getUnitAttr());
299 /// This pattern clones regions of "test.clone_region_before" ops before the
301 struct CloneRegionBeforeOp
: public RewritePattern
{
302 CloneRegionBeforeOp(MLIRContext
*context
)
303 : RewritePattern("test.clone_region_before", /*benefit=*/1, context
) {}
305 LogicalResult
matchAndRewrite(Operation
*op
,
306 PatternRewriter
&rewriter
) const override
{
307 // Do not clone already cloned ops to avoid going into an infinite loop.
308 if (op
->hasAttr("was_cloned"))
310 for (Region
&r
: op
->getRegions())
311 rewriter
.cloneRegionBefore(r
, op
->getBlock());
312 op
->setAttr("was_cloned", rewriter
.getUnitAttr());
317 /// Replace an operation may introduce the re-visiting of its users.
318 class ReplaceWithNewOp
: public RewritePattern
{
320 ReplaceWithNewOp(MLIRContext
*context
)
321 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context
) {}
323 LogicalResult
matchAndRewrite(Operation
*op
,
324 PatternRewriter
&rewriter
) const override
{
326 if (op
->hasAttr("create_erase_op")) {
327 newOp
= rewriter
.create(
329 OperationName("test.erase_op", op
->getContext()).getIdentifier(),
330 ValueRange(), TypeRange());
332 newOp
= rewriter
.create(
334 OperationName("test.new_op", op
->getContext()).getIdentifier(),
335 op
->getOperands(), op
->getResultTypes());
337 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
338 // A "notifyOperationReplaced" callback is triggered in either case.
339 rewriter
.replaceAllOpUsesWith(op
, newOp
->getResults());
340 rewriter
.eraseOp(op
);
345 /// Erases the first child block of the matched "test.erase_first_block"
347 class EraseFirstBlock
: public RewritePattern
{
349 EraseFirstBlock(MLIRContext
*context
)
350 : RewritePattern("test.erase_first_block", /*benefit=*/1, context
) {}
352 LogicalResult
matchAndRewrite(Operation
*op
,
353 PatternRewriter
&rewriter
) const override
{
354 for (Region
&r
: op
->getRegions()) {
355 for (Block
&b
: r
.getBlocks()) {
356 rewriter
.eraseBlock(&b
);
365 struct TestGreedyPatternDriver
366 : public PassWrapper
<TestGreedyPatternDriver
, OperationPass
<>> {
367 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver
)
369 TestGreedyPatternDriver() = default;
370 TestGreedyPatternDriver(const TestGreedyPatternDriver
&other
)
371 : PassWrapper(other
) {}
373 StringRef
getArgument() const final
{ return "test-greedy-patterns"; }
374 StringRef
getDescription() const final
{ return "Run test dialect patterns"; }
375 void runOnOperation() override
{
376 mlir::RewritePatternSet
patterns(&getContext());
377 populateWithGenerated(patterns
);
379 // Verify named pattern is generated with expected name.
380 patterns
.add
<FoldingPattern
, TestNamedPatternRule
,
381 FolderInsertBeforePreviouslyFoldedConstantPattern
,
382 FolderCommutativeOp2WithConstant
, HoistEligibleOps
,
383 MakeOpEligible
>(&getContext());
385 // Additional patterns for testing the GreedyPatternRewriteDriver.
386 patterns
.insert
<IncrementIntAttribute
<3>>(&getContext());
388 GreedyRewriteConfig config
;
389 config
.useTopDownTraversal
= this->useTopDownTraversal
;
390 config
.maxIterations
= this->maxIterations
;
391 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
),
395 Option
<bool> useTopDownTraversal
{
397 llvm::cl::desc("Seed the worklist in general top-down order"),
398 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal
)};
399 Option
<int> maxIterations
{
400 *this, "max-iterations",
401 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
402 llvm::cl::init(GreedyRewriteConfig().maxIterations
)};
405 struct DumpNotifications
: public RewriterBase::Listener
{
406 void notifyBlockInserted(Block
*block
, Region
*previous
,
407 Region::iterator previousIt
) override
{
408 llvm::outs() << "notifyBlockInserted";
409 if (block
->getParentOp()) {
410 llvm::outs() << " into " << block
->getParentOp()->getName() << ": ";
412 llvm::outs() << " into unknown op: ";
414 if (previous
== nullptr) {
415 llvm::outs() << "was unlinked\n";
417 llvm::outs() << "was linked\n";
420 void notifyOperationInserted(Operation
*op
,
421 OpBuilder::InsertPoint previous
) override
{
422 llvm::outs() << "notifyOperationInserted: " << op
->getName();
423 if (!previous
.isSet()) {
424 llvm::outs() << ", was unlinked\n";
426 if (!previous
.getPoint().getNodePtr()) {
427 llvm::outs() << ", was linked, exact position unknown\n";
428 } else if (previous
.getPoint() == previous
.getBlock()->end()) {
429 llvm::outs() << ", was last in block\n";
431 llvm::outs() << ", previous = " << previous
.getPoint()->getName()
436 void notifyBlockErased(Block
*block
) override
{
437 llvm::outs() << "notifyBlockErased\n";
439 void notifyOperationErased(Operation
*op
) override
{
440 llvm::outs() << "notifyOperationErased: " << op
->getName() << "\n";
442 void notifyOperationModified(Operation
*op
) override
{
443 llvm::outs() << "notifyOperationModified: " << op
->getName() << "\n";
445 void notifyOperationReplaced(Operation
*op
, ValueRange values
) override
{
446 llvm::outs() << "notifyOperationReplaced: " << op
->getName() << "\n";
450 struct TestStrictPatternDriver
451 : public PassWrapper
<TestStrictPatternDriver
, OperationPass
<func::FuncOp
>> {
453 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver
)
455 TestStrictPatternDriver() = default;
456 TestStrictPatternDriver(const TestStrictPatternDriver
&other
)
457 : PassWrapper(other
) {
458 strictMode
= other
.strictMode
;
461 StringRef
getArgument() const final
{ return "test-strict-pattern-driver"; }
462 StringRef
getDescription() const final
{
463 return "Test strict mode of pattern driver";
466 void runOnOperation() override
{
467 MLIRContext
*ctx
= &getContext();
468 mlir::RewritePatternSet
patterns(ctx
);
476 InlineBlocksIntoParent
,
483 SmallVector
<Operation
*> ops
;
484 getOperation()->walk([&](Operation
*op
) {
485 StringRef opName
= op
->getName().getStringRef();
486 if (opName
== "test.insert_same_op" || opName
== "test.change_block_op" ||
487 opName
== "test.replace_with_new_op" || opName
== "test.erase_op" ||
488 opName
== "test.move_before_parent_op" ||
489 opName
== "test.inline_blocks_into_parent" ||
490 opName
== "test.split_block_here" || opName
== "test.clone_me" ||
491 opName
== "test.clone_region_before") {
496 DumpNotifications dumpNotifications
;
497 GreedyRewriteConfig config
;
498 config
.listener
= &dumpNotifications
;
499 if (strictMode
== "AnyOp") {
500 config
.strictMode
= GreedyRewriteStrictness::AnyOp
;
501 } else if (strictMode
== "ExistingAndNewOps") {
502 config
.strictMode
= GreedyRewriteStrictness::ExistingAndNewOps
;
503 } else if (strictMode
== "ExistingOps") {
504 config
.strictMode
= GreedyRewriteStrictness::ExistingOps
;
506 llvm_unreachable("invalid strictness option");
509 // Check if these transformations introduce visiting of operations that
510 // are not in the `ops` set (The new created ops are valid). An invalid
511 // operation will trigger the assertion while processing.
512 bool changed
= false;
513 bool allErased
= false;
514 (void)applyOpPatternsAndFold(ArrayRef(ops
), std::move(patterns
), config
,
515 &changed
, &allErased
);
517 getOperation()->setAttr("pattern_driver_changed", b
.getBoolAttr(changed
));
518 getOperation()->setAttr("pattern_driver_all_erased",
519 b
.getBoolAttr(allErased
));
522 Option
<std::string
> strictMode
{
524 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
525 llvm::cl::init("AnyOp")};
528 // New inserted operation is valid for further transformation.
529 class InsertSameOp
: public RewritePattern
{
531 InsertSameOp(MLIRContext
*context
)
532 : RewritePattern("test.insert_same_op", /*benefit=*/1, context
) {}
534 LogicalResult
matchAndRewrite(Operation
*op
,
535 PatternRewriter
&rewriter
) const override
{
536 if (op
->hasAttr("skip"))
540 rewriter
.create(op
->getLoc(), op
->getName().getIdentifier(),
541 op
->getOperands(), op
->getResultTypes());
542 rewriter
.modifyOpInPlace(
543 op
, [&]() { op
->setAttr("skip", rewriter
.getBoolAttr(true)); });
544 newOp
->setAttr("skip", rewriter
.getBoolAttr(true));
550 // Remove an operation may introduce the re-visiting of its operands.
551 class EraseOp
: public RewritePattern
{
553 EraseOp(MLIRContext
*context
)
554 : RewritePattern("test.erase_op", /*benefit=*/1, context
) {}
555 LogicalResult
matchAndRewrite(Operation
*op
,
556 PatternRewriter
&rewriter
) const override
{
557 rewriter
.eraseOp(op
);
562 // The following two patterns test RewriterBase::replaceAllUsesWith.
564 // That function replaces all usages of a Block (or a Value) with another one
565 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
566 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
567 // worklist: when an op is modified, it is added to the worklist. The two
568 // patterns below make the tracking observable: ChangeBlockOp replaces all
569 // usages of a block and that pattern is applied because the corresponding ops
570 // are put on the initial worklist (see above). ImplicitChangeOp does an
571 // unrelated change but ops of the corresponding type are *not* on the initial
572 // worklist, so the effect of the second pattern is only visible if the
573 // tracking and subsequent adding to the worklist actually works.
575 // Replace all usages of the first successor with the second successor.
576 class ChangeBlockOp
: public RewritePattern
{
578 ChangeBlockOp(MLIRContext
*context
)
579 : RewritePattern("test.change_block_op", /*benefit=*/1, context
) {}
580 LogicalResult
matchAndRewrite(Operation
*op
,
581 PatternRewriter
&rewriter
) const override
{
582 if (op
->getNumSuccessors() < 2)
584 Block
*firstSuccessor
= op
->getSuccessor(0);
585 Block
*secondSuccessor
= op
->getSuccessor(1);
586 if (firstSuccessor
== secondSuccessor
)
588 // This is the function being tested:
589 rewriter
.replaceAllUsesWith(firstSuccessor
, secondSuccessor
);
590 // Using the following line instead would make the test fail:
591 // firstSuccessor->replaceAllUsesWith(secondSuccessor);
596 // Changes the successor to the parent block.
597 class ImplicitChangeOp
: public RewritePattern
{
599 ImplicitChangeOp(MLIRContext
*context
)
600 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context
) {}
601 LogicalResult
matchAndRewrite(Operation
*op
,
602 PatternRewriter
&rewriter
) const override
{
603 if (op
->getNumSuccessors() < 1 || op
->getSuccessor(0) == op
->getBlock())
605 rewriter
.modifyOpInPlace(op
,
606 [&]() { op
->setSuccessor(op
->getBlock(), 0); });
612 struct TestWalkPatternDriver final
613 : PassWrapper
<TestWalkPatternDriver
, OperationPass
<>> {
614 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver
)
616 TestWalkPatternDriver() = default;
617 TestWalkPatternDriver(const TestWalkPatternDriver
&other
)
618 : PassWrapper(other
) {}
620 StringRef
getArgument() const override
{
621 return "test-walk-pattern-rewrite-driver";
623 StringRef
getDescription() const override
{
624 return "Run test walk pattern rewrite driver";
626 void runOnOperation() override
{
627 mlir::RewritePatternSet
patterns(&getContext());
629 // Patterns for testing the WalkPatternRewriteDriver.
630 patterns
.add
<IncrementIntAttribute
<3>, MoveBeforeParentOp
,
631 MoveAfterParentOp
, CloneOp
, ReplaceWithNewOp
, EraseFirstBlock
>(
634 DumpNotifications dumpListener
;
635 walkAndApplyPatterns(getOperation(), std::move(patterns
),
636 dumpNotifications
? &dumpListener
: nullptr);
639 Option
<bool> dumpNotifications
{
640 *this, "dump-notifications",
641 llvm::cl::desc("Print rewrite listener notifications"),
642 llvm::cl::init(false)};
647 //===----------------------------------------------------------------------===//
648 // ReturnType Driver.
649 //===----------------------------------------------------------------------===//
652 // Generate ops for each instance where the type can be successfully inferred.
653 template <typename OpTy
>
654 static void invokeCreateWithInferredReturnType(Operation
*op
) {
655 auto *context
= op
->getContext();
656 auto fop
= op
->getParentOfType
<func::FuncOp
>();
657 auto location
= UnknownLoc::get(context
);
659 b
.setInsertionPointAfter(op
);
661 // Use permutations of 2 args as operands.
662 assert(fop
.getNumArguments() >= 2);
663 for (int i
= 0, e
= fop
.getNumArguments(); i
< e
; ++i
) {
664 for (int j
= 0; j
< e
; ++j
) {
665 std::array
<Value
, 2> values
= {{fop
.getArgument(i
), fop
.getArgument(j
)}};
666 SmallVector
<Type
, 2> inferredReturnTypes
;
667 if (succeeded(OpTy::inferReturnTypes(
668 context
, std::nullopt
, values
, op
->getDiscardableAttrDictionary(),
669 op
->getPropertiesStorage(), op
->getRegions(),
670 inferredReturnTypes
))) {
671 OperationState
state(location
, OpTy::getOperationName());
672 // TODO: Expand to regions.
673 OpTy::build(b
, state
, values
, op
->getAttrs());
674 (void)b
.create(state
);
680 static void reifyReturnShape(Operation
*op
) {
683 // Use permutations of 2 args as operands.
684 auto shapedOp
= cast
<OpWithShapedTypeInferTypeInterfaceOp
>(op
);
685 SmallVector
<Value
, 2> shapes
;
686 if (failed(shapedOp
.reifyReturnTypeShapes(b
, op
->getOperands(), shapes
)) ||
687 !llvm::hasSingleElement(shapes
))
689 for (const auto &it
: llvm::enumerate(shapes
)) {
690 op
->emitRemark() << "value " << it
.index() << ": "
691 << it
.value().getDefiningOp();
695 struct TestReturnTypeDriver
696 : public PassWrapper
<TestReturnTypeDriver
, OperationPass
<func::FuncOp
>> {
697 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver
)
699 void getDependentDialects(DialectRegistry
®istry
) const override
{
700 registry
.insert
<tensor::TensorDialect
>();
702 StringRef
getArgument() const final
{ return "test-return-type"; }
703 StringRef
getDescription() const final
{ return "Run return type functions"; }
705 void runOnOperation() override
{
706 if (getOperation().getName() == "testCreateFunctions") {
707 std::vector
<Operation
*> ops
;
708 // Collect ops to avoid triggering on inserted ops.
709 for (auto &op
: getOperation().getBody().front())
711 // Generate test patterns for each, but skip terminator.
712 for (auto *op
: llvm::ArrayRef(ops
).drop_back()) {
713 // Test create method of each of the Op classes below. The resultant
714 // output would be in reverse order underneath `op` from which
715 // the attributes and regions are used.
716 invokeCreateWithInferredReturnType
<OpWithInferTypeInterfaceOp
>(op
);
717 invokeCreateWithInferredReturnType
<OpWithInferTypeAdaptorInterfaceOp
>(
719 invokeCreateWithInferredReturnType
<
720 OpWithShapedTypeInferTypeInterfaceOp
>(op
);
724 if (getOperation().getName() == "testReifyFunctions") {
725 std::vector
<Operation
*> ops
;
726 // Collect ops to avoid triggering on inserted ops.
727 for (auto &op
: getOperation().getBody().front())
728 if (isa
<OpWithShapedTypeInferTypeInterfaceOp
>(op
))
730 // Generate test patterns for each, but skip terminator.
732 reifyReturnShape(op
);
739 struct TestDerivedAttributeDriver
740 : public PassWrapper
<TestDerivedAttributeDriver
,
741 OperationPass
<func::FuncOp
>> {
742 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver
)
744 StringRef
getArgument() const final
{ return "test-derived-attr"; }
745 StringRef
getDescription() const final
{
746 return "Run test derived attributes";
748 void runOnOperation() override
;
752 void TestDerivedAttributeDriver::runOnOperation() {
753 getOperation().walk([](DerivedAttributeOpInterface dOp
) {
754 auto dAttr
= dOp
.materializeDerivedAttributes();
758 dOp
.emitRemark() << d
.getName().getValue() << " = " << d
.getValue();
762 //===----------------------------------------------------------------------===//
763 // Legalization Driver.
764 //===----------------------------------------------------------------------===//
767 //===----------------------------------------------------------------------===//
768 // Region-Block Rewrite Testing
770 /// This pattern applies a signature conversion to a block inside a detached
772 struct TestDetachedSignatureConversion
: public ConversionPattern
{
773 TestDetachedSignatureConversion(MLIRContext
*ctx
)
774 : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
778 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
779 ConversionPatternRewriter
&rewriter
) const final
{
780 if (op
->getNumRegions() != 1)
782 OperationState
state(op
->getLoc(), "test.legal_op_with_region", operands
,
783 op
->getResultTypes(), {}, BlockRange());
784 Region
*newRegion
= state
.addRegion();
785 rewriter
.inlineRegionBefore(op
->getRegion(0), *newRegion
,
787 TypeConverter::SignatureConversion
result(newRegion
->getNumArguments());
788 for (unsigned i
= 0, e
= newRegion
->getNumArguments(); i
< e
; ++i
)
789 result
.addInputs(i
, rewriter
.getF64Type());
790 rewriter
.applySignatureConversion(&newRegion
->front(), result
);
791 Operation
*newOp
= rewriter
.create(state
);
792 rewriter
.replaceOp(op
, newOp
->getResults());
797 /// This pattern is a simple pattern that inlines the first region of a given
798 /// operation into the parent region.
799 struct TestRegionRewriteBlockMovement
: public ConversionPattern
{
800 TestRegionRewriteBlockMovement(MLIRContext
*ctx
)
801 : ConversionPattern("test.region", 1, ctx
) {}
804 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
805 ConversionPatternRewriter
&rewriter
) const final
{
806 // Inline this region into the parent region.
807 auto &parentRegion
= *op
->getParentRegion();
808 auto &opRegion
= op
->getRegion(0);
809 if (op
->getDiscardableAttr("legalizer.should_clone"))
810 rewriter
.cloneRegionBefore(opRegion
, parentRegion
, parentRegion
.end());
812 rewriter
.inlineRegionBefore(opRegion
, parentRegion
, parentRegion
.end());
814 if (op
->getDiscardableAttr("legalizer.erase_old_blocks")) {
815 while (!opRegion
.empty())
816 rewriter
.eraseBlock(&opRegion
.front());
819 // Drop this operation.
820 rewriter
.eraseOp(op
);
824 /// This pattern is a simple pattern that generates a region containing an
825 /// illegal operation.
826 struct TestRegionRewriteUndo
: public RewritePattern
{
827 TestRegionRewriteUndo(MLIRContext
*ctx
)
828 : RewritePattern("test.region_builder", 1, ctx
) {}
830 LogicalResult
matchAndRewrite(Operation
*op
,
831 PatternRewriter
&rewriter
) const final
{
832 // Create the region operation with an entry block containing arguments.
833 OperationState
newRegion(op
->getLoc(), "test.region");
834 newRegion
.addRegion();
835 auto *regionOp
= rewriter
.create(newRegion
);
836 auto *entryBlock
= rewriter
.createBlock(®ionOp
->getRegion(0));
837 entryBlock
->addArgument(rewriter
.getIntegerType(64),
838 rewriter
.getUnknownLoc());
840 // Add an explicitly illegal operation to ensure the conversion fails.
841 rewriter
.create
<ILLegalOpF
>(op
->getLoc(), rewriter
.getIntegerType(32));
842 rewriter
.create
<TestValidOp
>(op
->getLoc(), ArrayRef
<Value
>());
844 // Drop this operation.
845 rewriter
.eraseOp(op
);
849 /// A simple pattern that creates a block at the end of the parent region of the
850 /// matched operation.
851 struct TestCreateBlock
: public RewritePattern
{
852 TestCreateBlock(MLIRContext
*ctx
)
853 : RewritePattern("test.create_block", /*benefit=*/1, ctx
) {}
855 LogicalResult
matchAndRewrite(Operation
*op
,
856 PatternRewriter
&rewriter
) const final
{
857 Region
®ion
= *op
->getParentRegion();
858 Type i32Type
= rewriter
.getIntegerType(32);
859 Location loc
= op
->getLoc();
860 rewriter
.createBlock(®ion
, region
.end(), {i32Type
, i32Type
}, {loc
, loc
});
861 rewriter
.create
<TerminatorOp
>(loc
);
862 rewriter
.eraseOp(op
);
867 /// A simple pattern that creates a block containing an invalid operation in
868 /// order to trigger the block creation undo mechanism.
869 struct TestCreateIllegalBlock
: public RewritePattern
{
870 TestCreateIllegalBlock(MLIRContext
*ctx
)
871 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx
) {}
873 LogicalResult
matchAndRewrite(Operation
*op
,
874 PatternRewriter
&rewriter
) const final
{
875 Region
®ion
= *op
->getParentRegion();
876 Type i32Type
= rewriter
.getIntegerType(32);
877 Location loc
= op
->getLoc();
878 rewriter
.createBlock(®ion
, region
.end(), {i32Type
, i32Type
}, {loc
, loc
});
879 // Create an illegal op to ensure the conversion fails.
880 rewriter
.create
<ILLegalOpF
>(loc
, i32Type
);
881 rewriter
.create
<TerminatorOp
>(loc
);
882 rewriter
.eraseOp(op
);
887 /// A simple pattern that tests the undo mechanism when replacing the uses of a
889 struct TestUndoBlockArgReplace
: public ConversionPattern
{
890 TestUndoBlockArgReplace(MLIRContext
*ctx
)
891 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx
) {}
894 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
895 ConversionPatternRewriter
&rewriter
) const final
{
897 rewriter
.create
<ILLegalOpF
>(op
->getLoc(), rewriter
.getF32Type());
898 rewriter
.replaceUsesOfBlockArgument(op
->getRegion(0).getArgument(0),
899 illegalOp
->getResult(0));
900 rewriter
.modifyOpInPlace(op
, [] {});
905 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
906 /// This is to test the rollback logic.
907 struct TestUndoMoveOpBefore
: public ConversionPattern
{
908 TestUndoMoveOpBefore(MLIRContext
*ctx
)
909 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx
) {}
912 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
913 ConversionPatternRewriter
&rewriter
) const override
{
914 rewriter
.moveOpBefore(op
, op
->getParentOp());
915 // Replace with an illegal op to ensure the conversion fails.
916 rewriter
.replaceOpWithNewOp
<ILLegalOpF
>(op
, rewriter
.getF32Type());
921 /// A rewrite pattern that tests the undo mechanism when erasing a block.
922 struct TestUndoBlockErase
: public ConversionPattern
{
923 TestUndoBlockErase(MLIRContext
*ctx
)
924 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx
) {}
927 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
928 ConversionPatternRewriter
&rewriter
) const final
{
929 Block
*secondBlock
= &*std::next(op
->getRegion(0).begin());
930 rewriter
.setInsertionPointToStart(secondBlock
);
931 rewriter
.create
<ILLegalOpF
>(op
->getLoc(), rewriter
.getF32Type());
932 rewriter
.eraseBlock(secondBlock
);
933 rewriter
.modifyOpInPlace(op
, [] {});
938 /// A pattern that modifies a property in-place, but keeps the op illegal.
939 struct TestUndoPropertiesModification
: public ConversionPattern
{
940 TestUndoPropertiesModification(MLIRContext
*ctx
)
941 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx
) {}
943 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
944 ConversionPatternRewriter
&rewriter
) const final
{
945 if (!op
->hasAttr("modify_inplace"))
947 rewriter
.modifyOpInPlace(
948 op
, [&]() { cast
<TestOpWithProperties
>(op
).getProperties().setA(42); });
953 //===----------------------------------------------------------------------===//
954 // Type-Conversion Rewrite Testing
956 /// This patterns erases a region operation that has had a type conversion.
957 struct TestDropOpSignatureConversion
: public ConversionPattern
{
958 TestDropOpSignatureConversion(MLIRContext
*ctx
,
959 const TypeConverter
&converter
)
960 : ConversionPattern(converter
, "test.drop_region_op", 1, ctx
) {}
962 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
963 ConversionPatternRewriter
&rewriter
) const override
{
964 Region
®ion
= op
->getRegion(0);
965 Block
*entry
= ®ion
.front();
967 // Convert the original entry arguments.
968 const TypeConverter
&converter
= *getTypeConverter();
969 TypeConverter::SignatureConversion
result(entry
->getNumArguments());
970 if (failed(converter
.convertSignatureArgs(entry
->getArgumentTypes(),
972 failed(rewriter
.convertRegionTypes(®ion
, converter
, &result
)))
975 // Convert the region signature and just drop the operation.
976 rewriter
.eraseOp(op
);
980 /// This pattern simply updates the operands of the given operation.
981 struct TestPassthroughInvalidOp
: public ConversionPattern
{
982 TestPassthroughInvalidOp(MLIRContext
*ctx
)
983 : ConversionPattern("test.invalid", 1, ctx
) {}
985 matchAndRewrite(Operation
*op
, ArrayRef
<ValueRange
> operands
,
986 ConversionPatternRewriter
&rewriter
) const final
{
987 SmallVector
<Value
> flattened
;
988 for (auto it
: llvm::enumerate(operands
)) {
989 ValueRange range
= it
.value();
990 if (range
.size() == 1) {
991 flattened
.push_back(range
.front());
995 // This is a 1:N replacement. Insert a test.cast op. (That's what the
996 // argument materialization used to do.)
999 .create
<TestCastOp
>(op
->getLoc(),
1000 op
->getOperand(it
.index()).getType(), range
)
1003 rewriter
.replaceOpWithNewOp
<TestValidOp
>(op
, std::nullopt
, flattened
,
1008 /// Replace with valid op, but simply drop the operands. This is used in a
1009 /// regression where we used to generate circular unrealized_conversion_cast
1011 struct TestDropAndReplaceInvalidOp
: public ConversionPattern
{
1012 TestDropAndReplaceInvalidOp(MLIRContext
*ctx
, const TypeConverter
&converter
)
1013 : ConversionPattern(converter
,
1014 "test.drop_operands_and_replace_with_valid", 1, ctx
) {
1017 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1018 ConversionPatternRewriter
&rewriter
) const final
{
1019 rewriter
.replaceOpWithNewOp
<TestValidOp
>(op
, std::nullopt
, ValueRange(),
1024 /// This pattern handles the case of a split return value.
1025 struct TestSplitReturnType
: public ConversionPattern
{
1026 TestSplitReturnType(MLIRContext
*ctx
)
1027 : ConversionPattern("test.return", 1, ctx
) {}
1029 matchAndRewrite(Operation
*op
, ArrayRef
<ValueRange
> operands
,
1030 ConversionPatternRewriter
&rewriter
) const final
{
1031 // Check for a return of F32.
1032 if (op
->getNumOperands() != 1 || !op
->getOperand(0).getType().isF32())
1034 rewriter
.replaceOpWithNewOp
<TestReturnOp
>(op
, operands
[0]);
1039 //===----------------------------------------------------------------------===//
1040 // Multi-Level Type-Conversion Rewrite Testing
1041 struct TestChangeProducerTypeI32ToF32
: public ConversionPattern
{
1042 TestChangeProducerTypeI32ToF32(MLIRContext
*ctx
)
1043 : ConversionPattern("test.type_producer", 1, ctx
) {}
1045 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1046 ConversionPatternRewriter
&rewriter
) const final
{
1047 // If the type is I32, change the type to F32.
1048 if (!Type(*op
->result_type_begin()).isSignlessInteger(32))
1050 rewriter
.replaceOpWithNewOp
<TestTypeProducerOp
>(op
, rewriter
.getF32Type());
1054 struct TestChangeProducerTypeF32ToF64
: public ConversionPattern
{
1055 TestChangeProducerTypeF32ToF64(MLIRContext
*ctx
)
1056 : ConversionPattern("test.type_producer", 1, ctx
) {}
1058 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1059 ConversionPatternRewriter
&rewriter
) const final
{
1060 // If the type is F32, change the type to F64.
1061 if (!Type(*op
->result_type_begin()).isF32())
1062 return rewriter
.notifyMatchFailure(op
, "expected single f32 operand");
1063 rewriter
.replaceOpWithNewOp
<TestTypeProducerOp
>(op
, rewriter
.getF64Type());
1067 struct TestChangeProducerTypeF32ToInvalid
: public ConversionPattern
{
1068 TestChangeProducerTypeF32ToInvalid(MLIRContext
*ctx
)
1069 : ConversionPattern("test.type_producer", 10, ctx
) {}
1071 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1072 ConversionPatternRewriter
&rewriter
) const final
{
1073 // Always convert to B16, even though it is not a legal type. This tests
1074 // that values are unmapped correctly.
1075 rewriter
.replaceOpWithNewOp
<TestTypeProducerOp
>(op
, rewriter
.getBF16Type());
1079 struct TestUpdateConsumerType
: public ConversionPattern
{
1080 TestUpdateConsumerType(MLIRContext
*ctx
)
1081 : ConversionPattern("test.type_consumer", 1, ctx
) {}
1083 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1084 ConversionPatternRewriter
&rewriter
) const final
{
1085 // Verify that the incoming operand has been successfully remapped to F64.
1086 if (!operands
[0].getType().isF64())
1088 rewriter
.replaceOpWithNewOp
<TestTypeConsumerOp
>(op
, operands
[0]);
1093 //===----------------------------------------------------------------------===//
1094 // Non-Root Replacement Rewrite Testing
1095 /// This pattern generates an invalid operation, but replaces it before the
1096 /// pattern is finished. This checks that we don't need to legalize the
1098 struct TestNonRootReplacement
: public RewritePattern
{
1099 TestNonRootReplacement(MLIRContext
*ctx
)
1100 : RewritePattern("test.replace_non_root", 1, ctx
) {}
1102 LogicalResult
matchAndRewrite(Operation
*op
,
1103 PatternRewriter
&rewriter
) const final
{
1104 auto resultType
= *op
->result_type_begin();
1105 auto illegalOp
= rewriter
.create
<ILLegalOpF
>(op
->getLoc(), resultType
);
1106 auto legalOp
= rewriter
.create
<LegalOpB
>(op
->getLoc(), resultType
);
1108 rewriter
.replaceOp(illegalOp
, legalOp
);
1109 rewriter
.replaceOp(op
, illegalOp
);
1114 //===----------------------------------------------------------------------===//
1115 // Recursive Rewrite Testing
1116 /// This pattern is applied to the same operation multiple times, but has a
1117 /// bounded recursion.
1118 struct TestBoundedRecursiveRewrite
1119 : public OpRewritePattern
<TestRecursiveRewriteOp
> {
1120 using OpRewritePattern
<TestRecursiveRewriteOp
>::OpRewritePattern
;
1123 // The conversion target handles bounding the recursion of this pattern.
1124 setHasBoundedRewriteRecursion();
1127 LogicalResult
matchAndRewrite(TestRecursiveRewriteOp op
,
1128 PatternRewriter
&rewriter
) const final
{
1129 // Decrement the depth of the op in-place.
1130 rewriter
.modifyOpInPlace(op
, [&] {
1131 op
->setAttr("depth", rewriter
.getI64IntegerAttr(op
.getDepth() - 1));
1137 struct TestNestedOpCreationUndoRewrite
1138 : public OpRewritePattern
<IllegalOpWithRegionAnchor
> {
1139 using OpRewritePattern
<IllegalOpWithRegionAnchor
>::OpRewritePattern
;
1141 LogicalResult
matchAndRewrite(IllegalOpWithRegionAnchor op
,
1142 PatternRewriter
&rewriter
) const final
{
1143 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1144 rewriter
.replaceOpWithNewOp
<IllegalOpWithRegion
>(op
);
1149 // This pattern matches `test.blackhole` and delete this op and its producer.
1150 struct TestReplaceEraseOp
: public OpRewritePattern
<BlackHoleOp
> {
1151 using OpRewritePattern
<BlackHoleOp
>::OpRewritePattern
;
1153 LogicalResult
matchAndRewrite(BlackHoleOp op
,
1154 PatternRewriter
&rewriter
) const final
{
1155 Operation
*producer
= op
.getOperand().getDefiningOp();
1156 // Always erase the user before the producer, the framework should handle
1158 rewriter
.eraseOp(op
);
1159 rewriter
.eraseOp(producer
);
1164 // This pattern replaces explicitly illegal op with explicitly legal op,
1165 // but in addition creates unregistered operation.
1166 struct TestCreateUnregisteredOp
: public OpRewritePattern
<ILLegalOpG
> {
1167 using OpRewritePattern
<ILLegalOpG
>::OpRewritePattern
;
1169 LogicalResult
matchAndRewrite(ILLegalOpG op
,
1170 PatternRewriter
&rewriter
) const final
{
1171 IntegerAttr attr
= rewriter
.getI32IntegerAttr(0);
1172 Value val
= rewriter
.create
<arith::ConstantOp
>(op
->getLoc(), attr
);
1173 rewriter
.replaceOpWithNewOp
<LegalOpC
>(op
, val
);
1178 class TestEraseOp
: public ConversionPattern
{
1180 TestEraseOp(MLIRContext
*ctx
) : ConversionPattern("test.erase_op", 1, ctx
) {}
1182 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1183 ConversionPatternRewriter
&rewriter
) const final
{
1184 // Erase op without replacements.
1185 rewriter
.eraseOp(op
);
1190 /// This pattern matches a test.duplicate_block_args op and duplicates all
1191 /// block arguments.
1192 class TestDuplicateBlockArgs
1193 : public OpConversionPattern
<DuplicateBlockArgsOp
> {
1194 using OpConversionPattern
<DuplicateBlockArgsOp
>::OpConversionPattern
;
1197 matchAndRewrite(DuplicateBlockArgsOp op
, OpAdaptor adaptor
,
1198 ConversionPatternRewriter
&rewriter
) const override
{
1199 if (op
.getIsLegal())
1201 rewriter
.startOpModification(op
);
1202 Block
*body
= &op
.getBody().front();
1203 TypeConverter::SignatureConversion
result(body
->getNumArguments());
1204 for (auto it
: llvm::enumerate(body
->getArgumentTypes()))
1205 result
.addInputs(it
.index(), {it
.value(), it
.value()});
1206 rewriter
.applySignatureConversion(body
, result
, getTypeConverter());
1207 op
.setIsLegal(true);
1208 rewriter
.finalizeOpModification(op
);
1213 /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
1214 /// op. The pattern supports 1:N replacements and forwards the replacement
1215 /// values of the single operand as test.valid operands.
1216 class TestRepetitive1ToNConsumer
: public ConversionPattern
{
1218 TestRepetitive1ToNConsumer(MLIRContext
*ctx
)
1219 : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx
) {}
1221 matchAndRewrite(Operation
*op
, ArrayRef
<ValueRange
> operands
,
1222 ConversionPatternRewriter
&rewriter
) const final
{
1223 // A single operand is expected.
1224 if (op
->getNumOperands() != 1)
1226 rewriter
.replaceOpWithNewOp
<TestValidOp
>(op
, operands
.front());
1234 struct TestTypeConverter
: public TypeConverter
{
1235 using TypeConverter::TypeConverter
;
1236 TestTypeConverter() {
1237 addConversion(convertType
);
1238 addArgumentMaterialization(materializeCast
);
1239 addSourceMaterialization(materializeCast
);
1242 static LogicalResult
convertType(Type t
, SmallVectorImpl
<Type
> &results
) {
1244 if (t
.isSignlessInteger(16))
1247 // Convert I64 to F64.
1248 if (t
.isSignlessInteger(64)) {
1249 results
.push_back(FloatType::getF64(t
.getContext()));
1253 // Convert I42 to I43.
1254 if (t
.isInteger(42)) {
1255 results
.push_back(IntegerType::get(t
.getContext(), 43));
1259 // Split F32 into F16,F16.
1261 results
.assign(2, FloatType::getF16(t
.getContext()));
1266 if (t
.isInteger(24)) {
1270 // Otherwise, convert the type directly.
1271 results
.push_back(t
);
1275 /// Hook for materializing a conversion. This is necessary because we generate
1276 /// 1->N type mappings.
1277 static Value
materializeCast(OpBuilder
&builder
, Type resultType
,
1278 ValueRange inputs
, Location loc
) {
1279 return builder
.create
<TestCastOp
>(loc
, resultType
, inputs
).getResult();
1283 struct TestLegalizePatternDriver
1284 : public PassWrapper
<TestLegalizePatternDriver
, OperationPass
<>> {
1285 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver
)
1287 StringRef
getArgument() const final
{ return "test-legalize-patterns"; }
1288 StringRef
getDescription() const final
{
1289 return "Run test dialect legalization patterns";
1291 /// The mode of conversion to use with the driver.
1292 enum class ConversionMode
{ Analysis
, Full
, Partial
};
1294 TestLegalizePatternDriver(ConversionMode mode
) : mode(mode
) {}
1296 void getDependentDialects(DialectRegistry
®istry
) const override
{
1297 registry
.insert
<func::FuncDialect
, test::TestDialect
>();
1300 void runOnOperation() override
{
1301 TestTypeConverter converter
;
1302 mlir::RewritePatternSet
patterns(&getContext());
1303 populateWithGenerated(patterns
);
1305 TestRegionRewriteBlockMovement
, TestDetachedSignatureConversion
,
1306 TestRegionRewriteUndo
, TestCreateBlock
, TestCreateIllegalBlock
,
1307 TestUndoBlockArgReplace
, TestUndoBlockErase
, TestPassthroughInvalidOp
,
1308 TestSplitReturnType
, TestChangeProducerTypeI32ToF32
,
1309 TestChangeProducerTypeF32ToF64
, TestChangeProducerTypeF32ToInvalid
,
1310 TestUpdateConsumerType
, TestNonRootReplacement
,
1311 TestBoundedRecursiveRewrite
, TestNestedOpCreationUndoRewrite
,
1312 TestReplaceEraseOp
, TestCreateUnregisteredOp
, TestUndoMoveOpBefore
,
1313 TestUndoPropertiesModification
, TestEraseOp
,
1314 TestRepetitive1ToNConsumer
>(&getContext());
1315 patterns
.add
<TestDropOpSignatureConversion
, TestDropAndReplaceInvalidOp
>(
1316 &getContext(), converter
);
1317 patterns
.add
<TestDuplicateBlockArgs
>(converter
, &getContext());
1318 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns
,
1320 mlir::populateCallOpTypeConversionPattern(patterns
, converter
);
1322 // Define the conversion target used for the test.
1323 ConversionTarget
target(getContext());
1324 target
.addLegalOp
<ModuleOp
>();
1325 target
.addLegalOp
<LegalOpA
, LegalOpB
, LegalOpC
, TestCastOp
, TestValidOp
,
1326 TerminatorOp
, OneRegionOp
>();
1328 OperationName("test.legal_op_with_region", &getContext()));
1330 .addIllegalOp
<ILLegalOpF
, TestRegionBuilderOp
, TestOpWithRegionFold
>();
1331 target
.addDynamicallyLegalOp
<TestReturnOp
>([](TestReturnOp op
) {
1332 // Don't allow F32 operands.
1333 return llvm::none_of(op
.getOperandTypes(),
1334 [](Type type
) { return type
.isF32(); });
1336 target
.addDynamicallyLegalOp
<func::FuncOp
>([&](func::FuncOp op
) {
1337 return converter
.isSignatureLegal(op
.getFunctionType()) &&
1338 converter
.isLegal(&op
.getBody());
1340 target
.addDynamicallyLegalOp
<func::CallOp
>(
1341 [&](func::CallOp op
) { return converter
.isLegal(op
); });
1343 // TestCreateUnregisteredOp creates `arith.constant` operation,
1344 // which was not added to target intentionally to test
1345 // correct error code from conversion driver.
1346 target
.addDynamicallyLegalOp
<ILLegalOpG
>([](ILLegalOpG
) { return false; });
1348 // Expect the type_producer/type_consumer operations to only operate on f64.
1349 target
.addDynamicallyLegalOp
<TestTypeProducerOp
>(
1350 [](TestTypeProducerOp op
) { return op
.getType().isF64(); });
1351 target
.addDynamicallyLegalOp
<TestTypeConsumerOp
>([](TestTypeConsumerOp op
) {
1352 return op
.getOperand().getType().isF64();
1355 // Check support for marking certain operations as recursively legal.
1356 target
.markOpRecursivelyLegal
<func::FuncOp
, ModuleOp
>([](Operation
*op
) {
1357 return static_cast<bool>(
1358 op
->getAttrOfType
<UnitAttr
>("test.recursively_legal"));
1361 // Mark the bound recursion operation as dynamically legal.
1362 target
.addDynamicallyLegalOp
<TestRecursiveRewriteOp
>(
1363 [](TestRecursiveRewriteOp op
) { return op
.getDepth() == 0; });
1365 // Create a dynamically legal rule that can only be legalized by folding it.
1366 target
.addDynamicallyLegalOp
<TestOpInPlaceSelfFold
>(
1367 [](TestOpInPlaceSelfFold op
) { return op
.getFolded(); });
1369 target
.addDynamicallyLegalOp
<DuplicateBlockArgsOp
>(
1370 [](DuplicateBlockArgsOp op
) { return op
.getIsLegal(); });
1372 // Handle a partial conversion.
1373 if (mode
== ConversionMode::Partial
) {
1374 DenseSet
<Operation
*> unlegalizedOps
;
1375 ConversionConfig config
;
1376 DumpNotifications dumpNotifications
;
1377 config
.listener
= &dumpNotifications
;
1378 config
.unlegalizedOps
= &unlegalizedOps
;
1379 if (failed(applyPartialConversion(getOperation(), target
,
1380 std::move(patterns
), config
))) {
1381 getOperation()->emitRemark() << "applyPartialConversion failed";
1383 // Emit remarks for each legalizable operation.
1384 for (auto *op
: unlegalizedOps
)
1385 op
->emitRemark() << "op '" << op
->getName() << "' is not legalizable";
1389 // Handle a full conversion.
1390 if (mode
== ConversionMode::Full
) {
1391 // Check support for marking unknown operations as dynamically legal.
1392 target
.markUnknownOpDynamicallyLegal([](Operation
*op
) {
1393 return (bool)op
->getAttrOfType
<UnitAttr
>("test.dynamically_legal");
1396 ConversionConfig config
;
1397 DumpNotifications dumpNotifications
;
1398 config
.listener
= &dumpNotifications
;
1399 if (failed(applyFullConversion(getOperation(), target
,
1400 std::move(patterns
), config
))) {
1401 getOperation()->emitRemark() << "applyFullConversion failed";
1406 // Otherwise, handle an analysis conversion.
1407 assert(mode
== ConversionMode::Analysis
);
1409 // Analyze the convertible operations.
1410 DenseSet
<Operation
*> legalizedOps
;
1411 ConversionConfig config
;
1412 config
.legalizableOps
= &legalizedOps
;
1413 if (failed(applyAnalysisConversion(getOperation(), target
,
1414 std::move(patterns
), config
)))
1415 return signalPassFailure();
1417 // Emit remarks for each legalizable operation.
1418 for (auto *op
: legalizedOps
)
1419 op
->emitRemark() << "op '" << op
->getName() << "' is legalizable";
1422 /// The mode of conversion to use.
1423 ConversionMode mode
;
1427 static llvm::cl::opt
<TestLegalizePatternDriver::ConversionMode
>
1428 legalizerConversionMode(
1429 "test-legalize-mode",
1430 llvm::cl::desc("The legalization mode to use with the test driver"),
1431 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial
),
1433 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis
,
1434 "analysis", "Perform an analysis conversion"),
1435 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full
, "full",
1436 "Perform a full conversion"),
1437 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial
,
1438 "partial", "Perform a partial conversion")));
1440 //===----------------------------------------------------------------------===//
1441 // ConversionPatternRewriter::getRemappedValue testing. This method is used
1442 // to get the remapped value of an original value that was replaced using
1443 // ConversionPatternRewriter.
1445 struct TestRemapValueTypeConverter
: public TypeConverter
{
1446 using TypeConverter::TypeConverter
;
1448 TestRemapValueTypeConverter() {
1450 [](Float32Type type
) { return Float64Type::get(type
.getContext()); });
1451 addConversion([](Type type
) { return type
; });
1455 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1456 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1460 /// %1 = test.one_variadic_out_one_variadic_in1"(%0)
1461 /// is replaced with:
1462 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1463 struct OneVResOneVOperandOp1Converter
1464 : public OpConversionPattern
<OneVResOneVOperandOp1
> {
1465 using OpConversionPattern
<OneVResOneVOperandOp1
>::OpConversionPattern
;
1468 matchAndRewrite(OneVResOneVOperandOp1 op
, OpAdaptor adaptor
,
1469 ConversionPatternRewriter
&rewriter
) const override
{
1470 auto origOps
= op
.getOperands();
1471 assert(std::distance(origOps
.begin(), origOps
.end()) == 1 &&
1472 "One operand expected");
1473 Value origOp
= *origOps
.begin();
1474 SmallVector
<Value
, 2> remappedOperands
;
1475 // Replicate the remapped original operand twice. Note that we don't used
1476 // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1477 remappedOperands
.push_back(rewriter
.getRemappedValue(origOp
));
1478 remappedOperands
.push_back(rewriter
.getRemappedValue(origOp
));
1480 rewriter
.replaceOpWithNewOp
<OneVResOneVOperandOp1
>(op
, op
.getResultTypes(),
1486 /// A rewriter pattern that tests that blocks can be merged.
1487 struct TestRemapValueInRegion
1488 : public OpConversionPattern
<TestRemappedValueRegionOp
> {
1489 using OpConversionPattern
<TestRemappedValueRegionOp
>::OpConversionPattern
;
1492 matchAndRewrite(TestRemappedValueRegionOp op
, OpAdaptor adaptor
,
1493 ConversionPatternRewriter
&rewriter
) const final
{
1494 Block
&block
= op
.getBody().front();
1495 Operation
*terminator
= block
.getTerminator();
1497 // Merge the block into the parent region.
1498 Block
*parentBlock
= op
->getBlock();
1499 Block
*finalBlock
= rewriter
.splitBlock(parentBlock
, op
->getIterator());
1500 rewriter
.mergeBlocks(&block
, parentBlock
, ValueRange());
1501 rewriter
.mergeBlocks(finalBlock
, parentBlock
, ValueRange());
1503 // Replace the results of this operation with the remapped terminator
1505 SmallVector
<Value
> terminatorOperands
;
1506 if (failed(rewriter
.getRemappedValues(terminator
->getOperands(),
1507 terminatorOperands
)))
1510 rewriter
.eraseOp(terminator
);
1511 rewriter
.replaceOp(op
, terminatorOperands
);
1516 struct TestRemappedValue
1517 : public mlir::PassWrapper
<TestRemappedValue
, OperationPass
<>> {
1518 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue
)
1520 StringRef
getArgument() const final
{ return "test-remapped-value"; }
1521 StringRef
getDescription() const final
{
1522 return "Test public remapped value mechanism in ConversionPatternRewriter";
1524 void runOnOperation() override
{
1525 TestRemapValueTypeConverter typeConverter
;
1527 mlir::RewritePatternSet
patterns(&getContext());
1528 patterns
.add
<OneVResOneVOperandOp1Converter
>(&getContext());
1529 patterns
.add
<TestChangeProducerTypeF32ToF64
, TestUpdateConsumerType
>(
1531 patterns
.add
<TestRemapValueInRegion
>(typeConverter
, &getContext());
1533 mlir::ConversionTarget
target(getContext());
1534 target
.addLegalOp
<ModuleOp
, func::FuncOp
, TestReturnOp
>();
1536 // Expect the type_producer/type_consumer operations to only operate on f64.
1537 target
.addDynamicallyLegalOp
<TestTypeProducerOp
>(
1538 [](TestTypeProducerOp op
) { return op
.getType().isF64(); });
1539 target
.addDynamicallyLegalOp
<TestTypeConsumerOp
>([](TestTypeConsumerOp op
) {
1540 return op
.getOperand().getType().isF64();
1543 // We make OneVResOneVOperandOp1 legal only when it has more that one
1544 // operand. This will trigger the conversion that will replace one-operand
1545 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1546 target
.addDynamicallyLegalOp
<OneVResOneVOperandOp1
>(
1547 [](Operation
*op
) { return op
->getNumOperands() > 1; });
1549 if (failed(mlir::applyFullConversion(getOperation(), target
,
1550 std::move(patterns
)))) {
1551 signalPassFailure();
1557 //===----------------------------------------------------------------------===//
1558 // Test patterns without a specific root operation kind
1559 //===----------------------------------------------------------------------===//
1562 /// This pattern matches and removes any operation in the test dialect.
1563 struct RemoveTestDialectOps
: public RewritePattern
{
1564 RemoveTestDialectOps(MLIRContext
*context
)
1565 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context
) {}
1567 LogicalResult
matchAndRewrite(Operation
*op
,
1568 PatternRewriter
&rewriter
) const override
{
1569 if (!isa
<TestDialect
>(op
->getDialect()))
1571 rewriter
.eraseOp(op
);
1576 struct TestUnknownRootOpDriver
1577 : public mlir::PassWrapper
<TestUnknownRootOpDriver
, OperationPass
<>> {
1578 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver
)
1580 StringRef
getArgument() const final
{
1581 return "test-legalize-unknown-root-patterns";
1583 StringRef
getDescription() const final
{
1584 return "Test public remapped value mechanism in ConversionPatternRewriter";
1586 void runOnOperation() override
{
1587 mlir::RewritePatternSet
patterns(&getContext());
1588 patterns
.add
<RemoveTestDialectOps
>(&getContext());
1590 mlir::ConversionTarget
target(getContext());
1591 target
.addIllegalDialect
<TestDialect
>();
1592 if (failed(applyPartialConversion(getOperation(), target
,
1593 std::move(patterns
))))
1594 signalPassFailure();
1599 //===----------------------------------------------------------------------===//
1600 // Test patterns that uses operations and types defined at runtime
1601 //===----------------------------------------------------------------------===//
1604 /// This pattern matches dynamic operations 'test.one_operand_two_results' and
1605 /// replace them with dynamic operations 'test.generic_dynamic_op'.
1606 struct RewriteDynamicOp
: public RewritePattern
{
1607 RewriteDynamicOp(MLIRContext
*context
)
1608 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1611 LogicalResult
matchAndRewrite(Operation
*op
,
1612 PatternRewriter
&rewriter
) const override
{
1613 assert(op
->getName().getStringRef() ==
1614 "test.dynamic_one_operand_two_results" &&
1615 "rewrite pattern should only match operations with the right name");
1617 OperationState
state(op
->getLoc(), "test.dynamic_generic",
1618 op
->getOperands(), op
->getResultTypes(),
1620 auto *newOp
= rewriter
.create(state
);
1621 rewriter
.replaceOp(op
, newOp
->getResults());
1626 struct TestRewriteDynamicOpDriver
1627 : public PassWrapper
<TestRewriteDynamicOpDriver
, OperationPass
<>> {
1628 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver
)
1630 void getDependentDialects(DialectRegistry
®istry
) const override
{
1631 registry
.insert
<TestDialect
>();
1633 StringRef
getArgument() const final
{ return "test-rewrite-dynamic-op"; }
1634 StringRef
getDescription() const final
{
1635 return "Test rewritting on dynamic operations";
1637 void runOnOperation() override
{
1638 RewritePatternSet
patterns(&getContext());
1639 patterns
.add
<RewriteDynamicOp
>(&getContext());
1641 ConversionTarget
target(getContext());
1642 target
.addIllegalOp(
1643 OperationName("test.dynamic_one_operand_two_results", &getContext()));
1644 target
.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
1645 if (failed(applyPartialConversion(getOperation(), target
,
1646 std::move(patterns
))))
1647 signalPassFailure();
1650 } // end anonymous namespace
1652 //===----------------------------------------------------------------------===//
1653 // Test type conversions
1654 //===----------------------------------------------------------------------===//
1657 struct TestTypeConversionProducer
1658 : public OpConversionPattern
<TestTypeProducerOp
> {
1659 using OpConversionPattern
<TestTypeProducerOp
>::OpConversionPattern
;
1661 matchAndRewrite(TestTypeProducerOp op
, OpAdaptor adaptor
,
1662 ConversionPatternRewriter
&rewriter
) const final
{
1663 Type resultType
= op
.getType();
1664 Type convertedType
= getTypeConverter()
1665 ? getTypeConverter()->convertType(resultType
)
1667 if (isa
<FloatType
>(resultType
))
1668 resultType
= rewriter
.getF64Type();
1669 else if (resultType
.isInteger(16))
1670 resultType
= rewriter
.getIntegerType(64);
1671 else if (isa
<test::TestRecursiveType
>(resultType
) &&
1672 convertedType
!= resultType
)
1673 resultType
= convertedType
;
1677 rewriter
.replaceOpWithNewOp
<TestTypeProducerOp
>(op
, resultType
);
1682 /// Call signature conversion and then fail the rewrite to trigger the undo
1684 struct TestSignatureConversionUndo
1685 : public OpConversionPattern
<TestSignatureConversionUndoOp
> {
1686 using OpConversionPattern
<TestSignatureConversionUndoOp
>::OpConversionPattern
;
1689 matchAndRewrite(TestSignatureConversionUndoOp op
, OpAdaptor adaptor
,
1690 ConversionPatternRewriter
&rewriter
) const final
{
1691 (void)rewriter
.convertRegionTypes(&op
->getRegion(0), *getTypeConverter());
1696 /// Call signature conversion without providing a type converter to handle
1697 /// materializations.
1698 struct TestTestSignatureConversionNoConverter
1699 : public OpConversionPattern
<TestSignatureConversionNoConverterOp
> {
1700 TestTestSignatureConversionNoConverter(const TypeConverter
&converter
,
1701 MLIRContext
*context
)
1702 : OpConversionPattern
<TestSignatureConversionNoConverterOp
>(context
),
1703 converter(converter
) {}
1706 matchAndRewrite(TestSignatureConversionNoConverterOp op
, OpAdaptor adaptor
,
1707 ConversionPatternRewriter
&rewriter
) const final
{
1708 Region
®ion
= op
->getRegion(0);
1709 Block
*entry
= ®ion
.front();
1711 // Convert the original entry arguments.
1712 TypeConverter::SignatureConversion
result(entry
->getNumArguments());
1714 converter
.convertSignatureArgs(entry
->getArgumentTypes(), result
)))
1716 rewriter
.modifyOpInPlace(op
, [&] {
1717 rewriter
.applySignatureConversion(®ion
.front(), result
);
1722 const TypeConverter
&converter
;
1725 /// Just forward the operands to the root op. This is essentially a no-op
1726 /// pattern that is used to trigger target materialization.
1727 struct TestTypeConsumerForward
1728 : public OpConversionPattern
<TestTypeConsumerOp
> {
1729 using OpConversionPattern
<TestTypeConsumerOp
>::OpConversionPattern
;
1732 matchAndRewrite(TestTypeConsumerOp op
, OpAdaptor adaptor
,
1733 ConversionPatternRewriter
&rewriter
) const final
{
1734 rewriter
.modifyOpInPlace(op
,
1735 [&] { op
->setOperands(adaptor
.getOperands()); });
1740 struct TestTypeConversionAnotherProducer
1741 : public OpRewritePattern
<TestAnotherTypeProducerOp
> {
1742 using OpRewritePattern
<TestAnotherTypeProducerOp
>::OpRewritePattern
;
1744 LogicalResult
matchAndRewrite(TestAnotherTypeProducerOp op
,
1745 PatternRewriter
&rewriter
) const final
{
1746 rewriter
.replaceOpWithNewOp
<TestTypeProducerOp
>(op
, op
.getType());
1751 struct TestReplaceWithLegalOp
: public ConversionPattern
{
1752 TestReplaceWithLegalOp(MLIRContext
*ctx
)
1753 : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx
) {}
1755 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1756 ConversionPatternRewriter
&rewriter
) const final
{
1757 rewriter
.replaceOpWithNewOp
<LegalOpD
>(op
, operands
[0]);
1762 struct TestTypeConversionDriver
1763 : public PassWrapper
<TestTypeConversionDriver
, OperationPass
<>> {
1764 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver
)
1766 void getDependentDialects(DialectRegistry
®istry
) const override
{
1767 registry
.insert
<TestDialect
>();
1769 StringRef
getArgument() const final
{
1770 return "test-legalize-type-conversion";
1772 StringRef
getDescription() const final
{
1773 return "Test various type conversion functionalities in DialectConversion";
1776 void runOnOperation() override
{
1777 // Initialize the type converter.
1778 SmallVector
<Type
, 2> conversionCallStack
;
1779 TypeConverter converter
;
1781 /// Add the legal set of type conversions.
1782 converter
.addConversion([](Type type
) -> Type
{
1783 // Treat F64 as legal.
1786 // Allow converting BF16/F16/F32 to F64.
1787 if (type
.isBF16() || type
.isF16() || type
.isF32())
1788 return FloatType::getF64(type
.getContext());
1789 // Otherwise, the type is illegal.
1792 converter
.addConversion([](IntegerType type
, SmallVectorImpl
<Type
> &) {
1793 // Drop all integer types.
1796 converter
.addConversion(
1797 // Convert a recursive self-referring type into a non-self-referring
1798 // type named "outer_converted_type" that contains a SimpleAType.
1799 [&](test::TestRecursiveType type
,
1800 SmallVectorImpl
<Type
> &results
) -> std::optional
<LogicalResult
> {
1801 // If the type is already converted, return it to indicate that it is
1803 if (type
.getName() == "outer_converted_type") {
1804 results
.push_back(type
);
1808 conversionCallStack
.push_back(type
);
1809 auto popConversionCallStack
= llvm::make_scope_exit(
1810 [&conversionCallStack
]() { conversionCallStack
.pop_back(); });
1812 // If the type is on the call stack more than once (it is there at
1813 // least once because of the _current_ call, which is always the last
1814 // element on the stack), we've hit the recursive case. Just return
1815 // SimpleAType here to create a non-recursive type as a result.
1816 if (llvm::is_contained(ArrayRef(conversionCallStack
).drop_back(),
1818 results
.push_back(test::SimpleAType::get(type
.getContext()));
1822 // Convert the body recursively.
1823 auto result
= test::TestRecursiveType::get(type
.getContext(),
1824 "outer_converted_type");
1825 if (failed(result
.setBody(converter
.convertType(type
.getBody()))))
1827 results
.push_back(result
);
1831 /// Add the legal set of type materializations.
1832 converter
.addSourceMaterialization([](OpBuilder
&builder
, Type resultType
,
1834 Location loc
) -> Value
{
1835 // Allow casting from F64 back to F32.
1836 if (!resultType
.isF16() && inputs
.size() == 1 &&
1837 inputs
[0].getType().isF64())
1838 return builder
.create
<TestCastOp
>(loc
, resultType
, inputs
).getResult();
1839 // Allow producing an i32 or i64 from nothing.
1840 if ((resultType
.isInteger(32) || resultType
.isInteger(64)) &&
1842 return builder
.create
<TestTypeProducerOp
>(loc
, resultType
);
1843 // Allow producing an i64 from an integer.
1844 if (isa
<IntegerType
>(resultType
) && inputs
.size() == 1 &&
1845 isa
<IntegerType
>(inputs
[0].getType()))
1846 return builder
.create
<TestCastOp
>(loc
, resultType
, inputs
).getResult();
1851 // Initialize the conversion target.
1852 mlir::ConversionTarget
target(getContext());
1853 target
.addLegalOp
<LegalOpD
>();
1854 target
.addDynamicallyLegalOp
<TestTypeProducerOp
>([](TestTypeProducerOp op
) {
1855 auto recursiveType
= dyn_cast
<test::TestRecursiveType
>(op
.getType());
1856 return op
.getType().isF64() || op
.getType().isInteger(64) ||
1858 recursiveType
.getName() == "outer_converted_type");
1860 target
.addDynamicallyLegalOp
<func::FuncOp
>([&](func::FuncOp op
) {
1861 return converter
.isSignatureLegal(op
.getFunctionType()) &&
1862 converter
.isLegal(&op
.getBody());
1864 target
.addDynamicallyLegalOp
<TestCastOp
>([&](TestCastOp op
) {
1865 // Allow casts from F64 to F32.
1866 return (*op
.operand_type_begin()).isF64() && op
.getType().isF32();
1868 target
.addDynamicallyLegalOp
<TestSignatureConversionNoConverterOp
>(
1869 [&](TestSignatureConversionNoConverterOp op
) {
1870 return converter
.isLegal(op
.getRegion().front().getArgumentTypes());
1873 // Initialize the set of rewrite patterns.
1874 RewritePatternSet
patterns(&getContext());
1875 patterns
.add
<TestTypeConsumerForward
, TestTypeConversionProducer
,
1876 TestSignatureConversionUndo
,
1877 TestTestSignatureConversionNoConverter
>(converter
,
1879 patterns
.add
<TestTypeConversionAnotherProducer
, TestReplaceWithLegalOp
>(
1881 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns
,
1884 if (failed(applyPartialConversion(getOperation(), target
,
1885 std::move(patterns
))))
1886 signalPassFailure();
1891 //===----------------------------------------------------------------------===//
1892 // Test Target Materialization With No Uses
1893 //===----------------------------------------------------------------------===//
1896 struct ForwardOperandPattern
: public OpConversionPattern
<TestTypeChangerOp
> {
1897 using OpConversionPattern
<TestTypeChangerOp
>::OpConversionPattern
;
1900 matchAndRewrite(TestTypeChangerOp op
, OpAdaptor adaptor
,
1901 ConversionPatternRewriter
&rewriter
) const final
{
1902 rewriter
.replaceOp(op
, adaptor
.getOperands());
1907 struct TestTargetMaterializationWithNoUses
1908 : public PassWrapper
<TestTargetMaterializationWithNoUses
, OperationPass
<>> {
1909 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1910 TestTargetMaterializationWithNoUses
)
1912 StringRef
getArgument() const final
{
1913 return "test-target-materialization-with-no-uses";
1915 StringRef
getDescription() const final
{
1916 return "Test a special case of target materialization in DialectConversion";
1919 void runOnOperation() override
{
1920 TypeConverter converter
;
1921 converter
.addConversion([](Type t
) { return t
; });
1922 converter
.addConversion([](IntegerType intTy
) -> Type
{
1923 if (intTy
.getWidth() == 16)
1924 return IntegerType::get(intTy
.getContext(), 64);
1927 converter
.addTargetMaterialization(
1928 [](OpBuilder
&builder
, Type type
, ValueRange inputs
, Location loc
) {
1929 return builder
.create
<TestCastOp
>(loc
, type
, inputs
).getResult();
1932 ConversionTarget
target(getContext());
1933 target
.addIllegalOp
<TestTypeChangerOp
>();
1935 RewritePatternSet
patterns(&getContext());
1936 patterns
.add
<ForwardOperandPattern
>(converter
, &getContext());
1938 if (failed(applyPartialConversion(getOperation(), target
,
1939 std::move(patterns
))))
1940 signalPassFailure();
1945 //===----------------------------------------------------------------------===//
1946 // Test Block Merging
1947 //===----------------------------------------------------------------------===//
1950 /// A rewriter pattern that tests that blocks can be merged.
1951 struct TestMergeBlock
: public OpConversionPattern
<TestMergeBlocksOp
> {
1952 using OpConversionPattern
<TestMergeBlocksOp
>::OpConversionPattern
;
1955 matchAndRewrite(TestMergeBlocksOp op
, OpAdaptor adaptor
,
1956 ConversionPatternRewriter
&rewriter
) const final
{
1957 Block
&firstBlock
= op
.getBody().front();
1958 Operation
*branchOp
= firstBlock
.getTerminator();
1959 Block
*secondBlock
= &*(std::next(op
.getBody().begin()));
1960 auto succOperands
= branchOp
->getOperands();
1961 SmallVector
<Value
, 2> replacements(succOperands
);
1962 rewriter
.eraseOp(branchOp
);
1963 rewriter
.mergeBlocks(secondBlock
, &firstBlock
, replacements
);
1964 rewriter
.modifyOpInPlace(op
, [] {});
1969 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1970 struct TestUndoBlocksMerge
: public ConversionPattern
{
1971 TestUndoBlocksMerge(MLIRContext
*ctx
)
1972 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx
) {}
1974 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1975 ConversionPatternRewriter
&rewriter
) const final
{
1976 Block
&firstBlock
= op
->getRegion(0).front();
1977 Operation
*branchOp
= firstBlock
.getTerminator();
1978 Block
*secondBlock
= &*(std::next(op
->getRegion(0).begin()));
1979 rewriter
.setInsertionPointToStart(secondBlock
);
1980 rewriter
.create
<ILLegalOpF
>(op
->getLoc(), rewriter
.getF32Type());
1981 auto succOperands
= branchOp
->getOperands();
1982 SmallVector
<Value
, 2> replacements(succOperands
);
1983 rewriter
.eraseOp(branchOp
);
1984 rewriter
.mergeBlocks(secondBlock
, &firstBlock
, replacements
);
1985 rewriter
.modifyOpInPlace(op
, [] {});
1990 /// A rewrite mechanism to inline the body of the op into its parent, when both
1991 /// ops can have a single block.
1992 struct TestMergeSingleBlockOps
1993 : public OpConversionPattern
<SingleBlockImplicitTerminatorOp
> {
1994 using OpConversionPattern
<
1995 SingleBlockImplicitTerminatorOp
>::OpConversionPattern
;
1998 matchAndRewrite(SingleBlockImplicitTerminatorOp op
, OpAdaptor adaptor
,
1999 ConversionPatternRewriter
&rewriter
) const final
{
2000 SingleBlockImplicitTerminatorOp parentOp
=
2001 op
->getParentOfType
<SingleBlockImplicitTerminatorOp
>();
2004 Block
&innerBlock
= op
.getRegion().front();
2005 TerminatorOp innerTerminator
=
2006 cast
<TerminatorOp
>(innerBlock
.getTerminator());
2007 rewriter
.inlineBlockBefore(&innerBlock
, op
);
2008 rewriter
.eraseOp(innerTerminator
);
2009 rewriter
.eraseOp(op
);
2014 struct TestMergeBlocksPatternDriver
2015 : public PassWrapper
<TestMergeBlocksPatternDriver
, OperationPass
<>> {
2016 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver
)
2018 StringRef
getArgument() const final
{ return "test-merge-blocks"; }
2019 StringRef
getDescription() const final
{
2020 return "Test Merging operation in ConversionPatternRewriter";
2022 void runOnOperation() override
{
2023 MLIRContext
*context
= &getContext();
2024 mlir::RewritePatternSet
patterns(context
);
2025 patterns
.add
<TestMergeBlock
, TestUndoBlocksMerge
, TestMergeSingleBlockOps
>(
2027 ConversionTarget
target(*context
);
2028 target
.addLegalOp
<func::FuncOp
, ModuleOp
, TerminatorOp
, TestBranchOp
,
2029 TestTypeConsumerOp
, TestTypeProducerOp
, TestReturnOp
>();
2030 target
.addIllegalOp
<ILLegalOpF
>();
2032 /// Expect the op to have a single block after legalization.
2033 target
.addDynamicallyLegalOp
<TestMergeBlocksOp
>(
2034 [&](TestMergeBlocksOp op
) -> bool {
2035 return llvm::hasSingleElement(op
.getBody());
2038 /// Only allow `test.br` within test.merge_blocks op.
2039 target
.addDynamicallyLegalOp
<TestBranchOp
>([&](TestBranchOp op
) -> bool {
2040 return op
->getParentOfType
<TestMergeBlocksOp
>();
2043 /// Expect that all nested test.SingleBlockImplicitTerminator ops are
2045 target
.addDynamicallyLegalOp
<SingleBlockImplicitTerminatorOp
>(
2046 [&](SingleBlockImplicitTerminatorOp op
) -> bool {
2047 return !op
->getParentOfType
<SingleBlockImplicitTerminatorOp
>();
2050 DenseSet
<Operation
*> unlegalizedOps
;
2051 ConversionConfig config
;
2052 config
.unlegalizedOps
= &unlegalizedOps
;
2053 (void)applyPartialConversion(getOperation(), target
, std::move(patterns
),
2055 for (auto *op
: unlegalizedOps
)
2056 op
->emitRemark() << "op '" << op
->getName() << "' is not legalizable";
2061 //===----------------------------------------------------------------------===//
2062 // Test Selective Replacement
2063 //===----------------------------------------------------------------------===//
2066 /// A rewrite mechanism to inline the body of the op into its parent, when both
2067 /// ops can have a single block.
2068 struct TestSelectiveOpReplacementPattern
: public OpRewritePattern
<TestCastOp
> {
2069 using OpRewritePattern
<TestCastOp
>::OpRewritePattern
;
2071 LogicalResult
matchAndRewrite(TestCastOp op
,
2072 PatternRewriter
&rewriter
) const final
{
2073 if (op
.getNumOperands() != 2)
2075 OperandRange operands
= op
.getOperands();
2077 // Replace non-terminator uses with the first operand.
2078 rewriter
.replaceUsesWithIf(op
, operands
[0], [](OpOperand
&operand
) {
2079 return operand
.getOwner()->hasTrait
<OpTrait::IsTerminator
>();
2081 // Replace everything else with the second operand if the operation isn't
2083 rewriter
.replaceOp(op
, op
.getOperand(1));
2088 struct TestSelectiveReplacementPatternDriver
2089 : public PassWrapper
<TestSelectiveReplacementPatternDriver
,
2091 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2092 TestSelectiveReplacementPatternDriver
)
2094 StringRef
getArgument() const final
{
2095 return "test-pattern-selective-replacement";
2097 StringRef
getDescription() const final
{
2098 return "Test selective replacement in the PatternRewriter";
2100 void runOnOperation() override
{
2101 MLIRContext
*context
= &getContext();
2102 mlir::RewritePatternSet
patterns(context
);
2103 patterns
.add
<TestSelectiveOpReplacementPattern
>(context
);
2104 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
2109 //===----------------------------------------------------------------------===//
2111 //===----------------------------------------------------------------------===//
2115 void registerPatternsTestPass() {
2116 PassRegistration
<TestReturnTypeDriver
>();
2118 PassRegistration
<TestDerivedAttributeDriver
>();
2120 PassRegistration
<TestGreedyPatternDriver
>();
2121 PassRegistration
<TestStrictPatternDriver
>();
2122 PassRegistration
<TestWalkPatternDriver
>();
2124 PassRegistration
<TestLegalizePatternDriver
>([] {
2125 return std::make_unique
<TestLegalizePatternDriver
>(legalizerConversionMode
);
2128 PassRegistration
<TestRemappedValue
>();
2130 PassRegistration
<TestUnknownRootOpDriver
>();
2132 PassRegistration
<TestTypeConversionDriver
>();
2133 PassRegistration
<TestTargetMaterializationWithNoUses
>();
2135 PassRegistration
<TestRewriteDynamicOpDriver
>();
2137 PassRegistration
<TestMergeBlocksPatternDriver
>();
2138 PassRegistration
<TestSelectiveReplacementPatternDriver
>();