[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / mlir / lib / Target / LLVMIR / Dialect / OpenMP / OpenMPToLLVMIRTranslation.cpp
blobbfd7d65912bdbe623216a5c5830325ee53b654e5
1 //===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenMP dialect and LLVM
10 // IR.
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
16 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
17 #include "mlir/IR/IRMapping.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Support/LLVM.h"
20 #include "mlir/Support/LogicalResult.h"
21 #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
22 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
23 #include "mlir/Transforms/RegionUtils.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Frontend/OpenMP/OMPConstants.h"
28 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
29 #include "llvm/IR/DebugInfoMetadata.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/Support/FileSystem.h"
32 #include "llvm/TargetParser/Triple.h"
33 #include "llvm/Transforms/Utils/ModuleUtils.h"
35 #include <any>
36 #include <iterator>
37 #include <optional>
38 #include <utility>
40 using namespace mlir;
42 namespace {
43 static llvm::omp::ScheduleKind
44 convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
45 if (!schedKind.has_value())
46 return llvm::omp::OMP_SCHEDULE_Default;
47 switch (schedKind.value()) {
48 case omp::ClauseScheduleKind::Static:
49 return llvm::omp::OMP_SCHEDULE_Static;
50 case omp::ClauseScheduleKind::Dynamic:
51 return llvm::omp::OMP_SCHEDULE_Dynamic;
52 case omp::ClauseScheduleKind::Guided:
53 return llvm::omp::OMP_SCHEDULE_Guided;
54 case omp::ClauseScheduleKind::Auto:
55 return llvm::omp::OMP_SCHEDULE_Auto;
56 case omp::ClauseScheduleKind::Runtime:
57 return llvm::omp::OMP_SCHEDULE_Runtime;
59 llvm_unreachable("unhandled schedule clause argument");
62 /// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
63 /// insertion points for allocas.
64 class OpenMPAllocaStackFrame
65 : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
66 public:
67 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
69 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
70 : allocaInsertPoint(allocaIP) {}
71 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
74 /// ModuleTranslation stack frame containing the partial mapping between MLIR
75 /// values and their LLVM IR equivalents.
76 class OpenMPVarMappingStackFrame
77 : public LLVM::ModuleTranslation::StackFrameBase<
78 OpenMPVarMappingStackFrame> {
79 public:
80 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPVarMappingStackFrame)
82 explicit OpenMPVarMappingStackFrame(
83 const DenseMap<Value, llvm::Value *> &mapping)
84 : mapping(mapping) {}
86 DenseMap<Value, llvm::Value *> mapping;
88 } // namespace
90 /// Find the insertion point for allocas given the current insertion point for
91 /// normal operations in the builder.
92 static llvm::OpenMPIRBuilder::InsertPointTy
93 findAllocaInsertPoint(llvm::IRBuilderBase &builder,
94 const LLVM::ModuleTranslation &moduleTranslation) {
95 // If there is an alloca insertion point on stack, i.e. we are in a nested
96 // operation and a specific point was provided by some surrounding operation,
97 // use it.
98 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
99 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
100 [&](const OpenMPAllocaStackFrame &frame) {
101 allocaInsertPoint = frame.allocaInsertPoint;
102 return WalkResult::interrupt();
104 if (walkResult.wasInterrupted())
105 return allocaInsertPoint;
107 // Otherwise, insert to the entry block of the surrounding function.
108 // If the current IRBuilder InsertPoint is the function's entry, it cannot
109 // also be used for alloca insertion which would result in insertion order
110 // confusion. Create a new BasicBlock for the Builder and use the entry block
111 // for the allocs.
112 // TODO: Create a dedicated alloca BasicBlock at function creation such that
113 // we do not need to move the current InertPoint here.
114 if (builder.GetInsertBlock() ==
115 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
116 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
117 "Assuming end of basic block");
118 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
119 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
120 builder.GetInsertBlock()->getNextNode());
121 builder.CreateBr(entryBB);
122 builder.SetInsertPoint(entryBB);
125 llvm::BasicBlock &funcEntryBlock =
126 builder.GetInsertBlock()->getParent()->getEntryBlock();
127 return llvm::OpenMPIRBuilder::InsertPointTy(
128 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
131 /// Converts the given region that appears within an OpenMP dialect operation to
132 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
133 /// region, and a branch from any block with an successor-less OpenMP terminator
134 /// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
135 /// of the continuation block if provided.
136 static llvm::BasicBlock *convertOmpOpRegions(
137 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
138 LLVM::ModuleTranslation &moduleTranslation, LogicalResult &bodyGenStatus,
139 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
140 llvm::BasicBlock *continuationBlock =
141 splitBB(builder, true, "omp.region.cont");
142 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
144 llvm::LLVMContext &llvmContext = builder.getContext();
145 for (Block &bb : region) {
146 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
147 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
148 builder.GetInsertBlock()->getNextNode());
149 moduleTranslation.mapBlock(&bb, llvmBB);
152 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
154 // Terminators (namely YieldOp) may be forwarding values to the region that
155 // need to be available in the continuation block. Collect the types of these
156 // operands in preparation of creating PHI nodes.
157 SmallVector<llvm::Type *> continuationBlockPHITypes;
158 bool operandsProcessed = false;
159 unsigned numYields = 0;
160 for (Block &bb : region.getBlocks()) {
161 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
162 if (!operandsProcessed) {
163 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
164 continuationBlockPHITypes.push_back(
165 moduleTranslation.convertType(yield->getOperand(i).getType()));
167 operandsProcessed = true;
168 } else {
169 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
170 "mismatching number of values yielded from the region");
171 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
172 llvm::Type *operandType =
173 moduleTranslation.convertType(yield->getOperand(i).getType());
174 (void)operandType;
175 assert(continuationBlockPHITypes[i] == operandType &&
176 "values of mismatching types yielded from the region");
179 numYields++;
183 // Insert PHI nodes in the continuation block for any values forwarded by the
184 // terminators in this region.
185 if (!continuationBlockPHITypes.empty())
186 assert(
187 continuationBlockPHIs &&
188 "expected continuation block PHIs if converted regions yield values");
189 if (continuationBlockPHIs) {
190 llvm::IRBuilderBase::InsertPointGuard guard(builder);
191 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
192 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
193 for (llvm::Type *ty : continuationBlockPHITypes)
194 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
197 // Convert blocks one by one in topological order to ensure
198 // defs are converted before uses.
199 SetVector<Block *> blocks = getTopologicallySortedBlocks(region);
200 for (Block *bb : blocks) {
201 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
202 // Retarget the branch of the entry block to the entry block of the
203 // converted region (regions are single-entry).
204 if (bb->isEntryBlock()) {
205 assert(sourceTerminator->getNumSuccessors() == 1 &&
206 "provided entry block has multiple successors");
207 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
208 "ContinuationBlock is not the successor of the entry block");
209 sourceTerminator->setSuccessor(0, llvmBB);
212 llvm::IRBuilderBase::InsertPointGuard guard(builder);
213 if (failed(
214 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
215 bodyGenStatus = failure();
216 return continuationBlock;
219 // Special handling for `omp.yield` and `omp.terminator` (we may have more
220 // than one): they return the control to the parent OpenMP dialect operation
221 // so replace them with the branch to the continuation block. We handle this
222 // here to avoid relying inter-function communication through the
223 // ModuleTranslation class to set up the correct insertion point. This is
224 // also consistent with MLIR's idiom of handling special region terminators
225 // in the same code that handles the region-owning operation.
226 Operation *terminator = bb->getTerminator();
227 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
228 builder.CreateBr(continuationBlock);
230 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
231 (*continuationBlockPHIs)[i]->addIncoming(
232 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
235 // After all blocks have been traversed and values mapped, connect the PHI
236 // nodes to the results of preceding blocks.
237 LLVM::detail::connectPHINodes(region, moduleTranslation);
239 // Remove the blocks and values defined in this region from the mapping since
240 // they are not visible outside of this region. This allows the same region to
241 // be converted several times, that is cloned, without clashes, and slightly
242 // speeds up the lookups.
243 moduleTranslation.forgetMapping(region);
245 return continuationBlock;
248 /// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
249 static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
250 switch (kind) {
251 case omp::ClauseProcBindKind::Close:
252 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
253 case omp::ClauseProcBindKind::Master:
254 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
255 case omp::ClauseProcBindKind::Primary:
256 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
257 case omp::ClauseProcBindKind::Spread:
258 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
260 llvm_unreachable("Unknown ClauseProcBindKind kind");
263 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
264 static LogicalResult
265 convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
266 LLVM::ModuleTranslation &moduleTranslation) {
267 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
268 // TODO: support error propagation in OpenMPIRBuilder and use it instead of
269 // relying on captured variables.
270 LogicalResult bodyGenStatus = success();
272 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
273 // MasterOp has only one region associated with it.
274 auto &region = cast<omp::MasterOp>(opInst).getRegion();
275 builder.restoreIP(codeGenIP);
276 convertOmpOpRegions(region, "omp.master.region", builder, moduleTranslation,
277 bodyGenStatus);
280 // TODO: Perform finalization actions for variables. This has to be
281 // called for variables which have destructors/finalizers.
282 auto finiCB = [&](InsertPointTy codeGenIP) {};
284 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
285 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMaster(
286 ompLoc, bodyGenCB, finiCB));
287 return success();
290 /// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
291 static LogicalResult
292 convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
293 LLVM::ModuleTranslation &moduleTranslation) {
294 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
295 auto criticalOp = cast<omp::CriticalOp>(opInst);
296 // TODO: support error propagation in OpenMPIRBuilder and use it instead of
297 // relying on captured variables.
298 LogicalResult bodyGenStatus = success();
300 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
301 // CriticalOp has only one region associated with it.
302 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
303 builder.restoreIP(codeGenIP);
304 convertOmpOpRegions(region, "omp.critical.region", builder,
305 moduleTranslation, bodyGenStatus);
308 // TODO: Perform finalization actions for variables. This has to be
309 // called for variables which have destructors/finalizers.
310 auto finiCB = [&](InsertPointTy codeGenIP) {};
312 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
313 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
314 llvm::Constant *hint = nullptr;
316 // If it has a name, it probably has a hint too.
317 if (criticalOp.getNameAttr()) {
318 // The verifiers in OpenMP Dialect guarentee that all the pointers are
319 // non-null
320 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
321 auto criticalDeclareOp =
322 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
323 symbolRef);
324 hint = llvm::ConstantInt::get(
325 llvm::Type::getInt32Ty(llvmContext),
326 static_cast<int>(criticalDeclareOp.getHintVal()));
328 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical(
329 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint));
330 return success();
333 /// Returns a reduction declaration that corresponds to the given reduction
334 /// operation in the given container. Currently only supports reductions inside
335 /// WsloopOp and ParallelOp but can be easily extended as long as the given
336 /// construct implements getNumReductionVars.
337 template <typename T>
338 static std::optional<omp::DeclareReductionOp>
339 findReductionDeclInContainer(T container, omp::ReductionOp reduction) {
340 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) {
341 if (container.getReductionVars()[i] != reduction.getAccumulator())
342 continue;
344 SymbolRefAttr reductionSymbol =
345 cast<SymbolRefAttr>((*container.getReductions())[i]);
346 auto declareOp =
347 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
348 container, reductionSymbol);
349 return declareOp;
351 return std::nullopt;
354 /// Searches for a reduction in a provided region and the regions
355 /// it is nested in
356 static omp::DeclareReductionOp findReductionDecl(Operation &containerOp,
357 omp::ReductionOp reduction) {
358 std::optional<omp::DeclareReductionOp> declareOp = std::nullopt;
359 Operation *container = &containerOp;
361 while (!declareOp.has_value() && container) {
362 // Check if current container is supported for reductions searches
363 if (auto par = dyn_cast<omp::ParallelOp>(*container)) {
364 declareOp = findReductionDeclInContainer(par, reduction);
365 } else if (auto loop = dyn_cast<omp::WsloopOp>(*container)) {
366 declareOp = findReductionDeclInContainer(loop, reduction);
367 } else {
368 break;
371 // See if we can search parent for reductions as well
372 container = containerOp.getParentOp();
375 assert(declareOp.has_value() &&
376 "reduction operation must be associated with a declaration");
378 return *declareOp;
381 /// Populates `reductions` with reduction declarations used in the given loop.
382 template <typename T>
383 static void
384 collectReductionDecls(T loop,
385 SmallVectorImpl<omp::DeclareReductionOp> &reductions) {
386 std::optional<ArrayAttr> attr = loop.getReductions();
387 if (!attr)
388 return;
390 reductions.reserve(reductions.size() + loop.getNumReductionVars());
391 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
392 reductions.push_back(
393 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
394 loop, symbolRef));
398 /// Translates the blocks contained in the given region and appends them to at
399 /// the current insertion point of `builder`. The operations of the entry block
400 /// are appended to the current insertion block. If set, `continuationBlockArgs`
401 /// is populated with translated values that correspond to the values
402 /// omp.yield'ed from the region.
403 static LogicalResult inlineConvertOmpRegions(
404 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
405 LLVM::ModuleTranslation &moduleTranslation,
406 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
407 if (region.empty())
408 return success();
410 // Special case for single-block regions that don't create additional blocks:
411 // insert operations without creating additional blocks.
412 if (llvm::hasSingleElement(region)) {
413 llvm::Instruction *potentialTerminator =
414 builder.GetInsertBlock()->empty() ? nullptr
415 : &builder.GetInsertBlock()->back();
417 if (potentialTerminator && potentialTerminator->isTerminator())
418 potentialTerminator->removeFromParent();
419 moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
421 if (failed(moduleTranslation.convertBlock(
422 region.front(), /*ignoreArguments=*/true, builder)))
423 return failure();
425 // The continuation arguments are simply the translated terminator operands.
426 if (continuationBlockArgs)
427 llvm::append_range(
428 *continuationBlockArgs,
429 moduleTranslation.lookupValues(region.front().back().getOperands()));
431 // Drop the mapping that is no longer necessary so that the same region can
432 // be processed multiple times.
433 moduleTranslation.forgetMapping(region);
435 if (potentialTerminator && potentialTerminator->isTerminator())
436 potentialTerminator->insertAfter(&builder.GetInsertBlock()->back());
438 return success();
441 LogicalResult bodyGenStatus = success();
442 SmallVector<llvm::PHINode *> phis;
443 llvm::BasicBlock *continuationBlock = convertOmpOpRegions(
444 region, blockName, builder, moduleTranslation, bodyGenStatus, &phis);
445 if (failed(bodyGenStatus))
446 return failure();
447 if (continuationBlockArgs)
448 llvm::append_range(*continuationBlockArgs, phis);
449 builder.SetInsertPoint(continuationBlock,
450 continuationBlock->getFirstInsertionPt());
451 return success();
454 namespace {
455 /// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
456 /// store lambdas with capture.
457 using OwningReductionGen = std::function<llvm::OpenMPIRBuilder::InsertPointTy(
458 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
459 llvm::Value *&)>;
460 using OwningAtomicReductionGen =
461 std::function<llvm::OpenMPIRBuilder::InsertPointTy(
462 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
463 llvm::Value *)>;
464 } // namespace
466 /// Create an OpenMPIRBuilder-compatible reduction generator for the given
467 /// reduction declaration. The generator uses `builder` but ignores its
468 /// insertion point.
469 static OwningReductionGen
470 makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
471 LLVM::ModuleTranslation &moduleTranslation) {
472 // The lambda is mutable because we need access to non-const methods of decl
473 // (which aren't actually mutating it), and we must capture decl by-value to
474 // avoid the dangling reference after the parent function returns.
475 OwningReductionGen gen =
476 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
477 llvm::Value *lhs, llvm::Value *rhs,
478 llvm::Value *&result) mutable {
479 Region &reductionRegion = decl.getReductionRegion();
480 moduleTranslation.mapValue(reductionRegion.front().getArgument(0), lhs);
481 moduleTranslation.mapValue(reductionRegion.front().getArgument(1), rhs);
482 builder.restoreIP(insertPoint);
483 SmallVector<llvm::Value *> phis;
484 if (failed(inlineConvertOmpRegions(reductionRegion,
485 "omp.reduction.nonatomic.body",
486 builder, moduleTranslation, &phis)))
487 return llvm::OpenMPIRBuilder::InsertPointTy();
488 assert(phis.size() == 1);
489 result = phis[0];
490 return builder.saveIP();
492 return gen;
495 /// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
496 /// given reduction declaration. The generator uses `builder` but ignores its
497 /// insertion point. Returns null if there is no atomic region available in the
498 /// reduction declaration.
499 static OwningAtomicReductionGen
500 makeAtomicReductionGen(omp::DeclareReductionOp decl,
501 llvm::IRBuilderBase &builder,
502 LLVM::ModuleTranslation &moduleTranslation) {
503 if (decl.getAtomicReductionRegion().empty())
504 return OwningAtomicReductionGen();
506 // The lambda is mutable because we need access to non-const methods of decl
507 // (which aren't actually mutating it), and we must capture decl by-value to
508 // avoid the dangling reference after the parent function returns.
509 OwningAtomicReductionGen atomicGen =
510 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
511 llvm::Value *lhs, llvm::Value *rhs) mutable {
512 Region &atomicRegion = decl.getAtomicReductionRegion();
513 moduleTranslation.mapValue(atomicRegion.front().getArgument(0), lhs);
514 moduleTranslation.mapValue(atomicRegion.front().getArgument(1), rhs);
515 builder.restoreIP(insertPoint);
516 SmallVector<llvm::Value *> phis;
517 if (failed(inlineConvertOmpRegions(atomicRegion,
518 "omp.reduction.atomic.body", builder,
519 moduleTranslation, &phis)))
520 return llvm::OpenMPIRBuilder::InsertPointTy();
521 assert(phis.empty());
522 return builder.saveIP();
524 return atomicGen;
527 /// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
528 static LogicalResult
529 convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
530 LLVM::ModuleTranslation &moduleTranslation) {
531 auto orderedOp = cast<omp::OrderedOp>(opInst);
533 omp::ClauseDepend dependType = *orderedOp.getDependTypeVal();
534 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
535 unsigned numLoops = *orderedOp.getNumLoopsVal();
536 SmallVector<llvm::Value *> vecValues =
537 moduleTranslation.lookupValues(orderedOp.getDependVecVars());
539 size_t indexVecValues = 0;
540 while (indexVecValues < vecValues.size()) {
541 SmallVector<llvm::Value *> storeValues;
542 storeValues.reserve(numLoops);
543 for (unsigned i = 0; i < numLoops; i++) {
544 storeValues.push_back(vecValues[indexVecValues]);
545 indexVecValues++;
547 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
548 findAllocaInsertPoint(builder, moduleTranslation);
549 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
550 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
551 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
553 return success();
556 /// Converts an OpenMP 'ordered_region' operation into LLVM IR using
557 /// OpenMPIRBuilder.
558 static LogicalResult
559 convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
560 LLVM::ModuleTranslation &moduleTranslation) {
561 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
562 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
564 // TODO: The code generation for ordered simd directive is not supported yet.
565 if (orderedRegionOp.getSimd())
566 return failure();
568 // TODO: support error propagation in OpenMPIRBuilder and use it instead of
569 // relying on captured variables.
570 LogicalResult bodyGenStatus = success();
572 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
573 // OrderedOp has only one region associated with it.
574 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
575 builder.restoreIP(codeGenIP);
576 convertOmpOpRegions(region, "omp.ordered.region", builder,
577 moduleTranslation, bodyGenStatus);
580 // TODO: Perform finalization actions for variables. This has to be
581 // called for variables which have destructors/finalizers.
582 auto finiCB = [&](InsertPointTy codeGenIP) {};
584 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
585 builder.restoreIP(
586 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
587 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getSimd()));
588 return bodyGenStatus;
591 static LogicalResult
592 convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
593 LLVM::ModuleTranslation &moduleTranslation) {
594 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
595 using StorableBodyGenCallbackTy =
596 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
598 auto sectionsOp = cast<omp::SectionsOp>(opInst);
600 // TODO: Support the following clauses: private, firstprivate, lastprivate,
601 // reduction, allocate
602 if (!sectionsOp.getReductionVars().empty() || sectionsOp.getReductions() ||
603 !sectionsOp.getAllocateVars().empty() ||
604 !sectionsOp.getAllocatorsVars().empty())
605 return emitError(sectionsOp.getLoc())
606 << "reduction and allocate clauses are not supported for sections "
607 "construct";
609 LogicalResult bodyGenStatus = success();
610 SmallVector<StorableBodyGenCallbackTy> sectionCBs;
612 for (Operation &op : *sectionsOp.getRegion().begin()) {
613 auto sectionOp = dyn_cast<omp::SectionOp>(op);
614 if (!sectionOp) // omp.terminator
615 continue;
617 Region &region = sectionOp.getRegion();
618 auto sectionCB = [&region, &builder, &moduleTranslation, &bodyGenStatus](
619 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
620 builder.restoreIP(codeGenIP);
621 convertOmpOpRegions(region, "omp.section.region", builder,
622 moduleTranslation, bodyGenStatus);
624 sectionCBs.push_back(sectionCB);
627 // No sections within omp.sections operation - skip generation. This situation
628 // is only possible if there is only a terminator operation inside the
629 // sections operation
630 if (sectionCBs.empty())
631 return success();
633 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
635 // TODO: Perform appropriate actions according to the data-sharing
636 // attribute (shared, private, firstprivate, ...) of variables.
637 // Currently defaults to shared.
638 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
639 llvm::Value &vPtr,
640 llvm::Value *&replacementValue) -> InsertPointTy {
641 replacementValue = &vPtr;
642 return codeGenIP;
645 // TODO: Perform finalization actions for variables. This has to be
646 // called for variables which have destructors/finalizers.
647 auto finiCB = [&](InsertPointTy codeGenIP) {};
649 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
650 findAllocaInsertPoint(builder, moduleTranslation);
651 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
652 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSections(
653 ompLoc, allocaIP, sectionCBs, privCB, finiCB, false,
654 sectionsOp.getNowait()));
655 return bodyGenStatus;
658 /// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
659 static LogicalResult
660 convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
661 LLVM::ModuleTranslation &moduleTranslation) {
662 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
663 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
664 LogicalResult bodyGenStatus = success();
665 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
666 builder.restoreIP(codegenIP);
667 convertOmpOpRegions(singleOp.getRegion(), "omp.single.region", builder,
668 moduleTranslation, bodyGenStatus);
670 auto finiCB = [&](InsertPointTy codeGenIP) {};
672 // Handle copyprivate
673 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
674 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateFuncs();
675 llvm::SmallVector<llvm::Value *> llvmCPVars;
676 llvm::SmallVector<llvm::Function *> llvmCPFuncs;
677 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
678 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
679 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
680 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
681 llvmCPFuncs.push_back(
682 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
685 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSingle(
686 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars, llvmCPFuncs));
687 return bodyGenStatus;
690 // Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
691 static LogicalResult
692 convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
693 LLVM::ModuleTranslation &moduleTranslation) {
694 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
695 LogicalResult bodyGenStatus = success();
696 if (!op.getAllocatorsVars().empty() || op.getReductions())
697 return op.emitError("unhandled clauses for translation to LLVM IR");
699 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
700 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
701 moduleTranslation, allocaIP);
702 builder.restoreIP(codegenIP);
703 convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
704 moduleTranslation, bodyGenStatus);
707 llvm::Value *numTeamsLower = nullptr;
708 if (Value numTeamsLowerVar = op.getNumTeamsLower())
709 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
711 llvm::Value *numTeamsUpper = nullptr;
712 if (Value numTeamsUpperVar = op.getNumTeamsUpper())
713 numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
715 llvm::Value *threadLimit = nullptr;
716 if (Value threadLimitVar = op.getThreadLimit())
717 threadLimit = moduleTranslation.lookupValue(threadLimitVar);
719 llvm::Value *ifExpr = nullptr;
720 if (Value ifExprVar = op.getIfExpr())
721 ifExpr = moduleTranslation.lookupValue(ifExprVar);
723 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
724 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTeams(
725 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr));
726 return bodyGenStatus;
729 /// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
730 static LogicalResult
731 convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
732 LLVM::ModuleTranslation &moduleTranslation) {
733 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
734 LogicalResult bodyGenStatus = success();
735 if (taskOp.getUntiedAttr() || taskOp.getMergeableAttr() ||
736 taskOp.getInReductions() || taskOp.getPriority() ||
737 !taskOp.getAllocateVars().empty()) {
738 return taskOp.emitError("unhandled clauses for translation to LLVM IR");
740 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
741 // Save the alloca insertion point on ModuleTranslation stack for use in
742 // nested regions.
743 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
744 moduleTranslation, allocaIP);
746 builder.restoreIP(codegenIP);
747 convertOmpOpRegions(taskOp.getRegion(), "omp.task.region", builder,
748 moduleTranslation, bodyGenStatus);
751 SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
752 if (!taskOp.getDependVars().empty() && taskOp.getDepends()) {
753 for (auto dep :
754 llvm::zip(taskOp.getDependVars(), taskOp.getDepends()->getValue())) {
755 llvm::omp::RTLDependenceKindTy type;
756 switch (
757 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
758 case mlir::omp::ClauseTaskDepend::taskdependin:
759 type = llvm::omp::RTLDependenceKindTy::DepIn;
760 break;
761 // The OpenMP runtime requires that the codegen for 'depend' clause for
762 // 'out' dependency kind must be the same as codegen for 'depend' clause
763 // with 'inout' dependency.
764 case mlir::omp::ClauseTaskDepend::taskdependout:
765 case mlir::omp::ClauseTaskDepend::taskdependinout:
766 type = llvm::omp::RTLDependenceKindTy::DepInOut;
767 break;
769 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
770 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
771 dds.emplace_back(dd);
775 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
776 findAllocaInsertPoint(builder, moduleTranslation);
777 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
778 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTask(
779 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
780 moduleTranslation.lookupValue(taskOp.getFinalExpr()),
781 moduleTranslation.lookupValue(taskOp.getIfExpr()), dds));
782 return bodyGenStatus;
785 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
786 static LogicalResult
787 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
788 LLVM::ModuleTranslation &moduleTranslation) {
789 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
790 LogicalResult bodyGenStatus = success();
791 if (!tgOp.getTaskReductionVars().empty() || !tgOp.getAllocateVars().empty()) {
792 return tgOp.emitError("unhandled clauses for translation to LLVM IR");
794 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
795 builder.restoreIP(codegenIP);
796 convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region", builder,
797 moduleTranslation, bodyGenStatus);
799 InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
800 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
801 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTaskgroup(
802 ompLoc, allocaIP, bodyCB));
803 return bodyGenStatus;
806 /// Allocate space for privatized reduction variables.
807 template <typename T>
808 static void allocByValReductionVars(
809 T loop, llvm::IRBuilderBase &builder,
810 LLVM::ModuleTranslation &moduleTranslation,
811 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
812 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
813 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
814 DenseMap<Value, llvm::Value *> &reductionVariableMap) {
815 llvm::IRBuilderBase::InsertPointGuard guard(builder);
816 builder.restoreIP(allocaIP);
817 auto args =
818 loop.getRegion().getArguments().take_back(loop.getNumReductionVars());
820 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
821 llvm::Value *var = builder.CreateAlloca(
822 moduleTranslation.convertType(reductionDecls[i].getType()));
823 moduleTranslation.mapValue(args[i], var);
824 privateReductionVariables.push_back(var);
825 reductionVariableMap.try_emplace(loop.getReductionVars()[i], var);
829 /// Map input argument to all reduction initialization regions
830 template <typename T>
831 static void
832 mapInitializationArg(T loop, LLVM::ModuleTranslation &moduleTranslation,
833 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
834 unsigned i) {
835 // map input argument to the initialization region
836 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
837 Region &initializerRegion = reduction.getInitializerRegion();
838 Block &entry = initializerRegion.front();
839 assert(entry.getNumArguments() == 1 &&
840 "the initialization region has one argument");
842 mlir::Value mlirSource = loop.getReductionVars()[i];
843 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
844 assert(llvmSource && "lookup reduction var");
845 moduleTranslation.mapValue(entry.getArgument(0), llvmSource);
848 /// Collect reduction info
849 template <typename T>
850 static void collectReductionInfo(
851 T loop, llvm::IRBuilderBase &builder,
852 LLVM::ModuleTranslation &moduleTranslation,
853 SmallVector<omp::DeclareReductionOp> &reductionDecls,
854 SmallVector<OwningReductionGen> &owningReductionGens,
855 SmallVector<OwningAtomicReductionGen> &owningAtomicReductionGens,
856 const SmallVector<llvm::Value *> &privateReductionVariables,
857 SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos) {
858 unsigned numReductions = loop.getNumReductionVars();
860 for (unsigned i = 0; i < numReductions; ++i) {
861 owningReductionGens.push_back(
862 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
863 owningAtomicReductionGens.push_back(
864 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
867 // Collect the reduction information.
868 reductionInfos.reserve(numReductions);
869 for (unsigned i = 0; i < numReductions; ++i) {
870 llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = nullptr;
871 if (owningAtomicReductionGens[i])
872 atomicGen = owningAtomicReductionGens[i];
873 llvm::Value *variable =
874 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
875 reductionInfos.push_back(
876 {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
877 privateReductionVariables[i], owningReductionGens[i], atomicGen});
881 /// handling of DeclareReductionOp's cleanup region
882 static LogicalResult
883 inlineOmpRegionCleanup(llvm::SmallVectorImpl<Region *> &cleanupRegions,
884 llvm::ArrayRef<llvm::Value *> privateVariables,
885 LLVM::ModuleTranslation &moduleTranslation,
886 llvm::IRBuilderBase &builder, StringRef regionName,
887 bool shouldLoadCleanupRegionArg = true) {
888 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
889 if (cleanupRegion->empty())
890 continue;
892 // map the argument to the cleanup region
893 Block &entry = cleanupRegion->front();
895 llvm::Instruction *potentialTerminator =
896 builder.GetInsertBlock()->empty() ? nullptr
897 : &builder.GetInsertBlock()->back();
898 if (potentialTerminator && potentialTerminator->isTerminator())
899 builder.SetInsertPoint(potentialTerminator);
900 llvm::Value *prviateVarValue =
901 shouldLoadCleanupRegionArg
902 ? builder.CreateLoad(
903 moduleTranslation.convertType(entry.getArgument(0).getType()),
904 privateVariables[i])
905 : privateVariables[i];
907 moduleTranslation.mapValue(entry.getArgument(0), prviateVarValue);
909 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
910 moduleTranslation)))
911 return failure();
913 // clear block argument mapping in case it needs to be re-created with a
914 // different source for another use of the same reduction decl
915 moduleTranslation.forgetMapping(*cleanupRegion);
917 return success();
920 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
921 static LogicalResult
922 convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
923 LLVM::ModuleTranslation &moduleTranslation) {
924 auto wsloopOp = cast<omp::WsloopOp>(opInst);
925 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
926 const bool isByRef = wsloopOp.getByref();
928 // TODO: this should be in the op verifier instead.
929 if (loopOp.getLowerBound().empty())
930 return failure();
932 // Static is the default.
933 auto schedule =
934 wsloopOp.getScheduleVal().value_or(omp::ClauseScheduleKind::Static);
936 // Find the loop configuration.
937 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[0]);
938 llvm::Type *ivType = step->getType();
939 llvm::Value *chunk = nullptr;
940 if (wsloopOp.getScheduleChunkVar()) {
941 llvm::Value *chunkVar =
942 moduleTranslation.lookupValue(wsloopOp.getScheduleChunkVar());
943 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
946 SmallVector<omp::DeclareReductionOp> reductionDecls;
947 collectReductionDecls(wsloopOp, reductionDecls);
948 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
949 findAllocaInsertPoint(builder, moduleTranslation);
951 SmallVector<llvm::Value *> privateReductionVariables;
952 DenseMap<Value, llvm::Value *> reductionVariableMap;
953 if (!isByRef) {
954 allocByValReductionVars(wsloopOp, builder, moduleTranslation, allocaIP,
955 reductionDecls, privateReductionVariables,
956 reductionVariableMap);
959 // Before the loop, store the initial values of reductions into reduction
960 // variables. Although this could be done after allocas, we don't want to mess
961 // up with the alloca insertion point.
962 ArrayRef<BlockArgument> reductionArgs = wsloopOp.getRegion().getArguments();
963 for (unsigned i = 0; i < wsloopOp.getNumReductionVars(); ++i) {
964 SmallVector<llvm::Value *> phis;
966 // map block argument to initializer region
967 mapInitializationArg(wsloopOp, moduleTranslation, reductionDecls, i);
969 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
970 "omp.reduction.neutral", builder,
971 moduleTranslation, &phis)))
972 return failure();
973 assert(phis.size() == 1 && "expected one value to be yielded from the "
974 "reduction neutral element declaration region");
975 if (isByRef) {
976 // Allocate reduction variable (which is a pointer to the real reduction
977 // variable allocated in the inlined region)
978 llvm::Value *var = builder.CreateAlloca(
979 moduleTranslation.convertType(reductionDecls[i].getType()));
980 // Store the result of the inlined region to the allocated reduction var
981 // ptr
982 builder.CreateStore(phis[0], var);
984 privateReductionVariables.push_back(var);
985 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
986 reductionVariableMap.try_emplace(wsloopOp.getReductionVars()[i], phis[0]);
987 } else {
988 // for by-ref case the store is inside of the reduction region
989 builder.CreateStore(phis[0], privateReductionVariables[i]);
990 // the rest was handled in allocByValReductionVars
993 // forget the mapping for the initializer region because we might need a
994 // different mapping if this reduction declaration is re-used for a
995 // different variable
996 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
999 // Store the mapping between reduction variables and their private copies on
1000 // ModuleTranslation stack. It can be then recovered when translating
1001 // omp.reduce operations in a separate call.
1002 LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
1003 moduleTranslation, reductionVariableMap);
1005 // Set up the source location value for OpenMP runtime.
1006 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1008 // Generator of the canonical loop body.
1009 // TODO: support error propagation in OpenMPIRBuilder and use it instead of
1010 // relying on captured variables.
1011 SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
1012 SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
1013 LogicalResult bodyGenStatus = success();
1014 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
1015 // Make sure further conversions know about the induction variable.
1016 moduleTranslation.mapValue(
1017 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
1019 // Capture the body insertion point for use in nested loops. BodyIP of the
1020 // CanonicalLoopInfo always points to the beginning of the entry block of
1021 // the body.
1022 bodyInsertPoints.push_back(ip);
1024 if (loopInfos.size() != loopOp.getNumLoops() - 1)
1025 return;
1027 // Convert the body of the loop.
1028 builder.restoreIP(ip);
1029 convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder,
1030 moduleTranslation, bodyGenStatus);
1033 // Delegate actual loop construction to the OpenMP IRBuilder.
1034 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
1035 // loop, i.e. it has a positive step, uses signed integer semantics.
1036 // Reconsider this code when the nested loop operation clearly supports more
1037 // cases.
1038 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1039 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
1040 llvm::Value *lowerBound =
1041 moduleTranslation.lookupValue(loopOp.getLowerBound()[i]);
1042 llvm::Value *upperBound =
1043 moduleTranslation.lookupValue(loopOp.getUpperBound()[i]);
1044 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[i]);
1046 // Make sure loop trip count are emitted in the preheader of the outermost
1047 // loop at the latest so that they are all available for the new collapsed
1048 // loop will be created below.
1049 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
1050 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
1051 if (i != 0) {
1052 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
1053 computeIP = loopInfos.front()->getPreheaderIP();
1055 loopInfos.push_back(ompBuilder->createCanonicalLoop(
1056 loc, bodyGen, lowerBound, upperBound, step,
1057 /*IsSigned=*/true, loopOp.getInclusive(), computeIP));
1059 if (failed(bodyGenStatus))
1060 return failure();
1063 // Collapse loops. Store the insertion point because LoopInfos may get
1064 // invalidated.
1065 llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
1066 llvm::CanonicalLoopInfo *loopInfo =
1067 ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
1069 allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1071 // TODO: Handle doacross loops when the ordered clause has a parameter.
1072 bool isOrdered = wsloopOp.getOrderedVal().has_value();
1073 std::optional<omp::ScheduleModifier> scheduleModifier =
1074 wsloopOp.getScheduleModifier();
1075 bool isSimd = wsloopOp.getSimdModifier();
1077 ompBuilder->applyWorkshareLoop(
1078 ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
1079 convertToScheduleKind(schedule), chunk, isSimd,
1080 scheduleModifier == omp::ScheduleModifier::monotonic,
1081 scheduleModifier == omp::ScheduleModifier::nonmonotonic, isOrdered);
1083 // Continue building IR after the loop. Note that the LoopInfo returned by
1084 // `collapseLoops` points inside the outermost loop and is intended for
1085 // potential further loop transformations. Use the insertion point stored
1086 // before collapsing loops instead.
1087 builder.restoreIP(afterIP);
1089 // Process the reductions if required.
1090 if (wsloopOp.getNumReductionVars() == 0)
1091 return success();
1093 // Create the reduction generators. We need to own them here because
1094 // ReductionInfo only accepts references to the generators.
1095 SmallVector<OwningReductionGen> owningReductionGens;
1096 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1097 SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1098 collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
1099 owningReductionGens, owningAtomicReductionGens,
1100 privateReductionVariables, reductionInfos);
1102 // The call to createReductions below expects the block to have a
1103 // terminator. Create an unreachable instruction to serve as terminator
1104 // and remove it later.
1105 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1106 builder.SetInsertPoint(tempTerminator);
1107 llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
1108 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1109 wsloopOp.getNowait(), isByRef);
1110 if (!contInsertPoint.getBlock())
1111 return wsloopOp->emitOpError() << "failed to convert reductions";
1112 auto nextInsertionPoint =
1113 ompBuilder->createBarrier(contInsertPoint, llvm::omp::OMPD_for);
1114 tempTerminator->eraseFromParent();
1115 builder.restoreIP(nextInsertionPoint);
1117 // after the workshare loop, deallocate private reduction variables
1118 SmallVector<Region *> reductionRegions;
1119 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1120 [](omp::DeclareReductionOp reductionDecl) {
1121 return &reductionDecl.getCleanupRegion();
1123 return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
1124 moduleTranslation, builder,
1125 "omp.reduction.cleanup");
1128 /// A RAII class that on construction replaces the region arguments of the
1129 /// parallel op (which correspond to private variables) with the actual private
1130 /// variables they correspond to. This prepares the parallel op so that it
1131 /// matches what is expected by the OMPIRBuilder.
1133 /// On destruction, it restores the original state of the operation so that on
1134 /// the MLIR side, the op is not affected by conversion to LLVM IR.
1135 class OmpParallelOpConversionManager {
1136 public:
1137 OmpParallelOpConversionManager(omp::ParallelOp opInst)
1138 : region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
1139 privateArgBeginIdx(opInst.getNumReductionVars()),
1140 privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
1141 auto privateVarsIt = privateVars.begin();
1143 for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1144 ++argIdx, ++privateVarsIt)
1145 mlir::replaceAllUsesInRegionWith(region.getArgument(argIdx),
1146 *privateVarsIt, region);
1149 ~OmpParallelOpConversionManager() {
1150 auto privateVarsIt = privateVars.begin();
1152 for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1153 ++argIdx, ++privateVarsIt)
1154 mlir::replaceAllUsesInRegionWith(*privateVarsIt,
1155 region.getArgument(argIdx), region);
1158 private:
1159 Region &region;
1160 OperandRange privateVars;
1161 unsigned privateArgBeginIdx;
1162 unsigned privateArgEndIdx;
1165 /// Converts the OpenMP parallel operation to LLVM IR.
1166 static LogicalResult
1167 convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1168 LLVM::ModuleTranslation &moduleTranslation) {
1169 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1170 OmpParallelOpConversionManager raii(opInst);
1171 const bool isByRef = opInst.getByref();
1173 // TODO: support error propagation in OpenMPIRBuilder and use it instead of
1174 // relying on captured variables.
1175 LogicalResult bodyGenStatus = success();
1176 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1178 // Collect reduction declarations
1179 SmallVector<omp::DeclareReductionOp> reductionDecls;
1180 collectReductionDecls(opInst, reductionDecls);
1181 SmallVector<llvm::Value *> privateReductionVariables;
1183 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1184 // Allocate reduction vars
1185 DenseMap<Value, llvm::Value *> reductionVariableMap;
1186 if (!isByRef) {
1187 allocByValReductionVars(opInst, builder, moduleTranslation, allocaIP,
1188 reductionDecls, privateReductionVariables,
1189 reductionVariableMap);
1192 // Initialize reduction vars
1193 builder.restoreIP(allocaIP);
1194 MutableArrayRef<BlockArgument> reductionArgs =
1195 opInst.getRegion().getArguments().take_back(
1196 opInst.getNumReductionVars());
1197 for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
1198 SmallVector<llvm::Value *> phis;
1200 // map the block argument
1201 mapInitializationArg(opInst, moduleTranslation, reductionDecls, i);
1202 if (failed(inlineConvertOmpRegions(
1203 reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral",
1204 builder, moduleTranslation, &phis)))
1205 bodyGenStatus = failure();
1206 assert(phis.size() == 1 &&
1207 "expected one value to be yielded from the "
1208 "reduction neutral element declaration region");
1209 builder.restoreIP(allocaIP);
1211 if (isByRef) {
1212 // Allocate reduction variable (which is a pointer to the real reduciton
1213 // variable allocated in the inlined region)
1214 llvm::Value *var = builder.CreateAlloca(
1215 moduleTranslation.convertType(reductionDecls[i].getType()));
1216 // Store the result of the inlined region to the allocated reduction var
1217 // ptr
1218 builder.CreateStore(phis[0], var);
1220 privateReductionVariables.push_back(var);
1221 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1222 reductionVariableMap.try_emplace(opInst.getReductionVars()[i], phis[0]);
1223 } else {
1224 // for by-ref case the store is inside of the reduction init region
1225 builder.CreateStore(phis[0], privateReductionVariables[i]);
1226 // the rest is done in allocByValReductionVars
1229 // clear block argument mapping in case it needs to be re-created with a
1230 // different source for another use of the same reduction decl
1231 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1234 // Store the mapping between reduction variables and their private copies on
1235 // ModuleTranslation stack. It can be then recovered when translating
1236 // omp.reduce operations in a separate call.
1237 LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
1238 moduleTranslation, reductionVariableMap);
1240 // Save the alloca insertion point on ModuleTranslation stack for use in
1241 // nested regions.
1242 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
1243 moduleTranslation, allocaIP);
1245 // ParallelOp has only one region associated with it.
1246 builder.restoreIP(codeGenIP);
1247 auto regionBlock =
1248 convertOmpOpRegions(opInst.getRegion(), "omp.par.region", builder,
1249 moduleTranslation, bodyGenStatus);
1251 // Process the reductions if required.
1252 if (opInst.getNumReductionVars() > 0) {
1253 // Collect reduction info
1254 SmallVector<OwningReductionGen> owningReductionGens;
1255 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1256 SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1257 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
1258 owningReductionGens, owningAtomicReductionGens,
1259 privateReductionVariables, reductionInfos);
1261 // Move to region cont block
1262 builder.SetInsertPoint(regionBlock->getTerminator());
1264 // Generate reductions from info
1265 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1266 builder.SetInsertPoint(tempTerminator);
1268 llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
1269 ompBuilder->createReductions(builder.saveIP(), allocaIP,
1270 reductionInfos, false, isByRef);
1271 if (!contInsertPoint.getBlock()) {
1272 bodyGenStatus = opInst->emitOpError() << "failed to convert reductions";
1273 return;
1276 tempTerminator->eraseFromParent();
1277 builder.restoreIP(contInsertPoint);
1281 SmallVector<omp::PrivateClauseOp> privatizerClones;
1282 SmallVector<llvm::Value *> privateVariables;
1284 // TODO: Perform appropriate actions according to the data-sharing
1285 // attribute (shared, private, firstprivate, ...) of variables.
1286 // Currently shared and private are supported.
1287 auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
1288 llvm::Value &, llvm::Value &vPtr,
1289 llvm::Value *&replacementValue) -> InsertPointTy {
1290 replacementValue = &vPtr;
1292 // If this is a private value, this lambda will return the corresponding
1293 // mlir value and its `PrivateClauseOp`. Otherwise, empty values are
1294 // returned.
1295 auto [privVar, privatizerClone] =
1296 [&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1297 if (!opInst.getPrivateVars().empty()) {
1298 auto privVars = opInst.getPrivateVars();
1299 auto privatizers = opInst.getPrivatizers();
1301 for (auto [privVar, privatizerAttr] :
1302 llvm::zip_equal(privVars, *privatizers)) {
1303 // Find the MLIR private variable corresponding to the LLVM value
1304 // being privatized.
1305 llvm::Value *llvmPrivVar = moduleTranslation.lookupValue(privVar);
1306 if (llvmPrivVar != &vPtr)
1307 continue;
1309 SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1310 omp::PrivateClauseOp privatizer =
1311 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1312 opInst, privSym);
1314 // Clone the privatizer in case it is used by more than one parallel
1315 // region. The privatizer is processed in-place (see below) before it
1316 // gets inlined in the parallel region and therefore processing the
1317 // original op is dangerous.
1318 return {privVar, privatizer.clone()};
1322 return {mlir::Value(), omp::PrivateClauseOp()};
1323 }();
1325 if (privVar) {
1326 Region &allocRegion = privatizerClone.getAllocRegion();
1328 // If this is a `firstprivate` clause, prepare the `omp.private` op by:
1329 if (privatizerClone.getDataSharingType() ==
1330 omp::DataSharingClauseType::FirstPrivate) {
1331 auto oldAllocBackBlock = std::prev(allocRegion.end());
1332 omp::YieldOp oldAllocYieldOp =
1333 llvm::cast<omp::YieldOp>(oldAllocBackBlock->getTerminator());
1335 Region &copyRegion = privatizerClone.getCopyRegion();
1337 mlir::IRRewriter copyCloneBuilder(&moduleTranslation.getContext());
1338 // 1. Cloning the `copy` region to the end of the `alloc` region.
1339 copyCloneBuilder.cloneRegionBefore(copyRegion, allocRegion,
1340 allocRegion.end());
1342 auto newCopyRegionFrontBlock = std::next(oldAllocBackBlock);
1343 // 2. Merging the last `alloc` block with the first block in the `copy`
1344 // region clone.
1345 // 3. Re-mapping the first argument of the `copy` region to be the
1346 // argument of the `alloc` region and the second argument of the `copy`
1347 // region to be the yielded value of the `alloc` region (this is the
1348 // private clone of the privatized value).
1349 copyCloneBuilder.mergeBlocks(
1350 &*newCopyRegionFrontBlock, &*oldAllocBackBlock,
1351 {allocRegion.getArgument(0), oldAllocYieldOp.getOperand(0)});
1353 // 4. The old terminator of the `alloc` region is not needed anymore, so
1354 // delete it.
1355 oldAllocYieldOp.erase();
1358 // Replace the privatizer block argument with mlir value being privatized.
1359 // This way, the body of the privatizer will be changed from using the
1360 // region/block argument to the value being privatized.
1361 auto allocRegionArg = allocRegion.getArgument(0);
1362 replaceAllUsesInRegionWith(allocRegionArg, privVar, allocRegion);
1364 auto oldIP = builder.saveIP();
1365 builder.restoreIP(allocaIP);
1367 SmallVector<llvm::Value *, 1> yieldedValues;
1368 if (failed(inlineConvertOmpRegions(allocRegion, "omp.privatizer", builder,
1369 moduleTranslation, &yieldedValues))) {
1370 opInst.emitError("failed to inline `alloc` region of an `omp.private` "
1371 "op in the parallel region");
1372 bodyGenStatus = failure();
1373 privatizerClone.erase();
1374 } else {
1375 assert(yieldedValues.size() == 1);
1376 replacementValue = yieldedValues.front();
1378 // Keep the LLVM replacement value and the op clone in case we need to
1379 // emit cleanup (i.e. deallocation) logic.
1380 privateVariables.push_back(replacementValue);
1381 privatizerClones.push_back(privatizerClone);
1384 builder.restoreIP(oldIP);
1387 return codeGenIP;
1390 // TODO: Perform finalization actions for variables. This has to be
1391 // called for variables which have destructors/finalizers.
1392 auto finiCB = [&](InsertPointTy codeGenIP) {
1393 InsertPointTy oldIP = builder.saveIP();
1394 builder.restoreIP(codeGenIP);
1396 // if the reduction has a cleanup region, inline it here to finalize the
1397 // reduction variables
1398 SmallVector<Region *> reductionCleanupRegions;
1399 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
1400 [](omp::DeclareReductionOp reductionDecl) {
1401 return &reductionDecl.getCleanupRegion();
1403 if (failed(inlineOmpRegionCleanup(
1404 reductionCleanupRegions, privateReductionVariables,
1405 moduleTranslation, builder, "omp.reduction.cleanup")))
1406 bodyGenStatus = failure();
1408 SmallVector<Region *> privateCleanupRegions;
1409 llvm::transform(privatizerClones, std::back_inserter(privateCleanupRegions),
1410 [](omp::PrivateClauseOp privatizer) {
1411 return &privatizer.getDeallocRegion();
1414 if (failed(inlineOmpRegionCleanup(
1415 privateCleanupRegions, privateVariables, moduleTranslation, builder,
1416 "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
1417 bodyGenStatus = failure();
1419 builder.restoreIP(oldIP);
1422 llvm::Value *ifCond = nullptr;
1423 if (auto ifExprVar = opInst.getIfExprVar())
1424 ifCond = moduleTranslation.lookupValue(ifExprVar);
1425 llvm::Value *numThreads = nullptr;
1426 if (auto numThreadsVar = opInst.getNumThreadsVar())
1427 numThreads = moduleTranslation.lookupValue(numThreadsVar);
1428 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1429 if (auto bind = opInst.getProcBindVal())
1430 pbKind = getProcBindKind(*bind);
1431 // TODO: Is the Parallel construct cancellable?
1432 bool isCancellable = false;
1434 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1435 findAllocaInsertPoint(builder, moduleTranslation);
1436 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1438 builder.restoreIP(
1439 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
1440 ifCond, numThreads, pbKind, isCancellable));
1442 for (mlir::omp::PrivateClauseOp privatizerClone : privatizerClones)
1443 privatizerClone.erase();
1445 return bodyGenStatus;
1448 /// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
1449 static LogicalResult
1450 convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
1451 LLVM::ModuleTranslation &moduleTranslation) {
1452 auto simdOp = cast<omp::SimdOp>(opInst);
1453 auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop());
1455 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1457 // Generator of the canonical loop body.
1458 // TODO: support error propagation in OpenMPIRBuilder and use it instead of
1459 // relying on captured variables.
1460 SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
1461 SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
1462 LogicalResult bodyGenStatus = success();
1463 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
1464 // Make sure further conversions know about the induction variable.
1465 moduleTranslation.mapValue(
1466 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
1468 // Capture the body insertion point for use in nested loops. BodyIP of the
1469 // CanonicalLoopInfo always points to the beginning of the entry block of
1470 // the body.
1471 bodyInsertPoints.push_back(ip);
1473 if (loopInfos.size() != loopOp.getNumLoops() - 1)
1474 return;
1476 // Convert the body of the loop.
1477 builder.restoreIP(ip);
1478 convertOmpOpRegions(loopOp.getRegion(), "omp.simd.region", builder,
1479 moduleTranslation, bodyGenStatus);
1482 // Delegate actual loop construction to the OpenMP IRBuilder.
1483 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
1484 // loop, i.e. it has a positive step, uses signed integer semantics.
1485 // Reconsider this code when the nested loop operation clearly supports more
1486 // cases.
1487 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1488 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
1489 llvm::Value *lowerBound =
1490 moduleTranslation.lookupValue(loopOp.getLowerBound()[i]);
1491 llvm::Value *upperBound =
1492 moduleTranslation.lookupValue(loopOp.getUpperBound()[i]);
1493 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[i]);
1495 // Make sure loop trip count are emitted in the preheader of the outermost
1496 // loop at the latest so that they are all available for the new collapsed
1497 // loop will be created below.
1498 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
1499 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
1500 if (i != 0) {
1501 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
1502 ompLoc.DL);
1503 computeIP = loopInfos.front()->getPreheaderIP();
1505 loopInfos.push_back(ompBuilder->createCanonicalLoop(
1506 loc, bodyGen, lowerBound, upperBound, step,
1507 /*IsSigned=*/true, /*Inclusive=*/true, computeIP));
1509 if (failed(bodyGenStatus))
1510 return failure();
1513 // Collapse loops.
1514 llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
1515 llvm::CanonicalLoopInfo *loopInfo =
1516 ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
1518 llvm::ConstantInt *simdlen = nullptr;
1519 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
1520 simdlen = builder.getInt64(simdlenVar.value());
1522 llvm::ConstantInt *safelen = nullptr;
1523 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
1524 safelen = builder.getInt64(safelenVar.value());
1526 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
1527 ompBuilder->applySimd(
1528 loopInfo, alignedVars,
1529 simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr())
1530 : nullptr,
1531 llvm::omp::OrderKind::OMP_ORDER_unknown, simdlen, safelen);
1533 builder.restoreIP(afterIP);
1534 return success();
1537 /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
1538 static llvm::AtomicOrdering
1539 convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
1540 if (!ao)
1541 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
1543 switch (*ao) {
1544 case omp::ClauseMemoryOrderKind::Seq_cst:
1545 return llvm::AtomicOrdering::SequentiallyConsistent;
1546 case omp::ClauseMemoryOrderKind::Acq_rel:
1547 return llvm::AtomicOrdering::AcquireRelease;
1548 case omp::ClauseMemoryOrderKind::Acquire:
1549 return llvm::AtomicOrdering::Acquire;
1550 case omp::ClauseMemoryOrderKind::Release:
1551 return llvm::AtomicOrdering::Release;
1552 case omp::ClauseMemoryOrderKind::Relaxed:
1553 return llvm::AtomicOrdering::Monotonic;
1555 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
1558 /// Convert omp.atomic.read operation to LLVM IR.
1559 static LogicalResult
1560 convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
1561 LLVM::ModuleTranslation &moduleTranslation) {
1563 auto readOp = cast<omp::AtomicReadOp>(opInst);
1564 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1566 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1568 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrderVal());
1569 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
1570 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
1572 llvm::Type *elementType =
1573 moduleTranslation.convertType(readOp.getElementType());
1575 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
1576 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
1577 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO));
1578 return success();
1581 /// Converts an omp.atomic.write operation to LLVM IR.
1582 static LogicalResult
1583 convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
1584 LLVM::ModuleTranslation &moduleTranslation) {
1585 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
1586 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1588 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1589 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrderVal());
1590 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
1591 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
1592 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
1593 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
1594 /*isVolatile=*/false};
1595 builder.restoreIP(ompBuilder->createAtomicWrite(ompLoc, x, expr, ao));
1596 return success();
1599 /// Converts an LLVM dialect binary operation to the corresponding enum value
1600 /// for `atomicrmw` supported binary operation.
1601 llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
1602 return llvm::TypeSwitch<Operation *, llvm::AtomicRMWInst::BinOp>(&op)
1603 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
1604 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
1605 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
1606 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
1607 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
1608 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
1609 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
1610 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
1611 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
1612 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
1615 /// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
1616 static LogicalResult
1617 convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
1618 llvm::IRBuilderBase &builder,
1619 LLVM::ModuleTranslation &moduleTranslation) {
1620 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1622 // Convert values and types.
1623 auto &innerOpList = opInst.getRegion().front().getOperations();
1624 bool isRegionArgUsed{false}, isXBinopExpr{false};
1625 llvm::AtomicRMWInst::BinOp binop;
1626 mlir::Value mlirExpr;
1627 // Find the binary update operation that uses the region argument
1628 // and get the expression to update
1629 for (Operation &innerOp : innerOpList) {
1630 if (innerOp.getNumOperands() == 2) {
1631 binop = convertBinOpToAtomic(innerOp);
1632 if (!llvm::is_contained(innerOp.getOperands(),
1633 opInst.getRegion().getArgument(0)))
1634 continue;
1635 isRegionArgUsed = true;
1636 isXBinopExpr = innerOp.getNumOperands() > 0 &&
1637 innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
1638 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
1639 break;
1642 if (!isRegionArgUsed)
1643 return opInst.emitError("no atomic update operation with region argument"
1644 " as operand found inside atomic.update region");
1646 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
1647 llvm::Value *llvmX = moduleTranslation.lookupValue(opInst.getX());
1648 llvm::Type *llvmXElementType = moduleTranslation.convertType(
1649 opInst.getRegion().getArgument(0).getType());
1650 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
1651 /*isSigned=*/false,
1652 /*isVolatile=*/false};
1654 llvm::AtomicOrdering atomicOrdering =
1655 convertAtomicOrdering(opInst.getMemoryOrderVal());
1657 // Generate update code.
1658 LogicalResult updateGenStatus = success();
1659 auto updateFn = [&opInst, &moduleTranslation, &updateGenStatus](
1660 llvm::Value *atomicx,
1661 llvm::IRBuilder<> &builder) -> llvm::Value * {
1662 Block &bb = *opInst.getRegion().begin();
1663 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
1664 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
1665 if (failed(moduleTranslation.convertBlock(bb, true, builder))) {
1666 updateGenStatus = (opInst.emitError()
1667 << "unable to convert update operation to llvm IR");
1668 return nullptr;
1670 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
1671 assert(yieldop && yieldop.getResults().size() == 1 &&
1672 "terminator must be omp.yield op and it must have exactly one "
1673 "argument");
1674 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
1677 // Handle ambiguous alloca, if any.
1678 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1679 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1680 builder.restoreIP(ompBuilder->createAtomicUpdate(
1681 ompLoc, allocaIP, llvmAtomicX, llvmExpr, atomicOrdering, binop, updateFn,
1682 isXBinopExpr));
1683 return updateGenStatus;
1686 static LogicalResult
1687 convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
1688 llvm::IRBuilderBase &builder,
1689 LLVM::ModuleTranslation &moduleTranslation) {
1690 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1691 mlir::Value mlirExpr;
1692 bool isXBinopExpr = false, isPostfixUpdate = false;
1693 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
1695 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
1696 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
1698 assert((atomicUpdateOp || atomicWriteOp) &&
1699 "internal op must be an atomic.update or atomic.write op");
1701 if (atomicWriteOp) {
1702 isPostfixUpdate = true;
1703 mlirExpr = atomicWriteOp.getExpr();
1704 } else {
1705 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
1706 atomicCaptureOp.getAtomicUpdateOp().getOperation();
1707 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
1708 bool isRegionArgUsed{false};
1709 // Find the binary update operation that uses the region argument
1710 // and get the expression to update
1711 for (Operation &innerOp : innerOpList) {
1712 if (innerOp.getNumOperands() == 2) {
1713 binop = convertBinOpToAtomic(innerOp);
1714 if (!llvm::is_contained(innerOp.getOperands(),
1715 atomicUpdateOp.getRegion().getArgument(0)))
1716 continue;
1717 isRegionArgUsed = true;
1718 isXBinopExpr =
1719 innerOp.getNumOperands() > 0 &&
1720 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
1721 mlirExpr =
1722 (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
1723 break;
1726 if (!isRegionArgUsed)
1727 return atomicUpdateOp.emitError(
1728 "no atomic update operation with region argument"
1729 " as operand found inside atomic.update region");
1732 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
1733 llvm::Value *llvmX =
1734 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
1735 llvm::Value *llvmV =
1736 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
1737 llvm::Type *llvmXElementType = moduleTranslation.convertType(
1738 atomicCaptureOp.getAtomicReadOp().getElementType());
1739 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
1740 /*isSigned=*/false,
1741 /*isVolatile=*/false};
1742 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
1743 /*isSigned=*/false,
1744 /*isVolatile=*/false};
1746 llvm::AtomicOrdering atomicOrdering =
1747 convertAtomicOrdering(atomicCaptureOp.getMemoryOrderVal());
1749 LogicalResult updateGenStatus = success();
1750 auto updateFn = [&](llvm::Value *atomicx,
1751 llvm::IRBuilder<> &builder) -> llvm::Value * {
1752 if (atomicWriteOp)
1753 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
1754 Block &bb = *atomicUpdateOp.getRegion().begin();
1755 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
1756 atomicx);
1757 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
1758 if (failed(moduleTranslation.convertBlock(bb, true, builder))) {
1759 updateGenStatus = (atomicUpdateOp.emitError()
1760 << "unable to convert update operation to llvm IR");
1761 return nullptr;
1763 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
1764 assert(yieldop && yieldop.getResults().size() == 1 &&
1765 "terminator must be omp.yield op and it must have exactly one "
1766 "argument");
1767 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
1770 // Handle ambiguous alloca, if any.
1771 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1772 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1773 builder.restoreIP(ompBuilder->createAtomicCapture(
1774 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
1775 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr));
1776 return updateGenStatus;
1779 /// Converts an OpenMP reduction operation using OpenMPIRBuilder. Expects the
1780 /// mapping between reduction variables and their private equivalents to have
1781 /// been stored on the ModuleTranslation stack. Currently only supports
1782 /// reduction within WsloopOp and ParallelOp, but can be easily extended.
1783 static LogicalResult
1784 convertOmpReductionOp(omp::ReductionOp reductionOp,
1785 llvm::IRBuilderBase &builder,
1786 LLVM::ModuleTranslation &moduleTranslation) {
1787 // Find the declaration that corresponds to the reduction op.
1788 omp::DeclareReductionOp declaration;
1789 Operation *reductionParent = reductionOp->getParentOp();
1790 if (dyn_cast<omp::ParallelOp>(reductionParent) ||
1791 dyn_cast<omp::WsloopOp>(reductionParent)) {
1792 declaration = findReductionDecl(*reductionParent, reductionOp);
1793 } else {
1794 llvm_unreachable("Unhandled reduction container");
1796 assert(declaration && "could not find reduction declaration");
1798 // Retrieve the mapping between reduction variables and their private
1799 // equivalents.
1800 const DenseMap<Value, llvm::Value *> *reductionVariableMap = nullptr;
1801 moduleTranslation.stackWalk<OpenMPVarMappingStackFrame>(
1802 [&](const OpenMPVarMappingStackFrame &frame) {
1803 if (frame.mapping.contains(reductionOp.getAccumulator())) {
1804 reductionVariableMap = &frame.mapping;
1805 return WalkResult::interrupt();
1807 return WalkResult::advance();
1809 assert(reductionVariableMap && "couldn't find private reduction variables");
1810 // Translate the reduction operation by emitting the body of the corresponding
1811 // reduction declaration.
1812 Region &reductionRegion = declaration.getReductionRegion();
1813 llvm::Value *privateReductionVar =
1814 reductionVariableMap->lookup(reductionOp.getAccumulator());
1815 llvm::Value *reductionVal = builder.CreateLoad(
1816 moduleTranslation.convertType(reductionOp.getOperand().getType()),
1817 privateReductionVar);
1819 moduleTranslation.mapValue(reductionRegion.front().getArgument(0),
1820 reductionVal);
1821 moduleTranslation.mapValue(
1822 reductionRegion.front().getArgument(1),
1823 moduleTranslation.lookupValue(reductionOp.getOperand()));
1825 SmallVector<llvm::Value *> phis;
1826 if (failed(inlineConvertOmpRegions(reductionRegion, "omp.reduction.body",
1827 builder, moduleTranslation, &phis)))
1828 return failure();
1829 assert(phis.size() == 1 && "expected one value to be yielded from "
1830 "the reduction body declaration region");
1831 builder.CreateStore(phis[0], privateReductionVar);
1832 return success();
1835 /// Converts an OpenMP Threadprivate operation into LLVM IR using
1836 /// OpenMPIRBuilder.
1837 static LogicalResult
1838 convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
1839 LLVM::ModuleTranslation &moduleTranslation) {
1840 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1841 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
1843 Value symAddr = threadprivateOp.getSymAddr();
1844 auto *symOp = symAddr.getDefiningOp();
1845 if (!isa<LLVM::AddressOfOp>(symOp))
1846 return opInst.emitError("Addressing symbol not found");
1847 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
1849 LLVM::GlobalOp global =
1850 addressOfOp.getGlobal(moduleTranslation.symbolTable());
1851 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
1852 llvm::Type *type = globalValue->getValueType();
1853 llvm::TypeSize typeSize =
1854 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
1855 type);
1856 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
1857 llvm::StringRef suffix = llvm::StringRef(".cache", 6);
1858 std::string cacheName = (Twine(global.getSymName()).concat(suffix)).str();
1859 llvm::Value *callInst =
1860 moduleTranslation.getOpenMPBuilder()->createCachedThreadPrivate(
1861 ompLoc, globalValue, size, cacheName);
1862 moduleTranslation.mapValue(opInst.getResult(0), callInst);
1863 return success();
1866 static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
1867 convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
1868 switch (deviceClause) {
1869 case mlir::omp::DeclareTargetDeviceType::host:
1870 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
1871 break;
1872 case mlir::omp::DeclareTargetDeviceType::nohost:
1873 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
1874 break;
1875 case mlir::omp::DeclareTargetDeviceType::any:
1876 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
1877 break;
1879 llvm_unreachable("unhandled device clause");
1882 static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
1883 convertToCaptureClauseKind(
1884 mlir::omp::DeclareTargetCaptureClause captureClasue) {
1885 switch (captureClasue) {
1886 case mlir::omp::DeclareTargetCaptureClause::to:
1887 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
1888 case mlir::omp::DeclareTargetCaptureClause::link:
1889 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
1890 case mlir::omp::DeclareTargetCaptureClause::enter:
1891 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
1893 llvm_unreachable("unhandled capture clause");
1896 static llvm::SmallString<64>
1897 getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
1898 llvm::OpenMPIRBuilder &ompBuilder) {
1899 llvm::SmallString<64> suffix;
1900 llvm::raw_svector_ostream os(suffix);
1901 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
1902 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
1903 auto fileInfoCallBack = [&loc]() {
1904 return std::pair<std::string, uint64_t>(
1905 llvm::StringRef(loc.getFilename()), loc.getLine());
1908 os << llvm::format(
1909 "_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID);
1911 os << "_decl_tgt_ref_ptr";
1913 return suffix;
1916 static bool isDeclareTargetLink(mlir::Value value) {
1917 if (auto addressOfOp =
1918 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
1919 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
1920 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
1921 if (auto declareTargetGlobal =
1922 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
1923 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
1924 mlir::omp::DeclareTargetCaptureClause::link)
1925 return true;
1927 return false;
1930 // Returns the reference pointer generated by the lowering of the declare target
1931 // operation in cases where the link clause is used or the to clause is used in
1932 // USM mode.
1933 static llvm::Value *
1934 getRefPtrIfDeclareTarget(mlir::Value value,
1935 LLVM::ModuleTranslation &moduleTranslation) {
1936 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1938 // An easier way to do this may just be to keep track of any pointer
1939 // references and their mapping to their respective operation
1940 if (auto addressOfOp =
1941 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
1942 if (auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
1943 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
1944 addressOfOp.getGlobalName()))) {
1946 if (auto declareTargetGlobal =
1947 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
1948 gOp.getOperation())) {
1950 // In this case, we must utilise the reference pointer generated by the
1951 // declare target operation, similar to Clang
1952 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
1953 mlir::omp::DeclareTargetCaptureClause::link) ||
1954 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
1955 mlir::omp::DeclareTargetCaptureClause::to &&
1956 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
1957 llvm::SmallString<64> suffix =
1958 getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
1960 if (gOp.getSymName().contains(suffix))
1961 return moduleTranslation.getLLVMModule()->getNamedValue(
1962 gOp.getSymName());
1964 return moduleTranslation.getLLVMModule()->getNamedValue(
1965 (gOp.getSymName().str() + suffix.str()).str());
1971 return nullptr;
1974 // A small helper structure to contain data gathered
1975 // for map lowering and coalese it into one area and
1976 // avoiding extra computations such as searches in the
1977 // llvm module for lowered mapped variables or checking
1978 // if something is declare target (and retrieving the
1979 // value) more than neccessary.
1980 struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
1981 llvm::SmallVector<bool, 4> IsDeclareTarget;
1982 llvm::SmallVector<bool, 4> IsAMember;
1983 llvm::SmallVector<mlir::Operation *, 4> MapClause;
1984 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
1985 // Stripped off array/pointer to get the underlying
1986 // element type
1987 llvm::SmallVector<llvm::Type *, 4> BaseType;
1989 /// Append arrays in \a CurInfo.
1990 void append(MapInfoData &CurInfo) {
1991 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
1992 CurInfo.IsDeclareTarget.end());
1993 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
1994 OriginalValue.append(CurInfo.OriginalValue.begin(),
1995 CurInfo.OriginalValue.end());
1996 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
1997 llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo);
2001 uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl) {
2002 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
2003 arrTy.getElementType()))
2004 return getArrayElementSizeInBits(nestedArrTy, dl);
2005 return dl.getTypeSizeInBits(arrTy.getElementType());
2008 // This function calculates the size to be offloaded for a specified type, given
2009 // its associated map clause (which can contain bounds information which affects
2010 // the total size), this size is calculated based on the underlying element type
2011 // e.g. given a 1-D array of ints, we will calculate the size from the integer
2012 // type * number of elements in the array. This size can be used in other
2013 // calculations but is ultimately used as an argument to the OpenMP runtimes
2014 // kernel argument structure which is generated through the combinedInfo data
2015 // structures.
2016 // This function is somewhat equivalent to Clang's getExprTypeSize inside of
2017 // CGOpenMPRuntime.cpp.
2018 llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
2019 Operation *clauseOp, llvm::Value *basePointer,
2020 llvm::Type *baseType, llvm::IRBuilderBase &builder,
2021 LLVM::ModuleTranslation &moduleTranslation) {
2022 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
2023 // the size in inconsistent byte or bit format.
2024 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
2025 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
2026 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
2028 if (auto memberClause =
2029 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
2030 // This calculates the size to transfer based on bounds and the underlying
2031 // element type, provided bounds have been specified (Fortran
2032 // pointers/allocatables/target and arrays that have sections specified fall
2033 // into this as well).
2034 if (!memberClause.getBounds().empty()) {
2035 llvm::Value *elementCount = builder.getInt64(1);
2036 for (auto bounds : memberClause.getBounds()) {
2037 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2038 bounds.getDefiningOp())) {
2039 // The below calculation for the size to be mapped calculated from the
2040 // map_info's bounds is: (elemCount * [UB - LB] + 1), later we
2041 // multiply by the underlying element types byte size to get the full
2042 // size to be offloaded based on the bounds
2043 elementCount = builder.CreateMul(
2044 elementCount,
2045 builder.CreateAdd(
2046 builder.CreateSub(
2047 moduleTranslation.lookupValue(boundOp.getUpperBound()),
2048 moduleTranslation.lookupValue(boundOp.getLowerBound())),
2049 builder.getInt64(1)));
2053 // The size in bytes x number of elements, the sizeInBytes stored is
2054 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
2055 // size, so we do some on the fly runtime math to get the size in
2056 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
2057 // some adjustment for members with more complex types.
2058 return builder.CreateMul(elementCount,
2059 builder.getInt64(underlyingTypeSzInBits / 8));
2063 return builder.getInt64(underlyingTypeSzInBits / 8);
2066 void collectMapDataFromMapOperands(MapInfoData &mapData,
2067 llvm::SmallVectorImpl<Value> &mapOperands,
2068 LLVM::ModuleTranslation &moduleTranslation,
2069 DataLayout &dl,
2070 llvm::IRBuilderBase &builder) {
2071 for (mlir::Value mapValue : mapOperands) {
2072 if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2073 mapValue.getDefiningOp())) {
2074 mlir::Value offloadPtr =
2075 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
2076 mapData.OriginalValue.push_back(
2077 moduleTranslation.lookupValue(offloadPtr));
2078 mapData.Pointers.push_back(mapData.OriginalValue.back());
2080 if (llvm::Value *refPtr =
2081 getRefPtrIfDeclareTarget(offloadPtr,
2082 moduleTranslation)) { // declare target
2083 mapData.IsDeclareTarget.push_back(true);
2084 mapData.BasePointers.push_back(refPtr);
2085 } else { // regular mapped variable
2086 mapData.IsDeclareTarget.push_back(false);
2087 mapData.BasePointers.push_back(mapData.OriginalValue.back());
2090 mapData.BaseType.push_back(
2091 moduleTranslation.convertType(mapOp.getVarType()));
2092 mapData.Sizes.push_back(getSizeInBytes(
2093 dl, mapOp.getVarType(), mapOp, mapData.BasePointers.back(),
2094 mapData.BaseType.back(), builder, moduleTranslation));
2095 mapData.MapClause.push_back(mapOp.getOperation());
2096 mapData.Types.push_back(
2097 llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
2098 mapData.Names.push_back(LLVM::createMappingInformation(
2099 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
2100 mapData.DevicePointers.push_back(
2101 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2103 // Check if this is a member mapping and correctly assign that it is, if
2104 // it is a member of a larger object.
2105 // TODO: Need better handling of members, and distinguishing of members
2106 // that are implicitly allocated on device vs explicitly passed in as
2107 // arguments.
2108 // TODO: May require some further additions to support nested record
2109 // types, i.e. member maps that can have member maps.
2110 mapData.IsAMember.push_back(false);
2111 for (mlir::Value mapValue : mapOperands) {
2112 if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2113 mapValue.getDefiningOp())) {
2114 for (auto member : map.getMembers()) {
2115 if (member == mapOp) {
2116 mapData.IsAMember.back() = true;
2125 /// This function calculates the array/pointer offset for map data provided
2126 /// with bounds operations, e.g. when provided something like the following:
2128 /// Fortran
2129 /// map(tofrom: array(2:5, 3:2))
2130 /// or
2131 /// C++
2132 /// map(tofrom: array[1:4][2:3])
2133 /// We must calculate the initial pointer offset to pass across, this function
2134 /// performs this using bounds.
2136 /// NOTE: which while specified in row-major order it currently needs to be
2137 /// flipped for Fortran's column order array allocation and access (as
2138 /// opposed to C++'s row-major, hence the backwards processing where order is
2139 /// important). This is likely important to keep in mind for the future when
2140 /// we incorporate a C++ frontend, both frontends will need to agree on the
2141 /// ordering of generated bounds operations (one may have to flip them) to
2142 /// make the below lowering frontend agnostic. The offload size
2143 /// calcualtion may also have to be adjusted for C++.
2144 std::vector<llvm::Value *>
2145 calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
2146 llvm::IRBuilderBase &builder, bool isArrayTy,
2147 mlir::OperandRange bounds) {
2148 std::vector<llvm::Value *> idx;
2149 // There's no bounds to calculate an offset from, we can safely
2150 // ignore and return no indices.
2151 if (bounds.empty())
2152 return idx;
2154 // If we have an array type, then we have its type so can treat it as a
2155 // normal GEP instruction where the bounds operations are simply indexes
2156 // into the array. We currently do reverse order of the bounds, which
2157 // I believe leans more towards Fortran's column-major in memory.
2158 if (isArrayTy) {
2159 idx.push_back(builder.getInt64(0));
2160 for (int i = bounds.size() - 1; i >= 0; --i) {
2161 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2162 bounds[i].getDefiningOp())) {
2163 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
2166 } else {
2167 // If we do not have an array type, but we have bounds, then we're dealing
2168 // with a pointer that's being treated like an array and we have the
2169 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
2170 // address (pointer pointing to the actual data) so we must caclulate the
2171 // offset using a single index which the following two loops attempts to
2172 // compute.
2174 // Calculates the size offset we need to make per row e.g. first row or
2175 // column only needs to be offset by one, but the next would have to be
2176 // the previous row/column offset multiplied by the extent of current row.
2178 // For example ([1][10][100]):
2180 // - First row/column we move by 1 for each index increment
2181 // - Second row/column we move by 1 (first row/column) * 10 (extent/size of
2182 // current) for 10 for each index increment
2183 // - Third row/column we would move by 10 (second row/column) *
2184 // (extent/size of current) 100 for 1000 for each index increment
2185 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
2186 for (size_t i = 1; i < bounds.size(); ++i) {
2187 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2188 bounds[i].getDefiningOp())) {
2189 dimensionIndexSizeOffset.push_back(builder.CreateMul(
2190 moduleTranslation.lookupValue(boundOp.getExtent()),
2191 dimensionIndexSizeOffset[i - 1]));
2195 // Now that we have calculated how much we move by per index, we must
2196 // multiply each lower bound offset in indexes by the size offset we
2197 // have calculated in the previous and accumulate the results to get
2198 // our final resulting offset.
2199 for (int i = bounds.size() - 1; i >= 0; --i) {
2200 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2201 bounds[i].getDefiningOp())) {
2202 if (idx.empty())
2203 idx.emplace_back(builder.CreateMul(
2204 moduleTranslation.lookupValue(boundOp.getLowerBound()),
2205 dimensionIndexSizeOffset[i]));
2206 else
2207 idx.back() = builder.CreateAdd(
2208 idx.back(), builder.CreateMul(moduleTranslation.lookupValue(
2209 boundOp.getLowerBound()),
2210 dimensionIndexSizeOffset[i]));
2215 return idx;
2218 // This creates two insertions into the MapInfosTy data structure for the
2219 // "parent" of a set of members, (usually a container e.g.
2220 // class/structure/derived type) when subsequent members have also been
2221 // explicitly mapped on the same map clause. Certain types, such as Fortran
2222 // descriptors are mapped like this as well, however, the members are
2223 // implicit as far as a user is concerned, but we must explicitly map them
2224 // internally.
2226 // This function also returns the memberOfFlag for this particular parent,
2227 // which is utilised in subsequent member mappings (by modifying there map type
2228 // with it) to indicate that a member is part of this parent and should be
2229 // treated by the runtime as such. Important to achieve the correct mapping.
2230 static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
2231 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
2232 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
2233 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
2234 uint64_t mapDataIndex, bool isTargetParams) {
2235 // Map the first segment of our structure
2236 combinedInfo.Types.emplace_back(
2237 isTargetParams
2238 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
2239 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
2240 combinedInfo.DevicePointers.emplace_back(
2241 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2242 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
2243 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
2244 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2245 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
2247 // Calculate size of the parent object being mapped based on the
2248 // addresses at runtime, highAddr - lowAddr = size. This of course
2249 // doesn't factor in allocated data like pointers, hence the further
2250 // processing of members specified by users, or in the case of
2251 // Fortran pointers and allocatables, the mapping of the pointed to
2252 // data by the descriptor (which itself, is a structure containing
2253 // runtime information on the dynamically allocated data).
2254 llvm::Value *lowAddr = builder.CreatePointerCast(
2255 mapData.Pointers[mapDataIndex], builder.getPtrTy());
2256 llvm::Value *highAddr = builder.CreatePointerCast(
2257 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
2258 mapData.Pointers[mapDataIndex], 1),
2259 builder.getPtrTy());
2260 llvm::Value *size = builder.CreateIntCast(
2261 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
2262 builder.getInt64Ty(),
2263 /*isSigned=*/false);
2264 combinedInfo.Sizes.push_back(size);
2266 // This creates the initial MEMBER_OF mapping that consists of
2267 // the parent/top level container (same as above effectively, except
2268 // with a fixed initial compile time size and seperate maptype which
2269 // indicates the true mape type (tofrom etc.) and that it is a part
2270 // of a larger mapping and indicating the link between it and it's
2271 // members that are also explicitly mapped).
2272 llvm::omp::OpenMPOffloadMappingFlags mapFlag =
2273 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
2274 if (isTargetParams)
2275 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2277 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
2278 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
2279 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
2281 combinedInfo.Types.emplace_back(mapFlag);
2282 combinedInfo.DevicePointers.emplace_back(
2283 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2284 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
2285 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
2286 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2287 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
2288 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
2290 return memberOfFlag;
2293 // The intent is to verify if the mapped data being passed is a
2294 // pointer -> pointee that requires special handling in certain cases,
2295 // e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
2297 // There may be a better way to verify this, but unfortunately with
2298 // opaque pointers we lose the ability to easily check if something is
2299 // a pointer whilst maintaining access to the underlying type.
2300 static bool checkIfPointerMap(mlir::omp::MapInfoOp mapOp) {
2301 // If we have a varPtrPtr field assigned then the underlying type is a pointer
2302 if (mapOp.getVarPtrPtr())
2303 return true;
2305 // If the map data is declare target with a link clause, then it's represented
2306 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
2307 // no relation to pointers.
2308 if (isDeclareTargetLink(mapOp.getVarPtr()))
2309 return true;
2311 return false;
2314 // This function is intended to add explicit mappings of members
2315 static void processMapMembersWithParent(
2316 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
2317 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
2318 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
2319 uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
2321 auto parentClause =
2322 mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
2324 for (auto mappedMembers : parentClause.getMembers()) {
2325 auto memberClause =
2326 mlir::dyn_cast<mlir::omp::MapInfoOp>(mappedMembers.getDefiningOp());
2327 int memberDataIdx = -1;
2328 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
2329 if (mapData.MapClause[i] == memberClause)
2330 memberDataIdx = i;
2333 assert(memberDataIdx >= 0 && "could not find mapped member of structure");
2335 // Same MemberOfFlag to indicate its link with parent and other members
2336 // of, and we flag that it's part of a pointer and object coupling.
2337 auto mapFlag =
2338 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType().value());
2339 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2340 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
2341 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
2342 if (checkIfPointerMap(memberClause))
2343 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
2345 combinedInfo.Types.emplace_back(mapFlag);
2346 combinedInfo.DevicePointers.emplace_back(
2347 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2348 combinedInfo.Names.emplace_back(
2349 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
2351 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[memberDataIdx]);
2352 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
2353 combinedInfo.Sizes.emplace_back(mapData.Sizes[memberDataIdx]);
2357 static void processMapWithMembersOf(
2358 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
2359 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
2360 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
2361 uint64_t mapDataIndex, bool isTargetParams) {
2362 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
2363 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
2364 combinedInfo, mapData, mapDataIndex, isTargetParams);
2365 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
2366 combinedInfo, mapData, mapDataIndex,
2367 memberOfParentFlag);
2370 // This is a variation on Clang's GenerateOpenMPCapturedVars, which
2371 // generates different operation (e.g. load/store) combinations for
2372 // arguments to the kernel, based on map capture kinds which are then
2373 // utilised in the combinedInfo in place of the original Map value.
2374 static void
2375 createAlteredByCaptureMap(MapInfoData &mapData,
2376 LLVM::ModuleTranslation &moduleTranslation,
2377 llvm::IRBuilderBase &builder) {
2378 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
2379 // if it's declare target, skip it, it's handled seperately.
2380 if (!mapData.IsDeclareTarget[i]) {
2381 auto mapOp =
2382 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(mapData.MapClause[i]);
2383 mlir::omp::VariableCaptureKind captureKind =
2384 mapOp.getMapCaptureType().value_or(
2385 mlir::omp::VariableCaptureKind::ByRef);
2386 bool isPtrTy = checkIfPointerMap(mapOp);
2388 // Currently handles array sectioning lowerbound case, but more
2389 // logic may be required in the future. Clang invokes EmitLValue,
2390 // which has specialised logic for special Clang types such as user
2391 // defines, so it is possible we will have to extend this for
2392 // structures or other complex types. As the general idea is that this
2393 // function mimics some of the logic from Clang that we require for
2394 // kernel argument passing from host -> device.
2395 switch (captureKind) {
2396 case mlir::omp::VariableCaptureKind::ByRef: {
2397 llvm::Value *newV = mapData.Pointers[i];
2398 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
2399 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
2400 mapOp.getBounds());
2401 if (isPtrTy)
2402 newV = builder.CreateLoad(builder.getPtrTy(), newV);
2404 if (!offsetIdx.empty())
2405 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
2406 "array_offset");
2407 mapData.Pointers[i] = newV;
2408 } break;
2409 case mlir::omp::VariableCaptureKind::ByCopy: {
2410 llvm::Type *type = mapData.BaseType[i];
2411 llvm::Value *newV;
2412 if (mapData.Pointers[i]->getType()->isPointerTy())
2413 newV = builder.CreateLoad(type, mapData.Pointers[i]);
2414 else
2415 newV = mapData.Pointers[i];
2417 if (!isPtrTy) {
2418 auto curInsert = builder.saveIP();
2419 builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
2420 auto *memTempAlloc =
2421 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
2422 builder.restoreIP(curInsert);
2424 builder.CreateStore(newV, memTempAlloc);
2425 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
2428 mapData.Pointers[i] = newV;
2429 mapData.BasePointers[i] = newV;
2430 } break;
2431 case mlir::omp::VariableCaptureKind::This:
2432 case mlir::omp::VariableCaptureKind::VLAType:
2433 mapData.MapClause[i]->emitOpError("Unhandled capture kind");
2434 break;
2440 // Generate all map related information and fill the combinedInfo.
2441 static void genMapInfos(llvm::IRBuilderBase &builder,
2442 LLVM::ModuleTranslation &moduleTranslation,
2443 DataLayout &dl,
2444 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
2445 MapInfoData &mapData,
2446 const SmallVector<Value> &devPtrOperands = {},
2447 const SmallVector<Value> &devAddrOperands = {},
2448 bool isTargetParams = false) {
2449 // We wish to modify some of the methods in which arguments are
2450 // passed based on their capture type by the target region, this can
2451 // involve generating new loads and stores, which changes the
2452 // MLIR value to LLVM value mapping, however, we only wish to do this
2453 // locally for the current function/target and also avoid altering
2454 // ModuleTranslation, so we remap the base pointer or pointer stored
2455 // in the map infos corresponding MapInfoData, which is later accessed
2456 // by genMapInfos and createTarget to help generate the kernel and
2457 // kernel arg structure. It primarily becomes relevant in cases like
2458 // bycopy, or byref range'd arrays. In the default case, we simply
2459 // pass thee pointer byref as both basePointer and pointer.
2460 if (!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
2461 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
2463 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2465 auto fail = [&combinedInfo]() -> void {
2466 combinedInfo.BasePointers.clear();
2467 combinedInfo.Pointers.clear();
2468 combinedInfo.DevicePointers.clear();
2469 combinedInfo.Sizes.clear();
2470 combinedInfo.Types.clear();
2471 combinedInfo.Names.clear();
2474 // We operate under the assumption that all vectors that are
2475 // required in MapInfoData are of equal lengths (either filled with
2476 // default constructed data or appropiate information) so we can
2477 // utilise the size from any component of MapInfoData, if we can't
2478 // something is missing from the initial MapInfoData construction.
2479 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
2480 // NOTE/TODO: We currently do not handle member mapping seperately from it's
2481 // parent or explicit mapping of a parent and member in the same operation,
2482 // this will need to change in the near future, for now we primarily handle
2483 // descriptor mapping from fortran, generalised as mapping record types
2484 // with implicit member maps. This lowering needs further generalisation to
2485 // fully support fortran derived types, and C/C++ structures and classes.
2486 if (mapData.IsAMember[i])
2487 continue;
2489 auto mapInfoOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[i]);
2490 if (!mapInfoOp.getMembers().empty()) {
2491 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
2492 combinedInfo, mapData, i, isTargetParams);
2493 continue;
2496 auto mapFlag = mapData.Types[i];
2497 bool isPtrTy = checkIfPointerMap(mapInfoOp);
2498 if (isPtrTy)
2499 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
2501 // Declare Target Mappings are excluded from being marked as
2502 // OMP_MAP_TARGET_PARAM as they are not passed as parameters.
2503 if (isTargetParams && !mapData.IsDeclareTarget[i])
2504 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2506 if (auto mapInfoOp = dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[i]))
2507 if (mapInfoOp.getMapCaptureType().value() ==
2508 mlir::omp::VariableCaptureKind::ByCopy &&
2509 !isPtrTy)
2510 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
2512 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[i]);
2513 combinedInfo.Pointers.emplace_back(mapData.Pointers[i]);
2514 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[i]);
2515 combinedInfo.Names.emplace_back(mapData.Names[i]);
2516 combinedInfo.Types.emplace_back(mapFlag);
2517 combinedInfo.Sizes.emplace_back(mapData.Sizes[i]);
2520 auto findMapInfo = [&combinedInfo](llvm::Value *val, unsigned &index) {
2521 index = 0;
2522 for (llvm::Value *basePtr : combinedInfo.BasePointers) {
2523 if (basePtr == val)
2524 return true;
2525 index++;
2527 return false;
2530 auto addDevInfos = [&, fail](auto devOperands, auto devOpType) -> void {
2531 for (const auto &devOp : devOperands) {
2532 // TODO: Only LLVMPointerTypes are handled.
2533 if (!isa<LLVM::LLVMPointerType>(devOp.getType()))
2534 return fail();
2536 llvm::Value *mapOpValue = moduleTranslation.lookupValue(devOp);
2538 // Check if map info is already present for this entry.
2539 unsigned infoIndex;
2540 if (findMapInfo(mapOpValue, infoIndex)) {
2541 combinedInfo.Types[infoIndex] |=
2542 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
2543 combinedInfo.DevicePointers[infoIndex] = devOpType;
2544 } else {
2545 combinedInfo.BasePointers.emplace_back(mapOpValue);
2546 combinedInfo.Pointers.emplace_back(mapOpValue);
2547 combinedInfo.DevicePointers.emplace_back(devOpType);
2548 combinedInfo.Names.emplace_back(
2549 LLVM::createMappingInformation(devOp.getLoc(), *ompBuilder));
2550 combinedInfo.Types.emplace_back(
2551 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
2552 combinedInfo.Sizes.emplace_back(builder.getInt64(0));
2557 addDevInfos(devPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
2558 addDevInfos(devAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
2561 static LogicalResult
2562 convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
2563 LLVM::ModuleTranslation &moduleTranslation) {
2564 llvm::Value *ifCond = nullptr;
2565 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
2566 SmallVector<Value> mapOperands;
2567 SmallVector<Value> useDevPtrOperands;
2568 SmallVector<Value> useDevAddrOperands;
2569 llvm::omp::RuntimeFunction RTLFn;
2570 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
2572 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2574 LogicalResult result =
2575 llvm::TypeSwitch<Operation *, LogicalResult>(op)
2576 .Case([&](omp::TargetDataOp dataOp) {
2577 if (auto ifExprVar = dataOp.getIfExpr())
2578 ifCond = moduleTranslation.lookupValue(ifExprVar);
2580 if (auto devId = dataOp.getDevice())
2581 if (auto constOp =
2582 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
2583 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
2584 deviceID = intAttr.getInt();
2586 mapOperands = dataOp.getMapOperands();
2587 useDevPtrOperands = dataOp.getUseDevicePtr();
2588 useDevAddrOperands = dataOp.getUseDeviceAddr();
2589 return success();
2591 .Case([&](omp::TargetEnterDataOp enterDataOp) {
2592 if (enterDataOp.getNowait())
2593 return (LogicalResult)(enterDataOp.emitError(
2594 "`nowait` is not supported yet"));
2596 if (auto ifExprVar = enterDataOp.getIfExpr())
2597 ifCond = moduleTranslation.lookupValue(ifExprVar);
2599 if (auto devId = enterDataOp.getDevice())
2600 if (auto constOp =
2601 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
2602 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
2603 deviceID = intAttr.getInt();
2604 RTLFn = llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
2605 mapOperands = enterDataOp.getMapOperands();
2606 return success();
2608 .Case([&](omp::TargetExitDataOp exitDataOp) {
2609 if (exitDataOp.getNowait())
2610 return (LogicalResult)(exitDataOp.emitError(
2611 "`nowait` is not supported yet"));
2613 if (auto ifExprVar = exitDataOp.getIfExpr())
2614 ifCond = moduleTranslation.lookupValue(ifExprVar);
2616 if (auto devId = exitDataOp.getDevice())
2617 if (auto constOp =
2618 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
2619 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
2620 deviceID = intAttr.getInt();
2622 RTLFn = llvm::omp::OMPRTL___tgt_target_data_end_mapper;
2623 mapOperands = exitDataOp.getMapOperands();
2624 return success();
2626 .Case([&](omp::TargetUpdateOp updateDataOp) {
2627 if (updateDataOp.getNowait())
2628 return (LogicalResult)(updateDataOp.emitError(
2629 "`nowait` is not supported yet"));
2631 if (auto ifExprVar = updateDataOp.getIfExpr())
2632 ifCond = moduleTranslation.lookupValue(ifExprVar);
2634 if (auto devId = updateDataOp.getDevice())
2635 if (auto constOp =
2636 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
2637 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
2638 deviceID = intAttr.getInt();
2640 RTLFn = llvm::omp::OMPRTL___tgt_target_data_update_mapper;
2641 mapOperands = updateDataOp.getMapOperands();
2642 return success();
2644 .Default([&](Operation *op) {
2645 return op->emitError("unsupported OpenMP operation: ")
2646 << op->getName();
2649 if (failed(result))
2650 return failure();
2652 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2654 MapInfoData mapData;
2655 collectMapDataFromMapOperands(mapData, mapOperands, moduleTranslation, DL,
2656 builder);
2658 // Fill up the arrays with all the mapped variables.
2659 llvm::OpenMPIRBuilder::MapInfosTy combinedInfo;
2660 auto genMapInfoCB =
2661 [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
2662 builder.restoreIP(codeGenIP);
2663 if (auto dataOp = dyn_cast<omp::TargetDataOp>(op)) {
2664 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
2665 useDevPtrOperands, useDevAddrOperands);
2666 } else {
2667 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
2669 return combinedInfo;
2672 llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
2673 /*SeparateBeginEndCalls=*/true);
2675 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
2676 LogicalResult bodyGenStatus = success();
2677 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) {
2678 assert(isa<omp::TargetDataOp>(op) &&
2679 "BodyGen requested for non TargetDataOp");
2680 Region &region = cast<omp::TargetDataOp>(op).getRegion();
2681 switch (bodyGenType) {
2682 case BodyGenTy::Priv:
2683 // Check if any device ptr/addr info is available
2684 if (!info.DevicePtrInfoMap.empty()) {
2685 builder.restoreIP(codeGenIP);
2686 unsigned argIndex = 0;
2687 for (auto &devPtrOp : useDevPtrOperands) {
2688 llvm::Value *mapOpValue = moduleTranslation.lookupValue(devPtrOp);
2689 const auto &arg = region.front().getArgument(argIndex);
2690 moduleTranslation.mapValue(arg,
2691 info.DevicePtrInfoMap[mapOpValue].second);
2692 argIndex++;
2695 for (auto &devAddrOp : useDevAddrOperands) {
2696 llvm::Value *mapOpValue = moduleTranslation.lookupValue(devAddrOp);
2697 const auto &arg = region.front().getArgument(argIndex);
2698 auto *LI = builder.CreateLoad(
2699 builder.getPtrTy(), info.DevicePtrInfoMap[mapOpValue].second);
2700 moduleTranslation.mapValue(arg, LI);
2701 argIndex++;
2704 bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
2705 builder, moduleTranslation);
2707 break;
2708 case BodyGenTy::DupNoPriv:
2709 break;
2710 case BodyGenTy::NoPriv:
2711 // If device info is available then region has already been generated
2712 if (info.DevicePtrInfoMap.empty()) {
2713 builder.restoreIP(codeGenIP);
2714 bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
2715 builder, moduleTranslation);
2717 break;
2719 return builder.saveIP();
2722 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2723 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2724 findAllocaInsertPoint(builder, moduleTranslation);
2725 if (isa<omp::TargetDataOp>(op)) {
2726 builder.restoreIP(ompBuilder->createTargetData(
2727 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
2728 info, genMapInfoCB, nullptr, bodyGenCB));
2729 } else {
2730 builder.restoreIP(ompBuilder->createTargetData(
2731 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
2732 info, genMapInfoCB, &RTLFn));
2735 return bodyGenStatus;
2738 /// Lowers the FlagsAttr which is applied to the module on the device
2739 /// pass when offloading, this attribute contains OpenMP RTL globals that can
2740 /// be passed as flags to the frontend, otherwise they are set to default
2741 LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
2742 LLVM::ModuleTranslation &moduleTranslation) {
2743 if (!cast<mlir::ModuleOp>(op))
2744 return failure();
2746 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2748 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
2749 attribute.getOpenmpDeviceVersion());
2751 if (attribute.getNoGpuLib())
2752 return success();
2754 ompBuilder->createGlobalFlag(
2755 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
2756 "__omp_rtl_debug_kind");
2757 ompBuilder->createGlobalFlag(
2758 attribute
2759 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
2761 "__omp_rtl_assume_teams_oversubscription");
2762 ompBuilder->createGlobalFlag(
2763 attribute
2764 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
2766 "__omp_rtl_assume_threads_oversubscription");
2767 ompBuilder->createGlobalFlag(
2768 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
2769 "__omp_rtl_assume_no_thread_state");
2770 ompBuilder->createGlobalFlag(
2771 attribute
2772 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
2774 "__omp_rtl_assume_no_nested_parallelism");
2775 return success();
2778 static bool getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
2779 omp::TargetOp targetOp,
2780 llvm::StringRef parentName = "") {
2781 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
2783 assert(fileLoc && "No file found from location");
2784 StringRef fileName = fileLoc.getFilename().getValue();
2786 llvm::sys::fs::UniqueID id;
2787 if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
2788 targetOp.emitError("Unable to get unique ID for file");
2789 return false;
2792 uint64_t line = fileLoc.getLine();
2793 targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
2794 id.getFile(), line);
2795 return true;
2798 static bool targetOpSupported(Operation &opInst) {
2799 auto targetOp = cast<omp::TargetOp>(opInst);
2800 if (targetOp.getIfExpr()) {
2801 opInst.emitError("If clause not yet supported");
2802 return false;
2805 if (targetOp.getDevice()) {
2806 opInst.emitError("Device clause not yet supported");
2807 return false;
2810 if (targetOp.getThreadLimit()) {
2811 opInst.emitError("Thread limit clause not yet supported");
2812 return false;
2815 if (targetOp.getNowait()) {
2816 opInst.emitError("Nowait clause not yet supported");
2817 return false;
2820 return true;
2823 static void
2824 handleDeclareTargetMapVar(MapInfoData &mapData,
2825 LLVM::ModuleTranslation &moduleTranslation,
2826 llvm::IRBuilderBase &builder) {
2827 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
2828 // In the case of declare target mapped variables, the basePointer is
2829 // the reference pointer generated by the convertDeclareTargetAttr
2830 // method. Whereas the kernelValue is the original variable, so for
2831 // the device we must replace all uses of this original global variable
2832 // (stored in kernelValue) with the reference pointer (stored in
2833 // basePointer for declare target mapped variables), as for device the
2834 // data is mapped into this reference pointer and should be loaded
2835 // from it, the original variable is discarded. On host both exist and
2836 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
2837 // function to link the two variables in the runtime and then both the
2838 // reference pointer and the pointer are assigned in the kernel argument
2839 // structure for the host.
2840 if (mapData.IsDeclareTarget[i]) {
2841 // The users iterator will get invalidated if we modify an element,
2842 // so we populate this vector of uses to alter each user on an individual
2843 // basis to emit its own load (rather than one load for all).
2844 llvm::SmallVector<llvm::User *> userVec;
2845 for (llvm::User *user : mapData.OriginalValue[i]->users())
2846 userVec.push_back(user);
2848 for (llvm::User *user : userVec) {
2849 if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
2850 auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
2851 mapData.BasePointers[i]);
2852 load->moveBefore(insn);
2853 user->replaceUsesOfWith(mapData.OriginalValue[i], load);
2860 // The createDeviceArgumentAccessor function generates
2861 // instructions for retrieving (acessing) kernel
2862 // arguments inside of the device kernel for use by
2863 // the kernel. This enables different semantics such as
2864 // the creation of temporary copies of data allowing
2865 // semantics like read-only/no host write back kernel
2866 // arguments.
2868 // This currently implements a very light version of Clang's
2869 // EmitParmDecl's handling of direct argument handling as well
2870 // as a portion of the argument access generation based on
2871 // capture types found at the end of emitOutlinedFunctionPrologue
2872 // in Clang. The indirect path handling of EmitParmDecl's may be
2873 // required for future work, but a direct 1-to-1 copy doesn't seem
2874 // possible as the logic is rather scattered throughout Clang's
2875 // lowering and perhaps we wish to deviate slightly.
2877 // \param mapData - A container containing vectors of information
2878 // corresponding to the input argument, which should have a
2879 // corresponding entry in the MapInfoData containers
2880 // OrigialValue's.
2881 // \param arg - This is the generated kernel function argument that
2882 // corresponds to the passed in input argument. We generated different
2883 // accesses of this Argument, based on capture type and other Input
2884 // related information.
2885 // \param input - This is the host side value that will be passed to
2886 // the kernel i.e. the kernel input, we rewrite all uses of this within
2887 // the kernel (as we generate the kernel body based on the target's region
2888 // which maintians references to the original input) to the retVal argument
2889 // apon exit of this function inside of the OMPIRBuilder. This interlinks
2890 // the kernel argument to future uses of it in the function providing
2891 // appropriate "glue" instructions inbetween.
2892 // \param retVal - This is the value that all uses of input inside of the
2893 // kernel will be re-written to, the goal of this function is to generate
2894 // an appropriate location for the kernel argument to be accessed from,
2895 // e.g. ByRef will result in a temporary allocation location and then
2896 // a store of the kernel argument into this allocated memory which
2897 // will then be loaded from, ByCopy will use the allocated memory
2898 // directly.
2899 static llvm::IRBuilderBase::InsertPoint
2900 createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
2901 llvm::Value *input, llvm::Value *&retVal,
2902 llvm::IRBuilderBase &builder,
2903 llvm::OpenMPIRBuilder &ompBuilder,
2904 LLVM::ModuleTranslation &moduleTranslation,
2905 llvm::IRBuilderBase::InsertPoint allocaIP,
2906 llvm::IRBuilderBase::InsertPoint codeGenIP) {
2907 builder.restoreIP(allocaIP);
2909 mlir::omp::VariableCaptureKind capture =
2910 mlir::omp::VariableCaptureKind::ByRef;
2912 // Find the associated MapInfoData entry for the current input
2913 for (size_t i = 0; i < mapData.MapClause.size(); ++i)
2914 if (mapData.OriginalValue[i] == input) {
2915 if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2916 mapData.MapClause[i])) {
2917 capture = mapOp.getMapCaptureType().value_or(
2918 mlir::omp::VariableCaptureKind::ByRef);
2921 break;
2924 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
2925 unsigned int defaultAS =
2926 ompBuilder.M.getDataLayout().getProgramAddressSpace();
2928 // Create the alloca for the argument the current point.
2929 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
2931 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
2932 v = builder.CreatePointerBitCastOrAddrSpaceCast(
2933 v, arg.getType()->getPointerTo(defaultAS));
2935 builder.CreateStore(&arg, v);
2937 builder.restoreIP(codeGenIP);
2939 switch (capture) {
2940 case mlir::omp::VariableCaptureKind::ByCopy: {
2941 retVal = v;
2942 break;
2944 case mlir::omp::VariableCaptureKind::ByRef: {
2945 retVal = builder.CreateAlignedLoad(
2946 v->getType(), v,
2947 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
2948 break;
2950 case mlir::omp::VariableCaptureKind::This:
2951 case mlir::omp::VariableCaptureKind::VLAType:
2952 assert(false && "Currently unsupported capture kind");
2953 break;
2956 return builder.saveIP();
2959 static LogicalResult
2960 convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
2961 LLVM::ModuleTranslation &moduleTranslation) {
2963 if (!targetOpSupported(opInst))
2964 return failure();
2966 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
2967 auto targetOp = cast<omp::TargetOp>(opInst);
2968 auto &targetRegion = targetOp.getRegion();
2969 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
2970 SmallVector<Value> mapOperands = targetOp.getMapOperands();
2972 LogicalResult bodyGenStatus = success();
2973 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2974 auto bodyCB = [&](InsertPointTy allocaIP,
2975 InsertPointTy codeGenIP) -> InsertPointTy {
2976 // Forward target-cpu and target-features function attributes from the
2977 // original function to the new outlined function.
2978 llvm::Function *llvmParentFn =
2979 moduleTranslation.lookupFunction(parentFn.getName());
2980 llvm::Function *llvmOutlinedFn = codeGenIP.getBlock()->getParent();
2981 assert(llvmParentFn && llvmOutlinedFn &&
2982 "Both parent and outlined functions must exist at this point");
2984 if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
2985 attr.isStringAttribute())
2986 llvmOutlinedFn->addFnAttr(attr);
2988 if (auto attr = llvmParentFn->getFnAttribute("target-features");
2989 attr.isStringAttribute())
2990 llvmOutlinedFn->addFnAttr(attr);
2992 builder.restoreIP(codeGenIP);
2993 unsigned argIndex = 0;
2994 for (auto &mapOp : mapOperands) {
2995 auto mapInfoOp =
2996 mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp());
2997 llvm::Value *mapOpValue =
2998 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
2999 const auto &arg = targetRegion.front().getArgument(argIndex);
3000 moduleTranslation.mapValue(arg, mapOpValue);
3001 argIndex++;
3003 llvm::BasicBlock *exitBlock = convertOmpOpRegions(
3004 targetRegion, "omp.target", builder, moduleTranslation, bodyGenStatus);
3005 builder.SetInsertPoint(exitBlock);
3006 return builder.saveIP();
3009 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3010 StringRef parentName = parentFn.getName();
3012 llvm::TargetRegionEntryInfo entryInfo;
3014 if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName))
3015 return failure();
3017 int32_t defaultValTeams = -1;
3018 int32_t defaultValThreads = 0;
3020 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3021 findAllocaInsertPoint(builder, moduleTranslation);
3023 MapInfoData mapData;
3024 collectMapDataFromMapOperands(mapData, mapOperands, moduleTranslation, dl,
3025 builder);
3027 llvm::OpenMPIRBuilder::MapInfosTy combinedInfos;
3028 auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
3029 -> llvm::OpenMPIRBuilder::MapInfosTy & {
3030 builder.restoreIP(codeGenIP);
3031 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, {}, {},
3032 true);
3033 return combinedInfos;
3036 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
3037 llvm::Value *&retVal, InsertPointTy allocaIP,
3038 InsertPointTy codeGenIP) {
3039 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3041 // We just return the unaltered argument for the host function
3042 // for now, some alterations may be required in the future to
3043 // keep host fallback functions working identically to the device
3044 // version (e.g. pass ByCopy values should be treated as such on
3045 // host and device, currently not always the case)
3046 if (!ompBuilder->Config.isTargetDevice()) {
3047 retVal = cast<llvm::Value>(&arg);
3048 return codeGenIP;
3051 return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
3052 *ompBuilder, moduleTranslation,
3053 allocaIP, codeGenIP);
3056 llvm::SmallVector<llvm::Value *, 4> kernelInput;
3057 for (size_t i = 0; i < mapOperands.size(); ++i) {
3058 // declare target arguments are not passed to kernels as arguments
3059 // TODO: We currently do not handle cases where a member is explicitly
3060 // passed in as an argument, this will likley need to be handled in
3061 // the near future, rather than using IsAMember, it may be better to
3062 // test if the relevant BlockArg is used within the target region and
3063 // then use that as a basis for exclusion in the kernel inputs.
3064 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
3065 kernelInput.push_back(mapData.OriginalValue[i]);
3068 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget(
3069 ompLoc, allocaIP, builder.saveIP(), entryInfo, defaultValTeams,
3070 defaultValThreads, kernelInput, genMapInfoCB, bodyCB, argAccessorCB));
3072 // Remap access operations to declare target reference pointers for the
3073 // device, essentially generating extra loadop's as necessary
3074 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
3075 handleDeclareTargetMapVar(mapData, moduleTranslation, builder);
3077 return bodyGenStatus;
3080 static LogicalResult
3081 convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
3082 LLVM::ModuleTranslation &moduleTranslation) {
3083 // Amend omp.declare_target by deleting the IR of the outlined functions
3084 // created for target regions. They cannot be filtered out from MLIR earlier
3085 // because the omp.target operation inside must be translated to LLVM, but
3086 // the wrapper functions themselves must not remain at the end of the
3087 // process. We know that functions where omp.declare_target does not match
3088 // omp.is_target_device at this stage can only be wrapper functions because
3089 // those that aren't are removed earlier as an MLIR transformation pass.
3090 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
3091 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
3092 op->getParentOfType<ModuleOp>().getOperation())) {
3093 if (!offloadMod.getIsTargetDevice())
3094 return success();
3096 omp::DeclareTargetDeviceType declareType =
3097 attribute.getDeviceType().getValue();
3099 if (declareType == omp::DeclareTargetDeviceType::host) {
3100 llvm::Function *llvmFunc =
3101 moduleTranslation.lookupFunction(funcOp.getName());
3102 llvmFunc->dropAllReferences();
3103 llvmFunc->eraseFromParent();
3106 return success();
3109 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
3110 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
3111 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
3112 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3113 bool isDeclaration = gOp.isDeclaration();
3114 bool isExternallyVisible =
3115 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
3116 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
3117 llvm::StringRef mangledName = gOp.getSymName();
3118 auto captureClause =
3119 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
3120 auto deviceClause =
3121 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
3122 // unused for MLIR at the moment, required in Clang for book
3123 // keeping
3124 std::vector<llvm::GlobalVariable *> generatedRefs;
3126 std::vector<llvm::Triple> targetTriple;
3127 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
3128 op->getParentOfType<mlir::ModuleOp>()->getAttr(
3129 LLVM::LLVMDialect::getTargetTripleAttrName()));
3130 if (targetTripleAttr)
3131 targetTriple.emplace_back(targetTripleAttr.data());
3133 auto fileInfoCallBack = [&loc]() {
3134 std::string filename = "";
3135 std::uint64_t lineNo = 0;
3137 if (loc) {
3138 filename = loc.getFilename().str();
3139 lineNo = loc.getLine();
3142 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
3143 lineNo);
3146 ompBuilder->registerTargetGlobalVariable(
3147 captureClause, deviceClause, isDeclaration, isExternallyVisible,
3148 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
3149 generatedRefs, /*OpenMPSimd*/ false, targetTriple,
3150 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
3151 gVal->getType(), gVal);
3153 if (ompBuilder->Config.isTargetDevice() &&
3154 (attribute.getCaptureClause().getValue() !=
3155 mlir::omp::DeclareTargetCaptureClause::to ||
3156 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3157 ompBuilder->getAddrOfDeclareTargetVar(
3158 captureClause, deviceClause, isDeclaration, isExternallyVisible,
3159 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
3160 generatedRefs, /*OpenMPSimd*/ false, targetTriple, gVal->getType(),
3161 /*GlobalInitializer*/ nullptr,
3162 /*VariableLinkage*/ nullptr);
3167 return success();
3170 // Returns true if the operation is inside a TargetOp or
3171 // is part of a declare target function.
3172 static bool isTargetDeviceOp(Operation *op) {
3173 // Assumes no reverse offloading
3174 if (op->getParentOfType<omp::TargetOp>())
3175 return true;
3177 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
3178 if (auto declareTargetIface =
3179 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
3180 parentFn.getOperation()))
3181 if (declareTargetIface.isDeclareTarget() &&
3182 declareTargetIface.getDeclareTargetDeviceType() !=
3183 mlir::omp::DeclareTargetDeviceType::host)
3184 return true;
3186 return false;
3189 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
3190 /// (including OpenMP runtime calls).
3191 static LogicalResult
3192 convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
3193 LLVM::ModuleTranslation &moduleTranslation) {
3195 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3197 return llvm::TypeSwitch<Operation *, LogicalResult>(op)
3198 .Case([&](omp::BarrierOp) {
3199 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
3200 return success();
3202 .Case([&](omp::TaskwaitOp) {
3203 ompBuilder->createTaskwait(builder.saveIP());
3204 return success();
3206 .Case([&](omp::TaskyieldOp) {
3207 ompBuilder->createTaskyield(builder.saveIP());
3208 return success();
3210 .Case([&](omp::FlushOp) {
3211 // No support in Openmp runtime function (__kmpc_flush) to accept
3212 // the argument list.
3213 // OpenMP standard states the following:
3214 // "An implementation may implement a flush with a list by ignoring
3215 // the list, and treating it the same as a flush without a list."
3217 // The argument list is discarded so that, flush with a list is treated
3218 // same as a flush without a list.
3219 ompBuilder->createFlush(builder.saveIP());
3220 return success();
3222 .Case([&](omp::ParallelOp op) {
3223 return convertOmpParallel(op, builder, moduleTranslation);
3225 .Case([&](omp::ReductionOp reductionOp) {
3226 return convertOmpReductionOp(reductionOp, builder, moduleTranslation);
3228 .Case([&](omp::MasterOp) {
3229 return convertOmpMaster(*op, builder, moduleTranslation);
3231 .Case([&](omp::CriticalOp) {
3232 return convertOmpCritical(*op, builder, moduleTranslation);
3234 .Case([&](omp::OrderedRegionOp) {
3235 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
3237 .Case([&](omp::OrderedOp) {
3238 return convertOmpOrdered(*op, builder, moduleTranslation);
3240 .Case([&](omp::WsloopOp) {
3241 return convertOmpWsloop(*op, builder, moduleTranslation);
3243 .Case([&](omp::SimdOp) {
3244 return convertOmpSimd(*op, builder, moduleTranslation);
3246 .Case([&](omp::AtomicReadOp) {
3247 return convertOmpAtomicRead(*op, builder, moduleTranslation);
3249 .Case([&](omp::AtomicWriteOp) {
3250 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
3252 .Case([&](omp::AtomicUpdateOp op) {
3253 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
3255 .Case([&](omp::AtomicCaptureOp op) {
3256 return convertOmpAtomicCapture(op, builder, moduleTranslation);
3258 .Case([&](omp::SectionsOp) {
3259 return convertOmpSections(*op, builder, moduleTranslation);
3261 .Case([&](omp::SingleOp op) {
3262 return convertOmpSingle(op, builder, moduleTranslation);
3264 .Case([&](omp::TeamsOp op) {
3265 return convertOmpTeams(op, builder, moduleTranslation);
3267 .Case([&](omp::TaskOp op) {
3268 return convertOmpTaskOp(op, builder, moduleTranslation);
3270 .Case([&](omp::TaskgroupOp op) {
3271 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
3273 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareReductionOp,
3274 omp::CriticalDeclareOp>([](auto op) {
3275 // `yield` and `terminator` can be just omitted. The block structure
3276 // was created in the region that handles their parent operation.
3277 // `declare_reduction` will be used by reductions and is not
3278 // converted directly, skip it.
3279 // `critical.declare` is only used to declare names of critical
3280 // sections which will be used by `critical` ops and hence can be
3281 // ignored for lowering. The OpenMP IRBuilder will create unique
3282 // name for critical section names.
3283 return success();
3285 .Case([&](omp::ThreadprivateOp) {
3286 return convertOmpThreadprivate(*op, builder, moduleTranslation);
3288 .Case<omp::TargetDataOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
3289 omp::TargetUpdateOp>([&](auto op) {
3290 return convertOmpTargetData(op, builder, moduleTranslation);
3292 .Case([&](omp::TargetOp) {
3293 return convertOmpTarget(*op, builder, moduleTranslation);
3295 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
3296 [&](auto op) {
3297 // No-op, should be handled by relevant owning operations e.g.
3298 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp etc.
3299 // and then discarded
3300 return success();
3302 .Default([&](Operation *inst) {
3303 return inst->emitError("unsupported OpenMP operation: ")
3304 << inst->getName();
3308 static LogicalResult
3309 convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
3310 LLVM::ModuleTranslation &moduleTranslation) {
3311 return convertHostOrTargetOperation(op, builder, moduleTranslation);
3314 static LogicalResult
3315 convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
3316 LLVM::ModuleTranslation &moduleTranslation) {
3317 if (isa<omp::TargetOp>(op))
3318 return convertOmpTarget(*op, builder, moduleTranslation);
3319 if (isa<omp::TargetDataOp>(op))
3320 return convertOmpTargetData(op, builder, moduleTranslation);
3321 bool interrupted =
3322 op->walk<WalkOrder::PreOrder>([&](Operation *oper) {
3323 if (isa<omp::TargetOp>(oper)) {
3324 if (failed(convertOmpTarget(*oper, builder, moduleTranslation)))
3325 return WalkResult::interrupt();
3326 return WalkResult::skip();
3328 if (isa<omp::TargetDataOp>(oper)) {
3329 if (failed(convertOmpTargetData(oper, builder, moduleTranslation)))
3330 return WalkResult::interrupt();
3331 return WalkResult::skip();
3333 return WalkResult::advance();
3334 }).wasInterrupted();
3335 return failure(interrupted);
3338 namespace {
3340 /// Implementation of the dialect interface that converts operations belonging
3341 /// to the OpenMP dialect to LLVM IR.
3342 class OpenMPDialectLLVMIRTranslationInterface
3343 : public LLVMTranslationDialectInterface {
3344 public:
3345 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
3347 /// Translates the given operation to LLVM IR using the provided IR builder
3348 /// and saving the state in `moduleTranslation`.
3349 LogicalResult
3350 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
3351 LLVM::ModuleTranslation &moduleTranslation) const final;
3353 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
3354 /// runtime calls, or operation amendments
3355 LogicalResult
3356 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
3357 NamedAttribute attribute,
3358 LLVM::ModuleTranslation &moduleTranslation) const final;
3361 } // namespace
3363 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
3364 Operation *op, ArrayRef<llvm::Instruction *> instructions,
3365 NamedAttribute attribute,
3366 LLVM::ModuleTranslation &moduleTranslation) const {
3367 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
3368 attribute.getName())
3369 .Case("omp.is_target_device",
3370 [&](Attribute attr) {
3371 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
3372 llvm::OpenMPIRBuilderConfig &config =
3373 moduleTranslation.getOpenMPBuilder()->Config;
3374 config.setIsTargetDevice(deviceAttr.getValue());
3375 return success();
3377 return failure();
3379 .Case("omp.is_gpu",
3380 [&](Attribute attr) {
3381 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
3382 llvm::OpenMPIRBuilderConfig &config =
3383 moduleTranslation.getOpenMPBuilder()->Config;
3384 config.setIsGPU(gpuAttr.getValue());
3385 return success();
3387 return failure();
3389 .Case("omp.host_ir_filepath",
3390 [&](Attribute attr) {
3391 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
3392 llvm::OpenMPIRBuilder *ompBuilder =
3393 moduleTranslation.getOpenMPBuilder();
3394 ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
3395 return success();
3397 return failure();
3399 .Case("omp.flags",
3400 [&](Attribute attr) {
3401 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
3402 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
3403 return failure();
3405 .Case("omp.version",
3406 [&](Attribute attr) {
3407 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
3408 llvm::OpenMPIRBuilder *ompBuilder =
3409 moduleTranslation.getOpenMPBuilder();
3410 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
3411 versionAttr.getVersion());
3412 return success();
3414 return failure();
3416 .Case("omp.declare_target",
3417 [&](Attribute attr) {
3418 if (auto declareTargetAttr =
3419 dyn_cast<omp::DeclareTargetAttr>(attr))
3420 return convertDeclareTargetAttr(op, declareTargetAttr,
3421 moduleTranslation);
3422 return failure();
3424 .Case("omp.requires",
3425 [&](Attribute attr) {
3426 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
3427 using Requires = omp::ClauseRequires;
3428 Requires flags = requiresAttr.getValue();
3429 llvm::OpenMPIRBuilderConfig &config =
3430 moduleTranslation.getOpenMPBuilder()->Config;
3431 config.setHasRequiresReverseOffload(
3432 bitEnumContainsAll(flags, Requires::reverse_offload));
3433 config.setHasRequiresUnifiedAddress(
3434 bitEnumContainsAll(flags, Requires::unified_address));
3435 config.setHasRequiresUnifiedSharedMemory(
3436 bitEnumContainsAll(flags, Requires::unified_shared_memory));
3437 config.setHasRequiresDynamicAllocators(
3438 bitEnumContainsAll(flags, Requires::dynamic_allocators));
3439 return success();
3441 return failure();
3443 .Default([](Attribute) {
3444 // Fall through for omp attributes that do not require lowering.
3445 return success();
3446 })(attribute.getValue());
3448 return failure();
3451 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
3452 /// (including OpenMP runtime calls).
3453 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
3454 Operation *op, llvm::IRBuilderBase &builder,
3455 LLVM::ModuleTranslation &moduleTranslation) const {
3457 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3458 if (ompBuilder->Config.isTargetDevice()) {
3459 if (isTargetDeviceOp(op)) {
3460 return convertTargetDeviceOp(op, builder, moduleTranslation);
3461 } else {
3462 return convertTargetOpsInNest(op, builder, moduleTranslation);
3465 return convertHostOrTargetOperation(op, builder, moduleTranslation);
3468 void mlir::registerOpenMPDialectTranslation(DialectRegistry &registry) {
3469 registry.insert<omp::OpenMPDialect>();
3470 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
3471 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
3475 void mlir::registerOpenMPDialectTranslation(MLIRContext &context) {
3476 DialectRegistry registry;
3477 registerOpenMPDialectTranslation(registry);
3478 context.appendDialectRegistry(registry);