[AArch64,ELF] Restrict MOVZ/MOVK to non-PIC large code model (#70178)
[llvm-project.git] / flang / runtime / dot-product.cpp
blob58382863a5006752be0374312fe80650f716f6d8
1 //===-- runtime/dot-product.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "float.h"
10 #include "terminator.h"
11 #include "tools.h"
12 #include "flang/Common/float128.h"
13 #include "flang/Runtime/cpp-type.h"
14 #include "flang/Runtime/descriptor.h"
15 #include "flang/Runtime/reduction.h"
16 #include <cfloat>
17 #include <cinttypes>
19 namespace Fortran::runtime {
21 // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
22 // argument; MATMUL does not.
24 // General accumulator for any type and stride; this is not used for
25 // contiguous numeric vectors.
26 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
27 class Accumulator {
28 public:
29 using Result = AccumulationType<RCAT, RKIND>;
30 Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
31 void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) {
32 if constexpr (RCAT == TypeCategory::Logical) {
33 sum_ = sum_ ||
34 (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
35 } else {
36 const XT &xElement{*x_.Element<XT>(&xAt)};
37 const YT &yElement{*y_.Element<YT>(&yAt)};
38 if constexpr (RCAT == TypeCategory::Complex) {
39 sum_ += std::conj(static_cast<Result>(xElement)) *
40 static_cast<Result>(yElement);
41 } else {
42 sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement);
46 Result GetResult() const { return sum_; }
48 private:
49 const Descriptor &x_, &y_;
50 Result sum_{};
53 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
54 static inline CppTypeFor<RCAT, RKIND> DoDotProduct(
55 const Descriptor &x, const Descriptor &y, Terminator &terminator) {
56 using Result = CppTypeFor<RCAT, RKIND>;
57 RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
58 SubscriptValue n{x.GetDimension(0).Extent()};
59 if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
60 terminator.Crash(
61 "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
62 static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
64 if constexpr (RCAT != TypeCategory::Logical) {
65 if (x.GetDimension(0).ByteStride() == sizeof(XT) &&
66 y.GetDimension(0).ByteStride() == sizeof(YT)) {
67 // Contiguous numeric vectors
68 if constexpr (std::is_same_v<XT, YT>) {
69 // Contiguous homogeneous numeric vectors
70 if constexpr (std::is_same_v<XT, float>) {
71 // TODO: call BLAS-1 SDOT or SDSDOT
72 } else if constexpr (std::is_same_v<XT, double>) {
73 // TODO: call BLAS-1 DDOT
74 } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
75 // TODO: call BLAS-1 CDOTC
76 } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
77 // TODO: call BLAS-1 ZDOTC
80 XT *xp{x.OffsetElement<XT>(0)};
81 YT *yp{y.OffsetElement<YT>(0)};
82 using AccumType = AccumulationType<RCAT, RKIND>;
83 AccumType accum{};
84 if constexpr (RCAT == TypeCategory::Complex) {
85 for (SubscriptValue j{0}; j < n; ++j) {
86 accum += std::conj(static_cast<AccumType>(*xp++)) *
87 static_cast<AccumType>(*yp++);
89 } else {
90 for (SubscriptValue j{0}; j < n; ++j) {
91 accum +=
92 static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++);
95 return static_cast<Result>(accum);
98 // Non-contiguous, heterogeneous, & LOGICAL cases
99 SubscriptValue xAt{x.GetDimension(0).LowerBound()};
100 SubscriptValue yAt{y.GetDimension(0).LowerBound()};
101 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
102 for (SubscriptValue j{0}; j < n; ++j) {
103 accumulator.AccumulateIndexed(xAt++, yAt++);
105 return static_cast<Result>(accumulator.GetResult());
108 template <TypeCategory RCAT, int RKIND> struct DotProduct {
109 using Result = CppTypeFor<RCAT, RKIND>;
110 template <TypeCategory XCAT, int XKIND> struct DP1 {
111 template <TypeCategory YCAT, int YKIND> struct DP2 {
112 Result operator()(const Descriptor &x, const Descriptor &y,
113 Terminator &terminator) const {
114 if constexpr (constexpr auto resultType{
115 GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
116 if constexpr (resultType->first == RCAT &&
117 (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) {
118 return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>,
119 CppTypeFor<YCAT, YKIND>>(x, y, terminator);
122 terminator.Crash(
123 "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
124 static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
125 static_cast<int>(YCAT), YKIND);
128 Result operator()(const Descriptor &x, const Descriptor &y,
129 Terminator &terminator, TypeCategory yCat, int yKind) const {
130 return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
133 Result operator()(const Descriptor &x, const Descriptor &y,
134 const char *source, int line) const {
135 Terminator terminator{source, line};
136 if (RCAT != TypeCategory::Logical && x.type() == y.type()) {
137 // No conversions needed, operands and result have same known type
138 return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
139 x, y, terminator);
140 } else {
141 auto xCatKind{x.type().GetCategoryAndKind()};
142 auto yCatKind{y.type().GetCategoryAndKind()};
143 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
144 return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second,
145 terminator, x, y, terminator, yCatKind->first, yCatKind->second);
150 extern "C" {
151 CppTypeFor<TypeCategory::Integer, 1> RTNAME(DotProductInteger1)(
152 const Descriptor &x, const Descriptor &y, const char *source, int line) {
153 return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
155 CppTypeFor<TypeCategory::Integer, 2> RTNAME(DotProductInteger2)(
156 const Descriptor &x, const Descriptor &y, const char *source, int line) {
157 return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
159 CppTypeFor<TypeCategory::Integer, 4> RTNAME(DotProductInteger4)(
160 const Descriptor &x, const Descriptor &y, const char *source, int line) {
161 return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
163 CppTypeFor<TypeCategory::Integer, 8> RTNAME(DotProductInteger8)(
164 const Descriptor &x, const Descriptor &y, const char *source, int line) {
165 return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
167 #ifdef __SIZEOF_INT128__
168 CppTypeFor<TypeCategory::Integer, 16> RTNAME(DotProductInteger16)(
169 const Descriptor &x, const Descriptor &y, const char *source, int line) {
170 return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
172 #endif
174 // TODO: REAL/COMPLEX(2 & 3)
175 // Intermediate results and operations are at least 64 bits
176 CppTypeFor<TypeCategory::Real, 4> RTNAME(DotProductReal4)(
177 const Descriptor &x, const Descriptor &y, const char *source, int line) {
178 return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
180 CppTypeFor<TypeCategory::Real, 8> RTNAME(DotProductReal8)(
181 const Descriptor &x, const Descriptor &y, const char *source, int line) {
182 return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
184 #if LDBL_MANT_DIG == 64
185 CppTypeFor<TypeCategory::Real, 10> RTNAME(DotProductReal10)(
186 const Descriptor &x, const Descriptor &y, const char *source, int line) {
187 return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
189 #endif
190 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
191 CppTypeFor<TypeCategory::Real, 16> RTNAME(DotProductReal16)(
192 const Descriptor &x, const Descriptor &y, const char *source, int line) {
193 return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
195 #endif
197 void RTNAME(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
198 const Descriptor &x, const Descriptor &y, const char *source, int line) {
199 result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line);
201 void RTNAME(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
202 const Descriptor &x, const Descriptor &y, const char *source, int line) {
203 result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
205 #if LDBL_MANT_DIG == 64
206 void RTNAME(CppDotProductComplex10)(
207 CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x,
208 const Descriptor &y, const char *source, int line) {
209 result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
211 #endif
212 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
213 void RTNAME(CppDotProductComplex16)(
214 CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x,
215 const Descriptor &y, const char *source, int line) {
216 result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
218 #endif
220 bool RTNAME(DotProductLogical)(
221 const Descriptor &x, const Descriptor &y, const char *source, int line) {
222 return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
224 } // extern "C"
225 } // namespace Fortran::runtime