1 # RUN: %PYTHON %s 2>&1 | FileCheck %s
3 from mlir
.passmanager
import PassManager
4 from mlir
.ir
import Context
, Location
, Module
, InsertionPoint
, UnitAttr
5 from mlir
.dialects
import scf
, pdl
, func
, arith
, linalg
6 from mlir
.dialects
.transform
import (
8 apply_patterns_canonicalization
,
12 from mlir
.dialects
.transform
.structured
import structured_match
13 from mlir
.dialects
.transform
.loop
import loop_unroll
14 from mlir
.dialects
.transform
.extras
import named_sequence
, apply_patterns
15 from mlir
.extras
import types
as T
16 from mlir
.dialects
.builtin
import module
, ModuleOp
19 def construct_and_print_in_module(f
):
20 print("\nTEST:", f
.__name
__)
21 with
Context(), Location
.unknown():
22 module
= Module
.create()
23 with
InsertionPoint(module
.body
):
25 if module
is not None:
30 # CHECK-LABEL: TEST: test_named_sequence
31 @construct_and_print_in_module
32 def test_named_sequence(module_
):
33 # CHECK-LABEL: func.func @loop_unroll_op() {
34 # CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
35 # CHECK: %[[VAL_1:.*]] = arith.constant 42 : index
36 # CHECK: %[[VAL_2:.*]] = arith.constant 5 : index
37 # CHECK: scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
38 # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
44 for i
in scf
.for_(0, 42, 5):
48 # CHECK-LABEL: module attributes {transform.with_named_sequence} {
49 # CHECK: transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) {
50 # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
51 # CHECK: %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
52 # CHECK: transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
53 # CHECK: transform.yield
56 @module(attrs
={"transform.with_named_sequence": UnitAttr
.get()})
58 @named_sequence("__transform_main", [any_op_t()], [])
59 def basic(target
: any_op_t()):
60 m
= structured_match(any_op_t(), target
, ops
=["arith.addi"])
61 loop
= get_parent_op(pdl
.op_t(), m
, op_name
="scf.for")
64 # The identifier (name) of the function becomes the Operation
65 assert isinstance(mod
.opview
, ModuleOp
)
69 pm
= PassManager
.parse("builtin.module(transform-interpreter)")
70 pm
.run(module_
.operation
)
72 # CHECK-LABEL: func.func @loop_unroll_op() {
73 # CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
74 # CHECK: %[[VAL_1:.*]] = arith.constant 42 : index
75 # CHECK: %[[VAL_2:.*]] = arith.constant 5 : index
76 # CHECK: %[[VAL_6:.*]] = arith.constant 40 : index
77 # CHECK: %[[VAL_7:.*]] = arith.constant 20 : index
78 # CHECK: scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_6]] step %[[VAL_7]] {
79 # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
80 # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
81 # CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_2]], %[[VAL_8]] : index
82 # CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_3]], %[[VAL_9]] : index
83 # CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_10]] : index
84 # CHECK: %[[VAL_12:.*]] = arith.constant 2 : index
85 # CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_2]], %[[VAL_12]] : index
86 # CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_3]], %[[VAL_13]] : index
87 # CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_14]] : index
88 # CHECK: %[[VAL_16:.*]] = arith.constant 3 : index
89 # CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_2]], %[[VAL_16]] : index
90 # CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_3]], %[[VAL_17]] : index
91 # CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_18]] : index
93 # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
99 # CHECK-LABEL: TEST: test_apply_patterns
100 @construct_and_print_in_module
101 def test_apply_patterns(module_
):
102 b
, M
, N
, K
= 1, 3, 5, 3
104 # CHECK-LABEL: func.func @batch_reduce_matmul(
105 # CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x5xf32>,
106 # CHECK-SAME: %[[VAL_1:.*]]: tensor<1x5x3xf32>,
107 # CHECK-SAME: %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
108 # CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
109 # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : i32
110 # CHECK: %[[VAL_5:.*]] = linalg.batch_reduce_matmul ins(%[[VAL_0]], %[[VAL_1]] : tensor<1x3x5xf32>, tensor<1x5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
111 # CHECK: return %[[VAL_5]] : tensor<3x3xf32>
114 T
.tensor(b
, M
, N
, T
.f32()), T
.tensor(b
, N
, K
, T
.f32()), T
.tensor(M
, K
, T
.f32())
116 def batch_reduce_matmul(A
, B
, C
):
117 i
= arith
.constant(T
.i32(), 1)
119 return linalg
.batch_reduce_matmul(A
, B
, outs
=[C
])
121 # CHECK-LABEL: module attributes {transform.with_named_sequence} {
122 # CHECK: transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) {
123 # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
124 # CHECK: %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
125 # CHECK: transform.apply_patterns to %[[VAL_2]] {
126 # CHECK: transform.apply_patterns.canonicalization
127 # CHECK: } : !pdl.operation
128 # CHECK: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
129 # CHECK: transform.apply_cse to %[[VAL_3]] : !transform.any_op
130 # CHECK: transform.yield
133 @module(attrs
={"transform.with_named_sequence": UnitAttr
.get()})
135 @named_sequence("__transform_main", [any_op_t()], [])
136 def basic(variant_op
: any_op_t()):
137 matmul
= structured_match(
138 any_op_t(), variant_op
, ops
=["linalg.batch_reduce_matmul"]
140 top_func
= get_parent_op(pdl
.op_t(), matmul
, op_name
="func.func")
142 @apply_patterns(top_func
)
144 apply_patterns_canonicalization()
146 top_func
= structured_match(any_op_t(), variant_op
, ops
=["func.func"])
151 pm
= PassManager
.parse("builtin.module(transform-interpreter)")
152 pm
.run(module_
.operation
)
154 # CHECK-LABEL: func.func @batch_reduce_matmul(
155 # CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x5xf32>, %[[VAL_1:.*]]: tensor<1x5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
156 # CHECK: %[[VAL_3:.*]] = linalg.batch_reduce_matmul ins(%[[VAL_0]], %[[VAL_1]] : tensor<1x3x5xf32>, tensor<1x5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
157 # CHECK: return %[[VAL_3]] : tensor<3x3xf32>