[lldb] Add ability to hide the root name of a value
[llvm-project.git] / flang / runtime / matmul-transpose.cpp
blob345e3d8b41ac3301d7ea6949ae5922d9b4fd7485
1 //===-- runtime/matmul-transpose.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 // Implements a fused matmul-transpose operation
11 // There are two main entry points; one establishes a descriptor for the
12 // result and allocates it, and the other expects a result descriptor that
13 // points to existing storage.
15 // This implementation must handle all combinations of numeric types and
16 // kinds (100 - 165 cases depending on the target), plus all combinations
17 // of logical kinds (16). A single template undergoes many instantiations
18 // to cover all of the valid possibilities.
20 // The usefulness of this optimization should be reviewed once Matmul is swapped
21 // to use the faster BLAS routines.
23 #include "flang/Runtime/matmul-transpose.h"
24 #include "terminator.h"
25 #include "tools.h"
26 #include "flang/Runtime/c-or-cpp.h"
27 #include "flang/Runtime/cpp-type.h"
28 #include "flang/Runtime/descriptor.h"
29 #include <cstring>
31 namespace {
32 using namespace Fortran::runtime;
34 // Contiguous numeric TRANSPOSE(matrix)*matrix multiplication
35 // TRANSPOSE(matrix(n, rows)) * matrix(n,cols) ->
36 // matrix(rows, n) * matrix(n,cols) -> matrix(rows,cols)
37 // The transpose is implemented by swapping the indices of accesses into the LHS
39 // Straightforward algorithm:
40 // DO 1 I = 1, NROWS
41 // DO 1 J = 1, NCOLS
42 // RES(I,J) = 0
43 // DO 1 K = 1, N
44 // 1 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J)
46 // With loop distribution and transposition to avoid the inner sum
47 // reduction and to avoid non-unit strides:
48 // DO 1 I = 1, NROWS
49 // DO 1 J = 1, NCOLS
50 // 1 RES(I,J) = 0
51 // DO 2 J = 1, NCOLS
52 // DO 2 I = 1, NROWS
53 // DO 2 K = 1, N
54 // 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term
55 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
56 inline static void MatrixTransposedTimesMatrix(
57 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
58 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
59 SubscriptValue n) {
60 using ResultType = CppTypeFor<RCAT, RKIND>;
62 std::memset(product, 0, rows * cols * sizeof *product);
63 for (SubscriptValue j{0}; j < cols; ++j) {
64 for (SubscriptValue i{0}; i < rows; ++i) {
65 for (SubscriptValue k{0}; k < n; ++k) {
66 ResultType x_ki = static_cast<ResultType>(x[i * n + k]);
67 ResultType y_kj = static_cast<ResultType>(y[j * n + k]);
68 product[j * rows + i] += x_ki * y_kj;
74 // Contiguous numeric matrix*vector multiplication
75 // matrix(rows,n) * column vector(n) -> column vector(rows)
76 // Straightforward algorithm:
77 // DO 1 I = 1, NROWS
78 // RES(I) = 0
79 // DO 1 K = 1, N
80 // 1 RES(I) = RES(I) + X(K,I)*Y(K)
81 // With loop distribution and transposition to avoid the inner
82 // sum reduction and to avoid non-unit strides:
83 // DO 1 I = 1, NROWS
84 // 1 RES(I) = 0
85 // DO 2 I = 1, NROWS
86 // DO 2 K = 1, N
87 // 2 RES(I) = RES(I) + X(K,I)*Y(K)
88 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
89 inline static void MatrixTransposedTimesVector(
90 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
91 SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y) {
92 using ResultType = CppTypeFor<RCAT, RKIND>;
93 std::memset(product, 0, rows * sizeof *product);
94 for (SubscriptValue i{0}; i < rows; ++i) {
95 for (SubscriptValue k{0}; k < n; ++k) {
96 ResultType x_ki = static_cast<ResultType>(x[i * n + k]);
97 ResultType y_k = static_cast<ResultType>(y[k]);
98 product[i] += x_ki * y_k;
103 // Implements an instance of MATMUL for given argument types.
104 template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
105 typename YT>
106 inline static void DoMatmulTranspose(
107 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result,
108 const Descriptor &x, const Descriptor &y, Terminator &terminator) {
109 int xRank{x.rank()};
110 int yRank{y.rank()};
111 int resRank{xRank + yRank - 2};
112 if (xRank * yRank != 2 * resRank) {
113 terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank);
115 SubscriptValue extent[2]{x.GetDimension(1).Extent(),
116 resRank == 2 ? y.GetDimension(1).Extent() : 0};
117 if constexpr (IS_ALLOCATING) {
118 result.Establish(
119 RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable);
120 for (int j{0}; j < resRank; ++j) {
121 result.GetDimension(j).SetBounds(1, extent[j]);
123 if (int stat{result.Allocate()}) {
124 terminator.Crash(
125 "MATMUL: could not allocate memory for result; STAT=%d", stat);
127 } else {
128 RUNTIME_CHECK(terminator, resRank == result.rank());
129 RUNTIME_CHECK(
130 terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND));
131 RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]);
132 RUNTIME_CHECK(terminator,
133 resRank == 1 || result.GetDimension(1).Extent() == extent[1]);
135 SubscriptValue n{x.GetDimension(0).Extent()};
136 if (n != y.GetDimension(0).Extent()) {
137 terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
138 static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
139 static_cast<std::intmax_t>(x.GetDimension(1).Extent()),
140 static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
141 static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
143 using WriteResult =
144 CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
145 RKIND>;
146 const SubscriptValue rows{extent[0]};
147 const SubscriptValue cols{extent[1]};
148 if constexpr (RCAT != TypeCategory::Logical) {
149 if (x.IsContiguous() && y.IsContiguous() &&
150 (IS_ALLOCATING || result.IsContiguous())) {
151 // Contiguous numeric matrices
152 if (resRank == 2) { // M*M -> M
153 MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT>(
154 result.template OffsetElement<WriteResult>(), rows, cols,
155 x.OffsetElement<XT>(), y.OffsetElement<YT>(), n);
156 return;
158 if (xRank == 2) { // M*V -> V
159 MatrixTransposedTimesVector<RCAT, RKIND, XT, YT>(
160 result.template OffsetElement<WriteResult>(), rows, n,
161 x.OffsetElement<XT>(), y.OffsetElement<YT>());
162 return;
164 // else V*M -> V (not allowed because TRANSPOSE() is only defined for rank
165 // 1 matrices
166 terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
167 static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
168 static_cast<std::intmax_t>(n),
169 static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
170 static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
171 return;
174 // General algorithms for LOGICAL and noncontiguity
175 SubscriptValue xLB[2], yLB[2], resLB[2];
176 x.GetLowerBounds(xLB);
177 y.GetLowerBounds(yLB);
178 result.GetLowerBounds(resLB);
179 using ResultType = CppTypeFor<RCAT, RKIND>;
180 if (resRank == 2) { // M*M -> M
181 for (SubscriptValue i{0}; i < rows; ++i) {
182 for (SubscriptValue j{0}; j < cols; ++j) {
183 ResultType res_ij;
184 if constexpr (RCAT == TypeCategory::Logical) {
185 res_ij = false;
186 } else {
187 res_ij = 0;
190 for (SubscriptValue k{0}; k < n; ++k) {
191 SubscriptValue xAt[2]{k + xLB[0], i + xLB[1]};
192 SubscriptValue yAt[2]{k + yLB[0], j + yLB[1]};
193 if constexpr (RCAT == TypeCategory::Logical) {
194 ResultType x_ki = IsLogicalElementTrue(x, xAt);
195 ResultType y_kj = IsLogicalElementTrue(y, yAt);
196 res_ij = res_ij || (x_ki && y_kj);
197 } else {
198 ResultType x_ki = static_cast<ResultType>(*x.Element<XT>(xAt));
199 ResultType y_kj = static_cast<ResultType>(*y.Element<YT>(yAt));
200 res_ij += x_ki * y_kj;
203 SubscriptValue resAt[2]{i + resLB[0], j + resLB[1]};
204 *result.template Element<WriteResult>(resAt) = res_ij;
207 } else if (xRank == 2) { // M*V -> V
208 for (SubscriptValue i{0}; i < rows; ++i) {
209 ResultType res_i;
210 if constexpr (RCAT == TypeCategory::Logical) {
211 res_i = false;
212 } else {
213 res_i = 0;
216 for (SubscriptValue k{0}; k < n; ++k) {
217 SubscriptValue xAt[2]{k + xLB[0], i + xLB[1]};
218 SubscriptValue yAt[1]{k + yLB[0]};
219 if constexpr (RCAT == TypeCategory::Logical) {
220 ResultType x_ki = IsLogicalElementTrue(x, xAt);
221 ResultType y_k = IsLogicalElementTrue(y, yAt);
222 res_i = res_i || (x_ki && y_k);
223 } else {
224 ResultType x_ki = static_cast<ResultType>(*x.Element<XT>(xAt));
225 ResultType y_k = static_cast<ResultType>(*y.Element<YT>(yAt));
226 res_i += x_ki * y_k;
229 SubscriptValue resAt[1]{i + resLB[0]};
230 *result.template Element<WriteResult>(resAt) = res_i;
232 } else { // V*M -> V
233 // TRANSPOSE(V) not allowed by fortran standard
234 terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
235 static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
236 static_cast<std::intmax_t>(n),
237 static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
238 static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
242 // Maps the dynamic type information from the arguments' descriptors
243 // to the right instantiation of DoMatmul() for valid combinations of
244 // types.
245 template <bool IS_ALLOCATING> struct MatmulTranspose {
246 using ResultDescriptor =
247 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
248 template <TypeCategory XCAT, int XKIND> struct MM1 {
249 template <TypeCategory YCAT, int YKIND> struct MM2 {
250 void operator()(ResultDescriptor &result, const Descriptor &x,
251 const Descriptor &y, Terminator &terminator) const {
252 if constexpr (constexpr auto resultType{
253 GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
254 if constexpr (Fortran::common::IsNumericTypeCategory(
255 resultType->first) ||
256 resultType->first == TypeCategory::Logical) {
257 return DoMatmulTranspose<IS_ALLOCATING, resultType->first,
258 resultType->second, CppTypeFor<XCAT, XKIND>,
259 CppTypeFor<YCAT, YKIND>>(result, x, y, terminator);
262 terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
263 static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
266 void operator()(ResultDescriptor &result, const Descriptor &x,
267 const Descriptor &y, Terminator &terminator, TypeCategory yCat,
268 int yKind) const {
269 ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator);
272 void operator()(ResultDescriptor &result, const Descriptor &x,
273 const Descriptor &y, const char *sourceFile, int line) const {
274 Terminator terminator{sourceFile, line};
275 auto xCatKind{x.type().GetCategoryAndKind()};
276 auto yCatKind{y.type().GetCategoryAndKind()};
277 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
278 ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result,
279 x, y, terminator, yCatKind->first, yCatKind->second);
282 } // namespace
284 namespace Fortran::runtime {
285 extern "C" {
286 void RTNAME(MatmulTranspose)(Descriptor &result, const Descriptor &x,
287 const Descriptor &y, const char *sourceFile, int line) {
288 MatmulTranspose<true>{}(result, x, y, sourceFile, line);
290 void RTNAME(MatmulTransposeDirect)(const Descriptor &result,
291 const Descriptor &x, const Descriptor &y, const char *sourceFile,
292 int line) {
293 MatmulTranspose<false>{}(result, x, y, sourceFile, line);
295 } // extern "C"
296 } // namespace Fortran::runtime