[MemProf] Templatize CallStackRadixTreeBuilder (NFC) (#117014)
[llvm-project.git] / flang / lib / Evaluate / fold-integer.cpp
blob0ad09d76a6555d889003f56e22b791dd15d6d28b
1 //===-- lib/Evaluate/fold-integer.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "fold-implementation.h"
10 #include "fold-matmul.h"
11 #include "fold-reduction.h"
12 #include "flang/Evaluate/check-expression.h"
14 namespace Fortran::evaluate {
16 // Given a collection of ConstantSubscripts values, package them as a Constant.
17 // Return scalar value if asScalar == true and shape-dim array otherwise.
18 template <typename T>
19 Expr<T> PackageConstantBounds(
20 const ConstantSubscripts &&bounds, bool asScalar = false) {
21 if (asScalar) {
22 return Expr<T>{Constant<T>{bounds.at(0)}};
23 } else {
24 // As rank-dim array
25 const int rank{GetRank(bounds)};
26 std::vector<Scalar<T>> packed(rank);
27 std::transform(bounds.begin(), bounds.end(), packed.begin(),
28 [](ConstantSubscript x) { return Scalar<T>(x); });
29 return Expr<T>{Constant<T>{std::move(packed), ConstantSubscripts{rank}}};
33 // If a DIM= argument to LBOUND(), UBOUND(), or SIZE() exists and has a valid
34 // constant value, return in "dimVal" that value, less 1 (to make it suitable
35 // for use as a C++ vector<> index). Also check for erroneous constant values
36 // and returns false on error.
37 static bool CheckDimArg(const std::optional<ActualArgument> &dimArg,
38 const Expr<SomeType> &array, parser::ContextualMessages &messages,
39 bool isLBound, std::optional<int> &dimVal) {
40 dimVal.reset();
41 if (int rank{array.Rank()}; rank > 0 || IsAssumedRank(array)) {
42 auto named{ExtractNamedEntity(array)};
43 if (auto dim64{ToInt64(dimArg)}) {
44 if (*dim64 < 1) {
45 messages.Say("DIM=%jd dimension must be positive"_err_en_US, *dim64);
46 return false;
47 } else if (!IsAssumedRank(array) && *dim64 > rank) {
48 messages.Say(
49 "DIM=%jd dimension is out of range for rank-%d array"_err_en_US,
50 *dim64, rank);
51 return false;
52 } else if (!isLBound && named &&
53 semantics::IsAssumedSizeArray(named->GetLastSymbol()) &&
54 *dim64 == rank) {
55 messages.Say(
56 "DIM=%jd dimension is out of range for rank-%d assumed-size array"_err_en_US,
57 *dim64, rank);
58 return false;
59 } else if (IsAssumedRank(array)) {
60 if (*dim64 > common::maxRank) {
61 messages.Say(
62 "DIM=%jd dimension is too large for any array (maximum rank %d)"_err_en_US,
63 *dim64, common::maxRank);
64 return false;
66 } else {
67 dimVal = static_cast<int>(*dim64 - 1); // 1-based to 0-based
71 return true;
74 // Class to retrieve the constant bound of an expression which is an
75 // array that devolves to a type of Constant<T>
76 class GetConstantArrayBoundHelper {
77 public:
78 template <typename T>
79 static Expr<T> GetLbound(
80 const Expr<SomeType> &array, std::optional<int> dim) {
81 return PackageConstantBounds<T>(
82 GetConstantArrayBoundHelper(dim, /*getLbound=*/true).Get(array),
83 dim.has_value());
86 template <typename T>
87 static Expr<T> GetUbound(
88 const Expr<SomeType> &array, std::optional<int> dim) {
89 return PackageConstantBounds<T>(
90 GetConstantArrayBoundHelper(dim, /*getLbound=*/false).Get(array),
91 dim.has_value());
94 private:
95 GetConstantArrayBoundHelper(
96 std::optional<ConstantSubscript> dim, bool getLbound)
97 : dim_{dim}, getLbound_{getLbound} {}
99 template <typename T> ConstantSubscripts Get(const T &) {
100 // The method is needed for template expansion, but we should never get
101 // here in practice.
102 CHECK(false);
103 return {0};
106 template <typename T> ConstantSubscripts Get(const Constant<T> &x) {
107 if (getLbound_) {
108 // Return the lower bound
109 if (dim_) {
110 return {x.lbounds().at(*dim_)};
111 } else {
112 return x.lbounds();
114 } else {
115 // Return the upper bound
116 if (arrayFromParenthesesExpr) {
117 // Underlying array comes from (x) expression - return shapes
118 if (dim_) {
119 return {x.shape().at(*dim_)};
120 } else {
121 return x.shape();
123 } else {
124 return x.ComputeUbounds(dim_);
129 template <typename T> ConstantSubscripts Get(const Parentheses<T> &x) {
130 // Case of temp variable inside parentheses - return [1, ... 1] for lower
131 // bounds and shape for upper bounds
132 if (getLbound_) {
133 return ConstantSubscripts(x.Rank(), ConstantSubscript{1});
134 } else {
135 // Indicate that underlying array comes from parentheses expression.
136 // Continue to unwrap expression until we hit a constant
137 arrayFromParenthesesExpr = true;
138 return Get(x.left());
142 template <typename T> ConstantSubscripts Get(const Expr<T> &x) {
143 // recurse through Expr<T>'a until we hit a constant
144 return common::visit([&](const auto &inner) { return Get(inner); },
145 // [&](const auto &) { return 0; },
146 x.u);
149 const std::optional<ConstantSubscript> dim_;
150 const bool getLbound_;
151 bool arrayFromParenthesesExpr{false};
154 template <int KIND>
155 Expr<Type<TypeCategory::Integer, KIND>> LBOUND(FoldingContext &context,
156 FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) {
157 using T = Type<TypeCategory::Integer, KIND>;
158 ActualArguments &args{funcRef.arguments()};
159 if (const auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
160 std::optional<int> dim;
161 if (funcRef.Rank() == 0) {
162 // Optional DIM= argument is present: result is scalar.
163 if (!CheckDimArg(args[1], *array, context.messages(), true, dim)) {
164 return MakeInvalidIntrinsic<T>(std::move(funcRef));
165 } else if (!dim) {
166 // DIM= is present but not constant, or error
167 return Expr<T>{std::move(funcRef)};
170 if (IsAssumedRank(*array)) {
171 // Would like to return 1 if DIM=.. is present, but that would be
172 // hiding a runtime error if the DIM= were too large (including
173 // the case of an assumed-rank argument that's scalar).
174 } else if (int rank{array->Rank()}; rank > 0) {
175 bool lowerBoundsAreOne{true};
176 if (auto named{ExtractNamedEntity(*array)}) {
177 const Symbol &symbol{named->GetLastSymbol()};
178 if (symbol.Rank() == rank) {
179 lowerBoundsAreOne = false;
180 if (dim) {
181 if (auto lb{GetLBOUND(context, *named, *dim)}) {
182 return Fold(context, ConvertToType<T>(std::move(*lb)));
184 } else if (auto extents{
185 AsExtentArrayExpr(GetLBOUNDs(context, *named))}) {
186 return Fold(context,
187 ConvertToType<T>(Expr<ExtentType>{std::move(*extents)}));
189 } else {
190 lowerBoundsAreOne = symbol.Rank() == 0; // LBOUND(array%component)
193 if (IsActuallyConstant(*array)) {
194 return GetConstantArrayBoundHelper::GetLbound<T>(*array, dim);
196 if (lowerBoundsAreOne) {
197 ConstantSubscripts ones(rank, ConstantSubscript{1});
198 return PackageConstantBounds<T>(std::move(ones), dim.has_value());
202 return Expr<T>{std::move(funcRef)};
205 template <int KIND>
206 Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
207 FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) {
208 using T = Type<TypeCategory::Integer, KIND>;
209 ActualArguments &args{funcRef.arguments()};
210 if (auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
211 std::optional<int> dim;
212 if (funcRef.Rank() == 0) {
213 // Optional DIM= argument is present: result is scalar.
214 if (!CheckDimArg(args[1], *array, context.messages(), false, dim)) {
215 return MakeInvalidIntrinsic<T>(std::move(funcRef));
216 } else if (!dim) {
217 // DIM= is present but not constant, or error
218 return Expr<T>{std::move(funcRef)};
221 if (IsAssumedRank(*array)) {
222 } else if (int rank{array->Rank()}; rank > 0) {
223 bool takeBoundsFromShape{true};
224 if (auto named{ExtractNamedEntity(*array)}) {
225 const Symbol &symbol{named->GetLastSymbol()};
226 if (symbol.Rank() == rank) {
227 takeBoundsFromShape = false;
228 if (dim) {
229 if (auto ub{GetUBOUND(context, *named, *dim)}) {
230 return Fold(context, ConvertToType<T>(std::move(*ub)));
232 } else {
233 Shape ubounds{GetUBOUNDs(context, *named)};
234 if (semantics::IsAssumedSizeArray(symbol)) {
235 CHECK(!ubounds.back());
236 ubounds.back() = ExtentExpr{-1};
238 if (auto extents{AsExtentArrayExpr(ubounds)}) {
239 return Fold(context,
240 ConvertToType<T>(Expr<ExtentType>{std::move(*extents)}));
243 } else {
244 takeBoundsFromShape = symbol.Rank() == 0; // UBOUND(array%component)
247 if (IsActuallyConstant(*array)) {
248 return GetConstantArrayBoundHelper::GetUbound<T>(*array, dim);
250 if (takeBoundsFromShape) {
251 if (auto shape{GetContextFreeShape(context, *array)}) {
252 if (dim) {
253 if (auto &dimSize{shape->at(*dim)}) {
254 return Fold(context,
255 ConvertToType<T>(Expr<ExtentType>{std::move(*dimSize)}));
257 } else if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
258 return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
264 return Expr<T>{std::move(funcRef)};
267 // COUNT()
268 template <typename T, int MASK_KIND> class CountAccumulator {
269 using MaskT = Type<TypeCategory::Logical, MASK_KIND>;
271 public:
272 CountAccumulator(const Constant<MaskT> &mask) : mask_{mask} {}
273 void operator()(
274 Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
275 if (mask_.At(at).IsTrue()) {
276 auto incremented{element.AddSigned(Scalar<T>{1})};
277 overflow_ |= incremented.overflow;
278 element = incremented.value;
281 bool overflow() const { return overflow_; }
282 void Done(Scalar<T> &) const {}
284 private:
285 const Constant<MaskT> &mask_;
286 bool overflow_{false};
289 template <typename T, int maskKind>
290 static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
291 using KindLogical = Type<TypeCategory::Logical, maskKind>;
292 static_assert(T::category == TypeCategory::Integer);
293 std::optional<int> dim;
294 if (std::optional<ArrayAndMask<KindLogical>> arrayAndMask{
295 ProcessReductionArgs<KindLogical>(
296 context, ref.arguments(), dim, /*ARRAY=*/0, /*DIM=*/1)}) {
297 CountAccumulator<T, maskKind> accumulator{arrayAndMask->array};
298 Constant<T> result{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask,
299 dim, Scalar<T>{}, accumulator)};
300 if (accumulator.overflow() &&
301 context.languageFeatures().ShouldWarn(
302 common::UsageWarning::FoldingException)) {
303 context.messages().Say(common::UsageWarning::FoldingException,
304 "Result of intrinsic function COUNT overflows its result type"_warn_en_US);
306 return Expr<T>{std::move(result)};
308 return Expr<T>{std::move(ref)};
311 // FINDLOC(), MAXLOC(), & MINLOC()
312 enum class WhichLocation { Findloc, Maxloc, Minloc };
313 template <WhichLocation WHICH> class LocationHelper {
314 public:
315 LocationHelper(
316 DynamicType &&type, ActualArguments &arg, FoldingContext &context)
317 : type_{type}, arg_{arg}, context_{context} {}
318 using Result = std::optional<Constant<SubscriptInteger>>;
319 using Types = std::conditional_t<WHICH == WhichLocation::Findloc,
320 AllIntrinsicTypes, RelationalTypes>;
322 template <typename T> Result Test() const {
323 if (T::category != type_.category() || T::kind != type_.kind()) {
324 return std::nullopt;
326 CHECK(arg_.size() == (WHICH == WhichLocation::Findloc ? 6 : 5));
327 Folder<T> folder{context_};
328 Constant<T> *array{folder.Folding(arg_[0])};
329 if (!array) {
330 return std::nullopt;
332 std::optional<Constant<T>> value;
333 if constexpr (WHICH == WhichLocation::Findloc) {
334 if (const Constant<T> *p{folder.Folding(arg_[1])}) {
335 value.emplace(*p);
336 } else {
337 return std::nullopt;
340 std::optional<int> dim;
341 Constant<LogicalResult> *mask{
342 GetReductionMASK(arg_[maskArg], array->shape(), context_)};
343 if ((!mask && arg_[maskArg]) ||
344 !CheckReductionDIM(dim, context_, arg_, dimArg, array->Rank())) {
345 return std::nullopt;
347 bool back{false};
348 if (arg_[backArg]) {
349 const auto *backConst{
350 Folder<LogicalResult>{context_, /*forOptionalArgument=*/true}.Folding(
351 arg_[backArg])};
352 if (backConst) {
353 back = backConst->GetScalarValue().value().IsTrue();
354 } else {
355 return std::nullopt;
358 const RelationalOperator relation{WHICH == WhichLocation::Findloc
359 ? RelationalOperator::EQ
360 : WHICH == WhichLocation::Maxloc
361 ? (back ? RelationalOperator::GE : RelationalOperator::GT)
362 : back ? RelationalOperator::LE
363 : RelationalOperator::LT};
364 // Use lower bounds of 1 exclusively.
365 array->SetLowerBoundsToOne();
366 ConstantSubscripts at{array->lbounds()}, maskAt, resultIndices, resultShape;
367 if (mask) {
368 if (auto scalarMask{mask->GetScalarValue()}) {
369 // Convert into array in case of scalar MASK= (for
370 // MAXLOC/MINLOC/FINDLOC mask should be conformable)
371 ConstantSubscript n{GetSize(array->shape())};
372 std::vector<Scalar<LogicalResult>> mask_elements(
373 n, Scalar<LogicalResult>{scalarMask.value()});
374 *mask = Constant<LogicalResult>{
375 std::move(mask_elements), ConstantSubscripts{array->shape()}};
377 mask->SetLowerBoundsToOne();
378 maskAt = mask->lbounds();
380 if (dim) { // DIM=
381 if (*dim < 1 || *dim > array->Rank()) {
382 context_.messages().Say("DIM=%d is out of range"_err_en_US, *dim);
383 return std::nullopt;
385 int zbDim{*dim - 1};
386 resultShape = array->shape();
387 resultShape.erase(
388 resultShape.begin() + zbDim); // scalar if array is vector
389 ConstantSubscript dimLength{array->shape()[zbDim]};
390 ConstantSubscript n{GetSize(resultShape)};
391 for (ConstantSubscript j{0}; j < n; ++j) {
392 ConstantSubscript hit{0};
393 if constexpr (WHICH == WhichLocation::Maxloc ||
394 WHICH == WhichLocation::Minloc) {
395 value.reset();
397 for (ConstantSubscript k{0}; k < dimLength;
398 ++k, ++at[zbDim], mask && ++maskAt[zbDim]) {
399 if ((!mask || mask->At(maskAt).IsTrue()) &&
400 IsHit(array->At(at), value, relation, back)) {
401 hit = at[zbDim];
402 if constexpr (WHICH == WhichLocation::Findloc) {
403 if (!back) {
404 break;
409 resultIndices.emplace_back(hit);
410 at[zbDim] = std::max<ConstantSubscript>(dimLength, 1);
411 array->IncrementSubscripts(at);
412 at[zbDim] = 1;
413 if (mask) {
414 maskAt[zbDim] = mask->lbounds()[zbDim] +
415 std::max<ConstantSubscript>(dimLength, 1) - 1;
416 mask->IncrementSubscripts(maskAt);
417 maskAt[zbDim] = mask->lbounds()[zbDim];
420 } else { // no DIM=
421 resultShape = ConstantSubscripts{array->Rank()}; // always a vector
422 ConstantSubscript n{GetSize(array->shape())};
423 resultIndices = ConstantSubscripts(array->Rank(), 0);
424 for (ConstantSubscript j{0}; j < n; ++j, array->IncrementSubscripts(at),
425 mask && mask->IncrementSubscripts(maskAt)) {
426 if ((!mask || mask->At(maskAt).IsTrue()) &&
427 IsHit(array->At(at), value, relation, back)) {
428 resultIndices = at;
429 if constexpr (WHICH == WhichLocation::Findloc) {
430 if (!back) {
431 break;
437 std::vector<Scalar<SubscriptInteger>> resultElements;
438 for (ConstantSubscript j : resultIndices) {
439 resultElements.emplace_back(j);
441 return Constant<SubscriptInteger>{
442 std::move(resultElements), std::move(resultShape)};
445 private:
446 template <typename T>
447 bool IsHit(typename Constant<T>::Element element,
448 std::optional<Constant<T>> &value,
449 [[maybe_unused]] RelationalOperator relation,
450 [[maybe_unused]] bool back) const {
451 std::optional<Expr<LogicalResult>> cmp;
452 bool result{true};
453 if (value) {
454 if constexpr (T::category == TypeCategory::Logical) {
455 // array(at) .EQV. value?
456 static_assert(WHICH == WhichLocation::Findloc);
457 cmp.emplace(ConvertToType<LogicalResult>(
458 Expr<T>{LogicalOperation<T::kind>{LogicalOperator::Eqv,
459 Expr<T>{Constant<T>{element}}, Expr<T>{Constant<T>{*value}}}}));
460 } else { // compare array(at) to value
461 if constexpr (T::category == TypeCategory::Real &&
462 (WHICH == WhichLocation::Maxloc ||
463 WHICH == WhichLocation::Minloc)) {
464 if (value && value->GetScalarValue().value().IsNotANumber() &&
465 (back || !element.IsNotANumber())) {
466 // Replace NaN
467 cmp.emplace(Constant<LogicalResult>{Scalar<LogicalResult>{true}});
470 if (!cmp) {
471 cmp.emplace(PackageRelation(relation, Expr<T>{Constant<T>{element}},
472 Expr<T>{Constant<T>{*value}}));
475 Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))};
476 result = GetScalarConstantValue<LogicalResult>(folded).value().IsTrue();
477 } else {
478 // first unmasked element for MAXLOC/MINLOC - always take it
480 if constexpr (WHICH == WhichLocation::Maxloc ||
481 WHICH == WhichLocation::Minloc) {
482 if (result) {
483 value.emplace(std::move(element));
486 return result;
489 static constexpr int dimArg{WHICH == WhichLocation::Findloc ? 2 : 1};
490 static constexpr int maskArg{dimArg + 1};
491 static constexpr int backArg{maskArg + 2};
493 DynamicType type_;
494 ActualArguments &arg_;
495 FoldingContext &context_;
498 template <WhichLocation which>
499 static std::optional<Constant<SubscriptInteger>> FoldLocationCall(
500 ActualArguments &arg, FoldingContext &context) {
501 if (arg[0]) {
502 if (auto type{arg[0]->GetType()}) {
503 if constexpr (which == WhichLocation::Findloc) {
504 // Both ARRAY and VALUE are susceptible to conversion to a common
505 // comparison type.
506 if (arg[1]) {
507 if (auto valType{arg[1]->GetType()}) {
508 if (auto compareType{ComparisonType(*type, *valType)}) {
509 type = compareType;
514 return common::SearchTypes(
515 LocationHelper<which>{std::move(*type), arg, context});
518 return std::nullopt;
521 template <WhichLocation which, typename T>
522 static Expr<T> FoldLocation(FoldingContext &context, FunctionRef<T> &&ref) {
523 static_assert(T::category == TypeCategory::Integer);
524 if (std::optional<Constant<SubscriptInteger>> found{
525 FoldLocationCall<which>(ref.arguments(), context)}) {
526 return Expr<T>{Fold(
527 context, ConvertToType<T>(Expr<SubscriptInteger>{std::move(*found)}))};
528 } else {
529 return Expr<T>{std::move(ref)};
533 // for IALL, IANY, & IPARITY
534 template <typename T>
535 static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
536 Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
537 Scalar<T> identity) {
538 static_assert(T::category == TypeCategory::Integer);
539 std::optional<int> dim;
540 if (std::optional<ArrayAndMask<T>> arrayAndMask{
541 ProcessReductionArgs<T>(context, ref.arguments(), dim,
542 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
543 OperationAccumulator<T> accumulator{arrayAndMask->array, operation};
544 return Expr<T>{DoReduction<T>(
545 arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
547 return Expr<T>{std::move(ref)};
550 template <int KIND>
551 Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
552 FoldingContext &context,
553 FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) {
554 using T = Type<TypeCategory::Integer, KIND>;
555 using Int4 = Type<TypeCategory::Integer, 4>;
556 ActualArguments &args{funcRef.arguments()};
557 auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
558 CHECK(intrinsic);
559 std::string name{intrinsic->name};
560 auto FromInt64{[&name, &context](std::int64_t n) {
561 Scalar<T> result{n};
562 if (result.ToInt64() != n &&
563 context.languageFeatures().ShouldWarn(
564 common::UsageWarning::FoldingException)) {
565 context.messages().Say(common::UsageWarning::FoldingException,
566 "Result of intrinsic function '%s' (%jd) overflows its result type"_warn_en_US,
567 name, std::intmax_t{n});
569 return result;
571 if (name == "abs") { // incl. babs, iiabs, jiaabs, & kiabs
572 return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
573 ScalarFunc<T, T>([&context](const Scalar<T> &i) -> Scalar<T> {
574 typename Scalar<T>::ValueWithOverflow j{i.ABS()};
575 if (j.overflow &&
576 context.languageFeatures().ShouldWarn(
577 common::UsageWarning::FoldingException)) {
578 context.messages().Say(common::UsageWarning::FoldingException,
579 "abs(integer(kind=%d)) folding overflowed"_warn_en_US, KIND);
581 return j.value;
582 }));
583 } else if (name == "bit_size") {
584 return Expr<T>{Scalar<T>::bits};
585 } else if (name == "ceiling" || name == "floor" || name == "nint") {
586 if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
587 // NINT rounds ties away from zero, not to even
588 common::RoundingMode mode{name == "ceiling" ? common::RoundingMode::Up
589 : name == "floor" ? common::RoundingMode::Down
590 : common::RoundingMode::TiesAwayFromZero};
591 return common::visit(
592 [&](const auto &kx) {
593 using TR = ResultType<decltype(kx)>;
594 return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef),
595 ScalarFunc<T, TR>([&](const Scalar<TR> &x) {
596 auto y{x.template ToInteger<Scalar<T>>(mode)};
597 if (y.flags.test(RealFlag::Overflow) &&
598 context.languageFeatures().ShouldWarn(
599 common::UsageWarning::FoldingException)) {
600 context.messages().Say(
601 common::UsageWarning::FoldingException,
602 "%s intrinsic folding overflow"_warn_en_US, name);
604 return y.value;
605 }));
607 cx->u);
609 } else if (name == "count") {
610 int maskKind = args[0]->GetType()->kind();
611 switch (maskKind) {
612 SWITCH_COVERS_ALL_CASES
613 case 1:
614 return FoldCount<T, 1>(context, std::move(funcRef));
615 case 2:
616 return FoldCount<T, 2>(context, std::move(funcRef));
617 case 4:
618 return FoldCount<T, 4>(context, std::move(funcRef));
619 case 8:
620 return FoldCount<T, 8>(context, std::move(funcRef));
622 } else if (name == "digits") {
623 if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
624 return Expr<T>{common::visit(
625 [](const auto &kx) {
626 return Scalar<ResultType<decltype(kx)>>::DIGITS;
628 cx->u)};
629 } else if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
630 return Expr<T>{common::visit(
631 [](const auto &kx) {
632 return Scalar<ResultType<decltype(kx)>>::DIGITS;
634 cx->u)};
635 } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
636 return Expr<T>{common::visit(
637 [](const auto &kx) {
638 return Scalar<typename ResultType<decltype(kx)>::Part>::DIGITS;
640 cx->u)};
642 } else if (name == "dim") {
643 return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
644 ScalarFunc<T, T, T>([&context](const Scalar<T> &x,
645 const Scalar<T> &y) -> Scalar<T> {
646 auto result{x.DIM(y)};
647 if (result.overflow &&
648 context.languageFeatures().ShouldWarn(
649 common::UsageWarning::FoldingException)) {
650 context.messages().Say(common::UsageWarning::FoldingException,
651 "DIM intrinsic folding overflow"_warn_en_US);
653 return result.value;
654 }));
655 } else if (name == "dot_product") {
656 return FoldDotProduct<T>(context, std::move(funcRef));
657 } else if (name == "dshiftl" || name == "dshiftr") {
658 const auto fptr{
659 name == "dshiftl" ? &Scalar<T>::DSHIFTL : &Scalar<T>::DSHIFTR};
660 // Third argument can be of any kind. However, it must be smaller or equal
661 // than BIT_SIZE. It can be converted to Int4 to simplify.
662 if (const auto *argCon{Folder<T>(context).Folding(args[0])};
663 argCon && argCon->empty()) {
664 } else if (const auto *shiftCon{Folder<Int4>(context).Folding(args[2])}) {
665 for (const auto &scalar : shiftCon->values()) {
666 std::int64_t shiftVal{scalar.ToInt64()};
667 if (shiftVal < 0) {
668 context.messages().Say("SHIFT=%jd count for %s is negative"_err_en_US,
669 std::intmax_t{shiftVal}, name);
670 break;
671 } else if (shiftVal > T::Scalar::bits) {
672 context.messages().Say(
673 "SHIFT=%jd count for %s is greater than %d"_err_en_US,
674 std::intmax_t{shiftVal}, name, T::Scalar::bits);
675 break;
679 return FoldElementalIntrinsic<T, T, T, Int4>(context, std::move(funcRef),
680 ScalarFunc<T, T, T, Int4>(
681 [&fptr](const Scalar<T> &i, const Scalar<T> &j,
682 const Scalar<Int4> &shift) -> Scalar<T> {
683 return std::invoke(fptr, i, j, static_cast<int>(shift.ToInt64()));
684 }));
685 } else if (name == "exponent") {
686 if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
687 return common::visit(
688 [&funcRef, &context](const auto &x) -> Expr<T> {
689 using TR = typename std::decay_t<decltype(x)>::Result;
690 return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef),
691 &Scalar<TR>::template EXPONENT<Scalar<T>>);
693 sx->u);
694 } else {
695 DIE("exponent argument must be real");
697 } else if (name == "findloc") {
698 return FoldLocation<WhichLocation::Findloc, T>(context, std::move(funcRef));
699 } else if (name == "huge") {
700 return Expr<T>{Scalar<T>::HUGE()};
701 } else if (name == "iachar" || name == "ichar") {
702 auto *someChar{UnwrapExpr<Expr<SomeCharacter>>(args[0])};
703 CHECK(someChar);
704 if (auto len{ToInt64(someChar->LEN())}) {
705 if (len.value() < 1) {
706 context.messages().Say(
707 "Character in intrinsic function %s must have length one"_err_en_US,
708 name);
709 } else if (len.value() > 1 &&
710 context.languageFeatures().ShouldWarn(
711 common::UsageWarning::Portability)) {
712 // Do not die, this was not checked before
713 context.messages().Say(common::UsageWarning::Portability,
714 "Character in intrinsic function %s should have length one"_port_en_US,
715 name);
716 } else {
717 return common::visit(
718 [&funcRef, &context, &FromInt64](const auto &str) -> Expr<T> {
719 using Char = typename std::decay_t<decltype(str)>::Result;
720 (void)FromInt64;
721 return FoldElementalIntrinsic<T, Char>(context,
722 std::move(funcRef),
723 ScalarFunc<T, Char>(
724 #ifndef _MSC_VER
725 [&FromInt64](const Scalar<Char> &c) {
726 return FromInt64(CharacterUtils<Char::kind>::ICHAR(
727 CharacterUtils<Char::kind>::Resize(c, 1)));
728 }));
729 #else // _MSC_VER
730 // MSVC 14 get confused by the original code above and
731 // ends up emitting an error about passing a std::string
732 // to the std::u16string instantiation of
733 // CharacterUtils<2>::ICHAR(). Can't find a work-around,
734 // so remove the FromInt64 error checking lambda that
735 // seems to have caused the proble.
736 [](const Scalar<Char> &c) {
737 return CharacterUtils<Char::kind>::ICHAR(
738 CharacterUtils<Char::kind>::Resize(c, 1));
739 }));
740 #endif // _MSC_VER
742 someChar->u);
745 } else if (name == "iand" || name == "ior" || name == "ieor") {
746 auto fptr{&Scalar<T>::IAND};
747 if (name == "iand") { // done in fptr declaration
748 } else if (name == "ior") {
749 fptr = &Scalar<T>::IOR;
750 } else if (name == "ieor") {
751 fptr = &Scalar<T>::IEOR;
752 } else {
753 common::die("missing case to fold intrinsic function %s", name.c_str());
755 return FoldElementalIntrinsic<T, T, T>(
756 context, std::move(funcRef), ScalarFunc<T, T, T>(fptr));
757 } else if (name == "iall") {
758 return FoldBitReduction(
759 context, std::move(funcRef), &Scalar<T>::IAND, Scalar<T>{}.NOT());
760 } else if (name == "iany") {
761 return FoldBitReduction(
762 context, std::move(funcRef), &Scalar<T>::IOR, Scalar<T>{});
763 } else if (name == "ibclr" || name == "ibset") {
764 // Second argument can be of any kind. However, it must be smaller
765 // than BIT_SIZE. It can be converted to Int4 to simplify.
766 auto fptr{&Scalar<T>::IBCLR};
767 if (name == "ibclr") { // done in fptr definition
768 } else if (name == "ibset") {
769 fptr = &Scalar<T>::IBSET;
770 } else {
771 common::die("missing case to fold intrinsic function %s", name.c_str());
773 if (const auto *argCon{Folder<T>(context).Folding(args[0])};
774 argCon && argCon->empty()) {
775 } else if (const auto *posCon{Folder<Int4>(context).Folding(args[1])}) {
776 for (const auto &scalar : posCon->values()) {
777 std::int64_t posVal{scalar.ToInt64()};
778 if (posVal < 0) {
779 context.messages().Say(
780 "bit position for %s (%jd) is negative"_err_en_US, name,
781 std::intmax_t{posVal});
782 break;
783 } else if (posVal >= T::Scalar::bits) {
784 context.messages().Say(
785 "bit position for %s (%jd) is not less than %d"_err_en_US, name,
786 std::intmax_t{posVal}, T::Scalar::bits);
787 break;
791 return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
792 ScalarFunc<T, T, Int4>(
793 [&](const Scalar<T> &i, const Scalar<Int4> &pos) -> Scalar<T> {
794 return std::invoke(fptr, i, static_cast<int>(pos.ToInt64()));
795 }));
796 } else if (name == "ibits") {
797 const auto *posCon{Folder<Int4>(context).Folding(args[1])};
798 const auto *lenCon{Folder<Int4>(context).Folding(args[2])};
799 if (const auto *argCon{Folder<T>(context).Folding(args[0])};
800 argCon && argCon->empty()) {
801 } else {
802 std::size_t posCt{posCon ? posCon->size() : 0};
803 std::size_t lenCt{lenCon ? lenCon->size() : 0};
804 std::size_t n{std::max(posCt, lenCt)};
805 for (std::size_t j{0}; j < n; ++j) {
806 int posVal{j < posCt || posCt == 1
807 ? static_cast<int>(posCon->values()[j % posCt].ToInt64())
808 : 0};
809 int lenVal{j < lenCt || lenCt == 1
810 ? static_cast<int>(lenCon->values()[j % lenCt].ToInt64())
811 : 0};
812 if (posVal < 0) {
813 context.messages().Say(
814 "bit position for IBITS(POS=%jd) is negative"_err_en_US,
815 std::intmax_t{posVal});
816 break;
817 } else if (lenVal < 0) {
818 context.messages().Say(
819 "bit length for IBITS(LEN=%jd) is negative"_err_en_US,
820 std::intmax_t{lenVal});
821 break;
822 } else if (posVal + lenVal > T::Scalar::bits) {
823 context.messages().Say(
824 "IBITS() must have POS+LEN (>=%jd) no greater than %d"_err_en_US,
825 std::intmax_t{posVal + lenVal}, T::Scalar::bits);
826 break;
830 return FoldElementalIntrinsic<T, T, Int4, Int4>(context, std::move(funcRef),
831 ScalarFunc<T, T, Int4, Int4>(
832 [&](const Scalar<T> &i, const Scalar<Int4> &pos,
833 const Scalar<Int4> &len) -> Scalar<T> {
834 return i.IBITS(static_cast<int>(pos.ToInt64()),
835 static_cast<int>(len.ToInt64()));
836 }));
837 } else if (name == "index" || name == "scan" || name == "verify") {
838 if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
839 return common::visit(
840 [&](const auto &kch) -> Expr<T> {
841 using TC = typename std::decay_t<decltype(kch)>::Result;
842 if (UnwrapExpr<Expr<SomeLogical>>(args[2])) { // BACK=
843 return FoldElementalIntrinsic<T, TC, TC, LogicalResult>(context,
844 std::move(funcRef),
845 ScalarFunc<T, TC, TC, LogicalResult>{
846 [&name, &FromInt64](const Scalar<TC> &str,
847 const Scalar<TC> &other,
848 const Scalar<LogicalResult> &back) {
849 return FromInt64(name == "index"
850 ? CharacterUtils<TC::kind>::INDEX(
851 str, other, back.IsTrue())
852 : name == "scan"
853 ? CharacterUtils<TC::kind>::SCAN(
854 str, other, back.IsTrue())
855 : CharacterUtils<TC::kind>::VERIFY(
856 str, other, back.IsTrue()));
857 }});
858 } else {
859 return FoldElementalIntrinsic<T, TC, TC>(context,
860 std::move(funcRef),
861 ScalarFunc<T, TC, TC>{
862 [&name, &FromInt64](
863 const Scalar<TC> &str, const Scalar<TC> &other) {
864 return FromInt64(name == "index"
865 ? CharacterUtils<TC::kind>::INDEX(str, other)
866 : name == "scan"
867 ? CharacterUtils<TC::kind>::SCAN(str, other)
868 : CharacterUtils<TC::kind>::VERIFY(str, other));
869 }});
872 charExpr->u);
873 } else {
874 DIE("first argument must be CHARACTER");
876 } else if (name == "int" || name == "int2" || name == "int8") {
877 if (auto *expr{UnwrapExpr<Expr<SomeType>>(args[0])}) {
878 return common::visit(
879 [&](auto &&x) -> Expr<T> {
880 using From = std::decay_t<decltype(x)>;
881 if constexpr (std::is_same_v<From, BOZLiteralConstant> ||
882 IsNumericCategoryExpr<From>()) {
883 return Fold(context, ConvertToType<T>(std::move(x)));
885 DIE("int() argument type not valid");
887 std::move(expr->u));
889 } else if (name == "int_ptr_kind") {
890 return Expr<T>{8};
891 } else if (name == "kind") {
892 // FoldOperation(FunctionRef &&) in fold-implementation.h will not
893 // have folded the argument; in the case of TypeParamInquiry,
894 // try to get the type of the parameter itself.
895 if (const auto *expr{args[0] ? args[0]->UnwrapExpr() : nullptr}) {
896 if (const auto *inquiry{UnwrapExpr<TypeParamInquiry>(*expr)}) {
897 if (const auto *typeSpec{inquiry->parameter().GetType()}) {
898 if (const auto *intrinType{typeSpec->AsIntrinsic()}) {
899 if (auto k{ToInt64(Fold(
900 context, Expr<SubscriptInteger>{intrinType->kind()}))}) {
901 return Expr<T>{*k};
905 } else if (auto dyType{expr->GetType()}) {
906 return Expr<T>{dyType->kind()};
909 } else if (name == "iparity") {
910 return FoldBitReduction(
911 context, std::move(funcRef), &Scalar<T>::IEOR, Scalar<T>{});
912 } else if (name == "ishft" || name == "ishftc") {
913 const auto *argCon{Folder<T>(context).Folding(args[0])};
914 const auto *shiftCon{Folder<Int4>(context).Folding(args[1])};
915 const auto *shiftVals{shiftCon ? &shiftCon->values() : nullptr};
916 const auto *sizeCon{args.size() == 3
917 ? Folder<Int4>{context, /*forOptionalArgument=*/true}.Folding(
918 args[2])
919 : nullptr};
920 const auto *sizeVals{sizeCon ? &sizeCon->values() : nullptr};
921 if ((argCon && argCon->empty()) || !shiftVals || shiftVals->empty() ||
922 (sizeVals && sizeVals->empty())) {
923 // size= and shift= values don't need to be checked
924 } else {
925 for (const auto &scalar : *shiftVals) {
926 std::int64_t shiftVal{scalar.ToInt64()};
927 if (shiftVal < -T::Scalar::bits) {
928 context.messages().Say(
929 "SHIFT=%jd count for %s is less than %d"_err_en_US,
930 std::intmax_t{shiftVal}, name, -T::Scalar::bits);
931 break;
932 } else if (shiftVal > T::Scalar::bits) {
933 context.messages().Say(
934 "SHIFT=%jd count for %s is greater than %d"_err_en_US,
935 std::intmax_t{shiftVal}, name, T::Scalar::bits);
936 break;
939 if (sizeVals) {
940 for (const auto &scalar : *sizeVals) {
941 std::int64_t sizeVal{scalar.ToInt64()};
942 if (sizeVal <= 0) {
943 context.messages().Say(
944 "SIZE=%jd count for ishftc is not positive"_err_en_US,
945 std::intmax_t{sizeVal}, name);
946 break;
947 } else if (sizeVal > T::Scalar::bits) {
948 context.messages().Say(
949 "SIZE=%jd count for ishftc is greater than %d"_err_en_US,
950 std::intmax_t{sizeVal}, T::Scalar::bits);
951 break;
954 if (shiftVals->size() == 1 || sizeVals->size() == 1 ||
955 shiftVals->size() == sizeVals->size()) {
956 auto iters{std::max(shiftVals->size(), sizeVals->size())};
957 for (std::size_t j{0}; j < iters; ++j) {
958 auto shiftVal{static_cast<int>(
959 (*shiftVals)[j % shiftVals->size()].ToInt64())};
960 auto sizeVal{
961 static_cast<int>((*sizeVals)[j % sizeVals->size()].ToInt64())};
962 if (sizeVal > 0 && std::abs(shiftVal) > sizeVal) {
963 context.messages().Say(
964 "SHIFT=%jd count for ishftc is greater in magnitude than SIZE=%jd"_err_en_US,
965 std::intmax_t{shiftVal}, std::intmax_t{sizeVal});
966 break;
972 if (name == "ishft") {
973 return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
974 ScalarFunc<T, T, Int4>(
975 [&](const Scalar<T> &i, const Scalar<Int4> &shift) -> Scalar<T> {
976 return i.ISHFT(static_cast<int>(shift.ToInt64()));
977 }));
978 } else if (!args.at(2)) { // ISHFTC(no SIZE=)
979 return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
980 ScalarFunc<T, T, Int4>(
981 [&](const Scalar<T> &i, const Scalar<Int4> &shift) -> Scalar<T> {
982 return i.ISHFTC(static_cast<int>(shift.ToInt64()));
983 }));
984 } else { // ISHFTC(with SIZE=)
985 return FoldElementalIntrinsic<T, T, Int4, Int4>(context,
986 std::move(funcRef),
987 ScalarFunc<T, T, Int4, Int4>(
988 [&](const Scalar<T> &i, const Scalar<Int4> &shift,
989 const Scalar<Int4> &size) -> Scalar<T> {
990 auto shiftVal{static_cast<int>(shift.ToInt64())};
991 auto sizeVal{static_cast<int>(size.ToInt64())};
992 return i.ISHFTC(shiftVal, sizeVal);
994 /*hasOptionalArgument=*/true);
996 } else if (name == "izext" || name == "jzext") {
997 if (args.size() == 1) {
998 if (auto *expr{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
999 // Rewrite to IAND(INT(n,k),255_k) for k=KIND(T)
1000 intrinsic->name = "iand";
1001 auto converted{ConvertToType<T>(std::move(*expr))};
1002 *expr = Fold(context, Expr<SomeInteger>{std::move(converted)});
1003 args.emplace_back(AsGenericExpr(Expr<T>{Scalar<T>{255}}));
1004 return FoldIntrinsicFunction(context, std::move(funcRef));
1007 } else if (name == "lbound") {
1008 return LBOUND(context, std::move(funcRef));
1009 } else if (name == "leadz" || name == "trailz" || name == "poppar" ||
1010 name == "popcnt") {
1011 if (auto *sn{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
1012 return common::visit(
1013 [&funcRef, &context, &name](const auto &n) -> Expr<T> {
1014 using TI = typename std::decay_t<decltype(n)>::Result;
1015 if (name == "poppar") {
1016 return FoldElementalIntrinsic<T, TI>(context, std::move(funcRef),
1017 ScalarFunc<T, TI>([](const Scalar<TI> &i) -> Scalar<T> {
1018 return Scalar<T>{i.POPPAR() ? 1 : 0};
1019 }));
1021 auto fptr{&Scalar<TI>::LEADZ};
1022 if (name == "leadz") { // done in fptr definition
1023 } else if (name == "trailz") {
1024 fptr = &Scalar<TI>::TRAILZ;
1025 } else if (name == "popcnt") {
1026 fptr = &Scalar<TI>::POPCNT;
1027 } else {
1028 common::die(
1029 "missing case to fold intrinsic function %s", name.c_str());
1031 return FoldElementalIntrinsic<T, TI>(context, std::move(funcRef),
1032 // `i` should be declared as `const Scalar<TI>&`.
1033 // We declare it as `auto` to workaround an msvc bug:
1034 // https://developercommunity.visualstudio.com/t/Regression:-nested-closure-assumes-wrong/10130223
1035 ScalarFunc<T, TI>([&fptr](const auto &i) -> Scalar<T> {
1036 return Scalar<T>{std::invoke(fptr, i)};
1037 }));
1039 sn->u);
1040 } else {
1041 DIE("leadz argument must be integer");
1043 } else if (name == "len") {
1044 if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
1045 return common::visit(
1046 [&](auto &kx) {
1047 if (auto len{kx.LEN()}) {
1048 if (IsScopeInvariantExpr(*len)) {
1049 return Fold(context, ConvertToType<T>(*std::move(len)));
1050 } else {
1051 return Expr<T>{std::move(funcRef)};
1053 } else {
1054 return Expr<T>{std::move(funcRef)};
1057 charExpr->u);
1058 } else {
1059 DIE("len() argument must be of character type");
1061 } else if (name == "len_trim") {
1062 if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
1063 return common::visit(
1064 [&](const auto &kch) -> Expr<T> {
1065 using TC = typename std::decay_t<decltype(kch)>::Result;
1066 return FoldElementalIntrinsic<T, TC>(context, std::move(funcRef),
1067 ScalarFunc<T, TC>{[&FromInt64](const Scalar<TC> &str) {
1068 return FromInt64(CharacterUtils<TC::kind>::LEN_TRIM(str));
1069 }});
1071 charExpr->u);
1072 } else {
1073 DIE("len_trim() argument must be of character type");
1075 } else if (name == "maskl" || name == "maskr") {
1076 // Argument can be of any kind but value has to be smaller than BIT_SIZE.
1077 // It can be safely converted to Int4 to simplify.
1078 const auto fptr{name == "maskl" ? &Scalar<T>::MASKL : &Scalar<T>::MASKR};
1079 return FoldElementalIntrinsic<T, Int4>(context, std::move(funcRef),
1080 ScalarFunc<T, Int4>([&fptr](const Scalar<Int4> &places) -> Scalar<T> {
1081 return fptr(static_cast<int>(places.ToInt64()));
1082 }));
1083 } else if (name == "matmul") {
1084 return FoldMatmul(context, std::move(funcRef));
1085 } else if (name == "max") {
1086 return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
1087 } else if (name == "max0" || name == "max1") {
1088 return RewriteSpecificMINorMAX(context, std::move(funcRef));
1089 } else if (name == "maxexponent") {
1090 if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
1091 return common::visit(
1092 [](const auto &x) {
1093 using TR = typename std::decay_t<decltype(x)>::Result;
1094 return Expr<T>{Scalar<TR>::MAXEXPONENT};
1096 sx->u);
1098 } else if (name == "maxloc") {
1099 return FoldLocation<WhichLocation::Maxloc, T>(context, std::move(funcRef));
1100 } else if (name == "maxval") {
1101 return FoldMaxvalMinval<T>(context, std::move(funcRef),
1102 RelationalOperator::GT, T::Scalar::Least());
1103 } else if (name == "merge_bits") {
1104 return FoldElementalIntrinsic<T, T, T, T>(
1105 context, std::move(funcRef), &Scalar<T>::MERGE_BITS);
1106 } else if (name == "min") {
1107 return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
1108 } else if (name == "min0" || name == "min1") {
1109 return RewriteSpecificMINorMAX(context, std::move(funcRef));
1110 } else if (name == "minexponent") {
1111 if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
1112 return common::visit(
1113 [](const auto &x) {
1114 using TR = typename std::decay_t<decltype(x)>::Result;
1115 return Expr<T>{Scalar<TR>::MINEXPONENT};
1117 sx->u);
1119 } else if (name == "minloc") {
1120 return FoldLocation<WhichLocation::Minloc, T>(context, std::move(funcRef));
1121 } else if (name == "minval") {
1122 return FoldMaxvalMinval<T>(
1123 context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE());
1124 } else if (name == "mod") {
1125 bool badPConst{false};
1126 if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) {
1127 *pExpr = Fold(context, std::move(*pExpr));
1128 if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst &&
1129 pConst->IsZero() &&
1130 context.languageFeatures().ShouldWarn(
1131 common::UsageWarning::FoldingAvoidsRuntimeCrash)) {
1132 context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash,
1133 "MOD: P argument is zero"_warn_en_US);
1134 badPConst = true;
1137 return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
1138 ScalarFuncWithContext<T, T, T>(
1139 [badPConst](FoldingContext &context, const Scalar<T> &x,
1140 const Scalar<T> &y) -> Scalar<T> {
1141 auto quotRem{x.DivideSigned(y)};
1142 if (context.languageFeatures().ShouldWarn(
1143 common::UsageWarning::FoldingAvoidsRuntimeCrash)) {
1144 if (!badPConst && quotRem.divisionByZero) {
1145 context.messages().Say(
1146 common::UsageWarning::FoldingAvoidsRuntimeCrash,
1147 "mod() by zero"_warn_en_US);
1148 } else if (quotRem.overflow) {
1149 context.messages().Say(
1150 common::UsageWarning::FoldingAvoidsRuntimeCrash,
1151 "mod() folding overflowed"_warn_en_US);
1154 return quotRem.remainder;
1155 }));
1156 } else if (name == "modulo") {
1157 bool badPConst{false};
1158 if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) {
1159 *pExpr = Fold(context, std::move(*pExpr));
1160 if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst &&
1161 pConst->IsZero() &&
1162 context.languageFeatures().ShouldWarn(
1163 common::UsageWarning::FoldingAvoidsRuntimeCrash)) {
1164 context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash,
1165 "MODULO: P argument is zero"_warn_en_US);
1166 badPConst = true;
1169 return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
1170 ScalarFuncWithContext<T, T, T>([badPConst](FoldingContext &context,
1171 const Scalar<T> &x,
1172 const Scalar<T> &y) -> Scalar<T> {
1173 auto result{x.MODULO(y)};
1174 if (!badPConst && result.overflow &&
1175 context.languageFeatures().ShouldWarn(
1176 common::UsageWarning::FoldingException)) {
1177 context.messages().Say(common::UsageWarning::FoldingException,
1178 "modulo() folding overflowed"_warn_en_US);
1180 return result.value;
1181 }));
1182 } else if (name == "not") {
1183 return FoldElementalIntrinsic<T, T>(
1184 context, std::move(funcRef), &Scalar<T>::NOT);
1185 } else if (name == "precision") {
1186 if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
1187 return Expr<T>{common::visit(
1188 [](const auto &kx) {
1189 return Scalar<ResultType<decltype(kx)>>::PRECISION;
1191 cx->u)};
1192 } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
1193 return Expr<T>{common::visit(
1194 [](const auto &kx) {
1195 return Scalar<typename ResultType<decltype(kx)>::Part>::PRECISION;
1197 cx->u)};
1199 } else if (name == "product") {
1200 return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{1});
1201 } else if (name == "radix") {
1202 return Expr<T>{2};
1203 } else if (name == "range") {
1204 if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
1205 return Expr<T>{common::visit(
1206 [](const auto &kx) {
1207 return Scalar<ResultType<decltype(kx)>>::RANGE;
1209 cx->u)};
1210 } else if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
1211 return Expr<T>{common::visit(
1212 [](const auto &kx) {
1213 return Scalar<ResultType<decltype(kx)>>::RANGE;
1215 cx->u)};
1216 } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
1217 return Expr<T>{common::visit(
1218 [](const auto &kx) {
1219 return Scalar<typename ResultType<decltype(kx)>::Part>::RANGE;
1221 cx->u)};
1223 } else if (name == "rank") {
1224 if (args[0]) {
1225 const Symbol *symbol{nullptr};
1226 if (auto dataRef{ExtractDataRef(args[0])}) {
1227 symbol = &dataRef->GetLastSymbol();
1228 } else {
1229 symbol = args[0]->GetAssumedTypeDummy();
1231 if (symbol && IsAssumedRank(*symbol)) {
1232 // DescriptorInquiry can only be placed in expression of kind
1233 // DescriptorInquiry::Result::kind.
1234 return ConvertToType<T>(
1235 Expr<Type<TypeCategory::Integer, DescriptorInquiry::Result::kind>>{
1236 DescriptorInquiry{
1237 NamedEntity{*symbol}, DescriptorInquiry::Field::Rank}});
1239 return Expr<T>{args[0]->Rank()};
1241 } else if (name == "selected_char_kind") {
1242 if (const auto *chCon{UnwrapExpr<Constant<TypeOf<std::string>>>(args[0])}) {
1243 if (std::optional<std::string> value{chCon->GetScalarValue()}) {
1244 int defaultKind{
1245 context.defaults().GetDefaultKind(TypeCategory::Character)};
1246 return Expr<T>{SelectedCharKind(*value, defaultKind)};
1249 } else if (name == "selected_int_kind") {
1250 if (auto p{ToInt64(args[0])}) {
1251 return Expr<T>{context.targetCharacteristics().SelectedIntKind(*p)};
1253 } else if (name == "selected_logical_kind") {
1254 if (auto p{ToInt64(args[0])}) {
1255 return Expr<T>{context.targetCharacteristics().SelectedLogicalKind(*p)};
1257 } else if (name == "selected_real_kind" ||
1258 name == "__builtin_ieee_selected_real_kind") {
1259 if (auto p{GetInt64ArgOr(args[0], 0)}) {
1260 if (auto r{GetInt64ArgOr(args[1], 0)}) {
1261 if (auto radix{GetInt64ArgOr(args[2], 2)}) {
1262 return Expr<T>{
1263 context.targetCharacteristics().SelectedRealKind(*p, *r, *radix)};
1267 } else if (name == "shape") {
1268 if (auto shape{GetContextFreeShape(context, args[0])}) {
1269 if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
1270 return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
1273 } else if (name == "shifta" || name == "shiftr" || name == "shiftl") {
1274 // Second argument can be of any kind. However, it must be smaller or
1275 // equal than BIT_SIZE. It can be converted to Int4 to simplify.
1276 auto fptr{&Scalar<T>::SHIFTA};
1277 if (name == "shifta") { // done in fptr definition
1278 } else if (name == "shiftr") {
1279 fptr = &Scalar<T>::SHIFTR;
1280 } else if (name == "shiftl") {
1281 fptr = &Scalar<T>::SHIFTL;
1282 } else {
1283 common::die("missing case to fold intrinsic function %s", name.c_str());
1285 if (const auto *argCon{Folder<T>(context).Folding(args[0])};
1286 argCon && argCon->empty()) {
1287 } else if (const auto *shiftCon{Folder<Int4>(context).Folding(args[1])}) {
1288 for (const auto &scalar : shiftCon->values()) {
1289 std::int64_t shiftVal{scalar.ToInt64()};
1290 if (shiftVal < 0) {
1291 context.messages().Say("SHIFT=%jd count for %s is negative"_err_en_US,
1292 std::intmax_t{shiftVal}, name, -T::Scalar::bits);
1293 break;
1294 } else if (shiftVal > T::Scalar::bits) {
1295 context.messages().Say(
1296 "SHIFT=%jd count for %s is greater than %d"_err_en_US,
1297 std::intmax_t{shiftVal}, name, T::Scalar::bits);
1298 break;
1302 return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
1303 ScalarFunc<T, T, Int4>(
1304 [&](const Scalar<T> &i, const Scalar<Int4> &shift) -> Scalar<T> {
1305 return std::invoke(fptr, i, static_cast<int>(shift.ToInt64()));
1306 }));
1307 } else if (name == "sign") {
1308 return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
1309 ScalarFunc<T, T, T>([&context](const Scalar<T> &j,
1310 const Scalar<T> &k) -> Scalar<T> {
1311 typename Scalar<T>::ValueWithOverflow result{j.SIGN(k)};
1312 if (result.overflow &&
1313 context.languageFeatures().ShouldWarn(
1314 common::UsageWarning::FoldingException)) {
1315 context.messages().Say(common::UsageWarning::FoldingException,
1316 "sign(integer(kind=%d)) folding overflowed"_warn_en_US, KIND);
1318 return result.value;
1319 }));
1320 } else if (name == "size") {
1321 if (auto shape{GetContextFreeShape(context, args[0])}) {
1322 if (args[1]) { // DIM= is present, get one extent
1323 std::optional<int> dim;
1324 if (const auto *array{args[0].value().UnwrapExpr()}; array &&
1325 !CheckDimArg(args[1], *array, context.messages(), false, dim)) {
1326 return MakeInvalidIntrinsic<T>(std::move(funcRef));
1327 } else if (dim) {
1328 if (auto &extent{shape->at(*dim)}) {
1329 return Fold(context, ConvertToType<T>(std::move(*extent)));
1332 } else if (auto extents{common::AllElementsPresent(std::move(*shape))}) {
1333 // DIM= is absent; compute PRODUCT(SHAPE())
1334 ExtentExpr product{1};
1335 for (auto &&extent : std::move(*extents)) {
1336 product = std::move(product) * std::move(extent);
1338 return Expr<T>{ConvertToType<T>(Fold(context, std::move(product)))};
1341 } else if (name == "sizeof") { // in bytes; extension
1342 if (auto info{
1343 characteristics::TypeAndShape::Characterize(args[0], context)}) {
1344 if (auto bytes{info->MeasureSizeInBytes(context)}) {
1345 return Expr<T>{Fold(context, ConvertToType<T>(std::move(*bytes)))};
1348 } else if (name == "storage_size") { // in bits
1349 if (auto info{
1350 characteristics::TypeAndShape::Characterize(args[0], context)}) {
1351 if (auto bytes{info->MeasureElementSizeInBytes(context, true)}) {
1352 return Expr<T>{
1353 Fold(context, Expr<T>{8} * ConvertToType<T>(std::move(*bytes)))};
1356 } else if (name == "sum") {
1357 return FoldSum<T>(context, std::move(funcRef));
1358 } else if (name == "ubound") {
1359 return UBOUND(context, std::move(funcRef));
1360 } else if (name == "__builtin_numeric_storage_size") {
1361 if (!context.moduleFileName()) {
1362 // Don't fold this reference until it appears in the module file
1363 // for ISO_FORTRAN_ENV -- the value depends on the compiler options
1364 // that might be in force.
1365 } else {
1366 auto intBytes{
1367 context.targetCharacteristics().GetByteSize(TypeCategory::Integer,
1368 context.defaults().GetDefaultKind(TypeCategory::Integer))};
1369 auto realBytes{
1370 context.targetCharacteristics().GetByteSize(TypeCategory::Real,
1371 context.defaults().GetDefaultKind(TypeCategory::Real))};
1372 if (intBytes != realBytes &&
1373 context.languageFeatures().ShouldWarn(
1374 common::UsageWarning::FoldingValueChecks)) {
1375 context.messages().Say(common::UsageWarning::FoldingValueChecks,
1376 *context.moduleFileName(),
1377 "NUMERIC_STORAGE_SIZE from ISO_FORTRAN_ENV is not well-defined when default INTEGER and REAL are not consistent due to compiler options"_warn_en_US);
1379 return Expr<T>{8 * std::min(intBytes, realBytes)};
1382 return Expr<T>{std::move(funcRef)};
1385 // Substitutes a bare type parameter reference with its value if it has one now
1386 // in an instantiation. Bare LEN type parameters are substituted only when
1387 // the known value is constant.
1388 Expr<TypeParamInquiry::Result> FoldOperation(
1389 FoldingContext &context, TypeParamInquiry &&inquiry) {
1390 std::optional<NamedEntity> base{inquiry.base()};
1391 parser::CharBlock parameterName{inquiry.parameter().name()};
1392 if (base) {
1393 // Handling "designator%typeParam". Get the value of the type parameter
1394 // from the instantiation of the base
1395 if (const semantics::DeclTypeSpec *
1396 declType{base->GetLastSymbol().GetType()}) {
1397 if (const semantics::ParamValue *
1398 paramValue{
1399 declType->derivedTypeSpec().FindParameter(parameterName)}) {
1400 const semantics::MaybeIntExpr &paramExpr{paramValue->GetExplicit()};
1401 if (paramExpr && IsConstantExpr(*paramExpr)) {
1402 Expr<SomeInteger> intExpr{*paramExpr};
1403 return Fold(context,
1404 ConvertToType<TypeParamInquiry::Result>(std::move(intExpr)));
1408 } else {
1409 // A "bare" type parameter: replace with its value, if that's now known
1410 // in a current derived type instantiation.
1411 if (const auto *pdt{context.pdtInstance()}) {
1412 auto restorer{context.WithoutPDTInstance()}; // don't loop
1413 bool isLen{false};
1414 if (const semantics::Scope * scope{pdt->scope()}) {
1415 auto iter{scope->find(parameterName)};
1416 if (iter != scope->end()) {
1417 const Symbol &symbol{*iter->second};
1418 const auto *details{symbol.detailsIf<semantics::TypeParamDetails>()};
1419 if (details) {
1420 isLen = details->attr() == common::TypeParamAttr::Len;
1421 const semantics::MaybeIntExpr &initExpr{details->init()};
1422 if (initExpr && IsConstantExpr(*initExpr) &&
1423 (!isLen || ToInt64(*initExpr))) {
1424 Expr<SomeInteger> expr{*initExpr};
1425 return Fold(context,
1426 ConvertToType<TypeParamInquiry::Result>(std::move(expr)));
1431 if (const auto *value{pdt->FindParameter(parameterName)}) {
1432 if (value->isExplicit()) {
1433 auto folded{Fold(context,
1434 AsExpr(ConvertToType<TypeParamInquiry::Result>(
1435 Expr<SomeInteger>{value->GetExplicit().value()})))};
1436 if (!isLen || ToInt64(folded)) {
1437 return folded;
1443 return AsExpr(std::move(inquiry));
1446 std::optional<std::int64_t> ToInt64(const Expr<SomeInteger> &expr) {
1447 return common::visit(
1448 [](const auto &kindExpr) { return ToInt64(kindExpr); }, expr.u);
1451 std::optional<std::int64_t> ToInt64(const Expr<SomeType> &expr) {
1452 return ToInt64(UnwrapExpr<Expr<SomeInteger>>(expr));
1455 std::optional<std::int64_t> ToInt64(const ActualArgument &arg) {
1456 return ToInt64(arg.UnwrapExpr());
1459 #ifdef _MSC_VER // disable bogus warning about missing definitions
1460 #pragma warning(disable : 4661)
1461 #endif
1462 FOR_EACH_INTEGER_KIND(template class ExpressionBase, )
1463 template class ExpressionBase<SomeInteger>;
1464 } // namespace Fortran::evaluate