[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Dialect / Vector / vector-contract-to-outerproduct-matmul-transforms.mlir
blob7a60ff8ea85897b593fe73da3a596e1653f0bdd6
1 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
3 /// Tests for `vector.contract` -> `vector.outerproduct` transformations for
4 /// matmul operations:
5 ///   C += A * B.
6 /// (A, B and C are 2-d matrices). ATM three different variants / are tested:
7 ///   * plain (no mask, fixed-wdith vectors),
8 ///   * masked (fixed-width vectors,
9 ///   * scalable (mask + scalable vectors).
10 /// In order for the "vector.contract -> vector.outerproduct" patterns to work,
11 /// only the non-reduction dimension can be scalable (*). For matmul operations
12 /// that is set to be the N dimension (i.e. rows of the output matrix), which
13 /// matches how matrix multiplication are normally implemented for e.g.
14 /// Arm SVE. However, making the M dimension scalable (i.e. columns of the
15 /// output matrix) should work as well.
16 ///
17 /// (*) The conversion tested in this file unrolls along the reduction
18 /// dimension, which is not supported for scalable vectors.
20 #matmat_accesses_0 = [
21   affine_map<(m, n, k) -> (m, k)>,
22   affine_map<(m, n, k) -> (k, n)>,
23   affine_map<(m, n, k) -> (m, n)>
25 #matmat_trait_0 = {
26   indexing_maps = #matmat_accesses_0,
27   iterator_types = ["parallel", "parallel", "reduction"]
30 #matmat_accesses_1 = [
31   affine_map<(m, n, k) -> (m, k)>,
32   affine_map<(m, n, k) -> (n, k)>,
33   affine_map<(m, n, k) -> (m, n)>
35 #matmat_trait_1 = {
36   indexing_maps = #matmat_accesses_1,
37   iterator_types = ["parallel", "parallel", "reduction"]
40 #matmat_accesses_2 = [
41   affine_map<(m, n, k) -> (k, m)>,
42   affine_map<(m, n, k) -> (k, n)>,
43   affine_map<(m, n, k) -> (m, n)>
45 #matmat_trait_2 = {
46   indexing_maps = #matmat_accesses_2,
47   iterator_types = ["parallel", "parallel", "reduction"]
50 #matmat_accesses_3 = [
51   affine_map<(m, n, k) -> (k, m)>,
52   affine_map<(m, n, k) -> (n, k)>,
53   affine_map<(m, n, k) -> (m, n)>
55 #matmat_trait_3 = {
56   indexing_maps = #matmat_accesses_3,
57   iterator_types = ["parallel", "parallel", "reduction"]
60 #matmat_accesses_4 = [
61   affine_map<(m, n, k) -> (m, k)>,
62   affine_map<(m, n, k) -> (k, n)>,
63   affine_map<(m, n, k) -> (n, m)>
65 #matmat_trait_4 = {
66   indexing_maps = #matmat_accesses_4,
67   iterator_types = ["parallel", "parallel", "reduction"]
70 // ============================================================================
71 //  Matmul 0 (plain + masked + mixed types)
72 // ============================================================================
73 // CHECK-LABEL: func @matmul
74 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
75 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
76 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
77 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
78 // CHECK-SAME:  : vector<2x4xf32> to vector<4x2xf32>
80 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
81 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<4x3xf32>
82 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
83 // CHECK-SAME:  : vector<2xf32>, vector<3xf32>
85 //      CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
86 //      CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<4x3xf32>
87 //      CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
88 // CHECK-SAME:  : vector<2xf32>, vector<3xf32>
90 //      CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
91 //      CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<3xf32> from vector<4x3xf32>
92 //      CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
93 // CHECK-SAME:  : vector<2xf32>, vector<3xf32>
95 //      CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
96 //      CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<3xf32> from vector<4x3xf32>
97 //      CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
98 // CHECK-SAME:  : vector<2xf32>, vector<3xf32>
100 //      CHECK: return %[[c3]] : vector<2x3xf32>
101 func.func @matmul(%A: vector<2x4xf32>,
102                   %B: vector<4x3xf32>,
103                   %C: vector<2x3xf32>) -> vector<2x3xf32> {
104   %0 = vector.contract #matmat_trait_0 %A, %B, %C
105     : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
106   return %0 : vector<2x3xf32>
109 // CHECK-LABEL: func @matmul_scalable
110 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
111 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x[3]xf32>,
112 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
113 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
114 // CHECK-SAME:  : vector<2x4xf32> to vector<4x2xf32>
116 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
117 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<4x[3]xf32>
118 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
119 // CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
121 //      CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
122 //      CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<4x[3]xf32>
123 //      CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
124 // CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
126 //      CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
127 //      CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<4x[3]xf32>
128 //      CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
129 // CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
131 //      CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
132 //      CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<4x[3]xf32>
133 //      CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
134 // CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
136 //      CHECK: return %[[c3]] : vector<2x[3]xf32>
137 func.func @matmul_scalable(%A: vector<2x4xf32>,
138                            %B: vector<4x[3]xf32>,
139                            %C: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
140   %0 = vector.contract #matmat_trait_0 %A, %B, %C
141     : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
142   return %0 : vector<2x[3]xf32>
145 // CHECK-LABEL: func.func @masked_matmul(
146 // CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
147 // CHECK-SAME:    %{{.*}}: vector<5x7xf32>,
148 // CHECK-SAME:    %{{.*}}: vector<3x7xf32>,
149 // CHECK-SAME:    %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
150 // CHECK:         %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
151 // CHECK:         %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
152 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
153 // CHECK:         %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
154 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
155 // CHECK:         %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
156 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
157 // CHECK:         %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
158 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
159 // CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
160 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
162 func.func @masked_matmul(%A: vector<3x5xf32>,
163                          %B: vector<5x7xf32>,
164                          %C: vector<3x7xf32>,
165                          %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
166   %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
167   : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
168   return %0 : vector<3x7xf32>
171 // CHECK-LABEL: func.func @masked_matmul_scalable(
172 // CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
173 // CHECK-SAME:    %{{.*}}: vector<5x[7]xf32>,
174 // CHECK-SAME:    %{{.*}}: vector<3x[7]xf32>,
175 // CHECK-SAME:    %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
176 // CHECK:         %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
177 // CHECK:         %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
178 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
179 // CHECK:         %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
180 // CHECK:         %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
181 // CHECK:         %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
182 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
183 // CHECK:         %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
184 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
185 // CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
186 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
188 func.func @masked_matmul_scalable(%A: vector<3x5xf32>,
189                                   %B: vector<5x[7]xf32>,
190                                   %C: vector<3x[7]xf32>,
191                                   %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
192   %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
193   : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
194   return %0 : vector<3x[7]xf32>
197 // CHECK-LABEL: func @matmul_mixed
198 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
199 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
200 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
201 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
202 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
203 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf16> from vector<1x3xf16>
204 //      CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
205 //      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
206 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
207 //      CHECK: return %[[c0]] : vector<2x3xf32>
208 func.func @matmul_mixed(%A: vector<2x1xf16>,
209                         %B: vector<1x3xf16>,
210                         %C: vector<2x3xf32>) -> vector<2x3xf32>
212   %0 = vector.contract #matmat_trait_0 %A, %B, %C
213     : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
214   return %0 : vector<2x3xf32>
217 // CHECK-LABEL: func @matmul_mixed_scalable
218 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
219 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>,
220 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
221 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
222 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
223 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<1x[3]xf16>
224 //      CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
225 //      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
226 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
227 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
228 func.func @matmul_mixed_scalable(%A: vector<2x1xf16>,
229                                  %B: vector<1x[3]xf16>,
230                                  %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
232   %0 = vector.contract #matmat_trait_0 %A, %B, %C
233     : vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
234   return %0 : vector<2x[3]xf32>
237 // ============================================================================
238 //  Matmul 1 (plain + scalable)
239 // ============================================================================
240 // CHECK-LABEL: func @matmul_1
241 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
242 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
243 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
244 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
245 //      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
246 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
247 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
248 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
249 //      CHECK: return %[[c0]] : vector<2x3xf32>
250 func.func @matmul_1(%A: vector<2x1xf32>,
251                     %B: vector<3x1xf32>,
252                     %C: vector<2x3xf32>) -> vector<2x3xf32>
254   %0 = vector.contract #matmat_trait_1 %A, %B, %C
255     : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
256   return %0 : vector<2x3xf32>
259 // CHECK-LABEL: func @matmul_1_scalable
260 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
261 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
262 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
263 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
264 //      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
265 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
266 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
267 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
268 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
269 func.func @matmul_1_scalable(%A: vector<2x1xf32>,
270                              %B: vector<[3]x1xf32>,
271                              %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
273   %0 = vector.contract #matmat_trait_1 %A, %B, %C
274     : vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
275   return %0 : vector<2x[3]xf32>
278 // ============================================================================
279 //  Matmul 2 (plain + scalable)
280 // ============================================================================
281 // CHECK-LABEL: func @matmul_2
282 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
283 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
284 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
285 //      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
286 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
287 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
288 //      CHECK: return %[[c0]] : vector<2x3xf32>
289 func.func @matmul_2(%A: vector<1x2xf32>,
290                     %B: vector<1x3xf32>,
291                     %C: vector<2x3xf32>) -> vector<2x3xf32>
293   %0 = vector.contract #matmat_trait_2 %A, %B, %C
294     : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
295   return %0 : vector<2x3xf32>
298 // CHECK-LABEL: func @matmul_2_scalable
299 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
300 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
301 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
302 //      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
303 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
304 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
305 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
306 func.func @matmul_2_scalable(%A: vector<1x2xf32>,
307                              %B: vector<1x[3]xf32>,
308                              %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
310   %0 = vector.contract #matmat_trait_2 %A, %B, %C
311     : vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
312   return %0 : vector<2x[3]xf32>
315 // ============================================================================
316 //  Matmul 3 (plain + scalable)
317 // ============================================================================
318 // CHECK-LABEL: func @matmul_3
319 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
320 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
321 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
322 //      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
323 //      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
324 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
325 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
326 //      CHECK: return %[[c0]] : vector<2x3xf32>
327 func.func @matmul_3(%A: vector<1x2xf32>,
328                     %B: vector<3x1xf32>,
329                     %C: vector<2x3xf32>) -> vector<2x3xf32>
331   %0 = vector.contract #matmat_trait_3 %A, %B, %C
332     : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
333   return %0 : vector<2x3xf32>
336 // CHECK-LABEL: func @matmul_3_scalable
337 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
338 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
339 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
340 //      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
341 //      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
342 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
343 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
344 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
345 func.func @matmul_3_scalable(%A: vector<1x2xf32>,
346                              %B: vector<[3]x1xf32>,
347                              %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
349   %0 = vector.contract #matmat_trait_3 %A, %B, %C
350     : vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
351   return %0 : vector<2x[3]xf32>
354 // ============================================================================
355 //  Matmul 4 (plain + scalable)
356 // ============================================================================
357 // CHECK-LABEL: func @matmul_4
358 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
359 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
360 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
361 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
362 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
363 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
364 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
365 //      CHECK: return %[[c0]] : vector<3x2xf32>
366 func.func @matmul_4(%A: vector<2x1xf32>,
367                     %B: vector<1x3xf32>,
368                     %C: vector<3x2xf32>) -> vector<3x2xf32>
370   %0 = vector.contract #matmat_trait_4 %A, %B, %C
371     : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
372   return %0 : vector<3x2xf32>
375 // CHECK-LABEL: func @matmul_4_scalable
376 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<[2]x1xf32>,
377 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
378 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x[2]xf32>
379 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
380 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
381 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
382 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
383 //      CHECK: return %[[c0]] : vector<3x[2]xf32>
384 func.func @matmul_4_scalable(%A: vector<[2]x1xf32>,
385                              %B: vector<1x3xf32>,
386                              %C: vector<3x[2]xf32>) -> vector<3x[2]xf32>
388   %0 = vector.contract #matmat_trait_4 %A, %B, %C
389     : vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
390   return %0 : vector<3x[2]xf32>
393 // ============================================================================
394 //  TD sequence
395 // ============================================================================
396 module attributes {transform.with_named_sequence} {
397   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
398     %f = transform.structured.match ops{["func.func"]} in %module_op
399       : (!transform.any_op) -> !transform.any_op
401     transform.apply_patterns to %f {
402       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
403     } : !transform.any_op
404     transform.yield
405   }