DAG: Fix assuming f16 is the only 16-bit fp type in concat vector combine (#121637)
[llvm-project.git] / flang / lib / Optimizer / Transforms / CompilerGeneratedNames.cpp
blobf92c60908b149622a1b3b55aaf3ad1fd84ccff6c
1 //=== CompilerGeneratedNames.cpp - convert special symbols in global names ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Dialect/FIRDialect.h"
10 #include "flang/Optimizer/Dialect/FIROps.h"
11 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
12 #include "flang/Optimizer/Support/InternalNames.h"
13 #include "flang/Optimizer/Transforms/Passes.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "mlir/Pass/Pass.h"
19 namespace fir {
20 #define GEN_PASS_DEF_COMPILERGENERATEDNAMESCONVERSION
21 #include "flang/Optimizer/Transforms/Passes.h.inc"
22 } // namespace fir
24 using namespace mlir;
26 namespace {
28 class CompilerGeneratedNamesConversionPass
29 : public fir::impl::CompilerGeneratedNamesConversionBase<
30 CompilerGeneratedNamesConversionPass> {
31 public:
32 using CompilerGeneratedNamesConversionBase<
33 CompilerGeneratedNamesConversionPass>::
34 CompilerGeneratedNamesConversionBase;
36 mlir::ModuleOp getModule() { return getOperation(); }
37 void runOnOperation() override;
39 } // namespace
41 void CompilerGeneratedNamesConversionPass::runOnOperation() {
42 auto op = getOperation();
43 auto *context = &getContext();
45 llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings;
47 auto processOp = [&](mlir::Operation &op) {
48 auto symName = op.getAttrOfType<mlir::StringAttr>(
49 mlir::SymbolTable::getSymbolAttrName());
50 auto deconstructedName = fir::NameUniquer::deconstruct(symName);
51 if (deconstructedName.first != fir::NameUniquer::NameKind::NOT_UNIQUED &&
52 !fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) {
53 std::string newName =
54 fir::NameUniquer::replaceSpecialSymbols(symName.getValue().str());
55 if (newName != symName) {
56 auto newAttr = mlir::StringAttr::get(context, newName);
57 mlir::SymbolTable::setSymbolName(&op, newAttr);
58 auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr);
59 remappings.try_emplace(symName, newSymRef);
63 for (auto &op : op->getRegion(0).front()) {
64 if (llvm::isa<mlir::func::FuncOp>(op) || llvm::isa<fir::GlobalOp>(op))
65 processOp(op);
66 else if (auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(&op))
67 for (auto &op : gpuMod->getRegion(0).front())
68 if (llvm::isa<mlir::func::FuncOp>(op) || llvm::isa<fir::GlobalOp>(op) ||
69 llvm::isa<mlir::gpu::GPUFuncOp>(op))
70 processOp(op);
73 if (remappings.empty())
74 return;
76 // Update all uses of the functions and globals that have been renamed.
77 op.walk([&remappings](mlir::Operation *nestedOp) {
78 llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>> updates;
79 for (const mlir::NamedAttribute &attr : nestedOp->getAttrDictionary())
80 if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue()))
81 if (auto remap = remappings.find(symRef.getRootReference());
82 remap != remappings.end())
83 updates.emplace_back(std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{
84 attr.getName(), mlir::SymbolRefAttr(remap->second)});
85 for (auto update : updates)
86 nestedOp->setAttr(update.first, update.second);
87 });