[X86] AMD Zen 5 Initial enablement
[llvm-project.git] / flang / runtime / matmul.cpp
blob252557e2f9e7adfc3df214bff0d249c7f5af10c4
1 //===-- runtime/matmul.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 // Implements all forms of MATMUL (Fortran 2018 16.9.124)
11 // There are two main entry points; one establishes a descriptor for the
12 // result and allocates it, and the other expects a result descriptor that
13 // points to existing storage.
15 // This implementation must handle all combinations of numeric types and
16 // kinds (100 - 165 cases depending on the target), plus all combinations
17 // of logical kinds (16). A single template undergoes many instantiations
18 // to cover all of the valid possibilities.
20 // Places where BLAS routines could be called are marked as TODO items.
22 #include "flang/Runtime/matmul.h"
23 #include "terminator.h"
24 #include "tools.h"
25 #include "flang/Common/optional.h"
26 #include "flang/Runtime/c-or-cpp.h"
27 #include "flang/Runtime/cpp-type.h"
28 #include "flang/Runtime/descriptor.h"
29 #include <cstring>
31 namespace {
32 using namespace Fortran::runtime;
34 // Suppress the warnings about calling __host__-only std::complex operators,
35 // defined in C++ STD header files, from __device__ code.
36 RT_DIAG_PUSH
37 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
39 // General accumulator for any type and stride; this is not used for
40 // contiguous numeric cases.
41 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
42 class Accumulator {
43 public:
44 using Result = AccumulationType<RCAT, RKIND>;
45 RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y)
46 : x_{x}, y_{y} {}
47 RT_API_ATTRS void Accumulate(
48 const SubscriptValue xAt[], const SubscriptValue yAt[]) {
49 if constexpr (RCAT == TypeCategory::Logical) {
50 sum_ = sum_ ||
51 (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt));
52 } else {
53 sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) *
54 static_cast<Result>(*y_.Element<YT>(yAt));
57 RT_API_ATTRS Result GetResult() const { return sum_; }
59 private:
60 const Descriptor &x_, &y_;
61 Result sum_{};
64 // Contiguous numeric matrix*matrix multiplication
65 // matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols)
66 // Straightforward algorithm:
67 // DO 1 I = 1, NROWS
68 // DO 1 J = 1, NCOLS
69 // RES(I,J) = 0
70 // DO 1 K = 1, N
71 // 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J)
72 // With loop distribution and transposition to avoid the inner sum
73 // reduction and to avoid non-unit strides:
74 // DO 1 I = 1, NROWS
75 // DO 1 J = 1, NCOLS
76 // 1 RES(I,J) = 0
77 // DO 2 K = 1, N
78 // DO 2 J = 1, NCOLS
79 // DO 2 I = 1, NROWS
80 // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term
81 template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
82 bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS>
83 inline RT_API_ATTRS void MatrixTimesMatrix(
84 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
85 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
86 SubscriptValue n, std::size_t xColumnByteStride = 0,
87 std::size_t yColumnByteStride = 0) {
88 using ResultType = CppTypeFor<RCAT, RKIND>;
89 std::memset(product, 0, rows * cols * sizeof *product);
90 const XT *RESTRICT xp0{x};
91 for (SubscriptValue k{0}; k < n; ++k) {
92 ResultType *RESTRICT p{product};
93 for (SubscriptValue j{0}; j < cols; ++j) {
94 const XT *RESTRICT xp{xp0};
95 ResultType yv;
96 if constexpr (!Y_HAS_STRIDED_COLUMNS) {
97 yv = static_cast<ResultType>(y[k + j * n]);
98 } else {
99 yv = static_cast<ResultType>(reinterpret_cast<const YT *>(
100 reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]);
102 for (SubscriptValue i{0}; i < rows; ++i) {
103 *p++ += static_cast<ResultType>(*xp++) * yv;
106 if constexpr (!X_HAS_STRIDED_COLUMNS) {
107 xp0 += rows;
108 } else {
109 xp0 = reinterpret_cast<const XT *>(
110 reinterpret_cast<const char *>(xp0) + xColumnByteStride);
115 RT_DIAG_POP
117 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
118 inline RT_API_ATTRS void MatrixTimesMatrixHelper(
119 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
120 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
121 SubscriptValue n, Fortran::common::optional<std::size_t> xColumnByteStride,
122 Fortran::common::optional<std::size_t> yColumnByteStride) {
123 if (!xColumnByteStride) {
124 if (!yColumnByteStride) {
125 MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, false>(
126 product, rows, cols, x, y, n);
127 } else {
128 MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, true>(
129 product, rows, cols, x, y, n, 0, *yColumnByteStride);
131 } else {
132 if (!yColumnByteStride) {
133 MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, false>(
134 product, rows, cols, x, y, n, *xColumnByteStride);
135 } else {
136 MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, true>(
137 product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride);
142 RT_DIAG_PUSH
143 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
145 // Contiguous numeric matrix*vector multiplication
146 // matrix(rows,n) * column vector(n) -> column vector(rows)
147 // Straightforward algorithm:
148 // DO 1 J = 1, NROWS
149 // RES(J) = 0
150 // DO 1 K = 1, N
151 // 1 RES(J) = RES(J) + X(J,K)*Y(K)
152 // With loop distribution and transposition to avoid the inner
153 // sum reduction and to avoid non-unit strides:
154 // DO 1 J = 1, NROWS
155 // 1 RES(J) = 0
156 // DO 2 K = 1, N
157 // DO 2 J = 1, NROWS
158 // 2 RES(J) = RES(J) + X(J,K)*Y(K)
159 template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
160 bool X_HAS_STRIDED_COLUMNS>
161 inline RT_API_ATTRS void MatrixTimesVector(
162 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
163 SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y,
164 std::size_t xColumnByteStride = 0) {
165 using ResultType = CppTypeFor<RCAT, RKIND>;
166 std::memset(product, 0, rows * sizeof *product);
167 [[maybe_unused]] const XT *RESTRICT xp0{x};
168 for (SubscriptValue k{0}; k < n; ++k) {
169 ResultType *RESTRICT p{product};
170 auto yv{static_cast<ResultType>(*y++)};
171 for (SubscriptValue j{0}; j < rows; ++j) {
172 *p++ += static_cast<ResultType>(*x++) * yv;
174 if constexpr (X_HAS_STRIDED_COLUMNS) {
175 xp0 = reinterpret_cast<const XT *>(
176 reinterpret_cast<const char *>(xp0) + xColumnByteStride);
177 x = xp0;
182 RT_DIAG_POP
184 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
185 inline RT_API_ATTRS void MatrixTimesVectorHelper(
186 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
187 SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y,
188 Fortran::common::optional<std::size_t> xColumnByteStride) {
189 if (!xColumnByteStride) {
190 MatrixTimesVector<RCAT, RKIND, XT, YT, false>(product, rows, n, x, y);
191 } else {
192 MatrixTimesVector<RCAT, RKIND, XT, YT, true>(
193 product, rows, n, x, y, *xColumnByteStride);
197 RT_DIAG_PUSH
198 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
200 // Contiguous numeric vector*matrix multiplication
201 // row vector(n) * matrix(n,cols) -> row vector(cols)
202 // Straightforward algorithm:
203 // DO 1 J = 1, NCOLS
204 // RES(J) = 0
205 // DO 1 K = 1, N
206 // 1 RES(J) = RES(J) + X(K)*Y(K,J)
207 // With loop distribution and transposition to avoid the inner
208 // sum reduction and one non-unit stride (the other remains):
209 // DO 1 J = 1, NCOLS
210 // 1 RES(J) = 0
211 // DO 2 K = 1, N
212 // DO 2 J = 1, NCOLS
213 // 2 RES(J) = RES(J) + X(K)*Y(K,J)
214 template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
215 bool Y_HAS_STRIDED_COLUMNS>
216 inline RT_API_ATTRS void VectorTimesMatrix(
217 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n,
218 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
219 std::size_t yColumnByteStride = 0) {
220 using ResultType = CppTypeFor<RCAT, RKIND>;
221 std::memset(product, 0, cols * sizeof *product);
222 for (SubscriptValue k{0}; k < n; ++k) {
223 ResultType *RESTRICT p{product};
224 auto xv{static_cast<ResultType>(*x++)};
225 const YT *RESTRICT yp{&y[k]};
226 for (SubscriptValue j{0}; j < cols; ++j) {
227 *p++ += xv * static_cast<ResultType>(*yp);
228 if constexpr (!Y_HAS_STRIDED_COLUMNS) {
229 yp += n;
230 } else {
231 yp = reinterpret_cast<const YT *>(
232 reinterpret_cast<const char *>(yp) + yColumnByteStride);
238 RT_DIAG_POP
240 template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
241 bool SPARSE_COLUMNS = false>
242 inline RT_API_ATTRS void VectorTimesMatrixHelper(
243 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n,
244 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
245 Fortran::common::optional<std::size_t> yColumnByteStride) {
246 if (!yColumnByteStride) {
247 VectorTimesMatrix<RCAT, RKIND, XT, YT, false>(product, n, cols, x, y);
248 } else {
249 VectorTimesMatrix<RCAT, RKIND, XT, YT, true>(
250 product, n, cols, x, y, *yColumnByteStride);
254 RT_DIAG_PUSH
255 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
257 // Implements an instance of MATMUL for given argument types.
258 template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
259 typename YT>
260 static inline RT_API_ATTRS void DoMatmul(
261 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result,
262 const Descriptor &x, const Descriptor &y, Terminator &terminator) {
263 int xRank{x.rank()};
264 int yRank{y.rank()};
265 int resRank{xRank + yRank - 2};
266 if (xRank * yRank != 2 * resRank) {
267 terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank);
269 SubscriptValue extent[2]{
270 xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(),
271 resRank == 2 ? y.GetDimension(1).Extent() : 0};
272 if constexpr (IS_ALLOCATING) {
273 result.Establish(
274 RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable);
275 for (int j{0}; j < resRank; ++j) {
276 result.GetDimension(j).SetBounds(1, extent[j]);
278 if (int stat{result.Allocate()}) {
279 terminator.Crash(
280 "MATMUL: could not allocate memory for result; STAT=%d", stat);
282 } else {
283 RUNTIME_CHECK(terminator, resRank == result.rank());
284 RUNTIME_CHECK(
285 terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND));
286 RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]);
287 RUNTIME_CHECK(terminator,
288 resRank == 1 || result.GetDimension(1).Extent() == extent[1]);
290 SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
291 if (n != y.GetDimension(0).Extent()) {
292 // At this point, we know that there's a shape error. There are three
293 // possibilities, x is rank 1, y is rank 1, or both are rank 2.
294 if (xRank == 1) {
295 terminator.Crash("MATMUL: unacceptable operand shapes (%jd, %jdx%jd)",
296 static_cast<std::intmax_t>(n),
297 static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
298 static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
299 } else if (yRank == 1) {
300 terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jd)",
301 static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
302 static_cast<std::intmax_t>(n),
303 static_cast<std::intmax_t>(y.GetDimension(0).Extent()));
304 } else {
305 terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
306 static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
307 static_cast<std::intmax_t>(n),
308 static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
309 static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
312 using WriteResult =
313 CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
314 RKIND>;
315 if constexpr (RCAT != TypeCategory::Logical) {
316 if (x.IsContiguous(1) && y.IsContiguous(1) &&
317 (IS_ALLOCATING || result.IsContiguous())) {
318 // Contiguous numeric matrices (maybe with columns
319 // separated by a stride).
320 Fortran::common::optional<std::size_t> xColumnByteStride;
321 if (!x.IsContiguous()) {
322 // X's columns are strided.
323 SubscriptValue xAt[2]{};
324 x.GetLowerBounds(xAt);
325 xAt[1]++;
326 xColumnByteStride = x.SubscriptsToByteOffset(xAt);
328 Fortran::common::optional<std::size_t> yColumnByteStride;
329 if (!y.IsContiguous()) {
330 // Y's columns are strided.
331 SubscriptValue yAt[2]{};
332 y.GetLowerBounds(yAt);
333 yAt[1]++;
334 yColumnByteStride = y.SubscriptsToByteOffset(yAt);
336 // Note that BLAS GEMM can be used for the strided
337 // columns by setting proper leading dimension size.
338 // This implies that the column stride is divisible
339 // by the element size, which is usually true.
340 if (resRank == 2) { // M*M -> M
341 if (std::is_same_v<XT, YT>) {
342 if constexpr (std::is_same_v<XT, float>) {
343 // TODO: call BLAS-3 SGEMM
344 // TODO: try using CUTLASS for device.
345 } else if constexpr (std::is_same_v<XT, double>) {
346 // TODO: call BLAS-3 DGEMM
347 } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
348 // TODO: call BLAS-3 CGEMM
349 } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
350 // TODO: call BLAS-3 ZGEMM
353 MatrixTimesMatrixHelper<RCAT, RKIND, XT, YT>(
354 result.template OffsetElement<WriteResult>(), extent[0], extent[1],
355 x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride,
356 yColumnByteStride);
357 return;
358 } else if (xRank == 2) { // M*V -> V
359 if (std::is_same_v<XT, YT>) {
360 if constexpr (std::is_same_v<XT, float>) {
361 // TODO: call BLAS-2 SGEMV(x,y)
362 } else if constexpr (std::is_same_v<XT, double>) {
363 // TODO: call BLAS-2 DGEMV(x,y)
364 } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
365 // TODO: call BLAS-2 CGEMV(x,y)
366 } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
367 // TODO: call BLAS-2 ZGEMV(x,y)
370 MatrixTimesVectorHelper<RCAT, RKIND, XT, YT>(
371 result.template OffsetElement<WriteResult>(), extent[0], n,
372 x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride);
373 return;
374 } else { // V*M -> V
375 if (std::is_same_v<XT, YT>) {
376 if constexpr (std::is_same_v<XT, float>) {
377 // TODO: call BLAS-2 SGEMV(y,x)
378 } else if constexpr (std::is_same_v<XT, double>) {
379 // TODO: call BLAS-2 DGEMV(y,x)
380 } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
381 // TODO: call BLAS-2 CGEMV(y,x)
382 } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
383 // TODO: call BLAS-2 ZGEMV(y,x)
386 VectorTimesMatrixHelper<RCAT, RKIND, XT, YT>(
387 result.template OffsetElement<WriteResult>(), n, extent[0],
388 x.OffsetElement<XT>(), y.OffsetElement<YT>(), yColumnByteStride);
389 return;
393 // General algorithms for LOGICAL and noncontiguity
394 SubscriptValue xAt[2], yAt[2], resAt[2];
395 x.GetLowerBounds(xAt);
396 y.GetLowerBounds(yAt);
397 result.GetLowerBounds(resAt);
398 if (resRank == 2) { // M*M -> M
399 SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]};
400 for (SubscriptValue i{0}; i < extent[0]; ++i) {
401 for (SubscriptValue j{0}; j < extent[1]; ++j) {
402 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
403 yAt[1] = y1 + j;
404 for (SubscriptValue k{0}; k < n; ++k) {
405 xAt[1] = x1 + k;
406 yAt[0] = y0 + k;
407 accumulator.Accumulate(xAt, yAt);
409 resAt[1] = res1 + j;
410 *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
412 ++resAt[0];
413 ++xAt[0];
415 } else if (xRank == 2) { // M*V -> V
416 SubscriptValue x1{xAt[1]}, y0{yAt[0]};
417 for (SubscriptValue j{0}; j < extent[0]; ++j) {
418 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
419 for (SubscriptValue k{0}; k < n; ++k) {
420 xAt[1] = x1 + k;
421 yAt[0] = y0 + k;
422 accumulator.Accumulate(xAt, yAt);
424 *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
425 ++resAt[0];
426 ++xAt[0];
428 } else { // V*M -> V
429 SubscriptValue x0{xAt[0]}, y0{yAt[0]};
430 for (SubscriptValue j{0}; j < extent[0]; ++j) {
431 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
432 for (SubscriptValue k{0}; k < n; ++k) {
433 xAt[0] = x0 + k;
434 yAt[0] = y0 + k;
435 accumulator.Accumulate(xAt, yAt);
437 *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
438 ++resAt[0];
439 ++yAt[1];
444 RT_DIAG_POP
446 template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
447 int YKIND>
448 struct MatmulHelper {
449 using ResultDescriptor =
450 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
451 RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
452 const Descriptor &y, const char *sourceFile, int line) const {
453 Terminator terminator{sourceFile, line};
454 auto xCatKind{x.type().GetCategoryAndKind()};
455 auto yCatKind{y.type().GetCategoryAndKind()};
456 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
457 RUNTIME_CHECK(terminator, xCatKind->first == XCAT);
458 RUNTIME_CHECK(terminator, yCatKind->first == YCAT);
459 if constexpr (constexpr auto resultType{
460 GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
461 return DoMatmul<IS_ALLOCATING, resultType->first, resultType->second,
462 CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>(
463 result, x, y, terminator);
465 terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
466 static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
469 } // namespace
471 namespace Fortran::runtime {
472 extern "C" {
473 RT_EXT_API_GROUP_BEGIN
475 #define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
476 void RTDEF(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
477 const Descriptor &x, const Descriptor &y, const char *sourceFile, \
478 int line) { \
479 MatmulHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
480 YKIND>{}(result, x, y, sourceFile, line); \
483 #define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
484 void RTDEF(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
485 const Descriptor &x, const Descriptor &y, const char *sourceFile, \
486 int line) { \
487 MatmulHelper<false, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
488 YKIND>{}(result, x, y, sourceFile, line); \
491 #define MATMUL_FORCE_ALL_TYPES 0
493 #include "flang/Runtime/matmul-instances.inc"
495 RT_EXT_API_GROUP_END
496 } // extern "C"
497 } // namespace Fortran::runtime