1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "TestDialect.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
24 #include "llvm/ADT/ScopeExit.h"
30 // Native function for testing NativeCodeCall
31 static Value
chooseOperand(Value input1
, Value input2
, BoolAttr choice
) {
32 return choice
.getValue() ? input1
: input2
;
35 static void createOpI(PatternRewriter
&rewriter
, Location loc
, Value input
) {
36 rewriter
.create
<OpI
>(loc
, input
);
39 static void handleNoResultOp(PatternRewriter
&rewriter
,
40 OpSymbolBindingNoResult op
) {
41 // Turn the no result op to a one-result op.
42 rewriter
.create
<OpSymbolBindingB
>(op
.getLoc(), op
.getOperand().getType(),
46 static bool getFirstI32Result(Operation
*op
, Value
&value
) {
47 if (!Type(op
->getResult(0).getType()).isSignlessInteger(32))
49 value
= op
->getResult(0);
53 static Value
bindNativeCodeCallResult(Value value
) { return value
; }
55 static SmallVector
<Value
, 2> bindMultipleNativeCodeCallResult(Value input1
,
57 return SmallVector
<Value
, 2>({input2
, input1
});
60 // Test that natives calls are only called once during rewrites.
61 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
62 // This let us check the number of times OpM_Test was called by inspecting
63 // the returned value in the MLIR output.
64 static int64_t opMIncreasingValue
= 314159265;
65 static Attribute
opMTest(PatternRewriter
&rewriter
, Value val
) {
66 int64_t i
= opMIncreasingValue
++;
67 return rewriter
.getIntegerAttr(rewriter
.getIntegerType(32), i
);
71 #include "TestPatterns.inc"
74 //===----------------------------------------------------------------------===//
75 // Test Reduce Pattern Interface
76 //===----------------------------------------------------------------------===//
78 void test::populateTestReductionPatterns(RewritePatternSet
&patterns
) {
79 populateWithGenerated(patterns
);
82 //===----------------------------------------------------------------------===//
83 // Canonicalizer Driver.
84 //===----------------------------------------------------------------------===//
87 struct FoldingPattern
: public RewritePattern
{
89 FoldingPattern(MLIRContext
*context
)
90 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
91 /*benefit=*/1, context
) {}
93 LogicalResult
matchAndRewrite(Operation
*op
,
94 PatternRewriter
&rewriter
) const override
{
95 // Exercise createOrFold API for a single-result operation that is folded
96 // upon construction. The operation being created has an in-place folder,
97 // and it should be still present in the output. Furthermore, the folder
98 // should not crash when attempting to recover the (unchanged) operation
100 Value result
= rewriter
.createOrFold
<TestOpInPlaceFold
>(
101 op
->getLoc(), rewriter
.getIntegerType(32), op
->getOperand(0));
103 rewriter
.replaceOp(op
, result
);
108 /// This pattern creates a foldable operation at the entry point of the block.
109 /// This tests the situation where the operation folder will need to replace an
110 /// operation with a previously created constant that does not initially
111 /// dominate the operation to replace.
112 struct FolderInsertBeforePreviouslyFoldedConstantPattern
113 : public OpRewritePattern
<TestCastOp
> {
115 using OpRewritePattern
<TestCastOp
>::OpRewritePattern
;
117 LogicalResult
matchAndRewrite(TestCastOp op
,
118 PatternRewriter
&rewriter
) const override
{
119 if (!op
->hasAttr("test_fold_before_previously_folded_op"))
121 rewriter
.setInsertionPointToStart(op
->getBlock());
123 auto constOp
= rewriter
.create
<arith::ConstantOp
>(
124 op
.getLoc(), rewriter
.getBoolAttr(true));
125 rewriter
.replaceOpWithNewOp
<TestCastOp
>(op
, rewriter
.getI32Type(),
131 /// This pattern matches test.op_commutative2 with the first operand being
132 /// another test.op_commutative2 with a constant on the right side and fold it
133 /// away by propagating it as its result. This is intend to check that patterns
134 /// are applied after the commutative property moves constant to the right.
135 struct FolderCommutativeOp2WithConstant
136 : public OpRewritePattern
<TestCommutative2Op
> {
138 using OpRewritePattern
<TestCommutative2Op
>::OpRewritePattern
;
140 LogicalResult
matchAndRewrite(TestCommutative2Op op
,
141 PatternRewriter
&rewriter
) const override
{
143 dyn_cast_or_null
<TestCommutative2Op
>(op
->getOperand(0).getDefiningOp());
146 Attribute constInput
;
147 if (!matchPattern(operand
->getOperand(1), m_Constant(&constInput
)))
149 rewriter
.replaceOp(op
, operand
->getOperand(1));
154 /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
155 /// attribute with value smaller than MaxVal, it increments the value by 1.
156 template <int MaxVal
>
157 struct IncrementIntAttribute
: public OpRewritePattern
<AnyAttrOfOp
> {
158 using OpRewritePattern
<AnyAttrOfOp
>::OpRewritePattern
;
160 LogicalResult
matchAndRewrite(AnyAttrOfOp op
,
161 PatternRewriter
&rewriter
) const override
{
162 auto intAttr
= dyn_cast
<IntegerAttr
>(op
.getAttr());
165 int64_t val
= intAttr
.getInt();
168 rewriter
.modifyOpInPlace(
169 op
, [&]() { op
.setAttrAttr(rewriter
.getI32IntegerAttr(val
+ 1)); });
174 /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175 struct MakeOpEligible
: public RewritePattern
{
176 MakeOpEligible(MLIRContext
*context
)
177 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context
) {}
179 LogicalResult
matchAndRewrite(Operation
*op
,
180 PatternRewriter
&rewriter
) const override
{
181 if (op
->hasAttr("eligible"))
183 rewriter
.modifyOpInPlace(
184 op
, [&]() { op
->setAttr("eligible", rewriter
.getUnitAttr()); });
189 /// This pattern hoists eligible ops out of a "test.one_region_op".
190 struct HoistEligibleOps
: public OpRewritePattern
<test::OneRegionOp
> {
191 using OpRewritePattern
<test::OneRegionOp
>::OpRewritePattern
;
193 LogicalResult
matchAndRewrite(test::OneRegionOp op
,
194 PatternRewriter
&rewriter
) const override
{
195 Operation
*terminator
= op
.getRegion().front().getTerminator();
196 Operation
*toBeHoisted
= terminator
->getOperands()[0].getDefiningOp();
197 if (toBeHoisted
->getParentOp() != op
)
199 if (!toBeHoisted
->hasAttr("eligible"))
201 rewriter
.moveOpBefore(toBeHoisted
, op
);
206 /// This pattern moves "test.move_before_parent_op" before the parent op.
207 struct MoveBeforeParentOp
: public RewritePattern
{
208 MoveBeforeParentOp(MLIRContext
*context
)
209 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context
) {}
211 LogicalResult
matchAndRewrite(Operation
*op
,
212 PatternRewriter
&rewriter
) const override
{
213 // Do not hoist past functions.
214 if (isa
<FunctionOpInterface
>(op
->getParentOp()))
216 rewriter
.moveOpBefore(op
, op
->getParentOp());
221 /// This pattern moves "test.move_after_parent_op" after the parent op.
222 struct MoveAfterParentOp
: public RewritePattern
{
223 MoveAfterParentOp(MLIRContext
*context
)
224 : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context
) {}
226 LogicalResult
matchAndRewrite(Operation
*op
,
227 PatternRewriter
&rewriter
) const override
{
228 // Do not hoist past functions.
229 if (isa
<FunctionOpInterface
>(op
->getParentOp()))
232 int64_t moveForwardBy
= 0;
233 if (auto advanceBy
= op
->getAttrOfType
<IntegerAttr
>("advance"))
234 moveForwardBy
= advanceBy
.getInt();
236 Operation
*moveAfter
= op
->getParentOp();
237 for (int64_t i
= 0; i
< moveForwardBy
; ++i
)
238 moveAfter
= moveAfter
->getNextNode();
240 rewriter
.moveOpAfter(op
, moveAfter
);
245 /// This pattern inlines blocks that are nested in
246 /// "test.inline_blocks_into_parent" into the parent block.
247 struct InlineBlocksIntoParent
: public RewritePattern
{
248 InlineBlocksIntoParent(MLIRContext
*context
)
249 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
252 LogicalResult
matchAndRewrite(Operation
*op
,
253 PatternRewriter
&rewriter
) const override
{
254 bool changed
= false;
255 for (Region
&r
: op
->getRegions()) {
257 rewriter
.inlineBlockBefore(&r
.front(), op
);
261 return success(changed
);
265 /// This pattern splits blocks at "test.split_block_here" and replaces the op
266 /// with a new op (to prevent an infinite loop of block splitting).
267 struct SplitBlockHere
: public RewritePattern
{
268 SplitBlockHere(MLIRContext
*context
)
269 : RewritePattern("test.split_block_here", /*benefit=*/1, context
) {}
271 LogicalResult
matchAndRewrite(Operation
*op
,
272 PatternRewriter
&rewriter
) const override
{
273 rewriter
.splitBlock(op
->getBlock(), op
->getIterator());
274 Operation
*newOp
= rewriter
.create(
276 OperationName("test.new_op", op
->getContext()).getIdentifier(),
277 op
->getOperands(), op
->getResultTypes());
278 rewriter
.replaceOp(op
, newOp
);
283 /// This pattern clones "test.clone_me" ops.
284 struct CloneOp
: public RewritePattern
{
285 CloneOp(MLIRContext
*context
)
286 : RewritePattern("test.clone_me", /*benefit=*/1, context
) {}
288 LogicalResult
matchAndRewrite(Operation
*op
,
289 PatternRewriter
&rewriter
) const override
{
290 // Do not clone already cloned ops to avoid going into an infinite loop.
291 if (op
->hasAttr("was_cloned"))
293 Operation
*cloned
= rewriter
.clone(*op
);
294 cloned
->setAttr("was_cloned", rewriter
.getUnitAttr());
299 /// This pattern clones regions of "test.clone_region_before" ops before the
301 struct CloneRegionBeforeOp
: public RewritePattern
{
302 CloneRegionBeforeOp(MLIRContext
*context
)
303 : RewritePattern("test.clone_region_before", /*benefit=*/1, context
) {}
305 LogicalResult
matchAndRewrite(Operation
*op
,
306 PatternRewriter
&rewriter
) const override
{
307 // Do not clone already cloned ops to avoid going into an infinite loop.
308 if (op
->hasAttr("was_cloned"))
310 for (Region
&r
: op
->getRegions())
311 rewriter
.cloneRegionBefore(r
, op
->getBlock());
312 op
->setAttr("was_cloned", rewriter
.getUnitAttr());
317 /// Replace an operation may introduce the re-visiting of its users.
318 class ReplaceWithNewOp
: public RewritePattern
{
320 ReplaceWithNewOp(MLIRContext
*context
)
321 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context
) {}
323 LogicalResult
matchAndRewrite(Operation
*op
,
324 PatternRewriter
&rewriter
) const override
{
326 if (op
->hasAttr("create_erase_op")) {
327 newOp
= rewriter
.create(
329 OperationName("test.erase_op", op
->getContext()).getIdentifier(),
330 ValueRange(), TypeRange());
332 newOp
= rewriter
.create(
334 OperationName("test.new_op", op
->getContext()).getIdentifier(),
335 op
->getOperands(), op
->getResultTypes());
337 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
338 // A "notifyOperationReplaced" callback is triggered in either case.
339 rewriter
.replaceAllOpUsesWith(op
, newOp
->getResults());
340 rewriter
.eraseOp(op
);
345 /// Erases the first child block of the matched "test.erase_first_block"
347 class EraseFirstBlock
: public RewritePattern
{
349 EraseFirstBlock(MLIRContext
*context
)
350 : RewritePattern("test.erase_first_block", /*benefit=*/1, context
) {}
352 LogicalResult
matchAndRewrite(Operation
*op
,
353 PatternRewriter
&rewriter
) const override
{
354 for (Region
&r
: op
->getRegions()) {
355 for (Block
&b
: r
.getBlocks()) {
356 rewriter
.eraseBlock(&b
);
365 struct TestGreedyPatternDriver
366 : public PassWrapper
<TestGreedyPatternDriver
, OperationPass
<>> {
367 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver
)
369 TestGreedyPatternDriver() = default;
370 TestGreedyPatternDriver(const TestGreedyPatternDriver
&other
)
371 : PassWrapper(other
) {}
373 StringRef
getArgument() const final
{ return "test-greedy-patterns"; }
374 StringRef
getDescription() const final
{ return "Run test dialect patterns"; }
375 void runOnOperation() override
{
376 mlir::RewritePatternSet
patterns(&getContext());
377 populateWithGenerated(patterns
);
379 // Verify named pattern is generated with expected name.
380 patterns
.add
<FoldingPattern
, TestNamedPatternRule
,
381 FolderInsertBeforePreviouslyFoldedConstantPattern
,
382 FolderCommutativeOp2WithConstant
, HoistEligibleOps
,
383 MakeOpEligible
>(&getContext());
385 // Additional patterns for testing the GreedyPatternRewriteDriver.
386 patterns
.insert
<IncrementIntAttribute
<3>>(&getContext());
388 GreedyRewriteConfig config
;
389 config
.useTopDownTraversal
= this->useTopDownTraversal
;
390 config
.maxIterations
= this->maxIterations
;
391 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
),
395 Option
<bool> useTopDownTraversal
{
397 llvm::cl::desc("Seed the worklist in general top-down order"),
398 llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal
)};
399 Option
<int> maxIterations
{
400 *this, "max-iterations",
401 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
402 llvm::cl::init(GreedyRewriteConfig().maxIterations
)};
405 struct DumpNotifications
: public RewriterBase::Listener
{
406 void notifyBlockInserted(Block
*block
, Region
*previous
,
407 Region::iterator previousIt
) override
{
408 llvm::outs() << "notifyBlockInserted";
409 if (block
->getParentOp()) {
410 llvm::outs() << " into " << block
->getParentOp()->getName() << ": ";
412 llvm::outs() << " into unknown op: ";
414 if (previous
== nullptr) {
415 llvm::outs() << "was unlinked\n";
417 llvm::outs() << "was linked\n";
420 void notifyOperationInserted(Operation
*op
,
421 OpBuilder::InsertPoint previous
) override
{
422 llvm::outs() << "notifyOperationInserted: " << op
->getName();
423 if (!previous
.isSet()) {
424 llvm::outs() << ", was unlinked\n";
426 if (!previous
.getPoint().getNodePtr()) {
427 llvm::outs() << ", was linked, exact position unknown\n";
428 } else if (previous
.getPoint() == previous
.getBlock()->end()) {
429 llvm::outs() << ", was last in block\n";
431 llvm::outs() << ", previous = " << previous
.getPoint()->getName()
436 void notifyBlockErased(Block
*block
) override
{
437 llvm::outs() << "notifyBlockErased\n";
439 void notifyOperationErased(Operation
*op
) override
{
440 llvm::outs() << "notifyOperationErased: " << op
->getName() << "\n";
442 void notifyOperationModified(Operation
*op
) override
{
443 llvm::outs() << "notifyOperationModified: " << op
->getName() << "\n";
445 void notifyOperationReplaced(Operation
*op
, ValueRange values
) override
{
446 llvm::outs() << "notifyOperationReplaced: " << op
->getName() << "\n";
450 struct TestStrictPatternDriver
451 : public PassWrapper
<TestStrictPatternDriver
, OperationPass
<func::FuncOp
>> {
453 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver
)
455 TestStrictPatternDriver() = default;
456 TestStrictPatternDriver(const TestStrictPatternDriver
&other
)
457 : PassWrapper(other
) {
458 strictMode
= other
.strictMode
;
461 StringRef
getArgument() const final
{ return "test-strict-pattern-driver"; }
462 StringRef
getDescription() const final
{
463 return "Test strict mode of pattern driver";
466 void runOnOperation() override
{
467 MLIRContext
*ctx
= &getContext();
468 mlir::RewritePatternSet
patterns(ctx
);
476 InlineBlocksIntoParent
,
483 SmallVector
<Operation
*> ops
;
484 getOperation()->walk([&](Operation
*op
) {
485 StringRef opName
= op
->getName().getStringRef();
486 if (opName
== "test.insert_same_op" || opName
== "test.change_block_op" ||
487 opName
== "test.replace_with_new_op" || opName
== "test.erase_op" ||
488 opName
== "test.move_before_parent_op" ||
489 opName
== "test.inline_blocks_into_parent" ||
490 opName
== "test.split_block_here" || opName
== "test.clone_me" ||
491 opName
== "test.clone_region_before") {
496 DumpNotifications dumpNotifications
;
497 GreedyRewriteConfig config
;
498 config
.listener
= &dumpNotifications
;
499 if (strictMode
== "AnyOp") {
500 config
.strictMode
= GreedyRewriteStrictness::AnyOp
;
501 } else if (strictMode
== "ExistingAndNewOps") {
502 config
.strictMode
= GreedyRewriteStrictness::ExistingAndNewOps
;
503 } else if (strictMode
== "ExistingOps") {
504 config
.strictMode
= GreedyRewriteStrictness::ExistingOps
;
506 llvm_unreachable("invalid strictness option");
509 // Check if these transformations introduce visiting of operations that
510 // are not in the `ops` set (The new created ops are valid). An invalid
511 // operation will trigger the assertion while processing.
512 bool changed
= false;
513 bool allErased
= false;
514 (void)applyOpPatternsAndFold(ArrayRef(ops
), std::move(patterns
), config
,
515 &changed
, &allErased
);
517 getOperation()->setAttr("pattern_driver_changed", b
.getBoolAttr(changed
));
518 getOperation()->setAttr("pattern_driver_all_erased",
519 b
.getBoolAttr(allErased
));
522 Option
<std::string
> strictMode
{
524 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
525 llvm::cl::init("AnyOp")};
528 // New inserted operation is valid for further transformation.
529 class InsertSameOp
: public RewritePattern
{
531 InsertSameOp(MLIRContext
*context
)
532 : RewritePattern("test.insert_same_op", /*benefit=*/1, context
) {}
534 LogicalResult
matchAndRewrite(Operation
*op
,
535 PatternRewriter
&rewriter
) const override
{
536 if (op
->hasAttr("skip"))
540 rewriter
.create(op
->getLoc(), op
->getName().getIdentifier(),
541 op
->getOperands(), op
->getResultTypes());
542 rewriter
.modifyOpInPlace(
543 op
, [&]() { op
->setAttr("skip", rewriter
.getBoolAttr(true)); });
544 newOp
->setAttr("skip", rewriter
.getBoolAttr(true));
550 // Remove an operation may introduce the re-visiting of its operands.
551 class EraseOp
: public RewritePattern
{
553 EraseOp(MLIRContext
*context
)
554 : RewritePattern("test.erase_op", /*benefit=*/1, context
) {}
555 LogicalResult
matchAndRewrite(Operation
*op
,
556 PatternRewriter
&rewriter
) const override
{
557 rewriter
.eraseOp(op
);
562 // The following two patterns test RewriterBase::replaceAllUsesWith.
564 // That function replaces all usages of a Block (or a Value) with another one
565 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
566 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
567 // worklist: when an op is modified, it is added to the worklist. The two
568 // patterns below make the tracking observable: ChangeBlockOp replaces all
569 // usages of a block and that pattern is applied because the corresponding ops
570 // are put on the initial worklist (see above). ImplicitChangeOp does an
571 // unrelated change but ops of the corresponding type are *not* on the initial
572 // worklist, so the effect of the second pattern is only visible if the
573 // tracking and subsequent adding to the worklist actually works.
575 // Replace all usages of the first successor with the second successor.
576 class ChangeBlockOp
: public RewritePattern
{
578 ChangeBlockOp(MLIRContext
*context
)
579 : RewritePattern("test.change_block_op", /*benefit=*/1, context
) {}
580 LogicalResult
matchAndRewrite(Operation
*op
,
581 PatternRewriter
&rewriter
) const override
{
582 if (op
->getNumSuccessors() < 2)
584 Block
*firstSuccessor
= op
->getSuccessor(0);
585 Block
*secondSuccessor
= op
->getSuccessor(1);
586 if (firstSuccessor
== secondSuccessor
)
588 // This is the function being tested:
589 rewriter
.replaceAllUsesWith(firstSuccessor
, secondSuccessor
);
590 // Using the following line instead would make the test fail:
591 // firstSuccessor->replaceAllUsesWith(secondSuccessor);
596 // Changes the successor to the parent block.
597 class ImplicitChangeOp
: public RewritePattern
{
599 ImplicitChangeOp(MLIRContext
*context
)
600 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context
) {}
601 LogicalResult
matchAndRewrite(Operation
*op
,
602 PatternRewriter
&rewriter
) const override
{
603 if (op
->getNumSuccessors() < 1 || op
->getSuccessor(0) == op
->getBlock())
605 rewriter
.modifyOpInPlace(op
,
606 [&]() { op
->setSuccessor(op
->getBlock(), 0); });
612 struct TestWalkPatternDriver final
613 : PassWrapper
<TestWalkPatternDriver
, OperationPass
<>> {
614 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver
)
616 TestWalkPatternDriver() = default;
617 TestWalkPatternDriver(const TestWalkPatternDriver
&other
)
618 : PassWrapper(other
) {}
620 StringRef
getArgument() const override
{
621 return "test-walk-pattern-rewrite-driver";
623 StringRef
getDescription() const override
{
624 return "Run test walk pattern rewrite driver";
626 void runOnOperation() override
{
627 mlir::RewritePatternSet
patterns(&getContext());
629 // Patterns for testing the WalkPatternRewriteDriver.
630 patterns
.add
<IncrementIntAttribute
<3>, MoveBeforeParentOp
,
631 MoveAfterParentOp
, CloneOp
, ReplaceWithNewOp
, EraseFirstBlock
>(
634 DumpNotifications dumpListener
;
635 walkAndApplyPatterns(getOperation(), std::move(patterns
),
636 dumpNotifications
? &dumpListener
: nullptr);
639 Option
<bool> dumpNotifications
{
640 *this, "dump-notifications",
641 llvm::cl::desc("Print rewrite listener notifications"),
642 llvm::cl::init(false)};
647 //===----------------------------------------------------------------------===//
648 // ReturnType Driver.
649 //===----------------------------------------------------------------------===//
652 // Generate ops for each instance where the type can be successfully inferred.
653 template <typename OpTy
>
654 static void invokeCreateWithInferredReturnType(Operation
*op
) {
655 auto *context
= op
->getContext();
656 auto fop
= op
->getParentOfType
<func::FuncOp
>();
657 auto location
= UnknownLoc::get(context
);
659 b
.setInsertionPointAfter(op
);
661 // Use permutations of 2 args as operands.
662 assert(fop
.getNumArguments() >= 2);
663 for (int i
= 0, e
= fop
.getNumArguments(); i
< e
; ++i
) {
664 for (int j
= 0; j
< e
; ++j
) {
665 std::array
<Value
, 2> values
= {{fop
.getArgument(i
), fop
.getArgument(j
)}};
666 SmallVector
<Type
, 2> inferredReturnTypes
;
667 if (succeeded(OpTy::inferReturnTypes(
668 context
, std::nullopt
, values
, op
->getDiscardableAttrDictionary(),
669 op
->getPropertiesStorage(), op
->getRegions(),
670 inferredReturnTypes
))) {
671 OperationState
state(location
, OpTy::getOperationName());
672 // TODO: Expand to regions.
673 OpTy::build(b
, state
, values
, op
->getAttrs());
674 (void)b
.create(state
);
680 static void reifyReturnShape(Operation
*op
) {
683 // Use permutations of 2 args as operands.
684 auto shapedOp
= cast
<OpWithShapedTypeInferTypeInterfaceOp
>(op
);
685 SmallVector
<Value
, 2> shapes
;
686 if (failed(shapedOp
.reifyReturnTypeShapes(b
, op
->getOperands(), shapes
)) ||
687 !llvm::hasSingleElement(shapes
))
689 for (const auto &it
: llvm::enumerate(shapes
)) {
690 op
->emitRemark() << "value " << it
.index() << ": "
691 << it
.value().getDefiningOp();
695 struct TestReturnTypeDriver
696 : public PassWrapper
<TestReturnTypeDriver
, OperationPass
<func::FuncOp
>> {
697 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver
)
699 void getDependentDialects(DialectRegistry
®istry
) const override
{
700 registry
.insert
<tensor::TensorDialect
>();
702 StringRef
getArgument() const final
{ return "test-return-type"; }
703 StringRef
getDescription() const final
{ return "Run return type functions"; }
705 void runOnOperation() override
{
706 if (getOperation().getName() == "testCreateFunctions") {
707 std::vector
<Operation
*> ops
;
708 // Collect ops to avoid triggering on inserted ops.
709 for (auto &op
: getOperation().getBody().front())
711 // Generate test patterns for each, but skip terminator.
712 for (auto *op
: llvm::ArrayRef(ops
).drop_back()) {
713 // Test create method of each of the Op classes below. The resultant
714 // output would be in reverse order underneath `op` from which
715 // the attributes and regions are used.
716 invokeCreateWithInferredReturnType
<OpWithInferTypeInterfaceOp
>(op
);
717 invokeCreateWithInferredReturnType
<OpWithInferTypeAdaptorInterfaceOp
>(
719 invokeCreateWithInferredReturnType
<
720 OpWithShapedTypeInferTypeInterfaceOp
>(op
);
724 if (getOperation().getName() == "testReifyFunctions") {
725 std::vector
<Operation
*> ops
;
726 // Collect ops to avoid triggering on inserted ops.
727 for (auto &op
: getOperation().getBody().front())
728 if (isa
<OpWithShapedTypeInferTypeInterfaceOp
>(op
))
730 // Generate test patterns for each, but skip terminator.
732 reifyReturnShape(op
);
739 struct TestDerivedAttributeDriver
740 : public PassWrapper
<TestDerivedAttributeDriver
,
741 OperationPass
<func::FuncOp
>> {
742 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver
)
744 StringRef
getArgument() const final
{ return "test-derived-attr"; }
745 StringRef
getDescription() const final
{
746 return "Run test derived attributes";
748 void runOnOperation() override
;
752 void TestDerivedAttributeDriver::runOnOperation() {
753 getOperation().walk([](DerivedAttributeOpInterface dOp
) {
754 auto dAttr
= dOp
.materializeDerivedAttributes();
758 dOp
.emitRemark() << d
.getName().getValue() << " = " << d
.getValue();
762 //===----------------------------------------------------------------------===//
763 // Legalization Driver.
764 //===----------------------------------------------------------------------===//
767 //===----------------------------------------------------------------------===//
768 // Region-Block Rewrite Testing
770 /// This pattern applies a signature conversion to a block inside a detached
772 struct TestDetachedSignatureConversion
: public ConversionPattern
{
773 TestDetachedSignatureConversion(MLIRContext
*ctx
)
774 : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
778 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
779 ConversionPatternRewriter
&rewriter
) const final
{
780 if (op
->getNumRegions() != 1)
782 OperationState
state(op
->getLoc(), "test.legal_op_with_region", operands
,
783 op
->getResultTypes(), {}, BlockRange());
784 Region
*newRegion
= state
.addRegion();
785 rewriter
.inlineRegionBefore(op
->getRegion(0), *newRegion
,
787 TypeConverter::SignatureConversion
result(newRegion
->getNumArguments());
788 for (unsigned i
= 0, e
= newRegion
->getNumArguments(); i
< e
; ++i
)
789 result
.addInputs(i
, rewriter
.getF64Type());
790 rewriter
.applySignatureConversion(&newRegion
->front(), result
);
791 Operation
*newOp
= rewriter
.create(state
);
792 rewriter
.replaceOp(op
, newOp
->getResults());
797 /// This pattern is a simple pattern that inlines the first region of a given
798 /// operation into the parent region.
799 struct TestRegionRewriteBlockMovement
: public ConversionPattern
{
800 TestRegionRewriteBlockMovement(MLIRContext
*ctx
)
801 : ConversionPattern("test.region", 1, ctx
) {}
804 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
805 ConversionPatternRewriter
&rewriter
) const final
{
806 // Inline this region into the parent region.
807 auto &parentRegion
= *op
->getParentRegion();
808 auto &opRegion
= op
->getRegion(0);
809 if (op
->getDiscardableAttr("legalizer.should_clone"))
810 rewriter
.cloneRegionBefore(opRegion
, parentRegion
, parentRegion
.end());
812 rewriter
.inlineRegionBefore(opRegion
, parentRegion
, parentRegion
.end());
814 if (op
->getDiscardableAttr("legalizer.erase_old_blocks")) {
815 while (!opRegion
.empty())
816 rewriter
.eraseBlock(&opRegion
.front());
819 // Drop this operation.
820 rewriter
.eraseOp(op
);
824 /// This pattern is a simple pattern that generates a region containing an
825 /// illegal operation.
826 struct TestRegionRewriteUndo
: public RewritePattern
{
827 TestRegionRewriteUndo(MLIRContext
*ctx
)
828 : RewritePattern("test.region_builder", 1, ctx
) {}
830 LogicalResult
matchAndRewrite(Operation
*op
,
831 PatternRewriter
&rewriter
) const final
{
832 // Create the region operation with an entry block containing arguments.
833 OperationState
newRegion(op
->getLoc(), "test.region");
834 newRegion
.addRegion();
835 auto *regionOp
= rewriter
.create(newRegion
);
836 auto *entryBlock
= rewriter
.createBlock(®ionOp
->getRegion(0));
837 entryBlock
->addArgument(rewriter
.getIntegerType(64),
838 rewriter
.getUnknownLoc());
840 // Add an explicitly illegal operation to ensure the conversion fails.
841 rewriter
.create
<ILLegalOpF
>(op
->getLoc(), rewriter
.getIntegerType(32));
842 rewriter
.create
<TestValidOp
>(op
->getLoc(), ArrayRef
<Value
>());
844 // Drop this operation.
845 rewriter
.eraseOp(op
);
849 /// A simple pattern that creates a block at the end of the parent region of the
850 /// matched operation.
851 struct TestCreateBlock
: public RewritePattern
{
852 TestCreateBlock(MLIRContext
*ctx
)
853 : RewritePattern("test.create_block", /*benefit=*/1, ctx
) {}
855 LogicalResult
matchAndRewrite(Operation
*op
,
856 PatternRewriter
&rewriter
) const final
{
857 Region
®ion
= *op
->getParentRegion();
858 Type i32Type
= rewriter
.getIntegerType(32);
859 Location loc
= op
->getLoc();
860 rewriter
.createBlock(®ion
, region
.end(), {i32Type
, i32Type
}, {loc
, loc
});
861 rewriter
.create
<TerminatorOp
>(loc
);
862 rewriter
.eraseOp(op
);
867 /// A simple pattern that creates a block containing an invalid operation in
868 /// order to trigger the block creation undo mechanism.
869 struct TestCreateIllegalBlock
: public RewritePattern
{
870 TestCreateIllegalBlock(MLIRContext
*ctx
)
871 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx
) {}
873 LogicalResult
matchAndRewrite(Operation
*op
,
874 PatternRewriter
&rewriter
) const final
{
875 Region
®ion
= *op
->getParentRegion();
876 Type i32Type
= rewriter
.getIntegerType(32);
877 Location loc
= op
->getLoc();
878 rewriter
.createBlock(®ion
, region
.end(), {i32Type
, i32Type
}, {loc
, loc
});
879 // Create an illegal op to ensure the conversion fails.
880 rewriter
.create
<ILLegalOpF
>(loc
, i32Type
);
881 rewriter
.create
<TerminatorOp
>(loc
);
882 rewriter
.eraseOp(op
);
887 /// A simple pattern that tests the undo mechanism when replacing the uses of a
889 struct TestUndoBlockArgReplace
: public ConversionPattern
{
890 TestUndoBlockArgReplace(MLIRContext
*ctx
)
891 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx
) {}
894 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
895 ConversionPatternRewriter
&rewriter
) const final
{
897 rewriter
.create
<ILLegalOpF
>(op
->getLoc(), rewriter
.getF32Type());
898 rewriter
.replaceUsesOfBlockArgument(op
->getRegion(0).getArgument(0),
899 illegalOp
->getResult(0));
900 rewriter
.modifyOpInPlace(op
, [] {});
905 /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
906 /// This is to test the rollback logic.
907 struct TestUndoMoveOpBefore
: public ConversionPattern
{
908 TestUndoMoveOpBefore(MLIRContext
*ctx
)
909 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx
) {}
912 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
913 ConversionPatternRewriter
&rewriter
) const override
{
914 rewriter
.moveOpBefore(op
, op
->getParentOp());
915 // Replace with an illegal op to ensure the conversion fails.
916 rewriter
.replaceOpWithNewOp
<ILLegalOpF
>(op
, rewriter
.getF32Type());
921 /// A rewrite pattern that tests the undo mechanism when erasing a block.
922 struct TestUndoBlockErase
: public ConversionPattern
{
923 TestUndoBlockErase(MLIRContext
*ctx
)
924 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx
) {}
927 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
928 ConversionPatternRewriter
&rewriter
) const final
{
929 Block
*secondBlock
= &*std::next(op
->getRegion(0).begin());
930 rewriter
.setInsertionPointToStart(secondBlock
);
931 rewriter
.create
<ILLegalOpF
>(op
->getLoc(), rewriter
.getF32Type());
932 rewriter
.eraseBlock(secondBlock
);
933 rewriter
.modifyOpInPlace(op
, [] {});
938 /// A pattern that modifies a property in-place, but keeps the op illegal.
939 struct TestUndoPropertiesModification
: public ConversionPattern
{
940 TestUndoPropertiesModification(MLIRContext
*ctx
)
941 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx
) {}
943 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
944 ConversionPatternRewriter
&rewriter
) const final
{
945 if (!op
->hasAttr("modify_inplace"))
947 rewriter
.modifyOpInPlace(
948 op
, [&]() { cast
<TestOpWithProperties
>(op
).getProperties().setA(42); });
953 //===----------------------------------------------------------------------===//
954 // Type-Conversion Rewrite Testing
956 /// This patterns erases a region operation that has had a type conversion.
957 struct TestDropOpSignatureConversion
: public ConversionPattern
{
958 TestDropOpSignatureConversion(MLIRContext
*ctx
,
959 const TypeConverter
&converter
)
960 : ConversionPattern(converter
, "test.drop_region_op", 1, ctx
) {}
962 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
963 ConversionPatternRewriter
&rewriter
) const override
{
964 Region
®ion
= op
->getRegion(0);
965 Block
*entry
= ®ion
.front();
967 // Convert the original entry arguments.
968 const TypeConverter
&converter
= *getTypeConverter();
969 TypeConverter::SignatureConversion
result(entry
->getNumArguments());
970 if (failed(converter
.convertSignatureArgs(entry
->getArgumentTypes(),
972 failed(rewriter
.convertRegionTypes(®ion
, converter
, &result
)))
975 // Convert the region signature and just drop the operation.
976 rewriter
.eraseOp(op
);
980 /// This pattern simply updates the operands of the given operation.
981 struct TestPassthroughInvalidOp
: public ConversionPattern
{
982 TestPassthroughInvalidOp(MLIRContext
*ctx
)
983 : ConversionPattern("test.invalid", 1, ctx
) {}
985 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
986 ConversionPatternRewriter
&rewriter
) const final
{
987 rewriter
.replaceOpWithNewOp
<TestValidOp
>(op
, std::nullopt
, operands
,
992 /// Replace with valid op, but simply drop the operands. This is used in a
993 /// regression where we used to generate circular unrealized_conversion_cast
995 struct TestDropAndReplaceInvalidOp
: public ConversionPattern
{
996 TestDropAndReplaceInvalidOp(MLIRContext
*ctx
, const TypeConverter
&converter
)
997 : ConversionPattern(converter
,
998 "test.drop_operands_and_replace_with_valid", 1, ctx
) {
1001 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1002 ConversionPatternRewriter
&rewriter
) const final
{
1003 rewriter
.replaceOpWithNewOp
<TestValidOp
>(op
, std::nullopt
, ValueRange(),
1008 /// This pattern handles the case of a split return value.
1009 struct TestSplitReturnType
: public ConversionPattern
{
1010 TestSplitReturnType(MLIRContext
*ctx
)
1011 : ConversionPattern("test.return", 1, ctx
) {}
1013 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1014 ConversionPatternRewriter
&rewriter
) const final
{
1015 // Check for a return of F32.
1016 if (op
->getNumOperands() != 1 || !op
->getOperand(0).getType().isF32())
1019 // Check if the first operation is a cast operation, if it is we use the
1020 // results directly.
1021 auto *defOp
= operands
[0].getDefiningOp();
1023 llvm::dyn_cast_or_null
<UnrealizedConversionCastOp
>(defOp
)) {
1024 rewriter
.replaceOpWithNewOp
<TestReturnOp
>(op
, packerOp
.getOperands());
1028 // Otherwise, fail to match.
1033 //===----------------------------------------------------------------------===//
1034 // Multi-Level Type-Conversion Rewrite Testing
1035 struct TestChangeProducerTypeI32ToF32
: public ConversionPattern
{
1036 TestChangeProducerTypeI32ToF32(MLIRContext
*ctx
)
1037 : ConversionPattern("test.type_producer", 1, ctx
) {}
1039 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1040 ConversionPatternRewriter
&rewriter
) const final
{
1041 // If the type is I32, change the type to F32.
1042 if (!Type(*op
->result_type_begin()).isSignlessInteger(32))
1044 rewriter
.replaceOpWithNewOp
<TestTypeProducerOp
>(op
, rewriter
.getF32Type());
1048 struct TestChangeProducerTypeF32ToF64
: public ConversionPattern
{
1049 TestChangeProducerTypeF32ToF64(MLIRContext
*ctx
)
1050 : ConversionPattern("test.type_producer", 1, ctx
) {}
1052 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1053 ConversionPatternRewriter
&rewriter
) const final
{
1054 // If the type is F32, change the type to F64.
1055 if (!Type(*op
->result_type_begin()).isF32())
1056 return rewriter
.notifyMatchFailure(op
, "expected single f32 operand");
1057 rewriter
.replaceOpWithNewOp
<TestTypeProducerOp
>(op
, rewriter
.getF64Type());
1061 struct TestChangeProducerTypeF32ToInvalid
: public ConversionPattern
{
1062 TestChangeProducerTypeF32ToInvalid(MLIRContext
*ctx
)
1063 : ConversionPattern("test.type_producer", 10, ctx
) {}
1065 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1066 ConversionPatternRewriter
&rewriter
) const final
{
1067 // Always convert to B16, even though it is not a legal type. This tests
1068 // that values are unmapped correctly.
1069 rewriter
.replaceOpWithNewOp
<TestTypeProducerOp
>(op
, rewriter
.getBF16Type());
1073 struct TestUpdateConsumerType
: public ConversionPattern
{
1074 TestUpdateConsumerType(MLIRContext
*ctx
)
1075 : ConversionPattern("test.type_consumer", 1, ctx
) {}
1077 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1078 ConversionPatternRewriter
&rewriter
) const final
{
1079 // Verify that the incoming operand has been successfully remapped to F64.
1080 if (!operands
[0].getType().isF64())
1082 rewriter
.replaceOpWithNewOp
<TestTypeConsumerOp
>(op
, operands
[0]);
1087 //===----------------------------------------------------------------------===//
1088 // Non-Root Replacement Rewrite Testing
1089 /// This pattern generates an invalid operation, but replaces it before the
1090 /// pattern is finished. This checks that we don't need to legalize the
1092 struct TestNonRootReplacement
: public RewritePattern
{
1093 TestNonRootReplacement(MLIRContext
*ctx
)
1094 : RewritePattern("test.replace_non_root", 1, ctx
) {}
1096 LogicalResult
matchAndRewrite(Operation
*op
,
1097 PatternRewriter
&rewriter
) const final
{
1098 auto resultType
= *op
->result_type_begin();
1099 auto illegalOp
= rewriter
.create
<ILLegalOpF
>(op
->getLoc(), resultType
);
1100 auto legalOp
= rewriter
.create
<LegalOpB
>(op
->getLoc(), resultType
);
1102 rewriter
.replaceOp(illegalOp
, legalOp
);
1103 rewriter
.replaceOp(op
, illegalOp
);
1108 //===----------------------------------------------------------------------===//
1109 // Recursive Rewrite Testing
1110 /// This pattern is applied to the same operation multiple times, but has a
1111 /// bounded recursion.
1112 struct TestBoundedRecursiveRewrite
1113 : public OpRewritePattern
<TestRecursiveRewriteOp
> {
1114 using OpRewritePattern
<TestRecursiveRewriteOp
>::OpRewritePattern
;
1117 // The conversion target handles bounding the recursion of this pattern.
1118 setHasBoundedRewriteRecursion();
1121 LogicalResult
matchAndRewrite(TestRecursiveRewriteOp op
,
1122 PatternRewriter
&rewriter
) const final
{
1123 // Decrement the depth of the op in-place.
1124 rewriter
.modifyOpInPlace(op
, [&] {
1125 op
->setAttr("depth", rewriter
.getI64IntegerAttr(op
.getDepth() - 1));
1131 struct TestNestedOpCreationUndoRewrite
1132 : public OpRewritePattern
<IllegalOpWithRegionAnchor
> {
1133 using OpRewritePattern
<IllegalOpWithRegionAnchor
>::OpRewritePattern
;
1135 LogicalResult
matchAndRewrite(IllegalOpWithRegionAnchor op
,
1136 PatternRewriter
&rewriter
) const final
{
1137 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1138 rewriter
.replaceOpWithNewOp
<IllegalOpWithRegion
>(op
);
1143 // This pattern matches `test.blackhole` and delete this op and its producer.
1144 struct TestReplaceEraseOp
: public OpRewritePattern
<BlackHoleOp
> {
1145 using OpRewritePattern
<BlackHoleOp
>::OpRewritePattern
;
1147 LogicalResult
matchAndRewrite(BlackHoleOp op
,
1148 PatternRewriter
&rewriter
) const final
{
1149 Operation
*producer
= op
.getOperand().getDefiningOp();
1150 // Always erase the user before the producer, the framework should handle
1152 rewriter
.eraseOp(op
);
1153 rewriter
.eraseOp(producer
);
1158 // This pattern replaces explicitly illegal op with explicitly legal op,
1159 // but in addition creates unregistered operation.
1160 struct TestCreateUnregisteredOp
: public OpRewritePattern
<ILLegalOpG
> {
1161 using OpRewritePattern
<ILLegalOpG
>::OpRewritePattern
;
1163 LogicalResult
matchAndRewrite(ILLegalOpG op
,
1164 PatternRewriter
&rewriter
) const final
{
1165 IntegerAttr attr
= rewriter
.getI32IntegerAttr(0);
1166 Value val
= rewriter
.create
<arith::ConstantOp
>(op
->getLoc(), attr
);
1167 rewriter
.replaceOpWithNewOp
<LegalOpC
>(op
, val
);
1172 class TestEraseOp
: public ConversionPattern
{
1174 TestEraseOp(MLIRContext
*ctx
) : ConversionPattern("test.erase_op", 1, ctx
) {}
1176 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1177 ConversionPatternRewriter
&rewriter
) const final
{
1178 // Erase op without replacements.
1179 rewriter
.eraseOp(op
);
1187 struct TestTypeConverter
: public TypeConverter
{
1188 using TypeConverter::TypeConverter
;
1189 TestTypeConverter() {
1190 addConversion(convertType
);
1191 addArgumentMaterialization(materializeCast
);
1192 addSourceMaterialization(materializeCast
);
1195 static LogicalResult
convertType(Type t
, SmallVectorImpl
<Type
> &results
) {
1197 if (t
.isSignlessInteger(16))
1200 // Convert I64 to F64.
1201 if (t
.isSignlessInteger(64)) {
1202 results
.push_back(FloatType::getF64(t
.getContext()));
1206 // Convert I42 to I43.
1207 if (t
.isInteger(42)) {
1208 results
.push_back(IntegerType::get(t
.getContext(), 43));
1212 // Split F32 into F16,F16.
1214 results
.assign(2, FloatType::getF16(t
.getContext()));
1218 // Otherwise, convert the type directly.
1219 results
.push_back(t
);
1223 /// Hook for materializing a conversion. This is necessary because we generate
1224 /// 1->N type mappings.
1225 static Value
materializeCast(OpBuilder
&builder
, Type resultType
,
1226 ValueRange inputs
, Location loc
) {
1227 return builder
.create
<TestCastOp
>(loc
, resultType
, inputs
).getResult();
1231 struct TestLegalizePatternDriver
1232 : public PassWrapper
<TestLegalizePatternDriver
, OperationPass
<>> {
1233 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver
)
1235 StringRef
getArgument() const final
{ return "test-legalize-patterns"; }
1236 StringRef
getDescription() const final
{
1237 return "Run test dialect legalization patterns";
1239 /// The mode of conversion to use with the driver.
1240 enum class ConversionMode
{ Analysis
, Full
, Partial
};
1242 TestLegalizePatternDriver(ConversionMode mode
) : mode(mode
) {}
1244 void getDependentDialects(DialectRegistry
®istry
) const override
{
1245 registry
.insert
<func::FuncDialect
, test::TestDialect
>();
1248 void runOnOperation() override
{
1249 TestTypeConverter converter
;
1250 mlir::RewritePatternSet
patterns(&getContext());
1251 populateWithGenerated(patterns
);
1253 TestRegionRewriteBlockMovement
, TestDetachedSignatureConversion
,
1254 TestRegionRewriteUndo
, TestCreateBlock
, TestCreateIllegalBlock
,
1255 TestUndoBlockArgReplace
, TestUndoBlockErase
, TestPassthroughInvalidOp
,
1256 TestSplitReturnType
, TestChangeProducerTypeI32ToF32
,
1257 TestChangeProducerTypeF32ToF64
, TestChangeProducerTypeF32ToInvalid
,
1258 TestUpdateConsumerType
, TestNonRootReplacement
,
1259 TestBoundedRecursiveRewrite
, TestNestedOpCreationUndoRewrite
,
1260 TestReplaceEraseOp
, TestCreateUnregisteredOp
, TestUndoMoveOpBefore
,
1261 TestUndoPropertiesModification
, TestEraseOp
>(&getContext());
1262 patterns
.add
<TestDropOpSignatureConversion
, TestDropAndReplaceInvalidOp
>(
1263 &getContext(), converter
);
1264 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns
,
1266 mlir::populateCallOpTypeConversionPattern(patterns
, converter
);
1268 // Define the conversion target used for the test.
1269 ConversionTarget
target(getContext());
1270 target
.addLegalOp
<ModuleOp
>();
1271 target
.addLegalOp
<LegalOpA
, LegalOpB
, LegalOpC
, TestCastOp
, TestValidOp
,
1272 TerminatorOp
, OneRegionOp
>();
1274 OperationName("test.legal_op_with_region", &getContext()));
1276 .addIllegalOp
<ILLegalOpF
, TestRegionBuilderOp
, TestOpWithRegionFold
>();
1277 target
.addDynamicallyLegalOp
<TestReturnOp
>([](TestReturnOp op
) {
1278 // Don't allow F32 operands.
1279 return llvm::none_of(op
.getOperandTypes(),
1280 [](Type type
) { return type
.isF32(); });
1282 target
.addDynamicallyLegalOp
<func::FuncOp
>([&](func::FuncOp op
) {
1283 return converter
.isSignatureLegal(op
.getFunctionType()) &&
1284 converter
.isLegal(&op
.getBody());
1286 target
.addDynamicallyLegalOp
<func::CallOp
>(
1287 [&](func::CallOp op
) { return converter
.isLegal(op
); });
1289 // TestCreateUnregisteredOp creates `arith.constant` operation,
1290 // which was not added to target intentionally to test
1291 // correct error code from conversion driver.
1292 target
.addDynamicallyLegalOp
<ILLegalOpG
>([](ILLegalOpG
) { return false; });
1294 // Expect the type_producer/type_consumer operations to only operate on f64.
1295 target
.addDynamicallyLegalOp
<TestTypeProducerOp
>(
1296 [](TestTypeProducerOp op
) { return op
.getType().isF64(); });
1297 target
.addDynamicallyLegalOp
<TestTypeConsumerOp
>([](TestTypeConsumerOp op
) {
1298 return op
.getOperand().getType().isF64();
1301 // Check support for marking certain operations as recursively legal.
1302 target
.markOpRecursivelyLegal
<func::FuncOp
, ModuleOp
>([](Operation
*op
) {
1303 return static_cast<bool>(
1304 op
->getAttrOfType
<UnitAttr
>("test.recursively_legal"));
1307 // Mark the bound recursion operation as dynamically legal.
1308 target
.addDynamicallyLegalOp
<TestRecursiveRewriteOp
>(
1309 [](TestRecursiveRewriteOp op
) { return op
.getDepth() == 0; });
1311 // Create a dynamically legal rule that can only be legalized by folding it.
1312 target
.addDynamicallyLegalOp
<TestOpInPlaceSelfFold
>(
1313 [](TestOpInPlaceSelfFold op
) { return op
.getFolded(); });
1315 // Handle a partial conversion.
1316 if (mode
== ConversionMode::Partial
) {
1317 DenseSet
<Operation
*> unlegalizedOps
;
1318 ConversionConfig config
;
1319 DumpNotifications dumpNotifications
;
1320 config
.listener
= &dumpNotifications
;
1321 config
.unlegalizedOps
= &unlegalizedOps
;
1322 if (failed(applyPartialConversion(getOperation(), target
,
1323 std::move(patterns
), config
))) {
1324 getOperation()->emitRemark() << "applyPartialConversion failed";
1326 // Emit remarks for each legalizable operation.
1327 for (auto *op
: unlegalizedOps
)
1328 op
->emitRemark() << "op '" << op
->getName() << "' is not legalizable";
1332 // Handle a full conversion.
1333 if (mode
== ConversionMode::Full
) {
1334 // Check support for marking unknown operations as dynamically legal.
1335 target
.markUnknownOpDynamicallyLegal([](Operation
*op
) {
1336 return (bool)op
->getAttrOfType
<UnitAttr
>("test.dynamically_legal");
1339 ConversionConfig config
;
1340 DumpNotifications dumpNotifications
;
1341 config
.listener
= &dumpNotifications
;
1342 if (failed(applyFullConversion(getOperation(), target
,
1343 std::move(patterns
), config
))) {
1344 getOperation()->emitRemark() << "applyFullConversion failed";
1349 // Otherwise, handle an analysis conversion.
1350 assert(mode
== ConversionMode::Analysis
);
1352 // Analyze the convertible operations.
1353 DenseSet
<Operation
*> legalizedOps
;
1354 ConversionConfig config
;
1355 config
.legalizableOps
= &legalizedOps
;
1356 if (failed(applyAnalysisConversion(getOperation(), target
,
1357 std::move(patterns
), config
)))
1358 return signalPassFailure();
1360 // Emit remarks for each legalizable operation.
1361 for (auto *op
: legalizedOps
)
1362 op
->emitRemark() << "op '" << op
->getName() << "' is legalizable";
1365 /// The mode of conversion to use.
1366 ConversionMode mode
;
1370 static llvm::cl::opt
<TestLegalizePatternDriver::ConversionMode
>
1371 legalizerConversionMode(
1372 "test-legalize-mode",
1373 llvm::cl::desc("The legalization mode to use with the test driver"),
1374 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial
),
1376 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis
,
1377 "analysis", "Perform an analysis conversion"),
1378 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full
, "full",
1379 "Perform a full conversion"),
1380 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial
,
1381 "partial", "Perform a partial conversion")));
1383 //===----------------------------------------------------------------------===//
1384 // ConversionPatternRewriter::getRemappedValue testing. This method is used
1385 // to get the remapped value of an original value that was replaced using
1386 // ConversionPatternRewriter.
1388 struct TestRemapValueTypeConverter
: public TypeConverter
{
1389 using TypeConverter::TypeConverter
;
1391 TestRemapValueTypeConverter() {
1393 [](Float32Type type
) { return Float64Type::get(type
.getContext()); });
1394 addConversion([](Type type
) { return type
; });
1398 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1399 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1403 /// %1 = test.one_variadic_out_one_variadic_in1"(%0)
1404 /// is replaced with:
1405 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1406 struct OneVResOneVOperandOp1Converter
1407 : public OpConversionPattern
<OneVResOneVOperandOp1
> {
1408 using OpConversionPattern
<OneVResOneVOperandOp1
>::OpConversionPattern
;
1411 matchAndRewrite(OneVResOneVOperandOp1 op
, OpAdaptor adaptor
,
1412 ConversionPatternRewriter
&rewriter
) const override
{
1413 auto origOps
= op
.getOperands();
1414 assert(std::distance(origOps
.begin(), origOps
.end()) == 1 &&
1415 "One operand expected");
1416 Value origOp
= *origOps
.begin();
1417 SmallVector
<Value
, 2> remappedOperands
;
1418 // Replicate the remapped original operand twice. Note that we don't used
1419 // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1420 remappedOperands
.push_back(rewriter
.getRemappedValue(origOp
));
1421 remappedOperands
.push_back(rewriter
.getRemappedValue(origOp
));
1423 rewriter
.replaceOpWithNewOp
<OneVResOneVOperandOp1
>(op
, op
.getResultTypes(),
1429 /// A rewriter pattern that tests that blocks can be merged.
1430 struct TestRemapValueInRegion
1431 : public OpConversionPattern
<TestRemappedValueRegionOp
> {
1432 using OpConversionPattern
<TestRemappedValueRegionOp
>::OpConversionPattern
;
1435 matchAndRewrite(TestRemappedValueRegionOp op
, OpAdaptor adaptor
,
1436 ConversionPatternRewriter
&rewriter
) const final
{
1437 Block
&block
= op
.getBody().front();
1438 Operation
*terminator
= block
.getTerminator();
1440 // Merge the block into the parent region.
1441 Block
*parentBlock
= op
->getBlock();
1442 Block
*finalBlock
= rewriter
.splitBlock(parentBlock
, op
->getIterator());
1443 rewriter
.mergeBlocks(&block
, parentBlock
, ValueRange());
1444 rewriter
.mergeBlocks(finalBlock
, parentBlock
, ValueRange());
1446 // Replace the results of this operation with the remapped terminator
1448 SmallVector
<Value
> terminatorOperands
;
1449 if (failed(rewriter
.getRemappedValues(terminator
->getOperands(),
1450 terminatorOperands
)))
1453 rewriter
.eraseOp(terminator
);
1454 rewriter
.replaceOp(op
, terminatorOperands
);
1459 struct TestRemappedValue
1460 : public mlir::PassWrapper
<TestRemappedValue
, OperationPass
<>> {
1461 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue
)
1463 StringRef
getArgument() const final
{ return "test-remapped-value"; }
1464 StringRef
getDescription() const final
{
1465 return "Test public remapped value mechanism in ConversionPatternRewriter";
1467 void runOnOperation() override
{
1468 TestRemapValueTypeConverter typeConverter
;
1470 mlir::RewritePatternSet
patterns(&getContext());
1471 patterns
.add
<OneVResOneVOperandOp1Converter
>(&getContext());
1472 patterns
.add
<TestChangeProducerTypeF32ToF64
, TestUpdateConsumerType
>(
1474 patterns
.add
<TestRemapValueInRegion
>(typeConverter
, &getContext());
1476 mlir::ConversionTarget
target(getContext());
1477 target
.addLegalOp
<ModuleOp
, func::FuncOp
, TestReturnOp
>();
1479 // Expect the type_producer/type_consumer operations to only operate on f64.
1480 target
.addDynamicallyLegalOp
<TestTypeProducerOp
>(
1481 [](TestTypeProducerOp op
) { return op
.getType().isF64(); });
1482 target
.addDynamicallyLegalOp
<TestTypeConsumerOp
>([](TestTypeConsumerOp op
) {
1483 return op
.getOperand().getType().isF64();
1486 // We make OneVResOneVOperandOp1 legal only when it has more that one
1487 // operand. This will trigger the conversion that will replace one-operand
1488 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1489 target
.addDynamicallyLegalOp
<OneVResOneVOperandOp1
>(
1490 [](Operation
*op
) { return op
->getNumOperands() > 1; });
1492 if (failed(mlir::applyFullConversion(getOperation(), target
,
1493 std::move(patterns
)))) {
1494 signalPassFailure();
1500 //===----------------------------------------------------------------------===//
1501 // Test patterns without a specific root operation kind
1502 //===----------------------------------------------------------------------===//
1505 /// This pattern matches and removes any operation in the test dialect.
1506 struct RemoveTestDialectOps
: public RewritePattern
{
1507 RemoveTestDialectOps(MLIRContext
*context
)
1508 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context
) {}
1510 LogicalResult
matchAndRewrite(Operation
*op
,
1511 PatternRewriter
&rewriter
) const override
{
1512 if (!isa
<TestDialect
>(op
->getDialect()))
1514 rewriter
.eraseOp(op
);
1519 struct TestUnknownRootOpDriver
1520 : public mlir::PassWrapper
<TestUnknownRootOpDriver
, OperationPass
<>> {
1521 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver
)
1523 StringRef
getArgument() const final
{
1524 return "test-legalize-unknown-root-patterns";
1526 StringRef
getDescription() const final
{
1527 return "Test public remapped value mechanism in ConversionPatternRewriter";
1529 void runOnOperation() override
{
1530 mlir::RewritePatternSet
patterns(&getContext());
1531 patterns
.add
<RemoveTestDialectOps
>(&getContext());
1533 mlir::ConversionTarget
target(getContext());
1534 target
.addIllegalDialect
<TestDialect
>();
1535 if (failed(applyPartialConversion(getOperation(), target
,
1536 std::move(patterns
))))
1537 signalPassFailure();
1542 //===----------------------------------------------------------------------===//
1543 // Test patterns that uses operations and types defined at runtime
1544 //===----------------------------------------------------------------------===//
1547 /// This pattern matches dynamic operations 'test.one_operand_two_results' and
1548 /// replace them with dynamic operations 'test.generic_dynamic_op'.
1549 struct RewriteDynamicOp
: public RewritePattern
{
1550 RewriteDynamicOp(MLIRContext
*context
)
1551 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1554 LogicalResult
matchAndRewrite(Operation
*op
,
1555 PatternRewriter
&rewriter
) const override
{
1556 assert(op
->getName().getStringRef() ==
1557 "test.dynamic_one_operand_two_results" &&
1558 "rewrite pattern should only match operations with the right name");
1560 OperationState
state(op
->getLoc(), "test.dynamic_generic",
1561 op
->getOperands(), op
->getResultTypes(),
1563 auto *newOp
= rewriter
.create(state
);
1564 rewriter
.replaceOp(op
, newOp
->getResults());
1569 struct TestRewriteDynamicOpDriver
1570 : public PassWrapper
<TestRewriteDynamicOpDriver
, OperationPass
<>> {
1571 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver
)
1573 void getDependentDialects(DialectRegistry
®istry
) const override
{
1574 registry
.insert
<TestDialect
>();
1576 StringRef
getArgument() const final
{ return "test-rewrite-dynamic-op"; }
1577 StringRef
getDescription() const final
{
1578 return "Test rewritting on dynamic operations";
1580 void runOnOperation() override
{
1581 RewritePatternSet
patterns(&getContext());
1582 patterns
.add
<RewriteDynamicOp
>(&getContext());
1584 ConversionTarget
target(getContext());
1585 target
.addIllegalOp(
1586 OperationName("test.dynamic_one_operand_two_results", &getContext()));
1587 target
.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
1588 if (failed(applyPartialConversion(getOperation(), target
,
1589 std::move(patterns
))))
1590 signalPassFailure();
1593 } // end anonymous namespace
1595 //===----------------------------------------------------------------------===//
1596 // Test type conversions
1597 //===----------------------------------------------------------------------===//
1600 struct TestTypeConversionProducer
1601 : public OpConversionPattern
<TestTypeProducerOp
> {
1602 using OpConversionPattern
<TestTypeProducerOp
>::OpConversionPattern
;
1604 matchAndRewrite(TestTypeProducerOp op
, OpAdaptor adaptor
,
1605 ConversionPatternRewriter
&rewriter
) const final
{
1606 Type resultType
= op
.getType();
1607 Type convertedType
= getTypeConverter()
1608 ? getTypeConverter()->convertType(resultType
)
1610 if (isa
<FloatType
>(resultType
))
1611 resultType
= rewriter
.getF64Type();
1612 else if (resultType
.isInteger(16))
1613 resultType
= rewriter
.getIntegerType(64);
1614 else if (isa
<test::TestRecursiveType
>(resultType
) &&
1615 convertedType
!= resultType
)
1616 resultType
= convertedType
;
1620 rewriter
.replaceOpWithNewOp
<TestTypeProducerOp
>(op
, resultType
);
1625 /// Call signature conversion and then fail the rewrite to trigger the undo
1627 struct TestSignatureConversionUndo
1628 : public OpConversionPattern
<TestSignatureConversionUndoOp
> {
1629 using OpConversionPattern
<TestSignatureConversionUndoOp
>::OpConversionPattern
;
1632 matchAndRewrite(TestSignatureConversionUndoOp op
, OpAdaptor adaptor
,
1633 ConversionPatternRewriter
&rewriter
) const final
{
1634 (void)rewriter
.convertRegionTypes(&op
->getRegion(0), *getTypeConverter());
1639 /// Call signature conversion without providing a type converter to handle
1640 /// materializations.
1641 struct TestTestSignatureConversionNoConverter
1642 : public OpConversionPattern
<TestSignatureConversionNoConverterOp
> {
1643 TestTestSignatureConversionNoConverter(const TypeConverter
&converter
,
1644 MLIRContext
*context
)
1645 : OpConversionPattern
<TestSignatureConversionNoConverterOp
>(context
),
1646 converter(converter
) {}
1649 matchAndRewrite(TestSignatureConversionNoConverterOp op
, OpAdaptor adaptor
,
1650 ConversionPatternRewriter
&rewriter
) const final
{
1651 Region
®ion
= op
->getRegion(0);
1652 Block
*entry
= ®ion
.front();
1654 // Convert the original entry arguments.
1655 TypeConverter::SignatureConversion
result(entry
->getNumArguments());
1657 converter
.convertSignatureArgs(entry
->getArgumentTypes(), result
)))
1659 rewriter
.modifyOpInPlace(op
, [&] {
1660 rewriter
.applySignatureConversion(®ion
.front(), result
);
1665 const TypeConverter
&converter
;
1668 /// Just forward the operands to the root op. This is essentially a no-op
1669 /// pattern that is used to trigger target materialization.
1670 struct TestTypeConsumerForward
1671 : public OpConversionPattern
<TestTypeConsumerOp
> {
1672 using OpConversionPattern
<TestTypeConsumerOp
>::OpConversionPattern
;
1675 matchAndRewrite(TestTypeConsumerOp op
, OpAdaptor adaptor
,
1676 ConversionPatternRewriter
&rewriter
) const final
{
1677 rewriter
.modifyOpInPlace(op
,
1678 [&] { op
->setOperands(adaptor
.getOperands()); });
1683 struct TestTypeConversionAnotherProducer
1684 : public OpRewritePattern
<TestAnotherTypeProducerOp
> {
1685 using OpRewritePattern
<TestAnotherTypeProducerOp
>::OpRewritePattern
;
1687 LogicalResult
matchAndRewrite(TestAnotherTypeProducerOp op
,
1688 PatternRewriter
&rewriter
) const final
{
1689 rewriter
.replaceOpWithNewOp
<TestTypeProducerOp
>(op
, op
.getType());
1694 struct TestReplaceWithLegalOp
: public ConversionPattern
{
1695 TestReplaceWithLegalOp(MLIRContext
*ctx
)
1696 : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx
) {}
1698 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1699 ConversionPatternRewriter
&rewriter
) const final
{
1700 rewriter
.replaceOpWithNewOp
<LegalOpD
>(op
, operands
[0]);
1705 struct TestTypeConversionDriver
1706 : public PassWrapper
<TestTypeConversionDriver
, OperationPass
<>> {
1707 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver
)
1709 void getDependentDialects(DialectRegistry
®istry
) const override
{
1710 registry
.insert
<TestDialect
>();
1712 StringRef
getArgument() const final
{
1713 return "test-legalize-type-conversion";
1715 StringRef
getDescription() const final
{
1716 return "Test various type conversion functionalities in DialectConversion";
1719 void runOnOperation() override
{
1720 // Initialize the type converter.
1721 SmallVector
<Type
, 2> conversionCallStack
;
1722 TypeConverter converter
;
1724 /// Add the legal set of type conversions.
1725 converter
.addConversion([](Type type
) -> Type
{
1726 // Treat F64 as legal.
1729 // Allow converting BF16/F16/F32 to F64.
1730 if (type
.isBF16() || type
.isF16() || type
.isF32())
1731 return FloatType::getF64(type
.getContext());
1732 // Otherwise, the type is illegal.
1735 converter
.addConversion([](IntegerType type
, SmallVectorImpl
<Type
> &) {
1736 // Drop all integer types.
1739 converter
.addConversion(
1740 // Convert a recursive self-referring type into a non-self-referring
1741 // type named "outer_converted_type" that contains a SimpleAType.
1742 [&](test::TestRecursiveType type
,
1743 SmallVectorImpl
<Type
> &results
) -> std::optional
<LogicalResult
> {
1744 // If the type is already converted, return it to indicate that it is
1746 if (type
.getName() == "outer_converted_type") {
1747 results
.push_back(type
);
1751 conversionCallStack
.push_back(type
);
1752 auto popConversionCallStack
= llvm::make_scope_exit(
1753 [&conversionCallStack
]() { conversionCallStack
.pop_back(); });
1755 // If the type is on the call stack more than once (it is there at
1756 // least once because of the _current_ call, which is always the last
1757 // element on the stack), we've hit the recursive case. Just return
1758 // SimpleAType here to create a non-recursive type as a result.
1759 if (llvm::is_contained(ArrayRef(conversionCallStack
).drop_back(),
1761 results
.push_back(test::SimpleAType::get(type
.getContext()));
1765 // Convert the body recursively.
1766 auto result
= test::TestRecursiveType::get(type
.getContext(),
1767 "outer_converted_type");
1768 if (failed(result
.setBody(converter
.convertType(type
.getBody()))))
1770 results
.push_back(result
);
1774 /// Add the legal set of type materializations.
1775 converter
.addSourceMaterialization([](OpBuilder
&builder
, Type resultType
,
1777 Location loc
) -> Value
{
1778 // Allow casting from F64 back to F32.
1779 if (!resultType
.isF16() && inputs
.size() == 1 &&
1780 inputs
[0].getType().isF64())
1781 return builder
.create
<TestCastOp
>(loc
, resultType
, inputs
).getResult();
1782 // Allow producing an i32 or i64 from nothing.
1783 if ((resultType
.isInteger(32) || resultType
.isInteger(64)) &&
1785 return builder
.create
<TestTypeProducerOp
>(loc
, resultType
);
1786 // Allow producing an i64 from an integer.
1787 if (isa
<IntegerType
>(resultType
) && inputs
.size() == 1 &&
1788 isa
<IntegerType
>(inputs
[0].getType()))
1789 return builder
.create
<TestCastOp
>(loc
, resultType
, inputs
).getResult();
1794 // Initialize the conversion target.
1795 mlir::ConversionTarget
target(getContext());
1796 target
.addLegalOp
<LegalOpD
>();
1797 target
.addDynamicallyLegalOp
<TestTypeProducerOp
>([](TestTypeProducerOp op
) {
1798 auto recursiveType
= dyn_cast
<test::TestRecursiveType
>(op
.getType());
1799 return op
.getType().isF64() || op
.getType().isInteger(64) ||
1801 recursiveType
.getName() == "outer_converted_type");
1803 target
.addDynamicallyLegalOp
<func::FuncOp
>([&](func::FuncOp op
) {
1804 return converter
.isSignatureLegal(op
.getFunctionType()) &&
1805 converter
.isLegal(&op
.getBody());
1807 target
.addDynamicallyLegalOp
<TestCastOp
>([&](TestCastOp op
) {
1808 // Allow casts from F64 to F32.
1809 return (*op
.operand_type_begin()).isF64() && op
.getType().isF32();
1811 target
.addDynamicallyLegalOp
<TestSignatureConversionNoConverterOp
>(
1812 [&](TestSignatureConversionNoConverterOp op
) {
1813 return converter
.isLegal(op
.getRegion().front().getArgumentTypes());
1816 // Initialize the set of rewrite patterns.
1817 RewritePatternSet
patterns(&getContext());
1818 patterns
.add
<TestTypeConsumerForward
, TestTypeConversionProducer
,
1819 TestSignatureConversionUndo
,
1820 TestTestSignatureConversionNoConverter
>(converter
,
1822 patterns
.add
<TestTypeConversionAnotherProducer
, TestReplaceWithLegalOp
>(
1824 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns
,
1827 if (failed(applyPartialConversion(getOperation(), target
,
1828 std::move(patterns
))))
1829 signalPassFailure();
1834 //===----------------------------------------------------------------------===//
1835 // Test Target Materialization With No Uses
1836 //===----------------------------------------------------------------------===//
1839 struct ForwardOperandPattern
: public OpConversionPattern
<TestTypeChangerOp
> {
1840 using OpConversionPattern
<TestTypeChangerOp
>::OpConversionPattern
;
1843 matchAndRewrite(TestTypeChangerOp op
, OpAdaptor adaptor
,
1844 ConversionPatternRewriter
&rewriter
) const final
{
1845 rewriter
.replaceOp(op
, adaptor
.getOperands());
1850 struct TestTargetMaterializationWithNoUses
1851 : public PassWrapper
<TestTargetMaterializationWithNoUses
, OperationPass
<>> {
1852 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1853 TestTargetMaterializationWithNoUses
)
1855 StringRef
getArgument() const final
{
1856 return "test-target-materialization-with-no-uses";
1858 StringRef
getDescription() const final
{
1859 return "Test a special case of target materialization in DialectConversion";
1862 void runOnOperation() override
{
1863 TypeConverter converter
;
1864 converter
.addConversion([](Type t
) { return t
; });
1865 converter
.addConversion([](IntegerType intTy
) -> Type
{
1866 if (intTy
.getWidth() == 16)
1867 return IntegerType::get(intTy
.getContext(), 64);
1870 converter
.addTargetMaterialization(
1871 [](OpBuilder
&builder
, Type type
, ValueRange inputs
, Location loc
) {
1872 return builder
.create
<TestCastOp
>(loc
, type
, inputs
).getResult();
1875 ConversionTarget
target(getContext());
1876 target
.addIllegalOp
<TestTypeChangerOp
>();
1878 RewritePatternSet
patterns(&getContext());
1879 patterns
.add
<ForwardOperandPattern
>(converter
, &getContext());
1881 if (failed(applyPartialConversion(getOperation(), target
,
1882 std::move(patterns
))))
1883 signalPassFailure();
1888 //===----------------------------------------------------------------------===//
1889 // Test Block Merging
1890 //===----------------------------------------------------------------------===//
1893 /// A rewriter pattern that tests that blocks can be merged.
1894 struct TestMergeBlock
: public OpConversionPattern
<TestMergeBlocksOp
> {
1895 using OpConversionPattern
<TestMergeBlocksOp
>::OpConversionPattern
;
1898 matchAndRewrite(TestMergeBlocksOp op
, OpAdaptor adaptor
,
1899 ConversionPatternRewriter
&rewriter
) const final
{
1900 Block
&firstBlock
= op
.getBody().front();
1901 Operation
*branchOp
= firstBlock
.getTerminator();
1902 Block
*secondBlock
= &*(std::next(op
.getBody().begin()));
1903 auto succOperands
= branchOp
->getOperands();
1904 SmallVector
<Value
, 2> replacements(succOperands
);
1905 rewriter
.eraseOp(branchOp
);
1906 rewriter
.mergeBlocks(secondBlock
, &firstBlock
, replacements
);
1907 rewriter
.modifyOpInPlace(op
, [] {});
1912 /// A rewrite pattern to tests the undo mechanism of blocks being merged.
1913 struct TestUndoBlocksMerge
: public ConversionPattern
{
1914 TestUndoBlocksMerge(MLIRContext
*ctx
)
1915 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx
) {}
1917 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
1918 ConversionPatternRewriter
&rewriter
) const final
{
1919 Block
&firstBlock
= op
->getRegion(0).front();
1920 Operation
*branchOp
= firstBlock
.getTerminator();
1921 Block
*secondBlock
= &*(std::next(op
->getRegion(0).begin()));
1922 rewriter
.setInsertionPointToStart(secondBlock
);
1923 rewriter
.create
<ILLegalOpF
>(op
->getLoc(), rewriter
.getF32Type());
1924 auto succOperands
= branchOp
->getOperands();
1925 SmallVector
<Value
, 2> replacements(succOperands
);
1926 rewriter
.eraseOp(branchOp
);
1927 rewriter
.mergeBlocks(secondBlock
, &firstBlock
, replacements
);
1928 rewriter
.modifyOpInPlace(op
, [] {});
1933 /// A rewrite mechanism to inline the body of the op into its parent, when both
1934 /// ops can have a single block.
1935 struct TestMergeSingleBlockOps
1936 : public OpConversionPattern
<SingleBlockImplicitTerminatorOp
> {
1937 using OpConversionPattern
<
1938 SingleBlockImplicitTerminatorOp
>::OpConversionPattern
;
1941 matchAndRewrite(SingleBlockImplicitTerminatorOp op
, OpAdaptor adaptor
,
1942 ConversionPatternRewriter
&rewriter
) const final
{
1943 SingleBlockImplicitTerminatorOp parentOp
=
1944 op
->getParentOfType
<SingleBlockImplicitTerminatorOp
>();
1947 Block
&innerBlock
= op
.getRegion().front();
1948 TerminatorOp innerTerminator
=
1949 cast
<TerminatorOp
>(innerBlock
.getTerminator());
1950 rewriter
.inlineBlockBefore(&innerBlock
, op
);
1951 rewriter
.eraseOp(innerTerminator
);
1952 rewriter
.eraseOp(op
);
1957 struct TestMergeBlocksPatternDriver
1958 : public PassWrapper
<TestMergeBlocksPatternDriver
, OperationPass
<>> {
1959 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver
)
1961 StringRef
getArgument() const final
{ return "test-merge-blocks"; }
1962 StringRef
getDescription() const final
{
1963 return "Test Merging operation in ConversionPatternRewriter";
1965 void runOnOperation() override
{
1966 MLIRContext
*context
= &getContext();
1967 mlir::RewritePatternSet
patterns(context
);
1968 patterns
.add
<TestMergeBlock
, TestUndoBlocksMerge
, TestMergeSingleBlockOps
>(
1970 ConversionTarget
target(*context
);
1971 target
.addLegalOp
<func::FuncOp
, ModuleOp
, TerminatorOp
, TestBranchOp
,
1972 TestTypeConsumerOp
, TestTypeProducerOp
, TestReturnOp
>();
1973 target
.addIllegalOp
<ILLegalOpF
>();
1975 /// Expect the op to have a single block after legalization.
1976 target
.addDynamicallyLegalOp
<TestMergeBlocksOp
>(
1977 [&](TestMergeBlocksOp op
) -> bool {
1978 return llvm::hasSingleElement(op
.getBody());
1981 /// Only allow `test.br` within test.merge_blocks op.
1982 target
.addDynamicallyLegalOp
<TestBranchOp
>([&](TestBranchOp op
) -> bool {
1983 return op
->getParentOfType
<TestMergeBlocksOp
>();
1986 /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1988 target
.addDynamicallyLegalOp
<SingleBlockImplicitTerminatorOp
>(
1989 [&](SingleBlockImplicitTerminatorOp op
) -> bool {
1990 return !op
->getParentOfType
<SingleBlockImplicitTerminatorOp
>();
1993 DenseSet
<Operation
*> unlegalizedOps
;
1994 ConversionConfig config
;
1995 config
.unlegalizedOps
= &unlegalizedOps
;
1996 (void)applyPartialConversion(getOperation(), target
, std::move(patterns
),
1998 for (auto *op
: unlegalizedOps
)
1999 op
->emitRemark() << "op '" << op
->getName() << "' is not legalizable";
2004 //===----------------------------------------------------------------------===//
2005 // Test Selective Replacement
2006 //===----------------------------------------------------------------------===//
2009 /// A rewrite mechanism to inline the body of the op into its parent, when both
2010 /// ops can have a single block.
2011 struct TestSelectiveOpReplacementPattern
: public OpRewritePattern
<TestCastOp
> {
2012 using OpRewritePattern
<TestCastOp
>::OpRewritePattern
;
2014 LogicalResult
matchAndRewrite(TestCastOp op
,
2015 PatternRewriter
&rewriter
) const final
{
2016 if (op
.getNumOperands() != 2)
2018 OperandRange operands
= op
.getOperands();
2020 // Replace non-terminator uses with the first operand.
2021 rewriter
.replaceUsesWithIf(op
, operands
[0], [](OpOperand
&operand
) {
2022 return operand
.getOwner()->hasTrait
<OpTrait::IsTerminator
>();
2024 // Replace everything else with the second operand if the operation isn't
2026 rewriter
.replaceOp(op
, op
.getOperand(1));
2031 struct TestSelectiveReplacementPatternDriver
2032 : public PassWrapper
<TestSelectiveReplacementPatternDriver
,
2034 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2035 TestSelectiveReplacementPatternDriver
)
2037 StringRef
getArgument() const final
{
2038 return "test-pattern-selective-replacement";
2040 StringRef
getDescription() const final
{
2041 return "Test selective replacement in the PatternRewriter";
2043 void runOnOperation() override
{
2044 MLIRContext
*context
= &getContext();
2045 mlir::RewritePatternSet
patterns(context
);
2046 patterns
.add
<TestSelectiveOpReplacementPattern
>(context
);
2047 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
2052 //===----------------------------------------------------------------------===//
2054 //===----------------------------------------------------------------------===//
2058 void registerPatternsTestPass() {
2059 PassRegistration
<TestReturnTypeDriver
>();
2061 PassRegistration
<TestDerivedAttributeDriver
>();
2063 PassRegistration
<TestGreedyPatternDriver
>();
2064 PassRegistration
<TestStrictPatternDriver
>();
2065 PassRegistration
<TestWalkPatternDriver
>();
2067 PassRegistration
<TestLegalizePatternDriver
>([] {
2068 return std::make_unique
<TestLegalizePatternDriver
>(legalizerConversionMode
);
2071 PassRegistration
<TestRemappedValue
>();
2073 PassRegistration
<TestUnknownRootOpDriver
>();
2075 PassRegistration
<TestTypeConversionDriver
>();
2076 PassRegistration
<TestTargetMaterializationWithNoUses
>();
2078 PassRegistration
<TestRewriteDynamicOpDriver
>();
2080 PassRegistration
<TestMergeBlocksPatternDriver
>();
2081 PassRegistration
<TestSelectiveReplacementPatternDriver
>();