1 //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/SymbolTable.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/OpImplementation.h"
12 #include "llvm/ADT/SetVector.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallString.h"
15 #include "llvm/ADT/StringSwitch.h"
20 /// Return true if the given operation is unknown and may potentially define a
22 static bool isPotentiallyUnknownSymbolTable(Operation
*op
) {
23 return op
->getNumRegions() == 1 && !op
->getDialect();
26 /// Returns the string name of the given symbol, or null if this is not a
28 static StringAttr
getNameIfSymbol(Operation
*op
) {
29 return op
->getAttrOfType
<StringAttr
>(SymbolTable::getSymbolAttrName());
31 static StringAttr
getNameIfSymbol(Operation
*op
, StringAttr symbolAttrNameId
) {
32 return op
->getAttrOfType
<StringAttr
>(symbolAttrNameId
);
35 /// Computes the nested symbol reference attribute for the symbol 'symbolName'
36 /// that are usable within the symbol table operations from 'symbol' as far up
37 /// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
38 /// Returns success if all references up to 'within' could be computed.
40 collectValidReferencesFor(Operation
*symbol
, StringAttr symbolName
,
42 SmallVectorImpl
<SymbolRefAttr
> &results
) {
43 assert(within
->isAncestor(symbol
) && "expected 'within' to be an ancestor");
44 MLIRContext
*ctx
= symbol
->getContext();
46 auto leafRef
= FlatSymbolRefAttr::get(symbolName
);
47 results
.push_back(leafRef
);
49 // Early exit for when 'within' is the parent of 'symbol'.
50 Operation
*symbolTableOp
= symbol
->getParentOp();
51 if (within
== symbolTableOp
)
54 // Collect references until 'symbolTableOp' reaches 'within'.
55 SmallVector
<FlatSymbolRefAttr
, 1> nestedRefs(1, leafRef
);
56 StringAttr symbolNameId
=
57 StringAttr::get(ctx
, SymbolTable::getSymbolAttrName());
59 // Each parent of 'symbol' should define a symbol table.
60 if (!symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>())
62 // Each parent of 'symbol' should also be a symbol.
63 StringAttr symbolTableName
= getNameIfSymbol(symbolTableOp
, symbolNameId
);
66 results
.push_back(SymbolRefAttr::get(symbolTableName
, nestedRefs
));
68 symbolTableOp
= symbolTableOp
->getParentOp();
69 if (symbolTableOp
== within
)
71 nestedRefs
.insert(nestedRefs
.begin(),
72 FlatSymbolRefAttr::get(symbolTableName
));
77 /// Walk all of the operations within the given set of regions, without
78 /// traversing into any nested symbol tables. Stops walking if the result of the
79 /// callback is anything other than `WalkResult::advance`.
80 static std::optional
<WalkResult
>
81 walkSymbolTable(MutableArrayRef
<Region
> regions
,
82 function_ref
<std::optional
<WalkResult
>(Operation
*)> callback
) {
83 SmallVector
<Region
*, 1> worklist(llvm::make_pointer_range(regions
));
84 while (!worklist
.empty()) {
85 for (Operation
&op
: worklist
.pop_back_val()->getOps()) {
86 std::optional
<WalkResult
> result
= callback(&op
);
87 if (result
!= WalkResult::advance())
90 // If this op defines a new symbol table scope, we can't traverse. Any
91 // symbol references nested within 'op' are different semantically.
92 if (!op
.hasTrait
<OpTrait::SymbolTable
>()) {
93 for (Region
®ion
: op
.getRegions())
94 worklist
.push_back(®ion
);
98 return WalkResult::advance();
101 /// Walk all of the operations nested under, and including, the given operation,
102 /// without traversing into any nested symbol tables. Stops walking if the
103 /// result of the callback is anything other than `WalkResult::advance`.
104 static std::optional
<WalkResult
>
105 walkSymbolTable(Operation
*op
,
106 function_ref
<std::optional
<WalkResult
>(Operation
*)> callback
) {
107 std::optional
<WalkResult
> result
= callback(op
);
108 if (result
!= WalkResult::advance() || op
->hasTrait
<OpTrait::SymbolTable
>())
110 return walkSymbolTable(op
->getRegions(), callback
);
113 //===----------------------------------------------------------------------===//
115 //===----------------------------------------------------------------------===//
117 /// Build a symbol table with the symbols within the given operation.
118 SymbolTable::SymbolTable(Operation
*symbolTableOp
)
119 : symbolTableOp(symbolTableOp
) {
120 assert(symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>() &&
121 "expected operation to have SymbolTable trait");
122 assert(symbolTableOp
->getNumRegions() == 1 &&
123 "expected operation to have a single region");
124 assert(llvm::hasSingleElement(symbolTableOp
->getRegion(0)) &&
125 "expected operation to have a single block");
127 StringAttr symbolNameId
= StringAttr::get(symbolTableOp
->getContext(),
128 SymbolTable::getSymbolAttrName());
129 for (auto &op
: symbolTableOp
->getRegion(0).front()) {
130 StringAttr name
= getNameIfSymbol(&op
, symbolNameId
);
134 auto inserted
= symbolTable
.insert({name
, &op
});
136 assert(inserted
.second
&&
137 "expected region to contain uniquely named symbol operations");
141 /// Look up a symbol with the specified name, returning null if no such name
142 /// exists. Names never include the @ on them.
143 Operation
*SymbolTable::lookup(StringRef name
) const {
144 return lookup(StringAttr::get(symbolTableOp
->getContext(), name
));
146 Operation
*SymbolTable::lookup(StringAttr name
) const {
147 return symbolTable
.lookup(name
);
150 void SymbolTable::remove(Operation
*op
) {
151 StringAttr name
= getNameIfSymbol(op
);
152 assert(name
&& "expected valid 'name' attribute");
153 assert(op
->getParentOp() == symbolTableOp
&&
154 "expected this operation to be inside of the operation with this "
157 auto it
= symbolTable
.find(name
);
158 if (it
!= symbolTable
.end() && it
->second
== op
)
159 symbolTable
.erase(it
);
162 void SymbolTable::erase(Operation
*symbol
) {
167 // TODO: Consider if this should be renamed to something like insertOrUpdate
168 /// Insert a new symbol into the table and associated operation if not already
169 /// there and rename it as necessary to avoid collisions. Return the name of
170 /// the symbol after insertion as attribute.
171 StringAttr
SymbolTable::insert(Operation
*symbol
, Block::iterator insertPt
) {
172 // The symbol cannot be the child of another op and must be the child of the
173 // symbolTableOp after this.
175 // TODO: consider if SymbolTable's constructor should behave the same.
176 if (!symbol
->getParentOp()) {
177 auto &body
= symbolTableOp
->getRegion(0).front();
178 if (insertPt
== Block::iterator()) {
179 insertPt
= Block::iterator(body
.end());
181 assert((insertPt
== body
.end() ||
182 insertPt
->getParentOp() == symbolTableOp
) &&
183 "expected insertPt to be in the associated module operation");
185 // Insert before the terminator, if any.
186 if (insertPt
== Block::iterator(body
.end()) && !body
.empty() &&
187 std::prev(body
.end())->hasTrait
<OpTrait::IsTerminator
>())
188 insertPt
= std::prev(body
.end());
190 body
.getOperations().insert(insertPt
, symbol
);
192 assert(symbol
->getParentOp() == symbolTableOp
&&
193 "symbol is already inserted in another op");
195 // Add this symbol to the symbol table, uniquing the name if a conflict is
197 StringAttr name
= getSymbolName(symbol
);
198 if (symbolTable
.insert({name
, symbol
}).second
)
200 // If the symbol was already in the table, also return.
201 if (symbolTable
.lookup(name
) == symbol
)
203 // If a conflict was detected, then the symbol will not have been added to
204 // the symbol table. Try suffixes until we get to a unique name that works.
205 SmallString
<128> nameBuffer(name
.getValue());
206 unsigned originalLength
= nameBuffer
.size();
208 MLIRContext
*context
= symbol
->getContext();
210 // Iteratively try suffixes until we find one that isn't used.
212 nameBuffer
.resize(originalLength
);
214 nameBuffer
+= std::to_string(uniquingCounter
++);
215 } while (!symbolTable
.insert({StringAttr::get(context
, nameBuffer
), symbol
})
217 setSymbolName(symbol
, nameBuffer
);
218 return getSymbolName(symbol
);
221 /// Returns the name of the given symbol operation.
222 StringAttr
SymbolTable::getSymbolName(Operation
*symbol
) {
223 StringAttr name
= getNameIfSymbol(symbol
);
224 assert(name
&& "expected valid symbol name");
228 /// Sets the name of the given symbol operation.
229 void SymbolTable::setSymbolName(Operation
*symbol
, StringAttr name
) {
230 symbol
->setAttr(getSymbolAttrName(), name
);
233 /// Returns the visibility of the given symbol operation.
234 SymbolTable::Visibility
SymbolTable::getSymbolVisibility(Operation
*symbol
) {
235 // If the attribute doesn't exist, assume public.
236 StringAttr vis
= symbol
->getAttrOfType
<StringAttr
>(getVisibilityAttrName());
238 return Visibility::Public
;
240 // Otherwise, switch on the string value.
241 return StringSwitch
<Visibility
>(vis
.getValue())
242 .Case("private", Visibility::Private
)
243 .Case("nested", Visibility::Nested
)
244 .Case("public", Visibility::Public
);
246 /// Sets the visibility of the given symbol operation.
247 void SymbolTable::setSymbolVisibility(Operation
*symbol
, Visibility vis
) {
248 MLIRContext
*ctx
= symbol
->getContext();
250 // If the visibility is public, just drop the attribute as this is the
252 if (vis
== Visibility::Public
) {
253 symbol
->removeAttr(StringAttr::get(ctx
, getVisibilityAttrName()));
257 // Otherwise, update the attribute.
258 assert((vis
== Visibility::Private
|| vis
== Visibility::Nested
) &&
259 "unknown symbol visibility kind");
261 StringRef visName
= vis
== Visibility::Private
? "private" : "nested";
262 symbol
->setAttr(getVisibilityAttrName(), StringAttr::get(ctx
, visName
));
265 /// Returns the nearest symbol table from a given operation `from`. Returns
266 /// nullptr if no valid parent symbol table could be found.
267 Operation
*SymbolTable::getNearestSymbolTable(Operation
*from
) {
268 assert(from
&& "expected valid operation");
269 if (isPotentiallyUnknownSymbolTable(from
))
272 while (!from
->hasTrait
<OpTrait::SymbolTable
>()) {
273 from
= from
->getParentOp();
275 // Check that this is a valid op and isn't an unknown symbol table.
276 if (!from
|| isPotentiallyUnknownSymbolTable(from
))
282 /// Walks all symbol table operations nested within, and including, `op`. For
283 /// each symbol table operation, the provided callback is invoked with the op
284 /// and a boolean signifying if the symbols within that symbol table can be
285 /// treated as if all uses are visible. `allSymUsesVisible` identifies whether
286 /// all of the symbol uses of symbols within `op` are visible.
287 void SymbolTable::walkSymbolTables(
288 Operation
*op
, bool allSymUsesVisible
,
289 function_ref
<void(Operation
*, bool)> callback
) {
290 bool isSymbolTable
= op
->hasTrait
<OpTrait::SymbolTable
>();
292 SymbolOpInterface symbol
= dyn_cast
<SymbolOpInterface
>(op
);
293 allSymUsesVisible
|= !symbol
|| symbol
.isPrivate();
295 // Otherwise if 'op' is not a symbol table, any nested symbols are
296 // guaranteed to be hidden.
297 allSymUsesVisible
= true;
300 for (Region
®ion
: op
->getRegions())
301 for (Block
&block
: region
)
302 for (Operation
&nestedOp
: block
)
303 walkSymbolTables(&nestedOp
, allSymUsesVisible
, callback
);
305 // If 'op' had the symbol table trait, visit it after any nested symbol
308 callback(op
, allSymUsesVisible
);
311 /// Returns the operation registered with the given symbol name with the
312 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
313 /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
315 Operation
*SymbolTable::lookupSymbolIn(Operation
*symbolTableOp
,
317 assert(symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>());
318 Region
®ion
= symbolTableOp
->getRegion(0);
322 // Look for a symbol with the given name.
323 StringAttr symbolNameId
= StringAttr::get(symbolTableOp
->getContext(),
324 SymbolTable::getSymbolAttrName());
325 for (auto &op
: region
.front())
326 if (getNameIfSymbol(&op
, symbolNameId
) == symbol
)
330 Operation
*SymbolTable::lookupSymbolIn(Operation
*symbolTableOp
,
331 SymbolRefAttr symbol
) {
332 SmallVector
<Operation
*, 4> resolvedSymbols
;
333 if (failed(lookupSymbolIn(symbolTableOp
, symbol
, resolvedSymbols
)))
335 return resolvedSymbols
.back();
338 /// Internal implementation of `lookupSymbolIn` that allows for specialized
339 /// implementations of the lookup function.
340 static LogicalResult
lookupSymbolInImpl(
341 Operation
*symbolTableOp
, SymbolRefAttr symbol
,
342 SmallVectorImpl
<Operation
*> &symbols
,
343 function_ref
<Operation
*(Operation
*, StringAttr
)> lookupSymbolFn
) {
344 assert(symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>());
346 // Lookup the root reference for this symbol.
347 symbolTableOp
= lookupSymbolFn(symbolTableOp
, symbol
.getRootReference());
350 symbols
.push_back(symbolTableOp
);
352 // If there are no nested references, just return the root symbol directly.
353 ArrayRef
<FlatSymbolRefAttr
> nestedRefs
= symbol
.getNestedReferences();
354 if (nestedRefs
.empty())
357 // Verify that the root is also a symbol table.
358 if (!symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>())
361 // Otherwise, lookup each of the nested non-leaf references and ensure that
362 // each corresponds to a valid symbol table.
363 for (FlatSymbolRefAttr ref
: nestedRefs
.drop_back()) {
364 symbolTableOp
= lookupSymbolFn(symbolTableOp
, ref
.getAttr());
365 if (!symbolTableOp
|| !symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>())
367 symbols
.push_back(symbolTableOp
);
369 symbols
.push_back(lookupSymbolFn(symbolTableOp
, symbol
.getLeafReference()));
370 return success(symbols
.back());
374 SymbolTable::lookupSymbolIn(Operation
*symbolTableOp
, SymbolRefAttr symbol
,
375 SmallVectorImpl
<Operation
*> &symbols
) {
376 auto lookupFn
= [](Operation
*symbolTableOp
, StringAttr symbol
) {
377 return lookupSymbolIn(symbolTableOp
, symbol
);
379 return lookupSymbolInImpl(symbolTableOp
, symbol
, symbols
, lookupFn
);
382 /// Returns the operation registered with the given symbol name within the
383 /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
384 /// nullptr if no valid symbol was found.
385 Operation
*SymbolTable::lookupNearestSymbolFrom(Operation
*from
,
387 Operation
*symbolTableOp
= getNearestSymbolTable(from
);
388 return symbolTableOp
? lookupSymbolIn(symbolTableOp
, symbol
) : nullptr;
390 Operation
*SymbolTable::lookupNearestSymbolFrom(Operation
*from
,
391 SymbolRefAttr symbol
) {
392 Operation
*symbolTableOp
= getNearestSymbolTable(from
);
393 return symbolTableOp
? lookupSymbolIn(symbolTableOp
, symbol
) : nullptr;
396 raw_ostream
&mlir::operator<<(raw_ostream
&os
,
397 SymbolTable::Visibility visibility
) {
398 switch (visibility
) {
399 case SymbolTable::Visibility::Public
:
400 return os
<< "public";
401 case SymbolTable::Visibility::Private
:
402 return os
<< "private";
403 case SymbolTable::Visibility::Nested
:
404 return os
<< "nested";
406 llvm_unreachable("Unexpected visibility");
409 //===----------------------------------------------------------------------===//
410 // SymbolTable Trait Types
411 //===----------------------------------------------------------------------===//
413 LogicalResult
detail::verifySymbolTable(Operation
*op
) {
414 if (op
->getNumRegions() != 1)
415 return op
->emitOpError()
416 << "Operations with a 'SymbolTable' must have exactly one region";
417 if (!llvm::hasSingleElement(op
->getRegion(0)))
418 return op
->emitOpError()
419 << "Operations with a 'SymbolTable' must have exactly one block";
421 // Check that all symbols are uniquely named within child regions.
422 DenseMap
<Attribute
, Location
> nameToOrigLoc
;
423 for (auto &block
: op
->getRegion(0)) {
424 for (auto &op
: block
) {
425 // Check for a symbol name attribute.
427 op
.getAttrOfType
<StringAttr
>(mlir::SymbolTable::getSymbolAttrName());
431 // Try to insert this symbol into the table.
432 auto it
= nameToOrigLoc
.try_emplace(nameAttr
, op
.getLoc());
434 return op
.emitError()
435 .append("redefinition of symbol named '", nameAttr
.getValue(), "'")
436 .attachNote(it
.first
->second
)
437 .append("see existing symbol definition here");
441 // Verify any nested symbol user operations.
442 SymbolTableCollection symbolTable
;
443 auto verifySymbolUserFn
= [&](Operation
*op
) -> std::optional
<WalkResult
> {
444 if (SymbolUserOpInterface user
= dyn_cast
<SymbolUserOpInterface
>(op
))
445 return WalkResult(user
.verifySymbolUses(symbolTable
));
446 return WalkResult::advance();
449 std::optional
<WalkResult
> result
=
450 walkSymbolTable(op
->getRegions(), verifySymbolUserFn
);
451 return success(result
&& !result
->wasInterrupted());
454 LogicalResult
detail::verifySymbol(Operation
*op
) {
455 // Verify the name attribute.
456 if (!op
->getAttrOfType
<StringAttr
>(mlir::SymbolTable::getSymbolAttrName()))
457 return op
->emitOpError() << "requires string attribute '"
458 << mlir::SymbolTable::getSymbolAttrName() << "'";
460 // Verify the visibility attribute.
461 if (Attribute vis
= op
->getAttr(mlir::SymbolTable::getVisibilityAttrName())) {
462 StringAttr visStrAttr
= llvm::dyn_cast
<StringAttr
>(vis
);
464 return op
->emitOpError() << "requires visibility attribute '"
465 << mlir::SymbolTable::getVisibilityAttrName()
466 << "' to be a string attribute, but got " << vis
;
468 if (!llvm::is_contained(ArrayRef
<StringRef
>{"public", "private", "nested"},
469 visStrAttr
.getValue()))
470 return op
->emitOpError()
471 << "visibility expected to be one of [\"public\", \"private\", "
472 "\"nested\"], but got "
478 //===----------------------------------------------------------------------===//
480 //===----------------------------------------------------------------------===//
482 /// Walk all of the symbol references within the given operation, invoking the
483 /// provided callback for each found use. The callbacks takes the use of the
486 walkSymbolRefs(Operation
*op
,
487 function_ref
<WalkResult(SymbolTable::SymbolUse
)> callback
) {
488 return op
->getAttrDictionary().walk
<WalkOrder::PreOrder
>(
489 [&](SymbolRefAttr symbolRef
) {
490 if (callback({op
, symbolRef
}).wasInterrupted())
491 return WalkResult::interrupt();
493 // Don't walk nested references.
494 return WalkResult::skip();
498 /// Walk all of the uses, for any symbol, that are nested within the given
499 /// regions, invoking the provided callback for each. This does not traverse
500 /// into any nested symbol tables.
501 static std::optional
<WalkResult
>
502 walkSymbolUses(MutableArrayRef
<Region
> regions
,
503 function_ref
<WalkResult(SymbolTable::SymbolUse
)> callback
) {
504 return walkSymbolTable(regions
,
505 [&](Operation
*op
) -> std::optional
<WalkResult
> {
506 // Check that this isn't a potentially unknown symbol
508 if (isPotentiallyUnknownSymbolTable(op
))
511 return walkSymbolRefs(op
, callback
);
514 /// Walk all of the uses, for any symbol, that are nested within the given
515 /// operation 'from', invoking the provided callback for each. This does not
516 /// traverse into any nested symbol tables.
517 static std::optional
<WalkResult
>
518 walkSymbolUses(Operation
*from
,
519 function_ref
<WalkResult(SymbolTable::SymbolUse
)> callback
) {
520 // If this operation has regions, and it, as well as its dialect, isn't
521 // registered then conservatively fail. The operation may define a
522 // symbol table, so we can't opaquely know if we should traverse to find
524 if (isPotentiallyUnknownSymbolTable(from
))
527 // Walk the uses on this operation.
528 if (walkSymbolRefs(from
, callback
).wasInterrupted())
529 return WalkResult::interrupt();
531 // Only recurse if this operation is not a symbol table. A symbol table
532 // defines a new scope, so we can't walk the attributes from within the symbol
534 if (!from
->hasTrait
<OpTrait::SymbolTable
>())
535 return walkSymbolUses(from
->getRegions(), callback
);
536 return WalkResult::advance();
540 /// This class represents a single symbol scope. A symbol scope represents the
541 /// set of operations nested within a symbol table that may reference symbols
542 /// within that table. A symbol scope does not contain the symbol table
543 /// operation itself, just its contained operations. A scope ends at leaf
544 /// operations or another symbol table operation.
546 /// Walk the symbol uses within this scope, invoking the given callback.
547 /// This variant is used when the callback type matches that expected by
548 /// 'walkSymbolUses'.
549 template <typename CallbackT
,
550 std::enable_if_t
<!std::is_same
<
551 typename
llvm::function_traits
<CallbackT
>::result_t
,
552 void>::value
> * = nullptr>
553 std::optional
<WalkResult
> walk(CallbackT cback
) {
554 if (Region
*region
= llvm::dyn_cast_if_present
<Region
*>(limit
))
555 return walkSymbolUses(*region
, cback
);
556 return walkSymbolUses(limit
.get
<Operation
*>(), cback
);
558 /// This variant is used when the callback type matches a stripped down type:
559 /// void(SymbolTable::SymbolUse use)
560 template <typename CallbackT
,
561 std::enable_if_t
<std::is_same
<
562 typename
llvm::function_traits
<CallbackT
>::result_t
,
563 void>::value
> * = nullptr>
564 std::optional
<WalkResult
> walk(CallbackT cback
) {
565 return walk([=](SymbolTable::SymbolUse use
) {
566 return cback(use
), WalkResult::advance();
570 /// Walk all of the operations nested under the current scope without
571 /// traversing into any nested symbol tables.
572 template <typename CallbackT
>
573 std::optional
<WalkResult
> walkSymbolTable(CallbackT
&&cback
) {
574 if (Region
*region
= llvm::dyn_cast_if_present
<Region
*>(limit
))
575 return ::walkSymbolTable(*region
, cback
);
576 return ::walkSymbolTable(limit
.get
<Operation
*>(), cback
);
579 /// The representation of the symbol within this scope.
580 SymbolRefAttr symbol
;
582 /// The IR unit representing this scope.
583 llvm::PointerUnion
<Operation
*, Region
*> limit
;
587 /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
588 static SmallVector
<SymbolScope
, 2> collectSymbolScopes(Operation
*symbol
,
590 StringAttr symName
= SymbolTable::getSymbolName(symbol
);
591 assert(!symbol
->hasTrait
<OpTrait::SymbolTable
>() || symbol
!= limit
);
593 // Compute the ancestors of 'limit'.
594 SetVector
<Operation
*, SmallVector
<Operation
*, 4>,
595 SmallPtrSet
<Operation
*, 4>>
597 Operation
*limitAncestor
= limit
;
599 // Check to see if 'symbol' is an ancestor of 'limit'.
600 if (limitAncestor
== symbol
) {
601 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
602 // doesn't support parent references.
603 if (SymbolTable::getNearestSymbolTable(limit
->getParentOp()) ==
604 symbol
->getParentOp())
605 return {{SymbolRefAttr::get(symName
), limit
}};
609 limitAncestors
.insert(limitAncestor
);
610 } while ((limitAncestor
= limitAncestor
->getParentOp()));
612 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
613 Operation
*commonAncestor
= symbol
->getParentOp();
615 if (limitAncestors
.count(commonAncestor
))
617 } while ((commonAncestor
= commonAncestor
->getParentOp()));
618 assert(commonAncestor
&& "'limit' and 'symbol' have no common ancestor");
620 // Compute the set of valid nested references for 'symbol' as far up to the
621 // common ancestor as possible.
622 SmallVector
<SymbolRefAttr
, 2> references
;
623 bool collectedAllReferences
= succeeded(
624 collectValidReferencesFor(symbol
, symName
, commonAncestor
, references
));
626 // Handle the case where the common ancestor is 'limit'.
627 if (commonAncestor
== limit
) {
628 SmallVector
<SymbolScope
, 2> scopes
;
630 // Walk each of the ancestors of 'symbol', calling the compute function for
632 Operation
*limitIt
= symbol
->getParentOp();
633 for (size_t i
= 0, e
= references
.size(); i
!= e
;
634 ++i
, limitIt
= limitIt
->getParentOp()) {
635 assert(limitIt
->hasTrait
<OpTrait::SymbolTable
>());
636 scopes
.push_back({references
[i
], &limitIt
->getRegion(0)});
641 // Otherwise, we just need the symbol reference for 'symbol' that will be
642 // used within 'limit'. This is the last reference in the list we computed
643 // above if we were able to collect all references.
644 if (!collectedAllReferences
)
646 return {{references
.back(), limit
}};
648 static SmallVector
<SymbolScope
, 2> collectSymbolScopes(Operation
*symbol
,
650 auto scopes
= collectSymbolScopes(symbol
, limit
->getParentOp());
652 // If we collected some scopes to walk, make sure to constrain the one for
653 // limit to the specific region requested.
655 scopes
.back().limit
= limit
;
658 template <typename IRUnit
>
659 static SmallVector
<SymbolScope
, 1> collectSymbolScopes(StringAttr symbol
,
661 return {{SymbolRefAttr::get(symbol
), limit
}};
664 /// Returns true if the given reference 'SubRef' is a sub reference of the
665 /// reference 'ref', i.e. 'ref' is a further qualified reference.
666 static bool isReferencePrefixOf(SymbolRefAttr subRef
, SymbolRefAttr ref
) {
670 // If the references are not pointer equal, check to see if `subRef` is a
672 if (llvm::isa
<FlatSymbolRefAttr
>(ref
) ||
673 ref
.getRootReference() != subRef
.getRootReference())
676 auto refLeafs
= ref
.getNestedReferences();
677 auto subRefLeafs
= subRef
.getNestedReferences();
678 return subRefLeafs
.size() < refLeafs
.size() &&
679 subRefLeafs
== refLeafs
.take_front(subRefLeafs
.size());
682 //===----------------------------------------------------------------------===//
683 // SymbolTable::getSymbolUses
685 /// The implementation of SymbolTable::getSymbolUses below.
686 template <typename FromT
>
687 static std::optional
<SymbolTable::UseRange
> getSymbolUsesImpl(FromT from
) {
688 std::vector
<SymbolTable::SymbolUse
> uses
;
689 auto walkFn
= [&](SymbolTable::SymbolUse symbolUse
) {
690 uses
.push_back(symbolUse
);
691 return WalkResult::advance();
693 auto result
= walkSymbolUses(from
, walkFn
);
694 return result
? std::optional
<SymbolTable::UseRange
>(std::move(uses
))
698 /// Get an iterator range for all of the uses, for any symbol, that are nested
699 /// within the given operation 'from'. This does not traverse into any nested
700 /// symbol tables, and will also only return uses on 'from' if it does not
701 /// also define a symbol table. This is because we treat the region as the
702 /// boundary of the symbol table, and not the op itself. This function returns
703 /// std::nullopt if there are any unknown operations that may potentially be
705 auto SymbolTable::getSymbolUses(Operation
*from
) -> std::optional
<UseRange
> {
706 return getSymbolUsesImpl(from
);
708 auto SymbolTable::getSymbolUses(Region
*from
) -> std::optional
<UseRange
> {
709 return getSymbolUsesImpl(MutableArrayRef
<Region
>(*from
));
712 //===----------------------------------------------------------------------===//
713 // SymbolTable::getSymbolUses
715 /// The implementation of SymbolTable::getSymbolUses below.
716 template <typename SymbolT
, typename IRUnitT
>
717 static std::optional
<SymbolTable::UseRange
> getSymbolUsesImpl(SymbolT symbol
,
719 std::vector
<SymbolTable::SymbolUse
> uses
;
720 for (SymbolScope
&scope
: collectSymbolScopes(symbol
, limit
)) {
721 if (!scope
.walk([&](SymbolTable::SymbolUse symbolUse
) {
722 if (isReferencePrefixOf(scope
.symbol
, symbolUse
.getSymbolRef()))
723 uses
.push_back(symbolUse
);
727 return SymbolTable::UseRange(std::move(uses
));
730 /// Get all of the uses of the given symbol that are nested within the given
731 /// operation 'from', invoking the provided callback for each. This does not
732 /// traverse into any nested symbol tables. This function returns std::nullopt
733 /// if there are any unknown operations that may potentially be symbol tables.
734 auto SymbolTable::getSymbolUses(StringAttr symbol
, Operation
*from
)
735 -> std::optional
<UseRange
> {
736 return getSymbolUsesImpl(symbol
, from
);
738 auto SymbolTable::getSymbolUses(Operation
*symbol
, Operation
*from
)
739 -> std::optional
<UseRange
> {
740 return getSymbolUsesImpl(symbol
, from
);
742 auto SymbolTable::getSymbolUses(StringAttr symbol
, Region
*from
)
743 -> std::optional
<UseRange
> {
744 return getSymbolUsesImpl(symbol
, from
);
746 auto SymbolTable::getSymbolUses(Operation
*symbol
, Region
*from
)
747 -> std::optional
<UseRange
> {
748 return getSymbolUsesImpl(symbol
, from
);
751 //===----------------------------------------------------------------------===//
752 // SymbolTable::symbolKnownUseEmpty
754 /// The implementation of SymbolTable::symbolKnownUseEmpty below.
755 template <typename SymbolT
, typename IRUnitT
>
756 static bool symbolKnownUseEmptyImpl(SymbolT symbol
, IRUnitT
*limit
) {
757 for (SymbolScope
&scope
: collectSymbolScopes(symbol
, limit
)) {
758 // Walk all of the symbol uses looking for a reference to 'symbol'.
759 if (scope
.walk([&](SymbolTable::SymbolUse symbolUse
) {
760 return isReferencePrefixOf(scope
.symbol
, symbolUse
.getSymbolRef())
761 ? WalkResult::interrupt()
762 : WalkResult::advance();
763 }) != WalkResult::advance())
769 /// Return if the given symbol is known to have no uses that are nested within
770 /// the given operation 'from'. This does not traverse into any nested symbol
771 /// tables. This function will also return false if there are any unknown
772 /// operations that may potentially be symbol tables.
773 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol
, Operation
*from
) {
774 return symbolKnownUseEmptyImpl(symbol
, from
);
776 bool SymbolTable::symbolKnownUseEmpty(Operation
*symbol
, Operation
*from
) {
777 return symbolKnownUseEmptyImpl(symbol
, from
);
779 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol
, Region
*from
) {
780 return symbolKnownUseEmptyImpl(symbol
, from
);
782 bool SymbolTable::symbolKnownUseEmpty(Operation
*symbol
, Region
*from
) {
783 return symbolKnownUseEmptyImpl(symbol
, from
);
786 //===----------------------------------------------------------------------===//
787 // SymbolTable::replaceAllSymbolUses
789 /// Generates a new symbol reference attribute with a new leaf reference.
790 static SymbolRefAttr
generateNewRefAttr(SymbolRefAttr oldAttr
,
791 FlatSymbolRefAttr newLeafAttr
) {
792 if (llvm::isa
<FlatSymbolRefAttr
>(oldAttr
))
794 auto nestedRefs
= llvm::to_vector
<2>(oldAttr
.getNestedReferences());
795 nestedRefs
.back() = newLeafAttr
;
796 return SymbolRefAttr::get(oldAttr
.getRootReference(), nestedRefs
);
799 /// The implementation of SymbolTable::replaceAllSymbolUses below.
800 template <typename SymbolT
, typename IRUnitT
>
802 replaceAllSymbolUsesImpl(SymbolT symbol
, StringAttr newSymbol
, IRUnitT
*limit
) {
803 // Generate a new attribute to replace the given attribute.
804 FlatSymbolRefAttr newLeafAttr
= FlatSymbolRefAttr::get(newSymbol
);
805 for (SymbolScope
&scope
: collectSymbolScopes(symbol
, limit
)) {
806 SymbolRefAttr oldAttr
= scope
.symbol
;
807 SymbolRefAttr newAttr
= generateNewRefAttr(scope
.symbol
, newLeafAttr
);
808 AttrTypeReplacer replacer
;
809 replacer
.addReplacement(
810 [&](SymbolRefAttr attr
) -> std::pair
<Attribute
, WalkResult
> {
811 // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
812 // want to accidentally replace an inner reference.
814 return {newAttr
, WalkResult::skip()};
815 // Handle prefix matches.
816 if (isReferencePrefixOf(oldAttr
, attr
)) {
817 auto oldNestedRefs
= oldAttr
.getNestedReferences();
818 auto nestedRefs
= attr
.getNestedReferences();
819 if (oldNestedRefs
.empty())
820 return {SymbolRefAttr::get(newSymbol
, nestedRefs
),
823 auto newNestedRefs
= llvm::to_vector
<4>(nestedRefs
);
824 newNestedRefs
[oldNestedRefs
.size() - 1] = newLeafAttr
;
825 return {SymbolRefAttr::get(attr
.getRootReference(), newNestedRefs
),
828 return {attr
, WalkResult::skip()};
831 auto walkFn
= [&](Operation
*op
) -> std::optional
<WalkResult
> {
832 replacer
.replaceElementsIn(op
);
833 return WalkResult::advance();
835 if (!scope
.walkSymbolTable(walkFn
))
841 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
842 /// provided symbol 'newSymbol' that are nested within the given operation
843 /// 'from'. This does not traverse into any nested symbol tables. If there are
844 /// any unknown operations that may potentially be symbol tables, no uses are
845 /// replaced and failure is returned.
846 LogicalResult
SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol
,
847 StringAttr newSymbol
,
849 return replaceAllSymbolUsesImpl(oldSymbol
, newSymbol
, from
);
851 LogicalResult
SymbolTable::replaceAllSymbolUses(Operation
*oldSymbol
,
852 StringAttr newSymbol
,
854 return replaceAllSymbolUsesImpl(oldSymbol
, newSymbol
, from
);
856 LogicalResult
SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol
,
857 StringAttr newSymbol
,
859 return replaceAllSymbolUsesImpl(oldSymbol
, newSymbol
, from
);
861 LogicalResult
SymbolTable::replaceAllSymbolUses(Operation
*oldSymbol
,
862 StringAttr newSymbol
,
864 return replaceAllSymbolUsesImpl(oldSymbol
, newSymbol
, from
);
867 //===----------------------------------------------------------------------===//
868 // SymbolTableCollection
869 //===----------------------------------------------------------------------===//
871 Operation
*SymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
873 return getSymbolTable(symbolTableOp
).lookup(symbol
);
875 Operation
*SymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
876 SymbolRefAttr name
) {
877 SmallVector
<Operation
*, 4> symbols
;
878 if (failed(lookupSymbolIn(symbolTableOp
, name
, symbols
)))
880 return symbols
.back();
882 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
883 /// a given SymbolRefAttr. Returns failure if any of the nested references could
886 SymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
888 SmallVectorImpl
<Operation
*> &symbols
) {
889 auto lookupFn
= [this](Operation
*symbolTableOp
, StringAttr symbol
) {
890 return lookupSymbolIn(symbolTableOp
, symbol
);
892 return lookupSymbolInImpl(symbolTableOp
, name
, symbols
, lookupFn
);
895 /// Returns the operation registered with the given symbol name within the
896 /// closest parent operation of, or including, 'from' with the
897 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
899 Operation
*SymbolTableCollection::lookupNearestSymbolFrom(Operation
*from
,
901 Operation
*symbolTableOp
= SymbolTable::getNearestSymbolTable(from
);
902 return symbolTableOp
? lookupSymbolIn(symbolTableOp
, symbol
) : nullptr;
905 SymbolTableCollection::lookupNearestSymbolFrom(Operation
*from
,
906 SymbolRefAttr symbol
) {
907 Operation
*symbolTableOp
= SymbolTable::getNearestSymbolTable(from
);
908 return symbolTableOp
? lookupSymbolIn(symbolTableOp
, symbol
) : nullptr;
911 /// Lookup, or create, a symbol table for an operation.
912 SymbolTable
&SymbolTableCollection::getSymbolTable(Operation
*op
) {
913 auto it
= symbolTables
.try_emplace(op
, nullptr);
915 it
.first
->second
= std::make_unique
<SymbolTable
>(op
);
916 return *it
.first
->second
;
919 //===----------------------------------------------------------------------===//
920 // LockedSymbolTableCollection
921 //===----------------------------------------------------------------------===//
923 Operation
*LockedSymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
925 return getSymbolTable(symbolTableOp
).lookup(symbol
);
929 LockedSymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
930 FlatSymbolRefAttr symbol
) {
931 return lookupSymbolIn(symbolTableOp
, symbol
.getAttr());
934 Operation
*LockedSymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
935 SymbolRefAttr name
) {
936 SmallVector
<Operation
*> symbols
;
937 if (failed(lookupSymbolIn(symbolTableOp
, name
, symbols
)))
939 return symbols
.back();
942 LogicalResult
LockedSymbolTableCollection::lookupSymbolIn(
943 Operation
*symbolTableOp
, SymbolRefAttr name
,
944 SmallVectorImpl
<Operation
*> &symbols
) {
945 auto lookupFn
= [this](Operation
*symbolTableOp
, StringAttr symbol
) {
946 return lookupSymbolIn(symbolTableOp
, symbol
);
948 return lookupSymbolInImpl(symbolTableOp
, name
, symbols
, lookupFn
);
952 LockedSymbolTableCollection::getSymbolTable(Operation
*symbolTableOp
) {
953 assert(symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>());
954 // Try to find an existing symbol table.
956 llvm::sys::SmartScopedReader
<true> lock(mutex
);
957 auto it
= collection
.symbolTables
.find(symbolTableOp
);
958 if (it
!= collection
.symbolTables
.end())
961 // Create a symbol table for the operation. Perform construction outside of
962 // the critical section.
963 auto symbolTable
= std::make_unique
<SymbolTable
>(symbolTableOp
);
964 // Insert the constructed symbol table.
965 llvm::sys::SmartScopedWriter
<true> lock(mutex
);
966 return *collection
.symbolTables
967 .insert({symbolTableOp
, std::move(symbolTable
)})
971 //===----------------------------------------------------------------------===//
973 //===----------------------------------------------------------------------===//
975 SymbolUserMap::SymbolUserMap(SymbolTableCollection
&symbolTable
,
976 Operation
*symbolTableOp
)
977 : symbolTable(symbolTable
) {
978 // Walk each of the symbol tables looking for discardable callgraph nodes.
979 SmallVector
<Operation
*> symbols
;
980 auto walkFn
= [&](Operation
*symbolTableOp
, bool allUsesVisible
) {
981 for (Operation
&nestedOp
: symbolTableOp
->getRegion(0).getOps()) {
982 auto symbolUses
= SymbolTable::getSymbolUses(&nestedOp
);
983 assert(symbolUses
&& "expected uses to be valid");
985 for (const SymbolTable::SymbolUse
&use
: *symbolUses
) {
987 (void)symbolTable
.lookupSymbolIn(symbolTableOp
, use
.getSymbolRef(),
989 for (Operation
*symbolOp
: symbols
)
990 symbolToUsers
[symbolOp
].insert(use
.getUser());
994 // We just set `allSymUsesVisible` to false here because it isn't necessary
995 // for building the user map.
996 SymbolTable::walkSymbolTables(symbolTableOp
, /*allSymUsesVisible=*/false,
1000 void SymbolUserMap::replaceAllUsesWith(Operation
*symbol
,
1001 StringAttr newSymbolName
) {
1002 auto it
= symbolToUsers
.find(symbol
);
1003 if (it
== symbolToUsers
.end())
1006 // Replace the uses within the users of `symbol`.
1007 for (Operation
*user
: it
->second
)
1008 (void)SymbolTable::replaceAllSymbolUses(symbol
, newSymbolName
, user
);
1010 // Move the current users of `symbol` to the new symbol if it is in the
1012 Operation
*newSymbol
=
1013 symbolTable
.lookupSymbolIn(symbol
->getParentOp(), newSymbolName
);
1014 if (newSymbol
!= symbol
) {
1015 // Transfer over the users to the new symbol. The reference to the old one
1016 // is fetched again as the iterator is invalidated during the insertion.
1017 auto newIt
= symbolToUsers
.try_emplace(newSymbol
, SetVector
<Operation
*>{});
1018 auto oldIt
= symbolToUsers
.find(symbol
);
1019 assert(oldIt
!= symbolToUsers
.end() && "missing old users list");
1021 newIt
.first
->second
= std::move(oldIt
->second
);
1023 newIt
.first
->second
.set_union(oldIt
->second
);
1024 symbolToUsers
.erase(oldIt
);
1028 //===----------------------------------------------------------------------===//
1029 // Visibility parsing implementation.
1030 //===----------------------------------------------------------------------===//
1032 ParseResult
impl::parseOptionalVisibilityKeyword(OpAsmParser
&parser
,
1033 NamedAttrList
&attrs
) {
1034 StringRef visibility
;
1035 if (parser
.parseOptionalKeyword(&visibility
, {"public", "private", "nested"}))
1038 StringAttr visibilityAttr
= parser
.getBuilder().getStringAttr(visibility
);
1039 attrs
.push_back(parser
.getBuilder().getNamedAttr(
1040 SymbolTable::getVisibilityAttrName(), visibilityAttr
));
1044 //===----------------------------------------------------------------------===//
1045 // Symbol Interfaces
1046 //===----------------------------------------------------------------------===//
1048 /// Include the generated symbol interfaces.
1049 #include "mlir/IR/SymbolInterfaces.cpp.inc"