1 //===-- runtime/matmul.cpp ------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Implements all forms of MATMUL (Fortran 2018 16.9.124)
11 // There are two main entry points; one establishes a descriptor for the
12 // result and allocates it, and the other expects a result descriptor that
13 // points to existing storage.
15 // This implementation must handle all combinations of numeric types and
16 // kinds (100 - 165 cases depending on the target), plus all combinations
17 // of logical kinds (16). A single template undergoes many instantiations
18 // to cover all of the valid possibilities.
20 // Places where BLAS routines could be called are marked as TODO items.
22 #include "flang/Runtime/matmul.h"
23 #include "terminator.h"
25 #include "flang/Common/optional.h"
26 #include "flang/Runtime/c-or-cpp.h"
27 #include "flang/Runtime/cpp-type.h"
28 #include "flang/Runtime/descriptor.h"
32 using namespace Fortran::runtime
;
34 // General accumulator for any type and stride; this is not used for
35 // contiguous numeric cases.
36 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
39 using Result
= AccumulationType
<RCAT
, RKIND
>;
40 RT_API_ATTRS
Accumulator(const Descriptor
&x
, const Descriptor
&y
)
42 RT_API_ATTRS
void Accumulate(
43 const SubscriptValue xAt
[], const SubscriptValue yAt
[]) {
44 if constexpr (RCAT
== TypeCategory::Logical
) {
46 (IsLogicalElementTrue(x_
, xAt
) && IsLogicalElementTrue(y_
, yAt
));
48 sum_
+= static_cast<Result
>(*x_
.Element
<XT
>(xAt
)) *
49 static_cast<Result
>(*y_
.Element
<YT
>(yAt
));
52 RT_API_ATTRS Result
GetResult() const { return sum_
; }
55 const Descriptor
&x_
, &y_
;
59 // Contiguous numeric matrix*matrix multiplication
60 // matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols)
61 // Straightforward algorithm:
66 // 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J)
67 // With loop distribution and transposition to avoid the inner sum
68 // reduction and to avoid non-unit strides:
75 // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term
76 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
77 bool X_HAS_STRIDED_COLUMNS
, bool Y_HAS_STRIDED_COLUMNS
>
78 inline RT_API_ATTRS
void MatrixTimesMatrix(
79 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
80 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
81 SubscriptValue n
, std::size_t xColumnByteStride
= 0,
82 std::size_t yColumnByteStride
= 0) {
83 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
84 std::memset(product
, 0, rows
* cols
* sizeof *product
);
85 const XT
*RESTRICT xp0
{x
};
86 for (SubscriptValue k
{0}; k
< n
; ++k
) {
87 ResultType
*RESTRICT p
{product
};
88 for (SubscriptValue j
{0}; j
< cols
; ++j
) {
89 const XT
*RESTRICT xp
{xp0
};
91 if constexpr (!Y_HAS_STRIDED_COLUMNS
) {
92 yv
= static_cast<ResultType
>(y
[k
+ j
* n
]);
94 yv
= static_cast<ResultType
>(reinterpret_cast<const YT
*>(
95 reinterpret_cast<const char *>(y
) + j
* yColumnByteStride
)[k
]);
97 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
98 *p
++ += static_cast<ResultType
>(*xp
++) * yv
;
101 if constexpr (!X_HAS_STRIDED_COLUMNS
) {
104 xp0
= reinterpret_cast<const XT
*>(
105 reinterpret_cast<const char *>(xp0
) + xColumnByteStride
);
110 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
111 inline RT_API_ATTRS
void MatrixTimesMatrixHelper(
112 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
113 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
114 SubscriptValue n
, Fortran::common::optional
<std::size_t> xColumnByteStride
,
115 Fortran::common::optional
<std::size_t> yColumnByteStride
) {
116 if (!xColumnByteStride
) {
117 if (!yColumnByteStride
) {
118 MatrixTimesMatrix
<RCAT
, RKIND
, XT
, YT
, false, false>(
119 product
, rows
, cols
, x
, y
, n
);
121 MatrixTimesMatrix
<RCAT
, RKIND
, XT
, YT
, false, true>(
122 product
, rows
, cols
, x
, y
, n
, 0, *yColumnByteStride
);
125 if (!yColumnByteStride
) {
126 MatrixTimesMatrix
<RCAT
, RKIND
, XT
, YT
, true, false>(
127 product
, rows
, cols
, x
, y
, n
, *xColumnByteStride
);
129 MatrixTimesMatrix
<RCAT
, RKIND
, XT
, YT
, true, true>(
130 product
, rows
, cols
, x
, y
, n
, *xColumnByteStride
, *yColumnByteStride
);
135 // Contiguous numeric matrix*vector multiplication
136 // matrix(rows,n) * column vector(n) -> column vector(rows)
137 // Straightforward algorithm:
141 // 1 RES(J) = RES(J) + X(J,K)*Y(K)
142 // With loop distribution and transposition to avoid the inner
143 // sum reduction and to avoid non-unit strides:
148 // 2 RES(J) = RES(J) + X(J,K)*Y(K)
149 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
150 bool X_HAS_STRIDED_COLUMNS
>
151 inline RT_API_ATTRS
void MatrixTimesVector(
152 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
153 SubscriptValue n
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
154 std::size_t xColumnByteStride
= 0) {
155 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
156 std::memset(product
, 0, rows
* sizeof *product
);
157 [[maybe_unused
]] const XT
*RESTRICT xp0
{x
};
158 for (SubscriptValue k
{0}; k
< n
; ++k
) {
159 ResultType
*RESTRICT p
{product
};
160 auto yv
{static_cast<ResultType
>(*y
++)};
161 for (SubscriptValue j
{0}; j
< rows
; ++j
) {
162 *p
++ += static_cast<ResultType
>(*x
++) * yv
;
164 if constexpr (X_HAS_STRIDED_COLUMNS
) {
165 xp0
= reinterpret_cast<const XT
*>(
166 reinterpret_cast<const char *>(xp0
) + xColumnByteStride
);
172 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
173 inline RT_API_ATTRS
void MatrixTimesVectorHelper(
174 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
175 SubscriptValue n
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
176 Fortran::common::optional
<std::size_t> xColumnByteStride
) {
177 if (!xColumnByteStride
) {
178 MatrixTimesVector
<RCAT
, RKIND
, XT
, YT
, false>(product
, rows
, n
, x
, y
);
180 MatrixTimesVector
<RCAT
, RKIND
, XT
, YT
, true>(
181 product
, rows
, n
, x
, y
, *xColumnByteStride
);
185 // Contiguous numeric vector*matrix multiplication
186 // row vector(n) * matrix(n,cols) -> row vector(cols)
187 // Straightforward algorithm:
191 // 1 RES(J) = RES(J) + X(K)*Y(K,J)
192 // With loop distribution and transposition to avoid the inner
193 // sum reduction and one non-unit stride (the other remains):
198 // 2 RES(J) = RES(J) + X(K)*Y(K,J)
199 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
200 bool Y_HAS_STRIDED_COLUMNS
>
201 inline RT_API_ATTRS
void VectorTimesMatrix(
202 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue n
,
203 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
204 std::size_t yColumnByteStride
= 0) {
205 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
206 std::memset(product
, 0, cols
* sizeof *product
);
207 for (SubscriptValue k
{0}; k
< n
; ++k
) {
208 ResultType
*RESTRICT p
{product
};
209 auto xv
{static_cast<ResultType
>(*x
++)};
210 const YT
*RESTRICT yp
{&y
[k
]};
211 for (SubscriptValue j
{0}; j
< cols
; ++j
) {
212 *p
++ += xv
* static_cast<ResultType
>(*yp
);
213 if constexpr (!Y_HAS_STRIDED_COLUMNS
) {
216 yp
= reinterpret_cast<const YT
*>(
217 reinterpret_cast<const char *>(yp
) + yColumnByteStride
);
223 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
,
224 bool SPARSE_COLUMNS
= false>
225 inline RT_API_ATTRS
void VectorTimesMatrixHelper(
226 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue n
,
227 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
228 Fortran::common::optional
<std::size_t> yColumnByteStride
) {
229 if (!yColumnByteStride
) {
230 VectorTimesMatrix
<RCAT
, RKIND
, XT
, YT
, false>(product
, n
, cols
, x
, y
);
232 VectorTimesMatrix
<RCAT
, RKIND
, XT
, YT
, true>(
233 product
, n
, cols
, x
, y
, *yColumnByteStride
);
237 // Implements an instance of MATMUL for given argument types.
238 template <bool IS_ALLOCATING
, TypeCategory RCAT
, int RKIND
, typename XT
,
240 static inline RT_API_ATTRS
void DoMatmul(
241 std::conditional_t
<IS_ALLOCATING
, Descriptor
, const Descriptor
> &result
,
242 const Descriptor
&x
, const Descriptor
&y
, Terminator
&terminator
) {
245 int resRank
{xRank
+ yRank
- 2};
246 if (xRank
* yRank
!= 2 * resRank
) {
247 terminator
.Crash("MATMUL: bad argument ranks (%d * %d)", xRank
, yRank
);
249 SubscriptValue extent
[2]{
250 xRank
== 2 ? x
.GetDimension(0).Extent() : y
.GetDimension(1).Extent(),
251 resRank
== 2 ? y
.GetDimension(1).Extent() : 0};
252 if constexpr (IS_ALLOCATING
) {
254 RCAT
, RKIND
, nullptr, resRank
, extent
, CFI_attribute_allocatable
);
255 for (int j
{0}; j
< resRank
; ++j
) {
256 result
.GetDimension(j
).SetBounds(1, extent
[j
]);
258 if (int stat
{result
.Allocate()}) {
260 "MATMUL: could not allocate memory for result; STAT=%d", stat
);
263 RUNTIME_CHECK(terminator
, resRank
== result
.rank());
265 terminator
, result
.ElementBytes() == static_cast<std::size_t>(RKIND
));
266 RUNTIME_CHECK(terminator
, result
.GetDimension(0).Extent() == extent
[0]);
267 RUNTIME_CHECK(terminator
,
268 resRank
== 1 || result
.GetDimension(1).Extent() == extent
[1]);
270 SubscriptValue n
{x
.GetDimension(xRank
- 1).Extent()};
271 if (n
!= y
.GetDimension(0).Extent()) {
272 // At this point, we know that there's a shape error. There are three
273 // possibilities, x is rank 1, y is rank 1, or both are rank 2.
275 terminator
.Crash("MATMUL: unacceptable operand shapes (%jd, %jdx%jd)",
276 static_cast<std::intmax_t>(n
),
277 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
278 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
279 } else if (yRank
== 1) {
280 terminator
.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jd)",
281 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
282 static_cast<std::intmax_t>(n
),
283 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()));
285 terminator
.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
286 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
287 static_cast<std::intmax_t>(n
),
288 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
289 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
293 CppTypeFor
<RCAT
== TypeCategory::Logical
? TypeCategory::Integer
: RCAT
,
295 if constexpr (RCAT
!= TypeCategory::Logical
) {
296 if (x
.IsContiguous(1) && y
.IsContiguous(1) &&
297 (IS_ALLOCATING
|| result
.IsContiguous())) {
298 // Contiguous numeric matrices (maybe with columns
299 // separated by a stride).
300 Fortran::common::optional
<std::size_t> xColumnByteStride
;
301 if (!x
.IsContiguous()) {
302 // X's columns are strided.
303 SubscriptValue xAt
[2]{};
304 x
.GetLowerBounds(xAt
);
306 xColumnByteStride
= x
.SubscriptsToByteOffset(xAt
);
308 Fortran::common::optional
<std::size_t> yColumnByteStride
;
309 if (!y
.IsContiguous()) {
310 // Y's columns are strided.
311 SubscriptValue yAt
[2]{};
312 y
.GetLowerBounds(yAt
);
314 yColumnByteStride
= y
.SubscriptsToByteOffset(yAt
);
316 // Note that BLAS GEMM can be used for the strided
317 // columns by setting proper leading dimension size.
318 // This implies that the column stride is divisible
319 // by the element size, which is usually true.
320 if (resRank
== 2) { // M*M -> M
321 if (std::is_same_v
<XT
, YT
>) {
322 if constexpr (std::is_same_v
<XT
, float>) {
323 // TODO: call BLAS-3 SGEMM
324 // TODO: try using CUTLASS for device.
325 } else if constexpr (std::is_same_v
<XT
, double>) {
326 // TODO: call BLAS-3 DGEMM
327 } else if constexpr (std::is_same_v
<XT
, rtcmplx::complex<float>>) {
328 // TODO: call BLAS-3 CGEMM
329 } else if constexpr (std::is_same_v
<XT
, rtcmplx::complex<double>>) {
330 // TODO: call BLAS-3 ZGEMM
333 MatrixTimesMatrixHelper
<RCAT
, RKIND
, XT
, YT
>(
334 result
.template OffsetElement
<WriteResult
>(), extent
[0], extent
[1],
335 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>(), n
, xColumnByteStride
,
338 } else if (xRank
== 2) { // M*V -> V
339 if (std::is_same_v
<XT
, YT
>) {
340 if constexpr (std::is_same_v
<XT
, float>) {
341 // TODO: call BLAS-2 SGEMV(x,y)
342 } else if constexpr (std::is_same_v
<XT
, double>) {
343 // TODO: call BLAS-2 DGEMV(x,y)
344 } else if constexpr (std::is_same_v
<XT
, rtcmplx::complex<float>>) {
345 // TODO: call BLAS-2 CGEMV(x,y)
346 } else if constexpr (std::is_same_v
<XT
, rtcmplx::complex<double>>) {
347 // TODO: call BLAS-2 ZGEMV(x,y)
350 MatrixTimesVectorHelper
<RCAT
, RKIND
, XT
, YT
>(
351 result
.template OffsetElement
<WriteResult
>(), extent
[0], n
,
352 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>(), xColumnByteStride
);
355 if (std::is_same_v
<XT
, YT
>) {
356 if constexpr (std::is_same_v
<XT
, float>) {
357 // TODO: call BLAS-2 SGEMV(y,x)
358 } else if constexpr (std::is_same_v
<XT
, double>) {
359 // TODO: call BLAS-2 DGEMV(y,x)
360 } else if constexpr (std::is_same_v
<XT
, rtcmplx::complex<float>>) {
361 // TODO: call BLAS-2 CGEMV(y,x)
362 } else if constexpr (std::is_same_v
<XT
, rtcmplx::complex<double>>) {
363 // TODO: call BLAS-2 ZGEMV(y,x)
366 VectorTimesMatrixHelper
<RCAT
, RKIND
, XT
, YT
>(
367 result
.template OffsetElement
<WriteResult
>(), n
, extent
[0],
368 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>(), yColumnByteStride
);
373 // General algorithms for LOGICAL and noncontiguity
374 SubscriptValue xAt
[2], yAt
[2], resAt
[2];
375 x
.GetLowerBounds(xAt
);
376 y
.GetLowerBounds(yAt
);
377 result
.GetLowerBounds(resAt
);
378 if (resRank
== 2) { // M*M -> M
379 SubscriptValue x1
{xAt
[1]}, y0
{yAt
[0]}, y1
{yAt
[1]}, res1
{resAt
[1]};
380 for (SubscriptValue i
{0}; i
< extent
[0]; ++i
) {
381 for (SubscriptValue j
{0}; j
< extent
[1]; ++j
) {
382 Accumulator
<RCAT
, RKIND
, XT
, YT
> accumulator
{x
, y
};
384 for (SubscriptValue k
{0}; k
< n
; ++k
) {
387 accumulator
.Accumulate(xAt
, yAt
);
390 *result
.template Element
<WriteResult
>(resAt
) = accumulator
.GetResult();
395 } else if (xRank
== 2) { // M*V -> V
396 SubscriptValue x1
{xAt
[1]}, y0
{yAt
[0]};
397 for (SubscriptValue j
{0}; j
< extent
[0]; ++j
) {
398 Accumulator
<RCAT
, RKIND
, XT
, YT
> accumulator
{x
, y
};
399 for (SubscriptValue k
{0}; k
< n
; ++k
) {
402 accumulator
.Accumulate(xAt
, yAt
);
404 *result
.template Element
<WriteResult
>(resAt
) = accumulator
.GetResult();
409 SubscriptValue x0
{xAt
[0]}, y0
{yAt
[0]};
410 for (SubscriptValue j
{0}; j
< extent
[0]; ++j
) {
411 Accumulator
<RCAT
, RKIND
, XT
, YT
> accumulator
{x
, y
};
412 for (SubscriptValue k
{0}; k
< n
; ++k
) {
415 accumulator
.Accumulate(xAt
, yAt
);
417 *result
.template Element
<WriteResult
>(resAt
) = accumulator
.GetResult();
424 template <bool IS_ALLOCATING
, TypeCategory XCAT
, int XKIND
, TypeCategory YCAT
,
426 struct MatmulHelper
{
427 using ResultDescriptor
=
428 std::conditional_t
<IS_ALLOCATING
, Descriptor
, const Descriptor
>;
429 RT_API_ATTRS
void operator()(ResultDescriptor
&result
, const Descriptor
&x
,
430 const Descriptor
&y
, const char *sourceFile
, int line
) const {
431 Terminator terminator
{sourceFile
, line
};
432 auto xCatKind
{x
.type().GetCategoryAndKind()};
433 auto yCatKind
{y
.type().GetCategoryAndKind()};
434 RUNTIME_CHECK(terminator
, xCatKind
.has_value() && yCatKind
.has_value());
435 RUNTIME_CHECK(terminator
, xCatKind
->first
== XCAT
);
436 RUNTIME_CHECK(terminator
, yCatKind
->first
== YCAT
);
437 if constexpr (constexpr auto resultType
{
438 GetResultType(XCAT
, XKIND
, YCAT
, YKIND
)}) {
439 return DoMatmul
<IS_ALLOCATING
, resultType
->first
, resultType
->second
,
440 CppTypeFor
<XCAT
, XKIND
>, CppTypeFor
<YCAT
, YKIND
>>(
441 result
, x
, y
, terminator
);
443 terminator
.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
444 static_cast<int>(XCAT
), XKIND
, static_cast<int>(YCAT
), YKIND
);
449 namespace Fortran::runtime
{
451 RT_EXT_API_GROUP_BEGIN
453 #define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
454 void RTDEF(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
455 const Descriptor &x, const Descriptor &y, const char *sourceFile, \
457 MatmulHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
458 YKIND>{}(result, x, y, sourceFile, line); \
461 #define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
462 void RTDEF(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
463 const Descriptor &x, const Descriptor &y, const char *sourceFile, \
465 MatmulHelper<false, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
466 YKIND>{}(result, x, y, sourceFile, line); \
469 #define MATMUL_FORCE_ALL_TYPES 0
471 #include "flang/Runtime/matmul-instances.inc"
475 } // namespace Fortran::runtime