1 //===-- lib/Semantics/check-case.cpp --------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "check-case.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/reference.h"
12 #include "flang/Common/template.h"
13 #include "flang/Evaluate/fold.h"
14 #include "flang/Evaluate/type.h"
15 #include "flang/Parser/parse-tree.h"
16 #include "flang/Semantics/semantics.h"
17 #include "flang/Semantics/tools.h"
20 namespace Fortran::semantics
{
22 template <typename T
> class CaseValues
{
24 CaseValues(SemanticsContext
&c
, const evaluate::DynamicType
&t
)
25 : context_
{c
}, caseExprType_
{t
} {}
27 void Check(const std::list
<parser::CaseConstruct::Case
> &cases
) {
28 for (const parser::CaseConstruct::Case
&c
: cases
) {
32 cases_
.sort(Comparator
{});
33 if (!AreCasesDisjoint()) { // C1149
34 ReportConflictingCases();
40 using Value
= evaluate::Scalar
<T
>;
42 void AddCase(const parser::CaseConstruct::Case
&c
) {
43 const auto &stmt
{std::get
<parser::Statement
<parser::CaseStmt
>>(c
.t
)};
44 const parser::CaseStmt
&caseStmt
{stmt
.statement
};
45 const auto &selector
{std::get
<parser::CaseSelector
>(caseStmt
.t
)};
48 [&](const std::list
<parser::CaseValueRange
> &ranges
) {
49 for (const auto &range
: ranges
) {
50 auto pair
{ComputeBounds(range
)};
51 if (pair
.first
&& pair
.second
&& *pair
.first
> *pair
.second
) {
52 context_
.Warn(common::UsageWarning::EmptyCase
, stmt
.source
,
53 "CASE has lower bound greater than upper bound"_warn_en_US
);
55 if constexpr (T::category
== TypeCategory::Logical
) { // C1148
56 if ((pair
.first
|| pair
.second
) &&
57 (!pair
.first
|| !pair
.second
||
58 *pair
.first
!= *pair
.second
)) {
59 context_
.Say(stmt
.source
,
60 "CASE range is not allowed for LOGICAL"_err_en_US
);
63 cases_
.emplace_back(stmt
);
64 cases_
.back().lower
= std::move(pair
.first
);
65 cases_
.back().upper
= std::move(pair
.second
);
69 [&](const parser::Default
&) { cases_
.emplace_front(stmt
); },
74 std::optional
<Value
> GetValue(const parser::CaseValue
&caseValue
) {
75 const parser::Expr
&expr
{caseValue
.thing
.thing
.value()};
76 auto *x
{expr
.typedExpr
.get()};
77 if (x
&& x
->v
) { // C1147
78 auto type
{x
->v
->GetType()};
79 if (type
&& type
->category() == caseExprType_
.category() &&
80 (type
->category() != TypeCategory::Character
||
81 type
->kind() == caseExprType_
.kind())) {
82 parser::Messages buffer
; // discarded folding messages
83 parser::ContextualMessages foldingMessages
{expr
.source
, &buffer
};
84 evaluate::FoldingContext foldingContext
{
85 context_
.foldingContext(), foldingMessages
};
86 auto folded
{evaluate::Fold(foldingContext
, SomeExpr
{*x
->v
})};
87 if (auto converted
{evaluate::Fold(foldingContext
,
88 evaluate::ConvertToType(T::GetType(), SomeExpr
{folded
}))}) {
89 if (auto value
{evaluate::GetScalarConstantValue
<T
>(*converted
)}) {
90 auto back
{evaluate::Fold(foldingContext
,
91 evaluate::ConvertToType(*type
, SomeExpr
{*converted
}))};
96 context_
.Warn(common::UsageWarning::CaseOverflow
, expr
.source
,
97 "CASE value (%s) overflows type (%s) of SELECT CASE expression"_warn_en_US
,
98 folded
.AsFortran(), caseExprType_
.AsFortran());
104 context_
.Say(expr
.source
,
105 "CASE value (%s) must be a constant scalar"_err_en_US
,
108 std::string typeStr
{type
? type
->AsFortran() : "typeless"s
};
109 context_
.Say(expr
.source
,
110 "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US
,
111 typeStr
, caseExprType_
.AsFortran());
118 using PairOfValues
= std::pair
<std::optional
<Value
>, std::optional
<Value
>>;
119 PairOfValues
ComputeBounds(const parser::CaseValueRange
&range
) {
120 return common::visit(
122 [&](const parser::CaseValue
&x
) {
123 auto value
{GetValue(x
)};
124 return PairOfValues
{value
, value
};
126 [&](const parser::CaseValueRange::Range
&x
) {
127 std::optional
<Value
> lo
, hi
;
129 lo
= GetValue(*x
.lower
);
132 hi
= GetValue(*x
.upper
);
134 if ((x
.lower
&& !lo
) || (x
.upper
&& !hi
)) {
135 return PairOfValues
{}; // error case
137 return PairOfValues
{std::move(lo
), std::move(hi
)};
144 explicit Case(const parser::Statement
<parser::CaseStmt
> &s
) : stmt
{s
} {}
145 bool IsDefault() const { return !lower
&& !upper
; }
146 std::string
AsFortran() const {
149 llvm::raw_string_ostream bs
{result
};
151 evaluate::Constant
<T
>{*lower
}.AsFortran(bs
<< '(');
154 } else if (*lower
!= *upper
) {
155 evaluate::Constant
<T
>{*upper
}.AsFortran(bs
<< ':');
159 evaluate::Constant
<T
>{*upper
}.AsFortran(bs
<< "(:") << ')';
167 const parser::Statement
<parser::CaseStmt
> &stmt
;
168 std::optional
<Value
> lower
, upper
;
171 // Defines a comparator for use with std::list<>::sort().
172 // Returns true if and only if the highest value in range x is less
173 // than the least value in range y. The DEFAULT case is arbitrarily
174 // defined to be less than all others. When two ranges overlap,
175 // neither is less than the other.
177 bool operator()(const Case
&x
, const Case
&y
) const {
179 return !y
.IsDefault();
181 return x
.upper
&& y
.lower
&& *x
.upper
< *y
.lower
;
186 bool AreCasesDisjoint() const {
187 auto endIter
{cases_
.end()};
188 for (auto iter
{cases_
.begin()}; iter
!= endIter
; ++iter
) {
190 if (++next
!= endIter
&& !Comparator
{}(*iter
, *next
)) {
197 // This has quadratic time, but only runs in error cases
198 void ReportConflictingCases() {
199 for (auto iter
{cases_
.begin()}; iter
!= cases_
.end(); ++iter
) {
200 parser::Message
*msg
{nullptr};
201 for (auto p
{cases_
.begin()}; p
!= cases_
.end(); ++p
) {
202 if (p
->stmt
.source
.begin() < iter
->stmt
.source
.begin() &&
203 !Comparator
{}(*p
, *iter
) && !Comparator
{}(*iter
, *p
)) {
205 msg
= &context_
.Say(iter
->stmt
.source
,
206 "CASE %s conflicts with previous cases"_err_en_US
,
210 p
->stmt
.source
, "Conflicting CASE %s"_en_US
, p
->AsFortran());
216 SemanticsContext
&context_
;
217 const evaluate::DynamicType
&caseExprType_
;
218 std::list
<Case
> cases_
;
219 bool hasErrors_
{false};
222 template <TypeCategory CAT
> struct TypeVisitor
{
224 using Types
= evaluate::CategoryTypes
<CAT
>;
225 template <typename T
> Result
Test() {
226 if (T::kind
== exprType
.kind()) {
227 CaseValues
<T
>(context
, exprType
).Check(caseList
);
233 SemanticsContext
&context
;
234 const evaluate::DynamicType
&exprType
;
235 const std::list
<parser::CaseConstruct::Case
> &caseList
;
238 void CaseChecker::Enter(const parser::CaseConstruct
&construct
) {
239 const auto &selectCaseStmt
{
240 std::get
<parser::Statement
<parser::SelectCaseStmt
>>(construct
.t
)};
241 const auto &selectCase
{selectCaseStmt
.statement
};
242 const auto &selectExpr
{
243 std::get
<parser::Scalar
<parser::Expr
>>(selectCase
.t
).thing
};
244 const auto *x
{GetExpr(context_
, selectExpr
)};
246 return; // expression semantics failed
248 if (auto exprType
{x
->GetType()}) {
249 const auto &caseList
{
250 std::get
<std::list
<parser::CaseConstruct::Case
>>(construct
.t
)};
251 switch (exprType
->category()) {
252 case TypeCategory::Integer
:
254 TypeVisitor
<TypeCategory::Integer
>{context_
, *exprType
, caseList
});
256 case TypeCategory::Logical
:
257 CaseValues
<evaluate::Type
<TypeCategory::Logical
, 1>>{context_
, *exprType
}
260 case TypeCategory::Character
:
262 TypeVisitor
<TypeCategory::Character
>{context_
, *exprType
, caseList
});
268 context_
.Say(selectExpr
.source
,
269 "SELECT CASE expression must be integer, logical, or character"_err_en_US
);
271 } // namespace Fortran::semantics