[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / lib / IR / SymbolTable.cpp
blob7180ea432ea057d4828f522b0554abf5cce25291
1 //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/SymbolTable.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/OpImplementation.h"
12 #include "llvm/ADT/SetVector.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallString.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include <optional>
18 using namespace mlir;
20 /// Return true if the given operation is unknown and may potentially define a
21 /// symbol table.
22 static bool isPotentiallyUnknownSymbolTable(Operation *op) {
23 return op->getNumRegions() == 1 && !op->getDialect();
26 /// Returns the string name of the given symbol, or null if this is not a
27 /// symbol.
28 static StringAttr getNameIfSymbol(Operation *op) {
29 return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
31 static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
32 return op->getAttrOfType<StringAttr>(symbolAttrNameId);
35 /// Computes the nested symbol reference attribute for the symbol 'symbolName'
36 /// that are usable within the symbol table operations from 'symbol' as far up
37 /// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
38 /// Returns success if all references up to 'within' could be computed.
39 static LogicalResult
40 collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
41 Operation *within,
42 SmallVectorImpl<SymbolRefAttr> &results) {
43 assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
44 MLIRContext *ctx = symbol->getContext();
46 auto leafRef = FlatSymbolRefAttr::get(symbolName);
47 results.push_back(leafRef);
49 // Early exit for when 'within' is the parent of 'symbol'.
50 Operation *symbolTableOp = symbol->getParentOp();
51 if (within == symbolTableOp)
52 return success();
54 // Collect references until 'symbolTableOp' reaches 'within'.
55 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
56 StringAttr symbolNameId =
57 StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
58 do {
59 // Each parent of 'symbol' should define a symbol table.
60 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
61 return failure();
62 // Each parent of 'symbol' should also be a symbol.
63 StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
64 if (!symbolTableName)
65 return failure();
66 results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
68 symbolTableOp = symbolTableOp->getParentOp();
69 if (symbolTableOp == within)
70 break;
71 nestedRefs.insert(nestedRefs.begin(),
72 FlatSymbolRefAttr::get(symbolTableName));
73 } while (true);
74 return success();
77 /// Walk all of the operations within the given set of regions, without
78 /// traversing into any nested symbol tables. Stops walking if the result of the
79 /// callback is anything other than `WalkResult::advance`.
80 static std::optional<WalkResult>
81 walkSymbolTable(MutableArrayRef<Region> regions,
82 function_ref<std::optional<WalkResult>(Operation *)> callback) {
83 SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
84 while (!worklist.empty()) {
85 for (Operation &op : worklist.pop_back_val()->getOps()) {
86 std::optional<WalkResult> result = callback(&op);
87 if (result != WalkResult::advance())
88 return result;
90 // If this op defines a new symbol table scope, we can't traverse. Any
91 // symbol references nested within 'op' are different semantically.
92 if (!op.hasTrait<OpTrait::SymbolTable>()) {
93 for (Region &region : op.getRegions())
94 worklist.push_back(&region);
98 return WalkResult::advance();
101 /// Walk all of the operations nested under, and including, the given operation,
102 /// without traversing into any nested symbol tables. Stops walking if the
103 /// result of the callback is anything other than `WalkResult::advance`.
104 static std::optional<WalkResult>
105 walkSymbolTable(Operation *op,
106 function_ref<std::optional<WalkResult>(Operation *)> callback) {
107 std::optional<WalkResult> result = callback(op);
108 if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>())
109 return result;
110 return walkSymbolTable(op->getRegions(), callback);
113 //===----------------------------------------------------------------------===//
114 // SymbolTable
115 //===----------------------------------------------------------------------===//
117 /// Build a symbol table with the symbols within the given operation.
118 SymbolTable::SymbolTable(Operation *symbolTableOp)
119 : symbolTableOp(symbolTableOp) {
120 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
121 "expected operation to have SymbolTable trait");
122 assert(symbolTableOp->getNumRegions() == 1 &&
123 "expected operation to have a single region");
124 assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) &&
125 "expected operation to have a single block");
127 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
128 SymbolTable::getSymbolAttrName());
129 for (auto &op : symbolTableOp->getRegion(0).front()) {
130 StringAttr name = getNameIfSymbol(&op, symbolNameId);
131 if (!name)
132 continue;
134 auto inserted = symbolTable.insert({name, &op});
135 (void)inserted;
136 assert(inserted.second &&
137 "expected region to contain uniquely named symbol operations");
141 /// Look up a symbol with the specified name, returning null if no such name
142 /// exists. Names never include the @ on them.
143 Operation *SymbolTable::lookup(StringRef name) const {
144 return lookup(StringAttr::get(symbolTableOp->getContext(), name));
146 Operation *SymbolTable::lookup(StringAttr name) const {
147 return symbolTable.lookup(name);
150 void SymbolTable::remove(Operation *op) {
151 StringAttr name = getNameIfSymbol(op);
152 assert(name && "expected valid 'name' attribute");
153 assert(op->getParentOp() == symbolTableOp &&
154 "expected this operation to be inside of the operation with this "
155 "SymbolTable");
157 auto it = symbolTable.find(name);
158 if (it != symbolTable.end() && it->second == op)
159 symbolTable.erase(it);
162 void SymbolTable::erase(Operation *symbol) {
163 remove(symbol);
164 symbol->erase();
167 // TODO: Consider if this should be renamed to something like insertOrUpdate
168 /// Insert a new symbol into the table and associated operation if not already
169 /// there and rename it as necessary to avoid collisions. Return the name of
170 /// the symbol after insertion as attribute.
171 StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
172 // The symbol cannot be the child of another op and must be the child of the
173 // symbolTableOp after this.
175 // TODO: consider if SymbolTable's constructor should behave the same.
176 if (!symbol->getParentOp()) {
177 auto &body = symbolTableOp->getRegion(0).front();
178 if (insertPt == Block::iterator()) {
179 insertPt = Block::iterator(body.end());
180 } else {
181 assert((insertPt == body.end() ||
182 insertPt->getParentOp() == symbolTableOp) &&
183 "expected insertPt to be in the associated module operation");
185 // Insert before the terminator, if any.
186 if (insertPt == Block::iterator(body.end()) && !body.empty() &&
187 std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
188 insertPt = std::prev(body.end());
190 body.getOperations().insert(insertPt, symbol);
192 assert(symbol->getParentOp() == symbolTableOp &&
193 "symbol is already inserted in another op");
195 // Add this symbol to the symbol table, uniquing the name if a conflict is
196 // detected.
197 StringAttr name = getSymbolName(symbol);
198 if (symbolTable.insert({name, symbol}).second)
199 return name;
200 // If the symbol was already in the table, also return.
201 if (symbolTable.lookup(name) == symbol)
202 return name;
203 // If a conflict was detected, then the symbol will not have been added to
204 // the symbol table. Try suffixes until we get to a unique name that works.
205 SmallString<128> nameBuffer(name.getValue());
206 unsigned originalLength = nameBuffer.size();
208 MLIRContext *context = symbol->getContext();
210 // Iteratively try suffixes until we find one that isn't used.
211 do {
212 nameBuffer.resize(originalLength);
213 nameBuffer += '_';
214 nameBuffer += std::to_string(uniquingCounter++);
215 } while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol})
216 .second);
217 setSymbolName(symbol, nameBuffer);
218 return getSymbolName(symbol);
221 LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) {
222 Operation *op = lookup(from);
223 return rename(op, to);
226 LogicalResult SymbolTable::rename(Operation *op, StringAttr to) {
227 StringAttr from = getNameIfSymbol(op);
228 (void)from;
230 assert(from && "expected valid 'name' attribute");
231 assert(op->getParentOp() == symbolTableOp &&
232 "expected this operation to be inside of the operation with this "
233 "SymbolTable");
234 assert(lookup(from) == op && "current name does not resolve to op");
235 assert(lookup(to) == nullptr && "new name already exists");
237 if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp())))
238 return failure();
240 // Remove op with old name, change name, add with new name. The order is
241 // important here due to how `remove` and `insert` rely on the op name.
242 remove(op);
243 setSymbolName(op, to);
244 insert(op);
246 assert(lookup(to) == op && "new name does not resolve to renamed op");
247 assert(lookup(from) == nullptr && "old name still exists");
249 return success();
252 LogicalResult SymbolTable::rename(StringAttr from, StringRef to) {
253 auto toAttr = StringAttr::get(getOp()->getContext(), to);
254 return rename(from, toAttr);
257 LogicalResult SymbolTable::rename(Operation *op, StringRef to) {
258 auto toAttr = StringAttr::get(getOp()->getContext(), to);
259 return rename(op, toAttr);
262 FailureOr<StringAttr>
263 SymbolTable::renameToUnique(StringAttr oldName,
264 ArrayRef<SymbolTable *> others) {
266 // Determine new name that is unique in all symbol tables.
267 StringAttr newName;
269 MLIRContext *context = oldName.getContext();
270 SmallString<64> prefix = oldName.getValue();
271 int uniqueId = 0;
272 prefix.push_back('_');
273 while (true) {
274 newName = StringAttr::get(context, prefix + Twine(uniqueId++));
275 auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); };
276 if (!lookupNewName(this) && llvm::none_of(others, lookupNewName)) {
277 break;
282 // Apply renaming.
283 if (failed(rename(oldName, newName)))
284 return failure();
285 return newName;
288 FailureOr<StringAttr>
289 SymbolTable::renameToUnique(Operation *op, ArrayRef<SymbolTable *> others) {
290 StringAttr from = getNameIfSymbol(op);
291 assert(from && "expected valid 'name' attribute");
292 return renameToUnique(from, others);
295 /// Returns the name of the given symbol operation.
296 StringAttr SymbolTable::getSymbolName(Operation *symbol) {
297 StringAttr name = getNameIfSymbol(symbol);
298 assert(name && "expected valid symbol name");
299 return name;
302 /// Sets the name of the given symbol operation.
303 void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
304 symbol->setAttr(getSymbolAttrName(), name);
307 /// Returns the visibility of the given symbol operation.
308 SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) {
309 // If the attribute doesn't exist, assume public.
310 StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
311 if (!vis)
312 return Visibility::Public;
314 // Otherwise, switch on the string value.
315 return StringSwitch<Visibility>(vis.getValue())
316 .Case("private", Visibility::Private)
317 .Case("nested", Visibility::Nested)
318 .Case("public", Visibility::Public);
320 /// Sets the visibility of the given symbol operation.
321 void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
322 MLIRContext *ctx = symbol->getContext();
324 // If the visibility is public, just drop the attribute as this is the
325 // default.
326 if (vis == Visibility::Public) {
327 symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName()));
328 return;
331 // Otherwise, update the attribute.
332 assert((vis == Visibility::Private || vis == Visibility::Nested) &&
333 "unknown symbol visibility kind");
335 StringRef visName = vis == Visibility::Private ? "private" : "nested";
336 symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
339 /// Returns the nearest symbol table from a given operation `from`. Returns
340 /// nullptr if no valid parent symbol table could be found.
341 Operation *SymbolTable::getNearestSymbolTable(Operation *from) {
342 assert(from && "expected valid operation");
343 if (isPotentiallyUnknownSymbolTable(from))
344 return nullptr;
346 while (!from->hasTrait<OpTrait::SymbolTable>()) {
347 from = from->getParentOp();
349 // Check that this is a valid op and isn't an unknown symbol table.
350 if (!from || isPotentiallyUnknownSymbolTable(from))
351 return nullptr;
353 return from;
356 /// Walks all symbol table operations nested within, and including, `op`. For
357 /// each symbol table operation, the provided callback is invoked with the op
358 /// and a boolean signifying if the symbols within that symbol table can be
359 /// treated as if all uses are visible. `allSymUsesVisible` identifies whether
360 /// all of the symbol uses of symbols within `op` are visible.
361 void SymbolTable::walkSymbolTables(
362 Operation *op, bool allSymUsesVisible,
363 function_ref<void(Operation *, bool)> callback) {
364 bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
365 if (isSymbolTable) {
366 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
367 allSymUsesVisible |= !symbol || symbol.isPrivate();
368 } else {
369 // Otherwise if 'op' is not a symbol table, any nested symbols are
370 // guaranteed to be hidden.
371 allSymUsesVisible = true;
374 for (Region &region : op->getRegions())
375 for (Block &block : region)
376 for (Operation &nestedOp : block)
377 walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
379 // If 'op' had the symbol table trait, visit it after any nested symbol
380 // tables.
381 if (isSymbolTable)
382 callback(op, allSymUsesVisible);
385 /// Returns the operation registered with the given symbol name with the
386 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
387 /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
388 /// was found.
389 Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
390 StringAttr symbol) {
391 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
392 Region &region = symbolTableOp->getRegion(0);
393 if (region.empty())
394 return nullptr;
396 // Look for a symbol with the given name.
397 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
398 SymbolTable::getSymbolAttrName());
399 for (auto &op : region.front())
400 if (getNameIfSymbol(&op, symbolNameId) == symbol)
401 return &op;
402 return nullptr;
404 Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
405 SymbolRefAttr symbol) {
406 SmallVector<Operation *, 4> resolvedSymbols;
407 if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
408 return nullptr;
409 return resolvedSymbols.back();
412 /// Internal implementation of `lookupSymbolIn` that allows for specialized
413 /// implementations of the lookup function.
414 static LogicalResult lookupSymbolInImpl(
415 Operation *symbolTableOp, SymbolRefAttr symbol,
416 SmallVectorImpl<Operation *> &symbols,
417 function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
418 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
420 // Lookup the root reference for this symbol.
421 symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
422 if (!symbolTableOp)
423 return failure();
424 symbols.push_back(symbolTableOp);
426 // If there are no nested references, just return the root symbol directly.
427 ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
428 if (nestedRefs.empty())
429 return success();
431 // Verify that the root is also a symbol table.
432 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
433 return failure();
435 // Otherwise, lookup each of the nested non-leaf references and ensure that
436 // each corresponds to a valid symbol table.
437 for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
438 symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
439 if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
440 return failure();
441 symbols.push_back(symbolTableOp);
443 symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
444 return success(symbols.back());
447 LogicalResult
448 SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
449 SmallVectorImpl<Operation *> &symbols) {
450 auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
451 return lookupSymbolIn(symbolTableOp, symbol);
453 return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
456 /// Returns the operation registered with the given symbol name within the
457 /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
458 /// nullptr if no valid symbol was found.
459 Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
460 StringAttr symbol) {
461 Operation *symbolTableOp = getNearestSymbolTable(from);
462 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
464 Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
465 SymbolRefAttr symbol) {
466 Operation *symbolTableOp = getNearestSymbolTable(from);
467 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
470 raw_ostream &mlir::operator<<(raw_ostream &os,
471 SymbolTable::Visibility visibility) {
472 switch (visibility) {
473 case SymbolTable::Visibility::Public:
474 return os << "public";
475 case SymbolTable::Visibility::Private:
476 return os << "private";
477 case SymbolTable::Visibility::Nested:
478 return os << "nested";
480 llvm_unreachable("Unexpected visibility");
483 //===----------------------------------------------------------------------===//
484 // SymbolTable Trait Types
485 //===----------------------------------------------------------------------===//
487 LogicalResult detail::verifySymbolTable(Operation *op) {
488 if (op->getNumRegions() != 1)
489 return op->emitOpError()
490 << "Operations with a 'SymbolTable' must have exactly one region";
491 if (!llvm::hasSingleElement(op->getRegion(0)))
492 return op->emitOpError()
493 << "Operations with a 'SymbolTable' must have exactly one block";
495 // Check that all symbols are uniquely named within child regions.
496 DenseMap<Attribute, Location> nameToOrigLoc;
497 for (auto &block : op->getRegion(0)) {
498 for (auto &op : block) {
499 // Check for a symbol name attribute.
500 auto nameAttr =
501 op.getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName());
502 if (!nameAttr)
503 continue;
505 // Try to insert this symbol into the table.
506 auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
507 if (!it.second)
508 return op.emitError()
509 .append("redefinition of symbol named '", nameAttr.getValue(), "'")
510 .attachNote(it.first->second)
511 .append("see existing symbol definition here");
515 // Verify any nested symbol user operations.
516 SymbolTableCollection symbolTable;
517 auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
518 if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
519 return WalkResult(user.verifySymbolUses(symbolTable));
520 return WalkResult::advance();
523 std::optional<WalkResult> result =
524 walkSymbolTable(op->getRegions(), verifySymbolUserFn);
525 return success(result && !result->wasInterrupted());
528 LogicalResult detail::verifySymbol(Operation *op) {
529 // Verify the name attribute.
530 if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
531 return op->emitOpError() << "requires string attribute '"
532 << mlir::SymbolTable::getSymbolAttrName() << "'";
534 // Verify the visibility attribute.
535 if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) {
536 StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
537 if (!visStrAttr)
538 return op->emitOpError() << "requires visibility attribute '"
539 << mlir::SymbolTable::getVisibilityAttrName()
540 << "' to be a string attribute, but got " << vis;
542 if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
543 visStrAttr.getValue()))
544 return op->emitOpError()
545 << "visibility expected to be one of [\"public\", \"private\", "
546 "\"nested\"], but got "
547 << visStrAttr;
549 return success();
552 //===----------------------------------------------------------------------===//
553 // Symbol Use Lists
554 //===----------------------------------------------------------------------===//
556 /// Walk all of the symbol references within the given operation, invoking the
557 /// provided callback for each found use. The callbacks takes the use of the
558 /// symbol.
559 static WalkResult
560 walkSymbolRefs(Operation *op,
561 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
562 return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
563 [&](SymbolRefAttr symbolRef) {
564 if (callback({op, symbolRef}).wasInterrupted())
565 return WalkResult::interrupt();
567 // Don't walk nested references.
568 return WalkResult::skip();
572 /// Walk all of the uses, for any symbol, that are nested within the given
573 /// regions, invoking the provided callback for each. This does not traverse
574 /// into any nested symbol tables.
575 static std::optional<WalkResult>
576 walkSymbolUses(MutableArrayRef<Region> regions,
577 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
578 return walkSymbolTable(regions,
579 [&](Operation *op) -> std::optional<WalkResult> {
580 // Check that this isn't a potentially unknown symbol
581 // table.
582 if (isPotentiallyUnknownSymbolTable(op))
583 return std::nullopt;
585 return walkSymbolRefs(op, callback);
588 /// Walk all of the uses, for any symbol, that are nested within the given
589 /// operation 'from', invoking the provided callback for each. This does not
590 /// traverse into any nested symbol tables.
591 static std::optional<WalkResult>
592 walkSymbolUses(Operation *from,
593 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
594 // If this operation has regions, and it, as well as its dialect, isn't
595 // registered then conservatively fail. The operation may define a
596 // symbol table, so we can't opaquely know if we should traverse to find
597 // nested uses.
598 if (isPotentiallyUnknownSymbolTable(from))
599 return std::nullopt;
601 // Walk the uses on this operation.
602 if (walkSymbolRefs(from, callback).wasInterrupted())
603 return WalkResult::interrupt();
605 // Only recurse if this operation is not a symbol table. A symbol table
606 // defines a new scope, so we can't walk the attributes from within the symbol
607 // table op.
608 if (!from->hasTrait<OpTrait::SymbolTable>())
609 return walkSymbolUses(from->getRegions(), callback);
610 return WalkResult::advance();
613 namespace {
614 /// This class represents a single symbol scope. A symbol scope represents the
615 /// set of operations nested within a symbol table that may reference symbols
616 /// within that table. A symbol scope does not contain the symbol table
617 /// operation itself, just its contained operations. A scope ends at leaf
618 /// operations or another symbol table operation.
619 struct SymbolScope {
620 /// Walk the symbol uses within this scope, invoking the given callback.
621 /// This variant is used when the callback type matches that expected by
622 /// 'walkSymbolUses'.
623 template <typename CallbackT,
624 std::enable_if_t<!std::is_same<
625 typename llvm::function_traits<CallbackT>::result_t,
626 void>::value> * = nullptr>
627 std::optional<WalkResult> walk(CallbackT cback) {
628 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
629 return walkSymbolUses(*region, cback);
630 return walkSymbolUses(limit.get<Operation *>(), cback);
632 /// This variant is used when the callback type matches a stripped down type:
633 /// void(SymbolTable::SymbolUse use)
634 template <typename CallbackT,
635 std::enable_if_t<std::is_same<
636 typename llvm::function_traits<CallbackT>::result_t,
637 void>::value> * = nullptr>
638 std::optional<WalkResult> walk(CallbackT cback) {
639 return walk([=](SymbolTable::SymbolUse use) {
640 return cback(use), WalkResult::advance();
644 /// Walk all of the operations nested under the current scope without
645 /// traversing into any nested symbol tables.
646 template <typename CallbackT>
647 std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
648 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
649 return ::walkSymbolTable(*region, cback);
650 return ::walkSymbolTable(limit.get<Operation *>(), cback);
653 /// The representation of the symbol within this scope.
654 SymbolRefAttr symbol;
656 /// The IR unit representing this scope.
657 llvm::PointerUnion<Operation *, Region *> limit;
659 } // namespace
661 /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
662 static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
663 Operation *limit) {
664 StringAttr symName = SymbolTable::getSymbolName(symbol);
665 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
667 // Compute the ancestors of 'limit'.
668 SetVector<Operation *, SmallVector<Operation *, 4>,
669 SmallPtrSet<Operation *, 4>>
670 limitAncestors;
671 Operation *limitAncestor = limit;
672 do {
673 // Check to see if 'symbol' is an ancestor of 'limit'.
674 if (limitAncestor == symbol) {
675 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
676 // doesn't support parent references.
677 if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
678 symbol->getParentOp())
679 return {{SymbolRefAttr::get(symName), limit}};
680 return {};
683 limitAncestors.insert(limitAncestor);
684 } while ((limitAncestor = limitAncestor->getParentOp()));
686 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
687 Operation *commonAncestor = symbol->getParentOp();
688 do {
689 if (limitAncestors.count(commonAncestor))
690 break;
691 } while ((commonAncestor = commonAncestor->getParentOp()));
692 assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
694 // Compute the set of valid nested references for 'symbol' as far up to the
695 // common ancestor as possible.
696 SmallVector<SymbolRefAttr, 2> references;
697 bool collectedAllReferences = succeeded(
698 collectValidReferencesFor(symbol, symName, commonAncestor, references));
700 // Handle the case where the common ancestor is 'limit'.
701 if (commonAncestor == limit) {
702 SmallVector<SymbolScope, 2> scopes;
704 // Walk each of the ancestors of 'symbol', calling the compute function for
705 // each one.
706 Operation *limitIt = symbol->getParentOp();
707 for (size_t i = 0, e = references.size(); i != e;
708 ++i, limitIt = limitIt->getParentOp()) {
709 assert(limitIt->hasTrait<OpTrait::SymbolTable>());
710 scopes.push_back({references[i], &limitIt->getRegion(0)});
712 return scopes;
715 // Otherwise, we just need the symbol reference for 'symbol' that will be
716 // used within 'limit'. This is the last reference in the list we computed
717 // above if we were able to collect all references.
718 if (!collectedAllReferences)
719 return {};
720 return {{references.back(), limit}};
722 static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
723 Region *limit) {
724 auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
726 // If we collected some scopes to walk, make sure to constrain the one for
727 // limit to the specific region requested.
728 if (!scopes.empty())
729 scopes.back().limit = limit;
730 return scopes;
732 static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
733 Region *limit) {
734 return {{SymbolRefAttr::get(symbol), limit}};
737 static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
738 Operation *limit) {
739 SmallVector<SymbolScope, 1> scopes;
740 auto symbolRef = SymbolRefAttr::get(symbol);
741 for (auto &region : limit->getRegions())
742 scopes.push_back({symbolRef, &region});
743 return scopes;
746 /// Returns true if the given reference 'SubRef' is a sub reference of the
747 /// reference 'ref', i.e. 'ref' is a further qualified reference.
748 static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
749 if (ref == subRef)
750 return true;
752 // If the references are not pointer equal, check to see if `subRef` is a
753 // prefix of `ref`.
754 if (llvm::isa<FlatSymbolRefAttr>(ref) ||
755 ref.getRootReference() != subRef.getRootReference())
756 return false;
758 auto refLeafs = ref.getNestedReferences();
759 auto subRefLeafs = subRef.getNestedReferences();
760 return subRefLeafs.size() < refLeafs.size() &&
761 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
764 //===----------------------------------------------------------------------===//
765 // SymbolTable::getSymbolUses
767 /// The implementation of SymbolTable::getSymbolUses below.
768 template <typename FromT>
769 static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
770 std::vector<SymbolTable::SymbolUse> uses;
771 auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
772 uses.push_back(symbolUse);
773 return WalkResult::advance();
775 auto result = walkSymbolUses(from, walkFn);
776 return result ? std::optional<SymbolTable::UseRange>(std::move(uses))
777 : std::nullopt;
780 /// Get an iterator range for all of the uses, for any symbol, that are nested
781 /// within the given operation 'from'. This does not traverse into any nested
782 /// symbol tables, and will also only return uses on 'from' if it does not
783 /// also define a symbol table. This is because we treat the region as the
784 /// boundary of the symbol table, and not the op itself. This function returns
785 /// std::nullopt if there are any unknown operations that may potentially be
786 /// symbol tables.
787 auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> {
788 return getSymbolUsesImpl(from);
790 auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> {
791 return getSymbolUsesImpl(MutableArrayRef<Region>(*from));
794 //===----------------------------------------------------------------------===//
795 // SymbolTable::getSymbolUses
797 /// The implementation of SymbolTable::getSymbolUses below.
798 template <typename SymbolT, typename IRUnitT>
799 static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
800 IRUnitT *limit) {
801 std::vector<SymbolTable::SymbolUse> uses;
802 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
803 if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
804 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
805 uses.push_back(symbolUse);
807 return std::nullopt;
809 return SymbolTable::UseRange(std::move(uses));
812 /// Get all of the uses of the given symbol that are nested within the given
813 /// operation 'from', invoking the provided callback for each. This does not
814 /// traverse into any nested symbol tables. This function returns std::nullopt
815 /// if there are any unknown operations that may potentially be symbol tables.
816 auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
817 -> std::optional<UseRange> {
818 return getSymbolUsesImpl(symbol, from);
820 auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
821 -> std::optional<UseRange> {
822 return getSymbolUsesImpl(symbol, from);
824 auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
825 -> std::optional<UseRange> {
826 return getSymbolUsesImpl(symbol, from);
828 auto SymbolTable::getSymbolUses(Operation *symbol, Region *from)
829 -> std::optional<UseRange> {
830 return getSymbolUsesImpl(symbol, from);
833 //===----------------------------------------------------------------------===//
834 // SymbolTable::symbolKnownUseEmpty
836 /// The implementation of SymbolTable::symbolKnownUseEmpty below.
837 template <typename SymbolT, typename IRUnitT>
838 static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
839 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
840 // Walk all of the symbol uses looking for a reference to 'symbol'.
841 if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
842 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
843 ? WalkResult::interrupt()
844 : WalkResult::advance();
845 }) != WalkResult::advance())
846 return false;
848 return true;
851 /// Return if the given symbol is known to have no uses that are nested within
852 /// the given operation 'from'. This does not traverse into any nested symbol
853 /// tables. This function will also return false if there are any unknown
854 /// operations that may potentially be symbol tables.
855 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
856 return symbolKnownUseEmptyImpl(symbol, from);
858 bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
859 return symbolKnownUseEmptyImpl(symbol, from);
861 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
862 return symbolKnownUseEmptyImpl(symbol, from);
864 bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
865 return symbolKnownUseEmptyImpl(symbol, from);
868 //===----------------------------------------------------------------------===//
869 // SymbolTable::replaceAllSymbolUses
871 /// Generates a new symbol reference attribute with a new leaf reference.
872 static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
873 FlatSymbolRefAttr newLeafAttr) {
874 if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
875 return newLeafAttr;
876 auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
877 nestedRefs.back() = newLeafAttr;
878 return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
881 /// The implementation of SymbolTable::replaceAllSymbolUses below.
882 template <typename SymbolT, typename IRUnitT>
883 static LogicalResult
884 replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
885 // Generate a new attribute to replace the given attribute.
886 FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
887 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
888 SymbolRefAttr oldAttr = scope.symbol;
889 SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
890 AttrTypeReplacer replacer;
891 replacer.addReplacement(
892 [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
893 // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
894 // want to accidentally replace an inner reference.
895 if (attr == oldAttr)
896 return {newAttr, WalkResult::skip()};
897 // Handle prefix matches.
898 if (isReferencePrefixOf(oldAttr, attr)) {
899 auto oldNestedRefs = oldAttr.getNestedReferences();
900 auto nestedRefs = attr.getNestedReferences();
901 if (oldNestedRefs.empty())
902 return {SymbolRefAttr::get(newSymbol, nestedRefs),
903 WalkResult::skip()};
905 auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
906 newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
907 return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
908 WalkResult::skip()};
910 return {attr, WalkResult::skip()};
913 auto walkFn = [&](Operation *op) -> std::optional<WalkResult> {
914 replacer.replaceElementsIn(op);
915 return WalkResult::advance();
917 if (!scope.walkSymbolTable(walkFn))
918 return failure();
920 return success();
923 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
924 /// provided symbol 'newSymbol' that are nested within the given operation
925 /// 'from'. This does not traverse into any nested symbol tables. If there are
926 /// any unknown operations that may potentially be symbol tables, no uses are
927 /// replaced and failure is returned.
928 LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
929 StringAttr newSymbol,
930 Operation *from) {
931 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
933 LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
934 StringAttr newSymbol,
935 Operation *from) {
936 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
938 LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
939 StringAttr newSymbol,
940 Region *from) {
941 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
943 LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
944 StringAttr newSymbol,
945 Region *from) {
946 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
949 //===----------------------------------------------------------------------===//
950 // SymbolTableCollection
951 //===----------------------------------------------------------------------===//
953 Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
954 StringAttr symbol) {
955 return getSymbolTable(symbolTableOp).lookup(symbol);
957 Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
958 SymbolRefAttr name) {
959 SmallVector<Operation *, 4> symbols;
960 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
961 return nullptr;
962 return symbols.back();
964 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
965 /// a given SymbolRefAttr. Returns failure if any of the nested references could
966 /// not be resolved.
967 LogicalResult
968 SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
969 SymbolRefAttr name,
970 SmallVectorImpl<Operation *> &symbols) {
971 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
972 return lookupSymbolIn(symbolTableOp, symbol);
974 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
977 /// Returns the operation registered with the given symbol name within the
978 /// closest parent operation of, or including, 'from' with the
979 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
980 /// found.
981 Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
982 StringAttr symbol) {
983 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
984 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
986 Operation *
987 SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
988 SymbolRefAttr symbol) {
989 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
990 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
993 /// Lookup, or create, a symbol table for an operation.
994 SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
995 auto it = symbolTables.try_emplace(op, nullptr);
996 if (it.second)
997 it.first->second = std::make_unique<SymbolTable>(op);
998 return *it.first->second;
1001 //===----------------------------------------------------------------------===//
1002 // LockedSymbolTableCollection
1003 //===----------------------------------------------------------------------===//
1005 Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
1006 StringAttr symbol) {
1007 return getSymbolTable(symbolTableOp).lookup(symbol);
1010 Operation *
1011 LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
1012 FlatSymbolRefAttr symbol) {
1013 return lookupSymbolIn(symbolTableOp, symbol.getAttr());
1016 Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
1017 SymbolRefAttr name) {
1018 SmallVector<Operation *> symbols;
1019 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
1020 return nullptr;
1021 return symbols.back();
1024 LogicalResult LockedSymbolTableCollection::lookupSymbolIn(
1025 Operation *symbolTableOp, SymbolRefAttr name,
1026 SmallVectorImpl<Operation *> &symbols) {
1027 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
1028 return lookupSymbolIn(symbolTableOp, symbol);
1030 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
1033 SymbolTable &
1034 LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) {
1035 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
1036 // Try to find an existing symbol table.
1038 llvm::sys::SmartScopedReader<true> lock(mutex);
1039 auto it = collection.symbolTables.find(symbolTableOp);
1040 if (it != collection.symbolTables.end())
1041 return *it->second;
1043 // Create a symbol table for the operation. Perform construction outside of
1044 // the critical section.
1045 auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp);
1046 // Insert the constructed symbol table.
1047 llvm::sys::SmartScopedWriter<true> lock(mutex);
1048 return *collection.symbolTables
1049 .insert({symbolTableOp, std::move(symbolTable)})
1050 .first->second;
1053 //===----------------------------------------------------------------------===//
1054 // SymbolUserMap
1055 //===----------------------------------------------------------------------===//
1057 SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable,
1058 Operation *symbolTableOp)
1059 : symbolTable(symbolTable) {
1060 // Walk each of the symbol tables looking for discardable callgraph nodes.
1061 SmallVector<Operation *> symbols;
1062 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
1063 for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
1064 auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
1065 assert(symbolUses && "expected uses to be valid");
1067 for (const SymbolTable::SymbolUse &use : *symbolUses) {
1068 symbols.clear();
1069 (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
1070 symbols);
1071 for (Operation *symbolOp : symbols)
1072 symbolToUsers[symbolOp].insert(use.getUser());
1076 // We just set `allSymUsesVisible` to false here because it isn't necessary
1077 // for building the user map.
1078 SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
1079 walkFn);
1082 void SymbolUserMap::replaceAllUsesWith(Operation *symbol,
1083 StringAttr newSymbolName) {
1084 auto it = symbolToUsers.find(symbol);
1085 if (it == symbolToUsers.end())
1086 return;
1088 // Replace the uses within the users of `symbol`.
1089 for (Operation *user : it->second)
1090 (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
1092 // Move the current users of `symbol` to the new symbol if it is in the
1093 // symbol table.
1094 Operation *newSymbol =
1095 symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
1096 if (newSymbol != symbol) {
1097 // Transfer over the users to the new symbol. The reference to the old one
1098 // is fetched again as the iterator is invalidated during the insertion.
1099 auto newIt = symbolToUsers.try_emplace(newSymbol, SetVector<Operation *>{});
1100 auto oldIt = symbolToUsers.find(symbol);
1101 assert(oldIt != symbolToUsers.end() && "missing old users list");
1102 if (newIt.second)
1103 newIt.first->second = std::move(oldIt->second);
1104 else
1105 newIt.first->second.set_union(oldIt->second);
1106 symbolToUsers.erase(oldIt);
1110 //===----------------------------------------------------------------------===//
1111 // Visibility parsing implementation.
1112 //===----------------------------------------------------------------------===//
1114 ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser,
1115 NamedAttrList &attrs) {
1116 StringRef visibility;
1117 if (parser.parseOptionalKeyword(&visibility, {"public", "private", "nested"}))
1118 return failure();
1120 StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
1121 attrs.push_back(parser.getBuilder().getNamedAttr(
1122 SymbolTable::getVisibilityAttrName(), visibilityAttr));
1123 return success();
1126 //===----------------------------------------------------------------------===//
1127 // Symbol Interfaces
1128 //===----------------------------------------------------------------------===//
1130 /// Include the generated symbol interfaces.
1131 #include "mlir/IR/SymbolInterfaces.cpp.inc"