[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Examples / transform / Ch4 / multiple.mlir
blob2c2c7059055213d00f6660ed1a8a49bd6e65d2bd
1 // RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics
3 // Matmul+ReLU.
4 func.func @fc_relu_operands_00(
5     %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
6     %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
7     -> tensor<512x512xf32> {
8   // Matrix-matrix multiplication.
9   // expected-remark @below {{matmul # 0}}
10   %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
11                           outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
13   // Elementwise addition.
14   // expected-remark @below {{add # 0}}
15   %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
16     ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
17     outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
19   // Elementwise max with 0 (ReLU).
20   %c0f = arith.constant 0.0 : f32
21   // expected-remark @below {{max # 0}}
22   %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
23     ins(%biased, %c0f : tensor<512x512xf32>, f32)
24     outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
25   func.return %relued : tensor<512x512xf32>
28 // Matmul+ReLU with swapped operands.
29 func.func @fc_relu_operands_01(
30     %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
31     %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
32     -> tensor<512x512xf32> {
33   // Matrix-matrix multiplication.
34   // expected-remark @below {{matmul # 1}}
35   %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
36                           outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
38   // Elementwise addition.
39   // expected-remark @below {{add # 1}}
40   %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
41     ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
42     outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
44   // Elementwise max with 0 (ReLU).
45   %c0f = arith.constant 0.0 : f32
46   // expected-remark @below {{max # 1}}
47   %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
48     ins(%c0f, %biased : f32, tensor<512x512xf32>)
49     outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
50   func.return %relued : tensor<512x512xf32>
53 // The module containing named sequences must have an attribute allowing them
54 // to enable verification.
55 module @transforms attributes { transform.with_named_sequence } {
56   // Entry point. This takes as the only argument the root operation (typically
57   // pass root) given to the transform interpreter.
58   transform.named_sequence @__transform_main(
59       %root: !transform.any_op {transform.consumed}) {
61     // Traverses the payload IR associated with the operand handle, invoking
62     // @match_matmul_elemwise on each of the operations. If the named sequence
63     // succeeds, i.e., if none of the nested match (transform) operations
64     // produced a silenceable failure, invokes @print_matmul_elemwise and
65     // forwards the values yielded as arguments of the new invocation. If the
66     // named sequence fails with a silenceable failure, silences it (the message
67     // is forwarded to the debug stream). Definite failures are propagated
68     // immediately and unconditionally, as usual.
69     transform.foreach_match in %root
70       @match_matmul_elemwise -> @print_matmul_elemwise
71       : (!transform.any_op) -> !transform.any_op
73     transform.yield
74   }
76   // This is an action sequence.
77   transform.named_sequence @print_matmul_elemwise(
78       %matmul: !transform.any_op {transform.readonly},
79       %add: !transform.any_op {transform.readonly},
80       %max: !transform.any_op {transform.readonly},
81       %pos: !transform.param<i32> {transform.readonly}) {
82     transform.debug.emit_param_as_remark %pos, "matmul #" at %matmul
83       : !transform.param<i32>, !transform.any_op
84     transform.debug.emit_param_as_remark %pos, "add #" at %add
85       : !transform.param<i32>, !transform.any_op
86     transform.debug.emit_param_as_remark %pos, "max #" at %max
87       : !transform.param<i32>, !transform.any_op
88     transform.yield
89   }
91   // This is also a matcher sequence. It is similarly given an operation to
92   // match and nested operations must succeed in order for a match to be deemed
93   // successful. It starts matching from the last operation in the use-def chain
94   // and goes back because each operand (use) has exactly one definition.
95   transform.named_sequence @match_matmul_elemwise(
96       %last: !transform.any_op {transform.readonly}) 
97       -> (!transform.any_op, !transform.any_op, !transform.any_op,
98           !transform.param<i32>) {
99     // The last operation must be an elementwise binary.
100     transform.match.operation_name %last ["linalg.elemwise_binary"]
101       : !transform.any_op
103     // One of its operands must be defined by another operation, to which we
104     // will get a handle here. This is achieved thanks to a newly defined
105     // operation that tries to match operands one by one using the match
106     // operations nested in its region.
107     %pos, %middle = transform.match.my.has_operand_satisfying %last
108         : (!transform.any_op) -> (!transform.param<i32>, !transform.any_op) {
109     ^bb0(%operand: !transform.any_value):
110       // The operand must be defined by an operation.
111       %def = transform.get_defining_op %operand 
112         : (!transform.any_value) -> !transform.any_op
113       // The defining operation must itself be an elementwise binary.
114       transform.match.operation_name %def ["linalg.elemwise_binary"]
115         : !transform.any_op
116       transform.yield %def : !transform.any_op
117     }
118     
119     // And the first operand of that operation must be defined by yet another
120     // operation.
121     %matmul = transform.get_producer_of_operand %middle[0]
122       : (!transform.any_op) -> !transform.any_op
123     // And that operation is a matmul.
124     transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
125     // We will yield the handles to the matmul and the two elementwise
126     // operations separately. 
127     transform.yield %matmul, %middle, %last, %pos
128       : !transform.any_op, !transform.any_op, !transform.any_op,
129         !transform.param<i32>
130   }