1 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
4 affine_map<(i) -> (i)>,
5 affine_map<(i) -> (i)>,
9 indexing_maps = #dotp_accesses,
10 iterator_types = ["reduction"]
13 // CHECK-LABEL: func @extract_contract1
14 // CHECK-SAME: %[[A:.*0]]: vector<4xf32>,
15 // CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
16 // CHECK-SAME: %[[C:.*2]]: f32
17 // CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32>
18 // CHECK: %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xf32> into f32
19 // CHECK: return %[[R]] : f32
21 func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
22 %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
23 : vector<4xf32>, vector<4xf32> into f32
27 // CHECK-LABEL: func @masked_extract_contract1
28 // CHECK-SAME: %[[A:.*0]]: vector<4xf32>, %[[B:.*1]]: vector<4xf32>, %[[C:.*2]]: f32
29 // CHECK-SAME: %[[M:.*]]: vector<4xi1>
30 // CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32>
31 // CHECK: %[[R:.*]] = vector.mask %[[M]] { vector.reduction <add>, %0, %arg2 : vector<4xf32> into f32 } : vector<4xi1> -> f32
32 // CHECK: return %[[R]] : f32
34 func.func @masked_extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32, %mask: vector<4xi1>) -> f32 {
35 %0 = vector.mask %mask { vector.contract #dotp_trait %arg0, %arg1, %arg2 : vector<4xf32>, vector<4xf32> into f32 } : vector<4xi1> -> f32
39 // CHECK-LABEL: func @extract_contract1_int
40 // CHECK-SAME: %[[A:.*0]]: vector<4xi32>,
41 // CHECK-SAME: %[[B:.*1]]: vector<4xi32>,
42 // CHECK-SAME: %[[C:.*2]]: i32
43 // CHECK: %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32>
44 // CHECK: %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xi32> into i32
45 // CHECK: return %[[R]] : i32
47 func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 {
48 %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
49 : vector<4xi32>, vector<4xi32> into i32
54 affine_map<(i, j) -> (i, j)>,
55 affine_map<(i, j) -> (j)>,
56 affine_map<(i, j) -> (i)>
59 indexing_maps = #matvec_accesses,
60 iterator_types = ["parallel", "reduction"]
63 // CHECK-LABEL: func @extract_contract2
64 // CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
65 // CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
66 // CHECK-SAME: %[[C:.*2]]: vector<2xf32>
67 // CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
68 // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
69 // CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] : vector<3xf32>
70 // CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xf32> into f32
71 // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
72 // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
73 // CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
74 // CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xf32> into f32
75 // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
76 // CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32>
77 // CHECK: return %[[T10]] : vector<2xf32>
79 func.func @extract_contract2(%arg0: vector<2x3xf32>,
81 %arg2: vector<2xf32>) -> vector<2xf32> {
82 %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
83 : vector<2x3xf32>, vector<3xf32> into vector<2xf32>
84 return %0 : vector<2xf32>
87 // CHECK-LABEL: func @extract_contract2_int
88 // CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>,
89 // CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
90 // CHECK-SAME: %[[C:.*2]]: vector<2xi32>
91 // CHECK: %[[R:.*]] = arith.constant dense<0> : vector<2xi32>
92 // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xi32> from vector<2x3xi32>
93 // CHECK: %[[T2:.*]] = arith.muli %[[T0]], %[[B]] : vector<3xi32>
94 // CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xi32> into i32
95 // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32>
96 // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<3xi32> from vector<2x3xi32>
97 // CHECK: %[[T7:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
98 // CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xi32> into i32
99 // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32>
100 // CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[C]] : vector<2xi32>
101 // CHECK: return %[[T10]] : vector<2xi32>
102 func.func @extract_contract2_int(%arg0: vector<2x3xi32>,
103 %arg1: vector<3xi32>,
104 %arg2: vector<2xi32>) -> vector<2xi32> {
105 %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
106 : vector<2x3xi32>, vector<3xi32> into vector<2xi32>
107 return %0 : vector<2xi32>
111 affine_map<(i, j) -> (j)>,
112 affine_map<(i, j) -> (i, j)>,
113 affine_map<(i, j) -> (i)>
116 indexing_maps = #vecmat_accesses,
117 iterator_types = ["parallel", "reduction"]
120 // CHECK-LABEL: func @extract_contract3
121 // CHECK-SAME: %[[A:.*0]]: vector<3xf32>,
122 // CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
123 // CHECK-SAME: %[[C:.*2]]: vector<2xf32>
124 // CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
125 // CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<2x3xf32>
126 // CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[A]] : vector<3xf32>
127 // CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xf32> into f32
128 // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
129 // CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<2x3xf32>
130 // CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[A]] : vector<3xf32>
131 // CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xf32> into f32
132 // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
133 // CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32>
134 // CHECK: return %[[T10]] : vector<2xf32>
136 func.func @extract_contract3(%arg0: vector<3xf32>,
137 %arg1: vector<2x3xf32>,
138 %arg2: vector<2xf32>) -> vector<2xf32> {
139 %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2
140 : vector<3xf32>, vector<2x3xf32> into vector<2xf32>
141 return %0 : vector<2xf32>
145 affine_map<(i, j, k) -> (i, k)>,
146 affine_map<(i, j, k) -> (k, j)>,
147 affine_map<(i, j, k) -> (i, j)>
150 indexing_maps = #matmat_accesses,
151 iterator_types = ["parallel", "parallel", "reduction"]
154 // CHECK-LABEL: func @extract_contract4
155 // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
156 // CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
157 // CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
158 // CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
159 // CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32>
160 // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
161 // CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
162 // CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32>
163 // CHECK: %[[T10:.*]] = vector.reduction <add>, %[[T9]] : vector<2xf32> into f32
164 // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32>
166 // CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
167 // CHECK: %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32>
168 // CHECK: %[[T20:.*]] = vector.reduction <add>, %[[T19]] : vector<2xf32> into f32
169 // CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32>
171 // CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
172 // CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
173 // CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32>
174 // CHECK: %[[T33:.*]] = vector.reduction <add>, %[[T32]] : vector<2xf32> into f32
175 // CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32>
177 // CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
178 // CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32>
179 // CHECK: %[[T42:.*]] = vector.reduction <add>, %[[T41]] : vector<2xf32> into f32
180 // CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32>
182 // CHECK: %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32>
183 // CHECK: return %[[T52]] : vector<2x2xf32>
185 func.func @extract_contract4(%arg0: vector<2x2xf32>,
186 %arg1: vector<2x2xf32>,
187 %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
188 %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
189 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
190 return %0 : vector<2x2xf32>
194 #contraction2d_accesses = [
195 affine_map<(i, j) -> (i, j)>,
196 affine_map<(i, j) -> (i, j)>,
197 affine_map<(i, j) -> ()>
199 #contraction2d_trait = {
200 indexing_maps = #contraction2d_accesses,
201 iterator_types = ["reduction", "reduction"]
204 // CHECK-LABEL: func @full_contract1
205 // CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
206 // CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
207 // CHECK-SAME: %[[C:.*2]]: f32
208 // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
209 // CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<2x3xf32>
210 // CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32>
211 // CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[C]] : vector<3xf32> into f32
212 // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
213 // CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<2x3xf32>
214 // CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32>
215 // CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]], %[[T3]] : vector<3xf32> into f32
216 // CHECK: return %[[T8]] : f32
218 func.func @full_contract1(%arg0: vector<2x3xf32>,
219 %arg1: vector<2x3xf32>,
221 %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2
222 : vector<2x3xf32>, vector<2x3xf32> into f32
226 #contraction2d_trans_accesses = [
227 affine_map<(i, j) -> (i, j)>,
228 affine_map<(i, j) -> (j, i)>,
229 affine_map<(i, j) -> ()>
231 #contraction2d_trans_trait = {
232 indexing_maps = #contraction2d_trans_accesses,
233 iterator_types = ["reduction", "reduction"]
236 // CHECK-LABEL: func @full_contract2
237 // CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
238 // CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>,
239 // CHECK-SAME: %[[C:.*2]]: f32
240 // CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3xf32>
241 // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<3xf32> from vector<2x3xf32>
242 // CHECK: %[[T1:.*]] = vector.extract %[[B]][0, 0] : f32 from vector<3x2xf32>
243 // CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32>
244 // CHECK: %[[T4:.*]] = vector.extract %[[B]][1, 0] : f32 from vector<3x2xf32>
245 // CHECK: %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32>
246 // CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : f32 from vector<3x2xf32>
247 // CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32>
248 // CHECK: %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32>
249 // CHECK: %[[T11:.*]] = vector.reduction <add>, %[[T10]], %[[C]] : vector<3xf32> into f32
251 // CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<3xf32> from vector<2x3xf32>
252 // CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : f32 from vector<3x2xf32>
253 // CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32>
254 // CHECK: %[[T16:.*]] = vector.extract %[[B]][1, 1] : f32 from vector<3x2xf32>
255 // CHECK: %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32>
256 // CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : f32 from vector<3x2xf32>
257 // CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32>
258 // CHECK: %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32>
259 // CHECK: %[[T23:.*]] = vector.reduction <add>, %[[T22]], %[[T11]] : vector<3xf32> into f32
260 // CHECK: return %[[T23]] : f32
262 func.func @full_contract2(%arg0: vector<2x3xf32>,
263 %arg1: vector<3x2xf32>,
265 %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2
266 : vector<2x3xf32>, vector<3x2xf32> into f32
270 // CHECK-LABEL: @contract_one_sided_unit_reduction_dim
271 // CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>)
272 // CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32>
273 // CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<2xi32> from vector<1x2xi32>
274 // CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2xi32> from vector<2x2xi32>
275 // CHECK: %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32>
276 // CHECK: %[[R0:.+]] = vector.reduction <add>, %[[M0]] : vector<2xi32> into i32
277 // CHECK: %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32>
278 // CHECK: %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2xi32> from vector<2x2xi32>
279 // CHECK: %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32>
280 // CHECK: %[[R1:.+]] = vector.reduction <add>, %[[M1]] : vector<2xi32> into i32
281 // CHECK: %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32>
282 // CHECK: %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32>
283 // CHECK: return %[[S]] : vector<2xi32>
285 func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> {
286 %res = vector.contract {
288 affine_map<(d0, d1, d2) -> (d0, d2)>,
289 affine_map<(d0, d1, d2) -> (d1, d2)>,
290 affine_map<(d0, d1, d2) -> (d1)>
292 iterator_types = ["reduction", "parallel", "reduction"],
293 kind = #vector.kind<add>
294 } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32>
295 return %res : vector<2xi32>
298 module attributes {transform.with_named_sequence} {
299 transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
300 %f = transform.structured.match ops{["func.func"]} in %module_op
301 : (!transform.any_op) -> !transform.any_op
303 transform.apply_patterns to %f {
304 transform.apply_patterns.vector.lower_contraction lowering_strategy = "dot"
305 } : !transform.any_op