1 //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/SymbolTable.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/OpImplementation.h"
12 #include "llvm/ADT/SetVector.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallString.h"
15 #include "llvm/ADT/StringSwitch.h"
20 /// Return true if the given operation is unknown and may potentially define a
22 static bool isPotentiallyUnknownSymbolTable(Operation
*op
) {
23 return op
->getNumRegions() == 1 && !op
->getDialect();
26 /// Returns the string name of the given symbol, or null if this is not a
28 static StringAttr
getNameIfSymbol(Operation
*op
) {
29 return op
->getAttrOfType
<StringAttr
>(SymbolTable::getSymbolAttrName());
31 static StringAttr
getNameIfSymbol(Operation
*op
, StringAttr symbolAttrNameId
) {
32 return op
->getAttrOfType
<StringAttr
>(symbolAttrNameId
);
35 /// Computes the nested symbol reference attribute for the symbol 'symbolName'
36 /// that are usable within the symbol table operations from 'symbol' as far up
37 /// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
38 /// Returns success if all references up to 'within' could be computed.
40 collectValidReferencesFor(Operation
*symbol
, StringAttr symbolName
,
42 SmallVectorImpl
<SymbolRefAttr
> &results
) {
43 assert(within
->isAncestor(symbol
) && "expected 'within' to be an ancestor");
44 MLIRContext
*ctx
= symbol
->getContext();
46 auto leafRef
= FlatSymbolRefAttr::get(symbolName
);
47 results
.push_back(leafRef
);
49 // Early exit for when 'within' is the parent of 'symbol'.
50 Operation
*symbolTableOp
= symbol
->getParentOp();
51 if (within
== symbolTableOp
)
54 // Collect references until 'symbolTableOp' reaches 'within'.
55 SmallVector
<FlatSymbolRefAttr
, 1> nestedRefs(1, leafRef
);
56 StringAttr symbolNameId
=
57 StringAttr::get(ctx
, SymbolTable::getSymbolAttrName());
59 // Each parent of 'symbol' should define a symbol table.
60 if (!symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>())
62 // Each parent of 'symbol' should also be a symbol.
63 StringAttr symbolTableName
= getNameIfSymbol(symbolTableOp
, symbolNameId
);
66 results
.push_back(SymbolRefAttr::get(symbolTableName
, nestedRefs
));
68 symbolTableOp
= symbolTableOp
->getParentOp();
69 if (symbolTableOp
== within
)
71 nestedRefs
.insert(nestedRefs
.begin(),
72 FlatSymbolRefAttr::get(symbolTableName
));
77 /// Walk all of the operations within the given set of regions, without
78 /// traversing into any nested symbol tables. Stops walking if the result of the
79 /// callback is anything other than `WalkResult::advance`.
80 static std::optional
<WalkResult
>
81 walkSymbolTable(MutableArrayRef
<Region
> regions
,
82 function_ref
<std::optional
<WalkResult
>(Operation
*)> callback
) {
83 SmallVector
<Region
*, 1> worklist(llvm::make_pointer_range(regions
));
84 while (!worklist
.empty()) {
85 for (Operation
&op
: worklist
.pop_back_val()->getOps()) {
86 std::optional
<WalkResult
> result
= callback(&op
);
87 if (result
!= WalkResult::advance())
90 // If this op defines a new symbol table scope, we can't traverse. Any
91 // symbol references nested within 'op' are different semantically.
92 if (!op
.hasTrait
<OpTrait::SymbolTable
>()) {
93 for (Region
®ion
: op
.getRegions())
94 worklist
.push_back(®ion
);
98 return WalkResult::advance();
101 /// Walk all of the operations nested under, and including, the given operation,
102 /// without traversing into any nested symbol tables. Stops walking if the
103 /// result of the callback is anything other than `WalkResult::advance`.
104 static std::optional
<WalkResult
>
105 walkSymbolTable(Operation
*op
,
106 function_ref
<std::optional
<WalkResult
>(Operation
*)> callback
) {
107 std::optional
<WalkResult
> result
= callback(op
);
108 if (result
!= WalkResult::advance() || op
->hasTrait
<OpTrait::SymbolTable
>())
110 return walkSymbolTable(op
->getRegions(), callback
);
113 //===----------------------------------------------------------------------===//
115 //===----------------------------------------------------------------------===//
117 /// Build a symbol table with the symbols within the given operation.
118 SymbolTable::SymbolTable(Operation
*symbolTableOp
)
119 : symbolTableOp(symbolTableOp
) {
120 assert(symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>() &&
121 "expected operation to have SymbolTable trait");
122 assert(symbolTableOp
->getNumRegions() == 1 &&
123 "expected operation to have a single region");
124 assert(llvm::hasSingleElement(symbolTableOp
->getRegion(0)) &&
125 "expected operation to have a single block");
127 StringAttr symbolNameId
= StringAttr::get(symbolTableOp
->getContext(),
128 SymbolTable::getSymbolAttrName());
129 for (auto &op
: symbolTableOp
->getRegion(0).front()) {
130 StringAttr name
= getNameIfSymbol(&op
, symbolNameId
);
134 auto inserted
= symbolTable
.insert({name
, &op
});
136 assert(inserted
.second
&&
137 "expected region to contain uniquely named symbol operations");
141 /// Look up a symbol with the specified name, returning null if no such name
142 /// exists. Names never include the @ on them.
143 Operation
*SymbolTable::lookup(StringRef name
) const {
144 return lookup(StringAttr::get(symbolTableOp
->getContext(), name
));
146 Operation
*SymbolTable::lookup(StringAttr name
) const {
147 return symbolTable
.lookup(name
);
150 void SymbolTable::remove(Operation
*op
) {
151 StringAttr name
= getNameIfSymbol(op
);
152 assert(name
&& "expected valid 'name' attribute");
153 assert(op
->getParentOp() == symbolTableOp
&&
154 "expected this operation to be inside of the operation with this "
157 auto it
= symbolTable
.find(name
);
158 if (it
!= symbolTable
.end() && it
->second
== op
)
159 symbolTable
.erase(it
);
162 void SymbolTable::erase(Operation
*symbol
) {
167 // TODO: Consider if this should be renamed to something like insertOrUpdate
168 /// Insert a new symbol into the table and associated operation if not already
169 /// there and rename it as necessary to avoid collisions. Return the name of
170 /// the symbol after insertion as attribute.
171 StringAttr
SymbolTable::insert(Operation
*symbol
, Block::iterator insertPt
) {
172 // The symbol cannot be the child of another op and must be the child of the
173 // symbolTableOp after this.
175 // TODO: consider if SymbolTable's constructor should behave the same.
176 if (!symbol
->getParentOp()) {
177 auto &body
= symbolTableOp
->getRegion(0).front();
178 if (insertPt
== Block::iterator()) {
179 insertPt
= Block::iterator(body
.end());
181 assert((insertPt
== body
.end() ||
182 insertPt
->getParentOp() == symbolTableOp
) &&
183 "expected insertPt to be in the associated module operation");
185 // Insert before the terminator, if any.
186 if (insertPt
== Block::iterator(body
.end()) && !body
.empty() &&
187 std::prev(body
.end())->hasTrait
<OpTrait::IsTerminator
>())
188 insertPt
= std::prev(body
.end());
190 body
.getOperations().insert(insertPt
, symbol
);
192 assert(symbol
->getParentOp() == symbolTableOp
&&
193 "symbol is already inserted in another op");
195 // Add this symbol to the symbol table, uniquing the name if a conflict is
197 StringAttr name
= getSymbolName(symbol
);
198 if (symbolTable
.insert({name
, symbol
}).second
)
200 // If the symbol was already in the table, also return.
201 if (symbolTable
.lookup(name
) == symbol
)
204 MLIRContext
*context
= symbol
->getContext();
205 SmallString
<128> nameBuffer
= generateSymbolName
<128>(
207 [&](StringRef candidate
) {
209 .insert({StringAttr::get(context
, candidate
), symbol
})
213 setSymbolName(symbol
, nameBuffer
);
214 return getSymbolName(symbol
);
217 LogicalResult
SymbolTable::rename(StringAttr from
, StringAttr to
) {
218 Operation
*op
= lookup(from
);
219 return rename(op
, to
);
222 LogicalResult
SymbolTable::rename(Operation
*op
, StringAttr to
) {
223 StringAttr from
= getNameIfSymbol(op
);
226 assert(from
&& "expected valid 'name' attribute");
227 assert(op
->getParentOp() == symbolTableOp
&&
228 "expected this operation to be inside of the operation with this "
230 assert(lookup(from
) == op
&& "current name does not resolve to op");
231 assert(lookup(to
) == nullptr && "new name already exists");
233 if (failed(SymbolTable::replaceAllSymbolUses(op
, to
, getOp())))
236 // Remove op with old name, change name, add with new name. The order is
237 // important here due to how `remove` and `insert` rely on the op name.
239 setSymbolName(op
, to
);
242 assert(lookup(to
) == op
&& "new name does not resolve to renamed op");
243 assert(lookup(from
) == nullptr && "old name still exists");
248 LogicalResult
SymbolTable::rename(StringAttr from
, StringRef to
) {
249 auto toAttr
= StringAttr::get(getOp()->getContext(), to
);
250 return rename(from
, toAttr
);
253 LogicalResult
SymbolTable::rename(Operation
*op
, StringRef to
) {
254 auto toAttr
= StringAttr::get(getOp()->getContext(), to
);
255 return rename(op
, toAttr
);
258 FailureOr
<StringAttr
>
259 SymbolTable::renameToUnique(StringAttr oldName
,
260 ArrayRef
<SymbolTable
*> others
) {
262 // Determine new name that is unique in all symbol tables.
265 MLIRContext
*context
= oldName
.getContext();
266 SmallString
<64> prefix
= oldName
.getValue();
268 prefix
.push_back('_');
270 newName
= StringAttr::get(context
, prefix
+ Twine(uniqueId
++));
271 auto lookupNewName
= [&](SymbolTable
*st
) { return st
->lookup(newName
); };
272 if (!lookupNewName(this) && llvm::none_of(others
, lookupNewName
)) {
279 if (failed(rename(oldName
, newName
)))
284 FailureOr
<StringAttr
>
285 SymbolTable::renameToUnique(Operation
*op
, ArrayRef
<SymbolTable
*> others
) {
286 StringAttr from
= getNameIfSymbol(op
);
287 assert(from
&& "expected valid 'name' attribute");
288 return renameToUnique(from
, others
);
291 /// Returns the name of the given symbol operation.
292 StringAttr
SymbolTable::getSymbolName(Operation
*symbol
) {
293 StringAttr name
= getNameIfSymbol(symbol
);
294 assert(name
&& "expected valid symbol name");
298 /// Sets the name of the given symbol operation.
299 void SymbolTable::setSymbolName(Operation
*symbol
, StringAttr name
) {
300 symbol
->setAttr(getSymbolAttrName(), name
);
303 /// Returns the visibility of the given symbol operation.
304 SymbolTable::Visibility
SymbolTable::getSymbolVisibility(Operation
*symbol
) {
305 // If the attribute doesn't exist, assume public.
306 StringAttr vis
= symbol
->getAttrOfType
<StringAttr
>(getVisibilityAttrName());
308 return Visibility::Public
;
310 // Otherwise, switch on the string value.
311 return StringSwitch
<Visibility
>(vis
.getValue())
312 .Case("private", Visibility::Private
)
313 .Case("nested", Visibility::Nested
)
314 .Case("public", Visibility::Public
);
316 /// Sets the visibility of the given symbol operation.
317 void SymbolTable::setSymbolVisibility(Operation
*symbol
, Visibility vis
) {
318 MLIRContext
*ctx
= symbol
->getContext();
320 // If the visibility is public, just drop the attribute as this is the
322 if (vis
== Visibility::Public
) {
323 symbol
->removeAttr(StringAttr::get(ctx
, getVisibilityAttrName()));
327 // Otherwise, update the attribute.
328 assert((vis
== Visibility::Private
|| vis
== Visibility::Nested
) &&
329 "unknown symbol visibility kind");
331 StringRef visName
= vis
== Visibility::Private
? "private" : "nested";
332 symbol
->setAttr(getVisibilityAttrName(), StringAttr::get(ctx
, visName
));
335 /// Returns the nearest symbol table from a given operation `from`. Returns
336 /// nullptr if no valid parent symbol table could be found.
337 Operation
*SymbolTable::getNearestSymbolTable(Operation
*from
) {
338 assert(from
&& "expected valid operation");
339 if (isPotentiallyUnknownSymbolTable(from
))
342 while (!from
->hasTrait
<OpTrait::SymbolTable
>()) {
343 from
= from
->getParentOp();
345 // Check that this is a valid op and isn't an unknown symbol table.
346 if (!from
|| isPotentiallyUnknownSymbolTable(from
))
352 /// Walks all symbol table operations nested within, and including, `op`. For
353 /// each symbol table operation, the provided callback is invoked with the op
354 /// and a boolean signifying if the symbols within that symbol table can be
355 /// treated as if all uses are visible. `allSymUsesVisible` identifies whether
356 /// all of the symbol uses of symbols within `op` are visible.
357 void SymbolTable::walkSymbolTables(
358 Operation
*op
, bool allSymUsesVisible
,
359 function_ref
<void(Operation
*, bool)> callback
) {
360 bool isSymbolTable
= op
->hasTrait
<OpTrait::SymbolTable
>();
362 SymbolOpInterface symbol
= dyn_cast
<SymbolOpInterface
>(op
);
363 allSymUsesVisible
|= !symbol
|| symbol
.isPrivate();
365 // Otherwise if 'op' is not a symbol table, any nested symbols are
366 // guaranteed to be hidden.
367 allSymUsesVisible
= true;
370 for (Region
®ion
: op
->getRegions())
371 for (Block
&block
: region
)
372 for (Operation
&nestedOp
: block
)
373 walkSymbolTables(&nestedOp
, allSymUsesVisible
, callback
);
375 // If 'op' had the symbol table trait, visit it after any nested symbol
378 callback(op
, allSymUsesVisible
);
381 /// Returns the operation registered with the given symbol name with the
382 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
383 /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
385 Operation
*SymbolTable::lookupSymbolIn(Operation
*symbolTableOp
,
387 assert(symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>());
388 Region
®ion
= symbolTableOp
->getRegion(0);
392 // Look for a symbol with the given name.
393 StringAttr symbolNameId
= StringAttr::get(symbolTableOp
->getContext(),
394 SymbolTable::getSymbolAttrName());
395 for (auto &op
: region
.front())
396 if (getNameIfSymbol(&op
, symbolNameId
) == symbol
)
400 Operation
*SymbolTable::lookupSymbolIn(Operation
*symbolTableOp
,
401 SymbolRefAttr symbol
) {
402 SmallVector
<Operation
*, 4> resolvedSymbols
;
403 if (failed(lookupSymbolIn(symbolTableOp
, symbol
, resolvedSymbols
)))
405 return resolvedSymbols
.back();
408 /// Internal implementation of `lookupSymbolIn` that allows for specialized
409 /// implementations of the lookup function.
410 static LogicalResult
lookupSymbolInImpl(
411 Operation
*symbolTableOp
, SymbolRefAttr symbol
,
412 SmallVectorImpl
<Operation
*> &symbols
,
413 function_ref
<Operation
*(Operation
*, StringAttr
)> lookupSymbolFn
) {
414 assert(symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>());
416 // Lookup the root reference for this symbol.
417 symbolTableOp
= lookupSymbolFn(symbolTableOp
, symbol
.getRootReference());
420 symbols
.push_back(symbolTableOp
);
422 // If there are no nested references, just return the root symbol directly.
423 ArrayRef
<FlatSymbolRefAttr
> nestedRefs
= symbol
.getNestedReferences();
424 if (nestedRefs
.empty())
427 // Verify that the root is also a symbol table.
428 if (!symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>())
431 // Otherwise, lookup each of the nested non-leaf references and ensure that
432 // each corresponds to a valid symbol table.
433 for (FlatSymbolRefAttr ref
: nestedRefs
.drop_back()) {
434 symbolTableOp
= lookupSymbolFn(symbolTableOp
, ref
.getAttr());
435 if (!symbolTableOp
|| !symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>())
437 symbols
.push_back(symbolTableOp
);
439 symbols
.push_back(lookupSymbolFn(symbolTableOp
, symbol
.getLeafReference()));
440 return success(symbols
.back());
444 SymbolTable::lookupSymbolIn(Operation
*symbolTableOp
, SymbolRefAttr symbol
,
445 SmallVectorImpl
<Operation
*> &symbols
) {
446 auto lookupFn
= [](Operation
*symbolTableOp
, StringAttr symbol
) {
447 return lookupSymbolIn(symbolTableOp
, symbol
);
449 return lookupSymbolInImpl(symbolTableOp
, symbol
, symbols
, lookupFn
);
452 /// Returns the operation registered with the given symbol name within the
453 /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
454 /// nullptr if no valid symbol was found.
455 Operation
*SymbolTable::lookupNearestSymbolFrom(Operation
*from
,
457 Operation
*symbolTableOp
= getNearestSymbolTable(from
);
458 return symbolTableOp
? lookupSymbolIn(symbolTableOp
, symbol
) : nullptr;
460 Operation
*SymbolTable::lookupNearestSymbolFrom(Operation
*from
,
461 SymbolRefAttr symbol
) {
462 Operation
*symbolTableOp
= getNearestSymbolTable(from
);
463 return symbolTableOp
? lookupSymbolIn(symbolTableOp
, symbol
) : nullptr;
466 raw_ostream
&mlir::operator<<(raw_ostream
&os
,
467 SymbolTable::Visibility visibility
) {
468 switch (visibility
) {
469 case SymbolTable::Visibility::Public
:
470 return os
<< "public";
471 case SymbolTable::Visibility::Private
:
472 return os
<< "private";
473 case SymbolTable::Visibility::Nested
:
474 return os
<< "nested";
476 llvm_unreachable("Unexpected visibility");
479 //===----------------------------------------------------------------------===//
480 // SymbolTable Trait Types
481 //===----------------------------------------------------------------------===//
483 LogicalResult
detail::verifySymbolTable(Operation
*op
) {
484 if (op
->getNumRegions() != 1)
485 return op
->emitOpError()
486 << "Operations with a 'SymbolTable' must have exactly one region";
487 if (!llvm::hasSingleElement(op
->getRegion(0)))
488 return op
->emitOpError()
489 << "Operations with a 'SymbolTable' must have exactly one block";
491 // Check that all symbols are uniquely named within child regions.
492 DenseMap
<Attribute
, Location
> nameToOrigLoc
;
493 for (auto &block
: op
->getRegion(0)) {
494 for (auto &op
: block
) {
495 // Check for a symbol name attribute.
497 op
.getAttrOfType
<StringAttr
>(mlir::SymbolTable::getSymbolAttrName());
501 // Try to insert this symbol into the table.
502 auto it
= nameToOrigLoc
.try_emplace(nameAttr
, op
.getLoc());
504 return op
.emitError()
505 .append("redefinition of symbol named '", nameAttr
.getValue(), "'")
506 .attachNote(it
.first
->second
)
507 .append("see existing symbol definition here");
511 // Verify any nested symbol user operations.
512 SymbolTableCollection symbolTable
;
513 auto verifySymbolUserFn
= [&](Operation
*op
) -> std::optional
<WalkResult
> {
514 if (SymbolUserOpInterface user
= dyn_cast
<SymbolUserOpInterface
>(op
))
515 return WalkResult(user
.verifySymbolUses(symbolTable
));
516 return WalkResult::advance();
519 std::optional
<WalkResult
> result
=
520 walkSymbolTable(op
->getRegions(), verifySymbolUserFn
);
521 return success(result
&& !result
->wasInterrupted());
524 LogicalResult
detail::verifySymbol(Operation
*op
) {
525 // Verify the name attribute.
526 if (!op
->getAttrOfType
<StringAttr
>(mlir::SymbolTable::getSymbolAttrName()))
527 return op
->emitOpError() << "requires string attribute '"
528 << mlir::SymbolTable::getSymbolAttrName() << "'";
530 // Verify the visibility attribute.
531 if (Attribute vis
= op
->getAttr(mlir::SymbolTable::getVisibilityAttrName())) {
532 StringAttr visStrAttr
= llvm::dyn_cast
<StringAttr
>(vis
);
534 return op
->emitOpError() << "requires visibility attribute '"
535 << mlir::SymbolTable::getVisibilityAttrName()
536 << "' to be a string attribute, but got " << vis
;
538 if (!llvm::is_contained(ArrayRef
<StringRef
>{"public", "private", "nested"},
539 visStrAttr
.getValue()))
540 return op
->emitOpError()
541 << "visibility expected to be one of [\"public\", \"private\", "
542 "\"nested\"], but got "
548 //===----------------------------------------------------------------------===//
550 //===----------------------------------------------------------------------===//
552 /// Walk all of the symbol references within the given operation, invoking the
553 /// provided callback for each found use. The callbacks takes the use of the
556 walkSymbolRefs(Operation
*op
,
557 function_ref
<WalkResult(SymbolTable::SymbolUse
)> callback
) {
558 return op
->getAttrDictionary().walk
<WalkOrder::PreOrder
>(
559 [&](SymbolRefAttr symbolRef
) {
560 if (callback({op
, symbolRef
}).wasInterrupted())
561 return WalkResult::interrupt();
563 // Don't walk nested references.
564 return WalkResult::skip();
568 /// Walk all of the uses, for any symbol, that are nested within the given
569 /// regions, invoking the provided callback for each. This does not traverse
570 /// into any nested symbol tables.
571 static std::optional
<WalkResult
>
572 walkSymbolUses(MutableArrayRef
<Region
> regions
,
573 function_ref
<WalkResult(SymbolTable::SymbolUse
)> callback
) {
574 return walkSymbolTable(regions
,
575 [&](Operation
*op
) -> std::optional
<WalkResult
> {
576 // Check that this isn't a potentially unknown symbol
578 if (isPotentiallyUnknownSymbolTable(op
))
581 return walkSymbolRefs(op
, callback
);
584 /// Walk all of the uses, for any symbol, that are nested within the given
585 /// operation 'from', invoking the provided callback for each. This does not
586 /// traverse into any nested symbol tables.
587 static std::optional
<WalkResult
>
588 walkSymbolUses(Operation
*from
,
589 function_ref
<WalkResult(SymbolTable::SymbolUse
)> callback
) {
590 // If this operation has regions, and it, as well as its dialect, isn't
591 // registered then conservatively fail. The operation may define a
592 // symbol table, so we can't opaquely know if we should traverse to find
594 if (isPotentiallyUnknownSymbolTable(from
))
597 // Walk the uses on this operation.
598 if (walkSymbolRefs(from
, callback
).wasInterrupted())
599 return WalkResult::interrupt();
601 // Only recurse if this operation is not a symbol table. A symbol table
602 // defines a new scope, so we can't walk the attributes from within the symbol
604 if (!from
->hasTrait
<OpTrait::SymbolTable
>())
605 return walkSymbolUses(from
->getRegions(), callback
);
606 return WalkResult::advance();
610 /// This class represents a single symbol scope. A symbol scope represents the
611 /// set of operations nested within a symbol table that may reference symbols
612 /// within that table. A symbol scope does not contain the symbol table
613 /// operation itself, just its contained operations. A scope ends at leaf
614 /// operations or another symbol table operation.
616 /// Walk the symbol uses within this scope, invoking the given callback.
617 /// This variant is used when the callback type matches that expected by
618 /// 'walkSymbolUses'.
619 template <typename CallbackT
,
620 std::enable_if_t
<!std::is_same
<
621 typename
llvm::function_traits
<CallbackT
>::result_t
,
622 void>::value
> * = nullptr>
623 std::optional
<WalkResult
> walk(CallbackT cback
) {
624 if (Region
*region
= llvm::dyn_cast_if_present
<Region
*>(limit
))
625 return walkSymbolUses(*region
, cback
);
626 return walkSymbolUses(limit
.get
<Operation
*>(), cback
);
628 /// This variant is used when the callback type matches a stripped down type:
629 /// void(SymbolTable::SymbolUse use)
630 template <typename CallbackT
,
631 std::enable_if_t
<std::is_same
<
632 typename
llvm::function_traits
<CallbackT
>::result_t
,
633 void>::value
> * = nullptr>
634 std::optional
<WalkResult
> walk(CallbackT cback
) {
635 return walk([=](SymbolTable::SymbolUse use
) {
636 return cback(use
), WalkResult::advance();
640 /// Walk all of the operations nested under the current scope without
641 /// traversing into any nested symbol tables.
642 template <typename CallbackT
>
643 std::optional
<WalkResult
> walkSymbolTable(CallbackT
&&cback
) {
644 if (Region
*region
= llvm::dyn_cast_if_present
<Region
*>(limit
))
645 return ::walkSymbolTable(*region
, cback
);
646 return ::walkSymbolTable(limit
.get
<Operation
*>(), cback
);
649 /// The representation of the symbol within this scope.
650 SymbolRefAttr symbol
;
652 /// The IR unit representing this scope.
653 llvm::PointerUnion
<Operation
*, Region
*> limit
;
657 /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
658 static SmallVector
<SymbolScope
, 2> collectSymbolScopes(Operation
*symbol
,
660 StringAttr symName
= SymbolTable::getSymbolName(symbol
);
661 assert(!symbol
->hasTrait
<OpTrait::SymbolTable
>() || symbol
!= limit
);
663 // Compute the ancestors of 'limit'.
664 SetVector
<Operation
*, SmallVector
<Operation
*, 4>,
665 SmallPtrSet
<Operation
*, 4>>
667 Operation
*limitAncestor
= limit
;
669 // Check to see if 'symbol' is an ancestor of 'limit'.
670 if (limitAncestor
== symbol
) {
671 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
672 // doesn't support parent references.
673 if (SymbolTable::getNearestSymbolTable(limit
->getParentOp()) ==
674 symbol
->getParentOp())
675 return {{SymbolRefAttr::get(symName
), limit
}};
679 limitAncestors
.insert(limitAncestor
);
680 } while ((limitAncestor
= limitAncestor
->getParentOp()));
682 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
683 Operation
*commonAncestor
= symbol
->getParentOp();
685 if (limitAncestors
.count(commonAncestor
))
687 } while ((commonAncestor
= commonAncestor
->getParentOp()));
688 assert(commonAncestor
&& "'limit' and 'symbol' have no common ancestor");
690 // Compute the set of valid nested references for 'symbol' as far up to the
691 // common ancestor as possible.
692 SmallVector
<SymbolRefAttr
, 2> references
;
693 bool collectedAllReferences
= succeeded(
694 collectValidReferencesFor(symbol
, symName
, commonAncestor
, references
));
696 // Handle the case where the common ancestor is 'limit'.
697 if (commonAncestor
== limit
) {
698 SmallVector
<SymbolScope
, 2> scopes
;
700 // Walk each of the ancestors of 'symbol', calling the compute function for
702 Operation
*limitIt
= symbol
->getParentOp();
703 for (size_t i
= 0, e
= references
.size(); i
!= e
;
704 ++i
, limitIt
= limitIt
->getParentOp()) {
705 assert(limitIt
->hasTrait
<OpTrait::SymbolTable
>());
706 scopes
.push_back({references
[i
], &limitIt
->getRegion(0)});
711 // Otherwise, we just need the symbol reference for 'symbol' that will be
712 // used within 'limit'. This is the last reference in the list we computed
713 // above if we were able to collect all references.
714 if (!collectedAllReferences
)
716 return {{references
.back(), limit
}};
718 static SmallVector
<SymbolScope
, 2> collectSymbolScopes(Operation
*symbol
,
720 auto scopes
= collectSymbolScopes(symbol
, limit
->getParentOp());
722 // If we collected some scopes to walk, make sure to constrain the one for
723 // limit to the specific region requested.
725 scopes
.back().limit
= limit
;
728 static SmallVector
<SymbolScope
, 1> collectSymbolScopes(StringAttr symbol
,
730 return {{SymbolRefAttr::get(symbol
), limit
}};
733 static SmallVector
<SymbolScope
, 1> collectSymbolScopes(StringAttr symbol
,
735 SmallVector
<SymbolScope
, 1> scopes
;
736 auto symbolRef
= SymbolRefAttr::get(symbol
);
737 for (auto ®ion
: limit
->getRegions())
738 scopes
.push_back({symbolRef
, ®ion
});
742 /// Returns true if the given reference 'SubRef' is a sub reference of the
743 /// reference 'ref', i.e. 'ref' is a further qualified reference.
744 static bool isReferencePrefixOf(SymbolRefAttr subRef
, SymbolRefAttr ref
) {
748 // If the references are not pointer equal, check to see if `subRef` is a
750 if (llvm::isa
<FlatSymbolRefAttr
>(ref
) ||
751 ref
.getRootReference() != subRef
.getRootReference())
754 auto refLeafs
= ref
.getNestedReferences();
755 auto subRefLeafs
= subRef
.getNestedReferences();
756 return subRefLeafs
.size() < refLeafs
.size() &&
757 subRefLeafs
== refLeafs
.take_front(subRefLeafs
.size());
760 //===----------------------------------------------------------------------===//
761 // SymbolTable::getSymbolUses
763 /// The implementation of SymbolTable::getSymbolUses below.
764 template <typename FromT
>
765 static std::optional
<SymbolTable::UseRange
> getSymbolUsesImpl(FromT from
) {
766 std::vector
<SymbolTable::SymbolUse
> uses
;
767 auto walkFn
= [&](SymbolTable::SymbolUse symbolUse
) {
768 uses
.push_back(symbolUse
);
769 return WalkResult::advance();
771 auto result
= walkSymbolUses(from
, walkFn
);
772 return result
? std::optional
<SymbolTable::UseRange
>(std::move(uses
))
776 /// Get an iterator range for all of the uses, for any symbol, that are nested
777 /// within the given operation 'from'. This does not traverse into any nested
778 /// symbol tables, and will also only return uses on 'from' if it does not
779 /// also define a symbol table. This is because we treat the region as the
780 /// boundary of the symbol table, and not the op itself. This function returns
781 /// std::nullopt if there are any unknown operations that may potentially be
783 auto SymbolTable::getSymbolUses(Operation
*from
) -> std::optional
<UseRange
> {
784 return getSymbolUsesImpl(from
);
786 auto SymbolTable::getSymbolUses(Region
*from
) -> std::optional
<UseRange
> {
787 return getSymbolUsesImpl(MutableArrayRef
<Region
>(*from
));
790 //===----------------------------------------------------------------------===//
791 // SymbolTable::getSymbolUses
793 /// The implementation of SymbolTable::getSymbolUses below.
794 template <typename SymbolT
, typename IRUnitT
>
795 static std::optional
<SymbolTable::UseRange
> getSymbolUsesImpl(SymbolT symbol
,
797 std::vector
<SymbolTable::SymbolUse
> uses
;
798 for (SymbolScope
&scope
: collectSymbolScopes(symbol
, limit
)) {
799 if (!scope
.walk([&](SymbolTable::SymbolUse symbolUse
) {
800 if (isReferencePrefixOf(scope
.symbol
, symbolUse
.getSymbolRef()))
801 uses
.push_back(symbolUse
);
805 return SymbolTable::UseRange(std::move(uses
));
808 /// Get all of the uses of the given symbol that are nested within the given
809 /// operation 'from', invoking the provided callback for each. This does not
810 /// traverse into any nested symbol tables. This function returns std::nullopt
811 /// if there are any unknown operations that may potentially be symbol tables.
812 auto SymbolTable::getSymbolUses(StringAttr symbol
, Operation
*from
)
813 -> std::optional
<UseRange
> {
814 return getSymbolUsesImpl(symbol
, from
);
816 auto SymbolTable::getSymbolUses(Operation
*symbol
, Operation
*from
)
817 -> std::optional
<UseRange
> {
818 return getSymbolUsesImpl(symbol
, from
);
820 auto SymbolTable::getSymbolUses(StringAttr symbol
, Region
*from
)
821 -> std::optional
<UseRange
> {
822 return getSymbolUsesImpl(symbol
, from
);
824 auto SymbolTable::getSymbolUses(Operation
*symbol
, Region
*from
)
825 -> std::optional
<UseRange
> {
826 return getSymbolUsesImpl(symbol
, from
);
829 //===----------------------------------------------------------------------===//
830 // SymbolTable::symbolKnownUseEmpty
832 /// The implementation of SymbolTable::symbolKnownUseEmpty below.
833 template <typename SymbolT
, typename IRUnitT
>
834 static bool symbolKnownUseEmptyImpl(SymbolT symbol
, IRUnitT
*limit
) {
835 for (SymbolScope
&scope
: collectSymbolScopes(symbol
, limit
)) {
836 // Walk all of the symbol uses looking for a reference to 'symbol'.
837 if (scope
.walk([&](SymbolTable::SymbolUse symbolUse
) {
838 return isReferencePrefixOf(scope
.symbol
, symbolUse
.getSymbolRef())
839 ? WalkResult::interrupt()
840 : WalkResult::advance();
841 }) != WalkResult::advance())
847 /// Return if the given symbol is known to have no uses that are nested within
848 /// the given operation 'from'. This does not traverse into any nested symbol
849 /// tables. This function will also return false if there are any unknown
850 /// operations that may potentially be symbol tables.
851 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol
, Operation
*from
) {
852 return symbolKnownUseEmptyImpl(symbol
, from
);
854 bool SymbolTable::symbolKnownUseEmpty(Operation
*symbol
, Operation
*from
) {
855 return symbolKnownUseEmptyImpl(symbol
, from
);
857 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol
, Region
*from
) {
858 return symbolKnownUseEmptyImpl(symbol
, from
);
860 bool SymbolTable::symbolKnownUseEmpty(Operation
*symbol
, Region
*from
) {
861 return symbolKnownUseEmptyImpl(symbol
, from
);
864 //===----------------------------------------------------------------------===//
865 // SymbolTable::replaceAllSymbolUses
867 /// Generates a new symbol reference attribute with a new leaf reference.
868 static SymbolRefAttr
generateNewRefAttr(SymbolRefAttr oldAttr
,
869 FlatSymbolRefAttr newLeafAttr
) {
870 if (llvm::isa
<FlatSymbolRefAttr
>(oldAttr
))
872 auto nestedRefs
= llvm::to_vector
<2>(oldAttr
.getNestedReferences());
873 nestedRefs
.back() = newLeafAttr
;
874 return SymbolRefAttr::get(oldAttr
.getRootReference(), nestedRefs
);
877 /// The implementation of SymbolTable::replaceAllSymbolUses below.
878 template <typename SymbolT
, typename IRUnitT
>
880 replaceAllSymbolUsesImpl(SymbolT symbol
, StringAttr newSymbol
, IRUnitT
*limit
) {
881 // Generate a new attribute to replace the given attribute.
882 FlatSymbolRefAttr newLeafAttr
= FlatSymbolRefAttr::get(newSymbol
);
883 for (SymbolScope
&scope
: collectSymbolScopes(symbol
, limit
)) {
884 SymbolRefAttr oldAttr
= scope
.symbol
;
885 SymbolRefAttr newAttr
= generateNewRefAttr(scope
.symbol
, newLeafAttr
);
886 AttrTypeReplacer replacer
;
887 replacer
.addReplacement(
888 [&](SymbolRefAttr attr
) -> std::pair
<Attribute
, WalkResult
> {
889 // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
890 // want to accidentally replace an inner reference.
892 return {newAttr
, WalkResult::skip()};
893 // Handle prefix matches.
894 if (isReferencePrefixOf(oldAttr
, attr
)) {
895 auto oldNestedRefs
= oldAttr
.getNestedReferences();
896 auto nestedRefs
= attr
.getNestedReferences();
897 if (oldNestedRefs
.empty())
898 return {SymbolRefAttr::get(newSymbol
, nestedRefs
),
901 auto newNestedRefs
= llvm::to_vector
<4>(nestedRefs
);
902 newNestedRefs
[oldNestedRefs
.size() - 1] = newLeafAttr
;
903 return {SymbolRefAttr::get(attr
.getRootReference(), newNestedRefs
),
906 return {attr
, WalkResult::skip()};
909 auto walkFn
= [&](Operation
*op
) -> std::optional
<WalkResult
> {
910 replacer
.replaceElementsIn(op
);
911 return WalkResult::advance();
913 if (!scope
.walkSymbolTable(walkFn
))
919 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
920 /// provided symbol 'newSymbol' that are nested within the given operation
921 /// 'from'. This does not traverse into any nested symbol tables. If there are
922 /// any unknown operations that may potentially be symbol tables, no uses are
923 /// replaced and failure is returned.
924 LogicalResult
SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol
,
925 StringAttr newSymbol
,
927 return replaceAllSymbolUsesImpl(oldSymbol
, newSymbol
, from
);
929 LogicalResult
SymbolTable::replaceAllSymbolUses(Operation
*oldSymbol
,
930 StringAttr newSymbol
,
932 return replaceAllSymbolUsesImpl(oldSymbol
, newSymbol
, from
);
934 LogicalResult
SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol
,
935 StringAttr newSymbol
,
937 return replaceAllSymbolUsesImpl(oldSymbol
, newSymbol
, from
);
939 LogicalResult
SymbolTable::replaceAllSymbolUses(Operation
*oldSymbol
,
940 StringAttr newSymbol
,
942 return replaceAllSymbolUsesImpl(oldSymbol
, newSymbol
, from
);
945 //===----------------------------------------------------------------------===//
946 // SymbolTableCollection
947 //===----------------------------------------------------------------------===//
949 Operation
*SymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
951 return getSymbolTable(symbolTableOp
).lookup(symbol
);
953 Operation
*SymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
954 SymbolRefAttr name
) {
955 SmallVector
<Operation
*, 4> symbols
;
956 if (failed(lookupSymbolIn(symbolTableOp
, name
, symbols
)))
958 return symbols
.back();
960 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
961 /// a given SymbolRefAttr. Returns failure if any of the nested references could
964 SymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
966 SmallVectorImpl
<Operation
*> &symbols
) {
967 auto lookupFn
= [this](Operation
*symbolTableOp
, StringAttr symbol
) {
968 return lookupSymbolIn(symbolTableOp
, symbol
);
970 return lookupSymbolInImpl(symbolTableOp
, name
, symbols
, lookupFn
);
973 /// Returns the operation registered with the given symbol name within the
974 /// closest parent operation of, or including, 'from' with the
975 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
977 Operation
*SymbolTableCollection::lookupNearestSymbolFrom(Operation
*from
,
979 Operation
*symbolTableOp
= SymbolTable::getNearestSymbolTable(from
);
980 return symbolTableOp
? lookupSymbolIn(symbolTableOp
, symbol
) : nullptr;
983 SymbolTableCollection::lookupNearestSymbolFrom(Operation
*from
,
984 SymbolRefAttr symbol
) {
985 Operation
*symbolTableOp
= SymbolTable::getNearestSymbolTable(from
);
986 return symbolTableOp
? lookupSymbolIn(symbolTableOp
, symbol
) : nullptr;
989 /// Lookup, or create, a symbol table for an operation.
990 SymbolTable
&SymbolTableCollection::getSymbolTable(Operation
*op
) {
991 auto it
= symbolTables
.try_emplace(op
, nullptr);
993 it
.first
->second
= std::make_unique
<SymbolTable
>(op
);
994 return *it
.first
->second
;
997 //===----------------------------------------------------------------------===//
998 // LockedSymbolTableCollection
999 //===----------------------------------------------------------------------===//
1001 Operation
*LockedSymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
1002 StringAttr symbol
) {
1003 return getSymbolTable(symbolTableOp
).lookup(symbol
);
1007 LockedSymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
1008 FlatSymbolRefAttr symbol
) {
1009 return lookupSymbolIn(symbolTableOp
, symbol
.getAttr());
1012 Operation
*LockedSymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
1013 SymbolRefAttr name
) {
1014 SmallVector
<Operation
*> symbols
;
1015 if (failed(lookupSymbolIn(symbolTableOp
, name
, symbols
)))
1017 return symbols
.back();
1020 LogicalResult
LockedSymbolTableCollection::lookupSymbolIn(
1021 Operation
*symbolTableOp
, SymbolRefAttr name
,
1022 SmallVectorImpl
<Operation
*> &symbols
) {
1023 auto lookupFn
= [this](Operation
*symbolTableOp
, StringAttr symbol
) {
1024 return lookupSymbolIn(symbolTableOp
, symbol
);
1026 return lookupSymbolInImpl(symbolTableOp
, name
, symbols
, lookupFn
);
1030 LockedSymbolTableCollection::getSymbolTable(Operation
*symbolTableOp
) {
1031 assert(symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>());
1032 // Try to find an existing symbol table.
1034 llvm::sys::SmartScopedReader
<true> lock(mutex
);
1035 auto it
= collection
.symbolTables
.find(symbolTableOp
);
1036 if (it
!= collection
.symbolTables
.end())
1039 // Create a symbol table for the operation. Perform construction outside of
1040 // the critical section.
1041 auto symbolTable
= std::make_unique
<SymbolTable
>(symbolTableOp
);
1042 // Insert the constructed symbol table.
1043 llvm::sys::SmartScopedWriter
<true> lock(mutex
);
1044 return *collection
.symbolTables
1045 .insert({symbolTableOp
, std::move(symbolTable
)})
1049 //===----------------------------------------------------------------------===//
1051 //===----------------------------------------------------------------------===//
1053 SymbolUserMap::SymbolUserMap(SymbolTableCollection
&symbolTable
,
1054 Operation
*symbolTableOp
)
1055 : symbolTable(symbolTable
) {
1056 // Walk each of the symbol tables looking for discardable callgraph nodes.
1057 SmallVector
<Operation
*> symbols
;
1058 auto walkFn
= [&](Operation
*symbolTableOp
, bool allUsesVisible
) {
1059 for (Operation
&nestedOp
: symbolTableOp
->getRegion(0).getOps()) {
1060 auto symbolUses
= SymbolTable::getSymbolUses(&nestedOp
);
1061 assert(symbolUses
&& "expected uses to be valid");
1063 for (const SymbolTable::SymbolUse
&use
: *symbolUses
) {
1065 (void)symbolTable
.lookupSymbolIn(symbolTableOp
, use
.getSymbolRef(),
1067 for (Operation
*symbolOp
: symbols
)
1068 symbolToUsers
[symbolOp
].insert(use
.getUser());
1072 // We just set `allSymUsesVisible` to false here because it isn't necessary
1073 // for building the user map.
1074 SymbolTable::walkSymbolTables(symbolTableOp
, /*allSymUsesVisible=*/false,
1078 void SymbolUserMap::replaceAllUsesWith(Operation
*symbol
,
1079 StringAttr newSymbolName
) {
1080 auto it
= symbolToUsers
.find(symbol
);
1081 if (it
== symbolToUsers
.end())
1084 // Replace the uses within the users of `symbol`.
1085 for (Operation
*user
: it
->second
)
1086 (void)SymbolTable::replaceAllSymbolUses(symbol
, newSymbolName
, user
);
1088 // Move the current users of `symbol` to the new symbol if it is in the
1090 Operation
*newSymbol
=
1091 symbolTable
.lookupSymbolIn(symbol
->getParentOp(), newSymbolName
);
1092 if (newSymbol
!= symbol
) {
1093 // Transfer over the users to the new symbol. The reference to the old one
1094 // is fetched again as the iterator is invalidated during the insertion.
1095 auto newIt
= symbolToUsers
.try_emplace(newSymbol
, SetVector
<Operation
*>{});
1096 auto oldIt
= symbolToUsers
.find(symbol
);
1097 assert(oldIt
!= symbolToUsers
.end() && "missing old users list");
1099 newIt
.first
->second
= std::move(oldIt
->second
);
1101 newIt
.first
->second
.set_union(oldIt
->second
);
1102 symbolToUsers
.erase(oldIt
);
1106 //===----------------------------------------------------------------------===//
1107 // Visibility parsing implementation.
1108 //===----------------------------------------------------------------------===//
1110 ParseResult
impl::parseOptionalVisibilityKeyword(OpAsmParser
&parser
,
1111 NamedAttrList
&attrs
) {
1112 StringRef visibility
;
1113 if (parser
.parseOptionalKeyword(&visibility
, {"public", "private", "nested"}))
1116 StringAttr visibilityAttr
= parser
.getBuilder().getStringAttr(visibility
);
1117 attrs
.push_back(parser
.getBuilder().getNamedAttr(
1118 SymbolTable::getVisibilityAttrName(), visibilityAttr
));
1122 //===----------------------------------------------------------------------===//
1123 // Symbol Interfaces
1124 //===----------------------------------------------------------------------===//
1126 /// Include the generated symbol interfaces.
1127 #include "mlir/IR/SymbolInterfaces.cpp.inc"