DAG: Fix assuming f16 is the only 16-bit fp type in concat vector combine (#121637)
[llvm-project.git] / flang / lib / Optimizer / Transforms / CUFOpConversion.cpp
blob8c525fc6daff5e31552662aea0c1c957e4c96645
1 //===-- CUFDeviceGlobal.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Transforms/CUFOpConversion.h"
10 #include "flang/Common/Fortran.h"
11 #include "flang/Optimizer/Builder/CUFCommon.h"
12 #include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
13 #include "flang/Optimizer/CodeGen/TypeConverter.h"
14 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
15 #include "flang/Optimizer/Dialect/FIRDialect.h"
16 #include "flang/Optimizer/Dialect/FIROps.h"
17 #include "flang/Optimizer/HLFIR/HLFIROps.h"
18 #include "flang/Optimizer/Support/DataLayout.h"
19 #include "flang/Runtime/CUDA/allocatable.h"
20 #include "flang/Runtime/CUDA/common.h"
21 #include "flang/Runtime/CUDA/descriptor.h"
22 #include "flang/Runtime/CUDA/memory.h"
23 #include "flang/Runtime/allocatable.h"
24 #include "mlir/Conversion/LLVMCommon/Pattern.h"
25 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 namespace fir {
32 #define GEN_PASS_DEF_CUFOPCONVERSION
33 #include "flang/Optimizer/Transforms/Passes.h.inc"
34 } // namespace fir
36 using namespace fir;
37 using namespace mlir;
38 using namespace Fortran::runtime;
39 using namespace Fortran::runtime::cuda;
41 namespace {
43 static inline unsigned getMemType(cuf::DataAttribute attr) {
44 if (attr == cuf::DataAttribute::Device)
45 return kMemTypeDevice;
46 if (attr == cuf::DataAttribute::Managed)
47 return kMemTypeManaged;
48 if (attr == cuf::DataAttribute::Unified)
49 return kMemTypeUnified;
50 if (attr == cuf::DataAttribute::Pinned)
51 return kMemTypePinned;
52 llvm::report_fatal_error("unsupported memory type");
55 template <typename OpTy>
56 static bool isPinned(OpTy op) {
57 if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned)
58 return true;
59 return false;
62 template <typename OpTy>
63 static bool hasDoubleDescriptors(OpTy op) {
64 if (auto declareOp =
65 mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) {
66 if (mlir::isa_and_nonnull<fir::AddrOfOp>(
67 declareOp.getMemref().getDefiningOp())) {
68 if (isPinned(declareOp))
69 return false;
70 return true;
72 } else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
73 op.getBox().getDefiningOp())) {
74 if (mlir::isa_and_nonnull<fir::AddrOfOp>(
75 declareOp.getMemref().getDefiningOp())) {
76 if (isPinned(declareOp))
77 return false;
78 return true;
81 return false;
84 static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
85 mlir::Location loc, mlir::Type toTy,
86 mlir::Value val) {
87 if (val.getType() != toTy)
88 return rewriter.create<fir::ConvertOp>(loc, toTy, val);
89 return val;
92 template <typename OpTy>
93 static mlir::LogicalResult convertOpToCall(OpTy op,
94 mlir::PatternRewriter &rewriter,
95 mlir::func::FuncOp func) {
96 auto mod = op->template getParentOfType<mlir::ModuleOp>();
97 fir::FirOpBuilder builder(rewriter, mod);
98 mlir::Location loc = op.getLoc();
99 auto fTy = func.getFunctionType();
101 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
102 mlir::Value sourceLine;
103 if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>)
104 sourceLine = fir::factory::locationToLineNo(
105 builder, loc, op.getSource() ? fTy.getInput(6) : fTy.getInput(5));
106 else
107 sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
109 mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
110 : builder.createBool(loc, false);
112 mlir::Value errmsg;
113 if (op.getErrmsg()) {
114 errmsg = op.getErrmsg();
115 } else {
116 mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
117 errmsg = builder.create<fir::AbsentOp>(loc, boxNoneTy).getResult();
119 llvm::SmallVector<mlir::Value> args;
120 if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) {
121 if (op.getSource()) {
122 mlir::Value stream =
123 op.getStream()
124 ? op.getStream()
125 : builder.createIntegerConstant(loc, fTy.getInput(2), -1);
126 args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(),
127 op.getSource(), stream, hasStat,
128 errmsg, sourceFile, sourceLine);
129 } else {
130 mlir::Value stream =
131 op.getStream()
132 ? op.getStream()
133 : builder.createIntegerConstant(loc, fTy.getInput(1), -1);
134 args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(),
135 stream, hasStat, errmsg, sourceFile,
136 sourceLine);
138 } else {
139 args =
140 fir::runtime::createArguments(builder, loc, fTy, op.getBox(), hasStat,
141 errmsg, sourceFile, sourceLine);
143 auto callOp = builder.create<fir::CallOp>(loc, func, args);
144 rewriter.replaceOp(op, callOp);
145 return mlir::success();
148 struct CUFAllocateOpConversion
149 : public mlir::OpRewritePattern<cuf::AllocateOp> {
150 using OpRewritePattern::OpRewritePattern;
152 mlir::LogicalResult
153 matchAndRewrite(cuf::AllocateOp op,
154 mlir::PatternRewriter &rewriter) const override {
155 // TODO: Pinned is a reference to a logical value that can be set to true
156 // when pinned allocation succeed. This will require a new entry point.
157 if (op.getPinned())
158 return mlir::failure();
160 auto mod = op->getParentOfType<mlir::ModuleOp>();
161 fir::FirOpBuilder builder(rewriter, mod);
162 mlir::Location loc = op.getLoc();
164 if (hasDoubleDescriptors(op)) {
165 // Allocation for module variable are done with custom runtime entry point
166 // so the descriptors can be synchronized.
167 mlir::func::FuncOp func;
168 if (op.getSource())
169 func = fir::runtime::getRuntimeFunc<mkRTKey(
170 CUFAllocatableAllocateSourceSync)>(loc, builder);
171 else
172 func =
173 fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocateSync)>(
174 loc, builder);
175 return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
178 mlir::func::FuncOp func;
179 if (op.getSource())
180 func =
181 fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocateSource)>(
182 loc, builder);
183 else
184 func = fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
185 loc, builder);
187 return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
191 struct CUFDeallocateOpConversion
192 : public mlir::OpRewritePattern<cuf::DeallocateOp> {
193 using OpRewritePattern::OpRewritePattern;
195 mlir::LogicalResult
196 matchAndRewrite(cuf::DeallocateOp op,
197 mlir::PatternRewriter &rewriter) const override {
199 auto mod = op->getParentOfType<mlir::ModuleOp>();
200 fir::FirOpBuilder builder(rewriter, mod);
201 mlir::Location loc = op.getLoc();
203 if (hasDoubleDescriptors(op)) {
204 // Deallocation for module variable are done with custom runtime entry
205 // point so the descriptors can be synchronized.
206 mlir::func::FuncOp func =
207 fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>(
208 loc, builder);
209 return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
212 // Deallocation for local descriptor falls back on the standard runtime
213 // AllocatableDeallocate as the dedicated deallocator is set in the
214 // descriptor before the call.
215 mlir::func::FuncOp func =
216 fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
217 builder);
218 return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
222 static bool inDeviceContext(mlir::Operation *op) {
223 if (op->getParentOfType<cuf::KernelOp>())
224 return true;
225 if (auto funcOp = op->getParentOfType<mlir::gpu::GPUFuncOp>())
226 return true;
227 if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
228 if (auto cudaProcAttr =
229 funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
230 cuf::getProcAttrName())) {
231 return cudaProcAttr.getValue() != cuf::ProcAttribute::Host &&
232 cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
235 return false;
238 static int computeWidth(mlir::Location loc, mlir::Type type,
239 fir::KindMapping &kindMap) {
240 auto eleTy = fir::unwrapSequenceType(type);
241 if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
242 return t.getWidth() / 8;
243 if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
244 return t.getWidth() / 8;
245 if (eleTy.isInteger(1))
246 return 1;
247 if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
248 return kindMap.getLogicalBitsize(t.getFKind()) / 8;
249 if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
250 int elemSize =
251 mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
252 return 2 * elemSize;
254 if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
255 return kindMap.getCharacterBitsize(t.getFKind()) / 8;
256 mlir::emitError(loc, "unsupported type");
257 return 0;
260 struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
261 using OpRewritePattern::OpRewritePattern;
263 CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
264 const fir::LLVMTypeConverter *typeConverter)
265 : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
267 mlir::LogicalResult
268 matchAndRewrite(cuf::AllocOp op,
269 mlir::PatternRewriter &rewriter) const override {
271 if (inDeviceContext(op.getOperation())) {
272 // In device context just replace the cuf.alloc operation with a fir.alloc
273 // the cuf.free will be removed.
274 rewriter.replaceOpWithNewOp<fir::AllocaOp>(
275 op, op.getInType(), op.getUniqName() ? *op.getUniqName() : "",
276 op.getBindcName() ? *op.getBindcName() : "", op.getTypeparams(),
277 op.getShape());
278 return mlir::success();
281 auto mod = op->getParentOfType<mlir::ModuleOp>();
282 fir::FirOpBuilder builder(rewriter, mod);
283 mlir::Location loc = op.getLoc();
284 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
286 if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
287 // Convert scalar and known size array allocations.
288 mlir::Value bytes;
289 fir::KindMapping kindMap{fir::getKindMapping(mod)};
290 if (fir::isa_trivial(op.getInType())) {
291 int width = computeWidth(loc, op.getInType(), kindMap);
292 bytes =
293 builder.createIntegerConstant(loc, builder.getIndexType(), width);
294 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
295 op.getInType())) {
296 std::size_t size = 0;
297 if (fir::isa_derived(seqTy.getEleTy())) {
298 mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
299 size = dl->getTypeSizeInBits(structTy) / 8;
300 } else {
301 size = computeWidth(loc, seqTy.getEleTy(), kindMap);
303 mlir::Value width =
304 builder.createIntegerConstant(loc, builder.getIndexType(), size);
305 mlir::Value nbElem;
306 if (fir::sequenceWithNonConstantShape(seqTy)) {
307 assert(!op.getShape().empty() && "expect shape with dynamic arrays");
308 nbElem = builder.loadIfRef(loc, op.getShape()[0]);
309 for (unsigned i = 1; i < op.getShape().size(); ++i) {
310 nbElem = rewriter.create<mlir::arith::MulIOp>(
311 loc, nbElem, builder.loadIfRef(loc, op.getShape()[i]));
313 } else {
314 nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
315 seqTy.getConstantArraySize());
317 bytes = rewriter.create<mlir::arith::MulIOp>(loc, nbElem, width);
318 } else if (fir::isa_derived(op.getInType())) {
319 mlir::Type structTy = typeConverter->convertType(op.getInType());
320 std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8;
321 bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
322 structSize);
323 } else {
324 mlir::emitError(loc, "unsupported type in cuf.alloc\n");
326 mlir::func::FuncOp func =
327 fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
328 auto fTy = func.getFunctionType();
329 mlir::Value sourceLine =
330 fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
331 mlir::Value memTy = builder.createIntegerConstant(
332 loc, builder.getI32Type(), getMemType(op.getDataAttr()));
333 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
334 builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
335 auto callOp = builder.create<fir::CallOp>(loc, func, args);
336 auto convOp = builder.createConvert(loc, op.getResult().getType(),
337 callOp.getResult(0));
338 rewriter.replaceOp(op, convOp);
339 return mlir::success();
342 // Convert descriptor allocations to function call.
343 auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
344 mlir::func::FuncOp func =
345 fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder);
346 auto fTy = func.getFunctionType();
347 mlir::Value sourceLine =
348 fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
350 mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy);
351 std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
352 mlir::Value sizeInBytes =
353 builder.createIntegerConstant(loc, builder.getIndexType(), boxSize);
355 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
356 builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)};
357 auto callOp = builder.create<fir::CallOp>(loc, func, args);
358 auto convOp = builder.createConvert(loc, op.getResult().getType(),
359 callOp.getResult(0));
360 rewriter.replaceOp(op, convOp);
361 return mlir::success();
364 private:
365 mlir::DataLayout *dl;
366 const fir::LLVMTypeConverter *typeConverter;
369 struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
370 using OpRewritePattern::OpRewritePattern;
372 DeclareOpConversion(mlir::MLIRContext *context,
373 const mlir::SymbolTable &symtab)
374 : OpRewritePattern(context), symTab{symtab} {}
376 mlir::LogicalResult
377 matchAndRewrite(fir::DeclareOp op,
378 mlir::PatternRewriter &rewriter) const override {
379 if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
380 if (auto global = symTab.lookup<fir::GlobalOp>(
381 addrOfOp.getSymbol().getRootReference().getValue())) {
382 if (cuf::isRegisteredDeviceGlobal(global)) {
383 rewriter.setInsertionPointAfter(addrOfOp);
384 auto mod = op->getParentOfType<mlir::ModuleOp>();
385 fir::FirOpBuilder builder(rewriter, mod);
386 mlir::Location loc = op.getLoc();
387 mlir::func::FuncOp callee =
388 fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(
389 loc, builder);
390 auto fTy = callee.getFunctionType();
391 mlir::Type toTy = fTy.getInput(0);
392 mlir::Value inputArg =
393 createConvertOp(rewriter, loc, toTy, addrOfOp.getResult());
394 mlir::Value sourceFile =
395 fir::factory::locationToFilename(builder, loc);
396 mlir::Value sourceLine =
397 fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
398 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
399 builder, loc, fTy, inputArg, sourceFile, sourceLine)};
400 auto call = rewriter.create<fir::CallOp>(loc, callee, args);
401 mlir::Value cast = createConvertOp(
402 rewriter, loc, op.getMemref().getType(), call->getResult(0));
403 rewriter.startOpModification(op);
404 op.getMemrefMutable().assign(cast);
405 rewriter.finalizeOpModification(op);
406 return success();
410 return failure();
413 private:
414 const mlir::SymbolTable &symTab;
417 struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
418 using OpRewritePattern::OpRewritePattern;
420 mlir::LogicalResult
421 matchAndRewrite(cuf::FreeOp op,
422 mlir::PatternRewriter &rewriter) const override {
423 if (inDeviceContext(op.getOperation())) {
424 rewriter.eraseOp(op);
425 return mlir::success();
428 if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
429 return failure();
431 auto mod = op->getParentOfType<mlir::ModuleOp>();
432 fir::FirOpBuilder builder(rewriter, mod);
433 mlir::Location loc = op.getLoc();
434 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
436 auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
437 if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
438 mlir::func::FuncOp func =
439 fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
440 auto fTy = func.getFunctionType();
441 mlir::Value sourceLine =
442 fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
443 mlir::Value memTy = builder.createIntegerConstant(
444 loc, builder.getI32Type(), getMemType(op.getDataAttr()));
445 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
446 builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
447 builder.create<fir::CallOp>(loc, func, args);
448 rewriter.eraseOp(op);
449 return mlir::success();
452 // Convert cuf.free on descriptors.
453 mlir::func::FuncOp func =
454 fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder);
455 auto fTy = func.getFunctionType();
456 mlir::Value sourceLine =
457 fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
458 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
459 builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)};
460 builder.create<fir::CallOp>(loc, func, args);
461 rewriter.eraseOp(op);
462 return mlir::success();
466 static bool isDstGlobal(cuf::DataTransferOp op) {
467 if (auto declareOp = op.getDst().getDefiningOp<fir::DeclareOp>())
468 if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
469 return true;
470 if (auto declareOp = op.getDst().getDefiningOp<hlfir::DeclareOp>())
471 if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
472 return true;
473 return false;
476 static mlir::Value getShapeFromDecl(mlir::Value src) {
477 if (auto declareOp = src.getDefiningOp<fir::DeclareOp>())
478 return declareOp.getShape();
479 if (auto declareOp = src.getDefiningOp<hlfir::DeclareOp>())
480 return declareOp.getShape();
481 return mlir::Value{};
484 static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
485 cuf::DataTransferOp op,
486 const mlir::SymbolTable &symtab) {
487 auto mod = op->getParentOfType<mlir::ModuleOp>();
488 mlir::Location loc = op.getLoc();
489 fir::FirOpBuilder builder(rewriter, mod);
490 mlir::Value addr;
491 mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
492 if (fir::isa_trivial(srcTy) &&
493 mlir::matchPattern(op.getSrc().getDefiningOp(), mlir::m_Constant())) {
494 mlir::Value src = op.getSrc();
495 if (srcTy.isInteger(1)) {
496 // i1 is not a supported type in the descriptor and it is actually coming
497 // from a LOGICAL constant. Store it as a fir.logical.
498 srcTy = fir::LogicalType::get(rewriter.getContext(), 4);
499 src = createConvertOp(rewriter, loc, srcTy, src);
501 // Put constant in memory if it is not.
502 mlir::Value alloc = builder.createTemporary(loc, srcTy);
503 builder.create<fir::StoreOp>(loc, src, alloc);
504 addr = alloc;
505 } else {
506 addr = op.getSrc();
508 llvm::SmallVector<mlir::Value> lenParams;
509 mlir::Type boxTy = fir::BoxType::get(srcTy);
510 mlir::Value box =
511 builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getSrc()),
512 /*slice=*/nullptr, lenParams,
513 /*tdesc=*/nullptr);
514 mlir::Value src = builder.createTemporary(loc, box.getType());
515 builder.create<fir::StoreOp>(loc, box, src);
516 return src;
519 static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
520 cuf::DataTransferOp op,
521 const mlir::SymbolTable &symtab) {
522 auto mod = op->getParentOfType<mlir::ModuleOp>();
523 mlir::Location loc = op.getLoc();
524 fir::FirOpBuilder builder(rewriter, mod);
525 mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
526 mlir::Value dstAddr = op.getDst();
527 mlir::Type dstBoxTy = fir::BoxType::get(dstTy);
528 llvm::SmallVector<mlir::Value> lenParams;
529 mlir::Value dstBox =
530 builder.createBox(loc, dstBoxTy, dstAddr, getShapeFromDecl(op.getDst()),
531 /*slice=*/nullptr, lenParams,
532 /*tdesc=*/nullptr);
533 mlir::Value dst = builder.createTemporary(loc, dstBox.getType());
534 builder.create<fir::StoreOp>(loc, dstBox, dst);
535 return dst;
538 struct CUFDataTransferOpConversion
539 : public mlir::OpRewritePattern<cuf::DataTransferOp> {
540 using OpRewritePattern::OpRewritePattern;
542 CUFDataTransferOpConversion(mlir::MLIRContext *context,
543 const mlir::SymbolTable &symtab,
544 mlir::DataLayout *dl,
545 const fir::LLVMTypeConverter *typeConverter)
546 : OpRewritePattern(context), symtab{symtab}, dl{dl},
547 typeConverter{typeConverter} {}
549 mlir::LogicalResult
550 matchAndRewrite(cuf::DataTransferOp op,
551 mlir::PatternRewriter &rewriter) const override {
553 mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
554 mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
556 mlir::Location loc = op.getLoc();
557 unsigned mode = 0;
558 if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) {
559 mode = kHostToDevice;
560 } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceHost) {
561 mode = kDeviceToHost;
562 } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceDevice) {
563 mode = kDeviceToDevice;
564 } else {
565 mlir::emitError(loc, "unsupported transfer kind\n");
568 auto mod = op->getParentOfType<mlir::ModuleOp>();
569 fir::FirOpBuilder builder(rewriter, mod);
570 fir::KindMapping kindMap{fir::getKindMapping(mod)};
571 mlir::Value modeValue =
572 builder.createIntegerConstant(loc, builder.getI32Type(), mode);
574 // Convert data transfer without any descriptor.
575 if (!mlir::isa<fir::BaseBoxType>(srcTy) &&
576 !mlir::isa<fir::BaseBoxType>(dstTy)) {
578 if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) {
579 // Initialization of an array from a scalar value should be implemented
580 // via a kernel launch. Use the flan runtime via the Assign function
581 // until we have more infrastructure.
582 mlir::Value src = emboxSrc(rewriter, op, symtab);
583 mlir::Value dst = emboxDst(rewriter, op, symtab);
584 mlir::func::FuncOp func =
585 fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
586 loc, builder);
587 auto fTy = func.getFunctionType();
588 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
589 mlir::Value sourceLine =
590 fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
591 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
592 builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
593 builder.create<fir::CallOp>(loc, func, args);
594 rewriter.eraseOp(op);
595 return mlir::success();
598 mlir::Type i64Ty = builder.getI64Type();
599 mlir::Value nbElement;
600 if (op.getShape()) {
601 llvm::SmallVector<mlir::Value> extents;
602 if (auto shapeOp =
603 mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp())) {
604 extents = shapeOp.getExtents();
605 } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
606 op.getShape().getDefiningOp())) {
607 for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
608 if (i.index() & 1)
609 extents.push_back(i.value());
612 nbElement = rewriter.create<fir::ConvertOp>(loc, i64Ty, extents[0]);
613 for (unsigned i = 1; i < extents.size(); ++i) {
614 auto operand =
615 rewriter.create<fir::ConvertOp>(loc, i64Ty, extents[i]);
616 nbElement =
617 rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
619 } else {
620 if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy))
621 nbElement = builder.createIntegerConstant(
622 loc, i64Ty, seqTy.getConstantArraySize());
624 unsigned width = 0;
625 if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
626 mlir::Type structTy =
627 typeConverter->convertType(fir::unwrapSequenceType(dstTy));
628 width = dl->getTypeSizeInBits(structTy) / 8;
629 } else {
630 width = computeWidth(loc, dstTy, kindMap);
632 mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
633 loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
634 mlir::Value bytes =
635 nbElement
636 ? rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue)
637 : widthValue;
639 mlir::func::FuncOp func =
640 fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrPtr)>(loc,
641 builder);
642 auto fTy = func.getFunctionType();
643 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
644 mlir::Value sourceLine =
645 fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
647 mlir::Value dst = op.getDst();
648 mlir::Value src = op.getSrc();
649 // Materialize the src if constant.
650 if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) {
651 mlir::Value temp = builder.createTemporary(loc, srcTy);
652 builder.create<fir::StoreOp>(loc, src, temp);
653 src = temp;
655 llvm::SmallVector<mlir::Value> args{
656 fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
657 modeValue, sourceFile, sourceLine)};
658 builder.create<fir::CallOp>(loc, func, args);
659 rewriter.eraseOp(op);
660 return mlir::success();
663 auto materializeBoxIfNeeded = [&](mlir::Value val) -> mlir::Value {
664 if (mlir::isa<fir::EmboxOp, fir::ReboxOp>(val.getDefiningOp())) {
665 // Materialize the box to memory to be able to call the runtime.
666 mlir::Value box = builder.createTemporary(loc, val.getType());
667 builder.create<fir::StoreOp>(loc, val, box);
668 return box;
670 return val;
673 // Conversion of data transfer involving at least one descriptor.
674 if (mlir::isa<fir::BaseBoxType>(dstTy)) {
675 // Transfer to a descriptor.
676 mlir::func::FuncOp func =
677 isDstGlobal(op)
678 ? fir::runtime::getRuntimeFunc<mkRTKey(
679 CUFDataTransferGlobalDescDesc)>(loc, builder)
680 : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
681 loc, builder);
682 mlir::Value dst = op.getDst();
683 mlir::Value src = op.getSrc();
684 if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
685 src = emboxSrc(rewriter, op, symtab);
686 if (fir::isa_trivial(srcTy))
687 func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
688 loc, builder);
691 src = materializeBoxIfNeeded(src);
692 dst = materializeBoxIfNeeded(dst);
694 auto fTy = func.getFunctionType();
695 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
696 mlir::Value sourceLine =
697 fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
698 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
699 builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
700 builder.create<fir::CallOp>(loc, func, args);
701 rewriter.eraseOp(op);
702 } else {
703 // Transfer from a descriptor.
704 mlir::Value dst = emboxDst(rewriter, op, symtab);
705 mlir::Value src = materializeBoxIfNeeded(op.getSrc());
707 mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
708 CUFDataTransferDescDescNoRealloc)>(loc, builder);
710 auto fTy = func.getFunctionType();
711 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
712 mlir::Value sourceLine =
713 fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
714 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
715 builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
716 builder.create<fir::CallOp>(loc, func, args);
717 rewriter.eraseOp(op);
719 return mlir::success();
722 private:
723 const mlir::SymbolTable &symtab;
724 mlir::DataLayout *dl;
725 const fir::LLVMTypeConverter *typeConverter;
728 struct CUFLaunchOpConversion
729 : public mlir::OpRewritePattern<cuf::KernelLaunchOp> {
730 public:
731 using OpRewritePattern::OpRewritePattern;
733 CUFLaunchOpConversion(mlir::MLIRContext *context,
734 const mlir::SymbolTable &symTab)
735 : OpRewritePattern(context), symTab{symTab} {}
737 mlir::LogicalResult
738 matchAndRewrite(cuf::KernelLaunchOp op,
739 mlir::PatternRewriter &rewriter) const override {
740 mlir::Location loc = op.getLoc();
741 auto idxTy = mlir::IndexType::get(op.getContext());
742 auto zero = rewriter.create<mlir::arith::ConstantOp>(
743 loc, rewriter.getIntegerType(32), rewriter.getI32IntegerAttr(0));
744 auto gridSizeX =
745 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridX());
746 auto gridSizeY =
747 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridY());
748 auto gridSizeZ =
749 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridZ());
750 auto blockSizeX =
751 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockX());
752 auto blockSizeY =
753 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockY());
754 auto blockSizeZ =
755 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockZ());
756 auto kernelName = mlir::SymbolRefAttr::get(
757 rewriter.getStringAttr(cudaDeviceModuleName),
758 {mlir::SymbolRefAttr::get(
759 rewriter.getContext(),
760 op.getCallee().getLeafReference().getValue())});
761 mlir::Value clusterDimX, clusterDimY, clusterDimZ;
762 if (auto funcOp = symTab.lookup<mlir::func::FuncOp>(
763 op.getCallee().getLeafReference())) {
764 if (auto clusterDimsAttr = funcOp->getAttrOfType<cuf::ClusterDimsAttr>(
765 cuf::getClusterDimsAttrName())) {
766 clusterDimX = rewriter.create<mlir::arith::ConstantIndexOp>(
767 loc, clusterDimsAttr.getX().getInt());
768 clusterDimY = rewriter.create<mlir::arith::ConstantIndexOp>(
769 loc, clusterDimsAttr.getY().getInt());
770 clusterDimZ = rewriter.create<mlir::arith::ConstantIndexOp>(
771 loc, clusterDimsAttr.getZ().getInt());
774 auto gpuLaunchOp = rewriter.create<mlir::gpu::LaunchFuncOp>(
775 loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
776 mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero,
777 op.getArgs());
778 if (clusterDimX && clusterDimY && clusterDimZ) {
779 gpuLaunchOp.getClusterSizeXMutable().assign(clusterDimX);
780 gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY);
781 gpuLaunchOp.getClusterSizeZMutable().assign(clusterDimZ);
783 rewriter.replaceOp(op, gpuLaunchOp);
784 return mlir::success();
787 private:
788 const mlir::SymbolTable &symTab;
791 struct CUFSyncDescriptorOpConversion
792 : public mlir::OpRewritePattern<cuf::SyncDescriptorOp> {
793 using OpRewritePattern::OpRewritePattern;
795 mlir::LogicalResult
796 matchAndRewrite(cuf::SyncDescriptorOp op,
797 mlir::PatternRewriter &rewriter) const override {
798 auto mod = op->getParentOfType<mlir::ModuleOp>();
799 fir::FirOpBuilder builder(rewriter, mod);
800 mlir::Location loc = op.getLoc();
802 auto globalOp = mod.lookupSymbol<fir::GlobalOp>(op.getGlobalName());
803 if (!globalOp)
804 return mlir::failure();
806 auto hostAddr = builder.create<fir::AddrOfOp>(
807 loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName());
808 mlir::func::FuncOp callee =
809 fir::runtime::getRuntimeFunc<mkRTKey(CUFSyncGlobalDescriptor)>(loc,
810 builder);
811 auto fTy = callee.getFunctionType();
812 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
813 mlir::Value sourceLine =
814 fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
815 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
816 builder, loc, fTy, hostAddr, sourceFile, sourceLine)};
817 builder.create<fir::CallOp>(loc, callee, args);
818 op.erase();
819 return mlir::success();
823 class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
824 public:
825 void runOnOperation() override {
826 auto *ctx = &getContext();
827 mlir::RewritePatternSet patterns(ctx);
828 mlir::ConversionTarget target(*ctx);
830 mlir::Operation *op = getOperation();
831 mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
832 if (!module)
833 return signalPassFailure();
834 mlir::SymbolTable symtab(module);
836 std::optional<mlir::DataLayout> dl =
837 fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
838 fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
839 /*forceUnifiedTBAATree=*/false, *dl);
840 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
841 mlir::gpu::GPUDialect>();
842 cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
843 patterns);
844 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
845 std::move(patterns)))) {
846 mlir::emitError(mlir::UnknownLoc::get(ctx),
847 "error in CUF op conversion\n");
848 signalPassFailure();
851 target.addDynamicallyLegalOp<fir::DeclareOp>([&](fir::DeclareOp op) {
852 if (inDeviceContext(op))
853 return true;
854 if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
855 if (auto global = symtab.lookup<fir::GlobalOp>(
856 addrOfOp.getSymbol().getRootReference().getValue())) {
857 if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(global.getType())))
858 return true;
859 if (cuf::isRegisteredDeviceGlobal(global))
860 return false;
863 return true;
866 patterns.clear();
867 cuf::populateFIRCUFConversionPatterns(symtab, patterns);
868 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
869 std::move(patterns)))) {
870 mlir::emitError(mlir::UnknownLoc::get(ctx),
871 "error in CUF op conversion\n");
872 signalPassFailure();
876 } // namespace
878 void cuf::populateCUFToFIRConversionPatterns(
879 const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
880 const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
881 patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
882 patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion,
883 CUFFreeOpConversion, CUFSyncDescriptorOpConversion>(
884 patterns.getContext());
885 patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
886 &dl, &converter);
887 patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
890 void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
891 mlir::RewritePatternSet &patterns) {
892 patterns.insert<DeclareOpConversion>(patterns.getContext(), symtab);