[Support] Remove unused includes (NFC) (#116752)
[llvm-project.git] / flang / runtime / dot-product.cpp
blob335e5929f0865e03a6b4074108bb4a1609ba47a1
1 //===-- runtime/dot-product.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "float.h"
10 #include "terminator.h"
11 #include "tools.h"
12 #include "flang/Common/float128.h"
13 #include "flang/Runtime/cpp-type.h"
14 #include "flang/Runtime/descriptor.h"
15 #include "flang/Runtime/reduction.h"
16 #include <cfloat>
17 #include <cinttypes>
19 namespace Fortran::runtime {
21 // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
22 // argument; MATMUL does not.
24 // General accumulator for any type and stride; this is not used for
25 // contiguous numeric vectors.
26 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
27 class Accumulator {
28 public:
29 using Result = AccumulationType<RCAT, RKIND>;
30 RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y)
31 : x_{x}, y_{y} {}
32 RT_API_ATTRS void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) {
33 if constexpr (RCAT == TypeCategory::Logical) {
34 sum_ = sum_ ||
35 (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
36 } else {
37 const XT &xElement{*x_.Element<XT>(&xAt)};
38 const YT &yElement{*y_.Element<YT>(&yAt)};
39 if constexpr (RCAT == TypeCategory::Complex) {
40 sum_ += rtcmplx::conj(static_cast<Result>(xElement)) *
41 static_cast<Result>(yElement);
42 } else {
43 sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement);
47 RT_API_ATTRS Result GetResult() const { return sum_; }
49 private:
50 const Descriptor &x_, &y_;
51 Result sum_{};
54 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
55 static inline RT_API_ATTRS CppTypeFor<RCAT, RKIND> DoDotProduct(
56 const Descriptor &x, const Descriptor &y, Terminator &terminator) {
57 using Result = CppTypeFor<RCAT, RKIND>;
58 RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
59 SubscriptValue n{x.GetDimension(0).Extent()};
60 if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
61 terminator.Crash(
62 "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
63 static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
65 if constexpr (RCAT != TypeCategory::Logical) {
66 if (x.GetDimension(0).ByteStride() == sizeof(XT) &&
67 y.GetDimension(0).ByteStride() == sizeof(YT)) {
68 // Contiguous numeric vectors
69 if constexpr (std::is_same_v<XT, YT>) {
70 // Contiguous homogeneous numeric vectors
71 if constexpr (std::is_same_v<XT, float>) {
72 // TODO: call BLAS-1 SDOT or SDSDOT
73 } else if constexpr (std::is_same_v<XT, double>) {
74 // TODO: call BLAS-1 DDOT
75 } else if constexpr (std::is_same_v<XT, rtcmplx::complex<float>>) {
76 // TODO: call BLAS-1 CDOTC
77 } else if constexpr (std::is_same_v<XT, rtcmplx::complex<double>>) {
78 // TODO: call BLAS-1 ZDOTC
81 XT *xp{x.OffsetElement<XT>(0)};
82 YT *yp{y.OffsetElement<YT>(0)};
83 using AccumType = AccumulationType<RCAT, RKIND>;
84 AccumType accum{};
85 if constexpr (RCAT == TypeCategory::Complex) {
86 for (SubscriptValue j{0}; j < n; ++j) {
87 // conj() may instantiate its argument twice,
88 // so xp has to be incremented separately.
89 // This is a workaround for an alleged bug in clang,
90 // that shows up as:
91 // warning: multiple unsequenced modifications to 'xp'
92 accum += rtcmplx::conj(static_cast<AccumType>(*xp)) *
93 static_cast<AccumType>(*yp++);
94 xp++;
96 } else {
97 for (SubscriptValue j{0}; j < n; ++j) {
98 accum +=
99 static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++);
102 return static_cast<Result>(accum);
105 // Non-contiguous, heterogeneous, & LOGICAL cases
106 SubscriptValue xAt{x.GetDimension(0).LowerBound()};
107 SubscriptValue yAt{y.GetDimension(0).LowerBound()};
108 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
109 for (SubscriptValue j{0}; j < n; ++j) {
110 accumulator.AccumulateIndexed(xAt++, yAt++);
112 return static_cast<Result>(accumulator.GetResult());
115 template <TypeCategory RCAT, int RKIND> struct DotProduct {
116 using Result = CppTypeFor<RCAT, RKIND>;
117 template <TypeCategory XCAT, int XKIND> struct DP1 {
118 template <TypeCategory YCAT, int YKIND> struct DP2 {
119 RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
120 Terminator &terminator) const {
121 if constexpr (constexpr auto resultType{
122 GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
123 if constexpr (resultType->first == RCAT &&
124 (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) {
125 return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>,
126 CppTypeFor<YCAT, YKIND>>(x, y, terminator);
129 terminator.Crash(
130 "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
131 static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
132 static_cast<int>(YCAT), YKIND);
135 RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
136 Terminator &terminator, TypeCategory yCat, int yKind) const {
137 return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
140 RT_API_ATTRS Result operator()(const Descriptor &x, const Descriptor &y,
141 const char *source, int line) const {
142 Terminator terminator{source, line};
143 if (RCAT != TypeCategory::Logical && x.type() == y.type()) {
144 // No conversions needed, operands and result have same known type
145 return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
146 x, y, terminator);
147 } else {
148 auto xCatKind{x.type().GetCategoryAndKind()};
149 auto yCatKind{y.type().GetCategoryAndKind()};
150 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
151 return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second,
152 terminator, x, y, terminator, yCatKind->first, yCatKind->second);
157 extern "C" {
158 RT_EXT_API_GROUP_BEGIN
160 CppTypeFor<TypeCategory::Integer, 1> RTDEF(DotProductInteger1)(
161 const Descriptor &x, const Descriptor &y, const char *source, int line) {
162 return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
164 CppTypeFor<TypeCategory::Integer, 2> RTDEF(DotProductInteger2)(
165 const Descriptor &x, const Descriptor &y, const char *source, int line) {
166 return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
168 CppTypeFor<TypeCategory::Integer, 4> RTDEF(DotProductInteger4)(
169 const Descriptor &x, const Descriptor &y, const char *source, int line) {
170 return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
172 CppTypeFor<TypeCategory::Integer, 8> RTDEF(DotProductInteger8)(
173 const Descriptor &x, const Descriptor &y, const char *source, int line) {
174 return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
176 #ifdef __SIZEOF_INT128__
177 CppTypeFor<TypeCategory::Integer, 16> RTDEF(DotProductInteger16)(
178 const Descriptor &x, const Descriptor &y, const char *source, int line) {
179 return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
181 #endif
183 // TODO: REAL/COMPLEX(2 & 3)
184 // Intermediate results and operations are at least 64 bits
185 CppTypeFor<TypeCategory::Real, 4> RTDEF(DotProductReal4)(
186 const Descriptor &x, const Descriptor &y, const char *source, int line) {
187 return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
189 CppTypeFor<TypeCategory::Real, 8> RTDEF(DotProductReal8)(
190 const Descriptor &x, const Descriptor &y, const char *source, int line) {
191 return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
193 #if HAS_FLOAT80
194 CppTypeFor<TypeCategory::Real, 10> RTDEF(DotProductReal10)(
195 const Descriptor &x, const Descriptor &y, const char *source, int line) {
196 return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
198 #endif
199 #if HAS_LDBL128 || HAS_FLOAT128
200 CppTypeFor<TypeCategory::Real, 16> RTDEF(DotProductReal16)(
201 const Descriptor &x, const Descriptor &y, const char *source, int line) {
202 return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
204 #endif
206 void RTDEF(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
207 const Descriptor &x, const Descriptor &y, const char *source, int line) {
208 result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line);
210 void RTDEF(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
211 const Descriptor &x, const Descriptor &y, const char *source, int line) {
212 result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
214 #if HAS_FLOAT80
215 void RTDEF(CppDotProductComplex10)(
216 CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x,
217 const Descriptor &y, const char *source, int line) {
218 result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
220 #endif
221 #if HAS_LDBL128 || HAS_FLOAT128
222 void RTDEF(CppDotProductComplex16)(
223 CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x,
224 const Descriptor &y, const char *source, int line) {
225 result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
227 #endif
229 bool RTDEF(DotProductLogical)(
230 const Descriptor &x, const Descriptor &y, const char *source, int line) {
231 return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
234 RT_EXT_API_GROUP_END
235 } // extern "C"
236 } // namespace Fortran::runtime