1 //===-- runtime/dot-product.cpp -------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 #include "terminator.h"
12 #include "flang/Common/float128.h"
13 #include "flang/Runtime/cpp-type.h"
14 #include "flang/Runtime/descriptor.h"
15 #include "flang/Runtime/reduction.h"
19 namespace Fortran::runtime
{
21 // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
22 // argument; MATMUL does not.
24 // General accumulator for any type and stride; this is not used for
25 // contiguous numeric vectors.
26 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
29 using Result
= AccumulationType
<RCAT
, RKIND
>;
30 Accumulator(const Descriptor
&x
, const Descriptor
&y
) : x_
{x
}, y_
{y
} {}
31 void AccumulateIndexed(SubscriptValue xAt
, SubscriptValue yAt
) {
32 if constexpr (RCAT
== TypeCategory::Logical
) {
34 (IsLogicalElementTrue(x_
, &xAt
) && IsLogicalElementTrue(y_
, &yAt
));
36 const XT
&xElement
{*x_
.Element
<XT
>(&xAt
)};
37 const YT
&yElement
{*y_
.Element
<YT
>(&yAt
)};
38 if constexpr (RCAT
== TypeCategory::Complex
) {
39 sum_
+= std::conj(static_cast<Result
>(xElement
)) *
40 static_cast<Result
>(yElement
);
42 sum_
+= static_cast<Result
>(xElement
) * static_cast<Result
>(yElement
);
46 Result
GetResult() const { return sum_
; }
49 const Descriptor
&x_
, &y_
;
53 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
54 static inline CppTypeFor
<RCAT
, RKIND
> DoDotProduct(
55 const Descriptor
&x
, const Descriptor
&y
, Terminator
&terminator
) {
56 using Result
= CppTypeFor
<RCAT
, RKIND
>;
57 RUNTIME_CHECK(terminator
, x
.rank() == 1 && y
.rank() == 1);
58 SubscriptValue n
{x
.GetDimension(0).Extent()};
59 if (SubscriptValue yN
{y
.GetDimension(0).Extent()}; yN
!= n
) {
61 "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
62 static_cast<std::intmax_t>(n
), static_cast<std::intmax_t>(yN
));
64 if constexpr (RCAT
!= TypeCategory::Logical
) {
65 if (x
.GetDimension(0).ByteStride() == sizeof(XT
) &&
66 y
.GetDimension(0).ByteStride() == sizeof(YT
)) {
67 // Contiguous numeric vectors
68 if constexpr (std::is_same_v
<XT
, YT
>) {
69 // Contiguous homogeneous numeric vectors
70 if constexpr (std::is_same_v
<XT
, float>) {
71 // TODO: call BLAS-1 SDOT or SDSDOT
72 } else if constexpr (std::is_same_v
<XT
, double>) {
73 // TODO: call BLAS-1 DDOT
74 } else if constexpr (std::is_same_v
<XT
, std::complex<float>>) {
75 // TODO: call BLAS-1 CDOTC
76 } else if constexpr (std::is_same_v
<XT
, std::complex<double>>) {
77 // TODO: call BLAS-1 ZDOTC
80 XT
*xp
{x
.OffsetElement
<XT
>(0)};
81 YT
*yp
{y
.OffsetElement
<YT
>(0)};
82 using AccumType
= AccumulationType
<RCAT
, RKIND
>;
84 if constexpr (RCAT
== TypeCategory::Complex
) {
85 for (SubscriptValue j
{0}; j
< n
; ++j
) {
86 accum
+= std::conj(static_cast<AccumType
>(*xp
++)) *
87 static_cast<AccumType
>(*yp
++);
90 for (SubscriptValue j
{0}; j
< n
; ++j
) {
92 static_cast<AccumType
>(*xp
++) * static_cast<AccumType
>(*yp
++);
95 return static_cast<Result
>(accum
);
98 // Non-contiguous, heterogeneous, & LOGICAL cases
99 SubscriptValue xAt
{x
.GetDimension(0).LowerBound()};
100 SubscriptValue yAt
{y
.GetDimension(0).LowerBound()};
101 Accumulator
<RCAT
, RKIND
, XT
, YT
> accumulator
{x
, y
};
102 for (SubscriptValue j
{0}; j
< n
; ++j
) {
103 accumulator
.AccumulateIndexed(xAt
++, yAt
++);
105 return static_cast<Result
>(accumulator
.GetResult());
108 template <TypeCategory RCAT
, int RKIND
> struct DotProduct
{
109 using Result
= CppTypeFor
<RCAT
, RKIND
>;
110 template <TypeCategory XCAT
, int XKIND
> struct DP1
{
111 template <TypeCategory YCAT
, int YKIND
> struct DP2
{
112 Result
operator()(const Descriptor
&x
, const Descriptor
&y
,
113 Terminator
&terminator
) const {
114 if constexpr (constexpr auto resultType
{
115 GetResultType(XCAT
, XKIND
, YCAT
, YKIND
)}) {
116 if constexpr (resultType
->first
== RCAT
&&
117 (resultType
->second
<= RKIND
|| RCAT
== TypeCategory::Logical
)) {
118 return DoDotProduct
<RCAT
, RKIND
, CppTypeFor
<XCAT
, XKIND
>,
119 CppTypeFor
<YCAT
, YKIND
>>(x
, y
, terminator
);
123 "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
124 static_cast<int>(RCAT
), RKIND
, static_cast<int>(XCAT
), XKIND
,
125 static_cast<int>(YCAT
), YKIND
);
128 Result
operator()(const Descriptor
&x
, const Descriptor
&y
,
129 Terminator
&terminator
, TypeCategory yCat
, int yKind
) const {
130 return ApplyType
<DP2
, Result
>(yCat
, yKind
, terminator
, x
, y
, terminator
);
133 Result
operator()(const Descriptor
&x
, const Descriptor
&y
,
134 const char *source
, int line
) const {
135 Terminator terminator
{source
, line
};
136 if (RCAT
!= TypeCategory::Logical
&& x
.type() == y
.type()) {
137 // No conversions needed, operands and result have same known type
138 return typename DP1
<RCAT
, RKIND
>::template DP2
<RCAT
, RKIND
>{}(
141 auto xCatKind
{x
.type().GetCategoryAndKind()};
142 auto yCatKind
{y
.type().GetCategoryAndKind()};
143 RUNTIME_CHECK(terminator
, xCatKind
.has_value() && yCatKind
.has_value());
144 return ApplyType
<DP1
, Result
>(xCatKind
->first
, xCatKind
->second
,
145 terminator
, x
, y
, terminator
, yCatKind
->first
, yCatKind
->second
);
151 CppTypeFor
<TypeCategory::Integer
, 1> RTNAME(DotProductInteger1
)(
152 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
153 return DotProduct
<TypeCategory::Integer
, 1>{}(x
, y
, source
, line
);
155 CppTypeFor
<TypeCategory::Integer
, 2> RTNAME(DotProductInteger2
)(
156 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
157 return DotProduct
<TypeCategory::Integer
, 2>{}(x
, y
, source
, line
);
159 CppTypeFor
<TypeCategory::Integer
, 4> RTNAME(DotProductInteger4
)(
160 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
161 return DotProduct
<TypeCategory::Integer
, 4>{}(x
, y
, source
, line
);
163 CppTypeFor
<TypeCategory::Integer
, 8> RTNAME(DotProductInteger8
)(
164 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
165 return DotProduct
<TypeCategory::Integer
, 8>{}(x
, y
, source
, line
);
167 #ifdef __SIZEOF_INT128__
168 CppTypeFor
<TypeCategory::Integer
, 16> RTNAME(DotProductInteger16
)(
169 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
170 return DotProduct
<TypeCategory::Integer
, 16>{}(x
, y
, source
, line
);
174 // TODO: REAL/COMPLEX(2 & 3)
175 // Intermediate results and operations are at least 64 bits
176 CppTypeFor
<TypeCategory::Real
, 4> RTNAME(DotProductReal4
)(
177 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
178 return DotProduct
<TypeCategory::Real
, 4>{}(x
, y
, source
, line
);
180 CppTypeFor
<TypeCategory::Real
, 8> RTNAME(DotProductReal8
)(
181 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
182 return DotProduct
<TypeCategory::Real
, 8>{}(x
, y
, source
, line
);
184 #if LDBL_MANT_DIG == 64
185 CppTypeFor
<TypeCategory::Real
, 10> RTNAME(DotProductReal10
)(
186 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
187 return DotProduct
<TypeCategory::Real
, 10>{}(x
, y
, source
, line
);
190 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
191 CppTypeFor
<TypeCategory::Real
, 16> RTNAME(DotProductReal16
)(
192 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
193 return DotProduct
<TypeCategory::Real
, 16>{}(x
, y
, source
, line
);
197 void RTNAME(CppDotProductComplex4
)(CppTypeFor
<TypeCategory::Complex
, 4> &result
,
198 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
199 result
= DotProduct
<TypeCategory::Complex
, 4>{}(x
, y
, source
, line
);
201 void RTNAME(CppDotProductComplex8
)(CppTypeFor
<TypeCategory::Complex
, 8> &result
,
202 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
203 result
= DotProduct
<TypeCategory::Complex
, 8>{}(x
, y
, source
, line
);
205 #if LDBL_MANT_DIG == 64
206 void RTNAME(CppDotProductComplex10
)(
207 CppTypeFor
<TypeCategory::Complex
, 10> &result
, const Descriptor
&x
,
208 const Descriptor
&y
, const char *source
, int line
) {
209 result
= DotProduct
<TypeCategory::Complex
, 10>{}(x
, y
, source
, line
);
212 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
213 void RTNAME(CppDotProductComplex16
)(
214 CppTypeFor
<TypeCategory::Complex
, 16> &result
, const Descriptor
&x
,
215 const Descriptor
&y
, const char *source
, int line
) {
216 result
= DotProduct
<TypeCategory::Complex
, 16>{}(x
, y
, source
, line
);
220 bool RTNAME(DotProductLogical
)(
221 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
222 return DotProduct
<TypeCategory::Logical
, 1>{}(x
, y
, source
, line
);
225 } // namespace Fortran::runtime