[Clang][ASTMatcher] Add a matcher for the name of a DependentScopeDeclRefExpr (#121656)
[llvm-project.git] / flang / lib / Optimizer / Transforms / CUFDeviceGlobal.cpp
blob2e6c272fa90891106e396e0cc068bdf5062bdfff
1 //===-- CUFOpConversion.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "flang/Common/Fortran.h"
10 #include "flang/Optimizer/Builder/CUFCommon.h"
11 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
12 #include "flang/Optimizer/Dialect/FIRDialect.h"
13 #include "flang/Optimizer/Dialect/FIROps.h"
14 #include "flang/Optimizer/HLFIR/HLFIROps.h"
15 #include "flang/Optimizer/Support/InternalNames.h"
16 #include "flang/Runtime/CUDA/common.h"
17 #include "flang/Runtime/allocatable.h"
18 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19 #include "mlir/IR/SymbolTable.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/DenseSet.h"
24 namespace fir {
25 #define GEN_PASS_DEF_CUFDEVICEGLOBAL
26 #include "flang/Optimizer/Transforms/Passes.h.inc"
27 } // namespace fir
29 namespace {
31 static void processAddrOfOp(fir::AddrOfOp addrOfOp,
32 mlir::SymbolTable &symbolTable,
33 llvm::DenseSet<fir::GlobalOp> &candidates,
34 bool recurseInGlobal) {
35 if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
36 addrOfOp.getSymbol().getRootReference().getValue())) {
37 // TO DO: limit candidates to non-scalars. Scalars appear to have been
38 // folded in already.
39 if (globalOp.getConstant()) {
40 if (recurseInGlobal)
41 globalOp.walk([&](fir::AddrOfOp op) {
42 processAddrOfOp(op, symbolTable, candidates, recurseInGlobal);
43 });
44 candidates.insert(globalOp);
49 static void processEmboxOp(fir::EmboxOp emboxOp, mlir::SymbolTable &symbolTable,
50 llvm::DenseSet<fir::GlobalOp> &candidates) {
51 if (auto recTy = mlir::dyn_cast<fir::RecordType>(
52 fir::unwrapRefType(emboxOp.getMemref().getType()))) {
53 if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
54 fir::NameUniquer::getTypeDescriptorName(recTy.getName()))) {
55 if (!candidates.contains(globalOp)) {
56 globalOp.walk([&](fir::AddrOfOp op) {
57 processAddrOfOp(op, symbolTable, candidates,
58 /*recurseInGlobal=*/true);
59 });
60 candidates.insert(globalOp);
66 static void
67 prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp,
68 mlir::SymbolTable &symbolTable,
69 llvm::DenseSet<fir::GlobalOp> &candidates) {
70 auto cudaProcAttr{
71 funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())};
72 if (cudaProcAttr && cudaProcAttr.getValue() != cuf::ProcAttribute::Host) {
73 funcOp.walk([&](fir::AddrOfOp op) {
74 processAddrOfOp(op, symbolTable, candidates, /*recurseInGlobal=*/false);
75 });
76 funcOp.walk(
77 [&](fir::EmboxOp op) { processEmboxOp(op, symbolTable, candidates); });
81 class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
82 public:
83 void runOnOperation() override {
84 mlir::Operation *op = getOperation();
85 mlir::ModuleOp mod = mlir::dyn_cast<mlir::ModuleOp>(op);
86 if (!mod)
87 return signalPassFailure();
89 llvm::DenseSet<fir::GlobalOp> candidates;
90 mlir::SymbolTable symTable(mod);
91 mod.walk([&](mlir::func::FuncOp funcOp) {
92 prepareImplicitDeviceGlobals(funcOp, symTable, candidates);
93 return mlir::WalkResult::advance();
94 });
95 mod.walk([&](cuf::KernelOp kernelOp) {
96 kernelOp.walk([&](fir::AddrOfOp addrOfOp) {
97 processAddrOfOp(addrOfOp, symTable, candidates,
98 /*recurseInGlobal=*/false);
99 });
102 // Copying the device global variable into the gpu module
103 mlir::SymbolTable parentSymTable(mod);
104 auto gpuMod = cuf::getOrCreateGPUModule(mod, parentSymTable);
105 if (!gpuMod)
106 return signalPassFailure();
107 mlir::SymbolTable gpuSymTable(gpuMod);
108 for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
109 if (cuf::isRegisteredDeviceGlobal(globalOp))
110 candidates.insert(globalOp);
112 for (auto globalOp : candidates) {
113 auto globalName{globalOp.getSymbol().getValue()};
114 if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
115 break;
117 gpuSymTable.insert(globalOp->clone());
121 } // namespace