AMDGPU: Allow f16/bf16 for DS_READ_TR16_B64 gfx950 builtins (#118297)
[llvm-project.git] / flang / lib / Semantics / resolve-names-utils.cpp
bloba838d49c06104d42e5a00b160f61a3130ccbb980
1 //===-- lib/Semantics/resolve-names-utils.cpp -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "resolve-names-utils.h"
10 #include "flang/Common/Fortran-features.h"
11 #include "flang/Common/Fortran.h"
12 #include "flang/Common/idioms.h"
13 #include "flang/Common/indirection.h"
14 #include "flang/Evaluate/fold.h"
15 #include "flang/Evaluate/tools.h"
16 #include "flang/Evaluate/traverse.h"
17 #include "flang/Evaluate/type.h"
18 #include "flang/Parser/char-block.h"
19 #include "flang/Parser/parse-tree.h"
20 #include "flang/Semantics/expression.h"
21 #include "flang/Semantics/semantics.h"
22 #include "flang/Semantics/tools.h"
23 #include <initializer_list>
24 #include <variant>
26 namespace Fortran::semantics {
28 using common::LanguageFeature;
29 using common::LogicalOperator;
30 using common::NumericOperator;
31 using common::RelationalOperator;
32 using IntrinsicOperator = parser::DefinedOperator::IntrinsicOperator;
34 static GenericKind MapIntrinsicOperator(IntrinsicOperator);
36 Symbol *Resolve(const parser::Name &name, Symbol *symbol) {
37 if (symbol && !name.symbol) {
38 name.symbol = symbol;
40 return symbol;
42 Symbol &Resolve(const parser::Name &name, Symbol &symbol) {
43 return *Resolve(name, &symbol);
46 parser::MessageFixedText WithSeverity(
47 const parser::MessageFixedText &msg, parser::Severity severity) {
48 return parser::MessageFixedText{
49 msg.text().begin(), msg.text().size(), severity};
52 bool IsIntrinsicOperator(
53 const SemanticsContext &context, const SourceName &name) {
54 std::string str{name.ToString()};
55 for (int i{0}; i != common::LogicalOperator_enumSize; ++i) {
56 auto names{context.languageFeatures().GetNames(LogicalOperator{i})};
57 if (llvm::is_contained(names, str)) {
58 return true;
61 for (int i{0}; i != common::RelationalOperator_enumSize; ++i) {
62 auto names{context.languageFeatures().GetNames(RelationalOperator{i})};
63 if (llvm::is_contained(names, str)) {
64 return true;
67 return false;
70 bool IsLogicalConstant(
71 const SemanticsContext &context, const SourceName &name) {
72 std::string str{name.ToString()};
73 return str == ".true." || str == ".false." ||
74 (context.IsEnabled(LanguageFeature::LogicalAbbreviations) &&
75 (str == ".t" || str == ".f."));
78 void GenericSpecInfo::Resolve(Symbol *symbol) const {
79 if (symbol) {
80 if (auto *details{symbol->detailsIf<GenericDetails>()}) {
81 details->set_kind(kind_);
83 if (parseName_) {
84 semantics::Resolve(*parseName_, symbol);
89 void GenericSpecInfo::Analyze(const parser::DefinedOpName &name) {
90 kind_ = GenericKind::OtherKind::DefinedOp;
91 parseName_ = &name.v;
92 symbolName_ = name.v.source;
95 void GenericSpecInfo::Analyze(const parser::GenericSpec &x) {
96 symbolName_ = x.source;
97 kind_ = common::visit(
98 common::visitors{
99 [&](const parser::Name &y) -> GenericKind {
100 parseName_ = &y;
101 symbolName_ = y.source;
102 return GenericKind::OtherKind::Name;
104 [&](const parser::DefinedOperator &y) {
105 return common::visit(
106 common::visitors{
107 [&](const parser::DefinedOpName &z) -> GenericKind {
108 Analyze(z);
109 return GenericKind::OtherKind::DefinedOp;
111 [&](const IntrinsicOperator &z) {
112 return MapIntrinsicOperator(z);
115 y.u);
117 [&](const parser::GenericSpec::Assignment &) -> GenericKind {
118 return GenericKind::OtherKind::Assignment;
120 [&](const parser::GenericSpec::ReadFormatted &) -> GenericKind {
121 return common::DefinedIo::ReadFormatted;
123 [&](const parser::GenericSpec::ReadUnformatted &) -> GenericKind {
124 return common::DefinedIo::ReadUnformatted;
126 [&](const parser::GenericSpec::WriteFormatted &) -> GenericKind {
127 return common::DefinedIo::WriteFormatted;
129 [&](const parser::GenericSpec::WriteUnformatted &) -> GenericKind {
130 return common::DefinedIo::WriteUnformatted;
133 x.u);
136 llvm::raw_ostream &operator<<(
137 llvm::raw_ostream &os, const GenericSpecInfo &info) {
138 os << "GenericSpecInfo: kind=" << info.kind_.ToString();
139 os << " parseName="
140 << (info.parseName_ ? info.parseName_->ToString() : "null");
141 os << " symbolName="
142 << (info.symbolName_ ? info.symbolName_->ToString() : "null");
143 return os;
146 // parser::DefinedOperator::IntrinsicOperator -> GenericKind
147 static GenericKind MapIntrinsicOperator(IntrinsicOperator op) {
148 switch (op) {
149 SWITCH_COVERS_ALL_CASES
150 case IntrinsicOperator::Concat:
151 return GenericKind::OtherKind::Concat;
152 case IntrinsicOperator::Power:
153 return NumericOperator::Power;
154 case IntrinsicOperator::Multiply:
155 return NumericOperator::Multiply;
156 case IntrinsicOperator::Divide:
157 return NumericOperator::Divide;
158 case IntrinsicOperator::Add:
159 return NumericOperator::Add;
160 case IntrinsicOperator::Subtract:
161 return NumericOperator::Subtract;
162 case IntrinsicOperator::AND:
163 return LogicalOperator::And;
164 case IntrinsicOperator::OR:
165 return LogicalOperator::Or;
166 case IntrinsicOperator::EQV:
167 return LogicalOperator::Eqv;
168 case IntrinsicOperator::NEQV:
169 return LogicalOperator::Neqv;
170 case IntrinsicOperator::NOT:
171 return LogicalOperator::Not;
172 case IntrinsicOperator::LT:
173 return RelationalOperator::LT;
174 case IntrinsicOperator::LE:
175 return RelationalOperator::LE;
176 case IntrinsicOperator::EQ:
177 return RelationalOperator::EQ;
178 case IntrinsicOperator::NE:
179 return RelationalOperator::NE;
180 case IntrinsicOperator::GE:
181 return RelationalOperator::GE;
182 case IntrinsicOperator::GT:
183 return RelationalOperator::GT;
187 class ArraySpecAnalyzer {
188 public:
189 ArraySpecAnalyzer(SemanticsContext &context) : context_{context} {}
190 ArraySpec Analyze(const parser::ArraySpec &);
191 ArraySpec AnalyzeDeferredShapeSpecList(const parser::DeferredShapeSpecList &);
192 ArraySpec Analyze(const parser::ComponentArraySpec &);
193 ArraySpec Analyze(const parser::CoarraySpec &);
195 private:
196 SemanticsContext &context_;
197 ArraySpec arraySpec_;
199 template <typename T> void Analyze(const std::list<T> &list) {
200 for (const auto &elem : list) {
201 Analyze(elem);
204 void Analyze(const parser::AssumedShapeSpec &);
205 void Analyze(const parser::ExplicitShapeSpec &);
206 void Analyze(const parser::AssumedImpliedSpec &);
207 void Analyze(const parser::DeferredShapeSpecList &);
208 void Analyze(const parser::AssumedRankSpec &);
209 void MakeExplicit(const std::optional<parser::SpecificationExpr> &,
210 const parser::SpecificationExpr &);
211 void MakeImplied(const std::optional<parser::SpecificationExpr> &);
212 void MakeDeferred(int);
213 Bound GetBound(const std::optional<parser::SpecificationExpr> &);
214 Bound GetBound(const parser::SpecificationExpr &);
217 ArraySpec AnalyzeArraySpec(
218 SemanticsContext &context, const parser::ArraySpec &arraySpec) {
219 return ArraySpecAnalyzer{context}.Analyze(arraySpec);
221 ArraySpec AnalyzeArraySpec(
222 SemanticsContext &context, const parser::ComponentArraySpec &arraySpec) {
223 return ArraySpecAnalyzer{context}.Analyze(arraySpec);
225 ArraySpec AnalyzeDeferredShapeSpecList(SemanticsContext &context,
226 const parser::DeferredShapeSpecList &deferredShapeSpecs) {
227 return ArraySpecAnalyzer{context}.AnalyzeDeferredShapeSpecList(
228 deferredShapeSpecs);
230 ArraySpec AnalyzeCoarraySpec(
231 SemanticsContext &context, const parser::CoarraySpec &coarraySpec) {
232 return ArraySpecAnalyzer{context}.Analyze(coarraySpec);
235 ArraySpec ArraySpecAnalyzer::Analyze(const parser::ComponentArraySpec &x) {
236 common::visit([this](const auto &y) { Analyze(y); }, x.u);
237 CHECK(!arraySpec_.empty());
238 return arraySpec_;
240 ArraySpec ArraySpecAnalyzer::Analyze(const parser::ArraySpec &x) {
241 common::visit(common::visitors{
242 [&](const parser::AssumedSizeSpec &y) {
243 Analyze(
244 std::get<std::list<parser::ExplicitShapeSpec>>(y.t));
245 Analyze(std::get<parser::AssumedImpliedSpec>(y.t));
247 [&](const parser::ImpliedShapeSpec &y) { Analyze(y.v); },
248 [&](const auto &y) { Analyze(y); },
250 x.u);
251 CHECK(!arraySpec_.empty());
252 return arraySpec_;
254 ArraySpec ArraySpecAnalyzer::AnalyzeDeferredShapeSpecList(
255 const parser::DeferredShapeSpecList &x) {
256 Analyze(x);
257 CHECK(!arraySpec_.empty());
258 return arraySpec_;
260 ArraySpec ArraySpecAnalyzer::Analyze(const parser::CoarraySpec &x) {
261 common::visit(
262 common::visitors{
263 [&](const parser::DeferredCoshapeSpecList &y) { MakeDeferred(y.v); },
264 [&](const parser::ExplicitCoshapeSpec &y) {
265 Analyze(std::get<std::list<parser::ExplicitShapeSpec>>(y.t));
266 MakeImplied(
267 std::get<std::optional<parser::SpecificationExpr>>(y.t));
270 x.u);
271 CHECK(!arraySpec_.empty());
272 return arraySpec_;
275 void ArraySpecAnalyzer::Analyze(const parser::AssumedShapeSpec &x) {
276 arraySpec_.push_back(ShapeSpec::MakeAssumedShape(GetBound(x.v)));
278 void ArraySpecAnalyzer::Analyze(const parser::ExplicitShapeSpec &x) {
279 MakeExplicit(std::get<std::optional<parser::SpecificationExpr>>(x.t),
280 std::get<parser::SpecificationExpr>(x.t));
282 void ArraySpecAnalyzer::Analyze(const parser::AssumedImpliedSpec &x) {
283 MakeImplied(x.v);
285 void ArraySpecAnalyzer::Analyze(const parser::DeferredShapeSpecList &x) {
286 MakeDeferred(x.v);
288 void ArraySpecAnalyzer::Analyze(const parser::AssumedRankSpec &) {
289 arraySpec_.push_back(ShapeSpec::MakeAssumedRank());
292 void ArraySpecAnalyzer::MakeExplicit(
293 const std::optional<parser::SpecificationExpr> &lb,
294 const parser::SpecificationExpr &ub) {
295 arraySpec_.push_back(ShapeSpec::MakeExplicit(GetBound(lb), GetBound(ub)));
297 void ArraySpecAnalyzer::MakeImplied(
298 const std::optional<parser::SpecificationExpr> &lb) {
299 arraySpec_.push_back(ShapeSpec::MakeImplied(GetBound(lb)));
301 void ArraySpecAnalyzer::MakeDeferred(int n) {
302 for (int i = 0; i < n; ++i) {
303 arraySpec_.push_back(ShapeSpec::MakeDeferred());
307 Bound ArraySpecAnalyzer::GetBound(
308 const std::optional<parser::SpecificationExpr> &x) {
309 return x ? GetBound(*x) : Bound{1};
311 Bound ArraySpecAnalyzer::GetBound(const parser::SpecificationExpr &x) {
312 MaybeSubscriptIntExpr expr;
313 if (MaybeExpr maybeExpr{AnalyzeExpr(context_, x.v)}) {
314 if (auto *intExpr{evaluate::UnwrapExpr<SomeIntExpr>(*maybeExpr)}) {
315 expr = evaluate::Fold(context_.foldingContext(),
316 evaluate::ConvertToType<evaluate::SubscriptInteger>(
317 std::move(*intExpr)));
320 return Bound{std::move(expr)};
323 // If src is SAVE (explicitly or implicitly),
324 // set SAVE attribute on all members of dst.
325 static void PropagateSaveAttr(
326 const EquivalenceObject &src, EquivalenceSet &dst) {
327 if (IsSaved(src.symbol)) {
328 for (auto &obj : dst) {
329 if (!obj.symbol.attrs().test(Attr::SAVE)) {
330 obj.symbol.attrs().set(Attr::SAVE);
331 // If the other equivalenced symbol itself is not SAVE,
332 // then adding SAVE here implies that it has to be implicit.
333 obj.symbol.implicitAttrs().set(Attr::SAVE);
338 static void PropagateSaveAttr(const EquivalenceSet &src, EquivalenceSet &dst) {
339 if (!src.empty()) {
340 PropagateSaveAttr(src.front(), dst);
344 void EquivalenceSets::AddToSet(const parser::Designator &designator) {
345 if (CheckDesignator(designator)) {
346 if (Symbol * symbol{currObject_.symbol}) {
347 if (!currSet_.empty()) {
348 // check this symbol against first of set for compatibility
349 Symbol &first{currSet_.front().symbol};
350 CheckCanEquivalence(designator.source, first, *symbol) &&
351 CheckCanEquivalence(designator.source, *symbol, first);
353 auto subscripts{currObject_.subscripts};
354 if (subscripts.empty()) {
355 if (const ArraySpec * shape{symbol->GetShape()};
356 shape && shape->IsExplicitShape()) {
357 // record a whole array as its first element
358 for (const ShapeSpec &spec : *shape) {
359 if (auto lbound{spec.lbound().GetExplicit()}) {
360 if (auto lbValue{evaluate::ToInt64(*lbound)}) {
361 subscripts.push_back(*lbValue);
362 continue;
365 subscripts.clear(); // error recovery
366 break;
370 auto substringStart{currObject_.substringStart};
371 currSet_.emplace_back(
372 *symbol, subscripts, substringStart, designator.source);
373 PropagateSaveAttr(currSet_.back(), currSet_);
376 currObject_ = {};
379 void EquivalenceSets::FinishSet(const parser::CharBlock &source) {
380 std::set<std::size_t> existing; // indices of sets intersecting this one
381 for (auto &obj : currSet_) {
382 auto it{objectToSet_.find(obj)};
383 if (it != objectToSet_.end()) {
384 existing.insert(it->second); // symbol already in this set
387 if (existing.empty()) {
388 sets_.push_back({}); // create a new equivalence set
389 MergeInto(source, currSet_, sets_.size() - 1);
390 } else {
391 auto it{existing.begin()};
392 std::size_t dstIndex{*it};
393 MergeInto(source, currSet_, dstIndex);
394 while (++it != existing.end()) {
395 MergeInto(source, sets_[*it], dstIndex);
398 currSet_.clear();
401 // Report an error or warning if sym1 and sym2 cannot be in the same equivalence
402 // set.
403 bool EquivalenceSets::CheckCanEquivalence(
404 const parser::CharBlock &source, const Symbol &sym1, const Symbol &sym2) {
405 std::optional<common::LanguageFeature> feature;
406 std::optional<parser::MessageFixedText> msg;
407 const DeclTypeSpec *type1{sym1.GetType()};
408 const DeclTypeSpec *type2{sym2.GetType()};
409 bool isDefaultNum1{IsDefaultNumericSequenceType(type1)};
410 bool isAnyNum1{IsAnyNumericSequenceType(type1)};
411 bool isDefaultNum2{IsDefaultNumericSequenceType(type2)};
412 bool isAnyNum2{IsAnyNumericSequenceType(type2)};
413 bool isChar1{IsCharacterSequenceType(type1)};
414 bool isChar2{IsCharacterSequenceType(type2)};
415 if (sym1.attrs().test(Attr::PROTECTED) &&
416 !sym2.attrs().test(Attr::PROTECTED)) { // C8114
417 msg = "Equivalence set cannot contain '%s'"
418 " with PROTECTED attribute and '%s' without"_err_en_US;
419 } else if ((isDefaultNum1 && isDefaultNum2) || (isChar1 && isChar2)) {
420 // ok & standard conforming
421 } else if (!(isAnyNum1 || isChar1) &&
422 !(isAnyNum2 || isChar2)) { // C8110 - C8113
423 if (AreTkCompatibleTypes(type1, type2)) {
424 msg =
425 "nonstandard: Equivalence set contains '%s' and '%s' with same type that is neither numeric nor character sequence type"_port_en_US;
426 feature = LanguageFeature::EquivalenceSameNonSequence;
427 } else {
428 msg = "Equivalence set cannot contain '%s' and '%s' with distinct types "
429 "that are not both numeric or character sequence types"_err_en_US;
431 } else if (isAnyNum1) {
432 if (isChar2) {
433 msg =
434 "nonstandard: Equivalence set contains '%s' that is numeric sequence type and '%s' that is character"_port_en_US;
435 feature = LanguageFeature::EquivalenceNumericWithCharacter;
436 } else if (isAnyNum2) {
437 if (isDefaultNum1) {
438 msg =
439 "nonstandard: Equivalence set contains '%s' that is a default "
440 "numeric sequence type and '%s' that is numeric with non-default kind"_port_en_US;
441 } else if (!isDefaultNum2) {
442 msg = "nonstandard: Equivalence set contains '%s' and '%s' that are "
443 "numeric sequence types with non-default kinds"_port_en_US;
445 feature = LanguageFeature::EquivalenceNonDefaultNumeric;
448 if (msg) {
449 if (feature) {
450 context_.Warn(
451 *feature, source, std::move(*msg), sym1.name(), sym2.name());
452 } else {
453 context_.Say(source, std::move(*msg), sym1.name(), sym2.name());
455 return false;
457 return true;
460 // Move objects from src to sets_[dstIndex]
461 void EquivalenceSets::MergeInto(const parser::CharBlock &source,
462 EquivalenceSet &src, std::size_t dstIndex) {
463 EquivalenceSet &dst{sets_[dstIndex]};
464 PropagateSaveAttr(dst, src);
465 for (const auto &obj : src) {
466 dst.push_back(obj);
467 objectToSet_[obj] = dstIndex;
469 PropagateSaveAttr(src, dst);
470 src.clear();
473 // If set has an object with this symbol, return it.
474 const EquivalenceObject *EquivalenceSets::Find(
475 const EquivalenceSet &set, const Symbol &symbol) {
476 for (const auto &obj : set) {
477 if (obj.symbol == symbol) {
478 return &obj;
481 return nullptr;
484 bool EquivalenceSets::CheckDesignator(const parser::Designator &designator) {
485 return common::visit(
486 common::visitors{
487 [&](const parser::DataRef &x) {
488 return CheckDataRef(designator.source, x);
490 [&](const parser::Substring &x) {
491 const auto &dataRef{std::get<parser::DataRef>(x.t)};
492 const auto &range{std::get<parser::SubstringRange>(x.t)};
493 bool ok{CheckDataRef(designator.source, dataRef)};
494 if (const auto &lb{std::get<0>(range.t)}) {
495 ok &= CheckSubstringBound(lb->thing.thing.value(), true);
496 } else {
497 currObject_.substringStart = 1;
499 if (const auto &ub{std::get<1>(range.t)}) {
500 ok &= CheckSubstringBound(ub->thing.thing.value(), false);
502 return ok;
505 designator.u);
508 bool EquivalenceSets::CheckDataRef(
509 const parser::CharBlock &source, const parser::DataRef &x) {
510 return common::visit(
511 common::visitors{
512 [&](const parser::Name &name) { return CheckObject(name); },
513 [&](const common::Indirection<parser::StructureComponent> &) {
514 context_.Say(source, // C8107
515 "Derived type component '%s' is not allowed in an equivalence set"_err_en_US,
516 source);
517 return false;
519 [&](const common::Indirection<parser::ArrayElement> &elem) {
520 bool ok{CheckDataRef(source, elem.value().base)};
521 for (const auto &subscript : elem.value().subscripts) {
522 ok &= common::visit(
523 common::visitors{
524 [&](const parser::SubscriptTriplet &) {
525 context_.Say(source, // C924, R872
526 "Array section '%s' is not allowed in an equivalence set"_err_en_US,
527 source);
528 return false;
530 [&](const parser::IntExpr &y) {
531 return CheckArrayBound(y.thing.value());
534 subscript.u);
536 return ok;
538 [&](const common::Indirection<parser::CoindexedNamedObject> &) {
539 context_.Say(source, // C924 (R872)
540 "Coindexed object '%s' is not allowed in an equivalence set"_err_en_US,
541 source);
542 return false;
545 x.u);
548 bool EquivalenceSets::CheckObject(const parser::Name &name) {
549 currObject_.symbol = name.symbol;
550 return currObject_.symbol != nullptr;
553 bool EquivalenceSets::CheckArrayBound(const parser::Expr &bound) {
554 MaybeExpr expr{
555 evaluate::Fold(context_.foldingContext(), AnalyzeExpr(context_, bound))};
556 if (!expr) {
557 return false;
559 if (expr->Rank() > 0) {
560 context_.Say(bound.source, // C924, R872
561 "Array with vector subscript '%s' is not allowed in an equivalence set"_err_en_US,
562 bound.source);
563 return false;
565 auto subscript{evaluate::ToInt64(*expr)};
566 if (!subscript) {
567 context_.Say(bound.source, // C8109
568 "Array with nonconstant subscript '%s' is not allowed in an equivalence set"_err_en_US,
569 bound.source);
570 return false;
572 currObject_.subscripts.push_back(*subscript);
573 return true;
576 bool EquivalenceSets::CheckSubstringBound(
577 const parser::Expr &bound, bool isStart) {
578 MaybeExpr expr{
579 evaluate::Fold(context_.foldingContext(), AnalyzeExpr(context_, bound))};
580 if (!expr) {
581 return false;
583 auto subscript{evaluate::ToInt64(*expr)};
584 if (!subscript) {
585 context_.Say(bound.source, // C8109
586 "Substring with nonconstant bound '%s' is not allowed in an equivalence set"_err_en_US,
587 bound.source);
588 return false;
590 if (!isStart) {
591 auto start{currObject_.substringStart};
592 if (*subscript < (start ? *start : 1)) {
593 context_.Say(bound.source, // C8116
594 "Substring with zero length is not allowed in an equivalence set"_err_en_US);
595 return false;
597 } else if (*subscript != 1) {
598 currObject_.substringStart = *subscript;
600 return true;
603 bool EquivalenceSets::IsCharacterSequenceType(const DeclTypeSpec *type) {
604 return IsSequenceType(type, [&](const IntrinsicTypeSpec &type) {
605 auto kind{evaluate::ToInt64(type.kind())};
606 return type.category() == TypeCategory::Character && kind &&
607 kind.value() == context_.GetDefaultKind(TypeCategory::Character);
611 // Numeric or logical type of default kind or DOUBLE PRECISION or DOUBLE COMPLEX
612 bool EquivalenceSets::IsDefaultKindNumericType(const IntrinsicTypeSpec &type) {
613 if (auto kind{evaluate::ToInt64(type.kind())}) {
614 switch (type.category()) {
615 case TypeCategory::Integer:
616 case TypeCategory::Logical:
617 return *kind == context_.GetDefaultKind(TypeCategory::Integer);
618 case TypeCategory::Real:
619 case TypeCategory::Complex:
620 return *kind == context_.GetDefaultKind(TypeCategory::Real) ||
621 *kind == context_.doublePrecisionKind();
622 default:
623 return false;
626 return false;
629 bool EquivalenceSets::IsDefaultNumericSequenceType(const DeclTypeSpec *type) {
630 return IsSequenceType(type, [&](const IntrinsicTypeSpec &type) {
631 return IsDefaultKindNumericType(type);
635 bool EquivalenceSets::IsAnyNumericSequenceType(const DeclTypeSpec *type) {
636 return IsSequenceType(type, [&](const IntrinsicTypeSpec &type) {
637 return type.category() == TypeCategory::Logical ||
638 common::IsNumericTypeCategory(type.category());
642 // Is type an intrinsic type that satisfies predicate or a sequence type
643 // whose components do.
644 bool EquivalenceSets::IsSequenceType(const DeclTypeSpec *type,
645 std::function<bool(const IntrinsicTypeSpec &)> predicate) {
646 if (!type) {
647 return false;
648 } else if (const IntrinsicTypeSpec * intrinsic{type->AsIntrinsic()}) {
649 return predicate(*intrinsic);
650 } else if (const DerivedTypeSpec * derived{type->AsDerived()}) {
651 for (const auto &pair : *derived->typeSymbol().scope()) {
652 const Symbol &component{*pair.second};
653 if (IsAllocatableOrPointer(component) ||
654 !IsSequenceType(component.GetType(), predicate)) {
655 return false;
658 return true;
659 } else {
660 return false;
664 // MapSubprogramToNewSymbols() relies on the following recursive symbol/scope
665 // copying infrastructure to duplicate an interface's symbols and map all
666 // of the symbol references in their contained expressions and interfaces
667 // to the new symbols.
669 struct SymbolAndTypeMappings {
670 std::map<const Symbol *, const Symbol *> symbolMap;
671 std::map<const DeclTypeSpec *, const DeclTypeSpec *> typeMap;
674 class SymbolMapper : public evaluate::AnyTraverse<SymbolMapper, bool> {
675 public:
676 using Base = evaluate::AnyTraverse<SymbolMapper, bool>;
677 SymbolMapper(Scope &scope, SymbolAndTypeMappings &map)
678 : Base{*this}, scope_{scope}, map_{map} {}
679 using Base::operator();
680 bool operator()(const SymbolRef &ref) {
681 if (const Symbol *mapped{MapSymbol(*ref)}) {
682 const_cast<SymbolRef &>(ref) = *mapped;
683 } else if (ref->has<UseDetails>()) {
684 CopySymbol(&*ref);
686 return false;
688 bool operator()(const Symbol &x) {
689 if (MapSymbol(x)) {
690 DIE("SymbolMapper hit symbol outside SymbolRef");
692 return false;
694 void MapSymbolExprs(Symbol &);
695 Symbol *CopySymbol(const Symbol *);
697 private:
698 void MapParamValue(ParamValue &param) { (*this)(param.GetExplicit()); }
699 void MapBound(Bound &bound) { (*this)(bound.GetExplicit()); }
700 void MapShapeSpec(ShapeSpec &spec) {
701 MapBound(spec.lbound());
702 MapBound(spec.ubound());
704 const Symbol *MapSymbol(const Symbol &) const;
705 const Symbol *MapSymbol(const Symbol *) const;
706 const DeclTypeSpec *MapType(const DeclTypeSpec &);
707 const DeclTypeSpec *MapType(const DeclTypeSpec *);
708 const Symbol *MapInterface(const Symbol *);
710 Scope &scope_;
711 SymbolAndTypeMappings &map_;
714 Symbol *SymbolMapper::CopySymbol(const Symbol *symbol) {
715 if (symbol) {
716 if (auto *subp{symbol->detailsIf<SubprogramDetails>()}) {
717 if (subp->isInterface()) {
718 if (auto pair{scope_.try_emplace(symbol->name(), symbol->attrs())};
719 pair.second) {
720 Symbol &copy{*pair.first->second};
721 map_.symbolMap[symbol] = &copy;
722 copy.set(symbol->test(Symbol::Flag::Subroutine)
723 ? Symbol::Flag::Subroutine
724 : Symbol::Flag::Function);
725 Scope &newScope{scope_.MakeScope(Scope::Kind::Subprogram, &copy)};
726 copy.set_scope(&newScope);
727 copy.set_details(SubprogramDetails{});
728 auto &newSubp{copy.get<SubprogramDetails>()};
729 newSubp.set_isInterface(true);
730 newSubp.set_isDummy(subp->isDummy());
731 newSubp.set_defaultIgnoreTKR(subp->defaultIgnoreTKR());
732 MapSubprogramToNewSymbols(*symbol, copy, newScope, &map_);
733 return &copy;
736 } else if (Symbol * copy{scope_.CopySymbol(*symbol)}) {
737 map_.symbolMap[symbol] = copy;
738 return copy;
741 return nullptr;
744 void SymbolMapper::MapSymbolExprs(Symbol &symbol) {
745 common::visit(
746 common::visitors{[&](ObjectEntityDetails &object) {
747 if (const DeclTypeSpec * type{object.type()}) {
748 if (const DeclTypeSpec * newType{MapType(*type)}) {
749 object.ReplaceType(*newType);
752 for (ShapeSpec &spec : object.shape()) {
753 MapShapeSpec(spec);
755 for (ShapeSpec &spec : object.coshape()) {
756 MapShapeSpec(spec);
759 [&](ProcEntityDetails &proc) {
760 if (const Symbol *
761 mappedSymbol{MapInterface(proc.rawProcInterface())}) {
762 proc.set_procInterfaces(
763 *mappedSymbol, BypassGeneric(mappedSymbol->GetUltimate()));
764 } else if (const DeclTypeSpec * mappedType{MapType(proc.type())}) {
765 proc.set_type(*mappedType);
767 if (proc.init()) {
768 if (const Symbol * mapped{MapSymbol(*proc.init())}) {
769 proc.set_init(*mapped);
773 [&](const HostAssocDetails &hostAssoc) {
774 if (const Symbol * mapped{MapSymbol(hostAssoc.symbol())}) {
775 symbol.set_details(HostAssocDetails{*mapped});
778 [](const auto &) {}},
779 symbol.details());
782 const Symbol *SymbolMapper::MapSymbol(const Symbol &symbol) const {
783 if (auto iter{map_.symbolMap.find(&symbol)}; iter != map_.symbolMap.end()) {
784 return iter->second;
786 return nullptr;
789 const Symbol *SymbolMapper::MapSymbol(const Symbol *symbol) const {
790 return symbol ? MapSymbol(*symbol) : nullptr;
793 const DeclTypeSpec *SymbolMapper::MapType(const DeclTypeSpec &type) {
794 if (auto iter{map_.typeMap.find(&type)}; iter != map_.typeMap.end()) {
795 return iter->second;
797 const DeclTypeSpec *newType{nullptr};
798 if (type.category() == DeclTypeSpec::Category::Character) {
799 const CharacterTypeSpec &charType{type.characterTypeSpec()};
800 if (charType.length().GetExplicit()) {
801 ParamValue newLen{charType.length()};
802 (*this)(newLen.GetExplicit());
803 newType = &scope_.MakeCharacterType(
804 std::move(newLen), KindExpr{charType.kind()});
806 } else if (const DerivedTypeSpec *derived{type.AsDerived()}) {
807 if (!derived->parameters().empty()) {
808 DerivedTypeSpec newDerived{derived->name(), derived->typeSymbol()};
809 newDerived.CookParameters(scope_.context().foldingContext());
810 for (const auto &[paramName, paramValue] : derived->parameters()) {
811 ParamValue newParamValue{paramValue};
812 MapParamValue(newParamValue);
813 newDerived.AddParamValue(paramName, std::move(newParamValue));
815 // Scope::InstantiateDerivedTypes() instantiates it later.
816 newType = &scope_.MakeDerivedType(type.category(), std::move(newDerived));
819 if (newType) {
820 map_.typeMap[&type] = newType;
822 return newType;
825 const DeclTypeSpec *SymbolMapper::MapType(const DeclTypeSpec *type) {
826 return type ? MapType(*type) : nullptr;
829 const Symbol *SymbolMapper::MapInterface(const Symbol *interface) {
830 if (const Symbol *mapped{MapSymbol(interface)}) {
831 return mapped;
833 if (interface) {
834 if (&interface->owner() != &scope_) {
835 return interface;
836 } else if (const auto *subp{interface->detailsIf<SubprogramDetails>()};
837 subp && subp->isInterface()) {
838 return CopySymbol(interface);
841 return nullptr;
844 void MapSubprogramToNewSymbols(const Symbol &oldSymbol, Symbol &newSymbol,
845 Scope &newScope, SymbolAndTypeMappings *mappings) {
846 SymbolAndTypeMappings newMappings;
847 if (!mappings) {
848 mappings = &newMappings;
850 mappings->symbolMap[&oldSymbol] = &newSymbol;
851 const auto &oldDetails{oldSymbol.get<SubprogramDetails>()};
852 auto &newDetails{newSymbol.get<SubprogramDetails>()};
853 SymbolMapper mapper{newScope, *mappings};
854 for (const Symbol *dummyArg : oldDetails.dummyArgs()) {
855 if (!dummyArg) {
856 newDetails.add_alternateReturn();
857 } else if (Symbol * copy{mapper.CopySymbol(dummyArg)}) {
858 copy->set(Symbol::Flag::Implicit, false);
859 newDetails.add_dummyArg(*copy);
860 mappings->symbolMap[dummyArg] = copy;
863 if (oldDetails.isFunction()) {
864 newScope.erase(newSymbol.name());
865 const Symbol &result{oldDetails.result()};
866 if (Symbol * copy{mapper.CopySymbol(&result)}) {
867 newDetails.set_result(*copy);
868 mappings->symbolMap[&result] = copy;
871 for (auto &[_, ref] : newScope) {
872 mapper.MapSymbolExprs(*ref);
874 newScope.InstantiateDerivedTypes();
877 } // namespace Fortran::semantics