1 //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/SymbolTable.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/OpImplementation.h"
12 #include "llvm/ADT/SetVector.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallString.h"
15 #include "llvm/ADT/StringSwitch.h"
20 /// Return true if the given operation is unknown and may potentially define a
22 static bool isPotentiallyUnknownSymbolTable(Operation
*op
) {
23 return op
->getNumRegions() == 1 && !op
->getDialect();
26 /// Returns the string name of the given symbol, or null if this is not a
28 static StringAttr
getNameIfSymbol(Operation
*op
) {
29 return op
->getAttrOfType
<StringAttr
>(SymbolTable::getSymbolAttrName());
31 static StringAttr
getNameIfSymbol(Operation
*op
, StringAttr symbolAttrNameId
) {
32 return op
->getAttrOfType
<StringAttr
>(symbolAttrNameId
);
35 /// Computes the nested symbol reference attribute for the symbol 'symbolName'
36 /// that are usable within the symbol table operations from 'symbol' as far up
37 /// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
38 /// Returns success if all references up to 'within' could be computed.
40 collectValidReferencesFor(Operation
*symbol
, StringAttr symbolName
,
42 SmallVectorImpl
<SymbolRefAttr
> &results
) {
43 assert(within
->isAncestor(symbol
) && "expected 'within' to be an ancestor");
44 MLIRContext
*ctx
= symbol
->getContext();
46 auto leafRef
= FlatSymbolRefAttr::get(symbolName
);
47 results
.push_back(leafRef
);
49 // Early exit for when 'within' is the parent of 'symbol'.
50 Operation
*symbolTableOp
= symbol
->getParentOp();
51 if (within
== symbolTableOp
)
54 // Collect references until 'symbolTableOp' reaches 'within'.
55 SmallVector
<FlatSymbolRefAttr
, 1> nestedRefs(1, leafRef
);
56 StringAttr symbolNameId
=
57 StringAttr::get(ctx
, SymbolTable::getSymbolAttrName());
59 // Each parent of 'symbol' should define a symbol table.
60 if (!symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>())
62 // Each parent of 'symbol' should also be a symbol.
63 StringAttr symbolTableName
= getNameIfSymbol(symbolTableOp
, symbolNameId
);
66 results
.push_back(SymbolRefAttr::get(symbolTableName
, nestedRefs
));
68 symbolTableOp
= symbolTableOp
->getParentOp();
69 if (symbolTableOp
== within
)
71 nestedRefs
.insert(nestedRefs
.begin(),
72 FlatSymbolRefAttr::get(symbolTableName
));
77 /// Walk all of the operations within the given set of regions, without
78 /// traversing into any nested symbol tables. Stops walking if the result of the
79 /// callback is anything other than `WalkResult::advance`.
80 static std::optional
<WalkResult
>
81 walkSymbolTable(MutableArrayRef
<Region
> regions
,
82 function_ref
<std::optional
<WalkResult
>(Operation
*)> callback
) {
83 SmallVector
<Region
*, 1> worklist(llvm::make_pointer_range(regions
));
84 while (!worklist
.empty()) {
85 for (Operation
&op
: worklist
.pop_back_val()->getOps()) {
86 std::optional
<WalkResult
> result
= callback(&op
);
87 if (result
!= WalkResult::advance())
90 // If this op defines a new symbol table scope, we can't traverse. Any
91 // symbol references nested within 'op' are different semantically.
92 if (!op
.hasTrait
<OpTrait::SymbolTable
>()) {
93 for (Region
®ion
: op
.getRegions())
94 worklist
.push_back(®ion
);
98 return WalkResult::advance();
101 /// Walk all of the operations nested under, and including, the given operation,
102 /// without traversing into any nested symbol tables. Stops walking if the
103 /// result of the callback is anything other than `WalkResult::advance`.
104 static std::optional
<WalkResult
>
105 walkSymbolTable(Operation
*op
,
106 function_ref
<std::optional
<WalkResult
>(Operation
*)> callback
) {
107 std::optional
<WalkResult
> result
= callback(op
);
108 if (result
!= WalkResult::advance() || op
->hasTrait
<OpTrait::SymbolTable
>())
110 return walkSymbolTable(op
->getRegions(), callback
);
113 //===----------------------------------------------------------------------===//
115 //===----------------------------------------------------------------------===//
117 /// Build a symbol table with the symbols within the given operation.
118 SymbolTable::SymbolTable(Operation
*symbolTableOp
)
119 : symbolTableOp(symbolTableOp
) {
120 assert(symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>() &&
121 "expected operation to have SymbolTable trait");
122 assert(symbolTableOp
->getNumRegions() == 1 &&
123 "expected operation to have a single region");
124 assert(llvm::hasSingleElement(symbolTableOp
->getRegion(0)) &&
125 "expected operation to have a single block");
127 StringAttr symbolNameId
= StringAttr::get(symbolTableOp
->getContext(),
128 SymbolTable::getSymbolAttrName());
129 for (auto &op
: symbolTableOp
->getRegion(0).front()) {
130 StringAttr name
= getNameIfSymbol(&op
, symbolNameId
);
134 auto inserted
= symbolTable
.insert({name
, &op
});
136 assert(inserted
.second
&&
137 "expected region to contain uniquely named symbol operations");
141 /// Look up a symbol with the specified name, returning null if no such name
142 /// exists. Names never include the @ on them.
143 Operation
*SymbolTable::lookup(StringRef name
) const {
144 return lookup(StringAttr::get(symbolTableOp
->getContext(), name
));
146 Operation
*SymbolTable::lookup(StringAttr name
) const {
147 return symbolTable
.lookup(name
);
150 void SymbolTable::remove(Operation
*op
) {
151 StringAttr name
= getNameIfSymbol(op
);
152 assert(name
&& "expected valid 'name' attribute");
153 assert(op
->getParentOp() == symbolTableOp
&&
154 "expected this operation to be inside of the operation with this "
157 auto it
= symbolTable
.find(name
);
158 if (it
!= symbolTable
.end() && it
->second
== op
)
159 symbolTable
.erase(it
);
162 void SymbolTable::erase(Operation
*symbol
) {
167 // TODO: Consider if this should be renamed to something like insertOrUpdate
168 /// Insert a new symbol into the table and associated operation if not already
169 /// there and rename it as necessary to avoid collisions. Return the name of
170 /// the symbol after insertion as attribute.
171 StringAttr
SymbolTable::insert(Operation
*symbol
, Block::iterator insertPt
) {
172 // The symbol cannot be the child of another op and must be the child of the
173 // symbolTableOp after this.
175 // TODO: consider if SymbolTable's constructor should behave the same.
176 if (!symbol
->getParentOp()) {
177 auto &body
= symbolTableOp
->getRegion(0).front();
178 if (insertPt
== Block::iterator()) {
179 insertPt
= Block::iterator(body
.end());
181 assert((insertPt
== body
.end() ||
182 insertPt
->getParentOp() == symbolTableOp
) &&
183 "expected insertPt to be in the associated module operation");
185 // Insert before the terminator, if any.
186 if (insertPt
== Block::iterator(body
.end()) && !body
.empty() &&
187 std::prev(body
.end())->hasTrait
<OpTrait::IsTerminator
>())
188 insertPt
= std::prev(body
.end());
190 body
.getOperations().insert(insertPt
, symbol
);
192 assert(symbol
->getParentOp() == symbolTableOp
&&
193 "symbol is already inserted in another op");
195 // Add this symbol to the symbol table, uniquing the name if a conflict is
197 StringAttr name
= getSymbolName(symbol
);
198 if (symbolTable
.insert({name
, symbol
}).second
)
200 // If the symbol was already in the table, also return.
201 if (symbolTable
.lookup(name
) == symbol
)
203 // If a conflict was detected, then the symbol will not have been added to
204 // the symbol table. Try suffixes until we get to a unique name that works.
205 SmallString
<128> nameBuffer(name
.getValue());
206 unsigned originalLength
= nameBuffer
.size();
208 MLIRContext
*context
= symbol
->getContext();
210 // Iteratively try suffixes until we find one that isn't used.
212 nameBuffer
.resize(originalLength
);
214 nameBuffer
+= std::to_string(uniquingCounter
++);
215 } while (!symbolTable
.insert({StringAttr::get(context
, nameBuffer
), symbol
})
217 setSymbolName(symbol
, nameBuffer
);
218 return getSymbolName(symbol
);
221 LogicalResult
SymbolTable::rename(StringAttr from
, StringAttr to
) {
222 Operation
*op
= lookup(from
);
223 return rename(op
, to
);
226 LogicalResult
SymbolTable::rename(Operation
*op
, StringAttr to
) {
227 StringAttr from
= getNameIfSymbol(op
);
230 assert(from
&& "expected valid 'name' attribute");
231 assert(op
->getParentOp() == symbolTableOp
&&
232 "expected this operation to be inside of the operation with this "
234 assert(lookup(from
) == op
&& "current name does not resolve to op");
235 assert(lookup(to
) == nullptr && "new name already exists");
237 if (failed(SymbolTable::replaceAllSymbolUses(op
, to
, getOp())))
240 // Remove op with old name, change name, add with new name. The order is
241 // important here due to how `remove` and `insert` rely on the op name.
243 setSymbolName(op
, to
);
246 assert(lookup(to
) == op
&& "new name does not resolve to renamed op");
247 assert(lookup(from
) == nullptr && "old name still exists");
252 LogicalResult
SymbolTable::rename(StringAttr from
, StringRef to
) {
253 auto toAttr
= StringAttr::get(getOp()->getContext(), to
);
254 return rename(from
, toAttr
);
257 LogicalResult
SymbolTable::rename(Operation
*op
, StringRef to
) {
258 auto toAttr
= StringAttr::get(getOp()->getContext(), to
);
259 return rename(op
, toAttr
);
262 FailureOr
<StringAttr
>
263 SymbolTable::renameToUnique(StringAttr oldName
,
264 ArrayRef
<SymbolTable
*> others
) {
266 // Determine new name that is unique in all symbol tables.
269 MLIRContext
*context
= oldName
.getContext();
270 SmallString
<64> prefix
= oldName
.getValue();
272 prefix
.push_back('_');
274 newName
= StringAttr::get(context
, prefix
+ Twine(uniqueId
++));
275 auto lookupNewName
= [&](SymbolTable
*st
) { return st
->lookup(newName
); };
276 if (!lookupNewName(this) && llvm::none_of(others
, lookupNewName
)) {
283 if (failed(rename(oldName
, newName
)))
288 FailureOr
<StringAttr
>
289 SymbolTable::renameToUnique(Operation
*op
, ArrayRef
<SymbolTable
*> others
) {
290 StringAttr from
= getNameIfSymbol(op
);
291 assert(from
&& "expected valid 'name' attribute");
292 return renameToUnique(from
, others
);
295 /// Returns the name of the given symbol operation.
296 StringAttr
SymbolTable::getSymbolName(Operation
*symbol
) {
297 StringAttr name
= getNameIfSymbol(symbol
);
298 assert(name
&& "expected valid symbol name");
302 /// Sets the name of the given symbol operation.
303 void SymbolTable::setSymbolName(Operation
*symbol
, StringAttr name
) {
304 symbol
->setAttr(getSymbolAttrName(), name
);
307 /// Returns the visibility of the given symbol operation.
308 SymbolTable::Visibility
SymbolTable::getSymbolVisibility(Operation
*symbol
) {
309 // If the attribute doesn't exist, assume public.
310 StringAttr vis
= symbol
->getAttrOfType
<StringAttr
>(getVisibilityAttrName());
312 return Visibility::Public
;
314 // Otherwise, switch on the string value.
315 return StringSwitch
<Visibility
>(vis
.getValue())
316 .Case("private", Visibility::Private
)
317 .Case("nested", Visibility::Nested
)
318 .Case("public", Visibility::Public
);
320 /// Sets the visibility of the given symbol operation.
321 void SymbolTable::setSymbolVisibility(Operation
*symbol
, Visibility vis
) {
322 MLIRContext
*ctx
= symbol
->getContext();
324 // If the visibility is public, just drop the attribute as this is the
326 if (vis
== Visibility::Public
) {
327 symbol
->removeAttr(StringAttr::get(ctx
, getVisibilityAttrName()));
331 // Otherwise, update the attribute.
332 assert((vis
== Visibility::Private
|| vis
== Visibility::Nested
) &&
333 "unknown symbol visibility kind");
335 StringRef visName
= vis
== Visibility::Private
? "private" : "nested";
336 symbol
->setAttr(getVisibilityAttrName(), StringAttr::get(ctx
, visName
));
339 /// Returns the nearest symbol table from a given operation `from`. Returns
340 /// nullptr if no valid parent symbol table could be found.
341 Operation
*SymbolTable::getNearestSymbolTable(Operation
*from
) {
342 assert(from
&& "expected valid operation");
343 if (isPotentiallyUnknownSymbolTable(from
))
346 while (!from
->hasTrait
<OpTrait::SymbolTable
>()) {
347 from
= from
->getParentOp();
349 // Check that this is a valid op and isn't an unknown symbol table.
350 if (!from
|| isPotentiallyUnknownSymbolTable(from
))
356 /// Walks all symbol table operations nested within, and including, `op`. For
357 /// each symbol table operation, the provided callback is invoked with the op
358 /// and a boolean signifying if the symbols within that symbol table can be
359 /// treated as if all uses are visible. `allSymUsesVisible` identifies whether
360 /// all of the symbol uses of symbols within `op` are visible.
361 void SymbolTable::walkSymbolTables(
362 Operation
*op
, bool allSymUsesVisible
,
363 function_ref
<void(Operation
*, bool)> callback
) {
364 bool isSymbolTable
= op
->hasTrait
<OpTrait::SymbolTable
>();
366 SymbolOpInterface symbol
= dyn_cast
<SymbolOpInterface
>(op
);
367 allSymUsesVisible
|= !symbol
|| symbol
.isPrivate();
369 // Otherwise if 'op' is not a symbol table, any nested symbols are
370 // guaranteed to be hidden.
371 allSymUsesVisible
= true;
374 for (Region
®ion
: op
->getRegions())
375 for (Block
&block
: region
)
376 for (Operation
&nestedOp
: block
)
377 walkSymbolTables(&nestedOp
, allSymUsesVisible
, callback
);
379 // If 'op' had the symbol table trait, visit it after any nested symbol
382 callback(op
, allSymUsesVisible
);
385 /// Returns the operation registered with the given symbol name with the
386 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
387 /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
389 Operation
*SymbolTable::lookupSymbolIn(Operation
*symbolTableOp
,
391 assert(symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>());
392 Region
®ion
= symbolTableOp
->getRegion(0);
396 // Look for a symbol with the given name.
397 StringAttr symbolNameId
= StringAttr::get(symbolTableOp
->getContext(),
398 SymbolTable::getSymbolAttrName());
399 for (auto &op
: region
.front())
400 if (getNameIfSymbol(&op
, symbolNameId
) == symbol
)
404 Operation
*SymbolTable::lookupSymbolIn(Operation
*symbolTableOp
,
405 SymbolRefAttr symbol
) {
406 SmallVector
<Operation
*, 4> resolvedSymbols
;
407 if (failed(lookupSymbolIn(symbolTableOp
, symbol
, resolvedSymbols
)))
409 return resolvedSymbols
.back();
412 /// Internal implementation of `lookupSymbolIn` that allows for specialized
413 /// implementations of the lookup function.
414 static LogicalResult
lookupSymbolInImpl(
415 Operation
*symbolTableOp
, SymbolRefAttr symbol
,
416 SmallVectorImpl
<Operation
*> &symbols
,
417 function_ref
<Operation
*(Operation
*, StringAttr
)> lookupSymbolFn
) {
418 assert(symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>());
420 // Lookup the root reference for this symbol.
421 symbolTableOp
= lookupSymbolFn(symbolTableOp
, symbol
.getRootReference());
424 symbols
.push_back(symbolTableOp
);
426 // If there are no nested references, just return the root symbol directly.
427 ArrayRef
<FlatSymbolRefAttr
> nestedRefs
= symbol
.getNestedReferences();
428 if (nestedRefs
.empty())
431 // Verify that the root is also a symbol table.
432 if (!symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>())
435 // Otherwise, lookup each of the nested non-leaf references and ensure that
436 // each corresponds to a valid symbol table.
437 for (FlatSymbolRefAttr ref
: nestedRefs
.drop_back()) {
438 symbolTableOp
= lookupSymbolFn(symbolTableOp
, ref
.getAttr());
439 if (!symbolTableOp
|| !symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>())
441 symbols
.push_back(symbolTableOp
);
443 symbols
.push_back(lookupSymbolFn(symbolTableOp
, symbol
.getLeafReference()));
444 return success(symbols
.back());
448 SymbolTable::lookupSymbolIn(Operation
*symbolTableOp
, SymbolRefAttr symbol
,
449 SmallVectorImpl
<Operation
*> &symbols
) {
450 auto lookupFn
= [](Operation
*symbolTableOp
, StringAttr symbol
) {
451 return lookupSymbolIn(symbolTableOp
, symbol
);
453 return lookupSymbolInImpl(symbolTableOp
, symbol
, symbols
, lookupFn
);
456 /// Returns the operation registered with the given symbol name within the
457 /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
458 /// nullptr if no valid symbol was found.
459 Operation
*SymbolTable::lookupNearestSymbolFrom(Operation
*from
,
461 Operation
*symbolTableOp
= getNearestSymbolTable(from
);
462 return symbolTableOp
? lookupSymbolIn(symbolTableOp
, symbol
) : nullptr;
464 Operation
*SymbolTable::lookupNearestSymbolFrom(Operation
*from
,
465 SymbolRefAttr symbol
) {
466 Operation
*symbolTableOp
= getNearestSymbolTable(from
);
467 return symbolTableOp
? lookupSymbolIn(symbolTableOp
, symbol
) : nullptr;
470 raw_ostream
&mlir::operator<<(raw_ostream
&os
,
471 SymbolTable::Visibility visibility
) {
472 switch (visibility
) {
473 case SymbolTable::Visibility::Public
:
474 return os
<< "public";
475 case SymbolTable::Visibility::Private
:
476 return os
<< "private";
477 case SymbolTable::Visibility::Nested
:
478 return os
<< "nested";
480 llvm_unreachable("Unexpected visibility");
483 //===----------------------------------------------------------------------===//
484 // SymbolTable Trait Types
485 //===----------------------------------------------------------------------===//
487 LogicalResult
detail::verifySymbolTable(Operation
*op
) {
488 if (op
->getNumRegions() != 1)
489 return op
->emitOpError()
490 << "Operations with a 'SymbolTable' must have exactly one region";
491 if (!llvm::hasSingleElement(op
->getRegion(0)))
492 return op
->emitOpError()
493 << "Operations with a 'SymbolTable' must have exactly one block";
495 // Check that all symbols are uniquely named within child regions.
496 DenseMap
<Attribute
, Location
> nameToOrigLoc
;
497 for (auto &block
: op
->getRegion(0)) {
498 for (auto &op
: block
) {
499 // Check for a symbol name attribute.
501 op
.getAttrOfType
<StringAttr
>(mlir::SymbolTable::getSymbolAttrName());
505 // Try to insert this symbol into the table.
506 auto it
= nameToOrigLoc
.try_emplace(nameAttr
, op
.getLoc());
508 return op
.emitError()
509 .append("redefinition of symbol named '", nameAttr
.getValue(), "'")
510 .attachNote(it
.first
->second
)
511 .append("see existing symbol definition here");
515 // Verify any nested symbol user operations.
516 SymbolTableCollection symbolTable
;
517 auto verifySymbolUserFn
= [&](Operation
*op
) -> std::optional
<WalkResult
> {
518 if (SymbolUserOpInterface user
= dyn_cast
<SymbolUserOpInterface
>(op
))
519 return WalkResult(user
.verifySymbolUses(symbolTable
));
520 return WalkResult::advance();
523 std::optional
<WalkResult
> result
=
524 walkSymbolTable(op
->getRegions(), verifySymbolUserFn
);
525 return success(result
&& !result
->wasInterrupted());
528 LogicalResult
detail::verifySymbol(Operation
*op
) {
529 // Verify the name attribute.
530 if (!op
->getAttrOfType
<StringAttr
>(mlir::SymbolTable::getSymbolAttrName()))
531 return op
->emitOpError() << "requires string attribute '"
532 << mlir::SymbolTable::getSymbolAttrName() << "'";
534 // Verify the visibility attribute.
535 if (Attribute vis
= op
->getAttr(mlir::SymbolTable::getVisibilityAttrName())) {
536 StringAttr visStrAttr
= llvm::dyn_cast
<StringAttr
>(vis
);
538 return op
->emitOpError() << "requires visibility attribute '"
539 << mlir::SymbolTable::getVisibilityAttrName()
540 << "' to be a string attribute, but got " << vis
;
542 if (!llvm::is_contained(ArrayRef
<StringRef
>{"public", "private", "nested"},
543 visStrAttr
.getValue()))
544 return op
->emitOpError()
545 << "visibility expected to be one of [\"public\", \"private\", "
546 "\"nested\"], but got "
552 //===----------------------------------------------------------------------===//
554 //===----------------------------------------------------------------------===//
556 /// Walk all of the symbol references within the given operation, invoking the
557 /// provided callback for each found use. The callbacks takes the use of the
560 walkSymbolRefs(Operation
*op
,
561 function_ref
<WalkResult(SymbolTable::SymbolUse
)> callback
) {
562 return op
->getAttrDictionary().walk
<WalkOrder::PreOrder
>(
563 [&](SymbolRefAttr symbolRef
) {
564 if (callback({op
, symbolRef
}).wasInterrupted())
565 return WalkResult::interrupt();
567 // Don't walk nested references.
568 return WalkResult::skip();
572 /// Walk all of the uses, for any symbol, that are nested within the given
573 /// regions, invoking the provided callback for each. This does not traverse
574 /// into any nested symbol tables.
575 static std::optional
<WalkResult
>
576 walkSymbolUses(MutableArrayRef
<Region
> regions
,
577 function_ref
<WalkResult(SymbolTable::SymbolUse
)> callback
) {
578 return walkSymbolTable(regions
,
579 [&](Operation
*op
) -> std::optional
<WalkResult
> {
580 // Check that this isn't a potentially unknown symbol
582 if (isPotentiallyUnknownSymbolTable(op
))
585 return walkSymbolRefs(op
, callback
);
588 /// Walk all of the uses, for any symbol, that are nested within the given
589 /// operation 'from', invoking the provided callback for each. This does not
590 /// traverse into any nested symbol tables.
591 static std::optional
<WalkResult
>
592 walkSymbolUses(Operation
*from
,
593 function_ref
<WalkResult(SymbolTable::SymbolUse
)> callback
) {
594 // If this operation has regions, and it, as well as its dialect, isn't
595 // registered then conservatively fail. The operation may define a
596 // symbol table, so we can't opaquely know if we should traverse to find
598 if (isPotentiallyUnknownSymbolTable(from
))
601 // Walk the uses on this operation.
602 if (walkSymbolRefs(from
, callback
).wasInterrupted())
603 return WalkResult::interrupt();
605 // Only recurse if this operation is not a symbol table. A symbol table
606 // defines a new scope, so we can't walk the attributes from within the symbol
608 if (!from
->hasTrait
<OpTrait::SymbolTable
>())
609 return walkSymbolUses(from
->getRegions(), callback
);
610 return WalkResult::advance();
614 /// This class represents a single symbol scope. A symbol scope represents the
615 /// set of operations nested within a symbol table that may reference symbols
616 /// within that table. A symbol scope does not contain the symbol table
617 /// operation itself, just its contained operations. A scope ends at leaf
618 /// operations or another symbol table operation.
620 /// Walk the symbol uses within this scope, invoking the given callback.
621 /// This variant is used when the callback type matches that expected by
622 /// 'walkSymbolUses'.
623 template <typename CallbackT
,
624 std::enable_if_t
<!std::is_same
<
625 typename
llvm::function_traits
<CallbackT
>::result_t
,
626 void>::value
> * = nullptr>
627 std::optional
<WalkResult
> walk(CallbackT cback
) {
628 if (Region
*region
= llvm::dyn_cast_if_present
<Region
*>(limit
))
629 return walkSymbolUses(*region
, cback
);
630 return walkSymbolUses(limit
.get
<Operation
*>(), cback
);
632 /// This variant is used when the callback type matches a stripped down type:
633 /// void(SymbolTable::SymbolUse use)
634 template <typename CallbackT
,
635 std::enable_if_t
<std::is_same
<
636 typename
llvm::function_traits
<CallbackT
>::result_t
,
637 void>::value
> * = nullptr>
638 std::optional
<WalkResult
> walk(CallbackT cback
) {
639 return walk([=](SymbolTable::SymbolUse use
) {
640 return cback(use
), WalkResult::advance();
644 /// Walk all of the operations nested under the current scope without
645 /// traversing into any nested symbol tables.
646 template <typename CallbackT
>
647 std::optional
<WalkResult
> walkSymbolTable(CallbackT
&&cback
) {
648 if (Region
*region
= llvm::dyn_cast_if_present
<Region
*>(limit
))
649 return ::walkSymbolTable(*region
, cback
);
650 return ::walkSymbolTable(limit
.get
<Operation
*>(), cback
);
653 /// The representation of the symbol within this scope.
654 SymbolRefAttr symbol
;
656 /// The IR unit representing this scope.
657 llvm::PointerUnion
<Operation
*, Region
*> limit
;
661 /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
662 static SmallVector
<SymbolScope
, 2> collectSymbolScopes(Operation
*symbol
,
664 StringAttr symName
= SymbolTable::getSymbolName(symbol
);
665 assert(!symbol
->hasTrait
<OpTrait::SymbolTable
>() || symbol
!= limit
);
667 // Compute the ancestors of 'limit'.
668 SetVector
<Operation
*, SmallVector
<Operation
*, 4>,
669 SmallPtrSet
<Operation
*, 4>>
671 Operation
*limitAncestor
= limit
;
673 // Check to see if 'symbol' is an ancestor of 'limit'.
674 if (limitAncestor
== symbol
) {
675 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
676 // doesn't support parent references.
677 if (SymbolTable::getNearestSymbolTable(limit
->getParentOp()) ==
678 symbol
->getParentOp())
679 return {{SymbolRefAttr::get(symName
), limit
}};
683 limitAncestors
.insert(limitAncestor
);
684 } while ((limitAncestor
= limitAncestor
->getParentOp()));
686 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
687 Operation
*commonAncestor
= symbol
->getParentOp();
689 if (limitAncestors
.count(commonAncestor
))
691 } while ((commonAncestor
= commonAncestor
->getParentOp()));
692 assert(commonAncestor
&& "'limit' and 'symbol' have no common ancestor");
694 // Compute the set of valid nested references for 'symbol' as far up to the
695 // common ancestor as possible.
696 SmallVector
<SymbolRefAttr
, 2> references
;
697 bool collectedAllReferences
= succeeded(
698 collectValidReferencesFor(symbol
, symName
, commonAncestor
, references
));
700 // Handle the case where the common ancestor is 'limit'.
701 if (commonAncestor
== limit
) {
702 SmallVector
<SymbolScope
, 2> scopes
;
704 // Walk each of the ancestors of 'symbol', calling the compute function for
706 Operation
*limitIt
= symbol
->getParentOp();
707 for (size_t i
= 0, e
= references
.size(); i
!= e
;
708 ++i
, limitIt
= limitIt
->getParentOp()) {
709 assert(limitIt
->hasTrait
<OpTrait::SymbolTable
>());
710 scopes
.push_back({references
[i
], &limitIt
->getRegion(0)});
715 // Otherwise, we just need the symbol reference for 'symbol' that will be
716 // used within 'limit'. This is the last reference in the list we computed
717 // above if we were able to collect all references.
718 if (!collectedAllReferences
)
720 return {{references
.back(), limit
}};
722 static SmallVector
<SymbolScope
, 2> collectSymbolScopes(Operation
*symbol
,
724 auto scopes
= collectSymbolScopes(symbol
, limit
->getParentOp());
726 // If we collected some scopes to walk, make sure to constrain the one for
727 // limit to the specific region requested.
729 scopes
.back().limit
= limit
;
732 static SmallVector
<SymbolScope
, 1> collectSymbolScopes(StringAttr symbol
,
734 return {{SymbolRefAttr::get(symbol
), limit
}};
737 static SmallVector
<SymbolScope
, 1> collectSymbolScopes(StringAttr symbol
,
739 SmallVector
<SymbolScope
, 1> scopes
;
740 auto symbolRef
= SymbolRefAttr::get(symbol
);
741 for (auto ®ion
: limit
->getRegions())
742 scopes
.push_back({symbolRef
, ®ion
});
746 /// Returns true if the given reference 'SubRef' is a sub reference of the
747 /// reference 'ref', i.e. 'ref' is a further qualified reference.
748 static bool isReferencePrefixOf(SymbolRefAttr subRef
, SymbolRefAttr ref
) {
752 // If the references are not pointer equal, check to see if `subRef` is a
754 if (llvm::isa
<FlatSymbolRefAttr
>(ref
) ||
755 ref
.getRootReference() != subRef
.getRootReference())
758 auto refLeafs
= ref
.getNestedReferences();
759 auto subRefLeafs
= subRef
.getNestedReferences();
760 return subRefLeafs
.size() < refLeafs
.size() &&
761 subRefLeafs
== refLeafs
.take_front(subRefLeafs
.size());
764 //===----------------------------------------------------------------------===//
765 // SymbolTable::getSymbolUses
767 /// The implementation of SymbolTable::getSymbolUses below.
768 template <typename FromT
>
769 static std::optional
<SymbolTable::UseRange
> getSymbolUsesImpl(FromT from
) {
770 std::vector
<SymbolTable::SymbolUse
> uses
;
771 auto walkFn
= [&](SymbolTable::SymbolUse symbolUse
) {
772 uses
.push_back(symbolUse
);
773 return WalkResult::advance();
775 auto result
= walkSymbolUses(from
, walkFn
);
776 return result
? std::optional
<SymbolTable::UseRange
>(std::move(uses
))
780 /// Get an iterator range for all of the uses, for any symbol, that are nested
781 /// within the given operation 'from'. This does not traverse into any nested
782 /// symbol tables, and will also only return uses on 'from' if it does not
783 /// also define a symbol table. This is because we treat the region as the
784 /// boundary of the symbol table, and not the op itself. This function returns
785 /// std::nullopt if there are any unknown operations that may potentially be
787 auto SymbolTable::getSymbolUses(Operation
*from
) -> std::optional
<UseRange
> {
788 return getSymbolUsesImpl(from
);
790 auto SymbolTable::getSymbolUses(Region
*from
) -> std::optional
<UseRange
> {
791 return getSymbolUsesImpl(MutableArrayRef
<Region
>(*from
));
794 //===----------------------------------------------------------------------===//
795 // SymbolTable::getSymbolUses
797 /// The implementation of SymbolTable::getSymbolUses below.
798 template <typename SymbolT
, typename IRUnitT
>
799 static std::optional
<SymbolTable::UseRange
> getSymbolUsesImpl(SymbolT symbol
,
801 std::vector
<SymbolTable::SymbolUse
> uses
;
802 for (SymbolScope
&scope
: collectSymbolScopes(symbol
, limit
)) {
803 if (!scope
.walk([&](SymbolTable::SymbolUse symbolUse
) {
804 if (isReferencePrefixOf(scope
.symbol
, symbolUse
.getSymbolRef()))
805 uses
.push_back(symbolUse
);
809 return SymbolTable::UseRange(std::move(uses
));
812 /// Get all of the uses of the given symbol that are nested within the given
813 /// operation 'from', invoking the provided callback for each. This does not
814 /// traverse into any nested symbol tables. This function returns std::nullopt
815 /// if there are any unknown operations that may potentially be symbol tables.
816 auto SymbolTable::getSymbolUses(StringAttr symbol
, Operation
*from
)
817 -> std::optional
<UseRange
> {
818 return getSymbolUsesImpl(symbol
, from
);
820 auto SymbolTable::getSymbolUses(Operation
*symbol
, Operation
*from
)
821 -> std::optional
<UseRange
> {
822 return getSymbolUsesImpl(symbol
, from
);
824 auto SymbolTable::getSymbolUses(StringAttr symbol
, Region
*from
)
825 -> std::optional
<UseRange
> {
826 return getSymbolUsesImpl(symbol
, from
);
828 auto SymbolTable::getSymbolUses(Operation
*symbol
, Region
*from
)
829 -> std::optional
<UseRange
> {
830 return getSymbolUsesImpl(symbol
, from
);
833 //===----------------------------------------------------------------------===//
834 // SymbolTable::symbolKnownUseEmpty
836 /// The implementation of SymbolTable::symbolKnownUseEmpty below.
837 template <typename SymbolT
, typename IRUnitT
>
838 static bool symbolKnownUseEmptyImpl(SymbolT symbol
, IRUnitT
*limit
) {
839 for (SymbolScope
&scope
: collectSymbolScopes(symbol
, limit
)) {
840 // Walk all of the symbol uses looking for a reference to 'symbol'.
841 if (scope
.walk([&](SymbolTable::SymbolUse symbolUse
) {
842 return isReferencePrefixOf(scope
.symbol
, symbolUse
.getSymbolRef())
843 ? WalkResult::interrupt()
844 : WalkResult::advance();
845 }) != WalkResult::advance())
851 /// Return if the given symbol is known to have no uses that are nested within
852 /// the given operation 'from'. This does not traverse into any nested symbol
853 /// tables. This function will also return false if there are any unknown
854 /// operations that may potentially be symbol tables.
855 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol
, Operation
*from
) {
856 return symbolKnownUseEmptyImpl(symbol
, from
);
858 bool SymbolTable::symbolKnownUseEmpty(Operation
*symbol
, Operation
*from
) {
859 return symbolKnownUseEmptyImpl(symbol
, from
);
861 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol
, Region
*from
) {
862 return symbolKnownUseEmptyImpl(symbol
, from
);
864 bool SymbolTable::symbolKnownUseEmpty(Operation
*symbol
, Region
*from
) {
865 return symbolKnownUseEmptyImpl(symbol
, from
);
868 //===----------------------------------------------------------------------===//
869 // SymbolTable::replaceAllSymbolUses
871 /// Generates a new symbol reference attribute with a new leaf reference.
872 static SymbolRefAttr
generateNewRefAttr(SymbolRefAttr oldAttr
,
873 FlatSymbolRefAttr newLeafAttr
) {
874 if (llvm::isa
<FlatSymbolRefAttr
>(oldAttr
))
876 auto nestedRefs
= llvm::to_vector
<2>(oldAttr
.getNestedReferences());
877 nestedRefs
.back() = newLeafAttr
;
878 return SymbolRefAttr::get(oldAttr
.getRootReference(), nestedRefs
);
881 /// The implementation of SymbolTable::replaceAllSymbolUses below.
882 template <typename SymbolT
, typename IRUnitT
>
884 replaceAllSymbolUsesImpl(SymbolT symbol
, StringAttr newSymbol
, IRUnitT
*limit
) {
885 // Generate a new attribute to replace the given attribute.
886 FlatSymbolRefAttr newLeafAttr
= FlatSymbolRefAttr::get(newSymbol
);
887 for (SymbolScope
&scope
: collectSymbolScopes(symbol
, limit
)) {
888 SymbolRefAttr oldAttr
= scope
.symbol
;
889 SymbolRefAttr newAttr
= generateNewRefAttr(scope
.symbol
, newLeafAttr
);
890 AttrTypeReplacer replacer
;
891 replacer
.addReplacement(
892 [&](SymbolRefAttr attr
) -> std::pair
<Attribute
, WalkResult
> {
893 // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
894 // want to accidentally replace an inner reference.
896 return {newAttr
, WalkResult::skip()};
897 // Handle prefix matches.
898 if (isReferencePrefixOf(oldAttr
, attr
)) {
899 auto oldNestedRefs
= oldAttr
.getNestedReferences();
900 auto nestedRefs
= attr
.getNestedReferences();
901 if (oldNestedRefs
.empty())
902 return {SymbolRefAttr::get(newSymbol
, nestedRefs
),
905 auto newNestedRefs
= llvm::to_vector
<4>(nestedRefs
);
906 newNestedRefs
[oldNestedRefs
.size() - 1] = newLeafAttr
;
907 return {SymbolRefAttr::get(attr
.getRootReference(), newNestedRefs
),
910 return {attr
, WalkResult::skip()};
913 auto walkFn
= [&](Operation
*op
) -> std::optional
<WalkResult
> {
914 replacer
.replaceElementsIn(op
);
915 return WalkResult::advance();
917 if (!scope
.walkSymbolTable(walkFn
))
923 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
924 /// provided symbol 'newSymbol' that are nested within the given operation
925 /// 'from'. This does not traverse into any nested symbol tables. If there are
926 /// any unknown operations that may potentially be symbol tables, no uses are
927 /// replaced and failure is returned.
928 LogicalResult
SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol
,
929 StringAttr newSymbol
,
931 return replaceAllSymbolUsesImpl(oldSymbol
, newSymbol
, from
);
933 LogicalResult
SymbolTable::replaceAllSymbolUses(Operation
*oldSymbol
,
934 StringAttr newSymbol
,
936 return replaceAllSymbolUsesImpl(oldSymbol
, newSymbol
, from
);
938 LogicalResult
SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol
,
939 StringAttr newSymbol
,
941 return replaceAllSymbolUsesImpl(oldSymbol
, newSymbol
, from
);
943 LogicalResult
SymbolTable::replaceAllSymbolUses(Operation
*oldSymbol
,
944 StringAttr newSymbol
,
946 return replaceAllSymbolUsesImpl(oldSymbol
, newSymbol
, from
);
949 //===----------------------------------------------------------------------===//
950 // SymbolTableCollection
951 //===----------------------------------------------------------------------===//
953 Operation
*SymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
955 return getSymbolTable(symbolTableOp
).lookup(symbol
);
957 Operation
*SymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
958 SymbolRefAttr name
) {
959 SmallVector
<Operation
*, 4> symbols
;
960 if (failed(lookupSymbolIn(symbolTableOp
, name
, symbols
)))
962 return symbols
.back();
964 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
965 /// a given SymbolRefAttr. Returns failure if any of the nested references could
968 SymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
970 SmallVectorImpl
<Operation
*> &symbols
) {
971 auto lookupFn
= [this](Operation
*symbolTableOp
, StringAttr symbol
) {
972 return lookupSymbolIn(symbolTableOp
, symbol
);
974 return lookupSymbolInImpl(symbolTableOp
, name
, symbols
, lookupFn
);
977 /// Returns the operation registered with the given symbol name within the
978 /// closest parent operation of, or including, 'from' with the
979 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
981 Operation
*SymbolTableCollection::lookupNearestSymbolFrom(Operation
*from
,
983 Operation
*symbolTableOp
= SymbolTable::getNearestSymbolTable(from
);
984 return symbolTableOp
? lookupSymbolIn(symbolTableOp
, symbol
) : nullptr;
987 SymbolTableCollection::lookupNearestSymbolFrom(Operation
*from
,
988 SymbolRefAttr symbol
) {
989 Operation
*symbolTableOp
= SymbolTable::getNearestSymbolTable(from
);
990 return symbolTableOp
? lookupSymbolIn(symbolTableOp
, symbol
) : nullptr;
993 /// Lookup, or create, a symbol table for an operation.
994 SymbolTable
&SymbolTableCollection::getSymbolTable(Operation
*op
) {
995 auto it
= symbolTables
.try_emplace(op
, nullptr);
997 it
.first
->second
= std::make_unique
<SymbolTable
>(op
);
998 return *it
.first
->second
;
1001 //===----------------------------------------------------------------------===//
1002 // LockedSymbolTableCollection
1003 //===----------------------------------------------------------------------===//
1005 Operation
*LockedSymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
1006 StringAttr symbol
) {
1007 return getSymbolTable(symbolTableOp
).lookup(symbol
);
1011 LockedSymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
1012 FlatSymbolRefAttr symbol
) {
1013 return lookupSymbolIn(symbolTableOp
, symbol
.getAttr());
1016 Operation
*LockedSymbolTableCollection::lookupSymbolIn(Operation
*symbolTableOp
,
1017 SymbolRefAttr name
) {
1018 SmallVector
<Operation
*> symbols
;
1019 if (failed(lookupSymbolIn(symbolTableOp
, name
, symbols
)))
1021 return symbols
.back();
1024 LogicalResult
LockedSymbolTableCollection::lookupSymbolIn(
1025 Operation
*symbolTableOp
, SymbolRefAttr name
,
1026 SmallVectorImpl
<Operation
*> &symbols
) {
1027 auto lookupFn
= [this](Operation
*symbolTableOp
, StringAttr symbol
) {
1028 return lookupSymbolIn(symbolTableOp
, symbol
);
1030 return lookupSymbolInImpl(symbolTableOp
, name
, symbols
, lookupFn
);
1034 LockedSymbolTableCollection::getSymbolTable(Operation
*symbolTableOp
) {
1035 assert(symbolTableOp
->hasTrait
<OpTrait::SymbolTable
>());
1036 // Try to find an existing symbol table.
1038 llvm::sys::SmartScopedReader
<true> lock(mutex
);
1039 auto it
= collection
.symbolTables
.find(symbolTableOp
);
1040 if (it
!= collection
.symbolTables
.end())
1043 // Create a symbol table for the operation. Perform construction outside of
1044 // the critical section.
1045 auto symbolTable
= std::make_unique
<SymbolTable
>(symbolTableOp
);
1046 // Insert the constructed symbol table.
1047 llvm::sys::SmartScopedWriter
<true> lock(mutex
);
1048 return *collection
.symbolTables
1049 .insert({symbolTableOp
, std::move(symbolTable
)})
1053 //===----------------------------------------------------------------------===//
1055 //===----------------------------------------------------------------------===//
1057 SymbolUserMap::SymbolUserMap(SymbolTableCollection
&symbolTable
,
1058 Operation
*symbolTableOp
)
1059 : symbolTable(symbolTable
) {
1060 // Walk each of the symbol tables looking for discardable callgraph nodes.
1061 SmallVector
<Operation
*> symbols
;
1062 auto walkFn
= [&](Operation
*symbolTableOp
, bool allUsesVisible
) {
1063 for (Operation
&nestedOp
: symbolTableOp
->getRegion(0).getOps()) {
1064 auto symbolUses
= SymbolTable::getSymbolUses(&nestedOp
);
1065 assert(symbolUses
&& "expected uses to be valid");
1067 for (const SymbolTable::SymbolUse
&use
: *symbolUses
) {
1069 (void)symbolTable
.lookupSymbolIn(symbolTableOp
, use
.getSymbolRef(),
1071 for (Operation
*symbolOp
: symbols
)
1072 symbolToUsers
[symbolOp
].insert(use
.getUser());
1076 // We just set `allSymUsesVisible` to false here because it isn't necessary
1077 // for building the user map.
1078 SymbolTable::walkSymbolTables(symbolTableOp
, /*allSymUsesVisible=*/false,
1082 void SymbolUserMap::replaceAllUsesWith(Operation
*symbol
,
1083 StringAttr newSymbolName
) {
1084 auto it
= symbolToUsers
.find(symbol
);
1085 if (it
== symbolToUsers
.end())
1088 // Replace the uses within the users of `symbol`.
1089 for (Operation
*user
: it
->second
)
1090 (void)SymbolTable::replaceAllSymbolUses(symbol
, newSymbolName
, user
);
1092 // Move the current users of `symbol` to the new symbol if it is in the
1094 Operation
*newSymbol
=
1095 symbolTable
.lookupSymbolIn(symbol
->getParentOp(), newSymbolName
);
1096 if (newSymbol
!= symbol
) {
1097 // Transfer over the users to the new symbol. The reference to the old one
1098 // is fetched again as the iterator is invalidated during the insertion.
1099 auto newIt
= symbolToUsers
.try_emplace(newSymbol
, SetVector
<Operation
*>{});
1100 auto oldIt
= symbolToUsers
.find(symbol
);
1101 assert(oldIt
!= symbolToUsers
.end() && "missing old users list");
1103 newIt
.first
->second
= std::move(oldIt
->second
);
1105 newIt
.first
->second
.set_union(oldIt
->second
);
1106 symbolToUsers
.erase(oldIt
);
1110 //===----------------------------------------------------------------------===//
1111 // Visibility parsing implementation.
1112 //===----------------------------------------------------------------------===//
1114 ParseResult
impl::parseOptionalVisibilityKeyword(OpAsmParser
&parser
,
1115 NamedAttrList
&attrs
) {
1116 StringRef visibility
;
1117 if (parser
.parseOptionalKeyword(&visibility
, {"public", "private", "nested"}))
1120 StringAttr visibilityAttr
= parser
.getBuilder().getStringAttr(visibility
);
1121 attrs
.push_back(parser
.getBuilder().getNamedAttr(
1122 SymbolTable::getVisibilityAttrName(), visibilityAttr
));
1126 //===----------------------------------------------------------------------===//
1127 // Symbol Interfaces
1128 //===----------------------------------------------------------------------===//
1130 /// Include the generated symbol interfaces.
1131 #include "mlir/IR/SymbolInterfaces.cpp.inc"