[flang] Accept polymorphic component element in storage_size
[llvm-project.git] / flang / runtime / dot-product.cpp
blob857ed6759817aaadf9e276eab9d8a8afe9b2c8e8
1 //===-- runtime/dot-product.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "float.h"
10 #include "terminator.h"
11 #include "tools.h"
12 #include "flang/Runtime/cpp-type.h"
13 #include "flang/Runtime/descriptor.h"
14 #include "flang/Runtime/reduction.h"
15 #include <cfloat>
16 #include <cinttypes>
18 namespace Fortran::runtime {
20 // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
21 // argument; MATMUL does not.
23 // General accumulator for any type and stride; this is not used for
24 // contiguous numeric vectors.
25 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
26 class Accumulator {
27 public:
28 using Result = AccumulationType<RCAT, RKIND>;
29 Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
30 void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) {
31 if constexpr (RCAT == TypeCategory::Logical) {
32 sum_ = sum_ ||
33 (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
34 } else {
35 const XT &xElement{*x_.Element<XT>(&xAt)};
36 const YT &yElement{*y_.Element<YT>(&yAt)};
37 if constexpr (RCAT == TypeCategory::Complex) {
38 sum_ += std::conj(static_cast<Result>(xElement)) *
39 static_cast<Result>(yElement);
40 } else {
41 sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement);
45 Result GetResult() const { return sum_; }
47 private:
48 const Descriptor &x_, &y_;
49 Result sum_{};
52 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
53 static inline CppTypeFor<RCAT, RKIND> DoDotProduct(
54 const Descriptor &x, const Descriptor &y, Terminator &terminator) {
55 using Result = CppTypeFor<RCAT, RKIND>;
56 RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
57 SubscriptValue n{x.GetDimension(0).Extent()};
58 if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
59 terminator.Crash(
60 "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
61 static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
63 if constexpr (RCAT != TypeCategory::Logical) {
64 if (x.GetDimension(0).ByteStride() == sizeof(XT) &&
65 y.GetDimension(0).ByteStride() == sizeof(YT)) {
66 // Contiguous numeric vectors
67 if constexpr (std::is_same_v<XT, YT>) {
68 // Contiguous homogeneous numeric vectors
69 if constexpr (std::is_same_v<XT, float>) {
70 // TODO: call BLAS-1 SDOT or SDSDOT
71 } else if constexpr (std::is_same_v<XT, double>) {
72 // TODO: call BLAS-1 DDOT
73 } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
74 // TODO: call BLAS-1 CDOTC
75 } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
76 // TODO: call BLAS-1 ZDOTC
79 XT *xp{x.OffsetElement<XT>(0)};
80 YT *yp{y.OffsetElement<YT>(0)};
81 using AccumType = AccumulationType<RCAT, RKIND>;
82 AccumType accum{};
83 if constexpr (RCAT == TypeCategory::Complex) {
84 for (SubscriptValue j{0}; j < n; ++j) {
85 accum += std::conj(static_cast<AccumType>(*xp++)) *
86 static_cast<AccumType>(*yp++);
88 } else {
89 for (SubscriptValue j{0}; j < n; ++j) {
90 accum +=
91 static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++);
94 return static_cast<Result>(accum);
97 // Non-contiguous, heterogeneous, & LOGICAL cases
98 SubscriptValue xAt{x.GetDimension(0).LowerBound()};
99 SubscriptValue yAt{y.GetDimension(0).LowerBound()};
100 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
101 for (SubscriptValue j{0}; j < n; ++j) {
102 accumulator.AccumulateIndexed(xAt++, yAt++);
104 return static_cast<Result>(accumulator.GetResult());
107 template <TypeCategory RCAT, int RKIND> struct DotProduct {
108 using Result = CppTypeFor<RCAT, RKIND>;
109 template <TypeCategory XCAT, int XKIND> struct DP1 {
110 template <TypeCategory YCAT, int YKIND> struct DP2 {
111 Result operator()(const Descriptor &x, const Descriptor &y,
112 Terminator &terminator) const {
113 if constexpr (constexpr auto resultType{
114 GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
115 if constexpr (resultType->first == RCAT &&
116 (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) {
117 return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>,
118 CppTypeFor<YCAT, YKIND>>(x, y, terminator);
121 terminator.Crash(
122 "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
123 static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
124 static_cast<int>(YCAT), YKIND);
127 Result operator()(const Descriptor &x, const Descriptor &y,
128 Terminator &terminator, TypeCategory yCat, int yKind) const {
129 return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
132 Result operator()(const Descriptor &x, const Descriptor &y,
133 const char *source, int line) const {
134 Terminator terminator{source, line};
135 if (RCAT != TypeCategory::Logical && x.type() == y.type()) {
136 // No conversions needed, operands and result have same known type
137 return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
138 x, y, terminator);
139 } else {
140 auto xCatKind{x.type().GetCategoryAndKind()};
141 auto yCatKind{y.type().GetCategoryAndKind()};
142 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
143 return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second,
144 terminator, x, y, terminator, yCatKind->first, yCatKind->second);
149 extern "C" {
150 CppTypeFor<TypeCategory::Integer, 1> RTNAME(DotProductInteger1)(
151 const Descriptor &x, const Descriptor &y, const char *source, int line) {
152 return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
154 CppTypeFor<TypeCategory::Integer, 2> RTNAME(DotProductInteger2)(
155 const Descriptor &x, const Descriptor &y, const char *source, int line) {
156 return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
158 CppTypeFor<TypeCategory::Integer, 4> RTNAME(DotProductInteger4)(
159 const Descriptor &x, const Descriptor &y, const char *source, int line) {
160 return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
162 CppTypeFor<TypeCategory::Integer, 8> RTNAME(DotProductInteger8)(
163 const Descriptor &x, const Descriptor &y, const char *source, int line) {
164 return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
166 #ifdef __SIZEOF_INT128__
167 CppTypeFor<TypeCategory::Integer, 16> RTNAME(DotProductInteger16)(
168 const Descriptor &x, const Descriptor &y, const char *source, int line) {
169 return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
171 #endif
173 // TODO: REAL/COMPLEX(2 & 3)
174 // Intermediate results and operations are at least 64 bits
175 CppTypeFor<TypeCategory::Real, 4> RTNAME(DotProductReal4)(
176 const Descriptor &x, const Descriptor &y, const char *source, int line) {
177 return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
179 CppTypeFor<TypeCategory::Real, 8> RTNAME(DotProductReal8)(
180 const Descriptor &x, const Descriptor &y, const char *source, int line) {
181 return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
183 #if LDBL_MANT_DIG == 64
184 CppTypeFor<TypeCategory::Real, 10> RTNAME(DotProductReal10)(
185 const Descriptor &x, const Descriptor &y, const char *source, int line) {
186 return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
188 #endif
189 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
190 CppTypeFor<TypeCategory::Real, 16> RTNAME(DotProductReal16)(
191 const Descriptor &x, const Descriptor &y, const char *source, int line) {
192 return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
194 #endif
196 void RTNAME(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
197 const Descriptor &x, const Descriptor &y, const char *source, int line) {
198 result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line);
200 void RTNAME(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
201 const Descriptor &x, const Descriptor &y, const char *source, int line) {
202 result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
204 #if LDBL_MANT_DIG == 64
205 void RTNAME(CppDotProductComplex10)(
206 CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x,
207 const Descriptor &y, const char *source, int line) {
208 result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
210 #endif
211 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
212 void RTNAME(CppDotProductComplex16)(
213 CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x,
214 const Descriptor &y, const char *source, int line) {
215 result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
217 #endif
219 bool RTNAME(DotProductLogical)(
220 const Descriptor &x, const Descriptor &y, const char *source, int line) {
221 return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
223 } // extern "C"
224 } // namespace Fortran::runtime