1 //===-- runtime/matmul.cpp ------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Implements all forms of MATMUL (Fortran 2018 16.9.124)
11 // There are two main entry points; one establishes a descriptor for the
12 // result and allocates it, and the other expects a result descriptor that
13 // points to existing storage.
15 // This implementation must handle all combinations of numeric types and
16 // kinds (100 - 165 cases depending on the target), plus all combinations
17 // of logical kinds (16). A single template undergoes many instantiations
18 // to cover all of the valid possibilities.
20 // Places where BLAS routines could be called are marked as TODO items.
22 #include "flang/Runtime/matmul.h"
23 #include "terminator.h"
25 #include "flang/Common/optional.h"
26 #include "flang/Runtime/c-or-cpp.h"
27 #include "flang/Runtime/cpp-type.h"
28 #include "flang/Runtime/descriptor.h"
32 using namespace Fortran::runtime
;
34 // Suppress the warnings about calling __host__-only std::complex operators,
35 // defined in C++ STD header files, from __device__ code.
37 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
39 // General accumulator for any type and stride; this is not used for
40 // contiguous numeric cases.
41 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
44 using Result
= AccumulationType
<RCAT
, RKIND
>;
45 RT_API_ATTRS
Accumulator(const Descriptor
&x
, const Descriptor
&y
)
47 RT_API_ATTRS
void Accumulate(
48 const SubscriptValue xAt
[], const SubscriptValue yAt
[]) {
49 if constexpr (RCAT
== TypeCategory::Logical
) {
51 (IsLogicalElementTrue(x_
, xAt
) && IsLogicalElementTrue(y_
, yAt
));
53 sum_
+= static_cast<Result
>(*x_
.Element
<XT
>(xAt
)) *
54 static_cast<Result
>(*y_
.Element
<YT
>(yAt
));
57 RT_API_ATTRS Result
GetResult() const { return sum_
; }
60 const Descriptor
&x_
, &y_
;
64 // Contiguous numeric matrix*matrix multiplication
65 // matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols)
66 // Straightforward algorithm:
71 // 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J)
72 // With loop distribution and transposition to avoid the inner sum
73 // reduction and to avoid non-unit strides:
80 // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term
81 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
82 bool X_HAS_STRIDED_COLUMNS
, bool Y_HAS_STRIDED_COLUMNS
>
83 inline RT_API_ATTRS
void MatrixTimesMatrix(
84 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
85 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
86 SubscriptValue n
, std::size_t xColumnByteStride
= 0,
87 std::size_t yColumnByteStride
= 0) {
88 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
89 std::memset(product
, 0, rows
* cols
* sizeof *product
);
90 const XT
*RESTRICT xp0
{x
};
91 for (SubscriptValue k
{0}; k
< n
; ++k
) {
92 ResultType
*RESTRICT p
{product
};
93 for (SubscriptValue j
{0}; j
< cols
; ++j
) {
94 const XT
*RESTRICT xp
{xp0
};
96 if constexpr (!Y_HAS_STRIDED_COLUMNS
) {
97 yv
= static_cast<ResultType
>(y
[k
+ j
* n
]);
99 yv
= static_cast<ResultType
>(reinterpret_cast<const YT
*>(
100 reinterpret_cast<const char *>(y
) + j
* yColumnByteStride
)[k
]);
102 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
103 *p
++ += static_cast<ResultType
>(*xp
++) * yv
;
106 if constexpr (!X_HAS_STRIDED_COLUMNS
) {
109 xp0
= reinterpret_cast<const XT
*>(
110 reinterpret_cast<const char *>(xp0
) + xColumnByteStride
);
117 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
118 inline RT_API_ATTRS
void MatrixTimesMatrixHelper(
119 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
120 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
121 SubscriptValue n
, Fortran::common::optional
<std::size_t> xColumnByteStride
,
122 Fortran::common::optional
<std::size_t> yColumnByteStride
) {
123 if (!xColumnByteStride
) {
124 if (!yColumnByteStride
) {
125 MatrixTimesMatrix
<RCAT
, RKIND
, XT
, YT
, false, false>(
126 product
, rows
, cols
, x
, y
, n
);
128 MatrixTimesMatrix
<RCAT
, RKIND
, XT
, YT
, false, true>(
129 product
, rows
, cols
, x
, y
, n
, 0, *yColumnByteStride
);
132 if (!yColumnByteStride
) {
133 MatrixTimesMatrix
<RCAT
, RKIND
, XT
, YT
, true, false>(
134 product
, rows
, cols
, x
, y
, n
, *xColumnByteStride
);
136 MatrixTimesMatrix
<RCAT
, RKIND
, XT
, YT
, true, true>(
137 product
, rows
, cols
, x
, y
, n
, *xColumnByteStride
, *yColumnByteStride
);
143 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
145 // Contiguous numeric matrix*vector multiplication
146 // matrix(rows,n) * column vector(n) -> column vector(rows)
147 // Straightforward algorithm:
151 // 1 RES(J) = RES(J) + X(J,K)*Y(K)
152 // With loop distribution and transposition to avoid the inner
153 // sum reduction and to avoid non-unit strides:
158 // 2 RES(J) = RES(J) + X(J,K)*Y(K)
159 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
160 bool X_HAS_STRIDED_COLUMNS
>
161 inline RT_API_ATTRS
void MatrixTimesVector(
162 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
163 SubscriptValue n
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
164 std::size_t xColumnByteStride
= 0) {
165 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
166 std::memset(product
, 0, rows
* sizeof *product
);
167 [[maybe_unused
]] const XT
*RESTRICT xp0
{x
};
168 for (SubscriptValue k
{0}; k
< n
; ++k
) {
169 ResultType
*RESTRICT p
{product
};
170 auto yv
{static_cast<ResultType
>(*y
++)};
171 for (SubscriptValue j
{0}; j
< rows
; ++j
) {
172 *p
++ += static_cast<ResultType
>(*x
++) * yv
;
174 if constexpr (X_HAS_STRIDED_COLUMNS
) {
175 xp0
= reinterpret_cast<const XT
*>(
176 reinterpret_cast<const char *>(xp0
) + xColumnByteStride
);
184 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
185 inline RT_API_ATTRS
void MatrixTimesVectorHelper(
186 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
187 SubscriptValue n
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
188 Fortran::common::optional
<std::size_t> xColumnByteStride
) {
189 if (!xColumnByteStride
) {
190 MatrixTimesVector
<RCAT
, RKIND
, XT
, YT
, false>(product
, rows
, n
, x
, y
);
192 MatrixTimesVector
<RCAT
, RKIND
, XT
, YT
, true>(
193 product
, rows
, n
, x
, y
, *xColumnByteStride
);
198 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
200 // Contiguous numeric vector*matrix multiplication
201 // row vector(n) * matrix(n,cols) -> row vector(cols)
202 // Straightforward algorithm:
206 // 1 RES(J) = RES(J) + X(K)*Y(K,J)
207 // With loop distribution and transposition to avoid the inner
208 // sum reduction and one non-unit stride (the other remains):
213 // 2 RES(J) = RES(J) + X(K)*Y(K,J)
214 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
215 bool Y_HAS_STRIDED_COLUMNS
>
216 inline RT_API_ATTRS
void VectorTimesMatrix(
217 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue n
,
218 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
219 std::size_t yColumnByteStride
= 0) {
220 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
221 std::memset(product
, 0, cols
* sizeof *product
);
222 for (SubscriptValue k
{0}; k
< n
; ++k
) {
223 ResultType
*RESTRICT p
{product
};
224 auto xv
{static_cast<ResultType
>(*x
++)};
225 const YT
*RESTRICT yp
{&y
[k
]};
226 for (SubscriptValue j
{0}; j
< cols
; ++j
) {
227 *p
++ += xv
* static_cast<ResultType
>(*yp
);
228 if constexpr (!Y_HAS_STRIDED_COLUMNS
) {
231 yp
= reinterpret_cast<const YT
*>(
232 reinterpret_cast<const char *>(yp
) + yColumnByteStride
);
240 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
241 bool SPARSE_COLUMNS
= false>
242 inline RT_API_ATTRS
void VectorTimesMatrixHelper(
243 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue n
,
244 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
245 Fortran::common::optional
<std::size_t> yColumnByteStride
) {
246 if (!yColumnByteStride
) {
247 VectorTimesMatrix
<RCAT
, RKIND
, XT
, YT
, false>(product
, n
, cols
, x
, y
);
249 VectorTimesMatrix
<RCAT
, RKIND
, XT
, YT
, true>(
250 product
, n
, cols
, x
, y
, *yColumnByteStride
);
255 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
257 // Implements an instance of MATMUL for given argument types.
258 template <bool IS_ALLOCATING
, TypeCategory RCAT
, int RKIND
, typename XT
,
260 static inline RT_API_ATTRS
void DoMatmul(
261 std::conditional_t
<IS_ALLOCATING
, Descriptor
, const Descriptor
> &result
,
262 const Descriptor
&x
, const Descriptor
&y
, Terminator
&terminator
) {
265 int resRank
{xRank
+ yRank
- 2};
266 if (xRank
* yRank
!= 2 * resRank
) {
267 terminator
.Crash("MATMUL: bad argument ranks (%d * %d)", xRank
, yRank
);
269 SubscriptValue extent
[2]{
270 xRank
== 2 ? x
.GetDimension(0).Extent() : y
.GetDimension(1).Extent(),
271 resRank
== 2 ? y
.GetDimension(1).Extent() : 0};
272 if constexpr (IS_ALLOCATING
) {
274 RCAT
, RKIND
, nullptr, resRank
, extent
, CFI_attribute_allocatable
);
275 for (int j
{0}; j
< resRank
; ++j
) {
276 result
.GetDimension(j
).SetBounds(1, extent
[j
]);
278 if (int stat
{result
.Allocate()}) {
280 "MATMUL: could not allocate memory for result; STAT=%d", stat
);
283 RUNTIME_CHECK(terminator
, resRank
== result
.rank());
285 terminator
, result
.ElementBytes() == static_cast<std::size_t>(RKIND
));
286 RUNTIME_CHECK(terminator
, result
.GetDimension(0).Extent() == extent
[0]);
287 RUNTIME_CHECK(terminator
,
288 resRank
== 1 || result
.GetDimension(1).Extent() == extent
[1]);
290 SubscriptValue n
{x
.GetDimension(xRank
- 1).Extent()};
291 if (n
!= y
.GetDimension(0).Extent()) {
292 // At this point, we know that there's a shape error. There are three
293 // possibilities, x is rank 1, y is rank 1, or both are rank 2.
295 terminator
.Crash("MATMUL: unacceptable operand shapes (%jd, %jdx%jd)",
296 static_cast<std::intmax_t>(n
),
297 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
298 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
299 } else if (yRank
== 1) {
300 terminator
.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jd)",
301 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
302 static_cast<std::intmax_t>(n
),
303 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()));
305 terminator
.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
306 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
307 static_cast<std::intmax_t>(n
),
308 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
309 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
313 CppTypeFor
<RCAT
== TypeCategory::Logical
? TypeCategory::Integer
: RCAT
,
315 if constexpr (RCAT
!= TypeCategory::Logical
) {
316 if (x
.IsContiguous(1) && y
.IsContiguous(1) &&
317 (IS_ALLOCATING
|| result
.IsContiguous())) {
318 // Contiguous numeric matrices (maybe with columns
319 // separated by a stride).
320 Fortran::common::optional
<std::size_t> xColumnByteStride
;
321 if (!x
.IsContiguous()) {
322 // X's columns are strided.
323 SubscriptValue xAt
[2]{};
324 x
.GetLowerBounds(xAt
);
326 xColumnByteStride
= x
.SubscriptsToByteOffset(xAt
);
328 Fortran::common::optional
<std::size_t> yColumnByteStride
;
329 if (!y
.IsContiguous()) {
330 // Y's columns are strided.
331 SubscriptValue yAt
[2]{};
332 y
.GetLowerBounds(yAt
);
334 yColumnByteStride
= y
.SubscriptsToByteOffset(yAt
);
336 // Note that BLAS GEMM can be used for the strided
337 // columns by setting proper leading dimension size.
338 // This implies that the column stride is divisible
339 // by the element size, which is usually true.
340 if (resRank
== 2) { // M*M -> M
341 if (std::is_same_v
<XT
, YT
>) {
342 if constexpr (std::is_same_v
<XT
, float>) {
343 // TODO: call BLAS-3 SGEMM
344 // TODO: try using CUTLASS for device.
345 } else if constexpr (std::is_same_v
<XT
, double>) {
346 // TODO: call BLAS-3 DGEMM
347 } else if constexpr (std::is_same_v
<XT
, std::complex<float>>) {
348 // TODO: call BLAS-3 CGEMM
349 } else if constexpr (std::is_same_v
<XT
, std::complex<double>>) {
350 // TODO: call BLAS-3 ZGEMM
353 MatrixTimesMatrixHelper
<RCAT
, RKIND
, XT
, YT
>(
354 result
.template OffsetElement
<WriteResult
>(), extent
[0], extent
[1],
355 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>(), n
, xColumnByteStride
,
358 } else if (xRank
== 2) { // M*V -> V
359 if (std::is_same_v
<XT
, YT
>) {
360 if constexpr (std::is_same_v
<XT
, float>) {
361 // TODO: call BLAS-2 SGEMV(x,y)
362 } else if constexpr (std::is_same_v
<XT
, double>) {
363 // TODO: call BLAS-2 DGEMV(x,y)
364 } else if constexpr (std::is_same_v
<XT
, std::complex<float>>) {
365 // TODO: call BLAS-2 CGEMV(x,y)
366 } else if constexpr (std::is_same_v
<XT
, std::complex<double>>) {
367 // TODO: call BLAS-2 ZGEMV(x,y)
370 MatrixTimesVectorHelper
<RCAT
, RKIND
, XT
, YT
>(
371 result
.template OffsetElement
<WriteResult
>(), extent
[0], n
,
372 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>(), xColumnByteStride
);
375 if (std::is_same_v
<XT
, YT
>) {
376 if constexpr (std::is_same_v
<XT
, float>) {
377 // TODO: call BLAS-2 SGEMV(y,x)
378 } else if constexpr (std::is_same_v
<XT
, double>) {
379 // TODO: call BLAS-2 DGEMV(y,x)
380 } else if constexpr (std::is_same_v
<XT
, std::complex<float>>) {
381 // TODO: call BLAS-2 CGEMV(y,x)
382 } else if constexpr (std::is_same_v
<XT
, std::complex<double>>) {
383 // TODO: call BLAS-2 ZGEMV(y,x)
386 VectorTimesMatrixHelper
<RCAT
, RKIND
, XT
, YT
>(
387 result
.template OffsetElement
<WriteResult
>(), n
, extent
[0],
388 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>(), yColumnByteStride
);
393 // General algorithms for LOGICAL and noncontiguity
394 SubscriptValue xAt
[2], yAt
[2], resAt
[2];
395 x
.GetLowerBounds(xAt
);
396 y
.GetLowerBounds(yAt
);
397 result
.GetLowerBounds(resAt
);
398 if (resRank
== 2) { // M*M -> M
399 SubscriptValue x1
{xAt
[1]}, y0
{yAt
[0]}, y1
{yAt
[1]}, res1
{resAt
[1]};
400 for (SubscriptValue i
{0}; i
< extent
[0]; ++i
) {
401 for (SubscriptValue j
{0}; j
< extent
[1]; ++j
) {
402 Accumulator
<RCAT
, RKIND
, XT
, YT
> accumulator
{x
, y
};
404 for (SubscriptValue k
{0}; k
< n
; ++k
) {
407 accumulator
.Accumulate(xAt
, yAt
);
410 *result
.template Element
<WriteResult
>(resAt
) = accumulator
.GetResult();
415 } else if (xRank
== 2) { // M*V -> V
416 SubscriptValue x1
{xAt
[1]}, y0
{yAt
[0]};
417 for (SubscriptValue j
{0}; j
< extent
[0]; ++j
) {
418 Accumulator
<RCAT
, RKIND
, XT
, YT
> accumulator
{x
, y
};
419 for (SubscriptValue k
{0}; k
< n
; ++k
) {
422 accumulator
.Accumulate(xAt
, yAt
);
424 *result
.template Element
<WriteResult
>(resAt
) = accumulator
.GetResult();
429 SubscriptValue x0
{xAt
[0]}, y0
{yAt
[0]};
430 for (SubscriptValue j
{0}; j
< extent
[0]; ++j
) {
431 Accumulator
<RCAT
, RKIND
, XT
, YT
> accumulator
{x
, y
};
432 for (SubscriptValue k
{0}; k
< n
; ++k
) {
435 accumulator
.Accumulate(xAt
, yAt
);
437 *result
.template Element
<WriteResult
>(resAt
) = accumulator
.GetResult();
446 template <bool IS_ALLOCATING
, TypeCategory XCAT
, int XKIND
, TypeCategory YCAT
,
448 struct MatmulHelper
{
449 using ResultDescriptor
=
450 std::conditional_t
<IS_ALLOCATING
, Descriptor
, const Descriptor
>;
451 RT_API_ATTRS
void operator()(ResultDescriptor
&result
, const Descriptor
&x
,
452 const Descriptor
&y
, const char *sourceFile
, int line
) const {
453 Terminator terminator
{sourceFile
, line
};
454 auto xCatKind
{x
.type().GetCategoryAndKind()};
455 auto yCatKind
{y
.type().GetCategoryAndKind()};
456 RUNTIME_CHECK(terminator
, xCatKind
.has_value() && yCatKind
.has_value());
457 RUNTIME_CHECK(terminator
, xCatKind
->first
== XCAT
);
458 RUNTIME_CHECK(terminator
, yCatKind
->first
== YCAT
);
459 if constexpr (constexpr auto resultType
{
460 GetResultType(XCAT
, XKIND
, YCAT
, YKIND
)}) {
461 return DoMatmul
<IS_ALLOCATING
, resultType
->first
, resultType
->second
,
462 CppTypeFor
<XCAT
, XKIND
>, CppTypeFor
<YCAT
, YKIND
>>(
463 result
, x
, y
, terminator
);
465 terminator
.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
466 static_cast<int>(XCAT
), XKIND
, static_cast<int>(YCAT
), YKIND
);
471 namespace Fortran::runtime
{
473 RT_EXT_API_GROUP_BEGIN
475 #define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
476 void RTDEF(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
477 const Descriptor &x, const Descriptor &y, const char *sourceFile, \
479 MatmulHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
480 YKIND>{}(result, x, y, sourceFile, line); \
483 #define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
484 void RTDEF(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
485 const Descriptor &x, const Descriptor &y, const char *sourceFile, \
487 MatmulHelper<false, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
488 YKIND>{}(result, x, y, sourceFile, line); \
491 #define MATMUL_FORCE_ALL_TYPES 0
493 #include "flang/Runtime/matmul-instances.inc"
497 } // namespace Fortran::runtime