1 //===-- IterationSpace.cpp ------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11 //===----------------------------------------------------------------------===//
13 #include "flang/Lower/IterationSpace.h"
14 #include "flang/Evaluate/expression.h"
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Lower/Support/Utils.h"
17 #include "llvm/Support/Debug.h"
20 #define DEBUG_TYPE "flang-lower-iteration-space"
22 unsigned Fortran::lower::getHashValue(
23 const Fortran::lower::ExplicitIterSpace::ArrayBases
&x
) {
24 return Fortran::common::visit(
25 [&](const auto *p
) { return HashEvaluateExpr::getHashValue(*p
); }, x
);
28 bool Fortran::lower::isEqual(
29 const Fortran::lower::ExplicitIterSpace::ArrayBases
&x
,
30 const Fortran::lower::ExplicitIterSpace::ArrayBases
&y
) {
31 return Fortran::common::visit(
32 Fortran::common::visitors
{
33 // Fortran::semantics::Symbol * are the exception here. These pointers
34 // have identity; if two Symbol * values are the same (different) then
35 // they are the same (different) logical symbol.
36 [&](Fortran::lower::FrontEndSymbol p
,
37 Fortran::lower::FrontEndSymbol q
) { return p
== q
; },
38 [&](const auto *p
, const auto *q
) {
39 if constexpr (std::is_same_v
<decltype(p
), decltype(q
)>) {
40 LLVM_DEBUG(llvm::dbgs()
41 << "is equal: " << p
<< ' ' << q
<< ' '
42 << IsEqualEvaluateExpr::isEqual(*p
, *q
) << '\n');
43 return IsEqualEvaluateExpr::isEqual(*p
, *q
);
45 // Different subtree types are never equal.
54 /// This class can recover the base array in an expression that contains
55 /// explicit iteration space symbols. Most of the class can be ignored as it is
56 /// boilerplate Fortran::evaluate::Expr traversal.
57 class ArrayBaseFinder
{
61 ArrayBaseFinder(llvm::ArrayRef
<Fortran::lower::FrontEndSymbol
> syms
)
62 : controlVars(syms
.begin(), syms
.end()) {}
65 void operator()(const T
&x
) {
69 /// Get the list of bases.
70 llvm::ArrayRef
<Fortran::lower::ExplicitIterSpace::ArrayBases
>
72 LLVM_DEBUG(llvm::dbgs()
73 << "number of array bases found: " << bases
.size() << '\n');
78 // First, the cases that are of interest.
79 RT
find(const Fortran::semantics::Symbol
&symbol
) {
80 if (symbol
.Rank() > 0) {
81 bases
.push_back(&symbol
);
86 RT
find(const Fortran::evaluate::Component
&x
) {
87 auto found
= find(x
.base());
88 if (!found
&& x
.base().Rank() == 0 && x
.Rank() > 0) {
94 RT
find(const Fortran::evaluate::ArrayRef
&x
) {
95 for (const auto &sub
: x
.subscript())
97 if (x
.base().IsSymbol()) {
98 if (x
.Rank() > 0 || intersection(x
.subscript())) {
104 auto found
= find(x
.base());
105 if (!found
&& ((x
.base().Rank() == 0 && x
.Rank() > 0) ||
106 intersection(x
.subscript()))) {
112 RT
find(const Fortran::evaluate::Triplet
&x
) {
113 if (const auto *lower
= x
.GetLower())
115 if (const auto *upper
= x
.GetUpper())
117 return find(x
.GetStride());
119 RT
find(const Fortran::evaluate::IndirectSubscriptIntegerExpr
&x
) {
120 return find(x
.value());
122 RT
find(const Fortran::evaluate::Subscript
&x
) { return find(x
.u
); }
123 RT
find(const Fortran::evaluate::DataRef
&x
) { return find(x
.u
); }
124 RT
find(const Fortran::evaluate::CoarrayRef
&x
) {
125 assert(false && "coarray reference");
129 template <typename A
>
130 bool intersection(const A
&subscripts
) {
131 return Fortran::lower::symbolsIntersectSubscripts(controlVars
, subscripts
);
134 // The rest is traversal boilerplate and can be ignored.
135 RT
find(const Fortran::evaluate::Substring
&x
) { return find(x
.parent()); }
136 template <typename A
>
137 RT
find(const Fortran::semantics::SymbolRef x
) {
140 RT
find(const Fortran::evaluate::NamedEntity
&x
) {
142 return find(x
.GetFirstSymbol());
143 return find(x
.GetComponent());
146 template <typename A
, bool C
>
147 RT
find(const Fortran::common::Indirection
<A
, C
> &x
) {
148 return find(x
.value());
150 template <typename A
>
151 RT
find(const std::unique_ptr
<A
> &x
) {
152 return find(x
.get());
154 template <typename A
>
155 RT
find(const std::shared_ptr
<A
> &x
) {
156 return find(x
.get());
158 template <typename A
>
159 RT
find(const A
*x
) {
164 template <typename A
>
165 RT
find(const std::optional
<A
> &x
) {
170 template <typename
... A
>
171 RT
find(const std::variant
<A
...> &u
) {
172 return Fortran::common::visit([&](const auto &v
) { return find(v
); }, u
);
174 template <typename A
>
175 RT
find(const std::vector
<A
> &x
) {
180 RT
find(const Fortran::evaluate::BOZLiteralConstant
&) { return {}; }
181 RT
find(const Fortran::evaluate::NullPointer
&) { return {}; }
182 template <typename T
>
183 RT
find(const Fortran::evaluate::Constant
<T
> &x
) {
186 RT
find(const Fortran::evaluate::StaticDataObject
&) { return {}; }
187 RT
find(const Fortran::evaluate::ImpliedDoIndex
&) { return {}; }
188 RT
find(const Fortran::evaluate::BaseObject
&x
) {
192 RT
find(const Fortran::evaluate::TypeParamInquiry
&) { return {}; }
193 RT
find(const Fortran::evaluate::ComplexPart
&x
) { return {}; }
194 template <typename T
>
195 RT
find(const Fortran::evaluate::Designator
<T
> &x
) {
198 template <typename T
>
199 RT
find(const Fortran::evaluate::Variable
<T
> &x
) {
202 RT
find(const Fortran::evaluate::DescriptorInquiry
&) { return {}; }
203 RT
find(const Fortran::evaluate::SpecificIntrinsic
&) { return {}; }
204 RT
find(const Fortran::evaluate::ProcedureDesignator
&x
) { return {}; }
205 RT
find(const Fortran::evaluate::ProcedureRef
&x
) {
206 (void)find(x
.proc());
208 (void)find(x
.arguments());
211 RT
find(const Fortran::evaluate::ActualArgument
&x
) {
212 if (const auto *sym
= x
.GetAssumedTypeDummy())
215 (void)find(x
.UnwrapExpr());
218 template <typename T
>
219 RT
find(const Fortran::evaluate::FunctionRef
<T
> &x
) {
220 (void)find(static_cast<const Fortran::evaluate::ProcedureRef
&>(x
));
223 template <typename T
>
224 RT
find(const Fortran::evaluate::ArrayConstructorValue
<T
> &) {
227 template <typename T
>
228 RT
find(const Fortran::evaluate::ArrayConstructorValues
<T
> &) {
231 template <typename T
>
232 RT
find(const Fortran::evaluate::ImpliedDo
<T
> &) {
235 RT
find(const Fortran::semantics::ParamValue
&) { return {}; }
236 RT
find(const Fortran::semantics::DerivedTypeSpec
&) { return {}; }
237 RT
find(const Fortran::evaluate::StructureConstructor
&) { return {}; }
238 template <typename D
, typename R
, typename O
>
239 RT
find(const Fortran::evaluate::Operation
<D
, R
, O
> &op
) {
240 (void)find(op
.left());
243 template <typename D
, typename R
, typename LO
, typename RO
>
244 RT
find(const Fortran::evaluate::Operation
<D
, R
, LO
, RO
> &op
) {
245 (void)find(op
.left());
246 (void)find(op
.right());
249 RT
find(const Fortran::evaluate::Relational
<Fortran::evaluate::SomeType
> &x
) {
253 template <typename T
>
254 RT
find(const Fortran::evaluate::Expr
<T
> &x
) {
259 llvm::SmallVector
<Fortran::lower::ExplicitIterSpace::ArrayBases
> bases
;
260 llvm::SmallVector
<Fortran::lower::FrontEndSymbol
> controlVars
;
265 void Fortran::lower::ExplicitIterSpace::leave() {
266 ccLoopNest
.pop_back();
268 conditionalCleanup();
271 void Fortran::lower::ExplicitIterSpace::addSymbol(
272 Fortran::lower::FrontEndSymbol sym
) {
273 assert(!symbolStack
.empty());
274 symbolStack
.back().push_back(sym
);
277 void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x
,
279 ArrayBaseFinder
finder(collectAllSymbols());
281 llvm::ArrayRef
<Fortran::lower::ExplicitIterSpace::ArrayBases
> bases
=
283 if (rhsBases
.empty())
287 lhsBases
.push_back(std::nullopt
);
290 assert(bases
.size() >= 1 && "must detect an array reference on lhs");
291 if (bases
.size() > 1)
292 rhsBases
.back().append(bases
.begin(), bases
.end() - 1);
293 lhsBases
.push_back(bases
.back());
296 rhsBases
.back().append(bases
.begin(), bases
.end());
299 void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases
.emplace_back(); }
301 void Fortran::lower::ExplicitIterSpace::pushLevel() {
302 symbolStack
.push_back(llvm::SmallVector
<Fortran::lower::FrontEndSymbol
>{});
305 void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack
.pop_back(); }
307 void Fortran::lower::ExplicitIterSpace::conditionalCleanup() {
308 if (forallContextOpen
== 0) {
309 // Exiting the outermost FORALL context.
310 // Cleanup any residual mask buffers.
311 outermostContext().finalizeAndReset();
312 // Clear and reset all the cached information.
316 loadBindings
.clear();
319 outerLoop
= std::nullopt
;
325 std::optional
<size_t>
326 Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load
) {
327 if (lhsBases
[counter
]) {
328 auto ld
= loadBindings
.find(*lhsBases
[counter
]);
329 std::optional
<size_t> optPos
;
330 if (ld
!= loadBindings
.end() && ld
->second
== load
)
331 optPos
= static_cast<size_t>(0u);
332 assert(optPos
.has_value() && "load does not correspond to lhs");
338 llvm::SmallVector
<Fortran::lower::FrontEndSymbol
>
339 Fortran::lower::ExplicitIterSpace::collectAllSymbols() {
340 llvm::SmallVector
<Fortran::lower::FrontEndSymbol
> result
;
341 for (llvm::SmallVector
<FrontEndSymbol
> vec
: symbolStack
)
342 result
.append(vec
.begin(), vec
.end());
347 Fortran::lower::operator<<(llvm::raw_ostream
&s
,
348 const Fortran::lower::ImplicitIterSpace
&e
) {
349 for (const llvm::SmallVector
<
350 Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr
> &xs
:
353 for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr
&x
: xs
)
354 x
->AsFortran(s
<< '(') << "), ";
361 Fortran::lower::operator<<(llvm::raw_ostream
&s
,
362 const Fortran::lower::ExplicitIterSpace
&e
) {
363 auto dump
= [&](const auto &u
) {
364 Fortran::common::visit(
365 Fortran::common::visitors
{
366 [&](const Fortran::semantics::Symbol
*y
) {
367 s
<< " " << *y
<< '\n';
369 [&](const Fortran::evaluate::ArrayRef
*y
) {
371 if (y
->base().IsSymbol())
372 s
<< y
->base().GetFirstSymbol();
374 s
<< y
->base().GetComponent().GetLastSymbol();
377 [&](const Fortran::evaluate::Component
*y
) {
378 s
<< " " << y
->GetLastSymbol() << '\n';
383 for (const std::optional
<Fortran::lower::ExplicitIterSpace::ArrayBases
> &u
:
388 for (const llvm::SmallVector
<Fortran::lower::ExplicitIterSpace::ArrayBases
>
389 &bases
: e
.rhsBases
) {
390 for (const Fortran::lower::ExplicitIterSpace::ArrayBases
&u
: bases
)
397 void Fortran::lower::ImplicitIterSpace::dump() const {
398 llvm::errs() << *this << '\n';
401 void Fortran::lower::ExplicitIterSpace::dump() const {
402 llvm::errs() << *this << '\n';