1 //===-- runtime/matmul-transpose.cpp --------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Implements a fused matmul-transpose operation
11 // There are two main entry points; one establishes a descriptor for the
12 // result and allocates it, and the other expects a result descriptor that
13 // points to existing storage.
15 // This implementation must handle all combinations of numeric types and
16 // kinds (100 - 165 cases depending on the target), plus all combinations
17 // of logical kinds (16). A single template undergoes many instantiations
18 // to cover all of the valid possibilities.
20 // The usefulness of this optimization should be reviewed once Matmul is swapped
21 // to use the faster BLAS routines.
23 #include "flang/Runtime/matmul-transpose.h"
24 #include "terminator.h"
26 #include "flang/Runtime/c-or-cpp.h"
27 #include "flang/Runtime/cpp-type.h"
28 #include "flang/Runtime/descriptor.h"
32 using namespace Fortran::runtime
;
34 // Contiguous numeric TRANSPOSE(matrix)*matrix multiplication
35 // TRANSPOSE(matrix(n, rows)) * matrix(n,cols) ->
36 // matrix(rows, n) * matrix(n,cols) -> matrix(rows,cols)
37 // The transpose is implemented by swapping the indices of accesses into the LHS
39 // Straightforward algorithm:
44 // 1 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J)
46 // With loop distribution and transposition to avoid the inner sum
47 // reduction and to avoid non-unit strides:
54 // 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term
55 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
56 inline static void MatrixTransposedTimesMatrix(
57 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
58 SubscriptValue cols
, const XT
*RESTRICT x
, const YT
*RESTRICT y
,
60 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
62 std::memset(product
, 0, rows
* cols
* sizeof *product
);
63 for (SubscriptValue j
{0}; j
< cols
; ++j
) {
64 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
65 for (SubscriptValue k
{0}; k
< n
; ++k
) {
66 ResultType x_ki
= static_cast<ResultType
>(x
[i
* n
+ k
]);
67 ResultType y_kj
= static_cast<ResultType
>(y
[j
* n
+ k
]);
68 product
[j
* rows
+ i
] += x_ki
* y_kj
;
74 // Contiguous numeric matrix*vector multiplication
75 // matrix(rows,n) * column vector(n) -> column vector(rows)
76 // Straightforward algorithm:
80 // 1 RES(I) = RES(I) + X(K,I)*Y(K)
81 // With loop distribution and transposition to avoid the inner
82 // sum reduction and to avoid non-unit strides:
87 // 2 RES(I) = RES(I) + X(K,I)*Y(K)
88 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
89 inline static void MatrixTransposedTimesVector(
90 CppTypeFor
<RCAT
, RKIND
> *RESTRICT product
, SubscriptValue rows
,
91 SubscriptValue n
, const XT
*RESTRICT x
, const YT
*RESTRICT y
) {
92 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
93 std::memset(product
, 0, rows
* sizeof *product
);
94 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
95 for (SubscriptValue k
{0}; k
< n
; ++k
) {
96 ResultType x_ki
= static_cast<ResultType
>(x
[i
* n
+ k
]);
97 ResultType y_k
= static_cast<ResultType
>(y
[k
]);
98 product
[i
] += x_ki
* y_k
;
103 // Implements an instance of MATMUL for given argument types.
104 template <bool IS_ALLOCATING
, TypeCategory RCAT
, int RKIND
, typename XT
,
106 inline static void DoMatmulTranspose(
107 std::conditional_t
<IS_ALLOCATING
, Descriptor
, const Descriptor
> &result
,
108 const Descriptor
&x
, const Descriptor
&y
, Terminator
&terminator
) {
111 int resRank
{xRank
+ yRank
- 2};
112 if (xRank
* yRank
!= 2 * resRank
) {
113 terminator
.Crash("MATMUL: bad argument ranks (%d * %d)", xRank
, yRank
);
115 SubscriptValue extent
[2]{x
.GetDimension(1).Extent(),
116 resRank
== 2 ? y
.GetDimension(1).Extent() : 0};
117 if constexpr (IS_ALLOCATING
) {
119 RCAT
, RKIND
, nullptr, resRank
, extent
, CFI_attribute_allocatable
);
120 for (int j
{0}; j
< resRank
; ++j
) {
121 result
.GetDimension(j
).SetBounds(1, extent
[j
]);
123 if (int stat
{result
.Allocate()}) {
125 "MATMUL: could not allocate memory for result; STAT=%d", stat
);
128 RUNTIME_CHECK(terminator
, resRank
== result
.rank());
130 terminator
, result
.ElementBytes() == static_cast<std::size_t>(RKIND
));
131 RUNTIME_CHECK(terminator
, result
.GetDimension(0).Extent() == extent
[0]);
132 RUNTIME_CHECK(terminator
,
133 resRank
== 1 || result
.GetDimension(1).Extent() == extent
[1]);
135 SubscriptValue n
{x
.GetDimension(0).Extent()};
136 if (n
!= y
.GetDimension(0).Extent()) {
137 terminator
.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
138 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
139 static_cast<std::intmax_t>(x
.GetDimension(1).Extent()),
140 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
141 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
144 CppTypeFor
<RCAT
== TypeCategory::Logical
? TypeCategory::Integer
: RCAT
,
146 const SubscriptValue rows
{extent
[0]};
147 const SubscriptValue cols
{extent
[1]};
148 if constexpr (RCAT
!= TypeCategory::Logical
) {
149 if (x
.IsContiguous() && y
.IsContiguous() &&
150 (IS_ALLOCATING
|| result
.IsContiguous())) {
151 // Contiguous numeric matrices
152 if (resRank
== 2) { // M*M -> M
153 MatrixTransposedTimesMatrix
<RCAT
, RKIND
, XT
, YT
>(
154 result
.template OffsetElement
<WriteResult
>(), rows
, cols
,
155 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>(), n
);
158 if (xRank
== 2) { // M*V -> V
159 MatrixTransposedTimesVector
<RCAT
, RKIND
, XT
, YT
>(
160 result
.template OffsetElement
<WriteResult
>(), rows
, n
,
161 x
.OffsetElement
<XT
>(), y
.OffsetElement
<YT
>());
164 // else V*M -> V (not allowed because TRANSPOSE() is only defined for rank
166 terminator
.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
167 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
168 static_cast<std::intmax_t>(n
),
169 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
170 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
174 // General algorithms for LOGICAL and noncontiguity
175 SubscriptValue xLB
[2], yLB
[2], resLB
[2];
176 x
.GetLowerBounds(xLB
);
177 y
.GetLowerBounds(yLB
);
178 result
.GetLowerBounds(resLB
);
179 using ResultType
= CppTypeFor
<RCAT
, RKIND
>;
180 if (resRank
== 2) { // M*M -> M
181 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
182 for (SubscriptValue j
{0}; j
< cols
; ++j
) {
184 if constexpr (RCAT
== TypeCategory::Logical
) {
190 for (SubscriptValue k
{0}; k
< n
; ++k
) {
191 SubscriptValue xAt
[2]{k
+ xLB
[0], i
+ xLB
[1]};
192 SubscriptValue yAt
[2]{k
+ yLB
[0], j
+ yLB
[1]};
193 if constexpr (RCAT
== TypeCategory::Logical
) {
194 ResultType x_ki
= IsLogicalElementTrue(x
, xAt
);
195 ResultType y_kj
= IsLogicalElementTrue(y
, yAt
);
196 res_ij
= res_ij
|| (x_ki
&& y_kj
);
198 ResultType x_ki
= static_cast<ResultType
>(*x
.Element
<XT
>(xAt
));
199 ResultType y_kj
= static_cast<ResultType
>(*y
.Element
<YT
>(yAt
));
200 res_ij
+= x_ki
* y_kj
;
203 SubscriptValue resAt
[2]{i
+ resLB
[0], j
+ resLB
[1]};
204 *result
.template Element
<WriteResult
>(resAt
) = res_ij
;
207 } else if (xRank
== 2) { // M*V -> V
208 for (SubscriptValue i
{0}; i
< rows
; ++i
) {
210 if constexpr (RCAT
== TypeCategory::Logical
) {
216 for (SubscriptValue k
{0}; k
< n
; ++k
) {
217 SubscriptValue xAt
[2]{k
+ xLB
[0], i
+ xLB
[1]};
218 SubscriptValue yAt
[1]{k
+ yLB
[0]};
219 if constexpr (RCAT
== TypeCategory::Logical
) {
220 ResultType x_ki
= IsLogicalElementTrue(x
, xAt
);
221 ResultType y_k
= IsLogicalElementTrue(y
, yAt
);
222 res_i
= res_i
|| (x_ki
&& y_k
);
224 ResultType x_ki
= static_cast<ResultType
>(*x
.Element
<XT
>(xAt
));
225 ResultType y_k
= static_cast<ResultType
>(*y
.Element
<YT
>(yAt
));
229 SubscriptValue resAt
[1]{i
+ resLB
[0]};
230 *result
.template Element
<WriteResult
>(resAt
) = res_i
;
233 // TRANSPOSE(V) not allowed by fortran standard
234 terminator
.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
235 static_cast<std::intmax_t>(x
.GetDimension(0).Extent()),
236 static_cast<std::intmax_t>(n
),
237 static_cast<std::intmax_t>(y
.GetDimension(0).Extent()),
238 static_cast<std::intmax_t>(y
.GetDimension(1).Extent()));
242 // Maps the dynamic type information from the arguments' descriptors
243 // to the right instantiation of DoMatmul() for valid combinations of
245 template <bool IS_ALLOCATING
> struct MatmulTranspose
{
246 using ResultDescriptor
=
247 std::conditional_t
<IS_ALLOCATING
, Descriptor
, const Descriptor
>;
248 template <TypeCategory XCAT
, int XKIND
> struct MM1
{
249 template <TypeCategory YCAT
, int YKIND
> struct MM2
{
250 void operator()(ResultDescriptor
&result
, const Descriptor
&x
,
251 const Descriptor
&y
, Terminator
&terminator
) const {
252 if constexpr (constexpr auto resultType
{
253 GetResultType(XCAT
, XKIND
, YCAT
, YKIND
)}) {
254 if constexpr (Fortran::common::IsNumericTypeCategory(
255 resultType
->first
) ||
256 resultType
->first
== TypeCategory::Logical
) {
257 return DoMatmulTranspose
<IS_ALLOCATING
, resultType
->first
,
258 resultType
->second
, CppTypeFor
<XCAT
, XKIND
>,
259 CppTypeFor
<YCAT
, YKIND
>>(result
, x
, y
, terminator
);
262 terminator
.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
263 static_cast<int>(XCAT
), XKIND
, static_cast<int>(YCAT
), YKIND
);
266 void operator()(ResultDescriptor
&result
, const Descriptor
&x
,
267 const Descriptor
&y
, Terminator
&terminator
, TypeCategory yCat
,
269 ApplyType
<MM2
, void>(yCat
, yKind
, terminator
, result
, x
, y
, terminator
);
272 void operator()(ResultDescriptor
&result
, const Descriptor
&x
,
273 const Descriptor
&y
, const char *sourceFile
, int line
) const {
274 Terminator terminator
{sourceFile
, line
};
275 auto xCatKind
{x
.type().GetCategoryAndKind()};
276 auto yCatKind
{y
.type().GetCategoryAndKind()};
277 RUNTIME_CHECK(terminator
, xCatKind
.has_value() && yCatKind
.has_value());
278 ApplyType
<MM1
, void>(xCatKind
->first
, xCatKind
->second
, terminator
, result
,
279 x
, y
, terminator
, yCatKind
->first
, yCatKind
->second
);
284 namespace Fortran::runtime
{
286 void RTNAME(MatmulTranspose
)(Descriptor
&result
, const Descriptor
&x
,
287 const Descriptor
&y
, const char *sourceFile
, int line
) {
288 MatmulTranspose
<true>{}(result
, x
, y
, sourceFile
, line
);
290 void RTNAME(MatmulTransposeDirect
)(const Descriptor
&result
,
291 const Descriptor
&x
, const Descriptor
&y
, const char *sourceFile
,
293 MatmulTranspose
<false>{}(result
, x
, y
, sourceFile
, line
);
296 } // namespace Fortran::runtime