[GlobalISel] Combine [s,z]ext of undef into 0 (#117439)
[llvm-project.git] / mlir / lib / IR / SymbolTable.cpp
blobe83d19553d62ce8beec16d380fd5a25d5e1b6e19
1 //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/SymbolTable.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/OpImplementation.h"
12 #include "llvm/ADT/SetVector.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallString.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include <optional>
18 using namespace mlir;
20 /// Return true if the given operation is unknown and may potentially define a
21 /// symbol table.
22 static bool isPotentiallyUnknownSymbolTable(Operation *op) {
23 return op->getNumRegions() == 1 && !op->getDialect();
26 /// Returns the string name of the given symbol, or null if this is not a
27 /// symbol.
28 static StringAttr getNameIfSymbol(Operation *op) {
29 return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
31 static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
32 return op->getAttrOfType<StringAttr>(symbolAttrNameId);
35 /// Computes the nested symbol reference attribute for the symbol 'symbolName'
36 /// that are usable within the symbol table operations from 'symbol' as far up
37 /// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
38 /// Returns success if all references up to 'within' could be computed.
39 static LogicalResult
40 collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
41 Operation *within,
42 SmallVectorImpl<SymbolRefAttr> &results) {
43 assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
44 MLIRContext *ctx = symbol->getContext();
46 auto leafRef = FlatSymbolRefAttr::get(symbolName);
47 results.push_back(leafRef);
49 // Early exit for when 'within' is the parent of 'symbol'.
50 Operation *symbolTableOp = symbol->getParentOp();
51 if (within == symbolTableOp)
52 return success();
54 // Collect references until 'symbolTableOp' reaches 'within'.
55 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
56 StringAttr symbolNameId =
57 StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
58 do {
59 // Each parent of 'symbol' should define a symbol table.
60 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
61 return failure();
62 // Each parent of 'symbol' should also be a symbol.
63 StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
64 if (!symbolTableName)
65 return failure();
66 results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
68 symbolTableOp = symbolTableOp->getParentOp();
69 if (symbolTableOp == within)
70 break;
71 nestedRefs.insert(nestedRefs.begin(),
72 FlatSymbolRefAttr::get(symbolTableName));
73 } while (true);
74 return success();
77 /// Walk all of the operations within the given set of regions, without
78 /// traversing into any nested symbol tables. Stops walking if the result of the
79 /// callback is anything other than `WalkResult::advance`.
80 static std::optional<WalkResult>
81 walkSymbolTable(MutableArrayRef<Region> regions,
82 function_ref<std::optional<WalkResult>(Operation *)> callback) {
83 SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
84 while (!worklist.empty()) {
85 for (Operation &op : worklist.pop_back_val()->getOps()) {
86 std::optional<WalkResult> result = callback(&op);
87 if (result != WalkResult::advance())
88 return result;
90 // If this op defines a new symbol table scope, we can't traverse. Any
91 // symbol references nested within 'op' are different semantically.
92 if (!op.hasTrait<OpTrait::SymbolTable>()) {
93 for (Region &region : op.getRegions())
94 worklist.push_back(&region);
98 return WalkResult::advance();
101 /// Walk all of the operations nested under, and including, the given operation,
102 /// without traversing into any nested symbol tables. Stops walking if the
103 /// result of the callback is anything other than `WalkResult::advance`.
104 static std::optional<WalkResult>
105 walkSymbolTable(Operation *op,
106 function_ref<std::optional<WalkResult>(Operation *)> callback) {
107 std::optional<WalkResult> result = callback(op);
108 if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>())
109 return result;
110 return walkSymbolTable(op->getRegions(), callback);
113 //===----------------------------------------------------------------------===//
114 // SymbolTable
115 //===----------------------------------------------------------------------===//
117 /// Build a symbol table with the symbols within the given operation.
118 SymbolTable::SymbolTable(Operation *symbolTableOp)
119 : symbolTableOp(symbolTableOp) {
120 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
121 "expected operation to have SymbolTable trait");
122 assert(symbolTableOp->getNumRegions() == 1 &&
123 "expected operation to have a single region");
124 assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) &&
125 "expected operation to have a single block");
127 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
128 SymbolTable::getSymbolAttrName());
129 for (auto &op : symbolTableOp->getRegion(0).front()) {
130 StringAttr name = getNameIfSymbol(&op, symbolNameId);
131 if (!name)
132 continue;
134 auto inserted = symbolTable.insert({name, &op});
135 (void)inserted;
136 assert(inserted.second &&
137 "expected region to contain uniquely named symbol operations");
141 /// Look up a symbol with the specified name, returning null if no such name
142 /// exists. Names never include the @ on them.
143 Operation *SymbolTable::lookup(StringRef name) const {
144 return lookup(StringAttr::get(symbolTableOp->getContext(), name));
146 Operation *SymbolTable::lookup(StringAttr name) const {
147 return symbolTable.lookup(name);
150 void SymbolTable::remove(Operation *op) {
151 StringAttr name = getNameIfSymbol(op);
152 assert(name && "expected valid 'name' attribute");
153 assert(op->getParentOp() == symbolTableOp &&
154 "expected this operation to be inside of the operation with this "
155 "SymbolTable");
157 auto it = symbolTable.find(name);
158 if (it != symbolTable.end() && it->second == op)
159 symbolTable.erase(it);
162 void SymbolTable::erase(Operation *symbol) {
163 remove(symbol);
164 symbol->erase();
167 // TODO: Consider if this should be renamed to something like insertOrUpdate
168 /// Insert a new symbol into the table and associated operation if not already
169 /// there and rename it as necessary to avoid collisions. Return the name of
170 /// the symbol after insertion as attribute.
171 StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
172 // The symbol cannot be the child of another op and must be the child of the
173 // symbolTableOp after this.
175 // TODO: consider if SymbolTable's constructor should behave the same.
176 if (!symbol->getParentOp()) {
177 auto &body = symbolTableOp->getRegion(0).front();
178 if (insertPt == Block::iterator()) {
179 insertPt = Block::iterator(body.end());
180 } else {
181 assert((insertPt == body.end() ||
182 insertPt->getParentOp() == symbolTableOp) &&
183 "expected insertPt to be in the associated module operation");
185 // Insert before the terminator, if any.
186 if (insertPt == Block::iterator(body.end()) && !body.empty() &&
187 std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
188 insertPt = std::prev(body.end());
190 body.getOperations().insert(insertPt, symbol);
192 assert(symbol->getParentOp() == symbolTableOp &&
193 "symbol is already inserted in another op");
195 // Add this symbol to the symbol table, uniquing the name if a conflict is
196 // detected.
197 StringAttr name = getSymbolName(symbol);
198 if (symbolTable.insert({name, symbol}).second)
199 return name;
200 // If the symbol was already in the table, also return.
201 if (symbolTable.lookup(name) == symbol)
202 return name;
204 MLIRContext *context = symbol->getContext();
205 SmallString<128> nameBuffer = generateSymbolName<128>(
206 name.getValue(),
207 [&](StringRef candidate) {
208 return !symbolTable
209 .insert({StringAttr::get(context, candidate), symbol})
210 .second;
212 uniquingCounter);
213 setSymbolName(symbol, nameBuffer);
214 return getSymbolName(symbol);
217 LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) {
218 Operation *op = lookup(from);
219 return rename(op, to);
222 LogicalResult SymbolTable::rename(Operation *op, StringAttr to) {
223 StringAttr from = getNameIfSymbol(op);
224 (void)from;
226 assert(from && "expected valid 'name' attribute");
227 assert(op->getParentOp() == symbolTableOp &&
228 "expected this operation to be inside of the operation with this "
229 "SymbolTable");
230 assert(lookup(from) == op && "current name does not resolve to op");
231 assert(lookup(to) == nullptr && "new name already exists");
233 if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp())))
234 return failure();
236 // Remove op with old name, change name, add with new name. The order is
237 // important here due to how `remove` and `insert` rely on the op name.
238 remove(op);
239 setSymbolName(op, to);
240 insert(op);
242 assert(lookup(to) == op && "new name does not resolve to renamed op");
243 assert(lookup(from) == nullptr && "old name still exists");
245 return success();
248 LogicalResult SymbolTable::rename(StringAttr from, StringRef to) {
249 auto toAttr = StringAttr::get(getOp()->getContext(), to);
250 return rename(from, toAttr);
253 LogicalResult SymbolTable::rename(Operation *op, StringRef to) {
254 auto toAttr = StringAttr::get(getOp()->getContext(), to);
255 return rename(op, toAttr);
258 FailureOr<StringAttr>
259 SymbolTable::renameToUnique(StringAttr oldName,
260 ArrayRef<SymbolTable *> others) {
262 // Determine new name that is unique in all symbol tables.
263 StringAttr newName;
265 MLIRContext *context = oldName.getContext();
266 SmallString<64> prefix = oldName.getValue();
267 int uniqueId = 0;
268 prefix.push_back('_');
269 while (true) {
270 newName = StringAttr::get(context, prefix + Twine(uniqueId++));
271 auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); };
272 if (!lookupNewName(this) && llvm::none_of(others, lookupNewName)) {
273 break;
278 // Apply renaming.
279 if (failed(rename(oldName, newName)))
280 return failure();
281 return newName;
284 FailureOr<StringAttr>
285 SymbolTable::renameToUnique(Operation *op, ArrayRef<SymbolTable *> others) {
286 StringAttr from = getNameIfSymbol(op);
287 assert(from && "expected valid 'name' attribute");
288 return renameToUnique(from, others);
291 /// Returns the name of the given symbol operation.
292 StringAttr SymbolTable::getSymbolName(Operation *symbol) {
293 StringAttr name = getNameIfSymbol(symbol);
294 assert(name && "expected valid symbol name");
295 return name;
298 /// Sets the name of the given symbol operation.
299 void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
300 symbol->setAttr(getSymbolAttrName(), name);
303 /// Returns the visibility of the given symbol operation.
304 SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) {
305 // If the attribute doesn't exist, assume public.
306 StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
307 if (!vis)
308 return Visibility::Public;
310 // Otherwise, switch on the string value.
311 return StringSwitch<Visibility>(vis.getValue())
312 .Case("private", Visibility::Private)
313 .Case("nested", Visibility::Nested)
314 .Case("public", Visibility::Public);
316 /// Sets the visibility of the given symbol operation.
317 void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
318 MLIRContext *ctx = symbol->getContext();
320 // If the visibility is public, just drop the attribute as this is the
321 // default.
322 if (vis == Visibility::Public) {
323 symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName()));
324 return;
327 // Otherwise, update the attribute.
328 assert((vis == Visibility::Private || vis == Visibility::Nested) &&
329 "unknown symbol visibility kind");
331 StringRef visName = vis == Visibility::Private ? "private" : "nested";
332 symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
335 /// Returns the nearest symbol table from a given operation `from`. Returns
336 /// nullptr if no valid parent symbol table could be found.
337 Operation *SymbolTable::getNearestSymbolTable(Operation *from) {
338 assert(from && "expected valid operation");
339 if (isPotentiallyUnknownSymbolTable(from))
340 return nullptr;
342 while (!from->hasTrait<OpTrait::SymbolTable>()) {
343 from = from->getParentOp();
345 // Check that this is a valid op and isn't an unknown symbol table.
346 if (!from || isPotentiallyUnknownSymbolTable(from))
347 return nullptr;
349 return from;
352 /// Walks all symbol table operations nested within, and including, `op`. For
353 /// each symbol table operation, the provided callback is invoked with the op
354 /// and a boolean signifying if the symbols within that symbol table can be
355 /// treated as if all uses are visible. `allSymUsesVisible` identifies whether
356 /// all of the symbol uses of symbols within `op` are visible.
357 void SymbolTable::walkSymbolTables(
358 Operation *op, bool allSymUsesVisible,
359 function_ref<void(Operation *, bool)> callback) {
360 bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
361 if (isSymbolTable) {
362 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
363 allSymUsesVisible |= !symbol || symbol.isPrivate();
364 } else {
365 // Otherwise if 'op' is not a symbol table, any nested symbols are
366 // guaranteed to be hidden.
367 allSymUsesVisible = true;
370 for (Region &region : op->getRegions())
371 for (Block &block : region)
372 for (Operation &nestedOp : block)
373 walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
375 // If 'op' had the symbol table trait, visit it after any nested symbol
376 // tables.
377 if (isSymbolTable)
378 callback(op, allSymUsesVisible);
381 /// Returns the operation registered with the given symbol name with the
382 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
383 /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
384 /// was found.
385 Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
386 StringAttr symbol) {
387 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
388 Region &region = symbolTableOp->getRegion(0);
389 if (region.empty())
390 return nullptr;
392 // Look for a symbol with the given name.
393 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
394 SymbolTable::getSymbolAttrName());
395 for (auto &op : region.front())
396 if (getNameIfSymbol(&op, symbolNameId) == symbol)
397 return &op;
398 return nullptr;
400 Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
401 SymbolRefAttr symbol) {
402 SmallVector<Operation *, 4> resolvedSymbols;
403 if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
404 return nullptr;
405 return resolvedSymbols.back();
408 /// Internal implementation of `lookupSymbolIn` that allows for specialized
409 /// implementations of the lookup function.
410 static LogicalResult lookupSymbolInImpl(
411 Operation *symbolTableOp, SymbolRefAttr symbol,
412 SmallVectorImpl<Operation *> &symbols,
413 function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
414 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
416 // Lookup the root reference for this symbol.
417 symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
418 if (!symbolTableOp)
419 return failure();
420 symbols.push_back(symbolTableOp);
422 // If there are no nested references, just return the root symbol directly.
423 ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
424 if (nestedRefs.empty())
425 return success();
427 // Verify that the root is also a symbol table.
428 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
429 return failure();
431 // Otherwise, lookup each of the nested non-leaf references and ensure that
432 // each corresponds to a valid symbol table.
433 for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
434 symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
435 if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
436 return failure();
437 symbols.push_back(symbolTableOp);
439 symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
440 return success(symbols.back());
443 LogicalResult
444 SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
445 SmallVectorImpl<Operation *> &symbols) {
446 auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
447 return lookupSymbolIn(symbolTableOp, symbol);
449 return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
452 /// Returns the operation registered with the given symbol name within the
453 /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
454 /// nullptr if no valid symbol was found.
455 Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
456 StringAttr symbol) {
457 Operation *symbolTableOp = getNearestSymbolTable(from);
458 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
460 Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
461 SymbolRefAttr symbol) {
462 Operation *symbolTableOp = getNearestSymbolTable(from);
463 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
466 raw_ostream &mlir::operator<<(raw_ostream &os,
467 SymbolTable::Visibility visibility) {
468 switch (visibility) {
469 case SymbolTable::Visibility::Public:
470 return os << "public";
471 case SymbolTable::Visibility::Private:
472 return os << "private";
473 case SymbolTable::Visibility::Nested:
474 return os << "nested";
476 llvm_unreachable("Unexpected visibility");
479 //===----------------------------------------------------------------------===//
480 // SymbolTable Trait Types
481 //===----------------------------------------------------------------------===//
483 LogicalResult detail::verifySymbolTable(Operation *op) {
484 if (op->getNumRegions() != 1)
485 return op->emitOpError()
486 << "Operations with a 'SymbolTable' must have exactly one region";
487 if (!llvm::hasSingleElement(op->getRegion(0)))
488 return op->emitOpError()
489 << "Operations with a 'SymbolTable' must have exactly one block";
491 // Check that all symbols are uniquely named within child regions.
492 DenseMap<Attribute, Location> nameToOrigLoc;
493 for (auto &block : op->getRegion(0)) {
494 for (auto &op : block) {
495 // Check for a symbol name attribute.
496 auto nameAttr =
497 op.getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName());
498 if (!nameAttr)
499 continue;
501 // Try to insert this symbol into the table.
502 auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
503 if (!it.second)
504 return op.emitError()
505 .append("redefinition of symbol named '", nameAttr.getValue(), "'")
506 .attachNote(it.first->second)
507 .append("see existing symbol definition here");
511 // Verify any nested symbol user operations.
512 SymbolTableCollection symbolTable;
513 auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
514 if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
515 return WalkResult(user.verifySymbolUses(symbolTable));
516 return WalkResult::advance();
519 std::optional<WalkResult> result =
520 walkSymbolTable(op->getRegions(), verifySymbolUserFn);
521 return success(result && !result->wasInterrupted());
524 LogicalResult detail::verifySymbol(Operation *op) {
525 // Verify the name attribute.
526 if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
527 return op->emitOpError() << "requires string attribute '"
528 << mlir::SymbolTable::getSymbolAttrName() << "'";
530 // Verify the visibility attribute.
531 if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) {
532 StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
533 if (!visStrAttr)
534 return op->emitOpError() << "requires visibility attribute '"
535 << mlir::SymbolTable::getVisibilityAttrName()
536 << "' to be a string attribute, but got " << vis;
538 if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
539 visStrAttr.getValue()))
540 return op->emitOpError()
541 << "visibility expected to be one of [\"public\", \"private\", "
542 "\"nested\"], but got "
543 << visStrAttr;
545 return success();
548 //===----------------------------------------------------------------------===//
549 // Symbol Use Lists
550 //===----------------------------------------------------------------------===//
552 /// Walk all of the symbol references within the given operation, invoking the
553 /// provided callback for each found use. The callbacks takes the use of the
554 /// symbol.
555 static WalkResult
556 walkSymbolRefs(Operation *op,
557 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
558 return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
559 [&](SymbolRefAttr symbolRef) {
560 if (callback({op, symbolRef}).wasInterrupted())
561 return WalkResult::interrupt();
563 // Don't walk nested references.
564 return WalkResult::skip();
568 /// Walk all of the uses, for any symbol, that are nested within the given
569 /// regions, invoking the provided callback for each. This does not traverse
570 /// into any nested symbol tables.
571 static std::optional<WalkResult>
572 walkSymbolUses(MutableArrayRef<Region> regions,
573 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
574 return walkSymbolTable(regions,
575 [&](Operation *op) -> std::optional<WalkResult> {
576 // Check that this isn't a potentially unknown symbol
577 // table.
578 if (isPotentiallyUnknownSymbolTable(op))
579 return std::nullopt;
581 return walkSymbolRefs(op, callback);
584 /// Walk all of the uses, for any symbol, that are nested within the given
585 /// operation 'from', invoking the provided callback for each. This does not
586 /// traverse into any nested symbol tables.
587 static std::optional<WalkResult>
588 walkSymbolUses(Operation *from,
589 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
590 // If this operation has regions, and it, as well as its dialect, isn't
591 // registered then conservatively fail. The operation may define a
592 // symbol table, so we can't opaquely know if we should traverse to find
593 // nested uses.
594 if (isPotentiallyUnknownSymbolTable(from))
595 return std::nullopt;
597 // Walk the uses on this operation.
598 if (walkSymbolRefs(from, callback).wasInterrupted())
599 return WalkResult::interrupt();
601 // Only recurse if this operation is not a symbol table. A symbol table
602 // defines a new scope, so we can't walk the attributes from within the symbol
603 // table op.
604 if (!from->hasTrait<OpTrait::SymbolTable>())
605 return walkSymbolUses(from->getRegions(), callback);
606 return WalkResult::advance();
609 namespace {
610 /// This class represents a single symbol scope. A symbol scope represents the
611 /// set of operations nested within a symbol table that may reference symbols
612 /// within that table. A symbol scope does not contain the symbol table
613 /// operation itself, just its contained operations. A scope ends at leaf
614 /// operations or another symbol table operation.
615 struct SymbolScope {
616 /// Walk the symbol uses within this scope, invoking the given callback.
617 /// This variant is used when the callback type matches that expected by
618 /// 'walkSymbolUses'.
619 template <typename CallbackT,
620 std::enable_if_t<!std::is_same<
621 typename llvm::function_traits<CallbackT>::result_t,
622 void>::value> * = nullptr>
623 std::optional<WalkResult> walk(CallbackT cback) {
624 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
625 return walkSymbolUses(*region, cback);
626 return walkSymbolUses(limit.get<Operation *>(), cback);
628 /// This variant is used when the callback type matches a stripped down type:
629 /// void(SymbolTable::SymbolUse use)
630 template <typename CallbackT,
631 std::enable_if_t<std::is_same<
632 typename llvm::function_traits<CallbackT>::result_t,
633 void>::value> * = nullptr>
634 std::optional<WalkResult> walk(CallbackT cback) {
635 return walk([=](SymbolTable::SymbolUse use) {
636 return cback(use), WalkResult::advance();
640 /// Walk all of the operations nested under the current scope without
641 /// traversing into any nested symbol tables.
642 template <typename CallbackT>
643 std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
644 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
645 return ::walkSymbolTable(*region, cback);
646 return ::walkSymbolTable(limit.get<Operation *>(), cback);
649 /// The representation of the symbol within this scope.
650 SymbolRefAttr symbol;
652 /// The IR unit representing this scope.
653 llvm::PointerUnion<Operation *, Region *> limit;
655 } // namespace
657 /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
658 static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
659 Operation *limit) {
660 StringAttr symName = SymbolTable::getSymbolName(symbol);
661 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
663 // Compute the ancestors of 'limit'.
664 SetVector<Operation *, SmallVector<Operation *, 4>,
665 SmallPtrSet<Operation *, 4>>
666 limitAncestors;
667 Operation *limitAncestor = limit;
668 do {
669 // Check to see if 'symbol' is an ancestor of 'limit'.
670 if (limitAncestor == symbol) {
671 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
672 // doesn't support parent references.
673 if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
674 symbol->getParentOp())
675 return {{SymbolRefAttr::get(symName), limit}};
676 return {};
679 limitAncestors.insert(limitAncestor);
680 } while ((limitAncestor = limitAncestor->getParentOp()));
682 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
683 Operation *commonAncestor = symbol->getParentOp();
684 do {
685 if (limitAncestors.count(commonAncestor))
686 break;
687 } while ((commonAncestor = commonAncestor->getParentOp()));
688 assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
690 // Compute the set of valid nested references for 'symbol' as far up to the
691 // common ancestor as possible.
692 SmallVector<SymbolRefAttr, 2> references;
693 bool collectedAllReferences = succeeded(
694 collectValidReferencesFor(symbol, symName, commonAncestor, references));
696 // Handle the case where the common ancestor is 'limit'.
697 if (commonAncestor == limit) {
698 SmallVector<SymbolScope, 2> scopes;
700 // Walk each of the ancestors of 'symbol', calling the compute function for
701 // each one.
702 Operation *limitIt = symbol->getParentOp();
703 for (size_t i = 0, e = references.size(); i != e;
704 ++i, limitIt = limitIt->getParentOp()) {
705 assert(limitIt->hasTrait<OpTrait::SymbolTable>());
706 scopes.push_back({references[i], &limitIt->getRegion(0)});
708 return scopes;
711 // Otherwise, we just need the symbol reference for 'symbol' that will be
712 // used within 'limit'. This is the last reference in the list we computed
713 // above if we were able to collect all references.
714 if (!collectedAllReferences)
715 return {};
716 return {{references.back(), limit}};
718 static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
719 Region *limit) {
720 auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
722 // If we collected some scopes to walk, make sure to constrain the one for
723 // limit to the specific region requested.
724 if (!scopes.empty())
725 scopes.back().limit = limit;
726 return scopes;
728 static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
729 Region *limit) {
730 return {{SymbolRefAttr::get(symbol), limit}};
733 static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
734 Operation *limit) {
735 SmallVector<SymbolScope, 1> scopes;
736 auto symbolRef = SymbolRefAttr::get(symbol);
737 for (auto &region : limit->getRegions())
738 scopes.push_back({symbolRef, &region});
739 return scopes;
742 /// Returns true if the given reference 'SubRef' is a sub reference of the
743 /// reference 'ref', i.e. 'ref' is a further qualified reference.
744 static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
745 if (ref == subRef)
746 return true;
748 // If the references are not pointer equal, check to see if `subRef` is a
749 // prefix of `ref`.
750 if (llvm::isa<FlatSymbolRefAttr>(ref) ||
751 ref.getRootReference() != subRef.getRootReference())
752 return false;
754 auto refLeafs = ref.getNestedReferences();
755 auto subRefLeafs = subRef.getNestedReferences();
756 return subRefLeafs.size() < refLeafs.size() &&
757 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
760 //===----------------------------------------------------------------------===//
761 // SymbolTable::getSymbolUses
763 /// The implementation of SymbolTable::getSymbolUses below.
764 template <typename FromT>
765 static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
766 std::vector<SymbolTable::SymbolUse> uses;
767 auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
768 uses.push_back(symbolUse);
769 return WalkResult::advance();
771 auto result = walkSymbolUses(from, walkFn);
772 return result ? std::optional<SymbolTable::UseRange>(std::move(uses))
773 : std::nullopt;
776 /// Get an iterator range for all of the uses, for any symbol, that are nested
777 /// within the given operation 'from'. This does not traverse into any nested
778 /// symbol tables, and will also only return uses on 'from' if it does not
779 /// also define a symbol table. This is because we treat the region as the
780 /// boundary of the symbol table, and not the op itself. This function returns
781 /// std::nullopt if there are any unknown operations that may potentially be
782 /// symbol tables.
783 auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> {
784 return getSymbolUsesImpl(from);
786 auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> {
787 return getSymbolUsesImpl(MutableArrayRef<Region>(*from));
790 //===----------------------------------------------------------------------===//
791 // SymbolTable::getSymbolUses
793 /// The implementation of SymbolTable::getSymbolUses below.
794 template <typename SymbolT, typename IRUnitT>
795 static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
796 IRUnitT *limit) {
797 std::vector<SymbolTable::SymbolUse> uses;
798 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
799 if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
800 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
801 uses.push_back(symbolUse);
803 return std::nullopt;
805 return SymbolTable::UseRange(std::move(uses));
808 /// Get all of the uses of the given symbol that are nested within the given
809 /// operation 'from', invoking the provided callback for each. This does not
810 /// traverse into any nested symbol tables. This function returns std::nullopt
811 /// if there are any unknown operations that may potentially be symbol tables.
812 auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
813 -> std::optional<UseRange> {
814 return getSymbolUsesImpl(symbol, from);
816 auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
817 -> std::optional<UseRange> {
818 return getSymbolUsesImpl(symbol, from);
820 auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
821 -> std::optional<UseRange> {
822 return getSymbolUsesImpl(symbol, from);
824 auto SymbolTable::getSymbolUses(Operation *symbol, Region *from)
825 -> std::optional<UseRange> {
826 return getSymbolUsesImpl(symbol, from);
829 //===----------------------------------------------------------------------===//
830 // SymbolTable::symbolKnownUseEmpty
832 /// The implementation of SymbolTable::symbolKnownUseEmpty below.
833 template <typename SymbolT, typename IRUnitT>
834 static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
835 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
836 // Walk all of the symbol uses looking for a reference to 'symbol'.
837 if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
838 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
839 ? WalkResult::interrupt()
840 : WalkResult::advance();
841 }) != WalkResult::advance())
842 return false;
844 return true;
847 /// Return if the given symbol is known to have no uses that are nested within
848 /// the given operation 'from'. This does not traverse into any nested symbol
849 /// tables. This function will also return false if there are any unknown
850 /// operations that may potentially be symbol tables.
851 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
852 return symbolKnownUseEmptyImpl(symbol, from);
854 bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
855 return symbolKnownUseEmptyImpl(symbol, from);
857 bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
858 return symbolKnownUseEmptyImpl(symbol, from);
860 bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
861 return symbolKnownUseEmptyImpl(symbol, from);
864 //===----------------------------------------------------------------------===//
865 // SymbolTable::replaceAllSymbolUses
867 /// Generates a new symbol reference attribute with a new leaf reference.
868 static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
869 FlatSymbolRefAttr newLeafAttr) {
870 if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
871 return newLeafAttr;
872 auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
873 nestedRefs.back() = newLeafAttr;
874 return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
877 /// The implementation of SymbolTable::replaceAllSymbolUses below.
878 template <typename SymbolT, typename IRUnitT>
879 static LogicalResult
880 replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
881 // Generate a new attribute to replace the given attribute.
882 FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
883 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
884 SymbolRefAttr oldAttr = scope.symbol;
885 SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
886 AttrTypeReplacer replacer;
887 replacer.addReplacement(
888 [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
889 // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
890 // want to accidentally replace an inner reference.
891 if (attr == oldAttr)
892 return {newAttr, WalkResult::skip()};
893 // Handle prefix matches.
894 if (isReferencePrefixOf(oldAttr, attr)) {
895 auto oldNestedRefs = oldAttr.getNestedReferences();
896 auto nestedRefs = attr.getNestedReferences();
897 if (oldNestedRefs.empty())
898 return {SymbolRefAttr::get(newSymbol, nestedRefs),
899 WalkResult::skip()};
901 auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
902 newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
903 return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
904 WalkResult::skip()};
906 return {attr, WalkResult::skip()};
909 auto walkFn = [&](Operation *op) -> std::optional<WalkResult> {
910 replacer.replaceElementsIn(op);
911 return WalkResult::advance();
913 if (!scope.walkSymbolTable(walkFn))
914 return failure();
916 return success();
919 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
920 /// provided symbol 'newSymbol' that are nested within the given operation
921 /// 'from'. This does not traverse into any nested symbol tables. If there are
922 /// any unknown operations that may potentially be symbol tables, no uses are
923 /// replaced and failure is returned.
924 LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
925 StringAttr newSymbol,
926 Operation *from) {
927 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
929 LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
930 StringAttr newSymbol,
931 Operation *from) {
932 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
934 LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
935 StringAttr newSymbol,
936 Region *from) {
937 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
939 LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
940 StringAttr newSymbol,
941 Region *from) {
942 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
945 //===----------------------------------------------------------------------===//
946 // SymbolTableCollection
947 //===----------------------------------------------------------------------===//
949 Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
950 StringAttr symbol) {
951 return getSymbolTable(symbolTableOp).lookup(symbol);
953 Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
954 SymbolRefAttr name) {
955 SmallVector<Operation *, 4> symbols;
956 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
957 return nullptr;
958 return symbols.back();
960 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
961 /// a given SymbolRefAttr. Returns failure if any of the nested references could
962 /// not be resolved.
963 LogicalResult
964 SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
965 SymbolRefAttr name,
966 SmallVectorImpl<Operation *> &symbols) {
967 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
968 return lookupSymbolIn(symbolTableOp, symbol);
970 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
973 /// Returns the operation registered with the given symbol name within the
974 /// closest parent operation of, or including, 'from' with the
975 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
976 /// found.
977 Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
978 StringAttr symbol) {
979 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
980 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
982 Operation *
983 SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
984 SymbolRefAttr symbol) {
985 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
986 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
989 /// Lookup, or create, a symbol table for an operation.
990 SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
991 auto it = symbolTables.try_emplace(op, nullptr);
992 if (it.second)
993 it.first->second = std::make_unique<SymbolTable>(op);
994 return *it.first->second;
997 //===----------------------------------------------------------------------===//
998 // LockedSymbolTableCollection
999 //===----------------------------------------------------------------------===//
1001 Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
1002 StringAttr symbol) {
1003 return getSymbolTable(symbolTableOp).lookup(symbol);
1006 Operation *
1007 LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
1008 FlatSymbolRefAttr symbol) {
1009 return lookupSymbolIn(symbolTableOp, symbol.getAttr());
1012 Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
1013 SymbolRefAttr name) {
1014 SmallVector<Operation *> symbols;
1015 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
1016 return nullptr;
1017 return symbols.back();
1020 LogicalResult LockedSymbolTableCollection::lookupSymbolIn(
1021 Operation *symbolTableOp, SymbolRefAttr name,
1022 SmallVectorImpl<Operation *> &symbols) {
1023 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
1024 return lookupSymbolIn(symbolTableOp, symbol);
1026 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
1029 SymbolTable &
1030 LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) {
1031 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
1032 // Try to find an existing symbol table.
1034 llvm::sys::SmartScopedReader<true> lock(mutex);
1035 auto it = collection.symbolTables.find(symbolTableOp);
1036 if (it != collection.symbolTables.end())
1037 return *it->second;
1039 // Create a symbol table for the operation. Perform construction outside of
1040 // the critical section.
1041 auto symbolTable = std::make_unique<SymbolTable>(symbolTableOp);
1042 // Insert the constructed symbol table.
1043 llvm::sys::SmartScopedWriter<true> lock(mutex);
1044 return *collection.symbolTables
1045 .insert({symbolTableOp, std::move(symbolTable)})
1046 .first->second;
1049 //===----------------------------------------------------------------------===//
1050 // SymbolUserMap
1051 //===----------------------------------------------------------------------===//
1053 SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable,
1054 Operation *symbolTableOp)
1055 : symbolTable(symbolTable) {
1056 // Walk each of the symbol tables looking for discardable callgraph nodes.
1057 SmallVector<Operation *> symbols;
1058 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
1059 for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
1060 auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
1061 assert(symbolUses && "expected uses to be valid");
1063 for (const SymbolTable::SymbolUse &use : *symbolUses) {
1064 symbols.clear();
1065 (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
1066 symbols);
1067 for (Operation *symbolOp : symbols)
1068 symbolToUsers[symbolOp].insert(use.getUser());
1072 // We just set `allSymUsesVisible` to false here because it isn't necessary
1073 // for building the user map.
1074 SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
1075 walkFn);
1078 void SymbolUserMap::replaceAllUsesWith(Operation *symbol,
1079 StringAttr newSymbolName) {
1080 auto it = symbolToUsers.find(symbol);
1081 if (it == symbolToUsers.end())
1082 return;
1084 // Replace the uses within the users of `symbol`.
1085 for (Operation *user : it->second)
1086 (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
1088 // Move the current users of `symbol` to the new symbol if it is in the
1089 // symbol table.
1090 Operation *newSymbol =
1091 symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
1092 if (newSymbol != symbol) {
1093 // Transfer over the users to the new symbol. The reference to the old one
1094 // is fetched again as the iterator is invalidated during the insertion.
1095 auto newIt = symbolToUsers.try_emplace(newSymbol, SetVector<Operation *>{});
1096 auto oldIt = symbolToUsers.find(symbol);
1097 assert(oldIt != symbolToUsers.end() && "missing old users list");
1098 if (newIt.second)
1099 newIt.first->second = std::move(oldIt->second);
1100 else
1101 newIt.first->second.set_union(oldIt->second);
1102 symbolToUsers.erase(oldIt);
1106 //===----------------------------------------------------------------------===//
1107 // Visibility parsing implementation.
1108 //===----------------------------------------------------------------------===//
1110 ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser,
1111 NamedAttrList &attrs) {
1112 StringRef visibility;
1113 if (parser.parseOptionalKeyword(&visibility, {"public", "private", "nested"}))
1114 return failure();
1116 StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
1117 attrs.push_back(parser.getBuilder().getNamedAttr(
1118 SymbolTable::getVisibilityAttrName(), visibilityAttr));
1119 return success();
1122 //===----------------------------------------------------------------------===//
1123 // Symbol Interfaces
1124 //===----------------------------------------------------------------------===//
1126 /// Include the generated symbol interfaces.
1127 #include "mlir/IR/SymbolInterfaces.cpp.inc"