1 //===-- CUFAddConstructor.cpp ---------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Builder/BoxValue.h"
10 #include "flang/Optimizer/Builder/CUFCommon.h"
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
13 #include "flang/Optimizer/Builder/Todo.h"
14 #include "flang/Optimizer/CodeGen/Target.h"
15 #include "flang/Optimizer/CodeGen/TypeConverter.h"
16 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
17 #include "flang/Optimizer/Dialect/FIRAttr.h"
18 #include "flang/Optimizer/Dialect/FIRDialect.h"
19 #include "flang/Optimizer/Dialect/FIROps.h"
20 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/Support/DataLayout.h"
23 #include "flang/Runtime/CUDA/registration.h"
24 #include "flang/Runtime/entry-names.h"
25 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/Pass/Pass.h"
29 #include "llvm/ADT/SmallVector.h"
32 #define GEN_PASS_DEF_CUFADDCONSTRUCTOR
33 #include "flang/Optimizer/Transforms/Passes.h.inc"
36 using namespace Fortran::runtime::cuda
;
40 static constexpr llvm::StringRef cudaFortranCtorName
{
41 "__cudaFortranConstructor"};
43 struct CUFAddConstructor
44 : public fir::impl::CUFAddConstructorBase
<CUFAddConstructor
> {
46 void runOnOperation() override
{
47 mlir::ModuleOp mod
= getOperation();
48 mlir::SymbolTable
symTab(mod
);
49 mlir::OpBuilder opBuilder
{mod
.getBodyRegion()};
50 fir::FirOpBuilder
builder(opBuilder
, mod
);
51 fir::KindMapping kindMap
{fir::getKindMapping(mod
)};
52 builder
.setInsertionPointToEnd(mod
.getBody());
53 mlir::Location loc
= mod
.getLoc();
54 auto *ctx
= mod
.getContext();
55 auto voidTy
= mlir::LLVM::LLVMVoidType::get(ctx
);
56 auto idxTy
= builder
.getIndexType();
58 mlir::LLVM::LLVMFunctionType::get(voidTy
, {}, /*isVarArg=*/false);
59 std::optional
<mlir::DataLayout
> dl
=
60 fir::support::getOrSetDataLayout(mod
, /*allowDefaultLayout=*/false);
62 mlir::emitError(mod
.getLoc(),
63 "data layout attribute is required to perform " +
67 // Symbol reference to CUFRegisterAllocator.
68 builder
.setInsertionPointToEnd(mod
.getBody());
69 auto registerFuncOp
= builder
.create
<mlir::LLVM::LLVMFuncOp
>(
70 loc
, RTNAME_STRING(CUFRegisterAllocator
), funcTy
);
71 registerFuncOp
.setVisibility(mlir::SymbolTable::Visibility::Private
);
72 auto cufRegisterAllocatorRef
= mlir::SymbolRefAttr::get(
73 mod
.getContext(), RTNAME_STRING(CUFRegisterAllocator
));
74 builder
.setInsertionPointToEnd(mod
.getBody());
76 // Create the constructor function that call CUFRegisterAllocator.
77 auto func
= builder
.create
<mlir::LLVM::LLVMFuncOp
>(loc
, cudaFortranCtorName
,
79 func
.setLinkage(mlir::LLVM::Linkage::Internal
);
80 builder
.setInsertionPointToStart(func
.addEntryBlock(builder
));
81 builder
.create
<mlir::LLVM::CallOp
>(loc
, funcTy
, cufRegisterAllocatorRef
);
83 auto gpuMod
= symTab
.lookup
<mlir::gpu::GPUModuleOp
>(cudaDeviceModuleName
);
85 auto llvmPtrTy
= mlir::LLVM::LLVMPointerType::get(ctx
);
86 auto registeredMod
= builder
.create
<cuf::RegisterModuleOp
>(
87 loc
, llvmPtrTy
, mlir::SymbolRefAttr::get(ctx
, gpuMod
.getName()));
89 fir::LLVMTypeConverter
typeConverter(mod
, /*applyTBAA=*/false,
90 /*forceUnifiedTBAATree=*/false, *dl
);
92 for (auto func
: gpuMod
.getOps
<mlir::gpu::GPUFuncOp
>()) {
93 if (func
.isKernel()) {
94 auto kernelName
= mlir::SymbolRefAttr::get(
95 builder
.getStringAttr(cudaDeviceModuleName
),
96 {mlir::SymbolRefAttr::get(builder
.getContext(), func
.getName())});
97 builder
.create
<cuf::RegisterKernelOp
>(loc
, kernelName
, registeredMod
);
101 // Register variables
102 for (fir::GlobalOp globalOp
: mod
.getOps
<fir::GlobalOp
>()) {
103 auto attr
= globalOp
.getDataAttrAttr();
107 mlir::func::FuncOp func
;
108 switch (attr
.getValue()) {
109 case cuf::DataAttribute::Device
:
110 case cuf::DataAttribute::Constant
: {
111 func
= fir::runtime::getRuntimeFunc
<mkRTKey(CUFRegisterVariable
)>(
113 auto fTy
= func
.getFunctionType();
115 // Global variable name
116 std::string gblNameStr
= globalOp
.getSymbol().getValue().str();
118 mlir::Value gblName
= fir::getBase(
119 fir::factory::createStringLiteral(builder
, loc
, gblNameStr
));
121 // Global variable size
122 std::optional
<uint64_t> size
;
124 mlir::dyn_cast
<fir::BaseBoxType
>(globalOp
.getType())) {
125 mlir::Type structTy
= typeConverter
.convertBoxTypeAsStruct(boxTy
);
126 size
= dl
->getTypeSizeInBits(structTy
) / 8;
129 size
= fir::getTypeSizeAndAlignmentOrCrash(loc
, globalOp
.getType(),
133 auto sizeVal
= builder
.createIntegerConstant(loc
, idxTy
, *size
);
135 // Global variable address
136 mlir::Value addr
= builder
.create
<fir::AddrOfOp
>(
137 loc
, globalOp
.resultType(), globalOp
.getSymbol());
139 llvm::SmallVector
<mlir::Value
> args
{fir::runtime::createArguments(
140 builder
, loc
, fTy
, registeredMod
, addr
, gblName
, sizeVal
)};
141 builder
.create
<fir::CallOp
>(loc
, func
, args
);
143 case cuf::DataAttribute::Managed
:
144 TODO(loc
, "registration of managed variables");
150 builder
.create
<mlir::LLVM::ReturnOp
>(loc
, mlir::ValueRange
{});
152 // Create the llvm.global_ctor with the function.
153 // TODO: We might want to have a utility that retrieve it if already
154 // created and adds new functions.
155 builder
.setInsertionPointToEnd(mod
.getBody());
156 llvm::SmallVector
<mlir::Attribute
> funcs
;
158 mlir::FlatSymbolRefAttr::get(mod
.getContext(), func
.getSymName()));
159 llvm::SmallVector
<int> priorities
;
160 priorities
.push_back(0);
161 builder
.create
<mlir::LLVM::GlobalCtorsOp
>(
162 mod
.getLoc(), builder
.getArrayAttr(funcs
),
163 builder
.getI32ArrayAttr(priorities
));
167 } // end anonymous namespace