1 //===-- runtime/dot-product.cpp -------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 #include "terminator.h"
12 #include "flang/Common/float128.h"
13 #include "flang/Runtime/cpp-type.h"
14 #include "flang/Runtime/descriptor.h"
15 #include "flang/Runtime/reduction.h"
19 namespace Fortran::runtime
{
21 // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
22 // argument; MATMUL does not.
24 // General accumulator for any type and stride; this is not used for
25 // contiguous numeric vectors.
26 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
29 using Result
= AccumulationType
<RCAT
, RKIND
>;
30 RT_API_ATTRS
Accumulator(const Descriptor
&x
, const Descriptor
&y
)
32 RT_API_ATTRS
void AccumulateIndexed(SubscriptValue xAt
, SubscriptValue yAt
) {
33 if constexpr (RCAT
== TypeCategory::Logical
) {
35 (IsLogicalElementTrue(x_
, &xAt
) && IsLogicalElementTrue(y_
, &yAt
));
37 const XT
&xElement
{*x_
.Element
<XT
>(&xAt
)};
38 const YT
&yElement
{*y_
.Element
<YT
>(&yAt
)};
39 if constexpr (RCAT
== TypeCategory::Complex
) {
40 sum_
+= rtcmplx::conj(static_cast<Result
>(xElement
)) *
41 static_cast<Result
>(yElement
);
43 sum_
+= static_cast<Result
>(xElement
) * static_cast<Result
>(yElement
);
47 RT_API_ATTRS Result
GetResult() const { return sum_
; }
50 const Descriptor
&x_
, &y_
;
54 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
55 static inline RT_API_ATTRS CppTypeFor
<RCAT
, RKIND
> DoDotProduct(
56 const Descriptor
&x
, const Descriptor
&y
, Terminator
&terminator
) {
57 using Result
= CppTypeFor
<RCAT
, RKIND
>;
58 RUNTIME_CHECK(terminator
, x
.rank() == 1 && y
.rank() == 1);
59 SubscriptValue n
{x
.GetDimension(0).Extent()};
60 if (SubscriptValue yN
{y
.GetDimension(0).Extent()}; yN
!= n
) {
62 "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
63 static_cast<std::intmax_t>(n
), static_cast<std::intmax_t>(yN
));
65 if constexpr (RCAT
!= TypeCategory::Logical
) {
66 if (x
.GetDimension(0).ByteStride() == sizeof(XT
) &&
67 y
.GetDimension(0).ByteStride() == sizeof(YT
)) {
68 // Contiguous numeric vectors
69 if constexpr (std::is_same_v
<XT
, YT
>) {
70 // Contiguous homogeneous numeric vectors
71 if constexpr (std::is_same_v
<XT
, float>) {
72 // TODO: call BLAS-1 SDOT or SDSDOT
73 } else if constexpr (std::is_same_v
<XT
, double>) {
74 // TODO: call BLAS-1 DDOT
75 } else if constexpr (std::is_same_v
<XT
, rtcmplx::complex<float>>) {
76 // TODO: call BLAS-1 CDOTC
77 } else if constexpr (std::is_same_v
<XT
, rtcmplx::complex<double>>) {
78 // TODO: call BLAS-1 ZDOTC
81 XT
*xp
{x
.OffsetElement
<XT
>(0)};
82 YT
*yp
{y
.OffsetElement
<YT
>(0)};
83 using AccumType
= AccumulationType
<RCAT
, RKIND
>;
85 if constexpr (RCAT
== TypeCategory::Complex
) {
86 for (SubscriptValue j
{0}; j
< n
; ++j
) {
87 // conj() may instantiate its argument twice,
88 // so xp has to be incremented separately.
89 // This is a workaround for an alleged bug in clang,
91 // warning: multiple unsequenced modifications to 'xp'
92 accum
+= rtcmplx::conj(static_cast<AccumType
>(*xp
)) *
93 static_cast<AccumType
>(*yp
++);
97 for (SubscriptValue j
{0}; j
< n
; ++j
) {
99 static_cast<AccumType
>(*xp
++) * static_cast<AccumType
>(*yp
++);
102 return static_cast<Result
>(accum
);
105 // Non-contiguous, heterogeneous, & LOGICAL cases
106 SubscriptValue xAt
{x
.GetDimension(0).LowerBound()};
107 SubscriptValue yAt
{y
.GetDimension(0).LowerBound()};
108 Accumulator
<RCAT
, RKIND
, XT
, YT
> accumulator
{x
, y
};
109 for (SubscriptValue j
{0}; j
< n
; ++j
) {
110 accumulator
.AccumulateIndexed(xAt
++, yAt
++);
112 return static_cast<Result
>(accumulator
.GetResult());
115 template <TypeCategory RCAT
, int RKIND
> struct DotProduct
{
116 using Result
= CppTypeFor
<RCAT
, RKIND
>;
117 template <TypeCategory XCAT
, int XKIND
> struct DP1
{
118 template <TypeCategory YCAT
, int YKIND
> struct DP2
{
119 RT_API_ATTRS Result
operator()(const Descriptor
&x
, const Descriptor
&y
,
120 Terminator
&terminator
) const {
121 if constexpr (constexpr auto resultType
{
122 GetResultType(XCAT
, XKIND
, YCAT
, YKIND
)}) {
123 if constexpr (resultType
->first
== RCAT
&&
124 (resultType
->second
<= RKIND
|| RCAT
== TypeCategory::Logical
)) {
125 return DoDotProduct
<RCAT
, RKIND
, CppTypeFor
<XCAT
, XKIND
>,
126 CppTypeFor
<YCAT
, YKIND
>>(x
, y
, terminator
);
130 "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
131 static_cast<int>(RCAT
), RKIND
, static_cast<int>(XCAT
), XKIND
,
132 static_cast<int>(YCAT
), YKIND
);
135 RT_API_ATTRS Result
operator()(const Descriptor
&x
, const Descriptor
&y
,
136 Terminator
&terminator
, TypeCategory yCat
, int yKind
) const {
137 return ApplyType
<DP2
, Result
>(yCat
, yKind
, terminator
, x
, y
, terminator
);
140 RT_API_ATTRS Result
operator()(const Descriptor
&x
, const Descriptor
&y
,
141 const char *source
, int line
) const {
142 Terminator terminator
{source
, line
};
143 if (RCAT
!= TypeCategory::Logical
&& x
.type() == y
.type()) {
144 // No conversions needed, operands and result have same known type
145 return typename DP1
<RCAT
, RKIND
>::template DP2
<RCAT
, RKIND
>{}(
148 auto xCatKind
{x
.type().GetCategoryAndKind()};
149 auto yCatKind
{y
.type().GetCategoryAndKind()};
150 RUNTIME_CHECK(terminator
, xCatKind
.has_value() && yCatKind
.has_value());
151 return ApplyType
<DP1
, Result
>(xCatKind
->first
, xCatKind
->second
,
152 terminator
, x
, y
, terminator
, yCatKind
->first
, yCatKind
->second
);
158 RT_EXT_API_GROUP_BEGIN
160 CppTypeFor
<TypeCategory::Integer
, 1> RTDEF(DotProductInteger1
)(
161 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
162 return DotProduct
<TypeCategory::Integer
, 1>{}(x
, y
, source
, line
);
164 CppTypeFor
<TypeCategory::Integer
, 2> RTDEF(DotProductInteger2
)(
165 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
166 return DotProduct
<TypeCategory::Integer
, 2>{}(x
, y
, source
, line
);
168 CppTypeFor
<TypeCategory::Integer
, 4> RTDEF(DotProductInteger4
)(
169 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
170 return DotProduct
<TypeCategory::Integer
, 4>{}(x
, y
, source
, line
);
172 CppTypeFor
<TypeCategory::Integer
, 8> RTDEF(DotProductInteger8
)(
173 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
174 return DotProduct
<TypeCategory::Integer
, 8>{}(x
, y
, source
, line
);
176 #ifdef __SIZEOF_INT128__
177 CppTypeFor
<TypeCategory::Integer
, 16> RTDEF(DotProductInteger16
)(
178 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
179 return DotProduct
<TypeCategory::Integer
, 16>{}(x
, y
, source
, line
);
183 // TODO: REAL/COMPLEX(2 & 3)
184 // Intermediate results and operations are at least 64 bits
185 CppTypeFor
<TypeCategory::Real
, 4> RTDEF(DotProductReal4
)(
186 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
187 return DotProduct
<TypeCategory::Real
, 4>{}(x
, y
, source
, line
);
189 CppTypeFor
<TypeCategory::Real
, 8> RTDEF(DotProductReal8
)(
190 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
191 return DotProduct
<TypeCategory::Real
, 8>{}(x
, y
, source
, line
);
194 CppTypeFor
<TypeCategory::Real
, 10> RTDEF(DotProductReal10
)(
195 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
196 return DotProduct
<TypeCategory::Real
, 10>{}(x
, y
, source
, line
);
199 #if HAS_LDBL128 || HAS_FLOAT128
200 CppTypeFor
<TypeCategory::Real
, 16> RTDEF(DotProductReal16
)(
201 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
202 return DotProduct
<TypeCategory::Real
, 16>{}(x
, y
, source
, line
);
206 void RTDEF(CppDotProductComplex4
)(CppTypeFor
<TypeCategory::Complex
, 4> &result
,
207 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
208 result
= DotProduct
<TypeCategory::Complex
, 4>{}(x
, y
, source
, line
);
210 void RTDEF(CppDotProductComplex8
)(CppTypeFor
<TypeCategory::Complex
, 8> &result
,
211 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
212 result
= DotProduct
<TypeCategory::Complex
, 8>{}(x
, y
, source
, line
);
215 void RTDEF(CppDotProductComplex10
)(
216 CppTypeFor
<TypeCategory::Complex
, 10> &result
, const Descriptor
&x
,
217 const Descriptor
&y
, const char *source
, int line
) {
218 result
= DotProduct
<TypeCategory::Complex
, 10>{}(x
, y
, source
, line
);
221 #if HAS_LDBL128 || HAS_FLOAT128
222 void RTDEF(CppDotProductComplex16
)(
223 CppTypeFor
<TypeCategory::Complex
, 16> &result
, const Descriptor
&x
,
224 const Descriptor
&y
, const char *source
, int line
) {
225 result
= DotProduct
<TypeCategory::Complex
, 16>{}(x
, y
, source
, line
);
229 bool RTDEF(DotProductLogical
)(
230 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
231 return DotProduct
<TypeCategory::Logical
, 1>{}(x
, y
, source
, line
);
236 } // namespace Fortran::runtime