1 //===-- runtime/dot-product.cpp -------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 #include "terminator.h"
12 #include "flang/Common/float128.h"
13 #include "flang/Runtime/cpp-type.h"
14 #include "flang/Runtime/descriptor.h"
15 #include "flang/Runtime/reduction.h"
19 namespace Fortran::runtime
{
21 // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
22 // argument; MATMUL does not.
24 // Suppress the warnings about calling __host__-only std::complex operators,
25 // defined in C++ STD header files, from __device__ code.
27 RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
29 // General accumulator for any type and stride; this is not used for
30 // contiguous numeric vectors.
31 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
34 using Result
= AccumulationType
<RCAT
, RKIND
>;
35 RT_API_ATTRS
Accumulator(const Descriptor
&x
, const Descriptor
&y
)
37 RT_API_ATTRS
void AccumulateIndexed(SubscriptValue xAt
, SubscriptValue yAt
) {
38 if constexpr (RCAT
== TypeCategory::Logical
) {
40 (IsLogicalElementTrue(x_
, &xAt
) && IsLogicalElementTrue(y_
, &yAt
));
42 const XT
&xElement
{*x_
.Element
<XT
>(&xAt
)};
43 const YT
&yElement
{*y_
.Element
<YT
>(&yAt
)};
44 if constexpr (RCAT
== TypeCategory::Complex
) {
45 sum_
+= std::conj(static_cast<Result
>(xElement
)) *
46 static_cast<Result
>(yElement
);
48 sum_
+= static_cast<Result
>(xElement
) * static_cast<Result
>(yElement
);
52 RT_API_ATTRS Result
GetResult() const { return sum_
; }
55 const Descriptor
&x_
, &y_
;
59 template <TypeCategory RCAT
, int RKIND
, typename XT
, typename YT
>
60 static inline RT_API_ATTRS CppTypeFor
<RCAT
, RKIND
> DoDotProduct(
61 const Descriptor
&x
, const Descriptor
&y
, Terminator
&terminator
) {
62 using Result
= CppTypeFor
<RCAT
, RKIND
>;
63 RUNTIME_CHECK(terminator
, x
.rank() == 1 && y
.rank() == 1);
64 SubscriptValue n
{x
.GetDimension(0).Extent()};
65 if (SubscriptValue yN
{y
.GetDimension(0).Extent()}; yN
!= n
) {
67 "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
68 static_cast<std::intmax_t>(n
), static_cast<std::intmax_t>(yN
));
70 if constexpr (RCAT
!= TypeCategory::Logical
) {
71 if (x
.GetDimension(0).ByteStride() == sizeof(XT
) &&
72 y
.GetDimension(0).ByteStride() == sizeof(YT
)) {
73 // Contiguous numeric vectors
74 if constexpr (std::is_same_v
<XT
, YT
>) {
75 // Contiguous homogeneous numeric vectors
76 if constexpr (std::is_same_v
<XT
, float>) {
77 // TODO: call BLAS-1 SDOT or SDSDOT
78 } else if constexpr (std::is_same_v
<XT
, double>) {
79 // TODO: call BLAS-1 DDOT
80 } else if constexpr (std::is_same_v
<XT
, std::complex<float>>) {
81 // TODO: call BLAS-1 CDOTC
82 } else if constexpr (std::is_same_v
<XT
, std::complex<double>>) {
83 // TODO: call BLAS-1 ZDOTC
86 XT
*xp
{x
.OffsetElement
<XT
>(0)};
87 YT
*yp
{y
.OffsetElement
<YT
>(0)};
88 using AccumType
= AccumulationType
<RCAT
, RKIND
>;
90 if constexpr (RCAT
== TypeCategory::Complex
) {
91 for (SubscriptValue j
{0}; j
< n
; ++j
) {
92 // std::conj() may instantiate its argument twice,
93 // so xp has to be incremented separately.
94 // This is a workaround for an alleged bug in clang,
96 // warning: multiple unsequenced modifications to 'xp'
97 accum
+= std::conj(static_cast<AccumType
>(*xp
)) *
98 static_cast<AccumType
>(*yp
++);
102 for (SubscriptValue j
{0}; j
< n
; ++j
) {
104 static_cast<AccumType
>(*xp
++) * static_cast<AccumType
>(*yp
++);
107 return static_cast<Result
>(accum
);
110 // Non-contiguous, heterogeneous, & LOGICAL cases
111 SubscriptValue xAt
{x
.GetDimension(0).LowerBound()};
112 SubscriptValue yAt
{y
.GetDimension(0).LowerBound()};
113 Accumulator
<RCAT
, RKIND
, XT
, YT
> accumulator
{x
, y
};
114 for (SubscriptValue j
{0}; j
< n
; ++j
) {
115 accumulator
.AccumulateIndexed(xAt
++, yAt
++);
117 return static_cast<Result
>(accumulator
.GetResult());
122 template <TypeCategory RCAT
, int RKIND
> struct DotProduct
{
123 using Result
= CppTypeFor
<RCAT
, RKIND
>;
124 template <TypeCategory XCAT
, int XKIND
> struct DP1
{
125 template <TypeCategory YCAT
, int YKIND
> struct DP2
{
126 RT_API_ATTRS Result
operator()(const Descriptor
&x
, const Descriptor
&y
,
127 Terminator
&terminator
) const {
128 if constexpr (constexpr auto resultType
{
129 GetResultType(XCAT
, XKIND
, YCAT
, YKIND
)}) {
130 if constexpr (resultType
->first
== RCAT
&&
131 (resultType
->second
<= RKIND
|| RCAT
== TypeCategory::Logical
)) {
132 return DoDotProduct
<RCAT
, RKIND
, CppTypeFor
<XCAT
, XKIND
>,
133 CppTypeFor
<YCAT
, YKIND
>>(x
, y
, terminator
);
137 "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
138 static_cast<int>(RCAT
), RKIND
, static_cast<int>(XCAT
), XKIND
,
139 static_cast<int>(YCAT
), YKIND
);
142 RT_API_ATTRS Result
operator()(const Descriptor
&x
, const Descriptor
&y
,
143 Terminator
&terminator
, TypeCategory yCat
, int yKind
) const {
144 return ApplyType
<DP2
, Result
>(yCat
, yKind
, terminator
, x
, y
, terminator
);
147 RT_API_ATTRS Result
operator()(const Descriptor
&x
, const Descriptor
&y
,
148 const char *source
, int line
) const {
149 Terminator terminator
{source
, line
};
150 if (RCAT
!= TypeCategory::Logical
&& x
.type() == y
.type()) {
151 // No conversions needed, operands and result have same known type
152 return typename DP1
<RCAT
, RKIND
>::template DP2
<RCAT
, RKIND
>{}(
155 auto xCatKind
{x
.type().GetCategoryAndKind()};
156 auto yCatKind
{y
.type().GetCategoryAndKind()};
157 RUNTIME_CHECK(terminator
, xCatKind
.has_value() && yCatKind
.has_value());
158 return ApplyType
<DP1
, Result
>(xCatKind
->first
, xCatKind
->second
,
159 terminator
, x
, y
, terminator
, yCatKind
->first
, yCatKind
->second
);
165 RT_EXT_API_GROUP_BEGIN
167 CppTypeFor
<TypeCategory::Integer
, 1> RTDEF(DotProductInteger1
)(
168 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
169 return DotProduct
<TypeCategory::Integer
, 1>{}(x
, y
, source
, line
);
171 CppTypeFor
<TypeCategory::Integer
, 2> RTDEF(DotProductInteger2
)(
172 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
173 return DotProduct
<TypeCategory::Integer
, 2>{}(x
, y
, source
, line
);
175 CppTypeFor
<TypeCategory::Integer
, 4> RTDEF(DotProductInteger4
)(
176 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
177 return DotProduct
<TypeCategory::Integer
, 4>{}(x
, y
, source
, line
);
179 CppTypeFor
<TypeCategory::Integer
, 8> RTDEF(DotProductInteger8
)(
180 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
181 return DotProduct
<TypeCategory::Integer
, 8>{}(x
, y
, source
, line
);
183 #ifdef __SIZEOF_INT128__
184 CppTypeFor
<TypeCategory::Integer
, 16> RTDEF(DotProductInteger16
)(
185 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
186 return DotProduct
<TypeCategory::Integer
, 16>{}(x
, y
, source
, line
);
190 // TODO: REAL/COMPLEX(2 & 3)
191 // Intermediate results and operations are at least 64 bits
192 CppTypeFor
<TypeCategory::Real
, 4> RTDEF(DotProductReal4
)(
193 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
194 return DotProduct
<TypeCategory::Real
, 4>{}(x
, y
, source
, line
);
196 CppTypeFor
<TypeCategory::Real
, 8> RTDEF(DotProductReal8
)(
197 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
198 return DotProduct
<TypeCategory::Real
, 8>{}(x
, y
, source
, line
);
200 #if LDBL_MANT_DIG == 64
201 CppTypeFor
<TypeCategory::Real
, 10> RTDEF(DotProductReal10
)(
202 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
203 return DotProduct
<TypeCategory::Real
, 10>{}(x
, y
, source
, line
);
206 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
207 CppTypeFor
<TypeCategory::Real
, 16> RTDEF(DotProductReal16
)(
208 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
209 return DotProduct
<TypeCategory::Real
, 16>{}(x
, y
, source
, line
);
213 void RTDEF(CppDotProductComplex4
)(CppTypeFor
<TypeCategory::Complex
, 4> &result
,
214 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
215 result
= DotProduct
<TypeCategory::Complex
, 4>{}(x
, y
, source
, line
);
217 void RTDEF(CppDotProductComplex8
)(CppTypeFor
<TypeCategory::Complex
, 8> &result
,
218 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
219 result
= DotProduct
<TypeCategory::Complex
, 8>{}(x
, y
, source
, line
);
221 #if LDBL_MANT_DIG == 64
222 void RTDEF(CppDotProductComplex10
)(
223 CppTypeFor
<TypeCategory::Complex
, 10> &result
, const Descriptor
&x
,
224 const Descriptor
&y
, const char *source
, int line
) {
225 result
= DotProduct
<TypeCategory::Complex
, 10>{}(x
, y
, source
, line
);
228 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
229 void RTDEF(CppDotProductComplex16
)(
230 CppTypeFor
<TypeCategory::Complex
, 16> &result
, const Descriptor
&x
,
231 const Descriptor
&y
, const char *source
, int line
) {
232 result
= DotProduct
<TypeCategory::Complex
, 16>{}(x
, y
, source
, line
);
236 bool RTDEF(DotProductLogical
)(
237 const Descriptor
&x
, const Descriptor
&y
, const char *source
, int line
) {
238 return DotProduct
<TypeCategory::Logical
, 1>{}(x
, y
, source
, line
);
243 } // namespace Fortran::runtime