[C++20][modules] Fix std::initializer_list recognition if it's exported out of a...
[llvm-project.git] / mlir / test / python / integration / dialects / pdl.py
blob923af29a71ad757cee46cfbe06aa80e9091abad9
1 # RUN: %PYTHON %s 2>&1 | FileCheck %s
3 from mlir.dialects import arith, func, pdl
4 from mlir.dialects.builtin import module
5 from mlir.ir import *
6 from mlir.rewrite import *
9 def construct_and_print_in_module(f):
10 print("\nTEST:", f.__name__)
11 with Context(), Location.unknown():
12 module = Module.create()
13 with InsertionPoint(module.body):
14 module = f(module)
15 if module is not None:
16 print(module)
17 return f
20 # CHECK-LABEL: TEST: test_add_to_mul
21 # CHECK: arith.muli
22 @construct_and_print_in_module
23 def test_add_to_mul(module_):
24 index_type = IndexType.get()
26 # Create a test case.
27 @module(sym_name="ir")
28 def ir():
29 @func.func(index_type, index_type)
30 def add_func(a, b):
31 return arith.addi(a, b)
33 # Create a rewrite from add to mul. This will match
34 # - operation name is arith.addi
35 # - operands are index types.
36 # - there are two operands.
37 with Location.unknown():
38 m = Module.create()
39 with InsertionPoint(m.body):
40 # Change all arith.addi with index types to arith.muli.
41 @pdl.pattern(benefit=1, sym_name="addi_to_mul")
42 def pat():
43 # Match arith.addi with index types.
44 index_type = pdl.TypeOp(IndexType.get())
45 operand0 = pdl.OperandOp(index_type)
46 operand1 = pdl.OperandOp(index_type)
47 op0 = pdl.OperationOp(
48 name="arith.addi", args=[operand0, operand1], types=[index_type]
51 # Replace the matched op with arith.muli.
52 @pdl.rewrite()
53 def rew():
54 newOp = pdl.OperationOp(
55 name="arith.muli", args=[operand0, operand1], types=[index_type]
57 pdl.ReplaceOp(op0, with_op=newOp)
59 # Create a PDL module from module and freeze it. At this point the ownership
60 # of the module is transferred to the PDL module. This ownership transfer is
61 # not yet captured Python side/has sharp edges. So best to construct the
62 # module and PDL module in same scope.
63 # FIXME: This should be made more robust.
64 frozen = PDLModule(m).freeze()
65 # Could apply frozen pattern set multiple times.
66 apply_patterns_and_fold_greedily(module_, frozen)
67 return module_