[flang][cuda] Adapt ExternalNameConversion to work in gpu module (#117039)
[llvm-project.git] / flang / lib / Semantics / rewrite-directives.cpp
blobc94d0f3855bee31723ee29c739e0f00555c69f18
1 //===-- lib/Semantics/rewrite-directives.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "rewrite-directives.h"
10 #include "flang/Parser/parse-tree-visitor.h"
11 #include "flang/Parser/parse-tree.h"
12 #include "flang/Semantics/semantics.h"
13 #include "flang/Semantics/symbol.h"
14 #include "llvm/Frontend/OpenMP/OMP.h"
15 #include <list>
17 namespace Fortran::semantics {
19 using namespace parser::literals;
21 class DirectiveRewriteMutator {
22 public:
23 explicit DirectiveRewriteMutator(SemanticsContext &context)
24 : context_{context} {}
26 // Default action for a parse tree node is to visit children.
27 template <typename T> bool Pre(T &) { return true; }
28 template <typename T> void Post(T &) {}
30 protected:
31 SemanticsContext &context_;
34 // Rewrite atomic constructs to add an explicit memory ordering to all that do
35 // not specify it, honoring in this way the `atomic_default_mem_order` clause of
36 // the REQUIRES directive.
37 class OmpRewriteMutator : public DirectiveRewriteMutator {
38 public:
39 explicit OmpRewriteMutator(SemanticsContext &context)
40 : DirectiveRewriteMutator(context) {}
42 template <typename T> bool Pre(T &) { return true; }
43 template <typename T> void Post(T &) {}
45 bool Pre(parser::OpenMPAtomicConstruct &);
46 bool Pre(parser::OpenMPRequiresConstruct &);
48 private:
49 bool atomicDirectiveDefaultOrderFound_{false};
52 bool OmpRewriteMutator::Pre(parser::OpenMPAtomicConstruct &x) {
53 // Find top-level parent of the operation.
54 Symbol *topLevelParent{common::visit(
55 [&](auto &atomic) {
56 Symbol *symbol{nullptr};
57 Scope *scope{
58 &context_.FindScope(std::get<parser::Verbatim>(atomic.t).source)};
59 do {
60 if (Symbol * parent{scope->symbol()}) {
61 symbol = parent;
63 scope = &scope->parent();
64 } while (!scope->IsGlobal());
66 assert(symbol &&
67 "Atomic construct must be within a scope associated with a symbol");
68 return symbol;
70 x.u)};
72 // Get the `atomic_default_mem_order` clause from the top-level parent.
73 std::optional<common::OmpAtomicDefaultMemOrderType> defaultMemOrder;
74 common::visit(
75 [&](auto &details) {
76 if constexpr (std::is_convertible_v<decltype(&details),
77 WithOmpDeclarative *>) {
78 if (details.has_ompAtomicDefaultMemOrder()) {
79 defaultMemOrder = *details.ompAtomicDefaultMemOrder();
83 topLevelParent->details());
85 if (!defaultMemOrder) {
86 return false;
89 auto findMemOrderClause =
90 [](const std::list<parser::OmpAtomicClause> &clauses) {
91 return llvm::any_of(clauses, [](const auto &clause) {
92 return std::get_if<parser::OmpMemoryOrderClause>(&clause.u);
93 });
96 // Get the clause list to which the new memory order clause must be added,
97 // only if there are no other memory order clauses present for this atomic
98 // directive.
99 std::list<parser::OmpAtomicClause> *clauseList = common::visit(
100 common::visitors{[&](parser::OmpAtomic &atomicConstruct) {
101 // OmpAtomic only has a single list of clauses.
102 auto &clauses{std::get<parser::OmpAtomicClauseList>(
103 atomicConstruct.t)};
104 return !findMemOrderClause(clauses.v) ? &clauses.v
105 : nullptr;
107 [&](auto &atomicConstruct) {
108 // All other atomic constructs have two lists of clauses.
109 auto &clausesLhs{std::get<0>(atomicConstruct.t)};
110 auto &clausesRhs{std::get<2>(atomicConstruct.t)};
111 return !findMemOrderClause(clausesLhs.v) &&
112 !findMemOrderClause(clausesRhs.v)
113 ? &clausesRhs.v
114 : nullptr;
116 x.u);
118 // Add a memory order clause to the atomic directive.
119 if (clauseList) {
120 atomicDirectiveDefaultOrderFound_ = true;
121 switch (*defaultMemOrder) {
122 case common::OmpAtomicDefaultMemOrderType::AcqRel:
123 clauseList->emplace_back<parser::OmpMemoryOrderClause>(common::visit(
124 common::visitors{[](parser::OmpAtomicRead &) -> parser::OmpClause {
125 return parser::OmpClause::Acquire{};
127 [](parser::OmpAtomicCapture &) -> parser::OmpClause {
128 return parser::OmpClause::AcqRel{};
130 [](auto &) -> parser::OmpClause {
131 // parser::{OmpAtomic, OmpAtomicUpdate, OmpAtomicWrite}
132 return parser::OmpClause::Release{};
134 x.u));
135 break;
136 case common::OmpAtomicDefaultMemOrderType::Relaxed:
137 clauseList->emplace_back<parser::OmpMemoryOrderClause>(
138 parser::OmpClause{parser::OmpClause::Relaxed{}});
139 break;
140 case common::OmpAtomicDefaultMemOrderType::SeqCst:
141 clauseList->emplace_back<parser::OmpMemoryOrderClause>(
142 parser::OmpClause{parser::OmpClause::SeqCst{}});
143 break;
147 return false;
150 bool OmpRewriteMutator::Pre(parser::OpenMPRequiresConstruct &x) {
151 for (parser::OmpClause &clause : std::get<parser::OmpClauseList>(x.t).v) {
152 if (std::holds_alternative<parser::OmpClause::AtomicDefaultMemOrder>(
153 clause.u) &&
154 atomicDirectiveDefaultOrderFound_) {
155 context_.Say(clause.source,
156 "REQUIRES directive with '%s' clause found lexically after atomic "
157 "operation without a memory order clause"_err_en_US,
158 parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(
159 llvm::omp::OMPC_atomic_default_mem_order)
160 .str()));
163 return false;
166 bool RewriteOmpParts(SemanticsContext &context, parser::Program &program) {
167 if (!context.IsEnabled(common::LanguageFeature::OpenMP)) {
168 return true;
170 OmpRewriteMutator ompMutator{context};
171 parser::Walk(program, ompMutator);
172 return !context.AnyFatalError();
175 } // namespace Fortran::semantics