1 //===- StructuredOpsUtilsTest.cpp - StructuredOpsUtils unit tests ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "gmock/gmock.h"
13 #include "gtest/gtest.h"
21 TEST(isRowMajorMatmul
, Simple
) {
25 bindDims(&context
, m
, n
, k
);
26 auto mapA
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, k
}, &context
));
27 auto mapB
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, n
}, &context
));
28 auto mapC
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, n
}, &context
));
29 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
31 EXPECT_THAT(maps
, Truly(isRowMajorMatmul
));
34 TEST(isRowMajorMatmul
, BindingShifted
) {
38 bindDims(&context
, k
, m
, n
); // bind in different order
39 auto mapA
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, k
}, &context
));
40 auto mapB
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, n
}, &context
));
41 auto mapC
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, n
}, &context
));
42 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
44 EXPECT_THAT(maps
, Truly(isRowMajorMatmul
));
47 TEST(isRowMajorMatmul
, BindingSwapped
) {
51 bindDims(&context
, k
, n
, m
); // bind in different order
52 auto mapA
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, k
}, &context
));
53 auto mapB
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, n
}, &context
));
54 auto mapC
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, n
}, &context
));
55 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
57 EXPECT_THAT(maps
, Truly(isRowMajorMatmul
));
60 TEST(isRowMajorMatmul
, ColumnMajor
) {
64 bindDims(&context
, m
, n
, k
);
65 auto mapA
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, n
}, &context
));
66 auto mapB
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, k
}, &context
));
67 auto mapC
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, n
}, &context
));
68 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
70 EXPECT_THAT(maps
, Not(Truly(isRowMajorMatmul
)));
73 TEST(isRowMajorMatmul
, FirstInputSwapped
) {
77 bindDims(&context
, m
, n
, k
);
78 auto mapA
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, m
}, &context
));
79 auto mapB
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, n
}, &context
));
80 auto mapC
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, n
}, &context
));
81 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
83 EXPECT_THAT(maps
, Not(Truly(isRowMajorMatmul
)));
86 TEST(isRowMajorMatmul
, TooFewMaps
) {
90 bindDims(&context
, m
, n
, k
);
91 auto mapA
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, k
}, &context
));
92 auto mapB
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, n
}, &context
));
93 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
});
95 EXPECT_THAT(maps
, Not(Truly(isRowMajorMatmul
)));
98 TEST(isRowMajorMatmul
, TooManyMaps
) {
102 bindDims(&context
, m
, n
, k
);
103 auto mapA
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, k
}, &context
));
104 auto mapB
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, n
}, &context
));
105 auto mapC
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, n
}, &context
));
106 auto mapD
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, m
}, &context
));
108 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
, mapD
});
110 EXPECT_THAT(maps
, Not(Truly(isRowMajorMatmul
)));
113 TEST(isRowMajorMatmul
, TooFewOutputs
) {
117 bindDims(&context
, m
, n
, k
);
118 auto mapA
= AffineMapAttr::get(AffineMap::get(3, 0, {m
}, &context
));
119 auto mapB
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, n
}, &context
));
120 auto mapC
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, n
}, &context
));
121 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
123 EXPECT_THAT(maps
, Not(Truly(isRowMajorMatmul
)));
126 TEST(isColumnMajorMatmul
, Simple
) {
130 bindDims(&context
, m
, n
, k
);
131 auto mapA
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, n
}, &context
));
132 auto mapB
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, k
}, &context
));
133 auto mapC
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, n
}, &context
));
134 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
136 EXPECT_THAT(maps
, Truly(isColumnMajorMatmul
));
139 TEST(isColumnMajorMatmul
, BindingShifted
) {
143 bindDims(&context
, k
, m
, n
); // bind in different order
144 auto mapA
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, n
}, &context
));
145 auto mapB
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, k
}, &context
));
146 auto mapC
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, n
}, &context
));
147 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
149 EXPECT_THAT(maps
, Truly(isColumnMajorMatmul
));
152 TEST(isColumnMajorMatmul
, BindingSwapped
) {
156 bindDims(&context
, k
, n
, m
); // bind in different order
157 auto mapA
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, n
}, &context
));
158 auto mapB
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, k
}, &context
));
159 auto mapC
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, n
}, &context
));
160 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
162 EXPECT_THAT(maps
, Truly(isColumnMajorMatmul
));
165 TEST(isColumnMajorMatmul
, RowMajor
) {
169 bindDims(&context
, m
, n
, k
);
170 auto mapA
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, k
}, &context
));
171 auto mapB
= AffineMapAttr::get(AffineMap::get(3, 0, {k
, n
}, &context
));
172 auto mapC
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, n
}, &context
));
173 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
175 EXPECT_THAT(maps
, Not(Truly(isColumnMajorMatmul
)));
178 TEST(isColumnMajorMatmul
, FirstInputSwapped
) {
182 bindDims(&context
, m
, n
, k
);
183 auto mapA
= AffineMapAttr::get(AffineMap::get(3, 0, {n
, k
}, &context
));
184 auto mapB
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, k
}, &context
));
185 auto mapC
= AffineMapAttr::get(AffineMap::get(3, 0, {m
, n
}, &context
));
186 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
188 EXPECT_THAT(maps
, Not(Truly(isColumnMajorMatmul
)));
191 TEST(isRowMajorBatchMatmul
, Simple
) {
194 AffineExpr batch
, m
, n
, k
;
195 bindDims(&context
, batch
, m
, n
, k
);
196 auto mapA
= AffineMapAttr::get(AffineMap::get(4, 0, {batch
, m
, k
}, &context
));
197 auto mapB
= AffineMapAttr::get(AffineMap::get(4, 0, {batch
, k
, n
}, &context
));
198 auto mapC
= AffineMapAttr::get(AffineMap::get(4, 0, {batch
, m
, n
}, &context
));
199 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
201 EXPECT_THAT(maps
, Truly(isRowMajorBatchMatmul
));
204 TEST(isRowMajorBatchMatmul
, BindingShifted
) {
207 AffineExpr batch
, m
, n
, k
;
208 bindDims(&context
, k
, batch
, m
, n
); // bind in different order
209 auto mapA
= AffineMapAttr::get(AffineMap::get(4, 0, {batch
, m
, k
}, &context
));
210 auto mapB
= AffineMapAttr::get(AffineMap::get(4, 0, {batch
, k
, n
}, &context
));
211 auto mapC
= AffineMapAttr::get(AffineMap::get(4, 0, {batch
, m
, n
}, &context
));
212 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
214 EXPECT_THAT(maps
, Truly(isRowMajorBatchMatmul
));
217 TEST(isRowMajorBatchMatmul
, BindingSwapped
) {
220 AffineExpr batch
, m
, n
, k
;
221 bindDims(&context
, batch
, k
, n
, m
); // bind in different order
222 auto mapA
= AffineMapAttr::get(AffineMap::get(4, 0, {batch
, m
, k
}, &context
));
223 auto mapB
= AffineMapAttr::get(AffineMap::get(4, 0, {batch
, k
, n
}, &context
));
224 auto mapC
= AffineMapAttr::get(AffineMap::get(4, 0, {batch
, m
, n
}, &context
));
225 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
227 EXPECT_THAT(maps
, Truly(isRowMajorBatchMatmul
));
230 TEST(isRowMajorBatchMatmul
, FirstInputSwapped
) {
233 AffineExpr batch
, m
, n
, k
;
234 bindDims(&context
, batch
, m
, n
, k
);
235 auto mapA
= AffineMapAttr::get(AffineMap::get(4, 0, {batch
, k
, m
}, &context
));
236 auto mapB
= AffineMapAttr::get(AffineMap::get(4, 0, {batch
, k
, n
}, &context
));
237 auto mapC
= AffineMapAttr::get(AffineMap::get(4, 0, {batch
, m
, n
}, &context
));
238 auto maps
= ArrayAttr::get(&context
, {mapA
, mapB
, mapC
});
240 EXPECT_THAT(maps
, Not(Truly(isRowMajorBatchMatmul
)));