[MemProf] Templatize CallStackRadixTreeBuilder (NFC) (#117014)
[llvm-project.git] / flang / lib / Optimizer / Transforms / CUFAddConstructor.cpp
blobdd204126be5dbce078baadb0d450b2895e1333bc
1 //===-- CUFAddConstructor.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Builder/BoxValue.h"
10 #include "flang/Optimizer/Builder/FIRBuilder.h"
11 #include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
12 #include "flang/Optimizer/Builder/Todo.h"
13 #include "flang/Optimizer/CodeGen/Target.h"
14 #include "flang/Optimizer/CodeGen/TypeConverter.h"
15 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
16 #include "flang/Optimizer/Dialect/FIRAttr.h"
17 #include "flang/Optimizer/Dialect/FIRDialect.h"
18 #include "flang/Optimizer/Dialect/FIROps.h"
19 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
20 #include "flang/Optimizer/Dialect/FIRType.h"
21 #include "flang/Optimizer/Support/DataLayout.h"
22 #include "flang/Optimizer/Transforms/CUFCommon.h"
23 #include "flang/Runtime/CUDA/registration.h"
24 #include "flang/Runtime/entry-names.h"
25 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/Pass/Pass.h"
29 #include "llvm/ADT/SmallVector.h"
31 namespace fir {
32 #define GEN_PASS_DEF_CUFADDCONSTRUCTOR
33 #include "flang/Optimizer/Transforms/Passes.h.inc"
34 } // namespace fir
36 using namespace Fortran::runtime::cuda;
38 namespace {
40 static constexpr llvm::StringRef cudaFortranCtorName{
41 "__cudaFortranConstructor"};
43 struct CUFAddConstructor
44 : public fir::impl::CUFAddConstructorBase<CUFAddConstructor> {
46 void runOnOperation() override {
47 mlir::ModuleOp mod = getOperation();
48 mlir::SymbolTable symTab(mod);
49 mlir::OpBuilder opBuilder{mod.getBodyRegion()};
50 fir::FirOpBuilder builder(opBuilder, mod);
51 fir::KindMapping kindMap{fir::getKindMapping(mod)};
52 builder.setInsertionPointToEnd(mod.getBody());
53 mlir::Location loc = mod.getLoc();
54 auto *ctx = mod.getContext();
55 auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx);
56 auto idxTy = builder.getIndexType();
57 auto funcTy =
58 mlir::LLVM::LLVMFunctionType::get(voidTy, {}, /*isVarArg=*/false);
59 std::optional<mlir::DataLayout> dl =
60 fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/false);
61 if (!dl) {
62 mlir::emitError(mod.getLoc(),
63 "data layout attribute is required to perform " +
64 getName() + "pass");
67 // Symbol reference to CUFRegisterAllocator.
68 builder.setInsertionPointToEnd(mod.getBody());
69 auto registerFuncOp = builder.create<mlir::LLVM::LLVMFuncOp>(
70 loc, RTNAME_STRING(CUFRegisterAllocator), funcTy);
71 registerFuncOp.setVisibility(mlir::SymbolTable::Visibility::Private);
72 auto cufRegisterAllocatorRef = mlir::SymbolRefAttr::get(
73 mod.getContext(), RTNAME_STRING(CUFRegisterAllocator));
74 builder.setInsertionPointToEnd(mod.getBody());
76 // Create the constructor function that call CUFRegisterAllocator.
77 auto func = builder.create<mlir::LLVM::LLVMFuncOp>(loc, cudaFortranCtorName,
78 funcTy);
79 func.setLinkage(mlir::LLVM::Linkage::Internal);
80 builder.setInsertionPointToStart(func.addEntryBlock(builder));
81 builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);
83 auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
84 if (gpuMod) {
85 auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
86 auto registeredMod = builder.create<cuf::RegisterModuleOp>(
87 loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
89 fir::LLVMTypeConverter typeConverter(mod, /*applyTBAA=*/false,
90 /*forceUnifiedTBAATree=*/false, *dl);
91 // Register kernels
92 for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
93 if (func.isKernel()) {
94 auto kernelName = mlir::SymbolRefAttr::get(
95 builder.getStringAttr(cudaDeviceModuleName),
96 {mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
97 builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
101 // Register variables
102 for (fir::GlobalOp globalOp : mod.getOps<fir::GlobalOp>()) {
103 auto attr = globalOp.getDataAttrAttr();
104 if (!attr)
105 continue;
107 mlir::func::FuncOp func;
108 switch (attr.getValue()) {
109 case cuf::DataAttribute::Device:
110 case cuf::DataAttribute::Constant: {
111 func = fir::runtime::getRuntimeFunc<mkRTKey(CUFRegisterVariable)>(
112 loc, builder);
113 auto fTy = func.getFunctionType();
115 // Global variable name
116 std::string gblNameStr = globalOp.getSymbol().getValue().str();
117 gblNameStr += '\0';
118 mlir::Value gblName = fir::getBase(
119 fir::factory::createStringLiteral(builder, loc, gblNameStr));
121 // Global variable size
122 std::optional<uint64_t> size;
123 if (auto boxTy =
124 mlir::dyn_cast<fir::BaseBoxType>(globalOp.getType())) {
125 mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
126 size = dl->getTypeSizeInBits(structTy) / 8;
128 if (!size) {
129 size = fir::getTypeSizeAndAlignmentOrCrash(loc, globalOp.getType(),
130 *dl, kindMap)
131 .first;
133 auto sizeVal = builder.createIntegerConstant(loc, idxTy, *size);
135 // Global variable address
136 mlir::Value addr = builder.create<fir::AddrOfOp>(
137 loc, globalOp.resultType(), globalOp.getSymbol());
139 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
140 builder, loc, fTy, registeredMod, addr, gblName, sizeVal)};
141 builder.create<fir::CallOp>(loc, func, args);
142 } break;
143 case cuf::DataAttribute::Managed:
144 TODO(loc, "registration of managed variables");
145 default:
146 break;
148 if (!func)
149 continue;
152 builder.create<mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{});
154 // Create the llvm.global_ctor with the function.
155 // TODO: We might want to have a utility that retrieve it if already
156 // created and adds new functions.
157 builder.setInsertionPointToEnd(mod.getBody());
158 llvm::SmallVector<mlir::Attribute> funcs;
159 funcs.push_back(
160 mlir::FlatSymbolRefAttr::get(mod.getContext(), func.getSymName()));
161 llvm::SmallVector<int> priorities;
162 priorities.push_back(0);
163 builder.create<mlir::LLVM::GlobalCtorsOp>(
164 mod.getLoc(), builder.getArrayAttr(funcs),
165 builder.getI32ArrayAttr(priorities));
169 } // end anonymous namespace