[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / mlir / lib / Target / LLVMIR / Dialect / OpenACC / OpenACCToLLVMIRTranslation.cpp
blobb964d1c082b200b249b2b33e165297c7234f1124
1 //===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenACC dialect and LLVM
10 // IR.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/OpenACC/OpenACC.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Support/LLVM.h"
20 #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
21 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
22 #include "mlir/Transforms/RegionUtils.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Frontend/OpenMP/OMPConstants.h"
26 #include "llvm/Support/FormatVariadic.h"
28 using namespace mlir;
30 using OpenACCIRBuilder = llvm::OpenMPIRBuilder;
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
36 /// Flag values are extracted from openmp/libomptarget/include/omptarget.h and
37 /// mapped to corresponding OpenACC flags.
38 static constexpr uint64_t kCreateFlag = 0x000;
39 static constexpr uint64_t kDeviceCopyinFlag = 0x001;
40 static constexpr uint64_t kHostCopyoutFlag = 0x002;
41 static constexpr uint64_t kPresentFlag = 0x1000;
42 static constexpr uint64_t kDeleteFlag = 0x008;
43 // Runtime extension to implement the OpenACC second reference counter.
44 static constexpr uint64_t kHoldFlag = 0x2000;
46 /// Default value for the device id
47 static constexpr int64_t kDefaultDevice = -1;
49 /// Create the location struct from the operation location information.
50 static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder,
51 Operation *op) {
52 auto loc = op->getLoc();
53 auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
54 StringRef funcName = funcOp ? funcOp.getName() : "unknown";
55 uint32_t strLen;
56 llvm::Constant *locStr = mlir::LLVM::createSourceLocStrFromLocation(
57 loc, builder, funcName, strLen);
58 return builder.getOrCreateIdent(locStr, strLen);
61 /// Return the runtime function used to lower the given operation.
62 static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder,
63 Operation *op) {
64 return llvm::TypeSwitch<Operation *, llvm::Function *>(op)
65 .Case([&](acc::EnterDataOp) {
66 return builder.getOrCreateRuntimeFunctionPtr(
67 llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
69 .Case([&](acc::ExitDataOp) {
70 return builder.getOrCreateRuntimeFunctionPtr(
71 llvm::omp::OMPRTL___tgt_target_data_end_mapper);
73 .Case([&](acc::UpdateOp) {
74 return builder.getOrCreateRuntimeFunctionPtr(
75 llvm::omp::OMPRTL___tgt_target_data_update_mapper);
76 });
77 llvm_unreachable("Unknown OpenACC operation");
80 /// Extract pointer, size and mapping information from operands
81 /// to populate the future functions arguments.
82 static LogicalResult
83 processOperands(llvm::IRBuilderBase &builder,
84 LLVM::ModuleTranslation &moduleTranslation, Operation *op,
85 ValueRange operands, unsigned totalNbOperand,
86 uint64_t operandFlag, SmallVector<uint64_t> &flags,
87 SmallVectorImpl<llvm::Constant *> &names, unsigned &index,
88 struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
89 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
90 llvm::LLVMContext &ctx = builder.getContext();
91 auto *i8PtrTy = llvm::PointerType::getUnqual(ctx);
92 auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand);
93 auto *i64Ty = llvm::Type::getInt64Ty(ctx);
94 auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand);
96 for (Value data : operands) {
97 llvm::Value *dataValue = moduleTranslation.lookupValue(data);
99 llvm::Value *dataPtrBase;
100 llvm::Value *dataPtr;
101 llvm::Value *dataSize;
103 if (isa<LLVM::LLVMPointerType>(data.getType())) {
104 dataPtrBase = dataValue;
105 dataPtr = dataValue;
106 dataSize = accBuilder->getSizeInBytes(dataValue);
107 } else {
108 return op->emitOpError()
109 << "Data operand must be legalized before translation."
110 << "Unsupported type: " << data.getType();
113 // Store base pointer extracted from operand into the i-th position of
114 // argBase.
115 llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP(
116 arrI8PtrTy, mapperAllocas.ArgsBase,
117 {builder.getInt32(0), builder.getInt32(index)});
118 builder.CreateStore(dataPtrBase, ptrBaseGEP);
120 // Store pointer extracted from operand into the i-th position of args.
121 llvm::Value *ptrGEP = builder.CreateInBoundsGEP(
122 arrI8PtrTy, mapperAllocas.Args,
123 {builder.getInt32(0), builder.getInt32(index)});
124 builder.CreateStore(dataPtr, ptrGEP);
126 // Store size extracted from operand into the i-th position of argSizes.
127 llvm::Value *sizeGEP = builder.CreateInBoundsGEP(
128 arrI64Ty, mapperAllocas.ArgSizes,
129 {builder.getInt32(0), builder.getInt32(index)});
130 builder.CreateStore(dataSize, sizeGEP);
132 flags.push_back(operandFlag);
133 llvm::Constant *mapName =
134 mlir::LLVM::createMappingInformation(data.getLoc(), *accBuilder);
135 names.push_back(mapName);
136 ++index;
138 return success();
141 /// Process data operands from acc::EnterDataOp
142 static LogicalResult
143 processDataOperands(llvm::IRBuilderBase &builder,
144 LLVM::ModuleTranslation &moduleTranslation,
145 acc::EnterDataOp op, SmallVector<uint64_t> &flags,
146 SmallVectorImpl<llvm::Constant *> &names,
147 struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
148 // TODO add `create_zero` and `attach` operands
150 unsigned index = 0;
152 // Create operands are handled as `alloc` call.
153 // Copyin operands are handled as `to` call.
154 llvm::SmallVector<mlir::Value> create, copyin;
155 for (mlir::Value dataOp : op.getDataClauseOperands()) {
156 if (auto createOp =
157 mlir::dyn_cast_or_null<acc::CreateOp>(dataOp.getDefiningOp())) {
158 create.push_back(createOp.getVarPtr());
159 } else if (auto copyinOp = mlir::dyn_cast_or_null<acc::CopyinOp>(
160 dataOp.getDefiningOp())) {
161 copyin.push_back(copyinOp.getVarPtr());
165 auto nbTotalOperands = create.size() + copyin.size();
167 // Create operands are handled as `alloc` call.
168 if (failed(processOperands(builder, moduleTranslation, op, create,
169 nbTotalOperands, kCreateFlag, flags, names, index,
170 mapperAllocas)))
171 return failure();
173 // Copyin operands are handled as `to` call.
174 if (failed(processOperands(builder, moduleTranslation, op, copyin,
175 nbTotalOperands, kDeviceCopyinFlag, flags, names,
176 index, mapperAllocas)))
177 return failure();
179 return success();
182 /// Process data operands from acc::ExitDataOp
183 static LogicalResult
184 processDataOperands(llvm::IRBuilderBase &builder,
185 LLVM::ModuleTranslation &moduleTranslation,
186 acc::ExitDataOp op, SmallVector<uint64_t> &flags,
187 SmallVectorImpl<llvm::Constant *> &names,
188 struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
189 // TODO add `detach` operands
191 unsigned index = 0;
193 llvm::SmallVector<mlir::Value> deleteOperands, copyoutOperands;
194 for (mlir::Value dataOp : op.getDataClauseOperands()) {
195 if (auto devicePtrOp = mlir::dyn_cast_or_null<acc::GetDevicePtrOp>(
196 dataOp.getDefiningOp())) {
197 for (auto &u : devicePtrOp.getAccPtr().getUses()) {
198 if (mlir::dyn_cast_or_null<acc::DeleteOp>(u.getOwner()))
199 deleteOperands.push_back(devicePtrOp.getVarPtr());
200 else if (mlir::dyn_cast_or_null<acc::CopyoutOp>(u.getOwner()))
201 copyoutOperands.push_back(devicePtrOp.getVarPtr());
206 auto nbTotalOperands = deleteOperands.size() + copyoutOperands.size();
208 // Delete operands are handled as `delete` call.
209 if (failed(processOperands(builder, moduleTranslation, op, deleteOperands,
210 nbTotalOperands, kDeleteFlag, flags, names, index,
211 mapperAllocas)))
212 return failure();
214 // Copyout operands are handled as `from` call.
215 if (failed(processOperands(builder, moduleTranslation, op, copyoutOperands,
216 nbTotalOperands, kHostCopyoutFlag, flags, names,
217 index, mapperAllocas)))
218 return failure();
220 return success();
223 /// Process data operands from acc::UpdateOp
224 static LogicalResult
225 processDataOperands(llvm::IRBuilderBase &builder,
226 LLVM::ModuleTranslation &moduleTranslation,
227 acc::UpdateOp op, SmallVector<uint64_t> &flags,
228 SmallVectorImpl<llvm::Constant *> &names,
229 struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
230 unsigned index = 0;
232 // Host operands are handled as `from` call.
233 // Device operands are handled as `to` call.
234 llvm::SmallVector<mlir::Value> from, to;
235 for (mlir::Value dataOp : op.getDataClauseOperands()) {
236 if (auto getDevicePtrOp = mlir::dyn_cast_or_null<acc::GetDevicePtrOp>(
237 dataOp.getDefiningOp())) {
238 from.push_back(getDevicePtrOp.getVarPtr());
239 } else if (auto updateDeviceOp =
240 mlir::dyn_cast_or_null<acc::UpdateDeviceOp>(
241 dataOp.getDefiningOp())) {
242 to.push_back(updateDeviceOp.getVarPtr());
246 if (failed(processOperands(builder, moduleTranslation, op, from, from.size(),
247 kHostCopyoutFlag, flags, names, index,
248 mapperAllocas)))
249 return failure();
251 if (failed(processOperands(builder, moduleTranslation, op, to, to.size(),
252 kDeviceCopyinFlag, flags, names, index,
253 mapperAllocas)))
254 return failure();
255 return success();
258 //===----------------------------------------------------------------------===//
259 // Conversion functions
260 //===----------------------------------------------------------------------===//
262 /// Converts an OpenACC data operation into LLVM IR.
263 static LogicalResult convertDataOp(acc::DataOp &op,
264 llvm::IRBuilderBase &builder,
265 LLVM::ModuleTranslation &moduleTranslation) {
266 llvm::LLVMContext &ctx = builder.getContext();
267 auto enclosingFuncOp = op.getOperation()->getParentOfType<LLVM::LLVMFuncOp>();
268 llvm::Function *enclosingFunction =
269 moduleTranslation.lookupFunction(enclosingFuncOp.getName());
271 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
273 llvm::Value *srcLocInfo = createSourceLocationInfo(*accBuilder, op);
275 llvm::Function *beginMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr(
276 llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
278 llvm::Function *endMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr(
279 llvm::omp::OMPRTL___tgt_target_data_end_mapper);
281 // Number of arguments in the data operation.
282 unsigned totalNbOperand = op.getNumDataOperands();
284 struct OpenACCIRBuilder::MapperAllocas mapperAllocas;
285 OpenACCIRBuilder::InsertPointTy allocaIP(
286 &enclosingFunction->getEntryBlock(),
287 enclosingFunction->getEntryBlock().getFirstInsertionPt());
288 accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand,
289 mapperAllocas);
291 SmallVector<uint64_t> flags;
292 SmallVector<llvm::Constant *> names;
293 unsigned index = 0;
295 // TODO handle no_create, deviceptr and attach operands.
297 llvm::SmallVector<mlir::Value> copyin, copyout, create, present,
298 deleteOperands;
299 for (mlir::Value dataOp : op.getDataClauseOperands()) {
300 if (auto devicePtrOp = mlir::dyn_cast_or_null<acc::GetDevicePtrOp>(
301 dataOp.getDefiningOp())) {
302 for (auto &u : devicePtrOp.getAccPtr().getUses()) {
303 if (mlir::dyn_cast_or_null<acc::DeleteOp>(u.getOwner())) {
304 deleteOperands.push_back(devicePtrOp.getVarPtr());
305 } else if (mlir::dyn_cast_or_null<acc::CopyoutOp>(u.getOwner())) {
306 // TODO copyout zero currenlty handled as copyout. Update when
307 // extension available.
308 copyout.push_back(devicePtrOp.getVarPtr());
311 } else if (auto copyinOp = mlir::dyn_cast_or_null<acc::CopyinOp>(
312 dataOp.getDefiningOp())) {
313 // TODO copyin readonly currenlty handled as copyin. Update when extension
314 // available.
315 copyin.push_back(copyinOp.getVarPtr());
316 } else if (auto createOp = mlir::dyn_cast_or_null<acc::CreateOp>(
317 dataOp.getDefiningOp())) {
318 // TODO create zero currenlty handled as create. Update when extension
319 // available.
320 create.push_back(createOp.getVarPtr());
321 } else if (auto presentOp = mlir::dyn_cast_or_null<acc::PresentOp>(
322 dataOp.getDefiningOp())) {
323 present.push_back(createOp.getVarPtr());
327 auto nbTotalOperands = copyin.size() + copyout.size() + create.size() +
328 present.size() + deleteOperands.size();
330 // Copyin operands are handled as `to` call.
331 if (failed(processOperands(builder, moduleTranslation, op, copyin,
332 nbTotalOperands, kDeviceCopyinFlag | kHoldFlag,
333 flags, names, index, mapperAllocas)))
334 return failure();
336 // Delete operands are handled as `delete` call.
337 if (failed(processOperands(builder, moduleTranslation, op, deleteOperands,
338 nbTotalOperands, kDeleteFlag, flags, names, index,
339 mapperAllocas)))
340 return failure();
342 // Copyout operands are handled as `from` call.
343 if (failed(processOperands(builder, moduleTranslation, op, copyout,
344 nbTotalOperands, kHostCopyoutFlag | kHoldFlag,
345 flags, names, index, mapperAllocas)))
346 return failure();
348 // Create operands are handled as `alloc` call.
349 if (failed(processOperands(builder, moduleTranslation, op, create,
350 nbTotalOperands, kCreateFlag | kHoldFlag, flags,
351 names, index, mapperAllocas)))
352 return failure();
354 if (failed(processOperands(builder, moduleTranslation, op, present,
355 nbTotalOperands, kPresentFlag | kHoldFlag, flags,
356 names, index, mapperAllocas)))
357 return failure();
359 llvm::GlobalVariable *maptypes =
360 accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
361 llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
362 llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
363 maptypes, /*Idx0=*/0, /*Idx1=*/0);
365 llvm::GlobalVariable *mapnames =
366 accBuilder->createOffloadMapnames(names, ".offload_mapnames");
367 llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
368 llvm::ArrayType::get(llvm::PointerType::getUnqual(ctx), totalNbOperand),
369 mapnames, /*Idx0=*/0, /*Idx1=*/0);
371 // Create call to start the data region.
372 accBuilder->emitMapperCall(builder.saveIP(), beginMapperFunc, srcLocInfo,
373 maptypesArg, mapnamesArg, mapperAllocas,
374 kDefaultDevice, totalNbOperand);
376 // Convert the region.
377 llvm::BasicBlock *entryBlock = nullptr;
379 for (Block &bb : op.getRegion()) {
380 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
381 ctx, "acc.data", builder.GetInsertBlock()->getParent());
382 if (entryBlock == nullptr)
383 entryBlock = llvmBB;
384 moduleTranslation.mapBlock(&bb, llvmBB);
387 auto afterDataRegion = builder.saveIP();
389 llvm::BranchInst *sourceTerminator = builder.CreateBr(entryBlock);
391 builder.restoreIP(afterDataRegion);
392 llvm::BasicBlock *endDataBlock = llvm::BasicBlock::Create(
393 ctx, "acc.end_data", builder.GetInsertBlock()->getParent());
395 SetVector<Block *> blocks = getTopologicallySortedBlocks(op.getRegion());
396 for (Block *bb : blocks) {
397 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
398 if (bb->isEntryBlock()) {
399 assert(sourceTerminator->getNumSuccessors() == 1 &&
400 "provided entry block has multiple successors");
401 sourceTerminator->setSuccessor(0, llvmBB);
404 if (failed(
405 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
406 return failure();
409 if (isa<acc::TerminatorOp, acc::YieldOp>(bb->getTerminator()))
410 builder.CreateBr(endDataBlock);
413 // Create call to end the data region.
414 builder.SetInsertPoint(endDataBlock);
415 accBuilder->emitMapperCall(builder.saveIP(), endMapperFunc, srcLocInfo,
416 maptypesArg, mapnamesArg, mapperAllocas,
417 kDefaultDevice, totalNbOperand);
419 return success();
422 /// Converts an OpenACC standalone data operation into LLVM IR.
423 template <typename OpTy>
424 static LogicalResult
425 convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder,
426 LLVM::ModuleTranslation &moduleTranslation) {
427 auto enclosingFuncOp =
428 op.getOperation()->template getParentOfType<LLVM::LLVMFuncOp>();
429 llvm::Function *enclosingFunction =
430 moduleTranslation.lookupFunction(enclosingFuncOp.getName());
432 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
434 auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op);
435 auto *mapperFunc = getAssociatedFunction(*accBuilder, op);
437 // Number of arguments in the enter_data operation.
438 unsigned totalNbOperand = op.getNumDataOperands();
440 llvm::LLVMContext &ctx = builder.getContext();
442 struct OpenACCIRBuilder::MapperAllocas mapperAllocas;
443 OpenACCIRBuilder::InsertPointTy allocaIP(
444 &enclosingFunction->getEntryBlock(),
445 enclosingFunction->getEntryBlock().getFirstInsertionPt());
446 accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand,
447 mapperAllocas);
449 SmallVector<uint64_t> flags;
450 SmallVector<llvm::Constant *> names;
452 if (failed(processDataOperands(builder, moduleTranslation, op, flags, names,
453 mapperAllocas)))
454 return failure();
456 llvm::GlobalVariable *maptypes =
457 accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
458 llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
459 llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
460 maptypes, /*Idx0=*/0, /*Idx1=*/0);
462 llvm::GlobalVariable *mapnames =
463 accBuilder->createOffloadMapnames(names, ".offload_mapnames");
464 llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
465 llvm::ArrayType::get(llvm::PointerType::getUnqual(ctx), totalNbOperand),
466 mapnames, /*Idx0=*/0, /*Idx1=*/0);
468 accBuilder->emitMapperCall(builder.saveIP(), mapperFunc, srcLocInfo,
469 maptypesArg, mapnamesArg, mapperAllocas,
470 kDefaultDevice, totalNbOperand);
472 return success();
475 namespace {
477 /// Implementation of the dialect interface that converts operations belonging
478 /// to the OpenACC dialect to LLVM IR.
479 class OpenACCDialectLLVMIRTranslationInterface
480 : public LLVMTranslationDialectInterface {
481 public:
482 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
484 /// Translates the given operation to LLVM IR using the provided IR builder
485 /// and saving the state in `moduleTranslation`.
486 LogicalResult
487 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
488 LLVM::ModuleTranslation &moduleTranslation) const final;
491 } // namespace
493 /// Given an OpenACC MLIR operation, create the corresponding LLVM IR
494 /// (including OpenACC runtime calls).
495 LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
496 Operation *op, llvm::IRBuilderBase &builder,
497 LLVM::ModuleTranslation &moduleTranslation) const {
499 return llvm::TypeSwitch<Operation *, LogicalResult>(op)
500 .Case([&](acc::DataOp dataOp) {
501 return convertDataOp(dataOp, builder, moduleTranslation);
503 .Case([&](acc::EnterDataOp enterDataOp) {
504 return convertStandaloneDataOp<acc::EnterDataOp>(enterDataOp, builder,
505 moduleTranslation);
507 .Case([&](acc::ExitDataOp exitDataOp) {
508 return convertStandaloneDataOp<acc::ExitDataOp>(exitDataOp, builder,
509 moduleTranslation);
511 .Case([&](acc::UpdateOp updateOp) {
512 return convertStandaloneDataOp<acc::UpdateOp>(updateOp, builder,
513 moduleTranslation);
515 .Case<acc::TerminatorOp, acc::YieldOp>([](auto op) {
516 // `yield` and `terminator` can be just omitted. The block structure was
517 // created in the function that handles their parent operation.
518 assert(op->getNumOperands() == 0 &&
519 "unexpected OpenACC terminator with operands");
520 return success();
522 .Case<acc::CreateOp, acc::CopyinOp, acc::CopyoutOp, acc::DeleteOp,
523 acc::UpdateDeviceOp, acc::GetDevicePtrOp>([](auto op) {
524 // NOP
525 return success();
527 .Default([&](Operation *op) {
528 return op->emitError("unsupported OpenACC operation: ")
529 << op->getName();
533 void mlir::registerOpenACCDialectTranslation(DialectRegistry &registry) {
534 registry.insert<acc::OpenACCDialect>();
535 registry.addExtension(+[](MLIRContext *ctx, acc::OpenACCDialect *dialect) {
536 dialect->addInterfaces<OpenACCDialectLLVMIRTranslationInterface>();
540 void mlir::registerOpenACCDialectTranslation(MLIRContext &context) {
541 DialectRegistry registry;
542 registerOpenACCDialectTranslation(registry);
543 context.appendDialectRegistry(registry);