1 //===-- lib/Evaluate/fold-reduction.h -------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
10 #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
12 #include "fold-implementation.h"
14 namespace Fortran::evaluate
{
18 static Expr
<T
> FoldDotProduct(
19 FoldingContext
&context
, FunctionRef
<T
> &&funcRef
) {
20 using Element
= typename Constant
<T
>::Element
;
21 auto args
{funcRef
.arguments()};
22 CHECK(args
.size() == 2);
23 Folder
<T
> folder
{context
};
24 Constant
<T
> *va
{folder
.Folding(args
[0])};
25 Constant
<T
> *vb
{folder
.Folding(args
[1])};
27 CHECK(va
->Rank() == 1 && vb
->Rank() == 1);
28 if (va
->size() != vb
->size()) {
29 context
.messages().Say(
30 "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US
,
31 va
->size(), vb
->size());
32 return MakeInvalidIntrinsic(std::move(funcRef
));
36 if constexpr (T::category
== TypeCategory::Complex
) {
37 std::vector
<Element
> conjugates
;
38 for (const Element
&x
: va
->values()) {
39 conjugates
.emplace_back(x
.CONJG());
42 std::move(conjugates
), ConstantSubscripts
{va
->shape()}};
43 Expr
<T
> products
{Fold(
44 context
, Expr
<T
>{std::move(conjgA
)} * Expr
<T
>{Constant
<T
>{*vb
}})};
45 Constant
<T
> &cProducts
{DEREF(UnwrapConstantValue
<T
>(products
))};
46 Element correction
{}; // Use Kahan summation for greater precision.
47 const auto &rounding
{context
.targetCharacteristics().roundingMode()};
48 for (const Element
&x
: cProducts
.values()) {
49 auto next
{correction
.Add(x
, rounding
)};
50 overflow
|= next
.flags
.test(RealFlag::Overflow
);
51 auto added
{sum
.Add(next
.value
, rounding
)};
52 overflow
|= added
.flags
.test(RealFlag::Overflow
);
53 correction
= added
.value
.Subtract(sum
, rounding
)
54 .value
.Subtract(next
.value
, rounding
)
56 sum
= std::move(added
.value
);
58 } else if constexpr (T::category
== TypeCategory::Logical
) {
59 Expr
<T
> conjunctions
{Fold(context
,
60 Expr
<T
>{LogicalOperation
<T::kind
>{LogicalOperator::And
,
61 Expr
<T
>{Constant
<T
>{*va
}}, Expr
<T
>{Constant
<T
>{*vb
}}}})};
62 Constant
<T
> &cConjunctions
{DEREF(UnwrapConstantValue
<T
>(conjunctions
))};
63 for (const Element
&x
: cConjunctions
.values()) {
69 } else if constexpr (T::category
== TypeCategory::Integer
) {
71 Fold(context
, Expr
<T
>{Constant
<T
>{*va
}} * Expr
<T
>{Constant
<T
>{*vb
}})};
72 Constant
<T
> &cProducts
{DEREF(UnwrapConstantValue
<T
>(products
))};
73 for (const Element
&x
: cProducts
.values()) {
74 auto next
{sum
.AddSigned(x
)};
75 overflow
|= next
.overflow
;
76 sum
= std::move(next
.value
);
79 static_assert(T::category
== TypeCategory::Real
);
81 Fold(context
, Expr
<T
>{Constant
<T
>{*va
}} * Expr
<T
>{Constant
<T
>{*vb
}})};
82 Constant
<T
> &cProducts
{DEREF(UnwrapConstantValue
<T
>(products
))};
83 Element correction
{}; // Use Kahan summation for greater precision.
84 const auto &rounding
{context
.targetCharacteristics().roundingMode()};
85 for (const Element
&x
: cProducts
.values()) {
86 auto next
{correction
.Add(x
, rounding
)};
87 overflow
|= next
.flags
.test(RealFlag::Overflow
);
88 auto added
{sum
.Add(next
.value
, rounding
)};
89 overflow
|= added
.flags
.test(RealFlag::Overflow
);
90 correction
= added
.value
.Subtract(sum
, rounding
)
91 .value
.Subtract(next
.value
, rounding
)
93 sum
= std::move(added
.value
);
97 context
.messages().Say(
98 "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US
,
101 return Expr
<T
>{Constant
<T
>{std::move(sum
)}};
103 return Expr
<T
>{std::move(funcRef
)};
106 // Fold and validate a DIM= argument. Returns false on error.
107 bool CheckReductionDIM(std::optional
<int> &dim
, FoldingContext
&,
108 ActualArguments
&, std::optional
<int> dimIndex
, int rank
);
110 // Fold and validate a MASK= argument. Return null on error, absent MASK=, or
111 // non-constant MASK=.
112 Constant
<LogicalResult
> *GetReductionMASK(
113 std::optional
<ActualArgument
> &maskArg
, const ConstantSubscripts
&shape
,
116 // Common preprocessing for reduction transformational intrinsic function
117 // folding. If the intrinsic can have DIM= &/or MASK= arguments, extract
118 // and check them. If a MASK= is present, apply it to the array data and
119 // substitute replacement values for elements corresponding to .FALSE. in
120 // the mask. If the result is present, the intrinsic call can be folded.
121 template <typename T
> struct ArrayAndMask
{
123 Constant
<LogicalResult
> mask
;
125 template <typename T
>
126 static std::optional
<ArrayAndMask
<T
>> ProcessReductionArgs(
127 FoldingContext
&context
, ActualArguments
&arg
, std::optional
<int> &dim
,
128 int arrayIndex
, std::optional
<int> dimIndex
= std::nullopt
,
129 std::optional
<int> maskIndex
= std::nullopt
) {
133 Constant
<T
> *folded
{Folder
<T
>{context
}.Folding(arg
[arrayIndex
])};
134 if (!folded
|| folded
->Rank() < 1) {
137 if (!CheckReductionDIM(dim
, context
, arg
, dimIndex
, folded
->Rank())) {
140 std::size_t n
{folded
->size()};
141 std::vector
<Scalar
<LogicalResult
>> maskElement
;
142 if (maskIndex
&& static_cast<std::size_t>(*maskIndex
) < arg
.size() &&
144 if (const Constant
<LogicalResult
> *origMask
{
145 GetReductionMASK(arg
[*maskIndex
], folded
->shape(), context
)}) {
146 if (auto scalarMask
{origMask
->GetScalarValue()}) {
148 std::vector
<Scalar
<LogicalResult
>>(n
, scalarMask
->IsTrue());
150 maskElement
= origMask
->values();
156 maskElement
= std::vector
<Scalar
<LogicalResult
>>(n
, true);
158 return ArrayAndMask
<T
>{Constant
<T
>(*folded
),
159 Constant
<LogicalResult
>{
160 std::move(maskElement
), ConstantSubscripts
{folded
->shape()}}};
163 // Generalized reduction to an array of one dimension fewer (w/ DIM=)
164 // or to a scalar (w/o DIM=). The ACCUMULATOR type must define
165 // operator()(Scalar<T> &, const ConstantSubscripts &, bool first)
166 // and Done(Scalar<T> &).
167 template <typename T
, typename ACCUMULATOR
, typename ARRAY
>
168 static Constant
<T
> DoReduction(const Constant
<ARRAY
> &array
,
169 const Constant
<LogicalResult
> &mask
, std::optional
<int> &dim
,
170 const Scalar
<T
> &identity
, ACCUMULATOR
&accumulator
) {
171 ConstantSubscripts at
{array
.lbounds()};
172 ConstantSubscripts maskAt
{mask
.lbounds()};
173 std::vector
<typename Constant
<T
>::Element
> elements
;
174 ConstantSubscripts resultShape
; // empty -> scalar
175 if (dim
) { // DIM= is present, so result is an array
176 resultShape
= array
.shape();
177 resultShape
.erase(resultShape
.begin() + (*dim
- 1));
178 ConstantSubscript dimExtent
{array
.shape().at(*dim
- 1)};
179 CHECK(dimExtent
== mask
.shape().at(*dim
- 1));
180 ConstantSubscript
&dimAt
{at
[*dim
- 1]};
181 ConstantSubscript dimLbound
{dimAt
};
182 ConstantSubscript
&maskDimAt
{maskAt
[*dim
- 1]};
183 ConstantSubscript maskDimLbound
{maskDimAt
};
184 for (auto n
{GetSize(resultShape
)}; n
-- > 0;
185 IncrementSubscripts(at
, array
.shape()),
186 IncrementSubscripts(maskAt
, mask
.shape())) {
188 maskDimAt
= maskDimLbound
;
189 elements
.push_back(identity
);
190 bool firstUnmasked
{true};
191 for (ConstantSubscript j
{0}; j
< dimExtent
; ++j
, ++dimAt
, ++maskDimAt
) {
192 if (mask
.At(maskAt
).IsTrue()) {
193 accumulator(elements
.back(), at
, firstUnmasked
);
194 firstUnmasked
= false;
197 accumulator
.Done(elements
.back());
199 } else { // no DIM=, result is scalar
200 elements
.push_back(identity
);
201 bool firstUnmasked
{true};
202 for (auto n
{array
.size()}; n
-- > 0; IncrementSubscripts(at
, array
.shape()),
203 IncrementSubscripts(maskAt
, mask
.shape())) {
204 if (mask
.At(maskAt
).IsTrue()) {
205 accumulator(elements
.back(), at
, firstUnmasked
);
206 firstUnmasked
= false;
209 accumulator
.Done(elements
.back());
211 if constexpr (T::category
== TypeCategory::Character
) {
212 return {static_cast<ConstantSubscript
>(identity
.size()),
213 std::move(elements
), std::move(resultShape
)};
215 return {std::move(elements
), std::move(resultShape
)};
220 template <typename T
, bool ABS
= false> class MaxvalMinvalAccumulator
{
222 MaxvalMinvalAccumulator(
223 RelationalOperator opr
, FoldingContext
&context
, const Constant
<T
> &array
)
224 : opr_
{opr
}, context_
{context
}, array_
{array
} {};
225 void operator()(Scalar
<T
> &element
, const ConstantSubscripts
&at
,
226 [[maybe_unused
]] bool firstUnmasked
) const {
227 auto aAt
{array_
.At(at
)};
231 if constexpr (T::category
== TypeCategory::Real
) {
232 if (firstUnmasked
|| element
.IsNotANumber()) {
233 // Return NaN if and only if all unmasked elements are NaNs and
234 // at least one unmasked element is visible.
239 Expr
<LogicalResult
> test
{PackageRelation(
240 opr_
, Expr
<T
>{Constant
<T
>{aAt
}}, Expr
<T
>{Constant
<T
>{element
}})};
241 auto folded
{GetScalarConstantValue
<LogicalResult
>(
242 test
.Rewrite(context_
, std::move(test
)))};
243 CHECK(folded
.has_value());
244 if (folded
->IsTrue()) {
248 void Done(Scalar
<T
> &) const {}
251 RelationalOperator opr_
;
252 FoldingContext
&context_
;
253 const Constant
<T
> &array_
;
256 template <typename T
>
257 static Expr
<T
> FoldMaxvalMinval(FoldingContext
&context
, FunctionRef
<T
> &&ref
,
258 RelationalOperator opr
, const Scalar
<T
> &identity
) {
259 static_assert(T::category
== TypeCategory::Integer
||
260 T::category
== TypeCategory::Real
||
261 T::category
== TypeCategory::Character
);
262 std::optional
<int> dim
;
263 if (std::optional
<ArrayAndMask
<T
>> arrayAndMask
{
264 ProcessReductionArgs
<T
>(context
, ref
.arguments(), dim
,
265 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
266 MaxvalMinvalAccumulator accumulator
{opr
, context
, arrayAndMask
->array
};
267 return Expr
<T
>{DoReduction
<T
>(
268 arrayAndMask
->array
, arrayAndMask
->mask
, dim
, identity
, accumulator
)};
270 return Expr
<T
>{std::move(ref
)};
274 template <typename T
> class ProductAccumulator
{
276 ProductAccumulator(const Constant
<T
> &array
) : array_
{array
} {}
278 Scalar
<T
> &element
, const ConstantSubscripts
&at
, bool /*first*/) {
279 if constexpr (T::category
== TypeCategory::Integer
) {
280 auto prod
{element
.MultiplySigned(array_
.At(at
))};
281 overflow_
|= prod
.SignedMultiplicationOverflowed();
282 element
= prod
.lower
;
283 } else { // Real & Complex
284 auto prod
{element
.Multiply(array_
.At(at
))};
285 overflow_
|= prod
.flags
.test(RealFlag::Overflow
);
286 element
= prod
.value
;
289 bool overflow() const { return overflow_
; }
290 void Done(Scalar
<T
> &) const {}
293 const Constant
<T
> &array_
;
294 bool overflow_
{false};
297 template <typename T
>
298 static Expr
<T
> FoldProduct(
299 FoldingContext
&context
, FunctionRef
<T
> &&ref
, Scalar
<T
> identity
) {
300 static_assert(T::category
== TypeCategory::Integer
||
301 T::category
== TypeCategory::Real
||
302 T::category
== TypeCategory::Complex
);
303 std::optional
<int> dim
;
304 if (std::optional
<ArrayAndMask
<T
>> arrayAndMask
{
305 ProcessReductionArgs
<T
>(context
, ref
.arguments(), dim
,
306 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
307 ProductAccumulator accumulator
{arrayAndMask
->array
};
308 auto result
{Expr
<T
>{DoReduction
<T
>(
309 arrayAndMask
->array
, arrayAndMask
->mask
, dim
, identity
, accumulator
)}};
310 if (accumulator
.overflow()) {
311 context
.messages().Say(
312 "PRODUCT() of %s data overflowed"_warn_en_US
, T::AsFortran());
316 return Expr
<T
>{std::move(ref
)};
320 template <typename T
> class SumAccumulator
{
321 using Element
= typename Constant
<T
>::Element
;
324 SumAccumulator(const Constant
<T
> &array
, Rounding rounding
)
325 : array_
{array
}, rounding_
{rounding
} {}
327 Element
&element
, const ConstantSubscripts
&at
, bool /*first*/) {
328 if constexpr (T::category
== TypeCategory::Integer
) {
329 auto sum
{element
.AddSigned(array_
.At(at
))};
330 overflow_
|= sum
.overflow
;
332 } else { // Real & Complex: use Kahan summation
333 auto next
{array_
.At(at
).Add(correction_
, rounding_
)};
334 overflow_
|= next
.flags
.test(RealFlag::Overflow
);
335 auto sum
{element
.Add(next
.value
, rounding_
)};
336 overflow_
|= sum
.flags
.test(RealFlag::Overflow
);
337 // correction = (sum - element) - next; algebraically zero
338 correction_
= sum
.value
.Subtract(element
, rounding_
)
339 .value
.Subtract(next
.value
, rounding_
)
344 bool overflow() const { return overflow_
; }
345 void Done([[maybe_unused
]] Element
&element
) {
346 if constexpr (T::category
!= TypeCategory::Integer
) {
347 auto corrected
{element
.Add(correction_
, rounding_
)};
348 overflow_
|= corrected
.flags
.test(RealFlag::Overflow
);
349 correction_
= Scalar
<T
>{};
350 element
= corrected
.value
;
355 const Constant
<T
> &array_
;
357 bool overflow_
{false};
358 Element correction_
{};
361 template <typename T
>
362 static Expr
<T
> FoldSum(FoldingContext
&context
, FunctionRef
<T
> &&ref
) {
363 static_assert(T::category
== TypeCategory::Integer
||
364 T::category
== TypeCategory::Real
||
365 T::category
== TypeCategory::Complex
);
366 using Element
= typename Constant
<T
>::Element
;
367 std::optional
<int> dim
;
369 if (std::optional
<ArrayAndMask
<T
>> arrayAndMask
{
370 ProcessReductionArgs
<T
>(context
, ref
.arguments(), dim
,
371 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
372 SumAccumulator accumulator
{
373 arrayAndMask
->array
, context
.targetCharacteristics().roundingMode()};
374 auto result
{Expr
<T
>{DoReduction
<T
>(
375 arrayAndMask
->array
, arrayAndMask
->mask
, dim
, identity
, accumulator
)}};
376 if (accumulator
.overflow()) {
377 context
.messages().Say(
378 "SUM() of %s data overflowed"_warn_en_US
, T::AsFortran());
382 return Expr
<T
>{std::move(ref
)};
385 // Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY
386 template <typename T
> class OperationAccumulator
{
388 OperationAccumulator(const Constant
<T
> &array
,
389 Scalar
<T
> (Scalar
<T
>::*operation
)(const Scalar
<T
> &) const)
390 : array_
{array
}, operation_
{operation
} {}
392 Scalar
<T
> &element
, const ConstantSubscripts
&at
, bool /*first*/) {
393 element
= (element
.*operation_
)(array_
.At(at
));
395 void Done(Scalar
<T
> &) const {}
398 const Constant
<T
> &array_
;
399 Scalar
<T
> (Scalar
<T
>::*operation_
)(const Scalar
<T
> &) const;
402 } // namespace Fortran::evaluate
403 #endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_