[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Examples / transform / Ch4 / features.mlir
blobd23b6e8435ef53db62c3c0b587481d59bfd6a82e
1 // RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics
3 // Matmul as a named operation.
4 func.func @named(
5     %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
6     %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
7     -> tensor<512x512xf32> {
8   // expected-remark @below {{matmul}}
9   %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
10                           outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
11   func.return %matmul : tensor<512x512xf32>
14 // Matmul as a generic operation.
15 func.func @generic(
16     %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
17     %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
18     -> tensor<512x512xf32> {
19   // expected-remark @below {{matmul}}
20   %matmul = linalg.generic {
21     iterator_types = ["parallel", "parallel", "reduction"],
22     indexing_maps = [
23       affine_map<(d0, d1, d2) -> (d0, d2)>,
24       affine_map<(d0, d1, d2) -> (d2, d1)>,
25       affine_map<(d0, d1, d2) -> (d0, d1)>]
26   } ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
27     outs(%output: tensor<512x512xf32>) {
28   ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
29     %0 = arith.mulf %arg0, %arg1 : f32
30     %1 = arith.addf %0, %arg2 : f32
31     linalg.yield %1 : f32
32   } -> tensor<512x512xf32>
33   return %matmul : tensor<512x512xf32>
36 // The module containing named sequences must have an attribute allowing them
37 // to enable verification.
38 module @transforms attributes { transform.with_named_sequence } {
39   // Entry point. This takes as the only argument the root operation (typically
40   // pass root) given to the transform interpreter.
41   transform.named_sequence @__transform_main(
42       %root: !transform.any_op {transform.consumed}) {
44     // Traverses the payload IR associated with the operand handle, invoking
45     // @match_matmul_elemwise on each of the operations. If the named sequence
46     // succeeds, i.e., if none of the nested match (transform) operations
47     // produced a silenceable failure, invokes @print_matmul_elemwise and
48     // forwards the values yielded as arguments of the new invocation. If the
49     // named sequence fails with a silenceable failure, silences it (the message
50     // is forwarded to the debug stream). Definite failures are propagated
51     // immediately and unconditionally, as usual.
52     transform.foreach_match in %root
53       @match_generic_matmul -> @print_generic_matmul
54       : (!transform.any_op) -> !transform.any_op
56     transform.yield
57   }
59   // This is an action sequence.
60   transform.named_sequence @print_generic_matmul(
61       %matmul: !transform.any_op {transform.readonly}) {
62     transform.debug.emit_remark_at %matmul, "matmul" : !transform.any_op
63     transform.yield
64   }
66   transform.named_sequence @match_generic_matmul(
67       %candidate: !transform.any_op {transform.readonly}) -> !transform.any_op {
68     // Match a structured linear algebra operation.
69     transform.match.structured %candidate : !transform.any_op {
70     ^bb0(%c: !transform.any_op):
71       // With a rank equal to 3.
72       %rank = transform.match.structured.rank %c
73         : (!transform.any_op) -> !transform.param<i64>
74       %c3 = transform.param.constant 3 : i64 -> !transform.param<i64>
75       transform.match.param.cmpi eq %rank, %c3 : !transform.param<i64>
77       // With 2 inputs.
78       %n_ins = transform.match.structured.num_inputs %c
79         : (!transform.any_op) -> !transform.param<i64>
80       %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
81       transform.match.param.cmpi eq %n_ins, %c2 : !transform.param<i64>
83       // With 1 output (note that structured ops in destination passing style
84       // has as many inits as outputs).
85       %n_inits = transform.match.structured.num_inits %c
86         : (!transform.any_op) -> !transform.param<i64>
87       %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
88       transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64>
90       // All inputs and inits are accessed with a projected permutation.
91       transform.match.structured.input %c[all] {projected_permutation}
92         : !transform.any_op
93       transform.match.structured.init %c[0] {projected_permutation}
94         : !transform.any_op
96       // The body is a mulf/addf contraction with appropriate dimensions.
97       transform.match.structured.body %c 
98         { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
99       %batch, %lhs, %rhs, %reduction =
100       transform.match.structured.classify_contraction_dims %c
101         : (!transform.any_op)
102         -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
103             !transform.param<i64>)
105       // There is one of lhs, rhs and reduction dimensions and zero batch
106       // dimensions.
107       %n_batch = transform.num_associations %batch
108         : (!transform.param<i64>) -> !transform.param<i64>
109       %n_lhs = transform.num_associations %lhs
110         : (!transform.param<i64>) -> !transform.param<i64>
111       %n_rhs = transform.num_associations %rhs
112         : (!transform.param<i64>) -> !transform.param<i64>
113       %n_reduction = transform.num_associations %reduction
114         : (!transform.param<i64>) -> !transform.param<i64>
115       %c0 = transform.param.constant 0 : i64 -> !transform.param<i64>
116       transform.match.param.cmpi eq %n_batch, %c0 : !transform.param<i64>
117       transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param<i64>
118       transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param<i64>
119       transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param<i64>
120     }
121     transform.yield %candidate : !transform.any_op
122   }