1 //===-- CUFDeviceGlobal.cpp -----------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Transforms/CUFOpConversion.h"
10 #include "flang/Common/Fortran.h"
11 #include "flang/Optimizer/Builder/CUFCommon.h"
12 #include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
13 #include "flang/Optimizer/CodeGen/TypeConverter.h"
14 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
15 #include "flang/Optimizer/Dialect/FIRDialect.h"
16 #include "flang/Optimizer/Dialect/FIROps.h"
17 #include "flang/Optimizer/HLFIR/HLFIROps.h"
18 #include "flang/Optimizer/Support/DataLayout.h"
19 #include "flang/Runtime/CUDA/allocatable.h"
20 #include "flang/Runtime/CUDA/common.h"
21 #include "flang/Runtime/CUDA/descriptor.h"
22 #include "flang/Runtime/CUDA/memory.h"
23 #include "flang/Runtime/allocatable.h"
24 #include "mlir/Conversion/LLVMCommon/Pattern.h"
25 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #define GEN_PASS_DEF_CUFOPCONVERSION
33 #include "flang/Optimizer/Transforms/Passes.h.inc"
38 using namespace Fortran::runtime
;
39 using namespace Fortran::runtime::cuda
;
43 static inline unsigned getMemType(cuf::DataAttribute attr
) {
44 if (attr
== cuf::DataAttribute::Device
)
45 return kMemTypeDevice
;
46 if (attr
== cuf::DataAttribute::Managed
)
47 return kMemTypeManaged
;
48 if (attr
== cuf::DataAttribute::Unified
)
49 return kMemTypeUnified
;
50 if (attr
== cuf::DataAttribute::Pinned
)
51 return kMemTypePinned
;
52 llvm::report_fatal_error("unsupported memory type");
55 template <typename OpTy
>
56 static bool isPinned(OpTy op
) {
57 if (op
.getDataAttr() && *op
.getDataAttr() == cuf::DataAttribute::Pinned
)
62 template <typename OpTy
>
63 static bool hasDoubleDescriptors(OpTy op
) {
65 mlir::dyn_cast_or_null
<fir::DeclareOp
>(op
.getBox().getDefiningOp())) {
66 if (mlir::isa_and_nonnull
<fir::AddrOfOp
>(
67 declareOp
.getMemref().getDefiningOp())) {
68 if (isPinned(declareOp
))
72 } else if (auto declareOp
= mlir::dyn_cast_or_null
<hlfir::DeclareOp
>(
73 op
.getBox().getDefiningOp())) {
74 if (mlir::isa_and_nonnull
<fir::AddrOfOp
>(
75 declareOp
.getMemref().getDefiningOp())) {
76 if (isPinned(declareOp
))
84 static mlir::Value
createConvertOp(mlir::PatternRewriter
&rewriter
,
85 mlir::Location loc
, mlir::Type toTy
,
87 if (val
.getType() != toTy
)
88 return rewriter
.create
<fir::ConvertOp
>(loc
, toTy
, val
);
92 template <typename OpTy
>
93 static mlir::LogicalResult
convertOpToCall(OpTy op
,
94 mlir::PatternRewriter
&rewriter
,
95 mlir::func::FuncOp func
) {
96 auto mod
= op
->template getParentOfType
<mlir::ModuleOp
>();
97 fir::FirOpBuilder
builder(rewriter
, mod
);
98 mlir::Location loc
= op
.getLoc();
99 auto fTy
= func
.getFunctionType();
101 mlir::Value sourceFile
= fir::factory::locationToFilename(builder
, loc
);
102 mlir::Value sourceLine
;
103 if constexpr (std::is_same_v
<OpTy
, cuf::AllocateOp
>)
104 sourceLine
= fir::factory::locationToLineNo(
105 builder
, loc
, op
.getSource() ? fTy
.getInput(6) : fTy
.getInput(5));
107 sourceLine
= fir::factory::locationToLineNo(builder
, loc
, fTy
.getInput(4));
109 mlir::Value hasStat
= op
.getHasStat() ? builder
.createBool(loc
, true)
110 : builder
.createBool(loc
, false);
113 if (op
.getErrmsg()) {
114 errmsg
= op
.getErrmsg();
116 mlir::Type boxNoneTy
= fir::BoxType::get(builder
.getNoneType());
117 errmsg
= builder
.create
<fir::AbsentOp
>(loc
, boxNoneTy
).getResult();
119 llvm::SmallVector
<mlir::Value
> args
;
120 if constexpr (std::is_same_v
<OpTy
, cuf::AllocateOp
>) {
121 if (op
.getSource()) {
125 : builder
.createIntegerConstant(loc
, fTy
.getInput(2), -1);
126 args
= fir::runtime::createArguments(builder
, loc
, fTy
, op
.getBox(),
127 op
.getSource(), stream
, hasStat
,
128 errmsg
, sourceFile
, sourceLine
);
133 : builder
.createIntegerConstant(loc
, fTy
.getInput(1), -1);
134 args
= fir::runtime::createArguments(builder
, loc
, fTy
, op
.getBox(),
135 stream
, hasStat
, errmsg
, sourceFile
,
140 fir::runtime::createArguments(builder
, loc
, fTy
, op
.getBox(), hasStat
,
141 errmsg
, sourceFile
, sourceLine
);
143 auto callOp
= builder
.create
<fir::CallOp
>(loc
, func
, args
);
144 rewriter
.replaceOp(op
, callOp
);
145 return mlir::success();
148 struct CUFAllocateOpConversion
149 : public mlir::OpRewritePattern
<cuf::AllocateOp
> {
150 using OpRewritePattern::OpRewritePattern
;
153 matchAndRewrite(cuf::AllocateOp op
,
154 mlir::PatternRewriter
&rewriter
) const override
{
155 // TODO: Pinned is a reference to a logical value that can be set to true
156 // when pinned allocation succeed. This will require a new entry point.
158 return mlir::failure();
160 auto mod
= op
->getParentOfType
<mlir::ModuleOp
>();
161 fir::FirOpBuilder
builder(rewriter
, mod
);
162 mlir::Location loc
= op
.getLoc();
164 if (hasDoubleDescriptors(op
)) {
165 // Allocation for module variable are done with custom runtime entry point
166 // so the descriptors can be synchronized.
167 mlir::func::FuncOp func
;
169 func
= fir::runtime::getRuntimeFunc
<mkRTKey(
170 CUFAllocatableAllocateSourceSync
)>(loc
, builder
);
173 fir::runtime::getRuntimeFunc
<mkRTKey(CUFAllocatableAllocateSync
)>(
175 return convertOpToCall
<cuf::AllocateOp
>(op
, rewriter
, func
);
178 mlir::func::FuncOp func
;
181 fir::runtime::getRuntimeFunc
<mkRTKey(CUFAllocatableAllocateSource
)>(
184 func
= fir::runtime::getRuntimeFunc
<mkRTKey(CUFAllocatableAllocate
)>(
187 return convertOpToCall
<cuf::AllocateOp
>(op
, rewriter
, func
);
191 struct CUFDeallocateOpConversion
192 : public mlir::OpRewritePattern
<cuf::DeallocateOp
> {
193 using OpRewritePattern::OpRewritePattern
;
196 matchAndRewrite(cuf::DeallocateOp op
,
197 mlir::PatternRewriter
&rewriter
) const override
{
199 auto mod
= op
->getParentOfType
<mlir::ModuleOp
>();
200 fir::FirOpBuilder
builder(rewriter
, mod
);
201 mlir::Location loc
= op
.getLoc();
203 if (hasDoubleDescriptors(op
)) {
204 // Deallocation for module variable are done with custom runtime entry
205 // point so the descriptors can be synchronized.
206 mlir::func::FuncOp func
=
207 fir::runtime::getRuntimeFunc
<mkRTKey(CUFAllocatableDeallocate
)>(
209 return convertOpToCall
<cuf::DeallocateOp
>(op
, rewriter
, func
);
212 // Deallocation for local descriptor falls back on the standard runtime
213 // AllocatableDeallocate as the dedicated deallocator is set in the
214 // descriptor before the call.
215 mlir::func::FuncOp func
=
216 fir::runtime::getRuntimeFunc
<mkRTKey(AllocatableDeallocate
)>(loc
,
218 return convertOpToCall
<cuf::DeallocateOp
>(op
, rewriter
, func
);
222 static bool inDeviceContext(mlir::Operation
*op
) {
223 if (op
->getParentOfType
<cuf::KernelOp
>())
225 if (auto funcOp
= op
->getParentOfType
<mlir::gpu::GPUFuncOp
>())
227 if (auto funcOp
= op
->getParentOfType
<mlir::func::FuncOp
>()) {
228 if (auto cudaProcAttr
=
229 funcOp
.getOperation()->getAttrOfType
<cuf::ProcAttributeAttr
>(
230 cuf::getProcAttrName())) {
231 return cudaProcAttr
.getValue() != cuf::ProcAttribute::Host
&&
232 cudaProcAttr
.getValue() != cuf::ProcAttribute::HostDevice
;
238 static int computeWidth(mlir::Location loc
, mlir::Type type
,
239 fir::KindMapping
&kindMap
) {
240 auto eleTy
= fir::unwrapSequenceType(type
);
241 if (auto t
{mlir::dyn_cast
<mlir::IntegerType
>(eleTy
)})
242 return t
.getWidth() / 8;
243 if (auto t
{mlir::dyn_cast
<mlir::FloatType
>(eleTy
)})
244 return t
.getWidth() / 8;
245 if (eleTy
.isInteger(1))
247 if (auto t
{mlir::dyn_cast
<fir::LogicalType
>(eleTy
)})
248 return kindMap
.getLogicalBitsize(t
.getFKind()) / 8;
249 if (auto t
{mlir::dyn_cast
<mlir::ComplexType
>(eleTy
)}) {
251 mlir::cast
<mlir::FloatType
>(t
.getElementType()).getWidth() / 8;
254 if (auto t
{mlir::dyn_cast_or_null
<fir::CharacterType
>(eleTy
)})
255 return kindMap
.getCharacterBitsize(t
.getFKind()) / 8;
256 mlir::emitError(loc
, "unsupported type");
260 struct CUFAllocOpConversion
: public mlir::OpRewritePattern
<cuf::AllocOp
> {
261 using OpRewritePattern::OpRewritePattern
;
263 CUFAllocOpConversion(mlir::MLIRContext
*context
, mlir::DataLayout
*dl
,
264 const fir::LLVMTypeConverter
*typeConverter
)
265 : OpRewritePattern(context
), dl
{dl
}, typeConverter
{typeConverter
} {}
268 matchAndRewrite(cuf::AllocOp op
,
269 mlir::PatternRewriter
&rewriter
) const override
{
271 if (inDeviceContext(op
.getOperation())) {
272 // In device context just replace the cuf.alloc operation with a fir.alloc
273 // the cuf.free will be removed.
274 rewriter
.replaceOpWithNewOp
<fir::AllocaOp
>(
275 op
, op
.getInType(), op
.getUniqName() ? *op
.getUniqName() : "",
276 op
.getBindcName() ? *op
.getBindcName() : "", op
.getTypeparams(),
278 return mlir::success();
281 auto mod
= op
->getParentOfType
<mlir::ModuleOp
>();
282 fir::FirOpBuilder
builder(rewriter
, mod
);
283 mlir::Location loc
= op
.getLoc();
284 mlir::Value sourceFile
= fir::factory::locationToFilename(builder
, loc
);
286 if (!mlir::dyn_cast_or_null
<fir::BaseBoxType
>(op
.getInType())) {
287 // Convert scalar and known size array allocations.
289 fir::KindMapping kindMap
{fir::getKindMapping(mod
)};
290 if (fir::isa_trivial(op
.getInType())) {
291 int width
= computeWidth(loc
, op
.getInType(), kindMap
);
293 builder
.createIntegerConstant(loc
, builder
.getIndexType(), width
);
294 } else if (auto seqTy
= mlir::dyn_cast_or_null
<fir::SequenceType
>(
296 std::size_t size
= 0;
297 if (fir::isa_derived(seqTy
.getEleTy())) {
298 mlir::Type structTy
= typeConverter
->convertType(seqTy
.getEleTy());
299 size
= dl
->getTypeSizeInBits(structTy
) / 8;
301 size
= computeWidth(loc
, seqTy
.getEleTy(), kindMap
);
304 builder
.createIntegerConstant(loc
, builder
.getIndexType(), size
);
306 if (fir::sequenceWithNonConstantShape(seqTy
)) {
307 assert(!op
.getShape().empty() && "expect shape with dynamic arrays");
308 nbElem
= builder
.loadIfRef(loc
, op
.getShape()[0]);
309 for (unsigned i
= 1; i
< op
.getShape().size(); ++i
) {
310 nbElem
= rewriter
.create
<mlir::arith::MulIOp
>(
311 loc
, nbElem
, builder
.loadIfRef(loc
, op
.getShape()[i
]));
314 nbElem
= builder
.createIntegerConstant(loc
, builder
.getIndexType(),
315 seqTy
.getConstantArraySize());
317 bytes
= rewriter
.create
<mlir::arith::MulIOp
>(loc
, nbElem
, width
);
318 } else if (fir::isa_derived(op
.getInType())) {
319 mlir::Type structTy
= typeConverter
->convertType(op
.getInType());
320 std::size_t structSize
= dl
->getTypeSizeInBits(structTy
) / 8;
321 bytes
= builder
.createIntegerConstant(loc
, builder
.getIndexType(),
324 mlir::emitError(loc
, "unsupported type in cuf.alloc\n");
326 mlir::func::FuncOp func
=
327 fir::runtime::getRuntimeFunc
<mkRTKey(CUFMemAlloc
)>(loc
, builder
);
328 auto fTy
= func
.getFunctionType();
329 mlir::Value sourceLine
=
330 fir::factory::locationToLineNo(builder
, loc
, fTy
.getInput(3));
331 mlir::Value memTy
= builder
.createIntegerConstant(
332 loc
, builder
.getI32Type(), getMemType(op
.getDataAttr()));
333 llvm::SmallVector
<mlir::Value
> args
{fir::runtime::createArguments(
334 builder
, loc
, fTy
, bytes
, memTy
, sourceFile
, sourceLine
)};
335 auto callOp
= builder
.create
<fir::CallOp
>(loc
, func
, args
);
336 auto convOp
= builder
.createConvert(loc
, op
.getResult().getType(),
337 callOp
.getResult(0));
338 rewriter
.replaceOp(op
, convOp
);
339 return mlir::success();
342 // Convert descriptor allocations to function call.
343 auto boxTy
= mlir::dyn_cast_or_null
<fir::BaseBoxType
>(op
.getInType());
344 mlir::func::FuncOp func
=
345 fir::runtime::getRuntimeFunc
<mkRTKey(CUFAllocDescriptor
)>(loc
, builder
);
346 auto fTy
= func
.getFunctionType();
347 mlir::Value sourceLine
=
348 fir::factory::locationToLineNo(builder
, loc
, fTy
.getInput(2));
350 mlir::Type structTy
= typeConverter
->convertBoxTypeAsStruct(boxTy
);
351 std::size_t boxSize
= dl
->getTypeSizeInBits(structTy
) / 8;
352 mlir::Value sizeInBytes
=
353 builder
.createIntegerConstant(loc
, builder
.getIndexType(), boxSize
);
355 llvm::SmallVector
<mlir::Value
> args
{fir::runtime::createArguments(
356 builder
, loc
, fTy
, sizeInBytes
, sourceFile
, sourceLine
)};
357 auto callOp
= builder
.create
<fir::CallOp
>(loc
, func
, args
);
358 auto convOp
= builder
.createConvert(loc
, op
.getResult().getType(),
359 callOp
.getResult(0));
360 rewriter
.replaceOp(op
, convOp
);
361 return mlir::success();
365 mlir::DataLayout
*dl
;
366 const fir::LLVMTypeConverter
*typeConverter
;
369 struct DeclareOpConversion
: public mlir::OpRewritePattern
<fir::DeclareOp
> {
370 using OpRewritePattern::OpRewritePattern
;
372 DeclareOpConversion(mlir::MLIRContext
*context
,
373 const mlir::SymbolTable
&symtab
)
374 : OpRewritePattern(context
), symTab
{symtab
} {}
377 matchAndRewrite(fir::DeclareOp op
,
378 mlir::PatternRewriter
&rewriter
) const override
{
379 if (auto addrOfOp
= op
.getMemref().getDefiningOp
<fir::AddrOfOp
>()) {
380 if (auto global
= symTab
.lookup
<fir::GlobalOp
>(
381 addrOfOp
.getSymbol().getRootReference().getValue())) {
382 if (cuf::isRegisteredDeviceGlobal(global
)) {
383 rewriter
.setInsertionPointAfter(addrOfOp
);
384 auto mod
= op
->getParentOfType
<mlir::ModuleOp
>();
385 fir::FirOpBuilder
builder(rewriter
, mod
);
386 mlir::Location loc
= op
.getLoc();
387 mlir::func::FuncOp callee
=
388 fir::runtime::getRuntimeFunc
<mkRTKey(CUFGetDeviceAddress
)>(
390 auto fTy
= callee
.getFunctionType();
391 mlir::Type toTy
= fTy
.getInput(0);
392 mlir::Value inputArg
=
393 createConvertOp(rewriter
, loc
, toTy
, addrOfOp
.getResult());
394 mlir::Value sourceFile
=
395 fir::factory::locationToFilename(builder
, loc
);
396 mlir::Value sourceLine
=
397 fir::factory::locationToLineNo(builder
, loc
, fTy
.getInput(2));
398 llvm::SmallVector
<mlir::Value
> args
{fir::runtime::createArguments(
399 builder
, loc
, fTy
, inputArg
, sourceFile
, sourceLine
)};
400 auto call
= rewriter
.create
<fir::CallOp
>(loc
, callee
, args
);
401 mlir::Value cast
= createConvertOp(
402 rewriter
, loc
, op
.getMemref().getType(), call
->getResult(0));
403 rewriter
.startOpModification(op
);
404 op
.getMemrefMutable().assign(cast
);
405 rewriter
.finalizeOpModification(op
);
414 const mlir::SymbolTable
&symTab
;
417 struct CUFFreeOpConversion
: public mlir::OpRewritePattern
<cuf::FreeOp
> {
418 using OpRewritePattern::OpRewritePattern
;
421 matchAndRewrite(cuf::FreeOp op
,
422 mlir::PatternRewriter
&rewriter
) const override
{
423 if (inDeviceContext(op
.getOperation())) {
424 rewriter
.eraseOp(op
);
425 return mlir::success();
428 if (!mlir::isa
<fir::ReferenceType
>(op
.getDevptr().getType()))
431 auto mod
= op
->getParentOfType
<mlir::ModuleOp
>();
432 fir::FirOpBuilder
builder(rewriter
, mod
);
433 mlir::Location loc
= op
.getLoc();
434 mlir::Value sourceFile
= fir::factory::locationToFilename(builder
, loc
);
436 auto refTy
= mlir::dyn_cast
<fir::ReferenceType
>(op
.getDevptr().getType());
437 if (!mlir::isa
<fir::BaseBoxType
>(refTy
.getEleTy())) {
438 mlir::func::FuncOp func
=
439 fir::runtime::getRuntimeFunc
<mkRTKey(CUFMemFree
)>(loc
, builder
);
440 auto fTy
= func
.getFunctionType();
441 mlir::Value sourceLine
=
442 fir::factory::locationToLineNo(builder
, loc
, fTy
.getInput(3));
443 mlir::Value memTy
= builder
.createIntegerConstant(
444 loc
, builder
.getI32Type(), getMemType(op
.getDataAttr()));
445 llvm::SmallVector
<mlir::Value
> args
{fir::runtime::createArguments(
446 builder
, loc
, fTy
, op
.getDevptr(), memTy
, sourceFile
, sourceLine
)};
447 builder
.create
<fir::CallOp
>(loc
, func
, args
);
448 rewriter
.eraseOp(op
);
449 return mlir::success();
452 // Convert cuf.free on descriptors.
453 mlir::func::FuncOp func
=
454 fir::runtime::getRuntimeFunc
<mkRTKey(CUFFreeDescriptor
)>(loc
, builder
);
455 auto fTy
= func
.getFunctionType();
456 mlir::Value sourceLine
=
457 fir::factory::locationToLineNo(builder
, loc
, fTy
.getInput(2));
458 llvm::SmallVector
<mlir::Value
> args
{fir::runtime::createArguments(
459 builder
, loc
, fTy
, op
.getDevptr(), sourceFile
, sourceLine
)};
460 builder
.create
<fir::CallOp
>(loc
, func
, args
);
461 rewriter
.eraseOp(op
);
462 return mlir::success();
466 static bool isDstGlobal(cuf::DataTransferOp op
) {
467 if (auto declareOp
= op
.getDst().getDefiningOp
<fir::DeclareOp
>())
468 if (declareOp
.getMemref().getDefiningOp
<fir::AddrOfOp
>())
470 if (auto declareOp
= op
.getDst().getDefiningOp
<hlfir::DeclareOp
>())
471 if (declareOp
.getMemref().getDefiningOp
<fir::AddrOfOp
>())
476 static mlir::Value
getShapeFromDecl(mlir::Value src
) {
477 if (auto declareOp
= src
.getDefiningOp
<fir::DeclareOp
>())
478 return declareOp
.getShape();
479 if (auto declareOp
= src
.getDefiningOp
<hlfir::DeclareOp
>())
480 return declareOp
.getShape();
481 return mlir::Value
{};
484 static mlir::Value
emboxSrc(mlir::PatternRewriter
&rewriter
,
485 cuf::DataTransferOp op
,
486 const mlir::SymbolTable
&symtab
) {
487 auto mod
= op
->getParentOfType
<mlir::ModuleOp
>();
488 mlir::Location loc
= op
.getLoc();
489 fir::FirOpBuilder
builder(rewriter
, mod
);
491 mlir::Type srcTy
= fir::unwrapRefType(op
.getSrc().getType());
492 if (fir::isa_trivial(srcTy
) &&
493 mlir::matchPattern(op
.getSrc().getDefiningOp(), mlir::m_Constant())) {
494 mlir::Value src
= op
.getSrc();
495 if (srcTy
.isInteger(1)) {
496 // i1 is not a supported type in the descriptor and it is actually coming
497 // from a LOGICAL constant. Store it as a fir.logical.
498 srcTy
= fir::LogicalType::get(rewriter
.getContext(), 4);
499 src
= createConvertOp(rewriter
, loc
, srcTy
, src
);
501 // Put constant in memory if it is not.
502 mlir::Value alloc
= builder
.createTemporary(loc
, srcTy
);
503 builder
.create
<fir::StoreOp
>(loc
, src
, alloc
);
508 llvm::SmallVector
<mlir::Value
> lenParams
;
509 mlir::Type boxTy
= fir::BoxType::get(srcTy
);
511 builder
.createBox(loc
, boxTy
, addr
, getShapeFromDecl(op
.getSrc()),
512 /*slice=*/nullptr, lenParams
,
514 mlir::Value src
= builder
.createTemporary(loc
, box
.getType());
515 builder
.create
<fir::StoreOp
>(loc
, box
, src
);
519 static mlir::Value
emboxDst(mlir::PatternRewriter
&rewriter
,
520 cuf::DataTransferOp op
,
521 const mlir::SymbolTable
&symtab
) {
522 auto mod
= op
->getParentOfType
<mlir::ModuleOp
>();
523 mlir::Location loc
= op
.getLoc();
524 fir::FirOpBuilder
builder(rewriter
, mod
);
525 mlir::Type dstTy
= fir::unwrapRefType(op
.getDst().getType());
526 mlir::Value dstAddr
= op
.getDst();
527 mlir::Type dstBoxTy
= fir::BoxType::get(dstTy
);
528 llvm::SmallVector
<mlir::Value
> lenParams
;
530 builder
.createBox(loc
, dstBoxTy
, dstAddr
, getShapeFromDecl(op
.getDst()),
531 /*slice=*/nullptr, lenParams
,
533 mlir::Value dst
= builder
.createTemporary(loc
, dstBox
.getType());
534 builder
.create
<fir::StoreOp
>(loc
, dstBox
, dst
);
538 struct CUFDataTransferOpConversion
539 : public mlir::OpRewritePattern
<cuf::DataTransferOp
> {
540 using OpRewritePattern::OpRewritePattern
;
542 CUFDataTransferOpConversion(mlir::MLIRContext
*context
,
543 const mlir::SymbolTable
&symtab
,
544 mlir::DataLayout
*dl
,
545 const fir::LLVMTypeConverter
*typeConverter
)
546 : OpRewritePattern(context
), symtab
{symtab
}, dl
{dl
},
547 typeConverter
{typeConverter
} {}
550 matchAndRewrite(cuf::DataTransferOp op
,
551 mlir::PatternRewriter
&rewriter
) const override
{
553 mlir::Type srcTy
= fir::unwrapRefType(op
.getSrc().getType());
554 mlir::Type dstTy
= fir::unwrapRefType(op
.getDst().getType());
556 mlir::Location loc
= op
.getLoc();
558 if (op
.getTransferKind() == cuf::DataTransferKind::HostDevice
) {
559 mode
= kHostToDevice
;
560 } else if (op
.getTransferKind() == cuf::DataTransferKind::DeviceHost
) {
561 mode
= kDeviceToHost
;
562 } else if (op
.getTransferKind() == cuf::DataTransferKind::DeviceDevice
) {
563 mode
= kDeviceToDevice
;
565 mlir::emitError(loc
, "unsupported transfer kind\n");
568 auto mod
= op
->getParentOfType
<mlir::ModuleOp
>();
569 fir::FirOpBuilder
builder(rewriter
, mod
);
570 fir::KindMapping kindMap
{fir::getKindMapping(mod
)};
571 mlir::Value modeValue
=
572 builder
.createIntegerConstant(loc
, builder
.getI32Type(), mode
);
574 // Convert data transfer without any descriptor.
575 if (!mlir::isa
<fir::BaseBoxType
>(srcTy
) &&
576 !mlir::isa
<fir::BaseBoxType
>(dstTy
)) {
578 if (fir::isa_trivial(srcTy
) && !fir::isa_trivial(dstTy
)) {
579 // Initialization of an array from a scalar value should be implemented
580 // via a kernel launch. Use the flan runtime via the Assign function
581 // until we have more infrastructure.
582 mlir::Value src
= emboxSrc(rewriter
, op
, symtab
);
583 mlir::Value dst
= emboxDst(rewriter
, op
, symtab
);
584 mlir::func::FuncOp func
=
585 fir::runtime::getRuntimeFunc
<mkRTKey(CUFDataTransferCstDesc
)>(
587 auto fTy
= func
.getFunctionType();
588 mlir::Value sourceFile
= fir::factory::locationToFilename(builder
, loc
);
589 mlir::Value sourceLine
=
590 fir::factory::locationToLineNo(builder
, loc
, fTy
.getInput(4));
591 llvm::SmallVector
<mlir::Value
> args
{fir::runtime::createArguments(
592 builder
, loc
, fTy
, dst
, src
, modeValue
, sourceFile
, sourceLine
)};
593 builder
.create
<fir::CallOp
>(loc
, func
, args
);
594 rewriter
.eraseOp(op
);
595 return mlir::success();
598 mlir::Type i64Ty
= builder
.getI64Type();
599 mlir::Value nbElement
;
601 llvm::SmallVector
<mlir::Value
> extents
;
603 mlir::dyn_cast
<fir::ShapeOp
>(op
.getShape().getDefiningOp())) {
604 extents
= shapeOp
.getExtents();
605 } else if (auto shapeShiftOp
= mlir::dyn_cast
<fir::ShapeShiftOp
>(
606 op
.getShape().getDefiningOp())) {
607 for (auto i
: llvm::enumerate(shapeShiftOp
.getPairs()))
609 extents
.push_back(i
.value());
612 nbElement
= rewriter
.create
<fir::ConvertOp
>(loc
, i64Ty
, extents
[0]);
613 for (unsigned i
= 1; i
< extents
.size(); ++i
) {
615 rewriter
.create
<fir::ConvertOp
>(loc
, i64Ty
, extents
[i
]);
617 rewriter
.create
<mlir::arith::MulIOp
>(loc
, nbElement
, operand
);
620 if (auto seqTy
= mlir::dyn_cast_or_null
<fir::SequenceType
>(dstTy
))
621 nbElement
= builder
.createIntegerConstant(
622 loc
, i64Ty
, seqTy
.getConstantArraySize());
625 if (fir::isa_derived(fir::unwrapSequenceType(dstTy
))) {
626 mlir::Type structTy
=
627 typeConverter
->convertType(fir::unwrapSequenceType(dstTy
));
628 width
= dl
->getTypeSizeInBits(structTy
) / 8;
630 width
= computeWidth(loc
, dstTy
, kindMap
);
632 mlir::Value widthValue
= rewriter
.create
<mlir::arith::ConstantOp
>(
633 loc
, i64Ty
, rewriter
.getIntegerAttr(i64Ty
, width
));
636 ? rewriter
.create
<mlir::arith::MulIOp
>(loc
, nbElement
, widthValue
)
639 mlir::func::FuncOp func
=
640 fir::runtime::getRuntimeFunc
<mkRTKey(CUFDataTransferPtrPtr
)>(loc
,
642 auto fTy
= func
.getFunctionType();
643 mlir::Value sourceFile
= fir::factory::locationToFilename(builder
, loc
);
644 mlir::Value sourceLine
=
645 fir::factory::locationToLineNo(builder
, loc
, fTy
.getInput(5));
647 mlir::Value dst
= op
.getDst();
648 mlir::Value src
= op
.getSrc();
649 // Materialize the src if constant.
650 if (matchPattern(src
.getDefiningOp(), mlir::m_Constant())) {
651 mlir::Value temp
= builder
.createTemporary(loc
, srcTy
);
652 builder
.create
<fir::StoreOp
>(loc
, src
, temp
);
655 llvm::SmallVector
<mlir::Value
> args
{
656 fir::runtime::createArguments(builder
, loc
, fTy
, dst
, src
, bytes
,
657 modeValue
, sourceFile
, sourceLine
)};
658 builder
.create
<fir::CallOp
>(loc
, func
, args
);
659 rewriter
.eraseOp(op
);
660 return mlir::success();
663 auto materializeBoxIfNeeded
= [&](mlir::Value val
) -> mlir::Value
{
664 if (mlir::isa
<fir::EmboxOp
, fir::ReboxOp
>(val
.getDefiningOp())) {
665 // Materialize the box to memory to be able to call the runtime.
666 mlir::Value box
= builder
.createTemporary(loc
, val
.getType());
667 builder
.create
<fir::StoreOp
>(loc
, val
, box
);
673 // Conversion of data transfer involving at least one descriptor.
674 if (mlir::isa
<fir::BaseBoxType
>(dstTy
)) {
675 // Transfer to a descriptor.
676 mlir::func::FuncOp func
=
678 ? fir::runtime::getRuntimeFunc
<mkRTKey(
679 CUFDataTransferGlobalDescDesc
)>(loc
, builder
)
680 : fir::runtime::getRuntimeFunc
<mkRTKey(CUFDataTransferDescDesc
)>(
682 mlir::Value dst
= op
.getDst();
683 mlir::Value src
= op
.getSrc();
684 if (!mlir::isa
<fir::BaseBoxType
>(srcTy
)) {
685 src
= emboxSrc(rewriter
, op
, symtab
);
686 if (fir::isa_trivial(srcTy
))
687 func
= fir::runtime::getRuntimeFunc
<mkRTKey(CUFDataTransferCstDesc
)>(
691 src
= materializeBoxIfNeeded(src
);
692 dst
= materializeBoxIfNeeded(dst
);
694 auto fTy
= func
.getFunctionType();
695 mlir::Value sourceFile
= fir::factory::locationToFilename(builder
, loc
);
696 mlir::Value sourceLine
=
697 fir::factory::locationToLineNo(builder
, loc
, fTy
.getInput(4));
698 llvm::SmallVector
<mlir::Value
> args
{fir::runtime::createArguments(
699 builder
, loc
, fTy
, dst
, src
, modeValue
, sourceFile
, sourceLine
)};
700 builder
.create
<fir::CallOp
>(loc
, func
, args
);
701 rewriter
.eraseOp(op
);
703 // Transfer from a descriptor.
704 mlir::Value dst
= emboxDst(rewriter
, op
, symtab
);
705 mlir::Value src
= materializeBoxIfNeeded(op
.getSrc());
707 mlir::func::FuncOp func
= fir::runtime::getRuntimeFunc
<mkRTKey(
708 CUFDataTransferDescDescNoRealloc
)>(loc
, builder
);
710 auto fTy
= func
.getFunctionType();
711 mlir::Value sourceFile
= fir::factory::locationToFilename(builder
, loc
);
712 mlir::Value sourceLine
=
713 fir::factory::locationToLineNo(builder
, loc
, fTy
.getInput(4));
714 llvm::SmallVector
<mlir::Value
> args
{fir::runtime::createArguments(
715 builder
, loc
, fTy
, dst
, src
, modeValue
, sourceFile
, sourceLine
)};
716 builder
.create
<fir::CallOp
>(loc
, func
, args
);
717 rewriter
.eraseOp(op
);
719 return mlir::success();
723 const mlir::SymbolTable
&symtab
;
724 mlir::DataLayout
*dl
;
725 const fir::LLVMTypeConverter
*typeConverter
;
728 struct CUFLaunchOpConversion
729 : public mlir::OpRewritePattern
<cuf::KernelLaunchOp
> {
731 using OpRewritePattern::OpRewritePattern
;
733 CUFLaunchOpConversion(mlir::MLIRContext
*context
,
734 const mlir::SymbolTable
&symTab
)
735 : OpRewritePattern(context
), symTab
{symTab
} {}
738 matchAndRewrite(cuf::KernelLaunchOp op
,
739 mlir::PatternRewriter
&rewriter
) const override
{
740 mlir::Location loc
= op
.getLoc();
741 auto idxTy
= mlir::IndexType::get(op
.getContext());
742 auto zero
= rewriter
.create
<mlir::arith::ConstantOp
>(
743 loc
, rewriter
.getIntegerType(32), rewriter
.getI32IntegerAttr(0));
745 rewriter
.create
<mlir::arith::IndexCastOp
>(loc
, idxTy
, op
.getGridX());
747 rewriter
.create
<mlir::arith::IndexCastOp
>(loc
, idxTy
, op
.getGridY());
749 rewriter
.create
<mlir::arith::IndexCastOp
>(loc
, idxTy
, op
.getGridZ());
751 rewriter
.create
<mlir::arith::IndexCastOp
>(loc
, idxTy
, op
.getBlockX());
753 rewriter
.create
<mlir::arith::IndexCastOp
>(loc
, idxTy
, op
.getBlockY());
755 rewriter
.create
<mlir::arith::IndexCastOp
>(loc
, idxTy
, op
.getBlockZ());
756 auto kernelName
= mlir::SymbolRefAttr::get(
757 rewriter
.getStringAttr(cudaDeviceModuleName
),
758 {mlir::SymbolRefAttr::get(
759 rewriter
.getContext(),
760 op
.getCallee().getLeafReference().getValue())});
761 mlir::Value clusterDimX
, clusterDimY
, clusterDimZ
;
762 if (auto funcOp
= symTab
.lookup
<mlir::func::FuncOp
>(
763 op
.getCallee().getLeafReference())) {
764 if (auto clusterDimsAttr
= funcOp
->getAttrOfType
<cuf::ClusterDimsAttr
>(
765 cuf::getClusterDimsAttrName())) {
766 clusterDimX
= rewriter
.create
<mlir::arith::ConstantIndexOp
>(
767 loc
, clusterDimsAttr
.getX().getInt());
768 clusterDimY
= rewriter
.create
<mlir::arith::ConstantIndexOp
>(
769 loc
, clusterDimsAttr
.getY().getInt());
770 clusterDimZ
= rewriter
.create
<mlir::arith::ConstantIndexOp
>(
771 loc
, clusterDimsAttr
.getZ().getInt());
774 auto gpuLaunchOp
= rewriter
.create
<mlir::gpu::LaunchFuncOp
>(
775 loc
, kernelName
, mlir::gpu::KernelDim3
{gridSizeX
, gridSizeY
, gridSizeZ
},
776 mlir::gpu::KernelDim3
{blockSizeX
, blockSizeY
, blockSizeZ
}, zero
,
778 if (clusterDimX
&& clusterDimY
&& clusterDimZ
) {
779 gpuLaunchOp
.getClusterSizeXMutable().assign(clusterDimX
);
780 gpuLaunchOp
.getClusterSizeYMutable().assign(clusterDimY
);
781 gpuLaunchOp
.getClusterSizeZMutable().assign(clusterDimZ
);
783 rewriter
.replaceOp(op
, gpuLaunchOp
);
784 return mlir::success();
788 const mlir::SymbolTable
&symTab
;
791 struct CUFSyncDescriptorOpConversion
792 : public mlir::OpRewritePattern
<cuf::SyncDescriptorOp
> {
793 using OpRewritePattern::OpRewritePattern
;
796 matchAndRewrite(cuf::SyncDescriptorOp op
,
797 mlir::PatternRewriter
&rewriter
) const override
{
798 auto mod
= op
->getParentOfType
<mlir::ModuleOp
>();
799 fir::FirOpBuilder
builder(rewriter
, mod
);
800 mlir::Location loc
= op
.getLoc();
802 auto globalOp
= mod
.lookupSymbol
<fir::GlobalOp
>(op
.getGlobalName());
804 return mlir::failure();
806 auto hostAddr
= builder
.create
<fir::AddrOfOp
>(
807 loc
, fir::ReferenceType::get(globalOp
.getType()), op
.getGlobalName());
808 mlir::func::FuncOp callee
=
809 fir::runtime::getRuntimeFunc
<mkRTKey(CUFSyncGlobalDescriptor
)>(loc
,
811 auto fTy
= callee
.getFunctionType();
812 mlir::Value sourceFile
= fir::factory::locationToFilename(builder
, loc
);
813 mlir::Value sourceLine
=
814 fir::factory::locationToLineNo(builder
, loc
, fTy
.getInput(2));
815 llvm::SmallVector
<mlir::Value
> args
{fir::runtime::createArguments(
816 builder
, loc
, fTy
, hostAddr
, sourceFile
, sourceLine
)};
817 builder
.create
<fir::CallOp
>(loc
, callee
, args
);
819 return mlir::success();
823 class CUFOpConversion
: public fir::impl::CUFOpConversionBase
<CUFOpConversion
> {
825 void runOnOperation() override
{
826 auto *ctx
= &getContext();
827 mlir::RewritePatternSet
patterns(ctx
);
828 mlir::ConversionTarget
target(*ctx
);
830 mlir::Operation
*op
= getOperation();
831 mlir::ModuleOp module
= mlir::dyn_cast
<mlir::ModuleOp
>(op
);
833 return signalPassFailure();
834 mlir::SymbolTable
symtab(module
);
836 std::optional
<mlir::DataLayout
> dl
=
837 fir::support::getOrSetDataLayout(module
, /*allowDefaultLayout=*/false);
838 fir::LLVMTypeConverter
typeConverter(module
, /*applyTBAA=*/false,
839 /*forceUnifiedTBAATree=*/false, *dl
);
840 target
.addLegalDialect
<fir::FIROpsDialect
, mlir::arith::ArithDialect
,
841 mlir::gpu::GPUDialect
>();
842 cuf::populateCUFToFIRConversionPatterns(typeConverter
, *dl
, symtab
,
844 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target
,
845 std::move(patterns
)))) {
846 mlir::emitError(mlir::UnknownLoc::get(ctx
),
847 "error in CUF op conversion\n");
851 target
.addDynamicallyLegalOp
<fir::DeclareOp
>([&](fir::DeclareOp op
) {
852 if (inDeviceContext(op
))
854 if (auto addrOfOp
= op
.getMemref().getDefiningOp
<fir::AddrOfOp
>()) {
855 if (auto global
= symtab
.lookup
<fir::GlobalOp
>(
856 addrOfOp
.getSymbol().getRootReference().getValue())) {
857 if (mlir::isa
<fir::BaseBoxType
>(fir::unwrapRefType(global
.getType())))
859 if (cuf::isRegisteredDeviceGlobal(global
))
867 cuf::populateFIRCUFConversionPatterns(symtab
, patterns
);
868 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target
,
869 std::move(patterns
)))) {
870 mlir::emitError(mlir::UnknownLoc::get(ctx
),
871 "error in CUF op conversion\n");
878 void cuf::populateCUFToFIRConversionPatterns(
879 const fir::LLVMTypeConverter
&converter
, mlir::DataLayout
&dl
,
880 const mlir::SymbolTable
&symtab
, mlir::RewritePatternSet
&patterns
) {
881 patterns
.insert
<CUFAllocOpConversion
>(patterns
.getContext(), &dl
, &converter
);
882 patterns
.insert
<CUFAllocateOpConversion
, CUFDeallocateOpConversion
,
883 CUFFreeOpConversion
, CUFSyncDescriptorOpConversion
>(
884 patterns
.getContext());
885 patterns
.insert
<CUFDataTransferOpConversion
>(patterns
.getContext(), symtab
,
887 patterns
.insert
<CUFLaunchOpConversion
>(patterns
.getContext(), symtab
);
890 void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable
&symtab
,
891 mlir::RewritePatternSet
&patterns
) {
892 patterns
.insert
<DeclareOpConversion
>(patterns
.getContext(), symtab
);