[Hexagon] Use llvm::children (NFC)
[llvm-project.git] / flang / lib / Evaluate / fold-reduction.h
blob1ee957c0faebd8bc163f29531a04ee7ec730bffd
1 //===-- lib/Evaluate/fold-reduction.h -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
10 #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
12 #include "fold-implementation.h"
14 namespace Fortran::evaluate {
16 // DOT_PRODUCT
17 template <typename T>
18 static Expr<T> FoldDotProduct(
19 FoldingContext &context, FunctionRef<T> &&funcRef) {
20 using Element = typename Constant<T>::Element;
21 auto args{funcRef.arguments()};
22 CHECK(args.size() == 2);
23 Folder<T> folder{context};
24 Constant<T> *va{folder.Folding(args[0])};
25 Constant<T> *vb{folder.Folding(args[1])};
26 if (va && vb) {
27 CHECK(va->Rank() == 1 && vb->Rank() == 1);
28 if (va->size() != vb->size()) {
29 context.messages().Say(
30 "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US,
31 va->size(), vb->size());
32 return MakeInvalidIntrinsic(std::move(funcRef));
34 Element sum{};
35 bool overflow{false};
36 if constexpr (T::category == TypeCategory::Complex) {
37 std::vector<Element> conjugates;
38 for (const Element &x : va->values()) {
39 conjugates.emplace_back(x.CONJG());
41 Constant<T> conjgA{
42 std::move(conjugates), ConstantSubscripts{va->shape()}};
43 Expr<T> products{Fold(
44 context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
45 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
46 Element correction{}; // Use Kahan summation for greater precision.
47 const auto &rounding{context.targetCharacteristics().roundingMode()};
48 for (const Element &x : cProducts.values()) {
49 auto next{correction.Add(x, rounding)};
50 overflow |= next.flags.test(RealFlag::Overflow);
51 auto added{sum.Add(next.value, rounding)};
52 overflow |= added.flags.test(RealFlag::Overflow);
53 correction = added.value.Subtract(sum, rounding)
54 .value.Subtract(next.value, rounding)
55 .value;
56 sum = std::move(added.value);
58 } else if constexpr (T::category == TypeCategory::Logical) {
59 Expr<T> conjunctions{Fold(context,
60 Expr<T>{LogicalOperation<T::kind>{LogicalOperator::And,
61 Expr<T>{Constant<T>{*va}}, Expr<T>{Constant<T>{*vb}}}})};
62 Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))};
63 for (const Element &x : cConjunctions.values()) {
64 if (x.IsTrue()) {
65 sum = Element{true};
66 break;
69 } else if constexpr (T::category == TypeCategory::Integer) {
70 Expr<T> products{
71 Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
72 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
73 for (const Element &x : cProducts.values()) {
74 auto next{sum.AddSigned(x)};
75 overflow |= next.overflow;
76 sum = std::move(next.value);
78 } else {
79 static_assert(T::category == TypeCategory::Real);
80 Expr<T> products{
81 Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
82 Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
83 Element correction{}; // Use Kahan summation for greater precision.
84 const auto &rounding{context.targetCharacteristics().roundingMode()};
85 for (const Element &x : cProducts.values()) {
86 auto next{correction.Add(x, rounding)};
87 overflow |= next.flags.test(RealFlag::Overflow);
88 auto added{sum.Add(next.value, rounding)};
89 overflow |= added.flags.test(RealFlag::Overflow);
90 correction = added.value.Subtract(sum, rounding)
91 .value.Subtract(next.value, rounding)
92 .value;
93 sum = std::move(added.value);
96 if (overflow) {
97 context.messages().Say(
98 "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US,
99 T::AsFortran());
101 return Expr<T>{Constant<T>{std::move(sum)}};
103 return Expr<T>{std::move(funcRef)};
106 // Fold and validate a DIM= argument. Returns false on error.
107 bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &,
108 ActualArguments &, std::optional<int> dimIndex, int rank);
110 // Fold and validate a MASK= argument. Return null on error, absent MASK=, or
111 // non-constant MASK=.
112 Constant<LogicalResult> *GetReductionMASK(
113 std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape,
114 FoldingContext &);
116 // Common preprocessing for reduction transformational intrinsic function
117 // folding. If the intrinsic can have DIM= &/or MASK= arguments, extract
118 // and check them. If a MASK= is present, apply it to the array data and
119 // substitute replacement values for elements corresponding to .FALSE. in
120 // the mask. If the result is present, the intrinsic call can be folded.
121 template <typename T> struct ArrayAndMask {
122 Constant<T> array;
123 Constant<LogicalResult> mask;
125 template <typename T>
126 static std::optional<ArrayAndMask<T>> ProcessReductionArgs(
127 FoldingContext &context, ActualArguments &arg, std::optional<int> &dim,
128 int arrayIndex, std::optional<int> dimIndex = std::nullopt,
129 std::optional<int> maskIndex = std::nullopt) {
130 if (arg.empty()) {
131 return std::nullopt;
133 Constant<T> *folded{Folder<T>{context}.Folding(arg[arrayIndex])};
134 if (!folded || folded->Rank() < 1) {
135 return std::nullopt;
137 if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) {
138 return std::nullopt;
140 std::size_t n{folded->size()};
141 std::vector<Scalar<LogicalResult>> maskElement;
142 if (maskIndex && static_cast<std::size_t>(*maskIndex) < arg.size() &&
143 arg[*maskIndex]) {
144 if (const Constant<LogicalResult> *origMask{
145 GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) {
146 if (auto scalarMask{origMask->GetScalarValue()}) {
147 maskElement =
148 std::vector<Scalar<LogicalResult>>(n, scalarMask->IsTrue());
149 } else {
150 maskElement = origMask->values();
152 } else {
153 return std::nullopt;
155 } else {
156 maskElement = std::vector<Scalar<LogicalResult>>(n, true);
158 return ArrayAndMask<T>{Constant<T>(*folded),
159 Constant<LogicalResult>{
160 std::move(maskElement), ConstantSubscripts{folded->shape()}}};
163 // Generalized reduction to an array of one dimension fewer (w/ DIM=)
164 // or to a scalar (w/o DIM=). The ACCUMULATOR type must define
165 // operator()(Scalar<T> &, const ConstantSubscripts &, bool first)
166 // and Done(Scalar<T> &).
167 template <typename T, typename ACCUMULATOR, typename ARRAY>
168 static Constant<T> DoReduction(const Constant<ARRAY> &array,
169 const Constant<LogicalResult> &mask, std::optional<int> &dim,
170 const Scalar<T> &identity, ACCUMULATOR &accumulator) {
171 ConstantSubscripts at{array.lbounds()};
172 ConstantSubscripts maskAt{mask.lbounds()};
173 std::vector<typename Constant<T>::Element> elements;
174 ConstantSubscripts resultShape; // empty -> scalar
175 if (dim) { // DIM= is present, so result is an array
176 resultShape = array.shape();
177 resultShape.erase(resultShape.begin() + (*dim - 1));
178 ConstantSubscript dimExtent{array.shape().at(*dim - 1)};
179 CHECK(dimExtent == mask.shape().at(*dim - 1));
180 ConstantSubscript &dimAt{at[*dim - 1]};
181 ConstantSubscript dimLbound{dimAt};
182 ConstantSubscript &maskDimAt{maskAt[*dim - 1]};
183 ConstantSubscript maskDimLbound{maskDimAt};
184 for (auto n{GetSize(resultShape)}; n-- > 0;
185 IncrementSubscripts(at, array.shape()),
186 IncrementSubscripts(maskAt, mask.shape())) {
187 dimAt = dimLbound;
188 maskDimAt = maskDimLbound;
189 elements.push_back(identity);
190 bool firstUnmasked{true};
191 for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt, ++maskDimAt) {
192 if (mask.At(maskAt).IsTrue()) {
193 accumulator(elements.back(), at, firstUnmasked);
194 firstUnmasked = false;
197 accumulator.Done(elements.back());
199 } else { // no DIM=, result is scalar
200 elements.push_back(identity);
201 bool firstUnmasked{true};
202 for (auto n{array.size()}; n-- > 0; IncrementSubscripts(at, array.shape()),
203 IncrementSubscripts(maskAt, mask.shape())) {
204 if (mask.At(maskAt).IsTrue()) {
205 accumulator(elements.back(), at, firstUnmasked);
206 firstUnmasked = false;
209 accumulator.Done(elements.back());
211 if constexpr (T::category == TypeCategory::Character) {
212 return {static_cast<ConstantSubscript>(identity.size()),
213 std::move(elements), std::move(resultShape)};
214 } else {
215 return {std::move(elements), std::move(resultShape)};
219 // MAXVAL & MINVAL
220 template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
221 public:
222 MaxvalMinvalAccumulator(
223 RelationalOperator opr, FoldingContext &context, const Constant<T> &array)
224 : opr_{opr}, context_{context}, array_{array} {};
225 void operator()(Scalar<T> &element, const ConstantSubscripts &at,
226 [[maybe_unused]] bool firstUnmasked) const {
227 auto aAt{array_.At(at)};
228 if constexpr (ABS) {
229 aAt = aAt.ABS();
231 if constexpr (T::category == TypeCategory::Real) {
232 if (firstUnmasked || element.IsNotANumber()) {
233 // Return NaN if and only if all unmasked elements are NaNs and
234 // at least one unmasked element is visible.
235 element = aAt;
236 return;
239 Expr<LogicalResult> test{PackageRelation(
240 opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})};
241 auto folded{GetScalarConstantValue<LogicalResult>(
242 test.Rewrite(context_, std::move(test)))};
243 CHECK(folded.has_value());
244 if (folded->IsTrue()) {
245 element = aAt;
248 void Done(Scalar<T> &) const {}
250 private:
251 RelationalOperator opr_;
252 FoldingContext &context_;
253 const Constant<T> &array_;
256 template <typename T>
257 static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
258 RelationalOperator opr, const Scalar<T> &identity) {
259 static_assert(T::category == TypeCategory::Integer ||
260 T::category == TypeCategory::Real ||
261 T::category == TypeCategory::Character);
262 std::optional<int> dim;
263 if (std::optional<ArrayAndMask<T>> arrayAndMask{
264 ProcessReductionArgs<T>(context, ref.arguments(), dim,
265 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
266 MaxvalMinvalAccumulator accumulator{opr, context, arrayAndMask->array};
267 return Expr<T>{DoReduction<T>(
268 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
270 return Expr<T>{std::move(ref)};
273 // PRODUCT
274 template <typename T> class ProductAccumulator {
275 public:
276 ProductAccumulator(const Constant<T> &array) : array_{array} {}
277 void operator()(
278 Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
279 if constexpr (T::category == TypeCategory::Integer) {
280 auto prod{element.MultiplySigned(array_.At(at))};
281 overflow_ |= prod.SignedMultiplicationOverflowed();
282 element = prod.lower;
283 } else { // Real & Complex
284 auto prod{element.Multiply(array_.At(at))};
285 overflow_ |= prod.flags.test(RealFlag::Overflow);
286 element = prod.value;
289 bool overflow() const { return overflow_; }
290 void Done(Scalar<T> &) const {}
292 private:
293 const Constant<T> &array_;
294 bool overflow_{false};
297 template <typename T>
298 static Expr<T> FoldProduct(
299 FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) {
300 static_assert(T::category == TypeCategory::Integer ||
301 T::category == TypeCategory::Real ||
302 T::category == TypeCategory::Complex);
303 std::optional<int> dim;
304 if (std::optional<ArrayAndMask<T>> arrayAndMask{
305 ProcessReductionArgs<T>(context, ref.arguments(), dim,
306 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
307 ProductAccumulator accumulator{arrayAndMask->array};
308 auto result{Expr<T>{DoReduction<T>(
309 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
310 if (accumulator.overflow()) {
311 context.messages().Say(
312 "PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
314 return result;
316 return Expr<T>{std::move(ref)};
319 // SUM
320 template <typename T> class SumAccumulator {
321 using Element = typename Constant<T>::Element;
323 public:
324 SumAccumulator(const Constant<T> &array, Rounding rounding)
325 : array_{array}, rounding_{rounding} {}
326 void operator()(
327 Element &element, const ConstantSubscripts &at, bool /*first*/) {
328 if constexpr (T::category == TypeCategory::Integer) {
329 auto sum{element.AddSigned(array_.At(at))};
330 overflow_ |= sum.overflow;
331 element = sum.value;
332 } else { // Real & Complex: use Kahan summation
333 auto next{array_.At(at).Add(correction_, rounding_)};
334 overflow_ |= next.flags.test(RealFlag::Overflow);
335 auto sum{element.Add(next.value, rounding_)};
336 overflow_ |= sum.flags.test(RealFlag::Overflow);
337 // correction = (sum - element) - next; algebraically zero
338 correction_ = sum.value.Subtract(element, rounding_)
339 .value.Subtract(next.value, rounding_)
340 .value;
341 element = sum.value;
344 bool overflow() const { return overflow_; }
345 void Done([[maybe_unused]] Element &element) {
346 if constexpr (T::category != TypeCategory::Integer) {
347 auto corrected{element.Add(correction_, rounding_)};
348 overflow_ |= corrected.flags.test(RealFlag::Overflow);
349 correction_ = Scalar<T>{};
350 element = corrected.value;
354 private:
355 const Constant<T> &array_;
356 Rounding rounding_;
357 bool overflow_{false};
358 Element correction_{};
361 template <typename T>
362 static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
363 static_assert(T::category == TypeCategory::Integer ||
364 T::category == TypeCategory::Real ||
365 T::category == TypeCategory::Complex);
366 using Element = typename Constant<T>::Element;
367 std::optional<int> dim;
368 Element identity{};
369 if (std::optional<ArrayAndMask<T>> arrayAndMask{
370 ProcessReductionArgs<T>(context, ref.arguments(), dim,
371 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
372 SumAccumulator accumulator{
373 arrayAndMask->array, context.targetCharacteristics().roundingMode()};
374 auto result{Expr<T>{DoReduction<T>(
375 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
376 if (accumulator.overflow()) {
377 context.messages().Say(
378 "SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
380 return result;
382 return Expr<T>{std::move(ref)};
385 // Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY
386 template <typename T> class OperationAccumulator {
387 public:
388 OperationAccumulator(const Constant<T> &array,
389 Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const)
390 : array_{array}, operation_{operation} {}
391 void operator()(
392 Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
393 element = (element.*operation_)(array_.At(at));
395 void Done(Scalar<T> &) const {}
397 private:
398 const Constant<T> &array_;
399 Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const;
402 } // namespace Fortran::evaluate
403 #endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_