Bump version to 19.1.2
[llvm-project.git] / flang / lib / Lower / IterationSpace.cpp
blob930353640383733e8d6c854ee0f3c9082f0376ce
1 //===-- IterationSpace.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11 //===----------------------------------------------------------------------===//
13 #include "flang/Lower/IterationSpace.h"
14 #include "flang/Evaluate/expression.h"
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Lower/Support/Utils.h"
17 #include "llvm/Support/Debug.h"
18 #include <optional>
20 #define DEBUG_TYPE "flang-lower-iteration-space"
22 unsigned Fortran::lower::getHashValue(
23 const Fortran::lower::ExplicitIterSpace::ArrayBases &x) {
24 return Fortran::common::visit(
25 [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x);
28 bool Fortran::lower::isEqual(
29 const Fortran::lower::ExplicitIterSpace::ArrayBases &x,
30 const Fortran::lower::ExplicitIterSpace::ArrayBases &y) {
31 return Fortran::common::visit(
32 Fortran::common::visitors{
33 // Fortran::semantics::Symbol * are the exception here. These pointers
34 // have identity; if two Symbol * values are the same (different) then
35 // they are the same (different) logical symbol.
36 [&](Fortran::lower::FrontEndSymbol p,
37 Fortran::lower::FrontEndSymbol q) { return p == q; },
38 [&](const auto *p, const auto *q) {
39 if constexpr (std::is_same_v<decltype(p), decltype(q)>) {
40 LLVM_DEBUG(llvm::dbgs()
41 << "is equal: " << p << ' ' << q << ' '
42 << IsEqualEvaluateExpr::isEqual(*p, *q) << '\n');
43 return IsEqualEvaluateExpr::isEqual(*p, *q);
44 } else {
45 // Different subtree types are never equal.
46 return false;
48 }},
49 x, y);
52 namespace {
54 /// This class can recover the base array in an expression that contains
55 /// explicit iteration space symbols. Most of the class can be ignored as it is
56 /// boilerplate Fortran::evaluate::Expr traversal.
57 class ArrayBaseFinder {
58 public:
59 using RT = bool;
61 ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms)
62 : controlVars(syms.begin(), syms.end()) {}
64 template <typename T>
65 void operator()(const T &x) {
66 (void)find(x);
69 /// Get the list of bases.
70 llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases>
71 getBases() const {
72 LLVM_DEBUG(llvm::dbgs()
73 << "number of array bases found: " << bases.size() << '\n');
74 return bases;
77 private:
78 // First, the cases that are of interest.
79 RT find(const Fortran::semantics::Symbol &symbol) {
80 if (symbol.Rank() > 0) {
81 bases.push_back(&symbol);
82 return true;
84 return {};
86 RT find(const Fortran::evaluate::Component &x) {
87 auto found = find(x.base());
88 if (!found && x.base().Rank() == 0 && x.Rank() > 0) {
89 bases.push_back(&x);
90 return true;
92 return found;
94 RT find(const Fortran::evaluate::ArrayRef &x) {
95 for (const auto &sub : x.subscript())
96 (void)find(sub);
97 if (x.base().IsSymbol()) {
98 if (x.Rank() > 0 || intersection(x.subscript())) {
99 bases.push_back(&x);
100 return true;
102 return {};
104 auto found = find(x.base());
105 if (!found && ((x.base().Rank() == 0 && x.Rank() > 0) ||
106 intersection(x.subscript()))) {
107 bases.push_back(&x);
108 return true;
110 return found;
112 RT find(const Fortran::evaluate::Triplet &x) {
113 if (const auto *lower = x.GetLower())
114 (void)find(*lower);
115 if (const auto *upper = x.GetUpper())
116 (void)find(*upper);
117 return find(x.GetStride());
119 RT find(const Fortran::evaluate::IndirectSubscriptIntegerExpr &x) {
120 return find(x.value());
122 RT find(const Fortran::evaluate::Subscript &x) { return find(x.u); }
123 RT find(const Fortran::evaluate::DataRef &x) { return find(x.u); }
124 RT find(const Fortran::evaluate::CoarrayRef &x) {
125 assert(false && "coarray reference");
126 return {};
129 template <typename A>
130 bool intersection(const A &subscripts) {
131 return Fortran::lower::symbolsIntersectSubscripts(controlVars, subscripts);
134 // The rest is traversal boilerplate and can be ignored.
135 RT find(const Fortran::evaluate::Substring &x) { return find(x.parent()); }
136 template <typename A>
137 RT find(const Fortran::semantics::SymbolRef x) {
138 return find(*x);
140 RT find(const Fortran::evaluate::NamedEntity &x) {
141 if (x.IsSymbol())
142 return find(x.GetFirstSymbol());
143 return find(x.GetComponent());
146 template <typename A, bool C>
147 RT find(const Fortran::common::Indirection<A, C> &x) {
148 return find(x.value());
150 template <typename A>
151 RT find(const std::unique_ptr<A> &x) {
152 return find(x.get());
154 template <typename A>
155 RT find(const std::shared_ptr<A> &x) {
156 return find(x.get());
158 template <typename A>
159 RT find(const A *x) {
160 if (x)
161 return find(*x);
162 return {};
164 template <typename A>
165 RT find(const std::optional<A> &x) {
166 if (x)
167 return find(*x);
168 return {};
170 template <typename... A>
171 RT find(const std::variant<A...> &u) {
172 return Fortran::common::visit([&](const auto &v) { return find(v); }, u);
174 template <typename A>
175 RT find(const std::vector<A> &x) {
176 for (auto &v : x)
177 (void)find(v);
178 return {};
180 RT find(const Fortran::evaluate::BOZLiteralConstant &) { return {}; }
181 RT find(const Fortran::evaluate::NullPointer &) { return {}; }
182 template <typename T>
183 RT find(const Fortran::evaluate::Constant<T> &x) {
184 return {};
186 RT find(const Fortran::evaluate::StaticDataObject &) { return {}; }
187 RT find(const Fortran::evaluate::ImpliedDoIndex &) { return {}; }
188 RT find(const Fortran::evaluate::BaseObject &x) {
189 (void)find(x.u);
190 return {};
192 RT find(const Fortran::evaluate::TypeParamInquiry &) { return {}; }
193 RT find(const Fortran::evaluate::ComplexPart &x) { return {}; }
194 template <typename T>
195 RT find(const Fortran::evaluate::Designator<T> &x) {
196 return find(x.u);
198 template <typename T>
199 RT find(const Fortran::evaluate::Variable<T> &x) {
200 return find(x.u);
202 RT find(const Fortran::evaluate::DescriptorInquiry &) { return {}; }
203 RT find(const Fortran::evaluate::SpecificIntrinsic &) { return {}; }
204 RT find(const Fortran::evaluate::ProcedureDesignator &x) { return {}; }
205 RT find(const Fortran::evaluate::ProcedureRef &x) {
206 (void)find(x.proc());
207 if (x.IsElemental())
208 (void)find(x.arguments());
209 return {};
211 RT find(const Fortran::evaluate::ActualArgument &x) {
212 if (const auto *sym = x.GetAssumedTypeDummy())
213 (void)find(*sym);
214 else
215 (void)find(x.UnwrapExpr());
216 return {};
218 template <typename T>
219 RT find(const Fortran::evaluate::FunctionRef<T> &x) {
220 (void)find(static_cast<const Fortran::evaluate::ProcedureRef &>(x));
221 return {};
223 template <typename T>
224 RT find(const Fortran::evaluate::ArrayConstructorValue<T> &) {
225 return {};
227 template <typename T>
228 RT find(const Fortran::evaluate::ArrayConstructorValues<T> &) {
229 return {};
231 template <typename T>
232 RT find(const Fortran::evaluate::ImpliedDo<T> &) {
233 return {};
235 RT find(const Fortran::semantics::ParamValue &) { return {}; }
236 RT find(const Fortran::semantics::DerivedTypeSpec &) { return {}; }
237 RT find(const Fortran::evaluate::StructureConstructor &) { return {}; }
238 template <typename D, typename R, typename O>
239 RT find(const Fortran::evaluate::Operation<D, R, O> &op) {
240 (void)find(op.left());
241 return false;
243 template <typename D, typename R, typename LO, typename RO>
244 RT find(const Fortran::evaluate::Operation<D, R, LO, RO> &op) {
245 (void)find(op.left());
246 (void)find(op.right());
247 return false;
249 RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
250 (void)find(x.u);
251 return {};
253 template <typename T>
254 RT find(const Fortran::evaluate::Expr<T> &x) {
255 (void)find(x.u);
256 return {};
259 llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> bases;
260 llvm::SmallVector<Fortran::lower::FrontEndSymbol> controlVars;
263 } // namespace
265 void Fortran::lower::ExplicitIterSpace::leave() {
266 ccLoopNest.pop_back();
267 --forallContextOpen;
268 conditionalCleanup();
271 void Fortran::lower::ExplicitIterSpace::addSymbol(
272 Fortran::lower::FrontEndSymbol sym) {
273 assert(!symbolStack.empty());
274 symbolStack.back().push_back(sym);
277 void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x,
278 bool lhs) {
279 ArrayBaseFinder finder(collectAllSymbols());
280 finder(*x);
281 llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> bases =
282 finder.getBases();
283 if (rhsBases.empty())
284 endAssign();
285 if (lhs) {
286 if (bases.empty()) {
287 lhsBases.push_back(std::nullopt);
288 return;
290 assert(bases.size() >= 1 && "must detect an array reference on lhs");
291 if (bases.size() > 1)
292 rhsBases.back().append(bases.begin(), bases.end() - 1);
293 lhsBases.push_back(bases.back());
294 return;
296 rhsBases.back().append(bases.begin(), bases.end());
299 void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases.emplace_back(); }
301 void Fortran::lower::ExplicitIterSpace::pushLevel() {
302 symbolStack.push_back(llvm::SmallVector<Fortran::lower::FrontEndSymbol>{});
305 void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack.pop_back(); }
307 void Fortran::lower::ExplicitIterSpace::conditionalCleanup() {
308 if (forallContextOpen == 0) {
309 // Exiting the outermost FORALL context.
310 // Cleanup any residual mask buffers.
311 outermostContext().finalizeAndReset();
312 // Clear and reset all the cached information.
313 symbolStack.clear();
314 lhsBases.clear();
315 rhsBases.clear();
316 loadBindings.clear();
317 ccLoopNest.clear();
318 innerArgs.clear();
319 outerLoop = std::nullopt;
320 clearLoops();
321 counter = 0;
325 std::optional<size_t>
326 Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load) {
327 if (lhsBases[counter]) {
328 auto ld = loadBindings.find(*lhsBases[counter]);
329 std::optional<size_t> optPos;
330 if (ld != loadBindings.end() && ld->second == load)
331 optPos = static_cast<size_t>(0u);
332 assert(optPos.has_value() && "load does not correspond to lhs");
333 return optPos;
335 return std::nullopt;
338 llvm::SmallVector<Fortran::lower::FrontEndSymbol>
339 Fortran::lower::ExplicitIterSpace::collectAllSymbols() {
340 llvm::SmallVector<Fortran::lower::FrontEndSymbol> result;
341 for (llvm::SmallVector<FrontEndSymbol> vec : symbolStack)
342 result.append(vec.begin(), vec.end());
343 return result;
346 llvm::raw_ostream &
347 Fortran::lower::operator<<(llvm::raw_ostream &s,
348 const Fortran::lower::ImplicitIterSpace &e) {
349 for (const llvm::SmallVector<
350 Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr> &xs :
351 e.getMasks()) {
352 s << "{ ";
353 for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr &x : xs)
354 x->AsFortran(s << '(') << "), ";
355 s << "}\n";
357 return s;
360 llvm::raw_ostream &
361 Fortran::lower::operator<<(llvm::raw_ostream &s,
362 const Fortran::lower::ExplicitIterSpace &e) {
363 auto dump = [&](const auto &u) {
364 Fortran::common::visit(
365 Fortran::common::visitors{
366 [&](const Fortran::semantics::Symbol *y) {
367 s << " " << *y << '\n';
369 [&](const Fortran::evaluate::ArrayRef *y) {
370 s << " ";
371 if (y->base().IsSymbol())
372 s << y->base().GetFirstSymbol();
373 else
374 s << y->base().GetComponent().GetLastSymbol();
375 s << '\n';
377 [&](const Fortran::evaluate::Component *y) {
378 s << " " << y->GetLastSymbol() << '\n';
382 s << "LHS bases:\n";
383 for (const std::optional<Fortran::lower::ExplicitIterSpace::ArrayBases> &u :
384 e.lhsBases)
385 if (u)
386 dump(*u);
387 s << "RHS bases:\n";
388 for (const llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases>
389 &bases : e.rhsBases) {
390 for (const Fortran::lower::ExplicitIterSpace::ArrayBases &u : bases)
391 dump(u);
392 s << '\n';
394 return s;
397 void Fortran::lower::ImplicitIterSpace::dump() const {
398 llvm::errs() << *this << '\n';
401 void Fortran::lower::ExplicitIterSpace::dump() const {
402 llvm::errs() << *this << '\n';