1 //===-- runtime/dot-product.cpp -------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 #include "terminator.h"
12 #include "flang/Runtime/cpp-type.h"
13 #include "flang/Runtime/descriptor.h"
14 #include "flang/Runtime/reduction.h"
18 namespace Fortran::runtime
{
20 // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
21 // argument; MATMUL does not.
23 // General accumulator for any type and stride; this is not used for
24 // contiguous numeric vectors.
25 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
28 using Result
= AccumulationType
<RCAT
, RKIND
>;
29 Accumulator(const Descriptor
&x
, const Descriptor
&y
) : x_
{x
}, y_
{y
} {}
30 void AccumulateIndexed(SubscriptValue xAt
, SubscriptValue yAt
) {
31 if constexpr (RCAT
== TypeCategory::Logical
) {
33 (IsLogicalElementTrue(x_
, &xAt
) && IsLogicalElementTrue(y_
, &yAt
));
35 const XT
&xElement
{*x_
.Element
<XT
>(&xAt
)};
36 const YT
&yElement
{*y_
.Element
<YT
>(&yAt
)};
37 if constexpr (RCAT
== TypeCategory::Complex
) {
38 sum_
+= std::conj(static_cast<Result
>(xElement
)) *
39 static_cast<Result
>(yElement
);
41 sum_
+= static_cast<Result
>(xElement
) * static_cast<Result
>(yElement
);
45 Result
GetResult() const { return sum_
; }
48 const Descriptor
&x_
, &y_
;
52 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
53 static inline CppTypeFor
<RCAT
, RKIND
> DoDotProduct(
54 const Descriptor
&x
, const Descriptor
&y
, Terminator
&terminator
) {
55 using Result
= CppTypeFor
<RCAT
, RKIND
>;
56 RUNTIME_CHECK(terminator
, x
.rank() == 1 && y
.rank() == 1);
57 SubscriptValue n
{x
.GetDimension(0).Extent()};
58 if (SubscriptValue yN
{y
.GetDimension(0).Extent()}; yN
!= n
) {
60 "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
61 static_cast<std::intmax_t>(n
), static_cast<std::intmax_t>(yN
));
63 if constexpr (RCAT
!= TypeCategory::Logical
) {
64 if (x
.GetDimension(0).ByteStride() == sizeof(XT
) &&
65 y
.GetDimension(0).ByteStride() == sizeof(YT
)) {
66 // Contiguous numeric vectors
67 if constexpr (std::is_same_v
<XT
, YT
>) {
68 // Contiguous homogeneous numeric vectors
69 if constexpr (std::is_same_v
<XT
, float>) {
70 // TODO: call BLAS-1 SDOT or SDSDOT
71 } else if constexpr (std::is_same_v
<XT
, double>) {
72 // TODO: call BLAS-1 DDOT
73 } else if constexpr (std::is_same_v
<XT
, std::complex<float>>) {
74 // TODO: call BLAS-1 CDOTC
75 } else if constexpr (std::is_same_v
<XT
, std::complex<double>>) {
76 // TODO: call BLAS-1 ZDOTC
79 XT
*xp
{x
.OffsetElement
<XT
>(0)};
80 YT
*yp
{y
.OffsetElement
<YT
>(0)};
81 using AccumType
= AccumulationType
<RCAT
, RKIND
>;
83 if constexpr (RCAT
== TypeCategory::Complex
) {
84 for (SubscriptValue j
{0}; j
< n
; ++j
) {
85 accum
+= std::conj(static_cast<AccumType
>(*xp
++)) *
86 static_cast<AccumType
>(*yp
++);
89 for (SubscriptValue j
{0}; j
< n
; ++j
) {
91 static_cast<AccumType
>(*xp
++) * static_cast<AccumType
>(*yp
++);
94 return static_cast<Result
>(accum
);
97 // Non-contiguous, heterogeneous, & LOGICAL cases
98 SubscriptValue xAt
{x
.GetDimension(0).LowerBound()};
99 SubscriptValue yAt
{y
.GetDimension(0).LowerBound()};
100 Accumulator
<RCAT
, RKIND
, XT
, YT
> accumulator
{x
, y
};
101 for (SubscriptValue j
{0}; j
< n
; ++j
) {
102 accumulator
.AccumulateIndexed(xAt
++, yAt
++);
104 return static_cast<Result
>(accumulator
.GetResult());
107 template <TypeCategory RCAT
, int RKIND
> struct DotProduct
{
108 using Result
= CppTypeFor
<RCAT
, RKIND
>;
109 template <TypeCategory XCAT
, int XKIND
> struct DP1
{
110 template <TypeCategory YCAT
, int YKIND
> struct DP2
{
111 Result
operator()(const Descriptor
&x
, const Descriptor
&y
,
112 Terminator
&terminator
) const {
113 if constexpr (constexpr auto resultType
{
114 GetResultType(XCAT
, XKIND
, YCAT
, YKIND
)}) {
115 if constexpr (resultType
->first
== RCAT
&&
116 (resultType
->second
<= RKIND
|| RCAT
== TypeCategory::Logical
)) {
117 return DoDotProduct
<RCAT
, RKIND
, CppTypeFor
<XCAT
, XKIND
>,
118 CppTypeFor
<YCAT
, YKIND
>>(x
, y
, terminator
);
122 "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
123 static_cast<int>(RCAT
), RKIND
, static_cast<int>(XCAT
), XKIND
,
124 static_cast<int>(YCAT
), YKIND
);
127 Result
operator()(const Descriptor
&x
, const Descriptor
&y
,
128 Terminator
&terminator
, TypeCategory yCat
, int yKind
) const {
129 return ApplyType
<DP2
, Result
>(yCat
, yKind
, terminator
, x
, y
, terminator
);
132 Result
operator()(const Descriptor
&x
, const Descriptor
&y
,
133 const char *source
, int line
) const {
134 Terminator terminator
{source
, line
};
135 if (RCAT
!= TypeCategory::Logical
&& x
.type() == y
.type()) {
136 // No conversions needed, operands and result have same known type
137 return typename DP1
<RCAT
, RKIND
>::template DP2
<RCAT
, RKIND
>{}(
140 auto xCatKind
{x
.type().GetCategoryAndKind()};
141 auto yCatKind
{y
.type().GetCategoryAndKind()};
142 RUNTIME_CHECK(terminator
, xCatKind
.has_value() && yCatKind
.has_value());
143 return ApplyType
<DP1
, Result
>(xCatKind
->first
, xCatKind
->second
,
144 terminator
, x
, y
, terminator
, yCatKind
->first
, yCatKind
->second
);
150 CppTypeFor
<TypeCategory::Integer
, 1> RTNAME(DotProductInteger1
)(
151 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
152 return DotProduct
<TypeCategory::Integer
, 1>{}(x
, y
, source
, line
);
154 CppTypeFor
<TypeCategory::Integer
, 2> RTNAME(DotProductInteger2
)(
155 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
156 return DotProduct
<TypeCategory::Integer
, 2>{}(x
, y
, source
, line
);
158 CppTypeFor
<TypeCategory::Integer
, 4> RTNAME(DotProductInteger4
)(
159 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
160 return DotProduct
<TypeCategory::Integer
, 4>{}(x
, y
, source
, line
);
162 CppTypeFor
<TypeCategory::Integer
, 8> RTNAME(DotProductInteger8
)(
163 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
164 return DotProduct
<TypeCategory::Integer
, 8>{}(x
, y
, source
, line
);
166 #ifdef __SIZEOF_INT128__
167 CppTypeFor
<TypeCategory::Integer
, 16> RTNAME(DotProductInteger16
)(
168 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
169 return DotProduct
<TypeCategory::Integer
, 16>{}(x
, y
, source
, line
);
173 // TODO: REAL/COMPLEX(2 & 3)
174 // Intermediate results and operations are at least 64 bits
175 CppTypeFor
<TypeCategory::Real
, 4> RTNAME(DotProductReal4
)(
176 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
177 return DotProduct
<TypeCategory::Real
, 4>{}(x
, y
, source
, line
);
179 CppTypeFor
<TypeCategory::Real
, 8> RTNAME(DotProductReal8
)(
180 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
181 return DotProduct
<TypeCategory::Real
, 8>{}(x
, y
, source
, line
);
183 #if LDBL_MANT_DIG == 64
184 CppTypeFor
<TypeCategory::Real
, 10> RTNAME(DotProductReal10
)(
185 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
186 return DotProduct
<TypeCategory::Real
, 10>{}(x
, y
, source
, line
);
189 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
190 CppTypeFor
<TypeCategory::Real
, 16> RTNAME(DotProductReal16
)(
191 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
192 return DotProduct
<TypeCategory::Real
, 16>{}(x
, y
, source
, line
);
196 void RTNAME(CppDotProductComplex4
)(CppTypeFor
<TypeCategory::Complex
, 4> &result
,
197 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
198 result
= DotProduct
<TypeCategory::Complex
, 4>{}(x
, y
, source
, line
);
200 void RTNAME(CppDotProductComplex8
)(CppTypeFor
<TypeCategory::Complex
, 8> &result
,
201 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
202 result
= DotProduct
<TypeCategory::Complex
, 8>{}(x
, y
, source
, line
);
204 #if LDBL_MANT_DIG == 64
205 void RTNAME(CppDotProductComplex10
)(
206 CppTypeFor
<TypeCategory::Complex
, 10> &result
, const Descriptor
&x
,
207 const Descriptor
&y
, const char *source
, int line
) {
208 result
= DotProduct
<TypeCategory::Complex
, 10>{}(x
, y
, source
, line
);
211 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
212 void RTNAME(CppDotProductComplex16
)(
213 CppTypeFor
<TypeCategory::Complex
, 16> &result
, const Descriptor
&x
,
214 const Descriptor
&y
, const char *source
, int line
) {
215 result
= DotProduct
<TypeCategory::Complex
, 16>{}(x
, y
, source
, line
);
219 bool RTNAME(DotProductLogical
)(
220 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
221 return DotProduct
<TypeCategory::Logical
, 1>{}(x
, y
, source
, line
);
224 } // namespace Fortran::runtime