1 //===-- runtime/matmul-transpose.cpp --------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Implements a fused matmul-transpose operation
11 // There are two main entry points; one establishes a descriptor for the
12 // result and allocates it, and the other expects a result descriptor that
13 // points to existing storage.
15 // This implementation must handle all combinations of numeric types and
16 // kinds (100 - 165 cases depending on the target), plus all combinations
17 // of logical kinds (16). A single template undergoes many instantiations
18 // to cover all of the valid possibilities.
20 // The usefulness of this optimization should be reviewed once Matmul is swapped
21 // to use the faster BLAS routines.
23 #include "flang/Runtime/matmul-transpose.h"
24 #include "terminator.h"
26 #include "flang/Runtime/c-or-cpp.h"
27 #include "flang/Runtime/cpp-type.h"
28 #include "flang/Runtime/descriptor.h"
32 using namespace Fortran::runtime
;
34 // Suppress the warnings about calling __host__-only std::complex operators,
35 // defined in C++ STD header files, from __device__ code.
37 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
39 // Contiguous numeric TRANSPOSE(matrix)*matrix multiplication
40 // TRANSPOSE(matrix(n, rows)) * matrix(n,cols) ->
41 // matrix(rows, n) * matrix(n,cols) -> matrix(rows,cols)
42 // The transpose is implemented by swapping the indices of accesses into the LHS
44 // Straightforward algorithm:
49 // 1 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J)
51 // With loop distribution and transposition to avoid the inner sum
52 // reduction and to avoid non-unit strides:
59 // 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term
60 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
61 bool X_HAS_STRIDED_COLUMNS
, bool Y_HAS_STRIDED_COLUMNS
>
62 inline static RT_API_ATTRS
void MatrixTransposedTimesMatrix(
63 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
64 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
65 SubscriptValue n
, std::size_t xColumnByteStride
= 0,
66 std::size_t yColumnByteStride
= 0) {
67 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
69 std::memset(product
, 0, rows
* cols
* sizeof *product
);
70 for (SubscriptValue j
{0}; j
< cols
; ++j
) {
71 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
72 for (SubscriptValue k
{0}; k
< n
; ++k
) {
74 if constexpr (!X_HAS_STRIDED_COLUMNS
) {
75 x_ki
= static_cast<ResultType
>(x
[i
* n
+ k
]);
77 x_ki
= static_cast<ResultType
>(reinterpret_cast<const XT
*>(
78 reinterpret_cast<const char *>(x
) + i
* xColumnByteStride
)[k
]);
81 if constexpr (!Y_HAS_STRIDED_COLUMNS
) {
82 y_kj
= static_cast<ResultType
>(y
[j
* n
+ k
]);
84 y_kj
= static_cast<ResultType
>(reinterpret_cast<const YT
*>(
85 reinterpret_cast<const char *>(y
) + j
* yColumnByteStride
)[k
]);
87 product
[j
* rows
+ i
] += x_ki
* y_kj
;
95 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
96 inline static RT_API_ATTRS
void MatrixTransposedTimesMatrixHelper(
97 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
98 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
99 SubscriptValue n
, std::optional
<std::size_t> xColumnByteStride
,
100 std::optional
<std::size_t> yColumnByteStride
) {
101 if (!xColumnByteStride
) {
102 if (!yColumnByteStride
) {
103 MatrixTransposedTimesMatrix
<RCAT
, RKIND
, XT
, YT
, false, false>(
104 product
, rows
, cols
, x
, y
, n
);
106 MatrixTransposedTimesMatrix
<RCAT
, RKIND
, XT
, YT
, false, true>(
107 product
, rows
, cols
, x
, y
, n
, 0, *yColumnByteStride
);
110 if (!yColumnByteStride
) {
111 MatrixTransposedTimesMatrix
<RCAT
, RKIND
, XT
, YT
, true, false>(
112 product
, rows
, cols
, x
, y
, n
, *xColumnByteStride
);
114 MatrixTransposedTimesMatrix
<RCAT
, RKIND
, XT
, YT
, true, true>(
115 product
, rows
, cols
, x
, y
, n
, *xColumnByteStride
, *yColumnByteStride
);
121 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
123 // Contiguous numeric matrix*vector multiplication
124 // matrix(rows,n) * column vector(n) -> column vector(rows)
125 // Straightforward algorithm:
129 // 1 RES(I) = RES(I) + X(K,I)*Y(K)
130 // With loop distribution and transposition to avoid the inner
131 // sum reduction and to avoid non-unit strides:
136 // 2 RES(I) = RES(I) + X(K,I)*Y(K)
137 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
138 bool X_HAS_STRIDED_COLUMNS
>
139 inline static RT_API_ATTRS
void MatrixTransposedTimesVector(
140 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
141 SubscriptValue n
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
142 std::size_t xColumnByteStride
= 0) {
143 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
144 std::memset(product
, 0, rows
* sizeof *product
);
145 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
146 for (SubscriptValue k
{0}; k
< n
; ++k
) {
148 if constexpr (!X_HAS_STRIDED_COLUMNS
) {
149 x_ki
= static_cast<ResultType
>(x
[i
* n
+ k
]);
151 x_ki
= static_cast<ResultType
>(reinterpret_cast<const XT
*>(
152 reinterpret_cast<const char *>(x
) + i
* xColumnByteStride
)[k
]);
154 ResultType y_k
= static_cast<ResultType
>(y
[k
]);
155 product
[i
] += x_ki
* y_k
;
162 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
163 inline static RT_API_ATTRS
void MatrixTransposedTimesVectorHelper(
164 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
165 SubscriptValue n
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
166 std::optional
<std::size_t> xColumnByteStride
) {
167 if (!xColumnByteStride
) {
168 MatrixTransposedTimesVector
<RCAT
, RKIND
, XT
, YT
, false>(
169 product
, rows
, n
, x
, y
);
171 MatrixTransposedTimesVector
<RCAT
, RKIND
, XT
, YT
, true>(
172 product
, rows
, n
, x
, y
, *xColumnByteStride
);
177 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
179 // Implements an instance of MATMUL for given argument types.
180 template <bool IS_ALLOCATING
, TypeCategory RCAT
, int RKIND
, typename XT
,
182 inline static RT_API_ATTRS
void DoMatmulTranspose(
183 std::conditional_t
<IS_ALLOCATING
, Descriptor
, const Descriptor
> &result
,
184 const Descriptor
&x
, const Descriptor
&y
, Terminator
&terminator
) {
187 int resRank
{xRank
+ yRank
- 2};
188 if (xRank
* yRank
!= 2 * resRank
) {
190 "MATMUL-TRANSPOSE: bad argument ranks (%d * %d)", xRank
, yRank
);
192 SubscriptValue extent
[2]{x
.GetDimension(1).Extent(),
193 resRank
== 2 ? y
.GetDimension(1).Extent() : 0};
194 if constexpr (IS_ALLOCATING
) {
196 RCAT
, RKIND
, nullptr, resRank
, extent
, CFI_attribute_allocatable
);
197 for (int j
{0}; j
< resRank
; ++j
) {
198 result
.GetDimension(j
).SetBounds(1, extent
[j
]);
200 if (int stat
{result
.Allocate()}) {
202 "MATMUL-TRANSPOSE: could not allocate memory for result; STAT=%d",
206 RUNTIME_CHECK(terminator
, resRank
== result
.rank());
208 terminator
, result
.ElementBytes() == static_cast<std::size_t>(RKIND
));
209 RUNTIME_CHECK(terminator
, result
.GetDimension(0).Extent() == extent
[0]);
210 RUNTIME_CHECK(terminator
,
211 resRank
== 1 || result
.GetDimension(1).Extent() == extent
[1]);
213 SubscriptValue n
{x
.GetDimension(0).Extent()};
214 if (n
!= y
.GetDimension(0).Extent()) {
216 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
217 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
218 static_cast<std::intmax_t>(x
.GetDimension(1).Extent()),
219 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
220 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
223 CppTypeFor
<RCAT
== TypeCategory::Logical
? TypeCategory::Integer
: RCAT
,
225 const SubscriptValue rows
{extent
[0]};
226 const SubscriptValue cols
{extent
[1]};
227 if constexpr (RCAT
!= TypeCategory::Logical
) {
228 if (x
.IsContiguous(1) && y
.IsContiguous(1) &&
229 (IS_ALLOCATING
|| result
.IsContiguous())) {
230 // Contiguous numeric matrices (maybe with columns
231 // separated by a stride).
232 std::optional
<std::size_t> xColumnByteStride
;
233 if (!x
.IsContiguous()) {
234 // X's columns are strided.
235 SubscriptValue xAt
[2]{};
236 x
.GetLowerBounds(xAt
);
238 xColumnByteStride
= x
.SubscriptsToByteOffset(xAt
);
240 std::optional
<std::size_t> yColumnByteStride
;
241 if (!y
.IsContiguous()) {
242 // Y's columns are strided.
243 SubscriptValue yAt
[2]{};
244 y
.GetLowerBounds(yAt
);
246 yColumnByteStride
= y
.SubscriptsToByteOffset(yAt
);
248 if (resRank
== 2) { // M*M -> M
249 // TODO: use BLAS-3 GEMM for supported types.
250 MatrixTransposedTimesMatrixHelper
<RCAT
, RKIND
, XT
, YT
>(
251 result
.template OffsetElement
<WriteResult
>(), rows
, cols
,
252 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>(), n
, xColumnByteStride
,
256 if (xRank
== 2) { // M*V -> V
257 // TODO: use BLAS-2 GEMM for supported types.
258 MatrixTransposedTimesVectorHelper
<RCAT
, RKIND
, XT
, YT
>(
259 result
.template OffsetElement
<WriteResult
>(), rows
, n
,
260 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>(), xColumnByteStride
);
263 // else V*M -> V (not allowed because TRANSPOSE() is only defined for rank
266 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
267 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
268 static_cast<std::intmax_t>(n
),
269 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
270 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
274 // General algorithms for LOGICAL and noncontiguity
275 SubscriptValue xLB
[2], yLB
[2], resLB
[2];
276 x
.GetLowerBounds(xLB
);
277 y
.GetLowerBounds(yLB
);
278 result
.GetLowerBounds(resLB
);
279 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
280 if (resRank
== 2) { // M*M -> M
281 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
282 for (SubscriptValue j
{0}; j
< cols
; ++j
) {
284 if constexpr (RCAT
== TypeCategory::Logical
) {
290 for (SubscriptValue k
{0}; k
< n
; ++k
) {
291 SubscriptValue xAt
[2]{k
+ xLB
[0], i
+ xLB
[1]};
292 SubscriptValue yAt
[2]{k
+ yLB
[0], j
+ yLB
[1]};
293 if constexpr (RCAT
== TypeCategory::Logical
) {
294 ResultType x_ki
= IsLogicalElementTrue(x
, xAt
);
295 ResultType y_kj
= IsLogicalElementTrue(y
, yAt
);
296 res_ij
= res_ij
|| (x_ki
&& y_kj
);
298 ResultType x_ki
= static_cast<ResultType
>(*x
.Element
<XT
>(xAt
));
299 ResultType y_kj
= static_cast<ResultType
>(*y
.Element
<YT
>(yAt
));
300 res_ij
+= x_ki
* y_kj
;
303 SubscriptValue resAt
[2]{i
+ resLB
[0], j
+ resLB
[1]};
304 *result
.template Element
<WriteResult
>(resAt
) = res_ij
;
307 } else if (xRank
== 2) { // M*V -> V
308 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
310 if constexpr (RCAT
== TypeCategory::Logical
) {
316 for (SubscriptValue k
{0}; k
< n
; ++k
) {
317 SubscriptValue xAt
[2]{k
+ xLB
[0], i
+ xLB
[1]};
318 SubscriptValue yAt
[1]{k
+ yLB
[0]};
319 if constexpr (RCAT
== TypeCategory::Logical
) {
320 ResultType x_ki
= IsLogicalElementTrue(x
, xAt
);
321 ResultType y_k
= IsLogicalElementTrue(y
, yAt
);
322 res_i
= res_i
|| (x_ki
&& y_k
);
324 ResultType x_ki
= static_cast<ResultType
>(*x
.Element
<XT
>(xAt
));
325 ResultType y_k
= static_cast<ResultType
>(*y
.Element
<YT
>(yAt
));
329 SubscriptValue resAt
[1]{i
+ resLB
[0]};
330 *result
.template Element
<WriteResult
>(resAt
) = res_i
;
333 // TRANSPOSE(V) not allowed by fortran standard
335 "MATMUL-TRANSPOSE: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
336 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
337 static_cast<std::intmax_t>(n
),
338 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
339 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
345 // Maps the dynamic type information from the arguments' descriptors
346 // to the right instantiation of DoMatmul() for valid combinations of
348 template <bool IS_ALLOCATING
> struct MatmulTranspose
{
349 using ResultDescriptor
=
350 std::conditional_t
<IS_ALLOCATING
, Descriptor
, const Descriptor
>;
351 template <TypeCategory XCAT
, int XKIND
> struct MM1
{
352 template <TypeCategory YCAT
, int YKIND
> struct MM2
{
353 RT_API_ATTRS
void operator()(ResultDescriptor
&result
,
354 const Descriptor
&x
, const Descriptor
&y
,
355 Terminator
&terminator
) const {
356 if constexpr (constexpr auto resultType
{
357 GetResultType(XCAT
, XKIND
, YCAT
, YKIND
)}) {
358 if constexpr (Fortran::common::IsNumericTypeCategory(
359 resultType
->first
) ||
360 resultType
->first
== TypeCategory::Logical
) {
361 return DoMatmulTranspose
<IS_ALLOCATING
, resultType
->first
,
362 resultType
->second
, CppTypeFor
<XCAT
, XKIND
>,
363 CppTypeFor
<YCAT
, YKIND
>>(result
, x
, y
, terminator
);
366 terminator
.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))",
367 static_cast<int>(XCAT
), XKIND
, static_cast<int>(YCAT
), YKIND
);
370 RT_API_ATTRS
void operator()(ResultDescriptor
&result
, const Descriptor
&x
,
371 const Descriptor
&y
, Terminator
&terminator
, TypeCategory yCat
,
373 ApplyType
<MM2
, void>(yCat
, yKind
, terminator
, result
, x
, y
, terminator
);
376 RT_API_ATTRS
void operator()(ResultDescriptor
&result
, const Descriptor
&x
,
377 const Descriptor
&y
, const char *sourceFile
, int line
) const {
378 Terminator terminator
{sourceFile
, line
};
379 auto xCatKind
{x
.type().GetCategoryAndKind()};
380 auto yCatKind
{y
.type().GetCategoryAndKind()};
381 RUNTIME_CHECK(terminator
, xCatKind
.has_value() && yCatKind
.has_value());
382 ApplyType
<MM1
, void>(xCatKind
->first
, xCatKind
->second
, terminator
, result
,
383 x
, y
, terminator
, yCatKind
->first
, yCatKind
->second
);
388 namespace Fortran::runtime
{
390 RT_EXT_API_GROUP_BEGIN
392 void RTDEF(MatmulTranspose
)(Descriptor
&result
, const Descriptor
&x
,
393 const Descriptor
&y
, const char *sourceFile
, int line
) {
394 MatmulTranspose
<true>{}(result
, x
, y
, sourceFile
, line
);
396 void RTDEF(MatmulTransposeDirect
)(const Descriptor
&result
, const Descriptor
&x
,
397 const Descriptor
&y
, const char *sourceFile
, int line
) {
398 MatmulTranspose
<false>{}(result
, x
, y
, sourceFile
, line
);
403 } // namespace Fortran::runtime