[mlir] Fix resource printing in the presence of multiple dialects
[llvm-project.git] / mlir / lib / IR / SymbolTable.cpp
blob2494cb7086f0d7dcf21704bba2d21d73bcec76d7
1 //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/SymbolTable.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/OpImplementation.h"
12 #include "llvm/ADT/SetVector.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallString.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include <optional>
18 using namespace mlir;
20 /// Return true if the given operation is unknown and may potentially define a
21 /// symbol table.
22 static bool isPotentiallyUnknownSymbolTable(Operation *op) {
23 return op->getNumRegions() == 1 && !op->getDialect();
26 /// Returns the string name of the given symbol, or null if this is not a
27 /// symbol.
28 static StringAttr getNameIfSymbol(Operation *op) {
29 return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
31 static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
32 return op->getAttrOfType<StringAttr>(symbolAttrNameId);
35 /// Computes the nested symbol reference attribute for the symbol 'symbolName'
36 /// that are usable within the symbol table operations from 'symbol' as far up
37 /// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
38 /// Returns success if all references up to 'within' could be computed.
39 static LogicalResult
40 collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
41 Operation *within,
42 SmallVectorImpl<SymbolRefAttr> &results) {
43 assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
44 MLIRContext *ctx = symbol->getContext();
46 auto leafRef = FlatSymbolRefAttr::get(symbolName);
47 results.push_back(leafRef);
49 // Early exit for when 'within' is the parent of 'symbol'.
50 Operation *symbolTableOp = symbol->getParentOp();
51 if (within == symbolTableOp)
52 return success();
54 // Collect references until 'symbolTableOp' reaches 'within'.
55 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
56 StringAttr symbolNameId =
57 StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
58 do {
59 // Each parent of 'symbol' should define a symbol table.
60 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
61 return failure();
62 // Each parent of 'symbol' should also be a symbol.
63 StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
64 if (!symbolTableName)
65 return failure();
66 results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
68 symbolTableOp = symbolTableOp->getParentOp();
69 if (symbolTableOp == within)
70 break;
71 nestedRefs.insert(nestedRefs.begin(),
72 FlatSymbolRefAttr::get(symbolTableName));
73 } while (true);
74 return success();
77 /// Walk all of the operations within the given set of regions, without
78 /// traversing into any nested symbol tables. Stops walking if the result of the
79 /// callback is anything other than `WalkResult::advance`.
80 static std::optional<WalkResult>
81 walkSymbolTable(MutableArrayRef<Region> regions,
82 function_ref<std::optional<WalkResult>(Operation *)> callback) {
83 SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
84 while (!worklist.empty()) {
85 for (Operation &op : worklist.pop_back_val()->getOps()) {
86 std::optional<WalkResult> result = callback(&op);
87 if (result != WalkResult::advance())
88 return result;
90 // If this op defines a new symbol table scope, we can't traverse. Any
91 // symbol references nested within 'op' are different semantically.
92 if (!op.hasTrait<OpTrait::SymbolTable>()) {
93 for (Region &region : op.getRegions())
94 worklist.push_back(&region);
98 return WalkResult::advance();
101 /// Walk all of the operations nested under, and including, the given operation,
102 /// without traversing into any nested symbol tables. Stops walking if the
103 /// result of the callback is anything other than `WalkResult::advance`.
104 static std::optional<WalkResult>
105 walkSymbolTable(Operation *op,
106 function_ref<std::optional<WalkResult>(Operation *)> callback) {
107 std::optional<WalkResult> result = callback(op);
108 if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>())
109 return result;
110 return walkSymbolTable(op->getRegions(), callback);
113 //===----------------------------------------------------------------------===//
114 // SymbolTable
115 //===----------------------------------------------------------------------===//
117 /// Build a symbol table with the symbols within the given operation.
118 SymbolTable::SymbolTable(Operation *symbolTableOp)
119 : symbolTableOp(symbolTableOp) {
120 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
121 "expected operation to have SymbolTable trait");
122 assert(symbolTableOp->getNumRegions() == 1 &&
123 "expected operation to have a single region");
124 assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) &&
125 "expected operation to have a single block");
127 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
128 SymbolTable::getSymbolAttrName());
129 for (auto &op : symbolTableOp->getRegion(0).front()) {
130 StringAttr name = getNameIfSymbol(&op, symbolNameId);
131 if (!name)
132 continue;
134 auto inserted = symbolTable.insert({name, &op});
135 (void)inserted;
136 assert(inserted.second &&
137 "expected region to contain uniquely named symbol operations");
141 /// Look up a symbol with the specified name, returning null if no such name
142 /// exists. Names never include the @ on them.
143 Operation *SymbolTable::lookup(StringRef name) const {
144 return lookup(StringAttr::get(symbolTableOp->getContext(), name));
146 Operation *SymbolTable::lookup(StringAttr name) const {
147 return symbolTable.lookup(name);
150 void SymbolTable::remove(Operation *op) {
151 StringAttr name = getNameIfSymbol(op);
152 assert(name && "expected valid 'name' attribute");
153 assert(op->getParentOp() == symbolTableOp &&
154 "expected this operation to be inside of the operation with this "
155 "SymbolTable");
157 auto it = symbolTable.find(name);
158 if (it != symbolTable.end() && it->second == op)
159 symbolTable.erase(it);
162 void SymbolTable::erase(Operation *symbol) {
163 remove(symbol);
164 symbol->erase();
167 // TODO: Consider if this should be renamed to something like insertOrUpdate
168 /// Insert a new symbol into the table and associated operation if not already
169 /// there and rename it as necessary to avoid collisions. Return the name of
170 /// the symbol after insertion as attribute.
171 StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
172 // The symbol cannot be the child of another op and must be the child of the
173 // symbolTableOp after this.
175 // TODO: consider if SymbolTable's constructor should behave the same.
176 if (!symbol->getParentOp()) {
177 auto &body = symbolTableOp->getRegion(0).front();
178 if (insertPt == Block::iterator()) {
179 insertPt = Block::iterator(body.end());
180 } else {
181 assert((insertPt == body.end() ||
182 insertPt->getParentOp() == symbolTableOp) &&
183 "expected insertPt to be in the associated module operation");
185 // Insert before the terminator, if any.
186 if (insertPt == Block::iterator(body.end()) && !body.empty() &&
187 std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
188 insertPt = std::prev(body.end());
190 body.getOperations().insert(insertPt, symbol);
192 assert(symbol->getParentOp() == symbolTableOp &&
193 "symbol is already inserted in another op");
195 // Add this symbol to the symbol table, uniquing the name if a conflict is
196 // detected.
197 StringAttr name = getSymbolName(symbol);
198 if (symbolTable.insert({name, symbol}).second)
199 return name;
200 // If the symbol was already in the table, also return.
201 if (symbolTable.lookup(name) == symbol)
202 return name;
203 // If a conflict was detected, then the symbol will not have been added to
204 // the symbol table. Try suffixes until we get to a unique name that works.
205 SmallString<128> nameBuffer(name.getValue());
206 unsigned originalLength = nameBuffer.size();
208 MLIRContext *context = symbol->getContext();
210 // Iteratively try suffixes until we find one that isn't used.
211 do {
212 nameBuffer.resize(originalLength);
213 nameBuffer += '_';
214 nameBuffer += std::to_string(uniquingCounter++);
215 } while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol})
216 .second);
217 setSymbolName(symbol, nameBuffer);
218 return getSymbolName(symbol);
221 /// Returns the name of the given symbol operation.
222 StringAttr SymbolTable::getSymbolName(Operation *symbol) {
223 StringAttr name = getNameIfSymbol(symbol);
224 assert(name && "expected valid symbol name");
225 return name;
228 /// Sets the name of the given symbol operation.
229 void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
230 symbol->setAttr(getSymbolAttrName(), name);
233 /// Returns the visibility of the given symbol operation.
234 SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) {
235 // If the attribute doesn't exist, assume public.
236 StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
237 if (!vis)
238 return Visibility::Public;
240 // Otherwise, switch on the string value.
241 return StringSwitch<Visibility>(vis.getValue())
242 .Case("private", Visibility::Private)
243 .Case("nested", Visibility::Nested)
244 .Case("public", Visibility::Public);
246 /// Sets the visibility of the given symbol operation.
247 void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
248 MLIRContext *ctx = symbol->getContext();
250 // If the visibility is public, just drop the attribute as this is the
251 // default.
252 if (vis == Visibility::Public) {
253 symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName()));
254 return;
257 // Otherwise, update the attribute.
258 assert((vis == Visibility::Private || vis == Visibility::Nested) &&
259 "unknown symbol visibility kind");
261 StringRef visName = vis == Visibility::Private ? "private" : "nested";
262 symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
265 /// Returns the nearest symbol table from a given operation `from`. Returns
266 /// nullptr if no valid parent symbol table could be found.
267 Operation *SymbolTable::getNearestSymbolTable(Operation *from) {
268 assert(from && "expected valid operation");
269 if (isPotentiallyUnknownSymbolTable(from))
270 return nullptr;
272 while (!from->hasTrait<OpTrait::SymbolTable>()) {
273 from = from->getParentOp();
275 // Check that this is a valid op and isn't an unknown symbol table.
276 if (!from || isPotentiallyUnknownSymbolTable(from))
277 return nullptr;
279 return from;
282 /// Walks all symbol table operations nested within, and including, `op`. For
283 /// each symbol table operation, the provided callback is invoked with the op
284 /// and a boolean signifying if the symbols within that symbol table can be
285 /// treated as if all uses are visible. `allSymUsesVisible` identifies whether
286 /// all of the symbol uses of symbols within `op` are visible.
287 void SymbolTable::walkSymbolTables(
288 Operation *op, bool allSymUsesVisible,
289 function_ref<void(Operation *, bool)> callback) {
290 bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
291 if (isSymbolTable) {
292 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
293 allSymUsesVisible |= !symbol || symbol.isPrivate();
294 } else {
295 // Otherwise if 'op' is not a symbol table, any nested symbols are
296 // guaranteed to be hidden.
297 allSymUsesVisible = true;
300 for (Region &region : op->getRegions())
301 for (Block &block : region)
302 for (Operation &nestedOp : block)
303 walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
305 // If 'op' had the symbol table trait, visit it after any nested symbol
306 // tables.
307 if (isSymbolTable)
308 callback(op, allSymUsesVisible);
311 /// Returns the operation registered with the given symbol name with the
312 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
313 /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
314 /// was found.
315 Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
316 StringAttr symbol) {
317 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
318 Region &region = symbolTableOp->getRegion(0);
319 if (region.empty())
320 return nullptr;
322 // Look for a symbol with the given name.
323 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
324 SymbolTable::getSymbolAttrName());
325 for (auto &op : region.front())
326 if (getNameIfSymbol(&op, symbolNameId) == symbol)
327 return &op;
328 return nullptr;
330 Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
331 SymbolRefAttr symbol) {
332 SmallVector<Operation *, 4> resolvedSymbols;
333 if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
334 return nullptr;
335 return resolvedSymbols.back();
338 /// Internal implementation of `lookupSymbolIn` that allows for specialized
339 /// implementations of the lookup function.
340 static LogicalResult lookupSymbolInImpl(
341 Operation *symbolTableOp, SymbolRefAttr symbol,
342 SmallVectorImpl<Operation *> &symbols,
343 function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
344 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
346 // Lookup the root reference for this symbol.
347 symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
348 if (!symbolTableOp)
349 return failure();
350 symbols.push_back(symbolTableOp);
352 // If there are no nested references, just return the root symbol directly.
353 ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
354 if (nestedRefs.empty())
355 return success();
357 // Verify that the root is also a symbol table.
358 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
359 return failure();
361 // Otherwise, lookup each of the nested non-leaf references and ensure that
362 // each corresponds to a valid symbol table.
363 for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
364 symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
365 if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
366 return failure();
367 symbols.push_back(symbolTableOp);
369 symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
370 return success(symbols.back());
373 LogicalResult
374 SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
375 SmallVectorImpl<Operation *> &symbols) {
376 auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
377 return lookupSymbolIn(symbolTableOp, symbol);
379 return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
382 /// Returns the operation registered with the given symbol name within the
383 /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
384 /// nullptr if no valid symbol was found.
385 Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
386 StringAttr symbol) {
387 Operation *symbolTableOp = getNearestSymbolTable(from);
388 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
390 Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
391 SymbolRefAttr symbol) {
392 Operation *symbolTableOp = getNearestSymbolTable(from);
393 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
396 raw_ostream &mlir::operator<<(raw_ostream &os,
397 SymbolTable::Visibility visibility) {
398 switch (visibility) {
399 case SymbolTable::Visibility::Public:
400 return os << "public";
401 case SymbolTable::Visibility::Private:
402 return os << "private";
403 case SymbolTable::Visibility::Nested:
404 return os << "nested";
406 llvm_unreachable("Unexpected visibility");
409 //===----------------------------------------------------------------------===//
410 // SymbolTable Trait Types
411 //===----------------------------------------------------------------------===//
413 LogicalResult detail::verifySymbolTable(Operation *op) {
414 if (op->getNumRegions() != 1)
415 return op->emitOpError()
416 << "Operations with a 'SymbolTable' must have exactly one region";
417 if (!llvm::hasSingleElement(op->getRegion(0)))
418 return op->emitOpError()
419 << "Operations with a 'SymbolTable' must have exactly one block";
421 // Check that all symbols are uniquely named within child regions.
422 DenseMap<Attribute, Location> nameToOrigLoc;
423 for (auto &block : op->getRegion(0)) {
424 for (auto &op : block) {
425 // Check for a symbol name attribute.
426 auto nameAttr =
427 op.getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName());
428 if (!nameAttr)
429 continue;
431 // Try to insert this symbol into the table.
432 auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
433 if (!it.second)
434 return op.emitError()
435 .append("redefinition of symbol named '", nameAttr.getValue(), "'")
436 .attachNote(it.first->second)
437 .append("see existing symbol definition here");
441 // Verify any nested symbol user operations.
442 SymbolTableCollection symbolTable;
443 auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
444 if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
445 return WalkResult(user.verifySymbolUses(symbolTable));
446 return WalkResult::advance();
449 std::optional<WalkResult> result =
450 walkSymbolTable(op->getRegions(), verifySymbolUserFn);
451 return success(result && !result->wasInterrupted());
454 LogicalResult detail::verifySymbol(Operation *op) {
455 // Verify the name attribute.
456 if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
457 return op->emitOpError() << "requires string attribute '"
458 << mlir::SymbolTable::getSymbolAttrName() << "'";
460 // Verify the visibility attribute.
461 if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) {
462 StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
463 if (!visStrAttr)
464 return op->emitOpError() << "requires visibility attribute '"
465 << mlir::SymbolTable::getVisibilityAttrName()
466 << "' to be a string attribute, but got " << vis;
468 if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
469 visStrAttr.getValue()))
470 return op->emitOpError()
471 << "visibility expected to be one of [\"public\", \"private\", "
472 "\"nested\"], but got "
473 << visStrAttr;
475 return success();
478 //===----------------------------------------------------------------------===//
479 // Symbol Use Lists
480 //===----------------------------------------------------------------------===//
482 /// Walk all of the symbol references within the given operation, invoking the
483 /// provided callback for each found use. The callbacks takes the use of the
484 /// symbol.
485 static WalkResult
486 walkSymbolRefs(Operation *op,
487 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
488 return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
489 [&](SymbolRefAttr symbolRef) {
490 if (callback({op, symbolRef}).wasInterrupted())
491 return WalkResult::interrupt();
493 // Don't walk nested references.
494 return WalkResult::skip();
498 /// Walk all of the uses, for any symbol, that are nested within the given
499 /// regions, invoking the provided callback for each. This does not traverse
500 /// into any nested symbol tables.
501 static std::optional<WalkResult>
502 walkSymbolUses(MutableArrayRef<Region> regions,
503 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
504 return walkSymbolTable(regions,
505 [&](Operation *op) -> std::optional<WalkResult> {
506 // Check that this isn't a potentially unknown symbol
507 // table.
508 if (isPotentiallyUnknownSymbolTable(op))
509 return std::nullopt;
511 return walkSymbolRefs(op, callback);
514 /// Walk all of the uses, for any symbol, that are nested within the given
515 /// operation 'from', invoking the provided callback for each. This does not
516 /// traverse into any nested symbol tables.
517 static std::optional<WalkResult>
518 walkSymbolUses(Operation *from,
519 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
520 // If this operation has regions, and it, as well as its dialect, isn't
521 // registered then conservatively fail. The operation may define a
522 // symbol table, so we can't opaquely know if we should traverse to find
523 // nested uses.
524 if (isPotentiallyUnknownSymbolTable(from))
525 return std::nullopt;
527 // Walk the uses on this operation.
528 if (walkSymbolRefs(from, callback).wasInterrupted())
529 return WalkResult::interrupt();
531 // Only recurse if this operation is not a symbol table. A symbol table
532 // defines a new scope, so we can't walk the attributes from within the symbol
533 // table op.
534 if (!from->hasTrait<OpTrait::SymbolTable>())
535 return walkSymbolUses(from->getRegions(), callback);
536 return WalkResult::advance();
539 namespace {
540 /// This class represents a single symbol scope. A symbol scope represents the
541 /// set of operations nested within a symbol table that may reference symbols
542 /// within that table. A symbol scope does not contain the symbol table
543 /// operation itself, just its contained operations. A scope ends at leaf
544 /// operations or another symbol table operation.
545 struct SymbolScope {
546 /// Walk the symbol uses within this scope, invoking the given callback.
547 /// This variant is used when the callback type matches that expected by
548 /// 'walkSymbolUses'.
549 template <typename CallbackT,
550 std::enable_if_t<!std::is_same<
551 typename llvm::function_traits<CallbackT>::result_t,
552 void>::value> * = nullptr>
553 std::optional<WalkResult> walk(CallbackT cback) {
554 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
555 return walkSymbolUses(*region, cback);
556 return walkSymbolUses(limit.get<Operation *>(), cback);
558 /// This variant is used when the callback type matches a stripped down type:
559 /// void(SymbolTable::SymbolUse use)
560 template <typename CallbackT,
561 std::enable_if_t<std::is_same<
562 typename llvm::function_traits<CallbackT>::result_t,
563 void>::value> * = nullptr>
564 std::optional<WalkResult> walk(CallbackT cback) {
565 return walk([=](SymbolTable::SymbolUse use) {
566 return cback(use), WalkResult::advance();
570 /// Walk all of the operations nested under the current scope without
571 /// traversing into any nested symbol tables.
572 template <typename CallbackT>
573 std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
574 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
575 return ::walkSymbolTable(*region, cback);
576 return ::walkSymbolTable(limit.get<Operation *>(), cback);
579 /// The representation of the symbol within this scope.
580 SymbolRefAttr symbol;
582 /// The IR unit representing this scope.
583 llvm::PointerUnion<Operation *, Region *> limit;
585 } // namespace
587 /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
588 static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
589 Operation *limit) {
590 StringAttr symName = SymbolTable::getSymbolName(symbol);
591 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
593 // Compute the ancestors of 'limit'.
594 SetVector<Operation *, SmallVector<Operation *, 4>,
595 SmallPtrSet<Operation *, 4>>
596 limitAncestors;
597 Operation *limitAncestor = limit;
598 do {
599 // Check to see if 'symbol' is an ancestor of 'limit'.
600 if (limitAncestor == symbol) {
601 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
602 // doesn't support parent references.
603 if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
604 symbol->getParentOp())
605 return {{SymbolRefAttr::get(symName), limit}};
606 return {};
609 limitAncestors.insert(limitAncestor);
610 } while ((limitAncestor = limitAncestor->getParentOp()));
612 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
613 Operation *commonAncestor = symbol->getParentOp();
614 do {
615 if (limitAncestors.count(commonAncestor))
616 break;
617 } while ((commonAncestor = commonAncestor->getParentOp()));
618 assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
620 // Compute the set of valid nested references for 'symbol' as far up to the
621 // common ancestor as possible.
622 SmallVector<SymbolRefAttr, 2> references;
623 bool collectedAllReferences = succeeded(
624 collectValidReferencesFor(symbol, symName, commonAncestor, references));
626 // Handle the case where the common ancestor is 'limit'.
627 if (commonAncestor == limit) {
628 SmallVector<SymbolScope, 2> scopes;
630 // Walk each of the ancestors of 'symbol', calling the compute function for
631 // each one.
632 Operation *limitIt = symbol->getParentOp();
633 for (size_t i = 0, e = references.size(); i != e;
634 ++i, limitIt = limitIt->getParentOp()) {
635 assert(limitIt->hasTrait<OpTrait::SymbolTable>());
636 scopes.push_back({references[i], &limitIt->getRegion(0)});
638 return scopes;
641 // Otherwise, we just need the symbol reference for 'symbol' that will be
642 // used within 'limit'. This is the last reference in the list we computed
643 // above if we were able to collect all references.
644 if (!collectedAllReferences)
645 return {};
646 return {{references.back(), limit}};
648 static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
649 Region *limit) {
650 auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
652 // If we collected some scopes to walk, make sure to constrain the one for
653 // limit to the specific region requested.
654 if (!scopes.empty())
655 scopes.back().limit = limit;
656 return scopes;
658 template <typename IRUnit>
659 static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
660 IRUnit *limit) {
661 return {{SymbolRefAttr::get(symbol), limit}};
664 /// Returns true if the given reference 'SubRef' is a sub reference of the
665 /// reference 'ref', i.e. 'ref' is a further qualified reference.
666 static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
667 if (ref == subRef)
668 return true;
670 // If the references are not pointer equal, check to see if `subRef` is a
671 // prefix of `ref`.
672 if (llvm::isa<FlatSymbolRefAttr>(ref) ||
673 ref.getRootReference() != subRef.getRootReference())
674 return false;
676 auto refLeafs = ref.getNestedReferences();
677 auto subRefLeafs = subRef.getNestedReferences();
678 return subRefLeafs.size() < refLeafs.size() &&
679 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
682 //===----------------------------------------------------------------------===//
683 // SymbolTable::getSymbolUses
685 /// The implementation of SymbolTable::getSymbolUses below.
686 template <typename FromT>
687 static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
688 std::vector<SymbolTable::SymbolUse> uses;
689 auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
690 uses.push_back(symbolUse);
691 return WalkResult::advance();
693 auto result = walkSymbolUses(from, walkFn);
694 return result ? std::optional<SymbolTable::UseRange>(std::move(uses))
695 : std::nullopt;
698 /// Get an iterator range for all of the uses, for any symbol, that are nested
699 /// within the given operation 'from'. This does not traverse into any nested
700 /// symbol tables, and will also only return uses on 'from' if it does not
701 /// also define a symbol table. This is because we treat the region as the
702 /// boundary of the symbol table, and not the op itself. This function returns
703 /// std::nullopt if there are any unknown operations that may potentially be
704 /// symbol tables.
705 auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> {
706 return getSymbolUsesImpl(from);
708 auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> {
709 return getSymbolUsesImpl(MutableArrayRef<Region>(*from));
712 //===----------------------------------------------------------------------===//
713 // SymbolTable::getSymbolUses
715 /// The implementation of SymbolTable::getSymbolUses below.
716 template <typename SymbolT, typename IRUnitT>
717 static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
718 IRUnitT *limit) {
719 std::vector<SymbolTable::SymbolUse> uses;
720 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
721 if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
722 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
723 uses.push_back(symbolUse);
725 return std::nullopt;
727 return SymbolTable::UseRange(std::move(uses));
730 /// Get all of the uses of the given symbol that are nested within the given
731 /// operation 'from', invoking the provided callback for each. This does not
732 /// traverse into any nested symbol tables. This function returns std::nullopt
733 /// if there are any unknown operations that may potentially be symbol tables.
734 auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
735 -> std::optional<UseRange> {
736 return getSymbolUsesImpl(symbol, from);
738 auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
739 -> std::optional<UseRange> {
740 return getSymbolUsesImpl(symbol, from);
742 auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
743 -> std::optional<UseRange> {
744 return getSymbolUsesImpl(symbol, from);
746 auto SymbolTable::getSymbolUses(Operation *symbol, Region *from)
747 -> std::optional<UseRange> {
748 return getSymbolUsesImpl(symbol, from);
751 //===----------------------------------------------------------------------===//
752 // SymbolTable::symbolKnownUseEmpty
754 /// The implementation of SymbolTable::symbolKnownUseEmpty below.
755 template <typename SymbolT, typename IRUnitT>
756 static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
757 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
758 // Walk all of the symbol uses looking for a reference to 'symbol'.
759 if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
760 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
761 ? WalkResult::interrupt()
762 : WalkResult::advance();
763 }) != WalkResult::advance())
764 return false;
766 return true;
769 /// Return if the given symbol is known to have no uses that are nested within
770 /// the given operation 'from'. This does not traverse into any nested symbol
771 /// tables. This function will also return false if there are any unknown
772 /// operations that may potentially be symbol tables.
773 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
774 return symbolKnownUseEmptyImpl(symbol, from);
776 bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
777 return symbolKnownUseEmptyImpl(symbol, from);
779 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
780 return symbolKnownUseEmptyImpl(symbol, from);
782 bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
783 return symbolKnownUseEmptyImpl(symbol, from);
786 //===----------------------------------------------------------------------===//
787 // SymbolTable::replaceAllSymbolUses
789 /// Generates a new symbol reference attribute with a new leaf reference.
790 static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
791 FlatSymbolRefAttr newLeafAttr) {
792 if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
793 return newLeafAttr;
794 auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
795 nestedRefs.back() = newLeafAttr;
796 return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
799 /// The implementation of SymbolTable::replaceAllSymbolUses below.
800 template <typename SymbolT, typename IRUnitT>
801 static LogicalResult
802 replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
803 // Generate a new attribute to replace the given attribute.
804 FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
805 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
806 SymbolRefAttr oldAttr = scope.symbol;
807 SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
808 AttrTypeReplacer replacer;
809 replacer.addReplacement(
810 [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
811 // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
812 // want to accidentally replace an inner reference.
813 if (attr == oldAttr)
814 return {newAttr, WalkResult::skip()};
815 // Handle prefix matches.
816 if (isReferencePrefixOf(oldAttr, attr)) {
817 auto oldNestedRefs = oldAttr.getNestedReferences();
818 auto nestedRefs = attr.getNestedReferences();
819 if (oldNestedRefs.empty())
820 return {SymbolRefAttr::get(newSymbol, nestedRefs),
821 WalkResult::skip()};
823 auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
824 newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
825 return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
826 WalkResult::skip()};
828 return {attr, WalkResult::skip()};
831 auto walkFn = [&](Operation *op) -> std::optional<WalkResult> {
832 replacer.replaceElementsIn(op);
833 return WalkResult::advance();
835 if (!scope.walkSymbolTable(walkFn))
836 return failure();
838 return success();
841 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
842 /// provided symbol 'newSymbol' that are nested within the given operation
843 /// 'from'. This does not traverse into any nested symbol tables. If there are
844 /// any unknown operations that may potentially be symbol tables, no uses are
845 /// replaced and failure is returned.
846 LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
847 StringAttr newSymbol,
848 Operation *from) {
849 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
851 LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
852 StringAttr newSymbol,
853 Operation *from) {
854 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
856 LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
857 StringAttr newSymbol,
858 Region *from) {
859 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
861 LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
862 StringAttr newSymbol,
863 Region *from) {
864 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
867 //===----------------------------------------------------------------------===//
868 // SymbolTableCollection
869 //===----------------------------------------------------------------------===//
871 Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
872 StringAttr symbol) {
873 return getSymbolTable(symbolTableOp).lookup(symbol);
875 Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
876 SymbolRefAttr name) {
877 SmallVector<Operation *, 4> symbols;
878 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
879 return nullptr;
880 return symbols.back();
882 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
883 /// a given SymbolRefAttr. Returns failure if any of the nested references could
884 /// not be resolved.
885 LogicalResult
886 SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
887 SymbolRefAttr name,
888 SmallVectorImpl<Operation *> &symbols) {
889 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
890 return lookupSymbolIn(symbolTableOp, symbol);
892 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
895 /// Returns the operation registered with the given symbol name within the
896 /// closest parent operation of, or including, 'from' with the
897 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
898 /// found.
899 Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
900 StringAttr symbol) {
901 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
902 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
904 Operation *
905 SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
906 SymbolRefAttr symbol) {
907 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
908 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
911 /// Lookup, or create, a symbol table for an operation.
912 SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
913 auto it = symbolTables.try_emplace(op, nullptr);
914 if (it.second)
915 it.first->second = std::make_unique<SymbolTable>(op);
916 return *it.first->second;
919 //===----------------------------------------------------------------------===//
920 // LockedSymbolTableCollection
921 //===----------------------------------------------------------------------===//
923 Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
924 StringAttr symbol) {
925 return getSymbolTable(symbolTableOp).lookup(symbol);
928 Operation *
929 LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
930 FlatSymbolRefAttr symbol) {
931 return lookupSymbolIn(symbolTableOp, symbol.getAttr());
934 Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
935 SymbolRefAttr name) {
936 SmallVector<Operation *> symbols;
937 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
938 return nullptr;
939 return symbols.back();
942 LogicalResult LockedSymbolTableCollection::lookupSymbolIn(
943 Operation *symbolTableOp, SymbolRefAttr name,
944 SmallVectorImpl<Operation *> &symbols) {
945 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
946 return lookupSymbolIn(symbolTableOp, symbol);
948 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
951 SymbolTable &
952 LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) {
953 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
954 // Try to find an existing symbol table.
956 llvm::sys::SmartScopedReader<true> lock(mutex);
957 auto it = collection.symbolTables.find(symbolTableOp);
958 if (it != collection.symbolTables.end())
959 return *it->second;
961 // Create a symbol table for the operation. Perform construction outside of
962 // the critical section.
963 auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp);
964 // Insert the constructed symbol table.
965 llvm::sys::SmartScopedWriter<true> lock(mutex);
966 return *collection.symbolTables
967 .insert({symbolTableOp, std::move(symbolTable)})
968 .first->second;
971 //===----------------------------------------------------------------------===//
972 // SymbolUserMap
973 //===----------------------------------------------------------------------===//
975 SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable,
976 Operation *symbolTableOp)
977 : symbolTable(symbolTable) {
978 // Walk each of the symbol tables looking for discardable callgraph nodes.
979 SmallVector<Operation *> symbols;
980 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
981 for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
982 auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
983 assert(symbolUses && "expected uses to be valid");
985 for (const SymbolTable::SymbolUse &use : *symbolUses) {
986 symbols.clear();
987 (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
988 symbols);
989 for (Operation *symbolOp : symbols)
990 symbolToUsers[symbolOp].insert(use.getUser());
994 // We just set `allSymUsesVisible` to false here because it isn't necessary
995 // for building the user map.
996 SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
997 walkFn);
1000 void SymbolUserMap::replaceAllUsesWith(Operation *symbol,
1001 StringAttr newSymbolName) {
1002 auto it = symbolToUsers.find(symbol);
1003 if (it == symbolToUsers.end())
1004 return;
1006 // Replace the uses within the users of `symbol`.
1007 for (Operation *user : it->second)
1008 (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
1010 // Move the current users of `symbol` to the new symbol if it is in the
1011 // symbol table.
1012 Operation *newSymbol =
1013 symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
1014 if (newSymbol != symbol) {
1015 // Transfer over the users to the new symbol. The reference to the old one
1016 // is fetched again as the iterator is invalidated during the insertion.
1017 auto newIt = symbolToUsers.try_emplace(newSymbol, SetVector<Operation *>{});
1018 auto oldIt = symbolToUsers.find(symbol);
1019 assert(oldIt != symbolToUsers.end() && "missing old users list");
1020 if (newIt.second)
1021 newIt.first->second = std::move(oldIt->second);
1022 else
1023 newIt.first->second.set_union(oldIt->second);
1024 symbolToUsers.erase(oldIt);
1028 //===----------------------------------------------------------------------===//
1029 // Visibility parsing implementation.
1030 //===----------------------------------------------------------------------===//
1032 ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser,
1033 NamedAttrList &attrs) {
1034 StringRef visibility;
1035 if (parser.parseOptionalKeyword(&visibility, {"public", "private", "nested"}))
1036 return failure();
1038 StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
1039 attrs.push_back(parser.getBuilder().getNamedAttr(
1040 SymbolTable::getVisibilityAttrName(), visibilityAttr));
1041 return success();
1044 //===----------------------------------------------------------------------===//
1045 // Symbol Interfaces
1046 //===----------------------------------------------------------------------===//
1048 /// Include the generated symbol interfaces.
1049 #include "mlir/IR/SymbolInterfaces.cpp.inc"