1 //===- SymbolPrivatize.cpp - Pass to mark symbols private -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements an pass that marks all symbols as private unless
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Transforms/Passes.h"
16 #include "mlir/IR/SymbolTable.h"
19 #define GEN_PASS_DEF_SYMBOLPRIVATIZE
20 #include "mlir/Transforms/Passes.h.inc"
26 struct SymbolPrivatize
: public impl::SymbolPrivatizeBase
<SymbolPrivatize
> {
27 explicit SymbolPrivatize(ArrayRef
<std::string
> excludeSymbols
);
28 LogicalResult
initialize(MLIRContext
*context
) override
;
29 void runOnOperation() override
;
31 /// Symbols whose visibility won't be changed.
32 DenseSet
<StringAttr
> excludedSymbols
;
36 SymbolPrivatize::SymbolPrivatize(llvm::ArrayRef
<std::string
> excludeSymbols
) {
37 exclude
= excludeSymbols
;
40 LogicalResult
SymbolPrivatize::initialize(MLIRContext
*context
) {
41 for (const std::string
&symbol
: exclude
)
42 excludedSymbols
.insert(StringAttr::get(context
, symbol
));
46 void SymbolPrivatize::runOnOperation() {
47 for (Region
®ion
: getOperation()->getRegions()) {
48 for (Block
&block
: region
) {
49 for (Operation
&op
: block
) {
50 auto symbol
= dyn_cast
<SymbolOpInterface
>(op
);
53 if (!excludedSymbols
.contains(symbol
.getNameAttr()))
54 symbol
.setVisibility(SymbolTable::Visibility::Private
);
61 mlir::createSymbolPrivatizePass(ArrayRef
<std::string
> exclude
) {
62 return std::make_unique
<SymbolPrivatize
>(exclude
);