1 //===-- runtime/matmul-transpose.cpp --------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Implements a fused matmul-transpose operation
11 // There are two main entry points; one establishes a descriptor for the
12 // result and allocates it, and the other expects a result descriptor that
13 // points to existing storage.
15 // This implementation must handle all combinations of numeric types and
16 // kinds (100 - 165 cases depending on the target), plus all combinations
17 // of logical kinds (16). A single template undergoes many instantiations
18 // to cover all of the valid possibilities.
20 // The usefulness of this optimization should be reviewed once Matmul is swapped
21 // to use the faster BLAS routines.
23 #include "flang/Runtime/matmul-transpose.h"
24 #include "terminator.h"
26 #include "flang/Common/optional.h"
27 #include "flang/Runtime/c-or-cpp.h"
28 #include "flang/Runtime/cpp-type.h"
29 #include "flang/Runtime/descriptor.h"
33 using namespace Fortran::runtime
;
35 // Contiguous numeric TRANSPOSE(matrix)*matrix multiplication
36 // TRANSPOSE(matrix(n, rows)) * matrix(n,cols) ->
37 // matrix(rows, n) * matrix(n,cols) -> matrix(rows,cols)
38 // The transpose is implemented by swapping the indices of accesses into the LHS
40 // Straightforward algorithm:
45 // 1 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J)
47 // With loop distribution and transposition to avoid the inner sum
48 // reduction and to avoid non-unit strides:
55 // 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term
56 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
57 bool X_HAS_STRIDED_COLUMNS
, bool Y_HAS_STRIDED_COLUMNS
>
58 inline static RT_API_ATTRS
void MatrixTransposedTimesMatrix(
59 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
60 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
61 SubscriptValue n
, std::size_t xColumnByteStride
= 0,
62 std::size_t yColumnByteStride
= 0) {
63 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
65 std::memset(product
, 0, rows
* cols
* sizeof *product
);
66 for (SubscriptValue j
{0}; j
< cols
; ++j
) {
67 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
68 for (SubscriptValue k
{0}; k
< n
; ++k
) {
70 if constexpr (!X_HAS_STRIDED_COLUMNS
) {
71 x_ki
= static_cast<ResultType
>(x
[i
* n
+ k
]);
73 x_ki
= static_cast<ResultType
>(reinterpret_cast<const XT
*>(
74 reinterpret_cast<const char *>(x
) + i
* xColumnByteStride
)[k
]);
77 if constexpr (!Y_HAS_STRIDED_COLUMNS
) {
78 y_kj
= static_cast<ResultType
>(y
[j
* n
+ k
]);
80 y_kj
= static_cast<ResultType
>(reinterpret_cast<const YT
*>(
81 reinterpret_cast<const char *>(y
) + j
* yColumnByteStride
)[k
]);
83 product
[j
* rows
+ i
] += x_ki
* y_kj
;
89 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
90 inline static RT_API_ATTRS
void MatrixTransposedTimesMatrixHelper(
91 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
92 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
93 SubscriptValue n
, Fortran::common::optional
<std::size_t> xColumnByteStride
,
94 Fortran::common::optional
<std::size_t> yColumnByteStride
) {
95 if (!xColumnByteStride
) {
96 if (!yColumnByteStride
) {
97 MatrixTransposedTimesMatrix
<RCAT
, RKIND
, XT
, YT
, false, false>(
98 product
, rows
, cols
, x
, y
, n
);
100 MatrixTransposedTimesMatrix
<RCAT
, RKIND
, XT
, YT
, false, true>(
101 product
, rows
, cols
, x
, y
, n
, 0, *yColumnByteStride
);
104 if (!yColumnByteStride
) {
105 MatrixTransposedTimesMatrix
<RCAT
, RKIND
, XT
, YT
, true, false>(
106 product
, rows
, cols
, x
, y
, n
, *xColumnByteStride
);
108 MatrixTransposedTimesMatrix
<RCAT
, RKIND
, XT
, YT
, true, true>(
109 product
, rows
, cols
, x
, y
, n
, *xColumnByteStride
, *yColumnByteStride
);
114 // Contiguous numeric matrix*vector multiplication
115 // matrix(rows,n) * column vector(n) -> column vector(rows)
116 // Straightforward algorithm:
120 // 1 RES(I) = RES(I) + X(K,I)*Y(K)
121 // With loop distribution and transposition to avoid the inner
122 // sum reduction and to avoid non-unit strides:
127 // 2 RES(I) = RES(I) + X(K,I)*Y(K)
128 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
129 bool X_HAS_STRIDED_COLUMNS
>
130 inline static RT_API_ATTRS
void MatrixTransposedTimesVector(
131 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
132 SubscriptValue n
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
133 std::size_t xColumnByteStride
= 0) {
134 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
135 std::memset(product
, 0, rows
* sizeof *product
);
136 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
137 for (SubscriptValue k
{0}; k
< n
; ++k
) {
139 if constexpr (!X_HAS_STRIDED_COLUMNS
) {
140 x_ki
= static_cast<ResultType
>(x
[i
* n
+ k
]);
142 x_ki
= static_cast<ResultType
>(reinterpret_cast<const XT
*>(
143 reinterpret_cast<const char *>(x
) + i
* xColumnByteStride
)[k
]);
145 ResultType y_k
= static_cast<ResultType
>(y
[k
]);
146 product
[i
] += x_ki
* y_k
;
151 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
152 inline static RT_API_ATTRS
void MatrixTransposedTimesVectorHelper(
153 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
154 SubscriptValue n
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
155 Fortran::common::optional
<std::size_t> xColumnByteStride
) {
156 if (!xColumnByteStride
) {
157 MatrixTransposedTimesVector
<RCAT
, RKIND
, XT
, YT
, false>(
158 product
, rows
, n
, x
, y
);
160 MatrixTransposedTimesVector
<RCAT
, RKIND
, XT
, YT
, true>(
161 product
, rows
, n
, x
, y
, *xColumnByteStride
);
165 // Implements an instance of MATMUL for given argument types.
166 template <bool IS_ALLOCATING
, TypeCategory RCAT
, int RKIND
, typename XT
,
168 inline static RT_API_ATTRS
void DoMatmulTranspose(
169 std::conditional_t
<IS_ALLOCATING
, Descriptor
, const Descriptor
> &result
,
170 const Descriptor
&x
, const Descriptor
&y
, Terminator
&terminator
) {
173 int resRank
{xRank
+ yRank
- 2};
174 if (xRank
* yRank
!= 2 * resRank
) {
176 "MATMUL-TRANSPOSE: bad argument ranks (%d * %d)", xRank
, yRank
);
178 SubscriptValue extent
[2]{x
.GetDimension(1).Extent(),
179 resRank
== 2 ? y
.GetDimension(1).Extent() : 0};
180 if constexpr (IS_ALLOCATING
) {
182 RCAT
, RKIND
, nullptr, resRank
, extent
, CFI_attribute_allocatable
);
183 for (int j
{0}; j
< resRank
; ++j
) {
184 result
.GetDimension(j
).SetBounds(1, extent
[j
]);
186 if (int stat
{result
.Allocate()}) {
188 "MATMUL-TRANSPOSE: could not allocate memory for result; STAT=%d",
192 RUNTIME_CHECK(terminator
, resRank
== result
.rank());
194 terminator
, result
.ElementBytes() == static_cast<std::size_t>(RKIND
));
195 RUNTIME_CHECK(terminator
, result
.GetDimension(0).Extent() == extent
[0]);
196 RUNTIME_CHECK(terminator
,
197 resRank
== 1 || result
.GetDimension(1).Extent() == extent
[1]);
199 SubscriptValue n
{x
.GetDimension(0).Extent()};
200 if (n
!= y
.GetDimension(0).Extent()) {
202 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
203 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
204 static_cast<std::intmax_t>(x
.GetDimension(1).Extent()),
205 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
206 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
209 CppTypeFor
<RCAT
== TypeCategory::Logical
? TypeCategory::Integer
: RCAT
,
211 const SubscriptValue rows
{extent
[0]};
212 const SubscriptValue cols
{extent
[1]};
213 if constexpr (RCAT
!= TypeCategory::Logical
) {
214 if (x
.IsContiguous(1) && y
.IsContiguous(1) &&
215 (IS_ALLOCATING
|| result
.IsContiguous())) {
216 // Contiguous numeric matrices (maybe with columns
217 // separated by a stride).
218 Fortran::common::optional
<std::size_t> xColumnByteStride
;
219 if (!x
.IsContiguous()) {
220 // X's columns are strided.
221 SubscriptValue xAt
[2]{};
222 x
.GetLowerBounds(xAt
);
224 xColumnByteStride
= x
.SubscriptsToByteOffset(xAt
);
226 Fortran::common::optional
<std::size_t> yColumnByteStride
;
227 if (!y
.IsContiguous()) {
228 // Y's columns are strided.
229 SubscriptValue yAt
[2]{};
230 y
.GetLowerBounds(yAt
);
232 yColumnByteStride
= y
.SubscriptsToByteOffset(yAt
);
234 if (resRank
== 2) { // M*M -> M
235 // TODO: use BLAS-3 GEMM for supported types.
236 MatrixTransposedTimesMatrixHelper
<RCAT
, RKIND
, XT
, YT
>(
237 result
.template OffsetElement
<WriteResult
>(), rows
, cols
,
238 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>(), n
, xColumnByteStride
,
242 if (xRank
== 2) { // M*V -> V
243 // TODO: use BLAS-2 GEMM for supported types.
244 MatrixTransposedTimesVectorHelper
<RCAT
, RKIND
, XT
, YT
>(
245 result
.template OffsetElement
<WriteResult
>(), rows
, n
,
246 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>(), xColumnByteStride
);
249 // else V*M -> V (not allowed because TRANSPOSE() is only defined for rank
252 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
253 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
254 static_cast<std::intmax_t>(n
),
255 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
256 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
260 // General algorithms for LOGICAL and noncontiguity
261 SubscriptValue xLB
[2], yLB
[2], resLB
[2];
262 x
.GetLowerBounds(xLB
);
263 y
.GetLowerBounds(yLB
);
264 result
.GetLowerBounds(resLB
);
265 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
266 if (resRank
== 2) { // M*M -> M
267 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
268 for (SubscriptValue j
{0}; j
< cols
; ++j
) {
270 if constexpr (RCAT
== TypeCategory::Logical
) {
276 for (SubscriptValue k
{0}; k
< n
; ++k
) {
277 SubscriptValue xAt
[2]{k
+ xLB
[0], i
+ xLB
[1]};
278 SubscriptValue yAt
[2]{k
+ yLB
[0], j
+ yLB
[1]};
279 if constexpr (RCAT
== TypeCategory::Logical
) {
280 ResultType x_ki
= IsLogicalElementTrue(x
, xAt
);
281 ResultType y_kj
= IsLogicalElementTrue(y
, yAt
);
282 res_ij
= res_ij
|| (x_ki
&& y_kj
);
284 ResultType x_ki
= static_cast<ResultType
>(*x
.Element
<XT
>(xAt
));
285 ResultType y_kj
= static_cast<ResultType
>(*y
.Element
<YT
>(yAt
));
286 res_ij
+= x_ki
* y_kj
;
289 SubscriptValue resAt
[2]{i
+ resLB
[0], j
+ resLB
[1]};
290 *result
.template Element
<WriteResult
>(resAt
) = res_ij
;
293 } else if (xRank
== 2) { // M*V -> V
294 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
296 if constexpr (RCAT
== TypeCategory::Logical
) {
302 for (SubscriptValue k
{0}; k
< n
; ++k
) {
303 SubscriptValue xAt
[2]{k
+ xLB
[0], i
+ xLB
[1]};
304 SubscriptValue yAt
[1]{k
+ yLB
[0]};
305 if constexpr (RCAT
== TypeCategory::Logical
) {
306 ResultType x_ki
= IsLogicalElementTrue(x
, xAt
);
307 ResultType y_k
= IsLogicalElementTrue(y
, yAt
);
308 res_i
= res_i
|| (x_ki
&& y_k
);
310 ResultType x_ki
= static_cast<ResultType
>(*x
.Element
<XT
>(xAt
));
311 ResultType y_k
= static_cast<ResultType
>(*y
.Element
<YT
>(yAt
));
315 SubscriptValue resAt
[1]{i
+ resLB
[0]};
316 *result
.template Element
<WriteResult
>(resAt
) = res_i
;
319 // TRANSPOSE(V) not allowed by fortran standard
321 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
322 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
323 static_cast<std::intmax_t>(n
),
324 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
325 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
329 template <bool IS_ALLOCATING
, TypeCategory XCAT
, int XKIND
, TypeCategory YCAT
,
331 struct MatmulTransposeHelper
{
332 using ResultDescriptor
=
333 std::conditional_t
<IS_ALLOCATING
, Descriptor
, const Descriptor
>;
334 RT_API_ATTRS
void operator()(ResultDescriptor
&result
, const Descriptor
&x
,
335 const Descriptor
&y
, const char *sourceFile
, int line
) const {
336 Terminator terminator
{sourceFile
, line
};
337 auto xCatKind
{x
.type().GetCategoryAndKind()};
338 auto yCatKind
{y
.type().GetCategoryAndKind()};
339 RUNTIME_CHECK(terminator
, xCatKind
.has_value() && yCatKind
.has_value());
340 RUNTIME_CHECK(terminator
, xCatKind
->first
== XCAT
);
341 RUNTIME_CHECK(terminator
, yCatKind
->first
== YCAT
);
342 if constexpr (constexpr auto resultType
{
343 GetResultType(XCAT
, XKIND
, YCAT
, YKIND
)}) {
344 return DoMatmulTranspose
<IS_ALLOCATING
, resultType
->first
,
345 resultType
->second
, CppTypeFor
<XCAT
, XKIND
>, CppTypeFor
<YCAT
, YKIND
>>(
346 result
, x
, y
, terminator
);
348 terminator
.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))",
349 static_cast<int>(XCAT
), XKIND
, static_cast<int>(YCAT
), YKIND
);
354 namespace Fortran::runtime
{
356 RT_EXT_API_GROUP_BEGIN
358 #define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
359 void RTDEF(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
360 const Descriptor &x, const Descriptor &y, const char *sourceFile, \
362 MatmulTransposeHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
363 YKIND>{}(result, x, y, sourceFile, line); \
366 #define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
367 void RTDEF(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \
368 Descriptor & result, const Descriptor &x, const Descriptor &y, \
369 const char *sourceFile, int line) { \
370 MatmulTransposeHelper<false, TypeCategory::XCAT, XKIND, \
371 TypeCategory::YCAT, YKIND>{}(result, x, y, sourceFile, line); \
374 #define MATMUL_FORCE_ALL_TYPES 0
376 #include "flang/Runtime/matmul-instances.inc"
380 } // namespace Fortran::runtime