1 //===-- runtime/matmul-transpose.cpp --------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Implements a fused matmul-transpose operation
11 // There are two main entry points; one establishes a descriptor for the
12 // result and allocates it, and the other expects a result descriptor that
13 // points to existing storage.
15 // This implementation must handle all combinations of numeric types and
16 // kinds (100 - 165 cases depending on the target), plus all combinations
17 // of logical kinds (16). A single template undergoes many instantiations
18 // to cover all of the valid possibilities.
20 // The usefulness of this optimization should be reviewed once Matmul is swapped
21 // to use the faster BLAS routines.
23 #include "flang/Runtime/matmul-transpose.h"
24 #include "terminator.h"
26 #include "flang/Runtime/c-or-cpp.h"
27 #include "flang/Runtime/cpp-type.h"
28 #include "flang/Runtime/descriptor.h"
32 using namespace Fortran::runtime
;
34 // Contiguous numeric TRANSPOSE(matrix)*matrix multiplication
35 // TRANSPOSE(matrix(n, rows)) * matrix(n,cols) ->
36 // matrix(rows, n) * matrix(n,cols) -> matrix(rows,cols)
37 // The transpose is implemented by swapping the indices of accesses into the LHS
39 // Straightforward algorithm:
44 // 1 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J)
46 // With loop distribution and transposition to avoid the inner sum
47 // reduction and to avoid non-unit strides:
54 // 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term
55 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
56 bool X_HAS_STRIDED_COLUMNS
, bool Y_HAS_STRIDED_COLUMNS
>
57 inline static void MatrixTransposedTimesMatrix(
58 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
59 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
60 SubscriptValue n
, std::size_t xColumnByteStride
= 0,
61 std::size_t yColumnByteStride
= 0) {
62 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
64 std::memset(product
, 0, rows
* cols
* sizeof *product
);
65 for (SubscriptValue j
{0}; j
< cols
; ++j
) {
66 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
67 for (SubscriptValue k
{0}; k
< n
; ++k
) {
69 if constexpr (!X_HAS_STRIDED_COLUMNS
) {
70 x_ki
= static_cast<ResultType
>(x
[i
* n
+ k
]);
72 x_ki
= static_cast<ResultType
>(reinterpret_cast<const XT
*>(
73 reinterpret_cast<const char *>(x
) + i
* xColumnByteStride
)[k
]);
76 if constexpr (!Y_HAS_STRIDED_COLUMNS
) {
77 y_kj
= static_cast<ResultType
>(y
[j
* n
+ k
]);
79 y_kj
= static_cast<ResultType
>(reinterpret_cast<const YT
*>(
80 reinterpret_cast<const char *>(y
) + j
* yColumnByteStride
)[k
]);
82 product
[j
* rows
+ i
] += x_ki
* y_kj
;
88 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
89 inline static void MatrixTransposedTimesMatrixHelper(
90 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
91 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
92 SubscriptValue n
, std::optional
<std::size_t> xColumnByteStride
,
93 std::optional
<std::size_t> yColumnByteStride
) {
94 if (!xColumnByteStride
) {
95 if (!yColumnByteStride
) {
96 MatrixTransposedTimesMatrix
<RCAT
, RKIND
, XT
, YT
, false, false>(
97 product
, rows
, cols
, x
, y
, n
);
99 MatrixTransposedTimesMatrix
<RCAT
, RKIND
, XT
, YT
, false, true>(
100 product
, rows
, cols
, x
, y
, n
, 0, *yColumnByteStride
);
103 if (!yColumnByteStride
) {
104 MatrixTransposedTimesMatrix
<RCAT
, RKIND
, XT
, YT
, true, false>(
105 product
, rows
, cols
, x
, y
, n
, *xColumnByteStride
);
107 MatrixTransposedTimesMatrix
<RCAT
, RKIND
, XT
, YT
, true, true>(
108 product
, rows
, cols
, x
, y
, n
, *xColumnByteStride
, *yColumnByteStride
);
113 // Contiguous numeric matrix*vector multiplication
114 // matrix(rows,n) * column vector(n) -> column vector(rows)
115 // Straightforward algorithm:
119 // 1 RES(I) = RES(I) + X(K,I)*Y(K)
120 // With loop distribution and transposition to avoid the inner
121 // sum reduction and to avoid non-unit strides:
126 // 2 RES(I) = RES(I) + X(K,I)*Y(K)
127 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
128 bool X_HAS_STRIDED_COLUMNS
>
129 inline static void MatrixTransposedTimesVector(
130 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
131 SubscriptValue n
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
132 std::size_t xColumnByteStride
= 0) {
133 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
134 std::memset(product
, 0, rows
* sizeof *product
);
135 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
136 for (SubscriptValue k
{0}; k
< n
; ++k
) {
138 if constexpr (!X_HAS_STRIDED_COLUMNS
) {
139 x_ki
= static_cast<ResultType
>(x
[i
* n
+ k
]);
141 x_ki
= static_cast<ResultType
>(reinterpret_cast<const XT
*>(
142 reinterpret_cast<const char *>(x
) + i
* xColumnByteStride
)[k
]);
144 ResultType y_k
= static_cast<ResultType
>(y
[k
]);
145 product
[i
] += x_ki
* y_k
;
150 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
151 inline static void MatrixTransposedTimesVectorHelper(
152 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
153 SubscriptValue n
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
154 std::optional
<std::size_t> xColumnByteStride
) {
155 if (!xColumnByteStride
) {
156 MatrixTransposedTimesVector
<RCAT
, RKIND
, XT
, YT
, false>(
157 product
, rows
, n
, x
, y
);
159 MatrixTransposedTimesVector
<RCAT
, RKIND
, XT
, YT
, true>(
160 product
, rows
, n
, x
, y
, *xColumnByteStride
);
164 // Implements an instance of MATMUL for given argument types.
165 template <bool IS_ALLOCATING
, TypeCategory RCAT
, int RKIND
, typename XT
,
167 inline static void DoMatmulTranspose(
168 std::conditional_t
<IS_ALLOCATING
, Descriptor
, const Descriptor
> &result
,
169 const Descriptor
&x
, const Descriptor
&y
, Terminator
&terminator
) {
172 int resRank
{xRank
+ yRank
- 2};
173 if (xRank
* yRank
!= 2 * resRank
) {
175 "MATMUL-TRANSPOSE: bad argument ranks (%d * %d)", xRank
, yRank
);
177 SubscriptValue extent
[2]{x
.GetDimension(1).Extent(),
178 resRank
== 2 ? y
.GetDimension(1).Extent() : 0};
179 if constexpr (IS_ALLOCATING
) {
181 RCAT
, RKIND
, nullptr, resRank
, extent
, CFI_attribute_allocatable
);
182 for (int j
{0}; j
< resRank
; ++j
) {
183 result
.GetDimension(j
).SetBounds(1, extent
[j
]);
185 if (int stat
{result
.Allocate()}) {
187 "MATMUL-TRANSPOSE: could not allocate memory for result; STAT=%d",
191 RUNTIME_CHECK(terminator
, resRank
== result
.rank());
193 terminator
, result
.ElementBytes() == static_cast<std::size_t>(RKIND
));
194 RUNTIME_CHECK(terminator
, result
.GetDimension(0).Extent() == extent
[0]);
195 RUNTIME_CHECK(terminator
,
196 resRank
== 1 || result
.GetDimension(1).Extent() == extent
[1]);
198 SubscriptValue n
{x
.GetDimension(0).Extent()};
199 if (n
!= y
.GetDimension(0).Extent()) {
201 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
202 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
203 static_cast<std::intmax_t>(x
.GetDimension(1).Extent()),
204 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
205 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
208 CppTypeFor
<RCAT
== TypeCategory::Logical
? TypeCategory::Integer
: RCAT
,
210 const SubscriptValue rows
{extent
[0]};
211 const SubscriptValue cols
{extent
[1]};
212 if constexpr (RCAT
!= TypeCategory::Logical
) {
213 if (x
.IsContiguous(1) && y
.IsContiguous(1) &&
214 (IS_ALLOCATING
|| result
.IsContiguous())) {
215 // Contiguous numeric matrices (maybe with columns
216 // separated by a stride).
217 std::optional
<std::size_t> xColumnByteStride
;
218 if (!x
.IsContiguous()) {
219 // X's columns are strided.
220 SubscriptValue xAt
[2]{};
221 x
.GetLowerBounds(xAt
);
223 xColumnByteStride
= x
.SubscriptsToByteOffset(xAt
);
225 std::optional
<std::size_t> yColumnByteStride
;
226 if (!y
.IsContiguous()) {
227 // Y's columns are strided.
228 SubscriptValue yAt
[2]{};
229 y
.GetLowerBounds(yAt
);
231 yColumnByteStride
= y
.SubscriptsToByteOffset(yAt
);
233 if (resRank
== 2) { // M*M -> M
234 // TODO: use BLAS-3 GEMM for supported types.
235 MatrixTransposedTimesMatrixHelper
<RCAT
, RKIND
, XT
, YT
>(
236 result
.template OffsetElement
<WriteResult
>(), rows
, cols
,
237 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>(), n
, xColumnByteStride
,
241 if (xRank
== 2) { // M*V -> V
242 // TODO: use BLAS-2 GEMM for supported types.
243 MatrixTransposedTimesVectorHelper
<RCAT
, RKIND
, XT
, YT
>(
244 result
.template OffsetElement
<WriteResult
>(), rows
, n
,
245 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>(), xColumnByteStride
);
248 // else V*M -> V (not allowed because TRANSPOSE() is only defined for rank
251 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
252 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
253 static_cast<std::intmax_t>(n
),
254 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
255 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
259 // General algorithms for LOGICAL and noncontiguity
260 SubscriptValue xLB
[2], yLB
[2], resLB
[2];
261 x
.GetLowerBounds(xLB
);
262 y
.GetLowerBounds(yLB
);
263 result
.GetLowerBounds(resLB
);
264 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
265 if (resRank
== 2) { // M*M -> M
266 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
267 for (SubscriptValue j
{0}; j
< cols
; ++j
) {
269 if constexpr (RCAT
== TypeCategory::Logical
) {
275 for (SubscriptValue k
{0}; k
< n
; ++k
) {
276 SubscriptValue xAt
[2]{k
+ xLB
[0], i
+ xLB
[1]};
277 SubscriptValue yAt
[2]{k
+ yLB
[0], j
+ yLB
[1]};
278 if constexpr (RCAT
== TypeCategory::Logical
) {
279 ResultType x_ki
= IsLogicalElementTrue(x
, xAt
);
280 ResultType y_kj
= IsLogicalElementTrue(y
, yAt
);
281 res_ij
= res_ij
|| (x_ki
&& y_kj
);
283 ResultType x_ki
= static_cast<ResultType
>(*x
.Element
<XT
>(xAt
));
284 ResultType y_kj
= static_cast<ResultType
>(*y
.Element
<YT
>(yAt
));
285 res_ij
+= x_ki
* y_kj
;
288 SubscriptValue resAt
[2]{i
+ resLB
[0], j
+ resLB
[1]};
289 *result
.template Element
<WriteResult
>(resAt
) = res_ij
;
292 } else if (xRank
== 2) { // M*V -> V
293 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
295 if constexpr (RCAT
== TypeCategory::Logical
) {
301 for (SubscriptValue k
{0}; k
< n
; ++k
) {
302 SubscriptValue xAt
[2]{k
+ xLB
[0], i
+ xLB
[1]};
303 SubscriptValue yAt
[1]{k
+ yLB
[0]};
304 if constexpr (RCAT
== TypeCategory::Logical
) {
305 ResultType x_ki
= IsLogicalElementTrue(x
, xAt
);
306 ResultType y_k
= IsLogicalElementTrue(y
, yAt
);
307 res_i
= res_i
|| (x_ki
&& y_k
);
309 ResultType x_ki
= static_cast<ResultType
>(*x
.Element
<XT
>(xAt
));
310 ResultType y_k
= static_cast<ResultType
>(*y
.Element
<YT
>(yAt
));
314 SubscriptValue resAt
[1]{i
+ resLB
[0]};
315 *result
.template Element
<WriteResult
>(resAt
) = res_i
;
318 // TRANSPOSE(V) not allowed by fortran standard
320 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
321 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
322 static_cast<std::intmax_t>(n
),
323 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
324 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
328 // Maps the dynamic type information from the arguments' descriptors
329 // to the right instantiation of DoMatmul() for valid combinations of
331 template <bool IS_ALLOCATING
> struct MatmulTranspose
{
332 using ResultDescriptor
=
333 std::conditional_t
<IS_ALLOCATING
, Descriptor
, const Descriptor
>;
334 template <TypeCategory XCAT
, int XKIND
> struct MM1
{
335 template <TypeCategory YCAT
, int YKIND
> struct MM2
{
336 void operator()(ResultDescriptor
&result
, const Descriptor
&x
,
337 const Descriptor
&y
, Terminator
&terminator
) const {
338 if constexpr (constexpr auto resultType
{
339 GetResultType(XCAT
, XKIND
, YCAT
, YKIND
)}) {
340 if constexpr (Fortran::common::IsNumericTypeCategory(
341 resultType
->first
) ||
342 resultType
->first
== TypeCategory::Logical
) {
343 return DoMatmulTranspose
<IS_ALLOCATING
, resultType
->first
,
344 resultType
->second
, CppTypeFor
<XCAT
, XKIND
>,
345 CppTypeFor
<YCAT
, YKIND
>>(result
, x
, y
, terminator
);
348 terminator
.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))",
349 static_cast<int>(XCAT
), XKIND
, static_cast<int>(YCAT
), YKIND
);
352 void operator()(ResultDescriptor
&result
, const Descriptor
&x
,
353 const Descriptor
&y
, Terminator
&terminator
, TypeCategory yCat
,
355 ApplyType
<MM2
, void>(yCat
, yKind
, terminator
, result
, x
, y
, terminator
);
358 void operator()(ResultDescriptor
&result
, const Descriptor
&x
,
359 const Descriptor
&y
, const char *sourceFile
, int line
) const {
360 Terminator terminator
{sourceFile
, line
};
361 auto xCatKind
{x
.type().GetCategoryAndKind()};
362 auto yCatKind
{y
.type().GetCategoryAndKind()};
363 RUNTIME_CHECK(terminator
, xCatKind
.has_value() && yCatKind
.has_value());
364 ApplyType
<MM1
, void>(xCatKind
->first
, xCatKind
->second
, terminator
, result
,
365 x
, y
, terminator
, yCatKind
->first
, yCatKind
->second
);
370 namespace Fortran::runtime
{
372 void RTNAME(MatmulTranspose
)(Descriptor
&result
, const Descriptor
&x
,
373 const Descriptor
&y
, const char *sourceFile
, int line
) {
374 MatmulTranspose
<true>{}(result
, x
, y
, sourceFile
, line
);
376 void RTNAME(MatmulTransposeDirect
)(const Descriptor
&result
,
377 const Descriptor
&x
, const Descriptor
&y
, const char *sourceFile
,
379 MatmulTranspose
<false>{}(result
, x
, y
, sourceFile
, line
);
382 } // namespace Fortran::runtime