1 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
3 /// Tests for `vector.contract` -> `vector.outerproduct` transformations for
6 /// (A, B and C are 2-d matrices). ATM three different variants / are tested:
7 /// * plain (no mask, fixed-wdith vectors),
8 /// * masked (fixed-width vectors,
9 /// * scalable (mask + scalable vectors).
10 /// In order for the "vector.contract -> vector.outerproduct" patterns to work,
11 /// only the non-reduction dimension can be scalable (*). For matmul operations
12 /// that is set to be the N dimension (i.e. rows of the output matrix), which
13 /// matches how matrix multiplication are normally implemented for e.g.
14 /// Arm SVE. However, making the M dimension scalable (i.e. columns of the
15 /// output matrix) should work as well.
17 /// (*) The conversion tested in this file unrolls along the reduction
18 /// dimension, which is not supported for scalable vectors.
20 #matmat_accesses_0 = [
21 affine_map<(m, n, k) -> (m, k)>,
22 affine_map<(m, n, k) -> (k, n)>,
23 affine_map<(m, n, k) -> (m, n)>
26 indexing_maps = #matmat_accesses_0,
27 iterator_types = ["parallel", "parallel", "reduction"]
30 #matmat_accesses_1 = [
31 affine_map<(m, n, k) -> (m, k)>,
32 affine_map<(m, n, k) -> (n, k)>,
33 affine_map<(m, n, k) -> (m, n)>
36 indexing_maps = #matmat_accesses_1,
37 iterator_types = ["parallel", "parallel", "reduction"]
40 #matmat_accesses_2 = [
41 affine_map<(m, n, k) -> (k, m)>,
42 affine_map<(m, n, k) -> (k, n)>,
43 affine_map<(m, n, k) -> (m, n)>
46 indexing_maps = #matmat_accesses_2,
47 iterator_types = ["parallel", "parallel", "reduction"]
50 #matmat_accesses_3 = [
51 affine_map<(m, n, k) -> (k, m)>,
52 affine_map<(m, n, k) -> (n, k)>,
53 affine_map<(m, n, k) -> (m, n)>
56 indexing_maps = #matmat_accesses_3,
57 iterator_types = ["parallel", "parallel", "reduction"]
60 #matmat_accesses_4 = [
61 affine_map<(m, n, k) -> (m, k)>,
62 affine_map<(m, n, k) -> (k, n)>,
63 affine_map<(m, n, k) -> (n, m)>
66 indexing_maps = #matmat_accesses_4,
67 iterator_types = ["parallel", "parallel", "reduction"]
70 // ============================================================================
71 // Matmul 0 (plain + masked + mixed types)
72 // ============================================================================
73 // CHECK-LABEL: func @matmul
74 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
75 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
76 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
77 // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
78 // CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32>
80 // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
81 // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<4x3xf32>
82 // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
83 // CHECK-SAME: : vector<2xf32>, vector<3xf32>
85 // CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
86 // CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<4x3xf32>
87 // CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
88 // CHECK-SAME: : vector<2xf32>, vector<3xf32>
90 // CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
91 // CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<3xf32> from vector<4x3xf32>
92 // CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
93 // CHECK-SAME: : vector<2xf32>, vector<3xf32>
95 // CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
96 // CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<3xf32> from vector<4x3xf32>
97 // CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
98 // CHECK-SAME: : vector<2xf32>, vector<3xf32>
100 // CHECK: return %[[c3]] : vector<2x3xf32>
101 func.func @matmul(%A: vector<2x4xf32>,
103 %C: vector<2x3xf32>) -> vector<2x3xf32> {
104 %0 = vector.contract #matmat_trait_0 %A, %B, %C
105 : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
106 return %0 : vector<2x3xf32>
109 // CHECK-LABEL: func @matmul_scalable
110 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
111 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x[3]xf32>,
112 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
113 // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
114 // CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32>
116 // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
117 // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<4x[3]xf32>
118 // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
119 // CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
121 // CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
122 // CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<4x[3]xf32>
123 // CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
124 // CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
126 // CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
127 // CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<4x[3]xf32>
128 // CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
129 // CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
131 // CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
132 // CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<4x[3]xf32>
133 // CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
134 // CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
136 // CHECK: return %[[c3]] : vector<2x[3]xf32>
137 func.func @matmul_scalable(%A: vector<2x4xf32>,
138 %B: vector<4x[3]xf32>,
139 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
140 %0 = vector.contract #matmat_trait_0 %A, %B, %C
141 : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
142 return %0 : vector<2x[3]xf32>
145 // CHECK-LABEL: func.func @masked_matmul(
146 // CHECK-SAME: %{{.*}}: vector<3x5xf32>,
147 // CHECK-SAME: %{{.*}}: vector<5x7xf32>,
148 // CHECK-SAME: %{{.*}}: vector<3x7xf32>,
149 // CHECK-SAME: %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
150 // CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
151 // CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
152 // CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
153 // CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
154 // CHECK: %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
155 // CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
156 // CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
157 // CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
158 // CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
159 // CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
160 // CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
162 func.func @masked_matmul(%A: vector<3x5xf32>,
165 %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
166 %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
167 : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
168 return %0 : vector<3x7xf32>
171 // CHECK-LABEL: func.func @masked_matmul_scalable(
172 // CHECK-SAME: %{{.*}}: vector<3x5xf32>,
173 // CHECK-SAME: %{{.*}}: vector<5x[7]xf32>,
174 // CHECK-SAME: %{{.*}}: vector<3x[7]xf32>,
175 // CHECK-SAME: %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
176 // CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
177 // CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
178 // CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
179 // CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
180 // CHECK: %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
181 // CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
182 // CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
183 // CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
184 // CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
185 // CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
186 // CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
188 func.func @masked_matmul_scalable(%A: vector<3x5xf32>,
189 %B: vector<5x[7]xf32>,
190 %C: vector<3x[7]xf32>,
191 %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
192 %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
193 : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
194 return %0 : vector<3x[7]xf32>
197 // CHECK-LABEL: func @matmul_mixed
198 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
199 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
200 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
201 // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
202 // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
203 // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf16> from vector<1x3xf16>
204 // CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
205 // CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
206 // CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
207 // CHECK: return %[[c0]] : vector<2x3xf32>
208 func.func @matmul_mixed(%A: vector<2x1xf16>,
210 %C: vector<2x3xf32>) -> vector<2x3xf32>
212 %0 = vector.contract #matmat_trait_0 %A, %B, %C
213 : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
214 return %0 : vector<2x3xf32>
217 // CHECK-LABEL: func @matmul_mixed_scalable
218 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
219 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>,
220 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
221 // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
222 // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
223 // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<1x[3]xf16>
224 // CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
225 // CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
226 // CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
227 // CHECK: return %[[c0]] : vector<2x[3]xf32>
228 func.func @matmul_mixed_scalable(%A: vector<2x1xf16>,
229 %B: vector<1x[3]xf16>,
230 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
232 %0 = vector.contract #matmat_trait_0 %A, %B, %C
233 : vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
234 return %0 : vector<2x[3]xf32>
237 // ============================================================================
238 // Matmul 1 (plain + scalable)
239 // ============================================================================
240 // CHECK-LABEL: func @matmul_1
241 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
242 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
243 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
244 // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
245 // CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
246 // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
247 // CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
248 // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
249 // CHECK: return %[[c0]] : vector<2x3xf32>
250 func.func @matmul_1(%A: vector<2x1xf32>,
252 %C: vector<2x3xf32>) -> vector<2x3xf32>
254 %0 = vector.contract #matmat_trait_1 %A, %B, %C
255 : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
256 return %0 : vector<2x3xf32>
259 // CHECK-LABEL: func @matmul_1_scalable
260 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
261 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
262 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
263 // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
264 // CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
265 // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
266 // CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
267 // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
268 // CHECK: return %[[c0]] : vector<2x[3]xf32>
269 func.func @matmul_1_scalable(%A: vector<2x1xf32>,
270 %B: vector<[3]x1xf32>,
271 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
273 %0 = vector.contract #matmat_trait_1 %A, %B, %C
274 : vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
275 return %0 : vector<2x[3]xf32>
278 // ============================================================================
279 // Matmul 2 (plain + scalable)
280 // ============================================================================
281 // CHECK-LABEL: func @matmul_2
282 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
283 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
284 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
285 // CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
286 // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
287 // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
288 // CHECK: return %[[c0]] : vector<2x3xf32>
289 func.func @matmul_2(%A: vector<1x2xf32>,
291 %C: vector<2x3xf32>) -> vector<2x3xf32>
293 %0 = vector.contract #matmat_trait_2 %A, %B, %C
294 : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
295 return %0 : vector<2x3xf32>
298 // CHECK-LABEL: func @matmul_2_scalable
299 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
300 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
301 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
302 // CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
303 // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
304 // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
305 // CHECK: return %[[c0]] : vector<2x[3]xf32>
306 func.func @matmul_2_scalable(%A: vector<1x2xf32>,
307 %B: vector<1x[3]xf32>,
308 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
310 %0 = vector.contract #matmat_trait_2 %A, %B, %C
311 : vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
312 return %0 : vector<2x[3]xf32>
315 // ============================================================================
316 // Matmul 3 (plain + scalable)
317 // ============================================================================
318 // CHECK-LABEL: func @matmul_3
319 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
320 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
321 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
322 // CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
323 // CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
324 // CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
325 // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
326 // CHECK: return %[[c0]] : vector<2x3xf32>
327 func.func @matmul_3(%A: vector<1x2xf32>,
329 %C: vector<2x3xf32>) -> vector<2x3xf32>
331 %0 = vector.contract #matmat_trait_3 %A, %B, %C
332 : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
333 return %0 : vector<2x3xf32>
336 // CHECK-LABEL: func @matmul_3_scalable
337 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
338 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
339 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
340 // CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
341 // CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
342 // CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
343 // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
344 // CHECK: return %[[c0]] : vector<2x[3]xf32>
345 func.func @matmul_3_scalable(%A: vector<1x2xf32>,
346 %B: vector<[3]x1xf32>,
347 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
349 %0 = vector.contract #matmat_trait_3 %A, %B, %C
350 : vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
351 return %0 : vector<2x[3]xf32>
354 // ============================================================================
355 // Matmul 4 (plain + scalable)
356 // ============================================================================
357 // CHECK-LABEL: func @matmul_4
358 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
359 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
360 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
361 // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
362 // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
363 // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
364 // CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
365 // CHECK: return %[[c0]] : vector<3x2xf32>
366 func.func @matmul_4(%A: vector<2x1xf32>,
368 %C: vector<3x2xf32>) -> vector<3x2xf32>
370 %0 = vector.contract #matmat_trait_4 %A, %B, %C
371 : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
372 return %0 : vector<3x2xf32>
375 // CHECK-LABEL: func @matmul_4_scalable
376 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<[2]x1xf32>,
377 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
378 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x[2]xf32>
379 // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
380 // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
381 // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
382 // CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
383 // CHECK: return %[[c0]] : vector<3x[2]xf32>
384 func.func @matmul_4_scalable(%A: vector<[2]x1xf32>,
386 %C: vector<3x[2]xf32>) -> vector<3x[2]xf32>
388 %0 = vector.contract #matmat_trait_4 %A, %B, %C
389 : vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
390 return %0 : vector<3x[2]xf32>
393 // ============================================================================
395 // ============================================================================
396 module attributes {transform.with_named_sequence} {
397 transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
398 %f = transform.structured.match ops{["func.func"]} in %module_op
399 : (!transform.any_op) -> !transform.any_op
401 transform.apply_patterns to %f {
402 transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
403 } : !transform.any_op