1 //===-- CUFOpConversion.cpp -----------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Common/Fortran.h"
10 #include "flang/Optimizer/Builder/CUFCommon.h"
11 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
12 #include "flang/Optimizer/Dialect/FIRDialect.h"
13 #include "flang/Optimizer/Dialect/FIROps.h"
14 #include "flang/Optimizer/HLFIR/HLFIROps.h"
15 #include "flang/Optimizer/Support/InternalNames.h"
16 #include "flang/Runtime/CUDA/common.h"
17 #include "flang/Runtime/allocatable.h"
18 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19 #include "mlir/IR/SymbolTable.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/DenseSet.h"
25 #define GEN_PASS_DEF_CUFDEVICEGLOBAL
26 #include "flang/Optimizer/Transforms/Passes.h.inc"
31 static void processAddrOfOp(fir::AddrOfOp addrOfOp
,
32 mlir::SymbolTable
&symbolTable
,
33 llvm::DenseSet
<fir::GlobalOp
> &candidates
,
34 bool recurseInGlobal
) {
35 if (auto globalOp
= symbolTable
.lookup
<fir::GlobalOp
>(
36 addrOfOp
.getSymbol().getRootReference().getValue())) {
37 // TO DO: limit candidates to non-scalars. Scalars appear to have been
39 if (globalOp
.getConstant()) {
41 globalOp
.walk([&](fir::AddrOfOp op
) {
42 processAddrOfOp(op
, symbolTable
, candidates
, recurseInGlobal
);
44 candidates
.insert(globalOp
);
49 static void processEmboxOp(fir::EmboxOp emboxOp
, mlir::SymbolTable
&symbolTable
,
50 llvm::DenseSet
<fir::GlobalOp
> &candidates
) {
51 if (auto recTy
= mlir::dyn_cast
<fir::RecordType
>(
52 fir::unwrapRefType(emboxOp
.getMemref().getType()))) {
53 if (auto globalOp
= symbolTable
.lookup
<fir::GlobalOp
>(
54 fir::NameUniquer::getTypeDescriptorName(recTy
.getName()))) {
55 if (!candidates
.contains(globalOp
)) {
56 globalOp
.walk([&](fir::AddrOfOp op
) {
57 processAddrOfOp(op
, symbolTable
, candidates
,
58 /*recurseInGlobal=*/true);
60 candidates
.insert(globalOp
);
67 prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp
,
68 mlir::SymbolTable
&symbolTable
,
69 llvm::DenseSet
<fir::GlobalOp
> &candidates
) {
71 funcOp
->getAttrOfType
<cuf::ProcAttributeAttr
>(cuf::getProcAttrName())};
72 if (cudaProcAttr
&& cudaProcAttr
.getValue() != cuf::ProcAttribute::Host
) {
73 funcOp
.walk([&](fir::AddrOfOp op
) {
74 processAddrOfOp(op
, symbolTable
, candidates
, /*recurseInGlobal=*/false);
77 [&](fir::EmboxOp op
) { processEmboxOp(op
, symbolTable
, candidates
); });
81 class CUFDeviceGlobal
: public fir::impl::CUFDeviceGlobalBase
<CUFDeviceGlobal
> {
83 void runOnOperation() override
{
84 mlir::Operation
*op
= getOperation();
85 mlir::ModuleOp mod
= mlir::dyn_cast
<mlir::ModuleOp
>(op
);
87 return signalPassFailure();
89 llvm::DenseSet
<fir::GlobalOp
> candidates
;
90 mlir::SymbolTable
symTable(mod
);
91 mod
.walk([&](mlir::func::FuncOp funcOp
) {
92 prepareImplicitDeviceGlobals(funcOp
, symTable
, candidates
);
93 return mlir::WalkResult::advance();
95 mod
.walk([&](cuf::KernelOp kernelOp
) {
96 kernelOp
.walk([&](fir::AddrOfOp addrOfOp
) {
97 processAddrOfOp(addrOfOp
, symTable
, candidates
,
98 /*recurseInGlobal=*/false);
102 // Copying the device global variable into the gpu module
103 mlir::SymbolTable
parentSymTable(mod
);
104 auto gpuMod
= cuf::getOrCreateGPUModule(mod
, parentSymTable
);
106 return signalPassFailure();
107 mlir::SymbolTable
gpuSymTable(gpuMod
);
108 for (auto globalOp
: mod
.getOps
<fir::GlobalOp
>()) {
109 if (cuf::isRegisteredDeviceGlobal(globalOp
))
110 candidates
.insert(globalOp
);
112 for (auto globalOp
: candidates
) {
113 auto globalName
{globalOp
.getSymbol().getValue()};
114 if (gpuSymTable
.lookup
<fir::GlobalOp
>(globalName
)) {
117 gpuSymTable
.insert(globalOp
->clone());