Break circular dependency between FIR dialect and utilities
[llvm-project.git] / flang / lib / Evaluate / fold-logical.cpp
blob43406497a9e66d3b6798f94a84e9e8921097d3d2
1 //===-- lib/Evaluate/fold-logical.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "fold-implementation.h"
10 #include "fold-reduction.h"
11 #include "flang/Evaluate/check-expression.h"
13 namespace Fortran::evaluate {
15 template <typename T>
16 static std::optional<Expr<SomeType>> ZeroExtend(const Constant<T> &c) {
17 std::vector<Scalar<LargestInt>> exts;
18 for (const auto &v : c.values()) {
19 exts.push_back(Scalar<LargestInt>::ConvertUnsigned(v).value);
21 return AsGenericExpr(
22 Constant<LargestInt>(std::move(exts), ConstantSubscripts(c.shape())));
25 // for ALL, ANY & PARITY
26 template <typename T>
27 static Expr<T> FoldAllAnyParity(FoldingContext &context, FunctionRef<T> &&ref,
28 Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
29 Scalar<T> identity) {
30 static_assert(T::category == TypeCategory::Logical);
31 using Element = Scalar<T>;
32 std::optional<int> dim;
33 if (std::optional<Constant<T>> array{
34 ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
35 /*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
36 auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
37 element = (element.*operation)(array->At(at));
38 }};
39 return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
41 return Expr<T>{std::move(ref)};
44 template <int KIND>
45 Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
46 FoldingContext &context,
47 FunctionRef<Type<TypeCategory::Logical, KIND>> &&funcRef) {
48 using T = Type<TypeCategory::Logical, KIND>;
49 ActualArguments &args{funcRef.arguments()};
50 auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
51 CHECK(intrinsic);
52 std::string name{intrinsic->name};
53 using SameInt = Type<TypeCategory::Integer, KIND>;
54 if (name == "all") {
55 return FoldAllAnyParity(
56 context, std::move(funcRef), &Scalar<T>::AND, Scalar<T>{true});
57 } else if (name == "any") {
58 return FoldAllAnyParity(
59 context, std::move(funcRef), &Scalar<T>::OR, Scalar<T>{false});
60 } else if (name == "associated") {
61 bool gotConstant{true};
62 const Expr<SomeType> *firstArgExpr{args[0]->UnwrapExpr()};
63 if (!firstArgExpr || !IsNullPointer(*firstArgExpr)) {
64 gotConstant = false;
65 } else if (args[1]) { // There's a second argument
66 const Expr<SomeType> *secondArgExpr{args[1]->UnwrapExpr()};
67 if (!secondArgExpr || !IsNullPointer(*secondArgExpr)) {
68 gotConstant = false;
71 return gotConstant ? Expr<T>{false} : Expr<T>{std::move(funcRef)};
72 } else if (name == "bge" || name == "bgt" || name == "ble" || name == "blt") {
73 static_assert(std::is_same_v<Scalar<LargestInt>, BOZLiteralConstant>);
75 // The arguments to these intrinsics can be of different types. In that
76 // case, the shorter of the two would need to be zero-extended to match
77 // the size of the other. If at least one of the operands is not a constant,
78 // the zero-extending will be done during lowering. Otherwise, the folding
79 // must be done here.
80 std::optional<Expr<SomeType>> constArgs[2];
81 for (int i{0}; i <= 1; i++) {
82 if (BOZLiteralConstant * x{UnwrapExpr<BOZLiteralConstant>(args[i])}) {
83 constArgs[i] = AsGenericExpr(Constant<LargestInt>{std::move(*x)});
84 } else if (auto *x{UnwrapExpr<Expr<SomeInteger>>(args[i])}) {
85 common::visit(
86 [&](const auto &ix) {
87 using IntT = typename std::decay_t<decltype(ix)>::Result;
88 if (auto *c{UnwrapConstantValue<IntT>(ix)}) {
89 constArgs[i] = ZeroExtend(*c);
92 x->u);
96 if (constArgs[0] && constArgs[1]) {
97 auto fptr{&Scalar<LargestInt>::BGE};
98 if (name == "bge") { // done in fptr declaration
99 } else if (name == "bgt") {
100 fptr = &Scalar<LargestInt>::BGT;
101 } else if (name == "ble") {
102 fptr = &Scalar<LargestInt>::BLE;
103 } else if (name == "blt") {
104 fptr = &Scalar<LargestInt>::BLT;
105 } else {
106 common::die("missing case to fold intrinsic function %s", name.c_str());
109 for (int i{0}; i <= 1; i++) {
110 *args[i] = std::move(constArgs[i].value());
113 return FoldElementalIntrinsic<T, LargestInt, LargestInt>(context,
114 std::move(funcRef),
115 ScalarFunc<T, LargestInt, LargestInt>(
116 [&fptr](
117 const Scalar<LargestInt> &i, const Scalar<LargestInt> &j) {
118 return Scalar<T>{std::invoke(fptr, i, j)};
119 }));
120 } else {
121 return Expr<T>{std::move(funcRef)};
123 } else if (name == "btest") {
124 if (const auto *ix{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
125 return common::visit(
126 [&](const auto &x) {
127 using IT = ResultType<decltype(x)>;
128 return FoldElementalIntrinsic<T, IT, SameInt>(context,
129 std::move(funcRef),
130 ScalarFunc<T, IT, SameInt>(
131 [&](const Scalar<IT> &x, const Scalar<SameInt> &pos) {
132 auto posVal{pos.ToInt64()};
133 if (posVal < 0 || posVal >= x.bits) {
134 context.messages().Say(
135 "POS=%jd out of range for BTEST"_err_en_US,
136 static_cast<std::intmax_t>(posVal));
138 return Scalar<T>{x.BTEST(posVal)};
139 }));
141 ix->u);
143 } else if (name == "dot_product") {
144 return FoldDotProduct<T>(context, std::move(funcRef));
145 } else if (name == "extends_type_of") {
146 // Type extension testing with EXTENDS_TYPE_OF() ignores any type
147 // parameters. Returns a constant truth value when the result is known now.
148 if (args[0] && args[1]) {
149 auto t0{args[0]->GetType()};
150 auto t1{args[1]->GetType()};
151 if (t0 && t1) {
152 if (auto result{t0->ExtendsTypeOf(*t1)}) {
153 return Expr<T>{*result};
157 } else if (name == "isnan" || name == "__builtin_ieee_is_nan") {
158 using DefaultReal = Type<TypeCategory::Real, 4>;
159 // Only replace the type of the function if we can do the fold
160 if (args[0] && args[0]->UnwrapExpr() &&
161 IsActuallyConstant(*args[0]->UnwrapExpr())) {
162 return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
163 ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
164 return Scalar<T>{x.IsNotANumber()};
165 }));
167 } else if (name == "__builtin_ieee_is_negative") {
168 auto restorer{context.messages().DiscardMessages()};
169 using DefaultReal = Type<TypeCategory::Real, 4>;
170 return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
171 ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
172 return Scalar<T>{x.IsNegative()};
173 }));
174 } else if (name == "__builtin_ieee_is_normal") {
175 auto restorer{context.messages().DiscardMessages()};
176 using DefaultReal = Type<TypeCategory::Real, 4>;
177 if (args[0] && args[0]->UnwrapExpr() &&
178 IsActuallyConstant(*args[0]->UnwrapExpr())) {
179 return FoldElementalIntrinsic<T, DefaultReal>(context, std::move(funcRef),
180 ScalarFunc<T, DefaultReal>([](const Scalar<DefaultReal> &x) {
181 return Scalar<T>{x.IsNormal()};
182 }));
184 } else if (name == "is_contiguous") {
185 if (args.at(0)) {
186 if (auto *expr{args[0]->UnwrapExpr()}) {
187 if (auto contiguous{IsContiguous(*expr, context)}) {
188 return Expr<T>{*contiguous};
192 } else if (name == "lge" || name == "lgt" || name == "lle" || name == "llt") {
193 // Rewrite LGE/LGT/LLE/LLT into ASCII character relations
194 auto *cx0{UnwrapExpr<Expr<SomeCharacter>>(args[0])};
195 auto *cx1{UnwrapExpr<Expr<SomeCharacter>>(args[1])};
196 if (cx0 && cx1) {
197 return Fold(context,
198 ConvertToType<T>(
199 PackageRelation(name == "lge" ? RelationalOperator::GE
200 : name == "lgt" ? RelationalOperator::GT
201 : name == "lle" ? RelationalOperator::LE
202 : RelationalOperator::LT,
203 ConvertToType<Ascii>(std::move(*cx0)),
204 ConvertToType<Ascii>(std::move(*cx1)))));
206 } else if (name == "logical") {
207 if (auto *expr{UnwrapExpr<Expr<SomeLogical>>(args[0])}) {
208 return Fold(context, ConvertToType<T>(std::move(*expr)));
210 } else if (name == "merge") {
211 return FoldMerge<T>(context, std::move(funcRef));
212 } else if (name == "parity") {
213 return FoldAllAnyParity(
214 context, std::move(funcRef), &Scalar<T>::NEQV, Scalar<T>{false});
215 } else if (name == "same_type_as") {
216 // Type equality testing with SAME_TYPE_AS() ignores any type parameters.
217 // Returns a constant truth value when the result is known now.
218 if (args[0] && args[1]) {
219 auto t0{args[0]->GetType()};
220 auto t1{args[1]->GetType()};
221 if (t0 && t1) {
222 if (auto result{t0->SameTypeAs(*t1)}) {
223 return Expr<T>{*result};
227 } else if (name == "__builtin_ieee_support_datatype" ||
228 name == "__builtin_ieee_support_denormal" ||
229 name == "__builtin_ieee_support_divide" ||
230 name == "__builtin_ieee_support_inf" ||
231 name == "__builtin_ieee_support_io" ||
232 name == "__builtin_ieee_support_nan" ||
233 name == "__builtin_ieee_support_sqrt" ||
234 name == "__builtin_ieee_support_standard" ||
235 name == "__builtin_ieee_support_subnormal" ||
236 name == "__builtin_ieee_support_underflow_control") {
237 return Expr<T>{true};
239 // TODO: is_iostat_end,
240 // is_iostat_eor, logical, matmul, out_of_range,
241 // parity
242 return Expr<T>{std::move(funcRef)};
245 template <typename T>
246 Expr<LogicalResult> FoldOperation(
247 FoldingContext &context, Relational<T> &&relation) {
248 if (auto array{ApplyElementwise(context, relation,
249 std::function<Expr<LogicalResult>(Expr<T> &&, Expr<T> &&)>{
250 [=](Expr<T> &&x, Expr<T> &&y) {
251 return Expr<LogicalResult>{Relational<SomeType>{
252 Relational<T>{relation.opr, std::move(x), std::move(y)}}};
253 }})}) {
254 return *array;
256 if (auto folded{OperandsAreConstants(relation)}) {
257 bool result{};
258 if constexpr (T::category == TypeCategory::Integer) {
259 result =
260 Satisfies(relation.opr, folded->first.CompareSigned(folded->second));
261 } else if constexpr (T::category == TypeCategory::Real) {
262 result = Satisfies(relation.opr, folded->first.Compare(folded->second));
263 } else if constexpr (T::category == TypeCategory::Complex) {
264 result = (relation.opr == RelationalOperator::EQ) ==
265 folded->first.Equals(folded->second);
266 } else if constexpr (T::category == TypeCategory::Character) {
267 result = Satisfies(relation.opr, Compare(folded->first, folded->second));
268 } else {
269 static_assert(T::category != TypeCategory::Logical);
271 return Expr<LogicalResult>{Constant<LogicalResult>{result}};
273 return Expr<LogicalResult>{Relational<SomeType>{std::move(relation)}};
276 Expr<LogicalResult> FoldOperation(
277 FoldingContext &context, Relational<SomeType> &&relation) {
278 return common::visit(
279 [&](auto &&x) {
280 return Expr<LogicalResult>{FoldOperation(context, std::move(x))};
282 std::move(relation.u));
285 template <int KIND>
286 Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
287 FoldingContext &context, Not<KIND> &&x) {
288 if (auto array{ApplyElementwise(context, x)}) {
289 return *array;
291 using Ty = Type<TypeCategory::Logical, KIND>;
292 auto &operand{x.left()};
293 if (auto value{GetScalarConstantValue<Ty>(operand)}) {
294 return Expr<Ty>{Constant<Ty>{!value->IsTrue()}};
296 return Expr<Ty>{x};
299 template <int KIND>
300 Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(
301 FoldingContext &context, LogicalOperation<KIND> &&operation) {
302 using LOGICAL = Type<TypeCategory::Logical, KIND>;
303 if (auto array{ApplyElementwise(context, operation,
304 std::function<Expr<LOGICAL>(Expr<LOGICAL> &&, Expr<LOGICAL> &&)>{
305 [=](Expr<LOGICAL> &&x, Expr<LOGICAL> &&y) {
306 return Expr<LOGICAL>{LogicalOperation<KIND>{
307 operation.logicalOperator, std::move(x), std::move(y)}};
308 }})}) {
309 return *array;
311 if (auto folded{OperandsAreConstants(operation)}) {
312 bool xt{folded->first.IsTrue()}, yt{folded->second.IsTrue()}, result{};
313 switch (operation.logicalOperator) {
314 case LogicalOperator::And:
315 result = xt && yt;
316 break;
317 case LogicalOperator::Or:
318 result = xt || yt;
319 break;
320 case LogicalOperator::Eqv:
321 result = xt == yt;
322 break;
323 case LogicalOperator::Neqv:
324 result = xt != yt;
325 break;
326 case LogicalOperator::Not:
327 DIE("not a binary operator");
329 return Expr<LOGICAL>{Constant<LOGICAL>{result}};
331 return Expr<LOGICAL>{std::move(operation)};
334 #ifdef _MSC_VER // disable bogus warning about missing definitions
335 #pragma warning(disable : 4661)
336 #endif
337 FOR_EACH_LOGICAL_KIND(template class ExpressionBase, )
338 template class ExpressionBase<SomeLogical>;
339 } // namespace Fortran::evaluate