[flang] Accept polymorphic component element in storage_size
[llvm-project.git] / flang / runtime / reduction-templates.h
blob2aaf5c102a9ca9ef9c410ed387759b2d4012a44f
1 //===-- runtime/reduction-templates.h -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 // Generic function templates used by various reduction transformation
10 // intrinsic functions (SUM, PRODUCT, &c.)
12 // * Partial reductions (i.e., those with DIM= arguments that are not
13 // required to be 1 by the rank of the argument) return arrays that
14 // are dynamically allocated in a caller-supplied descriptor.
15 // * Total reductions (i.e., no DIM= argument) with FINDLOC, MAXLOC, & MINLOC
16 // return integer vectors of some kind, not scalars; a caller-supplied
17 // descriptor is used
18 // * Character-valued reductions (MAXVAL & MINVAL) return arbitrary
19 // length results, dynamically allocated in a caller-supplied descriptor
21 #ifndef FORTRAN_RUNTIME_REDUCTION_TEMPLATES_H_
22 #define FORTRAN_RUNTIME_REDUCTION_TEMPLATES_H_
24 #include "terminator.h"
25 #include "tools.h"
26 #include "flang/Runtime/cpp-type.h"
27 #include "flang/Runtime/descriptor.h"
29 namespace Fortran::runtime {
31 // Reductions are implemented with *accumulators*, which are instances of
32 // classes that incrementally build up the result (or an element thereof) during
33 // a traversal of the unmasked elements of an array. Each accumulator class
34 // supports a constructor (which captures a reference to the array), an
35 // AccumulateAt() member function that applies supplied subscripts to the
36 // array and does something with a scalar element, and a GetResult()
37 // member function that copies a final result into its destination.
39 // Total reduction of the array argument to a scalar (or to a vector in the
40 // cases of FINDLOC, MAXLOC, & MINLOC). These are the cases without DIM= or
41 // cases where the argument has rank 1 and DIM=, if present, must be 1.
42 template <typename TYPE, typename ACCUMULATOR>
43 inline void DoTotalReduction(const Descriptor &x, int dim,
44 const Descriptor *mask, ACCUMULATOR &accumulator, const char *intrinsic,
45 Terminator &terminator) {
46 if (dim < 0 || dim > 1) {
47 terminator.Crash("%s: bad DIM=%d for ARRAY argument with rank %d",
48 intrinsic, dim, x.rank());
50 SubscriptValue xAt[maxRank];
51 x.GetLowerBounds(xAt);
52 if (mask) {
53 CheckConformability(x, *mask, terminator, intrinsic, "ARRAY", "MASK");
54 SubscriptValue maskAt[maxRank];
55 mask->GetLowerBounds(maskAt);
56 if (mask->rank() > 0) {
57 for (auto elements{x.Elements()}; elements--;
58 x.IncrementSubscripts(xAt), mask->IncrementSubscripts(maskAt)) {
59 if (IsLogicalElementTrue(*mask, maskAt)) {
60 if (!accumulator.template AccumulateAt<TYPE>(xAt))
61 break;
64 return;
65 } else if (!IsLogicalElementTrue(*mask, maskAt)) {
66 // scalar MASK=.FALSE.: return identity value
67 return;
70 // No MASK=, or scalar MASK=.TRUE.
71 for (auto elements{x.Elements()}; elements--; x.IncrementSubscripts(xAt)) {
72 if (!accumulator.template AccumulateAt<TYPE>(xAt)) {
73 break; // cut short, result is known
78 template <TypeCategory CAT, int KIND, typename ACCUMULATOR>
79 inline CppTypeFor<CAT, KIND> GetTotalReduction(const Descriptor &x,
80 const char *source, int line, int dim, const Descriptor *mask,
81 ACCUMULATOR &&accumulator, const char *intrinsic) {
82 Terminator terminator{source, line};
83 RUNTIME_CHECK(terminator, TypeCode(CAT, KIND) == x.type());
84 using CppType = CppTypeFor<CAT, KIND>;
85 DoTotalReduction<CppType>(x, dim, mask, accumulator, intrinsic, terminator);
86 CppType result;
87 #ifdef _MSC_VER // work around MSVC spurious error
88 accumulator.GetResult(&result);
89 #else
90 accumulator.template GetResult(&result);
91 #endif
92 return result;
95 // For reductions on a dimension, e.g. SUM(array,DIM=2) where the shape
96 // of the array is [2,3,5], the shape of the result is [2,5] and
97 // result(j,k) = SUM(array(j,:,k)), possibly modified if the array has
98 // lower bounds other than one. This utility subroutine creates an
99 // array of subscripts [j,_,k] for result subscripts [j,k] so that the
100 // elements of array(j,:,k) can be reduced.
101 inline void GetExpandedSubscripts(SubscriptValue at[],
102 const Descriptor &descriptor, int zeroBasedDim,
103 const SubscriptValue from[]) {
104 descriptor.GetLowerBounds(at);
105 int rank{descriptor.rank()};
106 int j{0};
107 for (; j < zeroBasedDim; ++j) {
108 at[j] += from[j] - 1 /*lower bound*/;
110 for (++j; j < rank; ++j) {
111 at[j] += from[j - 1] - 1;
115 template <typename TYPE, typename ACCUMULATOR>
116 inline void ReduceDimToScalar(const Descriptor &x, int zeroBasedDim,
117 SubscriptValue subscripts[], TYPE *result, ACCUMULATOR &accumulator) {
118 SubscriptValue xAt[maxRank];
119 GetExpandedSubscripts(xAt, x, zeroBasedDim, subscripts);
120 const auto &dim{x.GetDimension(zeroBasedDim)};
121 SubscriptValue at{dim.LowerBound()};
122 for (auto n{dim.Extent()}; n-- > 0; ++at) {
123 xAt[zeroBasedDim] = at;
124 if (!accumulator.template AccumulateAt<TYPE>(xAt)) {
125 break;
128 #ifdef _MSC_VER // work around MSVC spurious error
129 accumulator.GetResult(result, zeroBasedDim);
130 #else
131 accumulator.template GetResult(result, zeroBasedDim);
132 #endif
135 template <typename TYPE, typename ACCUMULATOR>
136 inline void ReduceDimMaskToScalar(const Descriptor &x, int zeroBasedDim,
137 SubscriptValue subscripts[], const Descriptor &mask, TYPE *result,
138 ACCUMULATOR &accumulator) {
139 SubscriptValue xAt[maxRank], maskAt[maxRank];
140 GetExpandedSubscripts(xAt, x, zeroBasedDim, subscripts);
141 GetExpandedSubscripts(maskAt, mask, zeroBasedDim, subscripts);
142 const auto &xDim{x.GetDimension(zeroBasedDim)};
143 SubscriptValue xPos{xDim.LowerBound()};
144 const auto &maskDim{mask.GetDimension(zeroBasedDim)};
145 SubscriptValue maskPos{maskDim.LowerBound()};
146 for (auto n{x.GetDimension(zeroBasedDim).Extent()}; n-- > 0;
147 ++xPos, ++maskPos) {
148 maskAt[zeroBasedDim] = maskPos;
149 if (IsLogicalElementTrue(mask, maskAt)) {
150 xAt[zeroBasedDim] = xPos;
151 if (!accumulator.template AccumulateAt<TYPE>(xAt)) {
152 break;
156 #ifdef _MSC_VER // work around MSVC spurious error
157 accumulator.GetResult(result, zeroBasedDim);
158 #else
159 accumulator.template GetResult(result, zeroBasedDim);
160 #endif
163 // Utility: establishes & allocates the result array for a partial
164 // reduction (i.e., one with DIM=).
165 static void CreatePartialReductionResult(Descriptor &result,
166 const Descriptor &x, std::size_t resultElementSize, int dim,
167 Terminator &terminator, const char *intrinsic, TypeCode typeCode) {
168 int xRank{x.rank()};
169 if (dim < 1 || dim > xRank) {
170 terminator.Crash(
171 "%s: bad DIM=%d for ARRAY with rank %d", intrinsic, dim, xRank);
173 int zeroBasedDim{dim - 1};
174 SubscriptValue resultExtent[maxRank];
175 for (int j{0}; j < zeroBasedDim; ++j) {
176 resultExtent[j] = x.GetDimension(j).Extent();
178 for (int j{zeroBasedDim + 1}; j < xRank; ++j) {
179 resultExtent[j - 1] = x.GetDimension(j).Extent();
181 result.Establish(typeCode, resultElementSize, nullptr, xRank - 1,
182 resultExtent, CFI_attribute_allocatable);
183 for (int j{0}; j + 1 < xRank; ++j) {
184 result.GetDimension(j).SetBounds(1, resultExtent[j]);
186 if (int stat{result.Allocate()}) {
187 terminator.Crash(
188 "%s: could not allocate memory for result; STAT=%d", intrinsic, stat);
192 // Partial reductions with DIM=
194 template <typename ACCUMULATOR, TypeCategory CAT, int KIND>
195 inline void PartialReduction(Descriptor &result, const Descriptor &x,
196 std::size_t resultElementSize, int dim, const Descriptor *mask,
197 Terminator &terminator, const char *intrinsic, ACCUMULATOR &accumulator) {
198 CreatePartialReductionResult(result, x, resultElementSize, dim, terminator,
199 intrinsic, TypeCode{CAT, KIND});
200 SubscriptValue at[maxRank];
201 result.GetLowerBounds(at);
202 INTERNAL_CHECK(result.rank() == 0 || at[0] == 1);
203 using CppType = CppTypeFor<CAT, KIND>;
204 if (mask) {
205 CheckConformability(x, *mask, terminator, intrinsic, "ARRAY", "MASK");
206 SubscriptValue maskAt[maxRank]; // contents unused
207 if (mask->rank() > 0) {
208 for (auto n{result.Elements()}; n-- > 0; result.IncrementSubscripts(at)) {
209 accumulator.Reinitialize();
210 ReduceDimMaskToScalar<CppType, ACCUMULATOR>(
211 x, dim - 1, at, *mask, result.Element<CppType>(at), accumulator);
213 return;
214 } else if (!IsLogicalElementTrue(*mask, maskAt)) {
215 // scalar MASK=.FALSE.
216 accumulator.Reinitialize();
217 for (auto n{result.Elements()}; n-- > 0; result.IncrementSubscripts(at)) {
218 accumulator.GetResult(result.Element<CppType>(at));
220 return;
223 // No MASK= or scalar MASK=.TRUE.
224 for (auto n{result.Elements()}; n-- > 0; result.IncrementSubscripts(at)) {
225 accumulator.Reinitialize();
226 ReduceDimToScalar<CppType, ACCUMULATOR>(
227 x, dim - 1, at, result.Element<CppType>(at), accumulator);
231 template <template <typename> class ACCUM>
232 struct PartialIntegerReductionHelper {
233 template <int KIND> struct Functor {
234 static constexpr int Intermediate{
235 std::max(KIND, 4)}; // use at least "int" for intermediate results
236 void operator()(Descriptor &result, const Descriptor &x, int dim,
237 const Descriptor *mask, Terminator &terminator,
238 const char *intrinsic) const {
239 using Accumulator =
240 ACCUM<CppTypeFor<TypeCategory::Integer, Intermediate>>;
241 Accumulator accumulator{x};
242 // Element size of the destination descriptor is the same
243 // as the element size of the source.
244 PartialReduction<Accumulator, TypeCategory::Integer, KIND>(result, x,
245 x.ElementBytes(), dim, mask, terminator, intrinsic, accumulator);
250 template <template <typename> class INTEGER_ACCUM>
251 inline void PartialIntegerReduction(Descriptor &result, const Descriptor &x,
252 int dim, int kind, const Descriptor *mask, const char *intrinsic,
253 Terminator &terminator) {
254 ApplyIntegerKind<
255 PartialIntegerReductionHelper<INTEGER_ACCUM>::template Functor, void>(
256 kind, terminator, result, x, dim, mask, terminator, intrinsic);
259 template <TypeCategory CAT, template <typename> class ACCUM>
260 struct PartialFloatingReductionHelper {
261 template <int KIND> struct Functor {
262 static constexpr int Intermediate{
263 std::max(KIND, 8)}; // use at least "double" for intermediate results
264 void operator()(Descriptor &result, const Descriptor &x, int dim,
265 const Descriptor *mask, Terminator &terminator,
266 const char *intrinsic) const {
267 using Accumulator = ACCUM<CppTypeFor<TypeCategory::Real, Intermediate>>;
268 Accumulator accumulator{x};
269 // Element size of the destination descriptor is the same
270 // as the element size of the source.
271 PartialReduction<Accumulator, CAT, KIND>(result, x, x.ElementBytes(), dim,
272 mask, terminator, intrinsic, accumulator);
277 template <template <typename> class INTEGER_ACCUM,
278 template <typename> class REAL_ACCUM,
279 template <typename> class COMPLEX_ACCUM>
280 inline void TypedPartialNumericReduction(Descriptor &result,
281 const Descriptor &x, int dim, const char *source, int line,
282 const Descriptor *mask, const char *intrinsic) {
283 Terminator terminator{source, line};
284 auto catKind{x.type().GetCategoryAndKind()};
285 RUNTIME_CHECK(terminator, catKind.has_value());
286 switch (catKind->first) {
287 case TypeCategory::Integer:
288 PartialIntegerReduction<INTEGER_ACCUM>(
289 result, x, dim, catKind->second, mask, intrinsic, terminator);
290 break;
291 case TypeCategory::Real:
292 ApplyFloatingPointKind<PartialFloatingReductionHelper<TypeCategory::Real,
293 REAL_ACCUM>::template Functor,
294 void>(catKind->second, terminator, result, x, dim, mask, terminator,
295 intrinsic);
296 break;
297 case TypeCategory::Complex:
298 ApplyFloatingPointKind<PartialFloatingReductionHelper<TypeCategory::Complex,
299 COMPLEX_ACCUM>::template Functor,
300 void>(catKind->second, terminator, result, x, dim, mask, terminator,
301 intrinsic);
302 break;
303 default:
304 terminator.Crash("%s: bad type code %d", intrinsic, x.type().raw());
308 template <typename ACCUMULATOR> struct LocationResultHelper {
309 template <int KIND> struct Functor {
310 void operator()(ACCUMULATOR &accumulator, const Descriptor &result) const {
311 accumulator.GetResult(
312 result.OffsetElement<CppTypeFor<TypeCategory::Integer, KIND>>());
317 template <typename ACCUMULATOR> struct PartialLocationHelper {
318 template <int KIND> struct Functor {
319 void operator()(Descriptor &result, const Descriptor &x, int dim,
320 const Descriptor *mask, Terminator &terminator, const char *intrinsic,
321 ACCUMULATOR &accumulator) const {
322 // Element size of the destination descriptor is the size
323 // of {TypeCategory::Integer, KIND}.
324 PartialReduction<ACCUMULATOR, TypeCategory::Integer, KIND>(result, x,
325 Descriptor::BytesFor(TypeCategory::Integer, KIND), dim, mask,
326 terminator, intrinsic, accumulator);
331 } // namespace Fortran::runtime
332 #endif // FORTRAN_RUNTIME_REDUCTION_TEMPLATES_H_