[Hexagon] Use llvm::children (NFC)
[llvm-project.git] / flang / lib / Evaluate / fold-logical.cpp
blob5a9596f3c274b5af879f3a8eee21dcc4c399c552
1 //===-- lib/Evaluate/fold-logical.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "fold-implementation.h"
10 #include "fold-matmul.h"
11 #include "fold-reduction.h"
12 #include "flang/Evaluate/check-expression.h"
13 #include "flang/Runtime/magic-numbers.h"
15 namespace Fortran::evaluate {
17 template <typename T>
18 static std::optional<Expr<SomeType>> ZeroExtend(const Constant<T> &c) {
19 std::vector<Scalar<LargestInt>> exts;
20 for (const auto &v : c.values()) {
21 exts.push_back(Scalar<LargestInt>::ConvertUnsigned(v).value);
23 return AsGenericExpr(
24 Constant<LargestInt>(std::move(exts), ConstantSubscripts(c.shape())));
27 // for ALL, ANY & PARITY
28 template <typename T>
29 static Expr<T> FoldAllAnyParity(FoldingContext &context, FunctionRef<T> &&ref,
30 Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
31 Scalar<T> identity) {
32 static_assert(T::category == TypeCategory::Logical);
33 std::optional<int> dim;
34 if (std::optional<ArrayAndMask<T>> arrayAndMask{
35 ProcessReductionArgs<T>(context, ref.arguments(), dim,
36 /*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
37 OperationAccumulator accumulator{arrayAndMask->array, operation};
38 return Expr<T>{DoReduction<T>(
39 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
41 return Expr<T>{std::move(ref)};
44 template <int KIND>
45 Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
46 FoldingContext &context,
47 FunctionRef<Type<TypeCategory::Logical, KIND>> &&funcRef) {
48 using T = Type<TypeCategory::Logical, KIND>;
49 ActualArguments &args{funcRef.arguments()};
50 auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
51 CHECK(intrinsic);
52 std::string name{intrinsic->name};
53 using SameInt = Type<TypeCategory::Integer, KIND>;
54 if (name == "all") {
55 return FoldAllAnyParity(
56 context, std::move(funcRef), &Scalar<T>::AND, Scalar<T>{true});
57 } else if (name == "any") {
58 return FoldAllAnyParity(
59 context, std::move(funcRef), &Scalar<T>::OR, Scalar<T>{false});
60 } else if (name == "associated") {
61 bool gotConstant{true};
62 const Expr<SomeType> *firstArgExpr{args[0]->UnwrapExpr()};
63 if (!firstArgExpr || !IsNullPointer(*firstArgExpr)) {
64 gotConstant = false;
65 } else if (args[1]) { // There's a second argument
66 const Expr<SomeType> *secondArgExpr{args[1]->UnwrapExpr()};
67 if (!secondArgExpr || !IsNullPointer(*secondArgExpr)) {
68 gotConstant = false;
71 return gotConstant ? Expr<T>{false} : Expr<T>{std::move(funcRef)};
72 } else if (name == "bge" || name == "bgt" || name == "ble" || name == "blt") {
73 static_assert(std::is_same_v<Scalar<LargestInt>, BOZLiteralConstant>);
75 // The arguments to these intrinsics can be of different types. In that
76 // case, the shorter of the two would need to be zero-extended to match
77 // the size of the other. If at least one of the operands is not a constant,
78 // the zero-extending will be done during lowering. Otherwise, the folding
79 // must be done here.
80 std::optional<Expr<SomeType>> constArgs[2];
81 for (int i{0}; i <= 1; i++) {
82 if (BOZLiteralConstant * x{UnwrapExpr<BOZLiteralConstant>(args[i])}) {
83 constArgs[i] = AsGenericExpr(Constant<LargestInt>{std::move(*x)});
84 } else if (auto *x{UnwrapExpr<Expr<SomeInteger>>(args[i])}) {
85 common::visit(
86 [&](const auto &ix) {
87 using IntT = typename std::decay_t<decltype(ix)>::Result;
88 if (auto *c{UnwrapConstantValue<IntT>(ix)}) {
89 constArgs[i] = ZeroExtend(*c);
92 x->u);
96 if (constArgs[0] && constArgs[1]) {
97 auto fptr{&Scalar<LargestInt>::BGE};
98 if (name == "bge") { // done in fptr declaration
99 } else if (name == "bgt") {
100 fptr = &Scalar<LargestInt>::BGT;
101 } else if (name == "ble") {
102 fptr = &Scalar<LargestInt>::BLE;
103 } else if (name == "blt") {
104 fptr = &Scalar<LargestInt>::BLT;
105 } else {
106 common::die("missing case to fold intrinsic function %s", name.c_str());
109 for (int i{0}; i <= 1; i++) {
110 *args[i] = std::move(constArgs[i].value());
113 return FoldElementalIntrinsic<T, LargestInt, LargestInt>(context,
114 std::move(funcRef),
115 ScalarFunc<T, LargestInt, LargestInt>(
116 [&fptr](
117 const Scalar<LargestInt> &i, const Scalar<LargestInt> &j) {
118 return Scalar<T>{std::invoke(fptr, i, j)};
119 }));
120 } else {
121 return Expr<T>{std::move(funcRef)};
123 } else if (name == "btest") {
124 if (const auto *ix{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
125 return common::visit(
126 [&](const auto &x) {
127 using IT = ResultType<decltype(x)>;
128 return FoldElementalIntrinsic<T, IT, SameInt>(context,
129 std::move(funcRef),
130 ScalarFunc<T, IT, SameInt>(
131 [&](const Scalar<IT> &x, const Scalar<SameInt> &pos) {
132 auto posVal{pos.ToInt64()};
133 if (posVal < 0 || posVal >= x.bits) {
134 context.messages().Say(
135 "POS=%jd out of range for BTEST"_err_en_US,
136 static_cast<std::intmax_t>(posVal));
138 return Scalar<T>{x.BTEST(posVal)};
139 }));
141 ix->u);
143 } else if (name == "dot_product") {
144 return FoldDotProduct<T>(context, std::move(funcRef));
145 } else if (name == "extends_type_of") {
146 // Type extension testing with EXTENDS_TYPE_OF() ignores any type
147 // parameters. Returns a constant truth value when the result is known now.
148 if (args[0] && args[1]) {
149 auto t0{args[0]->GetType()};
150 auto t1{args[1]->GetType()};
151 if (t0 && t1) {
152 if (auto result{t0->ExtendsTypeOf(*t1)}) {
153 return Expr<T>{*result};
157 } else if (name == "isnan" || name == "__builtin_ieee_is_nan") {
158 // Only replace the type of the function if we can do the fold
159 if (args[0] && args[0]->UnwrapExpr() &&
160 IsActuallyConstant(*args[0]->UnwrapExpr())) {
161 auto restorer{context.messages().DiscardMessages()};
162 using DefaultReal = Type<TypeCategory::Real, 4>;
163 return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
164 ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
165 return Scalar<T>{x.IsNotANumber()};
166 }));
168 } else if (name == "__builtin_ieee_is_negative") {
169 auto restorer{context.messages().DiscardMessages()};
170 using DefaultReal = Type<TypeCategory::Real, 4>;
171 if (args[0] && args[0]->UnwrapExpr() &&
172 IsActuallyConstant(*args[0]->UnwrapExpr())) {
173 return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
174 ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
175 return Scalar<T>{x.IsNegative()};
176 }));
178 } else if (name == "__builtin_ieee_is_normal") {
179 auto restorer{context.messages().DiscardMessages()};
180 using DefaultReal = Type<TypeCategory::Real, 4>;
181 if (args[0] && args[0]->UnwrapExpr() &&
182 IsActuallyConstant(*args[0]->UnwrapExpr())) {
183 return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
184 ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
185 return Scalar<T>{x.IsNormal()};
186 }));
188 } else if (name == "is_contiguous") {
189 if (args.at(0)) {
190 if (auto *expr{args[0]->UnwrapExpr()}) {
191 if (auto contiguous{IsContiguous(*expr, context)}) {
192 return Expr<T>{*contiguous};
194 } else if (auto *assumedType{args[0]->GetAssumedTypeDummy()}) {
195 if (auto contiguous{IsContiguous(*assumedType, context)}) {
196 return Expr<T>{*contiguous};
200 } else if (name == "is_iostat_end") {
201 if (args[0] && args[0]->UnwrapExpr() &&
202 IsActuallyConstant(*args[0]->UnwrapExpr())) {
203 using Int64 = Type<TypeCategory::Integer, 8>;
204 return FoldElementalIntrinsic<T, Int64>(context, std::move(funcRef),
205 ScalarFunc<T, Int64>([](const Scalar<Int64> &x) {
206 return Scalar<T>{x.ToInt64() == FORTRAN_RUNTIME_IOSTAT_END};
207 }));
209 } else if (name == "is_iostat_eor") {
210 if (args[0] && args[0]->UnwrapExpr() &&
211 IsActuallyConstant(*args[0]->UnwrapExpr())) {
212 using Int64 = Type<TypeCategory::Integer, 8>;
213 return FoldElementalIntrinsic<T, Int64>(context, std::move(funcRef),
214 ScalarFunc<T, Int64>([](const Scalar<Int64> &x) {
215 return Scalar<T>{x.ToInt64() == FORTRAN_RUNTIME_IOSTAT_EOR};
216 }));
218 } else if (name == "lge" || name == "lgt" || name == "lle" || name == "llt") {
219 // Rewrite LGE/LGT/LLE/LLT into ASCII character relations
220 auto *cx0{UnwrapExpr<Expr<SomeCharacter>>(args[0])};
221 auto *cx1{UnwrapExpr<Expr<SomeCharacter>>(args[1])};
222 if (cx0 && cx1) {
223 return Fold(context,
224 ConvertToType<T>(
225 PackageRelation(name == "lge" ? RelationalOperator::GE
226 : name == "lgt" ? RelationalOperator::GT
227 : name == "lle" ? RelationalOperator::LE
228 : RelationalOperator::LT,
229 ConvertToType<Ascii>(std::move(*cx0)),
230 ConvertToType<Ascii>(std::move(*cx1)))));
232 } else if (name == "logical") {
233 if (auto *expr{UnwrapExpr<Expr<SomeLogical>>(args[0])}) {
234 return Fold(context, ConvertToType<T>(std::move(*expr)));
236 } else if (name == "matmul") {
237 return FoldMatmul(context, std::move(funcRef));
238 } else if (name == "out_of_range") {
239 if (Expr<SomeType> * cx{UnwrapExpr<Expr<SomeType>>(args[0])}) {
240 auto restorer{context.messages().DiscardMessages()};
241 *args[0] = Fold(context, std::move(*cx));
242 if (Expr<SomeType> & folded{DEREF(args[0].value().UnwrapExpr())};
243 IsActuallyConstant(folded)) {
244 std::optional<std::vector<typename T::Scalar>> result;
245 if (Expr<SomeReal> * realMold{UnwrapExpr<Expr<SomeReal>>(args[1])}) {
246 if (const auto *xInt{UnwrapExpr<Expr<SomeInteger>>(folded)}) {
247 result.emplace();
248 std::visit(
249 [&](const auto &mold, const auto &x) {
250 using RealType =
251 typename std::decay_t<decltype(mold)>::Result;
252 static_assert(RealType::category == TypeCategory::Real);
253 using Scalar = typename RealType::Scalar;
254 using xType = typename std::decay_t<decltype(x)>::Result;
255 const auto &xConst{DEREF(UnwrapExpr<Constant<xType>>(x))};
256 for (const auto &elt : xConst.values()) {
257 result->emplace_back(
258 Scalar::template FromInteger(elt).flags.test(
259 RealFlag::Overflow));
262 realMold->u, xInt->u);
263 } else if (const auto *xReal{UnwrapExpr<Expr<SomeReal>>(folded)}) {
264 result.emplace();
265 std::visit(
266 [&](const auto &mold, const auto &x) {
267 using RealType =
268 typename std::decay_t<decltype(mold)>::Result;
269 static_assert(RealType::category == TypeCategory::Real);
270 using Scalar = typename RealType::Scalar;
271 using xType = typename std::decay_t<decltype(x)>::Result;
272 const auto &xConst{DEREF(UnwrapExpr<Constant<xType>>(x))};
273 for (const auto &elt : xConst.values()) {
274 result->emplace_back(elt.IsFinite() &&
275 Scalar::template Convert(elt).flags.test(
276 RealFlag::Overflow));
279 realMold->u, xReal->u);
281 } else if (Expr<SomeInteger> *
282 intMold{UnwrapExpr<Expr<SomeInteger>>(args[1])}) {
283 if (const auto *xInt{UnwrapExpr<Expr<SomeInteger>>(folded)}) {
284 result.emplace();
285 std::visit(
286 [&](const auto &mold, const auto &x) {
287 using IntType = typename std::decay_t<decltype(mold)>::Result;
288 static_assert(IntType::category == TypeCategory::Integer);
289 using Scalar = typename IntType::Scalar;
290 using xType = typename std::decay_t<decltype(x)>::Result;
291 const auto &xConst{DEREF(UnwrapExpr<Constant<xType>>(x))};
292 for (const auto &elt : xConst.values()) {
293 result->emplace_back(
294 Scalar::template ConvertSigned(elt).overflow);
297 intMold->u, xInt->u);
298 } else if (Expr<SomeLogical> *
299 cRound{args.size() >= 3
300 ? UnwrapExpr<Expr<SomeLogical>>(args[2])
301 : nullptr};
302 !cRound || IsActuallyConstant(*args[2]->UnwrapExpr())) {
303 if (const auto *xReal{UnwrapExpr<Expr<SomeReal>>(folded)}) {
304 common::RoundingMode roundingMode{common::RoundingMode::ToZero};
305 if (cRound &&
306 common::visit(
307 [](const auto &x) {
308 using xType =
309 typename std::decay_t<decltype(x)>::Result;
310 return GetScalarConstantValue<xType>(x)
311 .value()
312 .IsTrue();
314 cRound->u)) {
315 // ROUND=.TRUE. - convert with NINT()
316 roundingMode = common::RoundingMode::TiesAwayFromZero;
318 result.emplace();
319 std::visit(
320 [&](const auto &mold, const auto &x) {
321 using IntType =
322 typename std::decay_t<decltype(mold)>::Result;
323 static_assert(IntType::category == TypeCategory::Integer);
324 using Scalar = typename IntType::Scalar;
325 using xType = typename std::decay_t<decltype(x)>::Result;
326 const auto &xConst{DEREF(UnwrapExpr<Constant<xType>>(x))};
327 for (const auto &elt : xConst.values()) {
328 // Note that OUT_OF_RANGE(Inf/NaN) is .TRUE. for the
329 // real->integer case, but not for real->real.
330 result->emplace_back(!elt.IsFinite() ||
331 elt.template ToInteger<Scalar>(roundingMode)
332 .flags.test(RealFlag::Overflow));
335 intMold->u, xReal->u);
339 if (result) {
340 if (auto extents{GetConstantExtents(context, folded)}) {
341 return Expr<T>{
342 Constant<T>{std::move(*result), std::move(*extents)}};
347 } else if (name == "parity") {
348 return FoldAllAnyParity(
349 context, std::move(funcRef), &Scalar<T>::NEQV, Scalar<T>{false});
350 } else if (name == "same_type_as") {
351 // Type equality testing with SAME_TYPE_AS() ignores any type parameters.
352 // Returns a constant truth value when the result is known now.
353 if (args[0] && args[1]) {
354 auto t0{args[0]->GetType()};
355 auto t1{args[1]->GetType()};
356 if (t0 && t1) {
357 if (auto result{t0->SameTypeAs(*t1)}) {
358 return Expr<T>{*result};
362 } else if (name == "__builtin_ieee_support_datatype" ||
363 name == "__builtin_ieee_support_denormal" ||
364 name == "__builtin_ieee_support_divide" ||
365 name == "__builtin_ieee_support_inf" ||
366 name == "__builtin_ieee_support_io" ||
367 name == "__builtin_ieee_support_nan" ||
368 name == "__builtin_ieee_support_sqrt" ||
369 name == "__builtin_ieee_support_standard" ||
370 name == "__builtin_ieee_support_subnormal" ||
371 name == "__builtin_ieee_support_underflow_control") {
372 return Expr<T>{true};
374 return Expr<T>{std::move(funcRef)};
377 template <typename T>
378 Expr<LogicalResult> FoldOperation(
379 FoldingContext &context, Relational<T> &&relation) {
380 if (auto array{ApplyElementwise(context, relation,
381 std::function<Expr<LogicalResult>(Expr<T> &&, Expr<T> &&)>{
382 [=](Expr<T> &&x, Expr<T> &&y) {
383 return Expr<LogicalResult>{Relational<SomeType>{
384 Relational<T>{relation.opr, std::move(x), std::move(y)}}};
385 }})}) {
386 return *array;
388 if (auto folded{OperandsAreConstants(relation)}) {
389 bool result{};
390 if constexpr (T::category == TypeCategory::Integer) {
391 result =
392 Satisfies(relation.opr, folded->first.CompareSigned(folded->second));
393 } else if constexpr (T::category == TypeCategory::Real) {
394 result = Satisfies(relation.opr, folded->first.Compare(folded->second));
395 } else if constexpr (T::category == TypeCategory::Complex) {
396 result = (relation.opr == RelationalOperator::EQ) ==
397 folded->first.Equals(folded->second);
398 } else if constexpr (T::category == TypeCategory::Character) {
399 result = Satisfies(relation.opr, Compare(folded->first, folded->second));
400 } else {
401 static_assert(T::category != TypeCategory::Logical);
403 return Expr<LogicalResult>{Constant<LogicalResult>{result}};
405 return Expr<LogicalResult>{Relational<SomeType>{std::move(relation)}};
408 Expr<LogicalResult> FoldOperation(
409 FoldingContext &context, Relational<SomeType> &&relation) {
410 return common::visit(
411 [&](auto &&x) {
412 return Expr<LogicalResult>{FoldOperation(context, std::move(x))};
414 std::move(relation.u));
417 template <int KIND>
418 Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
419 FoldingContext &context, Not<KIND> &&x) {
420 if (auto array{ApplyElementwise(context, x)}) {
421 return *array;
423 using Ty = Type<TypeCategory::Logical, KIND>;
424 auto &operand{x.left()};
425 if (auto value{GetScalarConstantValue<Ty>(operand)}) {
426 return Expr<Ty>{Constant<Ty>{!value->IsTrue()}};
428 return Expr<Ty>{x};
431 template <int KIND>
432 Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
433 FoldingContext &context, LogicalOperation<KIND> &&operation) {
434 using LOGICAL = Type<TypeCategory::Logical, KIND>;
435 if (auto array{ApplyElementwise(context, operation,
436 std::function<Expr<LOGICAL>(Expr<LOGICAL> &&, Expr<LOGICAL> &&)>{
437 [=](Expr<LOGICAL> &&x, Expr<LOGICAL> &&y) {
438 return Expr<LOGICAL>{LogicalOperation<KIND>{
439 operation.logicalOperator, std::move(x), std::move(y)}};
440 }})}) {
441 return *array;
443 if (auto folded{OperandsAreConstants(operation)}) {
444 bool xt{folded->first.IsTrue()}, yt{folded->second.IsTrue()}, result{};
445 switch (operation.logicalOperator) {
446 case LogicalOperator::And:
447 result = xt && yt;
448 break;
449 case LogicalOperator::Or:
450 result = xt || yt;
451 break;
452 case LogicalOperator::Eqv:
453 result = xt == yt;
454 break;
455 case LogicalOperator::Neqv:
456 result = xt != yt;
457 break;
458 case LogicalOperator::Not:
459 DIE("not a binary operator");
461 return Expr<LOGICAL>{Constant<LOGICAL>{result}};
463 return Expr<LOGICAL>{std::move(operation)};
466 #ifdef _MSC_VER // disable bogus warning about missing definitions
467 #pragma warning(disable : 4661)
468 #endif
469 FOR_EACH_LOGICAL_KIND(template class ExpressionBase, )
470 template class ExpressionBase<SomeLogical>;
471 } // namespace Fortran::evaluate