[Clang] Prevent `mlink-builtin-bitcode` from internalizing the RPC client (#118661)
[llvm-project.git] / mlir / lib / Target / LLVMIR / Dialect / OpenMP / OpenMPToLLVMIRTranslation.cpp
blobeb873fd1b7f6fe110461e048d9a08b6b48499899
1 //===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenMP dialect and LLVM
10 // IR.
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
14 #include "mlir/Analysis/TopologicalSortUtils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
17 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
18 #include "mlir/IR/IRMapping.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
22 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
23 #include "mlir/Transforms/RegionUtils.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Frontend/OpenMP/OMPConstants.h"
29 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
30 #include "llvm/IR/DebugInfoMetadata.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/ReplaceConstant.h"
33 #include "llvm/Support/FileSystem.h"
34 #include "llvm/TargetParser/Triple.h"
35 #include "llvm/Transforms/Utils/ModuleUtils.h"
37 #include <any>
38 #include <cstdint>
39 #include <iterator>
40 #include <numeric>
41 #include <optional>
42 #include <utility>
44 using namespace mlir;
46 namespace {
47 static llvm::omp::ScheduleKind
48 convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
49 if (!schedKind.has_value())
50 return llvm::omp::OMP_SCHEDULE_Default;
51 switch (schedKind.value()) {
52 case omp::ClauseScheduleKind::Static:
53 return llvm::omp::OMP_SCHEDULE_Static;
54 case omp::ClauseScheduleKind::Dynamic:
55 return llvm::omp::OMP_SCHEDULE_Dynamic;
56 case omp::ClauseScheduleKind::Guided:
57 return llvm::omp::OMP_SCHEDULE_Guided;
58 case omp::ClauseScheduleKind::Auto:
59 return llvm::omp::OMP_SCHEDULE_Auto;
60 case omp::ClauseScheduleKind::Runtime:
61 return llvm::omp::OMP_SCHEDULE_Runtime;
63 llvm_unreachable("unhandled schedule clause argument");
66 /// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
67 /// insertion points for allocas.
68 class OpenMPAllocaStackFrame
69 : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
70 public:
71 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
73 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
74 : allocaInsertPoint(allocaIP) {}
75 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
78 /// ModuleTranslation stack frame containing the partial mapping between MLIR
79 /// values and their LLVM IR equivalents.
80 class OpenMPVarMappingStackFrame
81 : public LLVM::ModuleTranslation::StackFrameBase<
82 OpenMPVarMappingStackFrame> {
83 public:
84 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPVarMappingStackFrame)
86 explicit OpenMPVarMappingStackFrame(
87 const DenseMap<Value, llvm::Value *> &mapping)
88 : mapping(mapping) {}
90 DenseMap<Value, llvm::Value *> mapping;
93 /// Custom error class to signal translation errors that don't need reporting,
94 /// since encountering them will have already triggered relevant error messages.
95 ///
96 /// Its purpose is to serve as the glue between MLIR failures represented as
97 /// \see LogicalResult instances and \see llvm::Error instances used to
98 /// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
99 /// error of the first type is raised, a message is emitted directly (the \see
100 /// LogicalResult itself does not hold any information). If we need to forward
101 /// this error condition as an \see llvm::Error while avoiding triggering some
102 /// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
103 /// class to just signal this situation has happened.
105 /// For example, this class should be used to trigger errors from within
106 /// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
107 /// translation of their own regions. This unclutters the error log from
108 /// redundant messages.
109 class PreviouslyReportedError
110 : public llvm::ErrorInfo<PreviouslyReportedError> {
111 public:
112 void log(raw_ostream &) const override {
113 // Do not log anything.
116 std::error_code convertToErrorCode() const override {
117 llvm_unreachable(
118 "PreviouslyReportedError doesn't support ECError conversion");
121 // Used by ErrorInfo::classID.
122 static char ID;
125 char PreviouslyReportedError::ID = 0;
127 } // namespace
129 /// Looks up from the operation from and returns the PrivateClauseOp with
130 /// name symbolName
131 static omp::PrivateClauseOp findPrivatizer(Operation *from,
132 SymbolRefAttr symbolName) {
133 omp::PrivateClauseOp privatizer =
134 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
135 symbolName);
136 assert(privatizer && "privatizer not found in the symbol table");
137 return privatizer;
140 /// Check whether translation to LLVM IR for the given operation is currently
141 /// supported. If not, descriptive diagnostics will be emitted to let users know
142 /// this is a not-yet-implemented feature.
144 /// \returns success if no unimplemented features are needed to translate the
145 /// given operation.
146 static LogicalResult checkImplementationStatus(Operation &op) {
147 auto todo = [&op](StringRef clauseName) {
148 return op.emitError() << "not yet implemented: Unhandled clause "
149 << clauseName << " in " << op.getName()
150 << " operation";
153 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
154 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
155 result = todo("allocate");
157 auto checkBare = [&todo](auto op, LogicalResult &result) {
158 if (op.getBare())
159 result = todo("ompx_bare");
161 auto checkDepend = [&todo](auto op, LogicalResult &result) {
162 if (!op.getDependVars().empty() || op.getDependKinds())
163 result = todo("depend");
165 auto checkDevice = [&todo](auto op, LogicalResult &result) {
166 if (op.getDevice())
167 result = todo("device");
169 auto checkHasDeviceAddr = [&todo](auto op, LogicalResult &result) {
170 if (!op.getHasDeviceAddrVars().empty())
171 result = todo("has_device_addr");
173 auto checkHint = [](auto op, LogicalResult &) {
174 if (op.getHint())
175 op.emitWarning("hint clause discarded");
177 auto checkHostEval = [](auto op, LogicalResult &result) {
178 // Host evaluated clauses are supported, except for loop bounds.
179 for (BlockArgument arg :
180 cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs())
181 for (Operation *user : arg.getUsers())
182 if (isa<omp::LoopNestOp>(user))
183 result = op.emitError("not yet implemented: host evaluation of loop "
184 "bounds in omp.target operation");
186 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
187 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
188 op.getInReductionSyms())
189 result = todo("in_reduction");
191 auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) {
192 if (!op.getIsDevicePtrVars().empty())
193 result = todo("is_device_ptr");
195 auto checkLinear = [&todo](auto op, LogicalResult &result) {
196 if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
197 result = todo("linear");
199 auto checkNontemporal = [&todo](auto op, LogicalResult &result) {
200 if (!op.getNontemporalVars().empty())
201 result = todo("nontemporal");
203 auto checkNowait = [&todo](auto op, LogicalResult &result) {
204 if (op.getNowait())
205 result = todo("nowait");
207 auto checkOrder = [&todo](auto op, LogicalResult &result) {
208 if (op.getOrder() || op.getOrderMod())
209 result = todo("order");
211 auto checkParLevelSimd = [&todo](auto op, LogicalResult &result) {
212 if (op.getParLevelSimd())
213 result = todo("parallelization-level");
215 auto checkPriority = [&todo](auto op, LogicalResult &result) {
216 if (op.getPriority())
217 result = todo("priority");
219 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
220 if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
221 // Privatization clauses are supported, except on some situations, so we
222 // need to check here whether any of these unsupported cases are being
223 // translated.
224 if (std::optional<ArrayAttr> privateSyms = op.getPrivateSyms()) {
225 for (Attribute privatizerNameAttr : *privateSyms) {
226 omp::PrivateClauseOp privatizer = findPrivatizer(
227 op.getOperation(), cast<SymbolRefAttr>(privatizerNameAttr));
229 if (privatizer.getDataSharingType() ==
230 omp::DataSharingClauseType::FirstPrivate)
231 result = todo("firstprivate");
234 } else {
235 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
236 result = todo("privatization");
239 auto checkReduction = [&todo](auto op, LogicalResult &result) {
240 if (isa<omp::TeamsOp>(op) || isa<omp::SimdOp>(op))
241 if (!op.getReductionVars().empty() || op.getReductionByref() ||
242 op.getReductionSyms())
243 result = todo("reduction");
244 if (op.getReductionMod() &&
245 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
246 result = todo("reduction with modifier");
248 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
249 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
250 op.getTaskReductionSyms())
251 result = todo("task_reduction");
253 auto checkUntied = [&todo](auto op, LogicalResult &result) {
254 if (op.getUntied())
255 result = todo("untied");
258 LogicalResult result = success();
259 llvm::TypeSwitch<Operation &>(op)
260 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
261 .Case([&](omp::SectionsOp op) {
262 checkAllocate(op, result);
263 checkPrivate(op, result);
264 checkReduction(op, result);
266 .Case([&](omp::SingleOp op) {
267 checkAllocate(op, result);
268 checkPrivate(op, result);
270 .Case([&](omp::TeamsOp op) {
271 checkAllocate(op, result);
272 checkPrivate(op, result);
273 checkReduction(op, result);
275 .Case([&](omp::TaskOp op) {
276 checkAllocate(op, result);
277 checkInReduction(op, result);
279 .Case([&](omp::TaskgroupOp op) {
280 checkAllocate(op, result);
281 checkTaskReduction(op, result);
283 .Case([&](omp::TaskwaitOp op) {
284 checkDepend(op, result);
285 checkNowait(op, result);
287 .Case([&](omp::TaskloopOp op) {
288 // TODO: Add other clauses check
289 checkUntied(op, result);
290 checkPriority(op, result);
292 .Case([&](omp::WsloopOp op) {
293 checkAllocate(op, result);
294 checkLinear(op, result);
295 checkOrder(op, result);
296 checkReduction(op, result);
298 .Case([&](omp::ParallelOp op) {
299 checkAllocate(op, result);
300 checkReduction(op, result);
302 .Case([&](omp::SimdOp op) {
303 checkLinear(op, result);
304 checkNontemporal(op, result);
305 checkReduction(op, result);
307 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
308 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
309 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
310 [&](auto op) { checkDepend(op, result); })
311 .Case([&](omp::TargetOp op) {
312 checkAllocate(op, result);
313 checkBare(op, result);
314 checkDevice(op, result);
315 checkHasDeviceAddr(op, result);
316 checkHostEval(op, result);
317 checkInReduction(op, result);
318 checkIsDevicePtr(op, result);
319 checkPrivate(op, result);
321 .Default([](Operation &) {
322 // Assume all clauses for an operation can be translated unless they are
323 // checked above.
325 return result;
328 static LogicalResult handleError(llvm::Error error, Operation &op) {
329 LogicalResult result = success();
330 if (error) {
331 llvm::handleAllErrors(
332 std::move(error),
333 [&](const PreviouslyReportedError &) { result = failure(); },
334 [&](const llvm::ErrorInfoBase &err) {
335 result = op.emitError(err.message());
338 return result;
341 template <typename T>
342 static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
343 if (!result)
344 return handleError(result.takeError(), op);
346 return success();
349 /// Find the insertion point for allocas given the current insertion point for
350 /// normal operations in the builder.
351 static llvm::OpenMPIRBuilder::InsertPointTy
352 findAllocaInsertPoint(llvm::IRBuilderBase &builder,
353 const LLVM::ModuleTranslation &moduleTranslation) {
354 // If there is an alloca insertion point on stack, i.e. we are in a nested
355 // operation and a specific point was provided by some surrounding operation,
356 // use it.
357 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
358 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
359 [&](const OpenMPAllocaStackFrame &frame) {
360 allocaInsertPoint = frame.allocaInsertPoint;
361 return WalkResult::interrupt();
363 if (walkResult.wasInterrupted())
364 return allocaInsertPoint;
366 // Otherwise, insert to the entry block of the surrounding function.
367 // If the current IRBuilder InsertPoint is the function's entry, it cannot
368 // also be used for alloca insertion which would result in insertion order
369 // confusion. Create a new BasicBlock for the Builder and use the entry block
370 // for the allocs.
371 // TODO: Create a dedicated alloca BasicBlock at function creation such that
372 // we do not need to move the current InertPoint here.
373 if (builder.GetInsertBlock() ==
374 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
375 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
376 "Assuming end of basic block");
377 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
378 builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
379 builder.GetInsertBlock()->getNextNode());
380 builder.CreateBr(entryBB);
381 builder.SetInsertPoint(entryBB);
384 llvm::BasicBlock &funcEntryBlock =
385 builder.GetInsertBlock()->getParent()->getEntryBlock();
386 return llvm::OpenMPIRBuilder::InsertPointTy(
387 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
390 /// Converts the given region that appears within an OpenMP dialect operation to
391 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
392 /// region, and a branch from any block with an successor-less OpenMP terminator
393 /// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
394 /// of the continuation block if provided.
395 static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
396 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
397 LLVM::ModuleTranslation &moduleTranslation,
398 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
399 llvm::BasicBlock *continuationBlock =
400 splitBB(builder, true, "omp.region.cont");
401 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
403 llvm::LLVMContext &llvmContext = builder.getContext();
404 for (Block &bb : region) {
405 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
406 llvmContext, blockName, builder.GetInsertBlock()->getParent(),
407 builder.GetInsertBlock()->getNextNode());
408 moduleTranslation.mapBlock(&bb, llvmBB);
411 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
413 // Terminators (namely YieldOp) may be forwarding values to the region that
414 // need to be available in the continuation block. Collect the types of these
415 // operands in preparation of creating PHI nodes.
416 SmallVector<llvm::Type *> continuationBlockPHITypes;
417 bool operandsProcessed = false;
418 unsigned numYields = 0;
419 for (Block &bb : region.getBlocks()) {
420 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
421 if (!operandsProcessed) {
422 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
423 continuationBlockPHITypes.push_back(
424 moduleTranslation.convertType(yield->getOperand(i).getType()));
426 operandsProcessed = true;
427 } else {
428 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
429 "mismatching number of values yielded from the region");
430 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
431 llvm::Type *operandType =
432 moduleTranslation.convertType(yield->getOperand(i).getType());
433 (void)operandType;
434 assert(continuationBlockPHITypes[i] == operandType &&
435 "values of mismatching types yielded from the region");
438 numYields++;
442 // Insert PHI nodes in the continuation block for any values forwarded by the
443 // terminators in this region.
444 if (!continuationBlockPHITypes.empty())
445 assert(
446 continuationBlockPHIs &&
447 "expected continuation block PHIs if converted regions yield values");
448 if (continuationBlockPHIs) {
449 llvm::IRBuilderBase::InsertPointGuard guard(builder);
450 continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
451 builder.SetInsertPoint(continuationBlock, continuationBlock->begin());
452 for (llvm::Type *ty : continuationBlockPHITypes)
453 continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
456 // Convert blocks one by one in topological order to ensure
457 // defs are converted before uses.
458 SetVector<Block *> blocks = getBlocksSortedByDominance(region);
459 for (Block *bb : blocks) {
460 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
461 // Retarget the branch of the entry block to the entry block of the
462 // converted region (regions are single-entry).
463 if (bb->isEntryBlock()) {
464 assert(sourceTerminator->getNumSuccessors() == 1 &&
465 "provided entry block has multiple successors");
466 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
467 "ContinuationBlock is not the successor of the entry block");
468 sourceTerminator->setSuccessor(0, llvmBB);
471 llvm::IRBuilderBase::InsertPointGuard guard(builder);
472 if (failed(
473 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
474 return llvm::make_error<PreviouslyReportedError>();
476 // Special handling for `omp.yield` and `omp.terminator` (we may have more
477 // than one): they return the control to the parent OpenMP dialect operation
478 // so replace them with the branch to the continuation block. We handle this
479 // here to avoid relying inter-function communication through the
480 // ModuleTranslation class to set up the correct insertion point. This is
481 // also consistent with MLIR's idiom of handling special region terminators
482 // in the same code that handles the region-owning operation.
483 Operation *terminator = bb->getTerminator();
484 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
485 builder.CreateBr(continuationBlock);
487 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
488 (*continuationBlockPHIs)[i]->addIncoming(
489 moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
492 // After all blocks have been traversed and values mapped, connect the PHI
493 // nodes to the results of preceding blocks.
494 LLVM::detail::connectPHINodes(region, moduleTranslation);
496 // Remove the blocks and values defined in this region from the mapping since
497 // they are not visible outside of this region. This allows the same region to
498 // be converted several times, that is cloned, without clashes, and slightly
499 // speeds up the lookups.
500 moduleTranslation.forgetMapping(region);
502 return continuationBlock;
505 /// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
506 static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
507 switch (kind) {
508 case omp::ClauseProcBindKind::Close:
509 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
510 case omp::ClauseProcBindKind::Master:
511 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
512 case omp::ClauseProcBindKind::Primary:
513 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
514 case omp::ClauseProcBindKind::Spread:
515 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
517 llvm_unreachable("Unknown ClauseProcBindKind kind");
520 /// Helper function to map block arguments defined by ignored loop wrappers to
521 /// LLVM values and prevent any uses of those from triggering null pointer
522 /// dereferences.
524 /// This must be called after block arguments of parent wrappers have already
525 /// been mapped to LLVM IR values.
526 static LogicalResult
527 convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
528 LLVM::ModuleTranslation &moduleTranslation) {
529 // Map block arguments directly to the LLVM value associated to the
530 // corresponding operand. This is semantically equivalent to this wrapper not
531 // being present.
532 auto forwardArgs =
533 [&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs,
534 OperandRange operands) {
535 for (auto [arg, var] : llvm::zip_equal(blockArgs, operands))
536 moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
539 return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
540 .Case([&](omp::SimdOp op) {
541 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
542 forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars());
543 forwardArgs(blockArgIface.getReductionBlockArgs(),
544 op.getReductionVars());
545 op.emitWarning() << "simd information on composite construct discarded";
546 return success();
548 .Default([&](Operation *op) {
549 return op->emitError() << "cannot ignore nested wrapper";
553 /// Helper function to call \c convertIgnoredWrapper() for all wrappers of the
554 /// given \c loopOp nested inside of \c parentOp. This has the effect of mapping
555 /// entry block arguments defined by these operations to outside values.
557 /// It must be called after block arguments of \c parentOp have already been
558 /// mapped themselves.
559 static LogicalResult
560 convertIgnoredWrappers(omp::LoopNestOp loopOp,
561 omp::LoopWrapperInterface parentOp,
562 LLVM::ModuleTranslation &moduleTranslation) {
563 SmallVector<omp::LoopWrapperInterface> wrappers;
564 loopOp.gatherWrappers(wrappers);
566 // Process wrappers nested inside of `parentOp` from outermost to innermost.
567 for (auto it =
568 std::next(std::find(wrappers.rbegin(), wrappers.rend(), parentOp));
569 it != wrappers.rend(); ++it) {
570 if (failed(convertIgnoredWrapper(*it, moduleTranslation)))
571 return failure();
574 return success();
577 /// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
578 static LogicalResult
579 convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
580 LLVM::ModuleTranslation &moduleTranslation) {
581 auto maskedOp = cast<omp::MaskedOp>(opInst);
582 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
584 if (failed(checkImplementationStatus(opInst)))
585 return failure();
587 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
588 // MaskedOp has only one region associated with it.
589 auto &region = maskedOp.getRegion();
590 builder.restoreIP(codeGenIP);
591 return convertOmpOpRegions(region, "omp.masked.region", builder,
592 moduleTranslation)
593 .takeError();
596 // TODO: Perform finalization actions for variables. This has to be
597 // called for variables which have destructors/finalizers.
598 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
600 llvm::Value *filterVal = nullptr;
601 if (auto filterVar = maskedOp.getFilteredThreadId()) {
602 filterVal = moduleTranslation.lookupValue(filterVar);
603 } else {
604 llvm::LLVMContext &llvmContext = builder.getContext();
605 filterVal =
606 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
608 assert(filterVal != nullptr);
609 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
610 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
611 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
612 finiCB, filterVal);
614 if (failed(handleError(afterIP, opInst)))
615 return failure();
617 builder.restoreIP(*afterIP);
618 return success();
621 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
622 static LogicalResult
623 convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
624 LLVM::ModuleTranslation &moduleTranslation) {
625 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
626 auto masterOp = cast<omp::MasterOp>(opInst);
628 if (failed(checkImplementationStatus(opInst)))
629 return failure();
631 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
632 // MasterOp has only one region associated with it.
633 auto &region = masterOp.getRegion();
634 builder.restoreIP(codeGenIP);
635 return convertOmpOpRegions(region, "omp.master.region", builder,
636 moduleTranslation)
637 .takeError();
640 // TODO: Perform finalization actions for variables. This has to be
641 // called for variables which have destructors/finalizers.
642 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
644 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
645 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
646 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
647 finiCB);
649 if (failed(handleError(afterIP, opInst)))
650 return failure();
652 builder.restoreIP(*afterIP);
653 return success();
656 /// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
657 static LogicalResult
658 convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
659 LLVM::ModuleTranslation &moduleTranslation) {
660 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
661 auto criticalOp = cast<omp::CriticalOp>(opInst);
663 if (failed(checkImplementationStatus(opInst)))
664 return failure();
666 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
667 // CriticalOp has only one region associated with it.
668 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
669 builder.restoreIP(codeGenIP);
670 return convertOmpOpRegions(region, "omp.critical.region", builder,
671 moduleTranslation)
672 .takeError();
675 // TODO: Perform finalization actions for variables. This has to be
676 // called for variables which have destructors/finalizers.
677 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
679 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
680 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
681 llvm::Constant *hint = nullptr;
683 // If it has a name, it probably has a hint too.
684 if (criticalOp.getNameAttr()) {
685 // The verifiers in OpenMP Dialect guarentee that all the pointers are
686 // non-null
687 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
688 auto criticalDeclareOp =
689 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
690 symbolRef);
691 hint =
692 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
693 static_cast<int>(criticalDeclareOp.getHint()));
695 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
696 moduleTranslation.getOpenMPBuilder()->createCritical(
697 ompLoc, bodyGenCB, finiCB, criticalOp.getName().value_or(""), hint);
699 if (failed(handleError(afterIP, opInst)))
700 return failure();
702 builder.restoreIP(*afterIP);
703 return success();
706 /// Populates `privatizations` with privatization declarations used for the
707 /// given op.
708 template <class OP>
709 static void collectPrivatizationDecls(
710 OP op, SmallVectorImpl<omp::PrivateClauseOp> &privatizations) {
711 std::optional<ArrayAttr> attr = op.getPrivateSyms();
712 if (!attr)
713 return;
715 privatizations.reserve(privatizations.size() + attr->size());
716 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
717 privatizations.push_back(findPrivatizer(op, symbolRef));
721 /// Populates `reductions` with reduction declarations used in the given op.
722 template <typename T>
723 static void
724 collectReductionDecls(T op,
725 SmallVectorImpl<omp::DeclareReductionOp> &reductions) {
726 std::optional<ArrayAttr> attr = op.getReductionSyms();
727 if (!attr)
728 return;
730 reductions.reserve(reductions.size() + op.getNumReductionVars());
731 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
732 reductions.push_back(
733 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
734 op, symbolRef));
738 /// Translates the blocks contained in the given region and appends them to at
739 /// the current insertion point of `builder`. The operations of the entry block
740 /// are appended to the current insertion block. If set, `continuationBlockArgs`
741 /// is populated with translated values that correspond to the values
742 /// omp.yield'ed from the region.
743 static LogicalResult inlineConvertOmpRegions(
744 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
745 LLVM::ModuleTranslation &moduleTranslation,
746 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
747 if (region.empty())
748 return success();
750 // Special case for single-block regions that don't create additional blocks:
751 // insert operations without creating additional blocks.
752 if (llvm::hasSingleElement(region)) {
753 llvm::Instruction *potentialTerminator =
754 builder.GetInsertBlock()->empty() ? nullptr
755 : &builder.GetInsertBlock()->back();
757 if (potentialTerminator && potentialTerminator->isTerminator())
758 potentialTerminator->removeFromParent();
759 moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
761 if (failed(moduleTranslation.convertBlock(
762 region.front(), /*ignoreArguments=*/true, builder)))
763 return failure();
765 // The continuation arguments are simply the translated terminator operands.
766 if (continuationBlockArgs)
767 llvm::append_range(
768 *continuationBlockArgs,
769 moduleTranslation.lookupValues(region.front().back().getOperands()));
771 // Drop the mapping that is no longer necessary so that the same region can
772 // be processed multiple times.
773 moduleTranslation.forgetMapping(region);
775 if (potentialTerminator && potentialTerminator->isTerminator()) {
776 llvm::BasicBlock *block = builder.GetInsertBlock();
777 if (block->empty()) {
778 // this can happen for really simple reduction init regions e.g.
779 // %0 = llvm.mlir.constant(0 : i32) : i32
780 // omp.yield(%0 : i32)
781 // because the llvm.mlir.constant (MLIR op) isn't converted into any
782 // llvm op
783 potentialTerminator->insertInto(block, block->begin());
784 } else {
785 potentialTerminator->insertAfter(&block->back());
789 return success();
792 SmallVector<llvm::PHINode *> phis;
793 llvm::Expected<llvm::BasicBlock *> continuationBlock =
794 convertOmpOpRegions(region, blockName, builder, moduleTranslation, &phis);
796 if (failed(handleError(continuationBlock, *region.getParentOp())))
797 return failure();
799 if (continuationBlockArgs)
800 llvm::append_range(*continuationBlockArgs, phis);
801 builder.SetInsertPoint(*continuationBlock,
802 (*continuationBlock)->getFirstInsertionPt());
803 return success();
806 namespace {
807 /// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
808 /// store lambdas with capture.
809 using OwningReductionGen =
810 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
811 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
812 llvm::Value *&)>;
813 using OwningAtomicReductionGen =
814 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
815 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
816 llvm::Value *)>;
817 } // namespace
819 /// Create an OpenMPIRBuilder-compatible reduction generator for the given
820 /// reduction declaration. The generator uses `builder` but ignores its
821 /// insertion point.
822 static OwningReductionGen
823 makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
824 LLVM::ModuleTranslation &moduleTranslation) {
825 // The lambda is mutable because we need access to non-const methods of decl
826 // (which aren't actually mutating it), and we must capture decl by-value to
827 // avoid the dangling reference after the parent function returns.
828 OwningReductionGen gen =
829 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
830 llvm::Value *lhs, llvm::Value *rhs,
831 llvm::Value *&result) mutable
832 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
833 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
834 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
835 builder.restoreIP(insertPoint);
836 SmallVector<llvm::Value *> phis;
837 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
838 "omp.reduction.nonatomic.body", builder,
839 moduleTranslation, &phis)))
840 return llvm::createStringError(
841 "failed to inline `combiner` region of `omp.declare_reduction`");
842 assert(phis.size() == 1);
843 result = phis[0];
844 return builder.saveIP();
846 return gen;
849 /// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
850 /// given reduction declaration. The generator uses `builder` but ignores its
851 /// insertion point. Returns null if there is no atomic region available in the
852 /// reduction declaration.
853 static OwningAtomicReductionGen
854 makeAtomicReductionGen(omp::DeclareReductionOp decl,
855 llvm::IRBuilderBase &builder,
856 LLVM::ModuleTranslation &moduleTranslation) {
857 if (decl.getAtomicReductionRegion().empty())
858 return OwningAtomicReductionGen();
860 // The lambda is mutable because we need access to non-const methods of decl
861 // (which aren't actually mutating it), and we must capture decl by-value to
862 // avoid the dangling reference after the parent function returns.
863 OwningAtomicReductionGen atomicGen =
864 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
865 llvm::Value *lhs, llvm::Value *rhs) mutable
866 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
867 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
868 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
869 builder.restoreIP(insertPoint);
870 SmallVector<llvm::Value *> phis;
871 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
872 "omp.reduction.atomic.body", builder,
873 moduleTranslation, &phis)))
874 return llvm::createStringError(
875 "failed to inline `atomic` region of `omp.declare_reduction`");
876 assert(phis.empty());
877 return builder.saveIP();
879 return atomicGen;
882 /// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
883 static LogicalResult
884 convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
885 LLVM::ModuleTranslation &moduleTranslation) {
886 auto orderedOp = cast<omp::OrderedOp>(opInst);
888 if (failed(checkImplementationStatus(opInst)))
889 return failure();
891 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
892 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
893 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
894 SmallVector<llvm::Value *> vecValues =
895 moduleTranslation.lookupValues(orderedOp.getDoacrossDependVars());
897 size_t indexVecValues = 0;
898 while (indexVecValues < vecValues.size()) {
899 SmallVector<llvm::Value *> storeValues;
900 storeValues.reserve(numLoops);
901 for (unsigned i = 0; i < numLoops; i++) {
902 storeValues.push_back(vecValues[indexVecValues]);
903 indexVecValues++;
905 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
906 findAllocaInsertPoint(builder, moduleTranslation);
907 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
908 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
909 ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
911 return success();
914 /// Converts an OpenMP 'ordered_region' operation into LLVM IR using
915 /// OpenMPIRBuilder.
916 static LogicalResult
917 convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
918 LLVM::ModuleTranslation &moduleTranslation) {
919 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
920 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
922 if (failed(checkImplementationStatus(opInst)))
923 return failure();
925 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
926 // OrderedOp has only one region associated with it.
927 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
928 builder.restoreIP(codeGenIP);
929 return convertOmpOpRegions(region, "omp.ordered.region", builder,
930 moduleTranslation)
931 .takeError();
934 // TODO: Perform finalization actions for variables. This has to be
935 // called for variables which have destructors/finalizers.
936 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
938 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
939 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
940 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
941 ompLoc, bodyGenCB, finiCB, !orderedRegionOp.getParLevelSimd());
943 if (failed(handleError(afterIP, opInst)))
944 return failure();
946 builder.restoreIP(*afterIP);
947 return success();
950 namespace {
951 /// Contains the arguments for an LLVM store operation
952 struct DeferredStore {
953 DeferredStore(llvm::Value *value, llvm::Value *address)
954 : value(value), address(address) {}
956 llvm::Value *value;
957 llvm::Value *address;
959 } // namespace
961 /// Allocate space for privatized reduction variables.
962 /// `deferredStores` contains information to create store operations which needs
963 /// to be inserted after all allocas
964 template <typename T>
965 static LogicalResult
966 allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
967 llvm::IRBuilderBase &builder,
968 LLVM::ModuleTranslation &moduleTranslation,
969 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
970 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
971 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
972 DenseMap<Value, llvm::Value *> &reductionVariableMap,
973 SmallVectorImpl<DeferredStore> &deferredStores,
974 llvm::ArrayRef<bool> isByRefs) {
975 llvm::IRBuilderBase::InsertPointGuard guard(builder);
976 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
978 // delay creating stores until after all allocas
979 deferredStores.reserve(loop.getNumReductionVars());
981 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
982 Region &allocRegion = reductionDecls[i].getAllocRegion();
983 if (isByRefs[i]) {
984 if (allocRegion.empty())
985 continue;
987 SmallVector<llvm::Value *, 1> phis;
988 if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
989 builder, moduleTranslation, &phis)))
990 return loop.emitError(
991 "failed to inline `alloc` region of `omp.declare_reduction`");
993 assert(phis.size() == 1 && "expected one allocation to be yielded");
994 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
996 // Allocate reduction variable (which is a pointer to the real reduction
997 // variable allocated in the inlined region)
998 llvm::Value *var = builder.CreateAlloca(
999 moduleTranslation.convertType(reductionDecls[i].getType()));
1000 deferredStores.emplace_back(phis[0], var);
1002 privateReductionVariables[i] = var;
1003 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1004 reductionVariableMap.try_emplace(loop.getReductionVars()[i], phis[0]);
1005 } else {
1006 assert(allocRegion.empty() &&
1007 "allocaction is implicit for by-val reduction");
1008 llvm::Value *var = builder.CreateAlloca(
1009 moduleTranslation.convertType(reductionDecls[i].getType()));
1010 moduleTranslation.mapValue(reductionArgs[i], var);
1011 privateReductionVariables[i] = var;
1012 reductionVariableMap.try_emplace(loop.getReductionVars()[i], var);
1016 return success();
1019 /// Map input arguments to reduction initialization region
1020 template <typename T>
1021 static void
1022 mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
1023 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1024 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1025 unsigned i) {
1026 // map input argument to the initialization region
1027 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1028 Region &initializerRegion = reduction.getInitializerRegion();
1029 Block &entry = initializerRegion.front();
1031 mlir::Value mlirSource = loop.getReductionVars()[i];
1032 llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
1033 assert(llvmSource && "lookup reduction var");
1034 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource);
1036 if (entry.getNumArguments() > 1) {
1037 llvm::Value *allocation =
1038 reductionVariableMap.lookup(loop.getReductionVars()[i]);
1039 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1043 /// Inline reductions' `init` regions. This functions assumes that the
1044 /// `builder`'s insertion point is where the user wants the `init` regions to be
1045 /// inlined; i.e. it does not try to find a proper insertion location for the
1046 /// `init` regions. It also leaves the `builder's insertions point in a state
1047 /// where the user can continue the code-gen directly afterwards.
1048 template <typename OP>
1049 static LogicalResult
1050 initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1051 llvm::IRBuilderBase &builder,
1052 LLVM::ModuleTranslation &moduleTranslation,
1053 llvm::BasicBlock *latestAllocaBlock,
1054 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1055 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1056 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1057 llvm::ArrayRef<bool> isByRef,
1058 SmallVectorImpl<DeferredStore> &deferredStores) {
1059 if (op.getNumReductionVars() == 0)
1060 return success();
1062 llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
1063 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1064 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1065 builder.restoreIP(allocaIP);
1066 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1068 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1069 if (isByRef[i]) {
1070 if (!reductionDecls[i].getAllocRegion().empty())
1071 continue;
1073 // TODO: remove after all users of by-ref are updated to use the alloc
1074 // region: Allocate reduction variable (which is a pointer to the real
1075 // reduciton variable allocated in the inlined region)
1076 byRefVars[i] = builder.CreateAlloca(
1077 moduleTranslation.convertType(reductionDecls[i].getType()));
1081 if (initBlock->empty() || initBlock->getTerminator() == nullptr)
1082 builder.SetInsertPoint(initBlock);
1083 else
1084 builder.SetInsertPoint(initBlock->getTerminator());
1086 // store result of the alloc region to the allocated pointer to the real
1087 // reduction variable
1088 for (auto [data, addr] : deferredStores)
1089 builder.CreateStore(data, addr);
1091 // Before the loop, store the initial values of reductions into reduction
1092 // variables. Although this could be done after allocas, we don't want to mess
1093 // up with the alloca insertion point.
1094 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1095 SmallVector<llvm::Value *, 1> phis;
1097 // map block argument to initializer region
1098 mapInitializationArgs(op, moduleTranslation, reductionDecls,
1099 reductionVariableMap, i);
1101 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1102 "omp.reduction.neutral", builder,
1103 moduleTranslation, &phis)))
1104 return failure();
1106 assert(phis.size() == 1 && "expected one value to be yielded from the "
1107 "reduction neutral element declaration region");
1109 if (builder.GetInsertBlock()->empty() ||
1110 builder.GetInsertBlock()->getTerminator() == nullptr)
1111 builder.SetInsertPoint(builder.GetInsertBlock());
1112 else
1113 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
1115 if (isByRef[i]) {
1116 if (!reductionDecls[i].getAllocRegion().empty())
1117 // done in allocReductionVars
1118 continue;
1120 // TODO: this path can be removed once all users of by-ref are updated to
1121 // use an alloc region
1123 // Store the result of the inlined region to the allocated reduction var
1124 // ptr
1125 builder.CreateStore(phis[0], byRefVars[i]);
1127 privateReductionVariables[i] = byRefVars[i];
1128 moduleTranslation.mapValue(reductionArgs[i], phis[0]);
1129 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1130 } else {
1131 // for by-ref case the store is inside of the reduction region
1132 builder.CreateStore(phis[0], privateReductionVariables[i]);
1133 // the rest was handled in allocByValReductionVars
1136 // forget the mapping for the initializer region because we might need a
1137 // different mapping if this reduction declaration is re-used for a
1138 // different variable
1139 moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
1142 return success();
1145 /// Collect reduction info
1146 template <typename T>
1147 static void collectReductionInfo(
1148 T loop, llvm::IRBuilderBase &builder,
1149 LLVM::ModuleTranslation &moduleTranslation,
1150 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1151 SmallVectorImpl<OwningReductionGen> &owningReductionGens,
1152 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1153 const ArrayRef<llvm::Value *> privateReductionVariables,
1154 SmallVectorImpl<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos) {
1155 unsigned numReductions = loop.getNumReductionVars();
1157 for (unsigned i = 0; i < numReductions; ++i) {
1158 owningReductionGens.push_back(
1159 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1160 owningAtomicReductionGens.push_back(
1161 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1164 // Collect the reduction information.
1165 reductionInfos.reserve(numReductions);
1166 for (unsigned i = 0; i < numReductions; ++i) {
1167 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1168 if (owningAtomicReductionGens[i])
1169 atomicGen = owningAtomicReductionGens[i];
1170 llvm::Value *variable =
1171 moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1172 reductionInfos.push_back(
1173 {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1174 privateReductionVariables[i],
1175 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1176 owningReductionGens[i],
1177 /*ReductionGenClang=*/nullptr, atomicGen});
1181 /// handling of DeclareReductionOp's cleanup region
1182 static LogicalResult
1183 inlineOmpRegionCleanup(llvm::SmallVectorImpl<Region *> &cleanupRegions,
1184 llvm::ArrayRef<llvm::Value *> privateVariables,
1185 LLVM::ModuleTranslation &moduleTranslation,
1186 llvm::IRBuilderBase &builder, StringRef regionName,
1187 bool shouldLoadCleanupRegionArg = true) {
1188 for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
1189 if (cleanupRegion->empty())
1190 continue;
1192 // map the argument to the cleanup region
1193 Block &entry = cleanupRegion->front();
1195 llvm::Instruction *potentialTerminator =
1196 builder.GetInsertBlock()->empty() ? nullptr
1197 : &builder.GetInsertBlock()->back();
1198 if (potentialTerminator && potentialTerminator->isTerminator())
1199 builder.SetInsertPoint(potentialTerminator);
1200 llvm::Value *privateVarValue =
1201 shouldLoadCleanupRegionArg
1202 ? builder.CreateLoad(
1203 moduleTranslation.convertType(entry.getArgument(0).getType()),
1204 privateVariables[i])
1205 : privateVariables[i];
1207 moduleTranslation.mapValue(entry.getArgument(0), privateVarValue);
1209 if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
1210 moduleTranslation)))
1211 return failure();
1213 // clear block argument mapping in case it needs to be re-created with a
1214 // different source for another use of the same reduction decl
1215 moduleTranslation.forgetMapping(*cleanupRegion);
1217 return success();
1220 // TODO: not used by ParallelOp
1221 template <class OP>
1222 static LogicalResult createReductionsAndCleanup(
1223 OP op, llvm::IRBuilderBase &builder,
1224 LLVM::ModuleTranslation &moduleTranslation,
1225 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1226 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1227 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef) {
1228 // Process the reductions if required.
1229 if (op.getNumReductionVars() == 0)
1230 return success();
1232 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1234 // Create the reduction generators. We need to own them here because
1235 // ReductionInfo only accepts references to the generators.
1236 SmallVector<OwningReductionGen> owningReductionGens;
1237 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1238 SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1239 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1240 owningReductionGens, owningAtomicReductionGens,
1241 privateReductionVariables, reductionInfos);
1243 // The call to createReductions below expects the block to have a
1244 // terminator. Create an unreachable instruction to serve as terminator
1245 // and remove it later.
1246 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1247 builder.SetInsertPoint(tempTerminator);
1248 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1249 ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1250 isByRef, op.getNowait());
1252 if (failed(handleError(contInsertPoint, *op)))
1253 return failure();
1255 if (!contInsertPoint->getBlock())
1256 return op->emitOpError() << "failed to convert reductions";
1258 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1259 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1261 if (failed(handleError(afterIP, *op)))
1262 return failure();
1264 tempTerminator->eraseFromParent();
1265 builder.restoreIP(*afterIP);
1267 // after the construct, deallocate private reduction variables
1268 SmallVector<Region *> reductionRegions;
1269 llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1270 [](omp::DeclareReductionOp reductionDecl) {
1271 return &reductionDecl.getCleanupRegion();
1273 return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
1274 moduleTranslation, builder,
1275 "omp.reduction.cleanup");
1276 return success();
1279 static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1280 if (!attr)
1281 return {};
1282 return *attr;
1285 // TODO: not used by omp.parallel
1286 template <typename OP>
1287 static LogicalResult allocAndInitializeReductionVars(
1288 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1289 LLVM::ModuleTranslation &moduleTranslation,
1290 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1291 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1292 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1293 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1294 llvm::ArrayRef<bool> isByRef) {
1295 if (op.getNumReductionVars() == 0)
1296 return success();
1298 SmallVector<DeferredStore> deferredStores;
1300 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1301 allocaIP, reductionDecls,
1302 privateReductionVariables, reductionVariableMap,
1303 deferredStores, isByRef)))
1304 return failure();
1306 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1307 allocaIP.getBlock(), reductionDecls,
1308 privateReductionVariables, reductionVariableMap,
1309 isByRef, deferredStores);
1312 /// Return the llvm::Value * corresponding to the `privateVar` that
1313 /// is being privatized. It isn't always as simple as looking up
1314 /// moduleTranslation with privateVar. For instance, in case of
1315 /// an allocatable, the descriptor for the allocatable is privatized.
1316 /// This descriptor is mapped using an MapInfoOp. So, this function
1317 /// will return a pointer to the llvm::Value corresponding to the
1318 /// block argument for the mapped descriptor.
1319 static llvm::Value *
1320 findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1321 LLVM::ModuleTranslation &moduleTranslation,
1322 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1323 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1324 return moduleTranslation.lookupValue(privateVar);
1326 Value blockArg = (*mappedPrivateVars)[privateVar];
1327 Type privVarType = privateVar.getType();
1328 Type blockArgType = blockArg.getType();
1329 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1330 "A block argument corresponding to a mapped var should have "
1331 "!llvm.ptr type");
1333 if (privVarType == blockArgType)
1334 return moduleTranslation.lookupValue(blockArg);
1336 // This typically happens when the privatized type is lowered from
1337 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1338 // struct/pair is passed by value. But, mapped values are passed only as
1339 // pointers, so before we privatize, we must load the pointer.
1340 if (!isa<LLVM::LLVMPointerType>(privVarType))
1341 return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1342 moduleTranslation.lookupValue(blockArg));
1344 return moduleTranslation.lookupValue(privateVar);
1347 /// Allocate delayed private variables. Returns the basic block which comes
1348 /// after all of these allocations. llvm::Value * for each of these private
1349 /// variables are populated in llvmPrivateVars.
1350 static llvm::Expected<llvm::BasicBlock *>
1351 allocatePrivateVars(llvm::IRBuilderBase &builder,
1352 LLVM::ModuleTranslation &moduleTranslation,
1353 MutableArrayRef<BlockArgument> privateBlockArgs,
1354 MutableArrayRef<omp::PrivateClauseOp> privateDecls,
1355 MutableArrayRef<mlir::Value> mlirPrivateVars,
1356 llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1357 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1358 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1359 // Allocate private vars
1360 llvm::BranchInst *allocaTerminator =
1361 llvm::cast<llvm::BranchInst>(allocaIP.getBlock()->getTerminator());
1362 splitBB(llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1363 allocaTerminator->getIterator()),
1364 true, "omp.region.after_alloca");
1366 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1367 // Update the allocaTerminator in case the alloca block was split above.
1368 allocaTerminator =
1369 llvm::cast<llvm::BranchInst>(allocaIP.getBlock()->getTerminator());
1370 builder.SetInsertPoint(allocaTerminator);
1371 assert(allocaTerminator->getNumSuccessors() == 1 &&
1372 "This is an unconditional branch created by OpenMPIRBuilder");
1374 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
1376 // FIXME: Some of the allocation regions do more than just allocating.
1377 // They read from their block argument (amongst other non-alloca things).
1378 // When OpenMPIRBuilder outlines the parallel region into a different
1379 // function it places the loads for live in-values (such as these block
1380 // arguments) at the end of the entry block (because the entry block is
1381 // assumed to contain only allocas). Therefore, if we put these complicated
1382 // alloc blocks in the entry block, these will not dominate the availability
1383 // of the live-in values they are using. Fix this by adding a latealloc
1384 // block after the entry block to put these in (this also helps to avoid
1385 // mixing non-alloca code with allocas).
1386 // Alloc regions which do not use the block argument can still be placed in
1387 // the entry block (therefore keeping the allocas together).
1388 llvm::BasicBlock *privAllocBlock = nullptr;
1389 if (!privateBlockArgs.empty())
1390 privAllocBlock = splitBB(builder, true, "omp.private.latealloc");
1391 for (auto [privDecl, mlirPrivVar, blockArg] :
1392 llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs)) {
1393 Region &allocRegion = privDecl.getAllocRegion();
1395 // map allocation region block argument
1396 llvm::Value *nonPrivateVar = findAssociatedValue(
1397 mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1398 assert(nonPrivateVar);
1399 moduleTranslation.mapValue(privDecl.getAllocMoldArg(), nonPrivateVar);
1401 // in-place convert the private allocation region
1402 SmallVector<llvm::Value *, 1> phis;
1403 if (privDecl.getAllocMoldArg().getUses().empty()) {
1404 // TODO this should use
1405 // allocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca() so it goes before
1406 // the code for fetching the thread id. Not doing this for now to avoid
1407 // test churn.
1408 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1409 } else {
1410 builder.SetInsertPoint(privAllocBlock->getTerminator());
1413 if (failed(inlineConvertOmpRegions(allocRegion, "omp.private.alloc",
1414 builder, moduleTranslation, &phis)))
1415 return llvm::createStringError(
1416 "failed to inline `alloc` region of `omp.private`");
1418 assert(phis.size() == 1 && "expected one allocation to be yielded");
1420 moduleTranslation.mapValue(blockArg, phis[0]);
1421 llvmPrivateVars.push_back(phis[0]);
1423 // clear alloc region block argument mapping in case it needs to be
1424 // re-created with a different source for another use of the same
1425 // reduction decl
1426 moduleTranslation.forgetMapping(allocRegion);
1428 return afterAllocas;
1431 static LogicalResult
1432 initFirstPrivateVars(llvm::IRBuilderBase &builder,
1433 LLVM::ModuleTranslation &moduleTranslation,
1434 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1435 SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1436 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls,
1437 llvm::BasicBlock *afterAllocas) {
1438 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1439 // Apply copy region for firstprivate.
1440 bool needsFirstprivate =
1441 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1442 return privOp.getDataSharingType() ==
1443 omp::DataSharingClauseType::FirstPrivate;
1446 if (!needsFirstprivate)
1447 return success();
1449 assert(afterAllocas->getSinglePredecessor());
1451 // Find the end of the allocation blocks
1452 builder.SetInsertPoint(afterAllocas->getSinglePredecessor()->getTerminator());
1453 llvm::BasicBlock *copyBlock =
1454 splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
1455 builder.SetInsertPoint(copyBlock->getFirstNonPHIOrDbgOrAlloca());
1457 for (auto [decl, mlirVar, llvmVar] :
1458 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1459 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1460 continue;
1462 // copyRegion implements `lhs = rhs`
1463 Region &copyRegion = decl.getCopyRegion();
1465 // map copyRegion rhs arg
1466 llvm::Value *nonPrivateVar = moduleTranslation.lookupValue(mlirVar);
1467 assert(nonPrivateVar);
1468 moduleTranslation.mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1470 // map copyRegion lhs arg
1471 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1473 // in-place convert copy region
1474 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
1475 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1476 moduleTranslation)))
1477 return decl.emitError("failed to inline `copy` region of `omp.private`");
1479 // ignore unused value yielded from copy region
1481 // clear copy region block argument mapping in case it needs to be
1482 // re-created with different sources for reuse of the same reduction
1483 // decl
1484 moduleTranslation.forgetMapping(copyRegion);
1487 return success();
1490 static LogicalResult
1491 cleanupPrivateVars(llvm::IRBuilderBase &builder,
1492 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1493 SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1494 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls) {
1495 // private variable deallocation
1496 SmallVector<Region *> privateCleanupRegions;
1497 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
1498 [](omp::PrivateClauseOp privatizer) {
1499 return &privatizer.getDeallocRegion();
1502 if (failed(inlineOmpRegionCleanup(
1503 privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
1504 "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
1505 return mlir::emitError(loc, "failed to inline `dealloc` region of an "
1506 "`omp.private` op in");
1508 return success();
1511 static LogicalResult
1512 convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1513 LLVM::ModuleTranslation &moduleTranslation) {
1514 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1515 using StorableBodyGenCallbackTy =
1516 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1518 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1520 if (failed(checkImplementationStatus(opInst)))
1521 return failure();
1523 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1524 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1526 SmallVector<omp::DeclareReductionOp> reductionDecls;
1527 collectReductionDecls(sectionsOp, reductionDecls);
1528 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1529 findAllocaInsertPoint(builder, moduleTranslation);
1531 SmallVector<llvm::Value *> privateReductionVariables(
1532 sectionsOp.getNumReductionVars());
1533 DenseMap<Value, llvm::Value *> reductionVariableMap;
1535 MutableArrayRef<BlockArgument> reductionArgs =
1536 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1538 if (failed(allocAndInitializeReductionVars(
1539 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1540 reductionDecls, privateReductionVariables, reductionVariableMap,
1541 isByRef)))
1542 return failure();
1544 // Store the mapping between reduction variables and their private copies on
1545 // ModuleTranslation stack. It can be then recovered when translating
1546 // omp.reduce operations in a separate call.
1547 LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
1548 moduleTranslation, reductionVariableMap);
1550 SmallVector<StorableBodyGenCallbackTy> sectionCBs;
1552 for (Operation &op : *sectionsOp.getRegion().begin()) {
1553 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1554 if (!sectionOp) // omp.terminator
1555 continue;
1557 Region &region = sectionOp.getRegion();
1558 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
1559 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1560 builder.restoreIP(codeGenIP);
1562 // map the omp.section reduction block argument to the omp.sections block
1563 // arguments
1564 // TODO: this assumes that the only block arguments are reduction
1565 // variables
1566 assert(region.getNumArguments() ==
1567 sectionsOp.getRegion().getNumArguments());
1568 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
1569 sectionsOp.getRegion().getArguments(), region.getArguments())) {
1570 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1571 assert(llvmVal);
1572 moduleTranslation.mapValue(sectionArg, llvmVal);
1575 return convertOmpOpRegions(region, "omp.section.region", builder,
1576 moduleTranslation)
1577 .takeError();
1579 sectionCBs.push_back(sectionCB);
1582 // No sections within omp.sections operation - skip generation. This situation
1583 // is only possible if there is only a terminator operation inside the
1584 // sections operation
1585 if (sectionCBs.empty())
1586 return success();
1588 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1590 // TODO: Perform appropriate actions according to the data-sharing
1591 // attribute (shared, private, firstprivate, ...) of variables.
1592 // Currently defaults to shared.
1593 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1594 llvm::Value &vPtr, llvm::Value *&replacementValue)
1595 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1596 replacementValue = &vPtr;
1597 return codeGenIP;
1600 // TODO: Perform finalization actions for variables. This has to be
1601 // called for variables which have destructors/finalizers.
1602 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1604 allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1605 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1606 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1607 moduleTranslation.getOpenMPBuilder()->createSections(
1608 ompLoc, allocaIP, sectionCBs, privCB, finiCB, false,
1609 sectionsOp.getNowait());
1611 if (failed(handleError(afterIP, opInst)))
1612 return failure();
1614 builder.restoreIP(*afterIP);
1616 // Process the reductions if required.
1617 return createReductionsAndCleanup(sectionsOp, builder, moduleTranslation,
1618 allocaIP, reductionDecls,
1619 privateReductionVariables, isByRef);
1622 /// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
1623 static LogicalResult
1624 convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
1625 LLVM::ModuleTranslation &moduleTranslation) {
1626 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1627 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1629 if (failed(checkImplementationStatus(*singleOp)))
1630 return failure();
1632 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1633 builder.restoreIP(codegenIP);
1634 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
1635 builder, moduleTranslation)
1636 .takeError();
1638 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1640 // Handle copyprivate
1641 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
1642 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1643 llvm::SmallVector<llvm::Value *> llvmCPVars;
1644 llvm::SmallVector<llvm::Function *> llvmCPFuncs;
1645 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
1646 llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
1647 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
1648 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1649 llvmCPFuncs.push_back(
1650 moduleTranslation.lookupFunction(llvmFuncOp.getName()));
1653 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1654 moduleTranslation.getOpenMPBuilder()->createSingle(
1655 ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars,
1656 llvmCPFuncs);
1658 if (failed(handleError(afterIP, *singleOp)))
1659 return failure();
1661 builder.restoreIP(*afterIP);
1662 return success();
1665 // Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
1666 static LogicalResult
1667 convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
1668 LLVM::ModuleTranslation &moduleTranslation) {
1669 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1670 if (failed(checkImplementationStatus(*op)))
1671 return failure();
1673 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1674 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
1675 moduleTranslation, allocaIP);
1676 builder.restoreIP(codegenIP);
1677 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
1678 moduleTranslation)
1679 .takeError();
1682 llvm::Value *numTeamsLower = nullptr;
1683 if (Value numTeamsLowerVar = op.getNumTeamsLower())
1684 numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
1686 llvm::Value *numTeamsUpper = nullptr;
1687 if (Value numTeamsUpperVar = op.getNumTeamsUpper())
1688 numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
1690 llvm::Value *threadLimit = nullptr;
1691 if (Value threadLimitVar = op.getThreadLimit())
1692 threadLimit = moduleTranslation.lookupValue(threadLimitVar);
1694 llvm::Value *ifExpr = nullptr;
1695 if (Value ifVar = op.getIfExpr())
1696 ifExpr = moduleTranslation.lookupValue(ifVar);
1698 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1699 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1700 moduleTranslation.getOpenMPBuilder()->createTeams(
1701 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
1703 if (failed(handleError(afterIP, *op)))
1704 return failure();
1706 builder.restoreIP(*afterIP);
1707 return success();
1710 static void
1711 buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
1712 LLVM::ModuleTranslation &moduleTranslation,
1713 SmallVectorImpl<llvm::OpenMPIRBuilder::DependData> &dds) {
1714 if (dependVars.empty())
1715 return;
1716 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
1717 llvm::omp::RTLDependenceKindTy type;
1718 switch (
1719 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
1720 case mlir::omp::ClauseTaskDepend::taskdependin:
1721 type = llvm::omp::RTLDependenceKindTy::DepIn;
1722 break;
1723 // The OpenMP runtime requires that the codegen for 'depend' clause for
1724 // 'out' dependency kind must be the same as codegen for 'depend' clause
1725 // with 'inout' dependency.
1726 case mlir::omp::ClauseTaskDepend::taskdependout:
1727 case mlir::omp::ClauseTaskDepend::taskdependinout:
1728 type = llvm::omp::RTLDependenceKindTy::DepInOut;
1729 break;
1730 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
1731 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
1732 break;
1733 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
1734 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
1735 break;
1737 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
1738 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
1739 dds.emplace_back(dd);
1743 /// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
1744 static LogicalResult
1745 convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
1746 LLVM::ModuleTranslation &moduleTranslation) {
1747 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1748 if (failed(checkImplementationStatus(*taskOp)))
1749 return failure();
1751 // Collect delayed privatisation declarations
1752 MutableArrayRef<BlockArgument> privateBlockArgs =
1753 cast<omp::BlockArgOpenMPOpInterface>(*taskOp).getPrivateBlockArgs();
1754 SmallVector<mlir::Value> mlirPrivateVars;
1755 SmallVector<llvm::Value *> llvmPrivateVars;
1756 SmallVector<omp::PrivateClauseOp> privateDecls;
1757 mlirPrivateVars.reserve(privateBlockArgs.size());
1758 llvmPrivateVars.reserve(privateBlockArgs.size());
1759 collectPrivatizationDecls(taskOp, privateDecls);
1760 for (mlir::Value privateVar : taskOp.getPrivateVars())
1761 mlirPrivateVars.push_back(privateVar);
1763 auto bodyCB = [&](InsertPointTy allocaIP,
1764 InsertPointTy codegenIP) -> llvm::Error {
1765 // Save the alloca insertion point on ModuleTranslation stack for use in
1766 // nested regions.
1767 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
1768 moduleTranslation, allocaIP);
1770 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
1771 builder, moduleTranslation, privateBlockArgs, privateDecls,
1772 mlirPrivateVars, llvmPrivateVars, allocaIP);
1773 if (handleError(afterAllocas, *taskOp).failed())
1774 return llvm::make_error<PreviouslyReportedError>();
1776 if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
1777 llvmPrivateVars, privateDecls,
1778 afterAllocas.get())))
1779 return llvm::make_error<PreviouslyReportedError>();
1781 // translate the body of the task:
1782 builder.restoreIP(codegenIP);
1783 auto continuationBlockOrError = convertOmpOpRegions(
1784 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
1785 if (failed(handleError(continuationBlockOrError, *taskOp)))
1786 return llvm::make_error<PreviouslyReportedError>();
1788 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
1790 if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
1791 llvmPrivateVars, privateDecls)))
1792 return llvm::make_error<PreviouslyReportedError>();
1794 return llvm::Error::success();
1797 SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
1798 buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
1799 moduleTranslation, dds);
1801 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1802 findAllocaInsertPoint(builder, moduleTranslation);
1803 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1804 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1805 moduleTranslation.getOpenMPBuilder()->createTask(
1806 ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
1807 moduleTranslation.lookupValue(taskOp.getFinal()),
1808 moduleTranslation.lookupValue(taskOp.getIfExpr()), dds,
1809 taskOp.getMergeable(),
1810 moduleTranslation.lookupValue(taskOp.getEventHandle()),
1811 moduleTranslation.lookupValue(taskOp.getPriority()));
1813 if (failed(handleError(afterIP, *taskOp)))
1814 return failure();
1816 builder.restoreIP(*afterIP);
1817 return success();
1820 /// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
1821 static LogicalResult
1822 convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
1823 LLVM::ModuleTranslation &moduleTranslation) {
1824 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1825 if (failed(checkImplementationStatus(*tgOp)))
1826 return failure();
1828 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1829 builder.restoreIP(codegenIP);
1830 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
1831 builder, moduleTranslation)
1832 .takeError();
1835 InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1836 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1837 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1838 moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP,
1839 bodyCB);
1841 if (failed(handleError(afterIP, *tgOp)))
1842 return failure();
1844 builder.restoreIP(*afterIP);
1845 return success();
1848 static LogicalResult
1849 convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
1850 LLVM::ModuleTranslation &moduleTranslation) {
1851 if (failed(checkImplementationStatus(*twOp)))
1852 return failure();
1854 moduleTranslation.getOpenMPBuilder()->createTaskwait(builder.saveIP());
1855 return success();
1858 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
1859 static LogicalResult
1860 convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
1861 LLVM::ModuleTranslation &moduleTranslation) {
1862 auto wsloopOp = cast<omp::WsloopOp>(opInst);
1863 if (failed(checkImplementationStatus(opInst)))
1864 return failure();
1866 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
1867 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
1868 assert(isByRef.size() == wsloopOp.getNumReductionVars());
1870 // Static is the default.
1871 auto schedule =
1872 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
1874 // Find the loop configuration.
1875 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[0]);
1876 llvm::Type *ivType = step->getType();
1877 llvm::Value *chunk = nullptr;
1878 if (wsloopOp.getScheduleChunk()) {
1879 llvm::Value *chunkVar =
1880 moduleTranslation.lookupValue(wsloopOp.getScheduleChunk());
1881 chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
1884 MutableArrayRef<BlockArgument> privateBlockArgs =
1885 cast<omp::BlockArgOpenMPOpInterface>(*wsloopOp).getPrivateBlockArgs();
1886 SmallVector<mlir::Value> mlirPrivateVars;
1887 SmallVector<llvm::Value *> llvmPrivateVars;
1888 SmallVector<omp::PrivateClauseOp> privateDecls;
1889 mlirPrivateVars.reserve(privateBlockArgs.size());
1890 llvmPrivateVars.reserve(privateBlockArgs.size());
1891 collectPrivatizationDecls(wsloopOp, privateDecls);
1893 for (mlir::Value privateVar : wsloopOp.getPrivateVars())
1894 mlirPrivateVars.push_back(privateVar);
1896 SmallVector<omp::DeclareReductionOp> reductionDecls;
1897 collectReductionDecls(wsloopOp, reductionDecls);
1898 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1899 findAllocaInsertPoint(builder, moduleTranslation);
1901 SmallVector<llvm::Value *> privateReductionVariables(
1902 wsloopOp.getNumReductionVars());
1904 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
1905 builder, moduleTranslation, privateBlockArgs, privateDecls,
1906 mlirPrivateVars, llvmPrivateVars, allocaIP);
1907 if (handleError(afterAllocas, opInst).failed())
1908 return failure();
1910 DenseMap<Value, llvm::Value *> reductionVariableMap;
1912 MutableArrayRef<BlockArgument> reductionArgs =
1913 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1915 SmallVector<DeferredStore> deferredStores;
1917 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
1918 moduleTranslation, allocaIP, reductionDecls,
1919 privateReductionVariables, reductionVariableMap,
1920 deferredStores, isByRef)))
1921 return failure();
1923 if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
1924 llvmPrivateVars, privateDecls,
1925 afterAllocas.get())))
1926 return failure();
1928 assert(afterAllocas.get()->getSinglePredecessor());
1929 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
1930 moduleTranslation,
1931 afterAllocas.get()->getSinglePredecessor(),
1932 reductionDecls, privateReductionVariables,
1933 reductionVariableMap, isByRef, deferredStores)))
1934 return failure();
1936 // TODO: Replace this with proper composite translation support.
1937 // Currently, all nested wrappers are ignored, so 'do/for simd' will be
1938 // treated the same as a standalone 'do/for'. This is allowed by the spec,
1939 // since it's equivalent to always using a SIMD length of 1.
1940 if (failed(convertIgnoredWrappers(loopOp, wsloopOp, moduleTranslation)))
1941 return failure();
1943 // Store the mapping between reduction variables and their private copies on
1944 // ModuleTranslation stack. It can be then recovered when translating
1945 // omp.reduce operations in a separate call.
1946 LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
1947 moduleTranslation, reductionVariableMap);
1949 // Set up the source location value for OpenMP runtime.
1950 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1952 // Generator of the canonical loop body.
1953 SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
1954 SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
1955 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
1956 llvm::Value *iv) -> llvm::Error {
1957 // Make sure further conversions know about the induction variable.
1958 moduleTranslation.mapValue(
1959 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
1961 // Capture the body insertion point for use in nested loops. BodyIP of the
1962 // CanonicalLoopInfo always points to the beginning of the entry block of
1963 // the body.
1964 bodyInsertPoints.push_back(ip);
1966 if (loopInfos.size() != loopOp.getNumLoops() - 1)
1967 return llvm::Error::success();
1969 // Convert the body of the loop.
1970 builder.restoreIP(ip);
1971 return convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder,
1972 moduleTranslation)
1973 .takeError();
1976 // Delegate actual loop construction to the OpenMP IRBuilder.
1977 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
1978 // loop, i.e. it has a positive step, uses signed integer semantics.
1979 // Reconsider this code when the nested loop operation clearly supports more
1980 // cases.
1981 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1982 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
1983 llvm::Value *lowerBound =
1984 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
1985 llvm::Value *upperBound =
1986 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
1987 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
1989 // Make sure loop trip count are emitted in the preheader of the outermost
1990 // loop at the latest so that they are all available for the new collapsed
1991 // loop will be created below.
1992 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
1993 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
1994 if (i != 0) {
1995 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
1996 computeIP = loopInfos.front()->getPreheaderIP();
1999 llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2000 ompBuilder->createCanonicalLoop(
2001 loc, bodyGen, lowerBound, upperBound, step,
2002 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
2004 if (failed(handleError(loopResult, *loopOp)))
2005 return failure();
2007 loopInfos.push_back(*loopResult);
2010 // Collapse loops. Store the insertion point because LoopInfos may get
2011 // invalidated.
2012 llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
2013 llvm::CanonicalLoopInfo *loopInfo =
2014 ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
2016 allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2018 // TODO: Handle doacross loops when the ordered clause has a parameter.
2019 bool isOrdered = wsloopOp.getOrdered().has_value();
2020 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2021 bool isSimd = wsloopOp.getScheduleSimd();
2023 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2024 ompBuilder->applyWorkshareLoop(
2025 ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
2026 convertToScheduleKind(schedule), chunk, isSimd,
2027 scheduleMod == omp::ScheduleModifier::monotonic,
2028 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered);
2030 if (failed(handleError(wsloopIP, opInst)))
2031 return failure();
2033 // Continue building IR after the loop. Note that the LoopInfo returned by
2034 // `collapseLoops` points inside the outermost loop and is intended for
2035 // potential further loop transformations. Use the insertion point stored
2036 // before collapsing loops instead.
2037 builder.restoreIP(afterIP);
2039 // Process the reductions if required.
2040 if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
2041 allocaIP, reductionDecls,
2042 privateReductionVariables, isByRef)))
2043 return failure();
2045 return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
2046 llvmPrivateVars, privateDecls);
2049 /// Converts the OpenMP parallel operation to LLVM IR.
2050 static LogicalResult
2051 convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
2052 LLVM::ModuleTranslation &moduleTranslation) {
2053 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2054 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
2055 assert(isByRef.size() == opInst.getNumReductionVars());
2056 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2058 if (failed(checkImplementationStatus(*opInst)))
2059 return failure();
2061 // Collect delayed privatization declarations
2062 MutableArrayRef<BlockArgument> privateBlockArgs =
2063 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getPrivateBlockArgs();
2064 SmallVector<mlir::Value> mlirPrivateVars;
2065 SmallVector<llvm::Value *> llvmPrivateVars;
2066 SmallVector<omp::PrivateClauseOp> privateDecls;
2067 mlirPrivateVars.reserve(privateBlockArgs.size());
2068 llvmPrivateVars.reserve(privateBlockArgs.size());
2069 collectPrivatizationDecls(opInst, privateDecls);
2070 for (mlir::Value privateVar : opInst.getPrivateVars())
2071 mlirPrivateVars.push_back(privateVar);
2073 // Collect reduction declarations
2074 SmallVector<omp::DeclareReductionOp> reductionDecls;
2075 collectReductionDecls(opInst, reductionDecls);
2076 SmallVector<llvm::Value *> privateReductionVariables(
2077 opInst.getNumReductionVars());
2078 SmallVector<DeferredStore> deferredStores;
2080 auto bodyGenCB = [&](InsertPointTy allocaIP,
2081 InsertPointTy codeGenIP) -> llvm::Error {
2082 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
2083 builder, moduleTranslation, privateBlockArgs, privateDecls,
2084 mlirPrivateVars, llvmPrivateVars, allocaIP);
2085 if (handleError(afterAllocas, *opInst).failed())
2086 return llvm::make_error<PreviouslyReportedError>();
2088 // Allocate reduction vars
2089 DenseMap<Value, llvm::Value *> reductionVariableMap;
2091 MutableArrayRef<BlockArgument> reductionArgs =
2092 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2094 allocaIP =
2095 InsertPointTy(allocaIP.getBlock(),
2096 allocaIP.getBlock()->getTerminator()->getIterator());
2098 if (failed(allocReductionVars(
2099 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2100 reductionDecls, privateReductionVariables, reductionVariableMap,
2101 deferredStores, isByRef)))
2102 return llvm::make_error<PreviouslyReportedError>();
2104 if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
2105 llvmPrivateVars, privateDecls,
2106 afterAllocas.get())))
2107 return llvm::make_error<PreviouslyReportedError>();
2109 assert(afterAllocas.get()->getSinglePredecessor());
2110 builder.restoreIP(codeGenIP);
2112 if (failed(
2113 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
2114 afterAllocas.get()->getSinglePredecessor(),
2115 reductionDecls, privateReductionVariables,
2116 reductionVariableMap, isByRef, deferredStores)))
2117 return llvm::make_error<PreviouslyReportedError>();
2119 // Store the mapping between reduction variables and their private copies on
2120 // ModuleTranslation stack. It can be then recovered when translating
2121 // omp.reduce operations in a separate call.
2122 LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
2123 moduleTranslation, reductionVariableMap);
2125 // Save the alloca insertion point on ModuleTranslation stack for use in
2126 // nested regions.
2127 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
2128 moduleTranslation, allocaIP);
2130 // ParallelOp has only one region associated with it.
2131 llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
2132 opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
2133 if (!regionBlock)
2134 return regionBlock.takeError();
2136 // Process the reductions if required.
2137 if (opInst.getNumReductionVars() > 0) {
2138 // Collect reduction info
2139 SmallVector<OwningReductionGen> owningReductionGens;
2140 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
2141 SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
2142 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
2143 owningReductionGens, owningAtomicReductionGens,
2144 privateReductionVariables, reductionInfos);
2146 // Move to region cont block
2147 builder.SetInsertPoint((*regionBlock)->getTerminator());
2149 // Generate reductions from info
2150 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2151 builder.SetInsertPoint(tempTerminator);
2153 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2154 ompBuilder->createReductions(builder.saveIP(), allocaIP,
2155 reductionInfos, isByRef, false);
2156 if (!contInsertPoint)
2157 return contInsertPoint.takeError();
2159 if (!contInsertPoint->getBlock())
2160 return llvm::make_error<PreviouslyReportedError>();
2162 tempTerminator->eraseFromParent();
2163 builder.restoreIP(*contInsertPoint);
2165 return llvm::Error::success();
2168 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2169 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2170 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
2171 // bodyGenCB.
2172 replVal = &val;
2173 return codeGenIP;
2176 // TODO: Perform finalization actions for variables. This has to be
2177 // called for variables which have destructors/finalizers.
2178 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2179 InsertPointTy oldIP = builder.saveIP();
2180 builder.restoreIP(codeGenIP);
2182 // if the reduction has a cleanup region, inline it here to finalize the
2183 // reduction variables
2184 SmallVector<Region *> reductionCleanupRegions;
2185 llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
2186 [](omp::DeclareReductionOp reductionDecl) {
2187 return &reductionDecl.getCleanupRegion();
2189 if (failed(inlineOmpRegionCleanup(
2190 reductionCleanupRegions, privateReductionVariables,
2191 moduleTranslation, builder, "omp.reduction.cleanup")))
2192 return llvm::createStringError(
2193 "failed to inline `cleanup` region of `omp.declare_reduction`");
2195 if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
2196 llvmPrivateVars, privateDecls)))
2197 return llvm::make_error<PreviouslyReportedError>();
2199 builder.restoreIP(oldIP);
2200 return llvm::Error::success();
2203 llvm::Value *ifCond = nullptr;
2204 if (auto ifVar = opInst.getIfExpr())
2205 ifCond = moduleTranslation.lookupValue(ifVar);
2206 llvm::Value *numThreads = nullptr;
2207 if (auto numThreadsVar = opInst.getNumThreads())
2208 numThreads = moduleTranslation.lookupValue(numThreadsVar);
2209 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2210 if (auto bind = opInst.getProcBindKind())
2211 pbKind = getProcBindKind(*bind);
2212 // TODO: Is the Parallel construct cancellable?
2213 bool isCancellable = false;
2215 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2216 findAllocaInsertPoint(builder, moduleTranslation);
2217 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2219 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2220 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2221 ifCond, numThreads, pbKind, isCancellable);
2223 if (failed(handleError(afterIP, *opInst)))
2224 return failure();
2226 builder.restoreIP(*afterIP);
2227 return success();
2230 /// Convert Order attribute to llvm::omp::OrderKind.
2231 static llvm::omp::OrderKind
2232 convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
2233 if (!o)
2234 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2235 switch (*o) {
2236 case omp::ClauseOrderKind::Concurrent:
2237 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2239 llvm_unreachable("Unknown ClauseOrderKind kind");
2242 /// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
2243 static LogicalResult
2244 convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2245 LLVM::ModuleTranslation &moduleTranslation) {
2246 auto simdOp = cast<omp::SimdOp>(opInst);
2247 auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop());
2249 if (failed(checkImplementationStatus(opInst)))
2250 return failure();
2252 MutableArrayRef<BlockArgument> privateBlockArgs =
2253 cast<omp::BlockArgOpenMPOpInterface>(*simdOp).getPrivateBlockArgs();
2254 SmallVector<mlir::Value> mlirPrivateVars;
2255 SmallVector<llvm::Value *> llvmPrivateVars;
2256 SmallVector<omp::PrivateClauseOp> privateDecls;
2257 mlirPrivateVars.reserve(privateBlockArgs.size());
2258 llvmPrivateVars.reserve(privateBlockArgs.size());
2259 collectPrivatizationDecls(simdOp, privateDecls);
2261 for (mlir::Value privateVar : simdOp.getPrivateVars())
2262 mlirPrivateVars.push_back(privateVar);
2264 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2265 findAllocaInsertPoint(builder, moduleTranslation);
2266 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2268 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
2269 builder, moduleTranslation, privateBlockArgs, privateDecls,
2270 mlirPrivateVars, llvmPrivateVars, allocaIP);
2271 if (handleError(afterAllocas, opInst).failed())
2272 return failure();
2274 // Generator of the canonical loop body.
2275 SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
2276 SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
2277 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
2278 llvm::Value *iv) -> llvm::Error {
2279 // Make sure further conversions know about the induction variable.
2280 moduleTranslation.mapValue(
2281 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
2283 // Capture the body insertion point for use in nested loops. BodyIP of the
2284 // CanonicalLoopInfo always points to the beginning of the entry block of
2285 // the body.
2286 bodyInsertPoints.push_back(ip);
2288 if (loopInfos.size() != loopOp.getNumLoops() - 1)
2289 return llvm::Error::success();
2291 // Convert the body of the loop.
2292 builder.restoreIP(ip);
2293 return convertOmpOpRegions(loopOp.getRegion(), "omp.simd.region", builder,
2294 moduleTranslation)
2295 .takeError();
2298 // Delegate actual loop construction to the OpenMP IRBuilder.
2299 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
2300 // loop, i.e. it has a positive step, uses signed integer semantics.
2301 // Reconsider this code when the nested loop operation clearly supports more
2302 // cases.
2303 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2304 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
2305 llvm::Value *lowerBound =
2306 moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
2307 llvm::Value *upperBound =
2308 moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
2309 llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
2311 // Make sure loop trip count are emitted in the preheader of the outermost
2312 // loop at the latest so that they are all available for the new collapsed
2313 // loop will be created below.
2314 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
2315 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
2316 if (i != 0) {
2317 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
2318 ompLoc.DL);
2319 computeIP = loopInfos.front()->getPreheaderIP();
2322 llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2323 ompBuilder->createCanonicalLoop(
2324 loc, bodyGen, lowerBound, upperBound, step,
2325 /*IsSigned=*/true, /*InclusiveStop=*/true, computeIP);
2327 if (failed(handleError(loopResult, *loopOp)))
2328 return failure();
2330 loopInfos.push_back(*loopResult);
2333 // Collapse loops.
2334 llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
2335 llvm::CanonicalLoopInfo *loopInfo =
2336 ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
2338 llvm::ConstantInt *simdlen = nullptr;
2339 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2340 simdlen = builder.getInt64(simdlenVar.value());
2342 llvm::ConstantInt *safelen = nullptr;
2343 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2344 safelen = builder.getInt64(safelenVar.value());
2346 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2347 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
2348 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2349 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
2350 mlir::OperandRange operands = simdOp.getAlignedVars();
2351 for (size_t i = 0; i < operands.size(); ++i) {
2352 llvm::Value *alignment = nullptr;
2353 llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
2354 llvm::Type *ty = llvmVal->getType();
2355 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
2356 alignment = builder.getInt64(intAttr.getInt());
2357 assert(ty->isPointerTy() && "Invalid type for aligned variable");
2358 assert(alignment && "Invalid alignment value");
2359 auto curInsert = builder.saveIP();
2360 builder.SetInsertPoint(sourceBlock->getTerminator());
2361 llvmVal = builder.CreateLoad(ty, llvmVal);
2362 builder.restoreIP(curInsert);
2363 alignedVars[llvmVal] = alignment;
2366 ompBuilder->applySimd(loopInfo, alignedVars,
2367 simdOp.getIfExpr()
2368 ? moduleTranslation.lookupValue(simdOp.getIfExpr())
2369 : nullptr,
2370 order, simdlen, safelen);
2372 builder.restoreIP(afterIP);
2374 return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
2375 llvmPrivateVars, privateDecls);
2378 /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
2379 static llvm::AtomicOrdering
2380 convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
2381 if (!ao)
2382 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
2384 switch (*ao) {
2385 case omp::ClauseMemoryOrderKind::Seq_cst:
2386 return llvm::AtomicOrdering::SequentiallyConsistent;
2387 case omp::ClauseMemoryOrderKind::Acq_rel:
2388 return llvm::AtomicOrdering::AcquireRelease;
2389 case omp::ClauseMemoryOrderKind::Acquire:
2390 return llvm::AtomicOrdering::Acquire;
2391 case omp::ClauseMemoryOrderKind::Release:
2392 return llvm::AtomicOrdering::Release;
2393 case omp::ClauseMemoryOrderKind::Relaxed:
2394 return llvm::AtomicOrdering::Monotonic;
2396 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
2399 /// Convert omp.atomic.read operation to LLVM IR.
2400 static LogicalResult
2401 convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
2402 LLVM::ModuleTranslation &moduleTranslation) {
2403 auto readOp = cast<omp::AtomicReadOp>(opInst);
2404 if (failed(checkImplementationStatus(opInst)))
2405 return failure();
2407 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2409 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2411 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
2412 llvm::Value *x = moduleTranslation.lookupValue(readOp.getX());
2413 llvm::Value *v = moduleTranslation.lookupValue(readOp.getV());
2415 llvm::Type *elementType =
2416 moduleTranslation.convertType(readOp.getElementType());
2418 llvm::OpenMPIRBuilder::AtomicOpValue V = {v, elementType, false, false};
2419 llvm::OpenMPIRBuilder::AtomicOpValue X = {x, elementType, false, false};
2420 builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO));
2421 return success();
2424 /// Converts an omp.atomic.write operation to LLVM IR.
2425 static LogicalResult
2426 convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
2427 LLVM::ModuleTranslation &moduleTranslation) {
2428 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
2429 if (failed(checkImplementationStatus(opInst)))
2430 return failure();
2432 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2434 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2435 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
2436 llvm::Value *expr = moduleTranslation.lookupValue(writeOp.getExpr());
2437 llvm::Value *dest = moduleTranslation.lookupValue(writeOp.getX());
2438 llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
2439 llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
2440 /*isVolatile=*/false};
2441 builder.restoreIP(ompBuilder->createAtomicWrite(ompLoc, x, expr, ao));
2442 return success();
2445 /// Converts an LLVM dialect binary operation to the corresponding enum value
2446 /// for `atomicrmw` supported binary operation.
2447 llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
2448 return llvm::TypeSwitch<Operation *, llvm::AtomicRMWInst::BinOp>(&op)
2449 .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
2450 .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
2451 .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
2452 .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
2453 .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
2454 .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
2455 .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
2456 .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
2457 .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
2458 .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
2461 /// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
2462 static LogicalResult
2463 convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
2464 llvm::IRBuilderBase &builder,
2465 LLVM::ModuleTranslation &moduleTranslation) {
2466 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2467 if (failed(checkImplementationStatus(*opInst)))
2468 return failure();
2470 // Convert values and types.
2471 auto &innerOpList = opInst.getRegion().front().getOperations();
2472 bool isXBinopExpr{false};
2473 llvm::AtomicRMWInst::BinOp binop;
2474 mlir::Value mlirExpr;
2475 llvm::Value *llvmExpr = nullptr;
2476 llvm::Value *llvmX = nullptr;
2477 llvm::Type *llvmXElementType = nullptr;
2478 if (innerOpList.size() == 2) {
2479 // The two operations here are the update and the terminator.
2480 // Since we can identify the update operation, there is a possibility
2481 // that we can generate the atomicrmw instruction.
2482 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
2483 if (!llvm::is_contained(innerOp.getOperands(),
2484 opInst.getRegion().getArgument(0))) {
2485 return opInst.emitError("no atomic update operation with region argument"
2486 " as operand found inside atomic.update region");
2488 binop = convertBinOpToAtomic(innerOp);
2489 isXBinopExpr = innerOp.getOperand(0) == opInst.getRegion().getArgument(0);
2490 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
2491 llvmExpr = moduleTranslation.lookupValue(mlirExpr);
2492 } else {
2493 // Since the update region includes more than one operation
2494 // we will resort to generating a cmpxchg loop.
2495 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
2497 llvmX = moduleTranslation.lookupValue(opInst.getX());
2498 llvmXElementType = moduleTranslation.convertType(
2499 opInst.getRegion().getArgument(0).getType());
2500 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
2501 /*isSigned=*/false,
2502 /*isVolatile=*/false};
2504 llvm::AtomicOrdering atomicOrdering =
2505 convertAtomicOrdering(opInst.getMemoryOrder());
2507 // Generate update code.
2508 auto updateFn =
2509 [&opInst, &moduleTranslation](
2510 llvm::Value *atomicx,
2511 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
2512 Block &bb = *opInst.getRegion().begin();
2513 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
2514 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
2515 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
2516 return llvm::make_error<PreviouslyReportedError>();
2518 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
2519 assert(yieldop && yieldop.getResults().size() == 1 &&
2520 "terminator must be omp.yield op and it must have exactly one "
2521 "argument");
2522 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
2525 // Handle ambiguous alloca, if any.
2526 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2527 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2528 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2529 ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
2530 atomicOrdering, binop, updateFn,
2531 isXBinopExpr);
2533 if (failed(handleError(afterIP, *opInst)))
2534 return failure();
2536 builder.restoreIP(*afterIP);
2537 return success();
2540 static LogicalResult
2541 convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
2542 llvm::IRBuilderBase &builder,
2543 LLVM::ModuleTranslation &moduleTranslation) {
2544 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2545 if (failed(checkImplementationStatus(*atomicCaptureOp)))
2546 return failure();
2548 mlir::Value mlirExpr;
2549 bool isXBinopExpr = false, isPostfixUpdate = false;
2550 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
2552 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
2553 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
2555 assert((atomicUpdateOp || atomicWriteOp) &&
2556 "internal op must be an atomic.update or atomic.write op");
2558 if (atomicWriteOp) {
2559 isPostfixUpdate = true;
2560 mlirExpr = atomicWriteOp.getExpr();
2561 } else {
2562 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
2563 atomicCaptureOp.getAtomicUpdateOp().getOperation();
2564 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
2565 // Find the binary update operation that uses the region argument
2566 // and get the expression to update
2567 if (innerOpList.size() == 2) {
2568 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
2569 if (!llvm::is_contained(innerOp.getOperands(),
2570 atomicUpdateOp.getRegion().getArgument(0))) {
2571 return atomicUpdateOp.emitError(
2572 "no atomic update operation with region argument"
2573 " as operand found inside atomic.update region");
2575 binop = convertBinOpToAtomic(innerOp);
2576 isXBinopExpr =
2577 innerOp.getOperand(0) == atomicUpdateOp.getRegion().getArgument(0);
2578 mlirExpr = (isXBinopExpr ? innerOp.getOperand(1) : innerOp.getOperand(0));
2579 } else {
2580 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
2584 llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
2585 llvm::Value *llvmX =
2586 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getX());
2587 llvm::Value *llvmV =
2588 moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().getV());
2589 llvm::Type *llvmXElementType = moduleTranslation.convertType(
2590 atomicCaptureOp.getAtomicReadOp().getElementType());
2591 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
2592 /*isSigned=*/false,
2593 /*isVolatile=*/false};
2594 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
2595 /*isSigned=*/false,
2596 /*isVolatile=*/false};
2598 llvm::AtomicOrdering atomicOrdering =
2599 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
2601 auto updateFn =
2602 [&](llvm::Value *atomicx,
2603 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
2604 if (atomicWriteOp)
2605 return moduleTranslation.lookupValue(atomicWriteOp.getExpr());
2606 Block &bb = *atomicUpdateOp.getRegion().begin();
2607 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
2608 atomicx);
2609 moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
2610 if (failed(moduleTranslation.convertBlock(bb, true, builder)))
2611 return llvm::make_error<PreviouslyReportedError>();
2613 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
2614 assert(yieldop && yieldop.getResults().size() == 1 &&
2615 "terminator must be omp.yield op and it must have exactly one "
2616 "argument");
2617 return moduleTranslation.lookupValue(yieldop.getResults()[0]);
2620 // Handle ambiguous alloca, if any.
2621 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2622 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2623 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2624 ompBuilder->createAtomicCapture(
2625 ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
2626 binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr);
2628 if (failed(handleError(afterIP, *atomicCaptureOp)))
2629 return failure();
2631 builder.restoreIP(*afterIP);
2632 return success();
2635 /// Converts an OpenMP Threadprivate operation into LLVM IR using
2636 /// OpenMPIRBuilder.
2637 static LogicalResult
2638 convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
2639 LLVM::ModuleTranslation &moduleTranslation) {
2640 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2641 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2642 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
2644 if (failed(checkImplementationStatus(opInst)))
2645 return failure();
2647 Value symAddr = threadprivateOp.getSymAddr();
2648 auto *symOp = symAddr.getDefiningOp();
2650 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
2651 symOp = asCast.getOperand().getDefiningOp();
2653 if (!isa<LLVM::AddressOfOp>(symOp))
2654 return opInst.emitError("Addressing symbol not found");
2655 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
2657 LLVM::GlobalOp global =
2658 addressOfOp.getGlobal(moduleTranslation.symbolTable());
2659 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
2661 if (!ompBuilder->Config.isTargetDevice()) {
2662 llvm::Type *type = globalValue->getValueType();
2663 llvm::TypeSize typeSize =
2664 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
2665 type);
2666 llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
2667 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
2668 ompLoc, globalValue, size, global.getSymName() + ".cache");
2669 moduleTranslation.mapValue(opInst.getResult(0), callInst);
2670 } else {
2671 moduleTranslation.mapValue(opInst.getResult(0), globalValue);
2674 return success();
2677 static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
2678 convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
2679 switch (deviceClause) {
2680 case mlir::omp::DeclareTargetDeviceType::host:
2681 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
2682 break;
2683 case mlir::omp::DeclareTargetDeviceType::nohost:
2684 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
2685 break;
2686 case mlir::omp::DeclareTargetDeviceType::any:
2687 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
2688 break;
2690 llvm_unreachable("unhandled device clause");
2693 static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
2694 convertToCaptureClauseKind(
2695 mlir::omp::DeclareTargetCaptureClause captureClause) {
2696 switch (captureClause) {
2697 case mlir::omp::DeclareTargetCaptureClause::to:
2698 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
2699 case mlir::omp::DeclareTargetCaptureClause::link:
2700 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
2701 case mlir::omp::DeclareTargetCaptureClause::enter:
2702 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
2704 llvm_unreachable("unhandled capture clause");
2707 static llvm::SmallString<64>
2708 getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
2709 llvm::OpenMPIRBuilder &ompBuilder) {
2710 llvm::SmallString<64> suffix;
2711 llvm::raw_svector_ostream os(suffix);
2712 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
2713 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
2714 auto fileInfoCallBack = [&loc]() {
2715 return std::pair<std::string, uint64_t>(
2716 llvm::StringRef(loc.getFilename()), loc.getLine());
2719 os << llvm::format(
2720 "_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID);
2722 os << "_decl_tgt_ref_ptr";
2724 return suffix;
2727 static bool isDeclareTargetLink(mlir::Value value) {
2728 if (auto addressOfOp =
2729 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
2730 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
2731 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
2732 if (auto declareTargetGlobal =
2733 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
2734 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
2735 mlir::omp::DeclareTargetCaptureClause::link)
2736 return true;
2738 return false;
2741 // Returns the reference pointer generated by the lowering of the declare target
2742 // operation in cases where the link clause is used or the to clause is used in
2743 // USM mode.
2744 static llvm::Value *
2745 getRefPtrIfDeclareTarget(mlir::Value value,
2746 LLVM::ModuleTranslation &moduleTranslation) {
2747 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2749 // An easier way to do this may just be to keep track of any pointer
2750 // references and their mapping to their respective operation
2751 if (auto addressOfOp =
2752 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
2753 if (auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
2754 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
2755 addressOfOp.getGlobalName()))) {
2757 if (auto declareTargetGlobal =
2758 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
2759 gOp.getOperation())) {
2761 // In this case, we must utilise the reference pointer generated by the
2762 // declare target operation, similar to Clang
2763 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
2764 mlir::omp::DeclareTargetCaptureClause::link) ||
2765 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
2766 mlir::omp::DeclareTargetCaptureClause::to &&
2767 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
2768 llvm::SmallString<64> suffix =
2769 getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
2771 if (gOp.getSymName().contains(suffix))
2772 return moduleTranslation.getLLVMModule()->getNamedValue(
2773 gOp.getSymName());
2775 return moduleTranslation.getLLVMModule()->getNamedValue(
2776 (gOp.getSymName().str() + suffix.str()).str());
2782 return nullptr;
2785 namespace {
2786 // A small helper structure to contain data gathered
2787 // for map lowering and coalese it into one area and
2788 // avoiding extra computations such as searches in the
2789 // llvm module for lowered mapped variables or checking
2790 // if something is declare target (and retrieving the
2791 // value) more than neccessary.
2792 struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
2793 llvm::SmallVector<bool, 4> IsDeclareTarget;
2794 llvm::SmallVector<bool, 4> IsAMember;
2795 // Identify if mapping was added by mapClause or use_device clauses.
2796 llvm::SmallVector<bool, 4> IsAMapping;
2797 llvm::SmallVector<mlir::Operation *, 4> MapClause;
2798 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
2799 // Stripped off array/pointer to get the underlying
2800 // element type
2801 llvm::SmallVector<llvm::Type *, 4> BaseType;
2803 /// Append arrays in \a CurInfo.
2804 void append(MapInfoData &CurInfo) {
2805 IsDeclareTarget.append(CurInfo.IsDeclareTarget.begin(),
2806 CurInfo.IsDeclareTarget.end());
2807 MapClause.append(CurInfo.MapClause.begin(), CurInfo.MapClause.end());
2808 OriginalValue.append(CurInfo.OriginalValue.begin(),
2809 CurInfo.OriginalValue.end());
2810 BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
2811 llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo);
2814 } // namespace
2816 uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl) {
2817 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
2818 arrTy.getElementType()))
2819 return getArrayElementSizeInBits(nestedArrTy, dl);
2820 return dl.getTypeSizeInBits(arrTy.getElementType());
2823 // This function calculates the size to be offloaded for a specified type, given
2824 // its associated map clause (which can contain bounds information which affects
2825 // the total size), this size is calculated based on the underlying element type
2826 // e.g. given a 1-D array of ints, we will calculate the size from the integer
2827 // type * number of elements in the array. This size can be used in other
2828 // calculations but is ultimately used as an argument to the OpenMP runtimes
2829 // kernel argument structure which is generated through the combinedInfo data
2830 // structures.
2831 // This function is somewhat equivalent to Clang's getExprTypeSize inside of
2832 // CGOpenMPRuntime.cpp.
2833 llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
2834 Operation *clauseOp, llvm::Value *basePointer,
2835 llvm::Type *baseType, llvm::IRBuilderBase &builder,
2836 LLVM::ModuleTranslation &moduleTranslation) {
2837 if (auto memberClause =
2838 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
2839 // This calculates the size to transfer based on bounds and the underlying
2840 // element type, provided bounds have been specified (Fortran
2841 // pointers/allocatables/target and arrays that have sections specified fall
2842 // into this as well).
2843 if (!memberClause.getBounds().empty()) {
2844 llvm::Value *elementCount = builder.getInt64(1);
2845 for (auto bounds : memberClause.getBounds()) {
2846 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2847 bounds.getDefiningOp())) {
2848 // The below calculation for the size to be mapped calculated from the
2849 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
2850 // multiply by the underlying element types byte size to get the full
2851 // size to be offloaded based on the bounds
2852 elementCount = builder.CreateMul(
2853 elementCount,
2854 builder.CreateAdd(
2855 builder.CreateSub(
2856 moduleTranslation.lookupValue(boundOp.getUpperBound()),
2857 moduleTranslation.lookupValue(boundOp.getLowerBound())),
2858 builder.getInt64(1)));
2862 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
2863 // the size in inconsistent byte or bit format.
2864 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
2865 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
2866 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
2868 // The size in bytes x number of elements, the sizeInBytes stored is
2869 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
2870 // size, so we do some on the fly runtime math to get the size in
2871 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
2872 // some adjustment for members with more complex types.
2873 return builder.CreateMul(elementCount,
2874 builder.getInt64(underlyingTypeSzInBits / 8));
2878 return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
2881 static void collectMapDataFromMapOperands(
2882 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
2883 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
2884 llvm::IRBuilderBase &builder, const ArrayRef<Value> &useDevPtrOperands = {},
2885 const ArrayRef<Value> &useDevAddrOperands = {}) {
2886 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
2887 // Check if this is a member mapping and correctly assign that it is, if
2888 // it is a member of a larger object.
2889 // TODO: Need better handling of members, and distinguishing of members
2890 // that are implicitly allocated on device vs explicitly passed in as
2891 // arguments.
2892 // TODO: May require some further additions to support nested record
2893 // types, i.e. member maps that can have member maps.
2894 for (Value mapValue : mapVars) {
2895 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
2896 for (auto member : map.getMembers())
2897 if (member == mapOp)
2898 return true;
2900 return false;
2903 // Process MapOperands
2904 for (Value mapValue : mapVars) {
2905 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
2906 Value offloadPtr =
2907 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
2908 mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
2909 mapData.Pointers.push_back(mapData.OriginalValue.back());
2911 if (llvm::Value *refPtr =
2912 getRefPtrIfDeclareTarget(offloadPtr,
2913 moduleTranslation)) { // declare target
2914 mapData.IsDeclareTarget.push_back(true);
2915 mapData.BasePointers.push_back(refPtr);
2916 } else { // regular mapped variable
2917 mapData.IsDeclareTarget.push_back(false);
2918 mapData.BasePointers.push_back(mapData.OriginalValue.back());
2921 mapData.BaseType.push_back(
2922 moduleTranslation.convertType(mapOp.getVarType()));
2923 mapData.Sizes.push_back(
2924 getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
2925 mapData.BaseType.back(), builder, moduleTranslation));
2926 mapData.MapClause.push_back(mapOp.getOperation());
2927 mapData.Types.push_back(
2928 llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
2929 mapData.Names.push_back(LLVM::createMappingInformation(
2930 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
2931 mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2932 mapData.IsAMapping.push_back(true);
2933 mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
2936 auto findMapInfo = [&mapData](llvm::Value *val,
2937 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
2938 unsigned index = 0;
2939 bool found = false;
2940 for (llvm::Value *basePtr : mapData.OriginalValue) {
2941 if (basePtr == val && mapData.IsAMapping[index]) {
2942 found = true;
2943 mapData.Types[index] |=
2944 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
2945 mapData.DevicePointers[index] = devInfoTy;
2947 index++;
2949 return found;
2952 // Process useDevPtr(Addr)Operands
2953 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
2954 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
2955 for (Value mapValue : useDevOperands) {
2956 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
2957 Value offloadPtr =
2958 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
2959 llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
2961 // Check if map info is already present for this entry.
2962 if (!findMapInfo(origValue, devInfoTy)) {
2963 mapData.OriginalValue.push_back(origValue);
2964 mapData.Pointers.push_back(mapData.OriginalValue.back());
2965 mapData.IsDeclareTarget.push_back(false);
2966 mapData.BasePointers.push_back(mapData.OriginalValue.back());
2967 mapData.BaseType.push_back(
2968 moduleTranslation.convertType(mapOp.getVarType()));
2969 mapData.Sizes.push_back(builder.getInt64(0));
2970 mapData.MapClause.push_back(mapOp.getOperation());
2971 mapData.Types.push_back(
2972 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
2973 mapData.Names.push_back(LLVM::createMappingInformation(
2974 mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
2975 mapData.DevicePointers.push_back(devInfoTy);
2976 mapData.IsAMapping.push_back(false);
2977 mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
2982 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
2983 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
2986 static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
2987 auto *res = llvm::find(mapData.MapClause, memberOp);
2988 assert(res != mapData.MapClause.end() &&
2989 "MapInfoOp for member not found in MapData, cannot return index");
2990 return std::distance(mapData.MapClause.begin(), res);
2993 static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
2994 bool first) {
2995 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
2996 // Only 1 member has been mapped, we can return it.
2997 if (indexAttr.size() == 1)
2998 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
3000 llvm::SmallVector<size_t> indices(indexAttr.size());
3001 std::iota(indices.begin(), indices.end(), 0);
3003 llvm::sort(indices.begin(), indices.end(),
3004 [&](const size_t a, const size_t b) {
3005 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
3006 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
3007 for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
3008 int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
3009 int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
3011 if (aIndex == bIndex)
3012 continue;
3014 if (aIndex < bIndex)
3015 return first;
3017 if (aIndex > bIndex)
3018 return !first;
3021 // Iterated the up until the end of the smallest member and
3022 // they were found to be equal up to that point, so select
3023 // the member with the lowest index count, so the "parent"
3024 return memberIndicesA.size() < memberIndicesB.size();
3027 return llvm::cast<omp::MapInfoOp>(
3028 mapInfo.getMembers()[indices.front()].getDefiningOp());
3031 /// This function calculates the array/pointer offset for map data provided
3032 /// with bounds operations, e.g. when provided something like the following:
3034 /// Fortran
3035 /// map(tofrom: array(2:5, 3:2))
3036 /// or
3037 /// C++
3038 /// map(tofrom: array[1:4][2:3])
3039 /// We must calculate the initial pointer offset to pass across, this function
3040 /// performs this using bounds.
3042 /// NOTE: which while specified in row-major order it currently needs to be
3043 /// flipped for Fortran's column order array allocation and access (as
3044 /// opposed to C++'s row-major, hence the backwards processing where order is
3045 /// important). This is likely important to keep in mind for the future when
3046 /// we incorporate a C++ frontend, both frontends will need to agree on the
3047 /// ordering of generated bounds operations (one may have to flip them) to
3048 /// make the below lowering frontend agnostic. The offload size
3049 /// calcualtion may also have to be adjusted for C++.
3050 std::vector<llvm::Value *>
3051 calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
3052 llvm::IRBuilderBase &builder, bool isArrayTy,
3053 OperandRange bounds) {
3054 std::vector<llvm::Value *> idx;
3055 // There's no bounds to calculate an offset from, we can safely
3056 // ignore and return no indices.
3057 if (bounds.empty())
3058 return idx;
3060 // If we have an array type, then we have its type so can treat it as a
3061 // normal GEP instruction where the bounds operations are simply indexes
3062 // into the array. We currently do reverse order of the bounds, which
3063 // I believe leans more towards Fortran's column-major in memory.
3064 if (isArrayTy) {
3065 idx.push_back(builder.getInt64(0));
3066 for (int i = bounds.size() - 1; i >= 0; --i) {
3067 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3068 bounds[i].getDefiningOp())) {
3069 idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
3072 } else {
3073 // If we do not have an array type, but we have bounds, then we're dealing
3074 // with a pointer that's being treated like an array and we have the
3075 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
3076 // address (pointer pointing to the actual data) so we must caclulate the
3077 // offset using a single index which the following two loops attempts to
3078 // compute.
3080 // Calculates the size offset we need to make per row e.g. first row or
3081 // column only needs to be offset by one, but the next would have to be
3082 // the previous row/column offset multiplied by the extent of current row.
3084 // For example ([1][10][100]):
3086 // - First row/column we move by 1 for each index increment
3087 // - Second row/column we move by 1 (first row/column) * 10 (extent/size of
3088 // current) for 10 for each index increment
3089 // - Third row/column we would move by 10 (second row/column) *
3090 // (extent/size of current) 100 for 1000 for each index increment
3091 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
3092 for (size_t i = 1; i < bounds.size(); ++i) {
3093 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3094 bounds[i].getDefiningOp())) {
3095 dimensionIndexSizeOffset.push_back(builder.CreateMul(
3096 moduleTranslation.lookupValue(boundOp.getExtent()),
3097 dimensionIndexSizeOffset[i - 1]));
3101 // Now that we have calculated how much we move by per index, we must
3102 // multiply each lower bound offset in indexes by the size offset we
3103 // have calculated in the previous and accumulate the results to get
3104 // our final resulting offset.
3105 for (int i = bounds.size() - 1; i >= 0; --i) {
3106 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3107 bounds[i].getDefiningOp())) {
3108 if (idx.empty())
3109 idx.emplace_back(builder.CreateMul(
3110 moduleTranslation.lookupValue(boundOp.getLowerBound()),
3111 dimensionIndexSizeOffset[i]));
3112 else
3113 idx.back() = builder.CreateAdd(
3114 idx.back(), builder.CreateMul(moduleTranslation.lookupValue(
3115 boundOp.getLowerBound()),
3116 dimensionIndexSizeOffset[i]));
3121 return idx;
3124 // This creates two insertions into the MapInfosTy data structure for the
3125 // "parent" of a set of members, (usually a container e.g.
3126 // class/structure/derived type) when subsequent members have also been
3127 // explicitly mapped on the same map clause. Certain types, such as Fortran
3128 // descriptors are mapped like this as well, however, the members are
3129 // implicit as far as a user is concerned, but we must explicitly map them
3130 // internally.
3132 // This function also returns the memberOfFlag for this particular parent,
3133 // which is utilised in subsequent member mappings (by modifying there map type
3134 // with it) to indicate that a member is part of this parent and should be
3135 // treated by the runtime as such. Important to achieve the correct mapping.
3137 // This function borrows a lot from Clang's emitCombinedEntry function
3138 // inside of CGOpenMPRuntime.cpp
3139 static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
3140 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
3141 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
3142 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
3143 uint64_t mapDataIndex, bool isTargetParams) {
3144 // Map the first segment of our structure
3145 combinedInfo.Types.emplace_back(
3146 isTargetParams
3147 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
3148 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
3149 combinedInfo.DevicePointers.emplace_back(
3150 mapData.DevicePointers[mapDataIndex]);
3151 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
3152 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3153 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3155 // Calculate size of the parent object being mapped based on the
3156 // addresses at runtime, highAddr - lowAddr = size. This of course
3157 // doesn't factor in allocated data like pointers, hence the further
3158 // processing of members specified by users, or in the case of
3159 // Fortran pointers and allocatables, the mapping of the pointed to
3160 // data by the descriptor (which itself, is a structure containing
3161 // runtime information on the dynamically allocated data).
3162 auto parentClause =
3163 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3165 llvm::Value *lowAddr, *highAddr;
3166 if (!parentClause.getPartialMap()) {
3167 lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
3168 builder.getPtrTy());
3169 highAddr = builder.CreatePointerCast(
3170 builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
3171 mapData.Pointers[mapDataIndex], 1),
3172 builder.getPtrTy());
3173 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3174 } else {
3175 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3176 int firstMemberIdx = getMapDataMemberIdx(
3177 mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
3178 lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
3179 builder.getPtrTy());
3180 int lastMemberIdx = getMapDataMemberIdx(
3181 mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
3182 highAddr = builder.CreatePointerCast(
3183 builder.CreateGEP(mapData.BaseType[lastMemberIdx],
3184 mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
3185 builder.getPtrTy());
3186 combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
3189 llvm::Value *size = builder.CreateIntCast(
3190 builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
3191 builder.getInt64Ty(),
3192 /*isSigned=*/false);
3193 combinedInfo.Sizes.push_back(size);
3195 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
3196 ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
3198 // This creates the initial MEMBER_OF mapping that consists of
3199 // the parent/top level container (same as above effectively, except
3200 // with a fixed initial compile time size and separate maptype which
3201 // indicates the true mape type (tofrom etc.). This parent mapping is
3202 // only relevant if the structure in its totality is being mapped,
3203 // otherwise the above suffices.
3204 if (!parentClause.getPartialMap()) {
3205 // TODO: This will need to be expanded to include the whole host of logic
3206 // for the map flags that Clang currently supports (e.g. it should do some
3207 // further case specific flag modifications). For the moment, it handles
3208 // what we support as expected.
3209 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
3210 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3211 combinedInfo.Types.emplace_back(mapFlag);
3212 combinedInfo.DevicePointers.emplace_back(
3213 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3214 combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
3215 mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
3216 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
3217 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
3218 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
3220 return memberOfFlag;
3223 // The intent is to verify if the mapped data being passed is a
3224 // pointer -> pointee that requires special handling in certain cases,
3225 // e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3227 // There may be a better way to verify this, but unfortunately with
3228 // opaque pointers we lose the ability to easily check if something is
3229 // a pointer whilst maintaining access to the underlying type.
3230 static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
3231 // If we have a varPtrPtr field assigned then the underlying type is a pointer
3232 if (mapOp.getVarPtrPtr())
3233 return true;
3235 // If the map data is declare target with a link clause, then it's represented
3236 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3237 // no relation to pointers.
3238 if (isDeclareTargetLink(mapOp.getVarPtr()))
3239 return true;
3241 return false;
3244 // This function is intended to add explicit mappings of members
3245 static void processMapMembersWithParent(
3246 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
3247 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
3248 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
3249 uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
3251 auto parentClause =
3252 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3254 for (auto mappedMembers : parentClause.getMembers()) {
3255 auto memberClause =
3256 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
3257 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
3259 assert(memberDataIdx >= 0 && "could not find mapped member of structure");
3261 // If we're currently mapping a pointer to a block of data, we must
3262 // initially map the pointer, and then attatch/bind the data with a
3263 // subsequent map to the pointer. This segment of code generates the
3264 // pointer mapping, which can in certain cases be optimised out as Clang
3265 // currently does in its lowering. However, for the moment we do not do so,
3266 // in part as we currently have substantially less information on the data
3267 // being mapped at this stage.
3268 if (checkIfPointerMap(memberClause)) {
3269 auto mapFlag = llvm::omp::OpenMPOffloadMappingFlags(
3270 memberClause.getMapType().value());
3271 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3272 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
3273 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3274 combinedInfo.Types.emplace_back(mapFlag);
3275 combinedInfo.DevicePointers.emplace_back(
3276 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3277 combinedInfo.Names.emplace_back(
3278 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
3279 combinedInfo.BasePointers.emplace_back(
3280 mapData.BasePointers[mapDataIndex]);
3281 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
3282 combinedInfo.Sizes.emplace_back(builder.getInt64(
3283 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
3286 // Same MemberOfFlag to indicate its link with parent and other members
3287 // of.
3288 auto mapFlag =
3289 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType().value());
3290 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3291 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
3292 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
3293 if (checkIfPointerMap(memberClause))
3294 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
3296 combinedInfo.Types.emplace_back(mapFlag);
3297 combinedInfo.DevicePointers.emplace_back(
3298 mapData.DevicePointers[memberDataIdx]);
3299 combinedInfo.Names.emplace_back(
3300 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
3301 uint64_t basePointerIndex =
3302 checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
3303 combinedInfo.BasePointers.emplace_back(
3304 mapData.BasePointers[basePointerIndex]);
3305 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
3306 combinedInfo.Sizes.emplace_back(mapData.Sizes[memberDataIdx]);
3310 static void
3311 processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
3312 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
3313 bool isTargetParams, int mapDataParentIdx = -1) {
3314 // Declare Target Mappings are excluded from being marked as
3315 // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
3316 // marked with OMP_MAP_PTR_AND_OBJ instead.
3317 auto mapFlag = mapData.Types[mapDataIdx];
3318 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
3320 bool isPtrTy = checkIfPointerMap(mapInfoOp);
3321 if (isPtrTy)
3322 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
3324 if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
3325 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3327 if (mapInfoOp.getMapCaptureType().value() ==
3328 omp::VariableCaptureKind::ByCopy &&
3329 !isPtrTy)
3330 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
3332 // if we're provided a mapDataParentIdx, then the data being mapped is
3333 // part of a larger object (in a parent <-> member mapping) and in this
3334 // case our BasePointer should be the parent.
3335 if (mapDataParentIdx >= 0)
3336 combinedInfo.BasePointers.emplace_back(
3337 mapData.BasePointers[mapDataParentIdx]);
3338 else
3339 combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
3341 combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
3342 combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
3343 combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
3344 combinedInfo.Types.emplace_back(mapFlag);
3345 combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
3348 static void processMapWithMembersOf(
3349 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
3350 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
3351 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
3352 uint64_t mapDataIndex, bool isTargetParams) {
3353 auto parentClause =
3354 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3356 // If we have a partial map (no parent referenced in the map clauses of the
3357 // directive, only members) and only a single member, we do not need to bind
3358 // the map of the member to the parent, we can pass the member separately.
3359 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
3360 auto memberClause = llvm::cast<omp::MapInfoOp>(
3361 parentClause.getMembers()[0].getDefiningOp());
3362 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
3363 // Note: Clang treats arrays with explicit bounds that fall into this
3364 // category as a parent with map case, however, it seems this isn't a
3365 // requirement, and processing them as an individual map is fine. So,
3366 // we will handle them as individual maps for the moment, as it's
3367 // difficult for us to check this as we always require bounds to be
3368 // specified currently and it's also marginally more optimal (single
3369 // map rather than two). The difference may come from the fact that
3370 // Clang maps array without bounds as pointers (which we do not
3371 // currently do), whereas we treat them as arrays in all cases
3372 // currently.
3373 processIndividualMap(mapData, memberDataIdx, combinedInfo, isTargetParams,
3374 mapDataIndex);
3375 return;
3378 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
3379 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
3380 combinedInfo, mapData, mapDataIndex, isTargetParams);
3381 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
3382 combinedInfo, mapData, mapDataIndex,
3383 memberOfParentFlag);
3386 // This is a variation on Clang's GenerateOpenMPCapturedVars, which
3387 // generates different operation (e.g. load/store) combinations for
3388 // arguments to the kernel, based on map capture kinds which are then
3389 // utilised in the combinedInfo in place of the original Map value.
3390 static void
3391 createAlteredByCaptureMap(MapInfoData &mapData,
3392 LLVM::ModuleTranslation &moduleTranslation,
3393 llvm::IRBuilderBase &builder) {
3394 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
3395 // if it's declare target, skip it, it's handled separately.
3396 if (!mapData.IsDeclareTarget[i]) {
3397 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
3398 omp::VariableCaptureKind captureKind =
3399 mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef);
3400 bool isPtrTy = checkIfPointerMap(mapOp);
3402 // Currently handles array sectioning lowerbound case, but more
3403 // logic may be required in the future. Clang invokes EmitLValue,
3404 // which has specialised logic for special Clang types such as user
3405 // defines, so it is possible we will have to extend this for
3406 // structures or other complex types. As the general idea is that this
3407 // function mimics some of the logic from Clang that we require for
3408 // kernel argument passing from host -> device.
3409 switch (captureKind) {
3410 case omp::VariableCaptureKind::ByRef: {
3411 llvm::Value *newV = mapData.Pointers[i];
3412 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
3413 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
3414 mapOp.getBounds());
3415 if (isPtrTy)
3416 newV = builder.CreateLoad(builder.getPtrTy(), newV);
3418 if (!offsetIdx.empty())
3419 newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
3420 "array_offset");
3421 mapData.Pointers[i] = newV;
3422 } break;
3423 case omp::VariableCaptureKind::ByCopy: {
3424 llvm::Type *type = mapData.BaseType[i];
3425 llvm::Value *newV;
3426 if (mapData.Pointers[i]->getType()->isPointerTy())
3427 newV = builder.CreateLoad(type, mapData.Pointers[i]);
3428 else
3429 newV = mapData.Pointers[i];
3431 if (!isPtrTy) {
3432 auto curInsert = builder.saveIP();
3433 builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
3434 auto *memTempAlloc =
3435 builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
3436 builder.restoreIP(curInsert);
3438 builder.CreateStore(newV, memTempAlloc);
3439 newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
3442 mapData.Pointers[i] = newV;
3443 mapData.BasePointers[i] = newV;
3444 } break;
3445 case omp::VariableCaptureKind::This:
3446 case omp::VariableCaptureKind::VLAType:
3447 mapData.MapClause[i]->emitOpError("Unhandled capture kind");
3448 break;
3454 // Generate all map related information and fill the combinedInfo.
3455 static void genMapInfos(llvm::IRBuilderBase &builder,
3456 LLVM::ModuleTranslation &moduleTranslation,
3457 DataLayout &dl,
3458 llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
3459 MapInfoData &mapData, bool isTargetParams = false) {
3460 // We wish to modify some of the methods in which arguments are
3461 // passed based on their capture type by the target region, this can
3462 // involve generating new loads and stores, which changes the
3463 // MLIR value to LLVM value mapping, however, we only wish to do this
3464 // locally for the current function/target and also avoid altering
3465 // ModuleTranslation, so we remap the base pointer or pointer stored
3466 // in the map infos corresponding MapInfoData, which is later accessed
3467 // by genMapInfos and createTarget to help generate the kernel and
3468 // kernel arg structure. It primarily becomes relevant in cases like
3469 // bycopy, or byref range'd arrays. In the default case, we simply
3470 // pass thee pointer byref as both basePointer and pointer.
3471 if (!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
3472 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
3474 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3476 // We operate under the assumption that all vectors that are
3477 // required in MapInfoData are of equal lengths (either filled with
3478 // default constructed data or appropiate information) so we can
3479 // utilise the size from any component of MapInfoData, if we can't
3480 // something is missing from the initial MapInfoData construction.
3481 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
3482 // NOTE/TODO: We currently do not support arbitrary depth record
3483 // type mapping.
3484 if (mapData.IsAMember[i])
3485 continue;
3487 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
3488 if (!mapInfoOp.getMembers().empty()) {
3489 processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
3490 combinedInfo, mapData, i, isTargetParams);
3491 continue;
3494 processIndividualMap(mapData, i, combinedInfo, isTargetParams);
3498 static LogicalResult
3499 convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
3500 LLVM::ModuleTranslation &moduleTranslation) {
3501 llvm::Value *ifCond = nullptr;
3502 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
3503 SmallVector<Value> mapVars;
3504 SmallVector<Value> useDevicePtrVars;
3505 SmallVector<Value> useDeviceAddrVars;
3506 llvm::omp::RuntimeFunction RTLFn;
3507 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
3509 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3510 llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
3511 /*SeparateBeginEndCalls=*/true);
3513 LogicalResult result =
3514 llvm::TypeSwitch<Operation *, LogicalResult>(op)
3515 .Case([&](omp::TargetDataOp dataOp) {
3516 if (failed(checkImplementationStatus(*dataOp)))
3517 return failure();
3519 if (auto ifVar = dataOp.getIfExpr())
3520 ifCond = moduleTranslation.lookupValue(ifVar);
3522 if (auto devId = dataOp.getDevice())
3523 if (auto constOp =
3524 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
3525 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3526 deviceID = intAttr.getInt();
3528 mapVars = dataOp.getMapVars();
3529 useDevicePtrVars = dataOp.getUseDevicePtrVars();
3530 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
3531 return success();
3533 .Case([&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
3534 if (failed(checkImplementationStatus(*enterDataOp)))
3535 return failure();
3537 if (auto ifVar = enterDataOp.getIfExpr())
3538 ifCond = moduleTranslation.lookupValue(ifVar);
3540 if (auto devId = enterDataOp.getDevice())
3541 if (auto constOp =
3542 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
3543 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3544 deviceID = intAttr.getInt();
3545 RTLFn =
3546 enterDataOp.getNowait()
3547 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
3548 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
3549 mapVars = enterDataOp.getMapVars();
3550 info.HasNoWait = enterDataOp.getNowait();
3551 return success();
3553 .Case([&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
3554 if (failed(checkImplementationStatus(*exitDataOp)))
3555 return failure();
3557 if (auto ifVar = exitDataOp.getIfExpr())
3558 ifCond = moduleTranslation.lookupValue(ifVar);
3560 if (auto devId = exitDataOp.getDevice())
3561 if (auto constOp =
3562 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
3563 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3564 deviceID = intAttr.getInt();
3566 RTLFn = exitDataOp.getNowait()
3567 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
3568 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
3569 mapVars = exitDataOp.getMapVars();
3570 info.HasNoWait = exitDataOp.getNowait();
3571 return success();
3573 .Case([&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
3574 if (failed(checkImplementationStatus(*updateDataOp)))
3575 return failure();
3577 if (auto ifVar = updateDataOp.getIfExpr())
3578 ifCond = moduleTranslation.lookupValue(ifVar);
3580 if (auto devId = updateDataOp.getDevice())
3581 if (auto constOp =
3582 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
3583 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3584 deviceID = intAttr.getInt();
3586 RTLFn =
3587 updateDataOp.getNowait()
3588 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
3589 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
3590 mapVars = updateDataOp.getMapVars();
3591 info.HasNoWait = updateDataOp.getNowait();
3592 return success();
3594 .Default([&](Operation *op) {
3595 llvm_unreachable("unexpected operation");
3596 return failure();
3599 if (failed(result))
3600 return failure();
3602 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3604 MapInfoData mapData;
3605 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
3606 builder, useDevicePtrVars, useDeviceAddrVars);
3608 // Fill up the arrays with all the mapped variables.
3609 llvm::OpenMPIRBuilder::MapInfosTy combinedInfo;
3610 auto genMapInfoCB =
3611 [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
3612 builder.restoreIP(codeGenIP);
3613 genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
3614 return combinedInfo;
3617 // Define a lambda to apply mappings between use_device_addr and
3618 // use_device_ptr base pointers, and their associated block arguments.
3619 auto mapUseDevice =
3620 [&moduleTranslation](
3621 llvm::OpenMPIRBuilder::DeviceInfoTy type,
3622 llvm::ArrayRef<BlockArgument> blockArgs,
3623 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
3624 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
3625 for (auto [arg, useDevVar] :
3626 llvm::zip_equal(blockArgs, useDeviceVars)) {
3628 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
3629 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
3630 : mapInfoOp.getVarPtr();
3633 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
3634 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
3635 mapInfoData.MapClause, mapInfoData.DevicePointers,
3636 mapInfoData.BasePointers)) {
3637 auto mapOp = cast<omp::MapInfoOp>(mapClause);
3638 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
3639 devicePointer != type)
3640 continue;
3642 if (llvm::Value *devPtrInfoMap =
3643 mapper ? mapper(basePointer) : basePointer) {
3644 moduleTranslation.mapValue(arg, devPtrInfoMap);
3645 break;
3651 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
3652 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
3653 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
3654 assert(isa<omp::TargetDataOp>(op) &&
3655 "BodyGen requested for non TargetDataOp");
3656 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
3657 Region &region = cast<omp::TargetDataOp>(op).getRegion();
3658 switch (bodyGenType) {
3659 case BodyGenTy::Priv:
3660 // Check if any device ptr/addr info is available
3661 if (!info.DevicePtrInfoMap.empty()) {
3662 builder.restoreIP(codeGenIP);
3664 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
3665 blockArgIface.getUseDeviceAddrBlockArgs(),
3666 useDeviceAddrVars, mapData,
3667 [&](llvm::Value *basePointer) -> llvm::Value * {
3668 if (!info.DevicePtrInfoMap[basePointer].second)
3669 return nullptr;
3670 return builder.CreateLoad(
3671 builder.getPtrTy(),
3672 info.DevicePtrInfoMap[basePointer].second);
3674 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
3675 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
3676 mapData, [&](llvm::Value *basePointer) {
3677 return info.DevicePtrInfoMap[basePointer].second;
3680 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
3681 moduleTranslation)))
3682 return llvm::make_error<PreviouslyReportedError>();
3684 break;
3685 case BodyGenTy::DupNoPriv:
3686 break;
3687 case BodyGenTy::NoPriv:
3688 // If device info is available then region has already been generated
3689 if (info.DevicePtrInfoMap.empty()) {
3690 builder.restoreIP(codeGenIP);
3691 // For device pass, if use_device_ptr(addr) mappings were present,
3692 // we need to link them here before codegen.
3693 if (ompBuilder->Config.IsTargetDevice.value_or(false)) {
3694 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
3695 blockArgIface.getUseDeviceAddrBlockArgs(),
3696 useDeviceAddrVars, mapData);
3697 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
3698 blockArgIface.getUseDevicePtrBlockArgs(),
3699 useDevicePtrVars, mapData);
3702 if (failed(inlineConvertOmpRegions(region, "omp.data.region", builder,
3703 moduleTranslation)))
3704 return llvm::make_error<PreviouslyReportedError>();
3706 break;
3708 return builder.saveIP();
3711 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3712 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3713 findAllocaInsertPoint(builder, moduleTranslation);
3714 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
3715 if (isa<omp::TargetDataOp>(op))
3716 return ompBuilder->createTargetData(
3717 ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID),
3718 ifCond, info, genMapInfoCB, nullptr, bodyGenCB);
3719 return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
3720 builder.getInt64(deviceID), ifCond,
3721 info, genMapInfoCB, &RTLFn);
3722 }();
3724 if (failed(handleError(afterIP, *op)))
3725 return failure();
3727 builder.restoreIP(*afterIP);
3728 return success();
3731 /// Lowers the FlagsAttr which is applied to the module on the device
3732 /// pass when offloading, this attribute contains OpenMP RTL globals that can
3733 /// be passed as flags to the frontend, otherwise they are set to default
3734 LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
3735 LLVM::ModuleTranslation &moduleTranslation) {
3736 if (!cast<mlir::ModuleOp>(op))
3737 return failure();
3739 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3741 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
3742 attribute.getOpenmpDeviceVersion());
3744 if (attribute.getNoGpuLib())
3745 return success();
3747 ompBuilder->createGlobalFlag(
3748 attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
3749 "__omp_rtl_debug_kind");
3750 ompBuilder->createGlobalFlag(
3751 attribute
3752 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
3754 "__omp_rtl_assume_teams_oversubscription");
3755 ompBuilder->createGlobalFlag(
3756 attribute
3757 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
3759 "__omp_rtl_assume_threads_oversubscription");
3760 ompBuilder->createGlobalFlag(
3761 attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
3762 "__omp_rtl_assume_no_thread_state");
3763 ompBuilder->createGlobalFlag(
3764 attribute
3765 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
3767 "__omp_rtl_assume_no_nested_parallelism");
3768 return success();
3771 static bool getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
3772 omp::TargetOp targetOp,
3773 llvm::StringRef parentName = "") {
3774 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
3776 assert(fileLoc && "No file found from location");
3777 StringRef fileName = fileLoc.getFilename().getValue();
3779 llvm::sys::fs::UniqueID id;
3780 if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
3781 targetOp.emitError("Unable to get unique ID for file");
3782 return false;
3785 uint64_t line = fileLoc.getLine();
3786 targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
3787 id.getFile(), line);
3788 return true;
3791 static void
3792 handleDeclareTargetMapVar(MapInfoData &mapData,
3793 LLVM::ModuleTranslation &moduleTranslation,
3794 llvm::IRBuilderBase &builder, llvm::Function *func) {
3795 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
3796 // In the case of declare target mapped variables, the basePointer is
3797 // the reference pointer generated by the convertDeclareTargetAttr
3798 // method. Whereas the kernelValue is the original variable, so for
3799 // the device we must replace all uses of this original global variable
3800 // (stored in kernelValue) with the reference pointer (stored in
3801 // basePointer for declare target mapped variables), as for device the
3802 // data is mapped into this reference pointer and should be loaded
3803 // from it, the original variable is discarded. On host both exist and
3804 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
3805 // function to link the two variables in the runtime and then both the
3806 // reference pointer and the pointer are assigned in the kernel argument
3807 // structure for the host.
3808 if (mapData.IsDeclareTarget[i]) {
3809 // If the original map value is a constant, then we have to make sure all
3810 // of it's uses within the current kernel/function that we are going to
3811 // rewrite are converted to instructions, as we will be altering the old
3812 // use (OriginalValue) from a constant to an instruction, which will be
3813 // illegal and ICE the compiler if the user is a constant expression of
3814 // some kind e.g. a constant GEP.
3815 if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
3816 convertUsersOfConstantsToInstructions(constant, func, false);
3818 // The users iterator will get invalidated if we modify an element,
3819 // so we populate this vector of uses to alter each user on an
3820 // individual basis to emit its own load (rather than one load for
3821 // all).
3822 llvm::SmallVector<llvm::User *> userVec;
3823 for (llvm::User *user : mapData.OriginalValue[i]->users())
3824 userVec.push_back(user);
3826 for (llvm::User *user : userVec) {
3827 if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
3828 if (insn->getFunction() == func) {
3829 auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
3830 mapData.BasePointers[i]);
3831 load->moveBefore(insn->getIterator());
3832 user->replaceUsesOfWith(mapData.OriginalValue[i], load);
3840 // The createDeviceArgumentAccessor function generates
3841 // instructions for retrieving (acessing) kernel
3842 // arguments inside of the device kernel for use by
3843 // the kernel. This enables different semantics such as
3844 // the creation of temporary copies of data allowing
3845 // semantics like read-only/no host write back kernel
3846 // arguments.
3848 // This currently implements a very light version of Clang's
3849 // EmitParmDecl's handling of direct argument handling as well
3850 // as a portion of the argument access generation based on
3851 // capture types found at the end of emitOutlinedFunctionPrologue
3852 // in Clang. The indirect path handling of EmitParmDecl's may be
3853 // required for future work, but a direct 1-to-1 copy doesn't seem
3854 // possible as the logic is rather scattered throughout Clang's
3855 // lowering and perhaps we wish to deviate slightly.
3857 // \param mapData - A container containing vectors of information
3858 // corresponding to the input argument, which should have a
3859 // corresponding entry in the MapInfoData containers
3860 // OrigialValue's.
3861 // \param arg - This is the generated kernel function argument that
3862 // corresponds to the passed in input argument. We generated different
3863 // accesses of this Argument, based on capture type and other Input
3864 // related information.
3865 // \param input - This is the host side value that will be passed to
3866 // the kernel i.e. the kernel input, we rewrite all uses of this within
3867 // the kernel (as we generate the kernel body based on the target's region
3868 // which maintians references to the original input) to the retVal argument
3869 // apon exit of this function inside of the OMPIRBuilder. This interlinks
3870 // the kernel argument to future uses of it in the function providing
3871 // appropriate "glue" instructions inbetween.
3872 // \param retVal - This is the value that all uses of input inside of the
3873 // kernel will be re-written to, the goal of this function is to generate
3874 // an appropriate location for the kernel argument to be accessed from,
3875 // e.g. ByRef will result in a temporary allocation location and then
3876 // a store of the kernel argument into this allocated memory which
3877 // will then be loaded from, ByCopy will use the allocated memory
3878 // directly.
3879 static llvm::IRBuilderBase::InsertPoint
3880 createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
3881 llvm::Value *input, llvm::Value *&retVal,
3882 llvm::IRBuilderBase &builder,
3883 llvm::OpenMPIRBuilder &ompBuilder,
3884 LLVM::ModuleTranslation &moduleTranslation,
3885 llvm::IRBuilderBase::InsertPoint allocaIP,
3886 llvm::IRBuilderBase::InsertPoint codeGenIP) {
3887 builder.restoreIP(allocaIP);
3889 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
3891 // Find the associated MapInfoData entry for the current input
3892 for (size_t i = 0; i < mapData.MapClause.size(); ++i)
3893 if (mapData.OriginalValue[i] == input) {
3894 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
3895 capture =
3896 mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef);
3898 break;
3901 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
3902 unsigned int defaultAS =
3903 ompBuilder.M.getDataLayout().getProgramAddressSpace();
3905 // Create the alloca for the argument the current point.
3906 llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
3908 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
3909 v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
3911 builder.CreateStore(&arg, v);
3913 builder.restoreIP(codeGenIP);
3915 switch (capture) {
3916 case omp::VariableCaptureKind::ByCopy: {
3917 retVal = v;
3918 break;
3920 case omp::VariableCaptureKind::ByRef: {
3921 retVal = builder.CreateAlignedLoad(
3922 v->getType(), v,
3923 ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
3924 break;
3926 case omp::VariableCaptureKind::This:
3927 case omp::VariableCaptureKind::VLAType:
3928 // TODO: Consider returning error to use standard reporting for
3929 // unimplemented features.
3930 assert(false && "Currently unsupported capture kind");
3931 break;
3934 return builder.saveIP();
3937 /// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
3938 /// operation and populate output variables with their corresponding host value
3939 /// (i.e. operand evaluated outside of the target region), based on their uses
3940 /// inside of the target region.
3942 /// Loop bounds and steps are only optionally populated, if output vectors are
3943 /// provided.
3944 static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
3945 Value &numTeamsLower, Value &numTeamsUpper,
3946 Value &threadLimit) {
3947 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
3948 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
3949 blockArgIface.getHostEvalBlockArgs())) {
3950 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
3952 for (Operation *user : blockArg.getUsers()) {
3953 llvm::TypeSwitch<Operation *>(user)
3954 .Case([&](omp::TeamsOp teamsOp) {
3955 if (teamsOp.getNumTeamsLower() == blockArg)
3956 numTeamsLower = hostEvalVar;
3957 else if (teamsOp.getNumTeamsUpper() == blockArg)
3958 numTeamsUpper = hostEvalVar;
3959 else if (teamsOp.getThreadLimit() == blockArg)
3960 threadLimit = hostEvalVar;
3961 else
3962 llvm_unreachable("unsupported host_eval use");
3964 .Case([&](omp::ParallelOp parallelOp) {
3965 if (parallelOp.getNumThreads() == blockArg)
3966 numThreads = hostEvalVar;
3967 else
3968 llvm_unreachable("unsupported host_eval use");
3970 .Case([&](omp::LoopNestOp loopOp) {
3971 // TODO: Extract bounds and step values. Currently, this cannot be
3972 // reached because translation would have been stopped earlier as a
3973 // result of `checkImplementationStatus` detecting and reporting
3974 // this situation.
3975 llvm_unreachable("unsupported host_eval use");
3977 .Default([](Operation *) {
3978 llvm_unreachable("unsupported host_eval use");
3984 /// If \p op is of the given type parameter, return it casted to that type.
3985 /// Otherwise, if its immediate parent operation (or some other higher-level
3986 /// parent, if \p immediateParent is false) is of that type, return that parent
3987 /// casted to the given type.
3989 /// If \p op is \c null or neither it or its parent(s) are of the specified
3990 /// type, return a \c null operation.
3991 template <typename OpTy>
3992 static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
3993 if (!op)
3994 return OpTy();
3996 if (OpTy casted = dyn_cast<OpTy>(op))
3997 return casted;
3999 if (immediateParent)
4000 return dyn_cast_if_present<OpTy>(op->getParentOp());
4002 return op->getParentOfType<OpTy>();
4005 /// If the given \p value is defined by an \c llvm.mlir.constant operation and
4006 /// it is of an integer type, return its value.
4007 static std::optional<int64_t> extractConstInteger(Value value) {
4008 if (!value)
4009 return std::nullopt;
4011 if (auto constOp =
4012 dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp()))
4013 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4014 return constAttr.getInt();
4016 return std::nullopt;
4019 /// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
4020 /// values as stated by the corresponding clauses, if constant.
4022 /// These default values must be set before the creation of the outlined LLVM
4023 /// function for the target region, so that they can be used to initialize the
4024 /// corresponding global `ConfigurationEnvironmentTy` structure.
4025 static void
4026 initTargetDefaultAttrs(omp::TargetOp targetOp,
4027 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
4028 bool isTargetDevice) {
4029 // TODO: Handle constant 'if' clauses.
4030 Operation *capturedOp = targetOp.getInnermostCapturedOmpOp();
4032 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
4033 if (!isTargetDevice) {
4034 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
4035 threadLimit);
4036 } else {
4037 // In the target device, values for these clauses are not passed as
4038 // host_eval, but instead evaluated prior to entry to the region. This
4039 // ensures values are mapped and available inside of the target region.
4040 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
4041 numTeamsLower = teamsOp.getNumTeamsLower();
4042 numTeamsUpper = teamsOp.getNumTeamsUpper();
4043 threadLimit = teamsOp.getThreadLimit();
4046 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
4047 numThreads = parallelOp.getNumThreads();
4050 // Handle clauses impacting the number of teams.
4052 int32_t minTeamsVal = 1, maxTeamsVal = -1;
4053 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
4054 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
4055 // clang and set min and max to the same value.
4056 if (numTeamsUpper) {
4057 if (auto val = extractConstInteger(numTeamsUpper))
4058 minTeamsVal = maxTeamsVal = *val;
4059 } else {
4060 minTeamsVal = maxTeamsVal = 0;
4062 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
4063 /*immediateParent=*/true) ||
4064 castOrGetParentOfType<omp::SimdOp>(capturedOp,
4065 /*immediateParent=*/true)) {
4066 minTeamsVal = maxTeamsVal = 1;
4067 } else {
4068 minTeamsVal = maxTeamsVal = -1;
4071 // Handle clauses impacting the number of threads.
4073 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
4074 if (!clauseValue)
4075 return;
4077 if (auto val = extractConstInteger(clauseValue))
4078 result = *val;
4080 // Found an applicable clause, so it's not undefined. Mark as unknown
4081 // because it's not constant.
4082 if (result < 0)
4083 result = 0;
4086 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
4087 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
4088 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
4089 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
4091 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
4092 int32_t maxThreadsVal = -1;
4093 if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
4094 setMaxValueFromClause(numThreads, maxThreadsVal);
4095 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
4096 /*immediateParent=*/true))
4097 maxThreadsVal = 1;
4099 // For max values, < 0 means unset, == 0 means set but unknown. Select the
4100 // minimum value between 'max_threads' and 'thread_limit' clauses that were
4101 // set.
4102 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
4103 if (combinedMaxThreadsVal < 0 ||
4104 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
4105 combinedMaxThreadsVal = teamsThreadLimitVal;
4107 if (combinedMaxThreadsVal < 0 ||
4108 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
4109 combinedMaxThreadsVal = maxThreadsVal;
4111 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4112 attrs.MinTeams = minTeamsVal;
4113 attrs.MaxTeams.front() = maxTeamsVal;
4114 attrs.MinThreads = 1;
4115 attrs.MaxThreads.front() = combinedMaxThreadsVal;
4118 /// Gather LLVM runtime values for all clauses evaluated in the host that are
4119 /// passed to the kernel invocation.
4121 /// This function must be called only when compiling for the host. Also, it will
4122 /// only provide correct results if it's called after the body of \c targetOp
4123 /// has been fully generated.
4124 static void
4125 initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
4126 LLVM::ModuleTranslation &moduleTranslation,
4127 omp::TargetOp targetOp,
4128 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4129 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4130 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
4131 teamsThreadLimit);
4133 // TODO: Handle constant 'if' clauses.
4134 if (Value targetThreadLimit = targetOp.getThreadLimit())
4135 attrs.TargetThreadLimit.front() =
4136 moduleTranslation.lookupValue(targetThreadLimit);
4138 if (numTeamsLower)
4139 attrs.MinTeams = moduleTranslation.lookupValue(numTeamsLower);
4141 if (numTeamsUpper)
4142 attrs.MaxTeams.front() = moduleTranslation.lookupValue(numTeamsUpper);
4144 if (teamsThreadLimit)
4145 attrs.TeamsThreadLimit.front() =
4146 moduleTranslation.lookupValue(teamsThreadLimit);
4148 if (numThreads)
4149 attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
4151 // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4154 static LogicalResult
4155 convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
4156 LLVM::ModuleTranslation &moduleTranslation) {
4157 auto targetOp = cast<omp::TargetOp>(opInst);
4158 if (failed(checkImplementationStatus(opInst)))
4159 return failure();
4161 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4162 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
4164 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
4165 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
4166 auto &targetRegion = targetOp.getRegion();
4167 // Holds the private vars that have been mapped along with the block argument
4168 // that corresponds to the MapInfoOp corresponding to the private var in
4169 // question. So, for instance:
4171 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
4172 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
4174 // Then, %10 has been created so that the descriptor can be used by the
4175 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
4176 // %arg0} in the mappedPrivateVars map.
4177 llvm::DenseMap<Value, Value> mappedPrivateVars;
4178 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
4179 SmallVector<Value> mapVars = targetOp.getMapVars();
4180 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
4181 llvm::Function *llvmOutlinedFn = nullptr;
4183 // TODO: It can also be false if a compile-time constant `false` IF clause is
4184 // specified.
4185 bool isOffloadEntry =
4186 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
4188 // For some private variables, the MapsForPrivatizedVariablesPass
4189 // creates MapInfoOp instances. Go through the private variables and
4190 // the mapped variables so that during codegeneration we are able
4191 // to quickly look up the corresponding map variable, if any for each
4192 // private variable.
4193 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
4194 OperandRange privateVars = targetOp.getPrivateVars();
4195 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
4196 std::optional<DenseI64ArrayAttr> privateMapIndices =
4197 targetOp.getPrivateMapsAttr();
4199 for (auto [privVarIdx, privVarSymPair] :
4200 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
4201 auto privVar = std::get<0>(privVarSymPair);
4202 auto privSym = std::get<1>(privVarSymPair);
4204 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
4205 omp::PrivateClauseOp privatizer =
4206 findPrivatizer(targetOp, privatizerName);
4208 if (!privatizer.needsMap())
4209 continue;
4211 mlir::Value mappedValue =
4212 targetOp.getMappedValueForPrivateVar(privVarIdx);
4213 assert(mappedValue && "Expected to find mapped value for a privatized "
4214 "variable that needs mapping");
4216 // The MapInfoOp defining the map var isn't really needed later.
4217 // So, we don't store it in any datastructure. Instead, we just
4218 // do some sanity checks on it right now.
4219 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
4220 [[maybe_unused]] Type varType = mapInfoOp.getVarType();
4222 // Check #1: Check that the type of the private variable matches
4223 // the type of the variable being mapped.
4224 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
4225 assert(
4226 varType == privVar.getType() &&
4227 "Type of private var doesn't match the type of the mapped value");
4229 // Ok, only 1 sanity check for now.
4230 // Record the block argument corresponding to this mapvar.
4231 mappedPrivateVars.insert(
4232 {privVar,
4233 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
4234 (*privateMapIndices)[privVarIdx])});
4238 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4239 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
4240 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4241 // Forward target-cpu and target-features function attributes from the
4242 // original function to the new outlined function.
4243 llvm::Function *llvmParentFn =
4244 moduleTranslation.lookupFunction(parentFn.getName());
4245 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
4246 assert(llvmParentFn && llvmOutlinedFn &&
4247 "Both parent and outlined functions must exist at this point");
4249 if (auto attr = llvmParentFn->getFnAttribute("target-cpu");
4250 attr.isStringAttribute())
4251 llvmOutlinedFn->addFnAttr(attr);
4253 if (auto attr = llvmParentFn->getFnAttribute("target-features");
4254 attr.isStringAttribute())
4255 llvmOutlinedFn->addFnAttr(attr);
4257 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
4258 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
4259 llvm::Value *mapOpValue =
4260 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
4261 moduleTranslation.mapValue(arg, mapOpValue);
4264 // Do privatization after moduleTranslation has already recorded
4265 // mapped values.
4266 MutableArrayRef<BlockArgument> privateBlockArgs =
4267 argIface.getPrivateBlockArgs();
4268 SmallVector<mlir::Value> mlirPrivateVars;
4269 SmallVector<llvm::Value *> llvmPrivateVars;
4270 SmallVector<omp::PrivateClauseOp> privateDecls;
4271 mlirPrivateVars.reserve(privateBlockArgs.size());
4272 llvmPrivateVars.reserve(privateBlockArgs.size());
4273 collectPrivatizationDecls(targetOp, privateDecls);
4274 for (mlir::Value privateVar : targetOp.getPrivateVars())
4275 mlirPrivateVars.push_back(privateVar);
4277 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
4278 builder, moduleTranslation, privateBlockArgs, privateDecls,
4279 mlirPrivateVars, llvmPrivateVars, allocaIP, &mappedPrivateVars);
4281 if (failed(handleError(afterAllocas, *targetOp)))
4282 return llvm::make_error<PreviouslyReportedError>();
4284 SmallVector<Region *> privateCleanupRegions;
4285 llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
4286 [](omp::PrivateClauseOp privatizer) {
4287 return &privatizer.getDeallocRegion();
4290 builder.restoreIP(codeGenIP);
4291 llvm::Expected<llvm::BasicBlock *> exitBlock = convertOmpOpRegions(
4292 targetRegion, "omp.target", builder, moduleTranslation);
4294 if (!exitBlock)
4295 return exitBlock.takeError();
4297 builder.SetInsertPoint(*exitBlock);
4298 if (!privateCleanupRegions.empty()) {
4299 if (failed(inlineOmpRegionCleanup(
4300 privateCleanupRegions, llvmPrivateVars, moduleTranslation,
4301 builder, "omp.targetop.private.cleanup",
4302 /*shouldLoadCleanupRegionArg=*/false))) {
4303 return llvm::createStringError(
4304 "failed to inline `dealloc` region of `omp.private` "
4305 "op in the target region");
4309 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
4312 StringRef parentName = parentFn.getName();
4314 llvm::TargetRegionEntryInfo entryInfo;
4316 if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName))
4317 return failure();
4319 MapInfoData mapData;
4320 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
4321 builder);
4323 llvm::OpenMPIRBuilder::MapInfosTy combinedInfos;
4324 auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
4325 -> llvm::OpenMPIRBuilder::MapInfosTy & {
4326 builder.restoreIP(codeGenIP);
4327 genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, true);
4328 return combinedInfos;
4331 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
4332 llvm::Value *&retVal, InsertPointTy allocaIP,
4333 InsertPointTy codeGenIP)
4334 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4335 // We just return the unaltered argument for the host function
4336 // for now, some alterations may be required in the future to
4337 // keep host fallback functions working identically to the device
4338 // version (e.g. pass ByCopy values should be treated as such on
4339 // host and device, currently not always the case)
4340 if (!isTargetDevice) {
4341 retVal = cast<llvm::Value>(&arg);
4342 return codeGenIP;
4345 return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
4346 *ompBuilder, moduleTranslation,
4347 allocaIP, codeGenIP);
4350 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
4351 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
4352 initTargetDefaultAttrs(targetOp, defaultAttrs, isTargetDevice);
4354 // Collect host-evaluated values needed to properly launch the kernel from the
4355 // host.
4356 if (!isTargetDevice)
4357 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp, runtimeAttrs);
4359 // Pass host-evaluated values as parameters to the kernel / host fallback,
4360 // except if they are constants. In any case, map the MLIR block argument to
4361 // the corresponding LLVM values.
4362 llvm::SmallVector<llvm::Value *, 4> kernelInput;
4363 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
4364 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
4365 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
4366 llvm::Value *value = moduleTranslation.lookupValue(var);
4367 moduleTranslation.mapValue(arg, value);
4369 if (!llvm::isa<llvm::Constant>(value))
4370 kernelInput.push_back(value);
4373 for (size_t i = 0; i < mapVars.size(); ++i) {
4374 // declare target arguments are not passed to kernels as arguments
4375 // TODO: We currently do not handle cases where a member is explicitly
4376 // passed in as an argument, this will likley need to be handled in
4377 // the near future, rather than using IsAMember, it may be better to
4378 // test if the relevant BlockArg is used within the target region and
4379 // then use that as a basis for exclusion in the kernel inputs.
4380 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
4381 kernelInput.push_back(mapData.OriginalValue[i]);
4384 SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
4385 buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(),
4386 moduleTranslation, dds);
4388 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4389 findAllocaInsertPoint(builder, moduleTranslation);
4390 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4392 llvm::Value *ifCond = nullptr;
4393 if (Value targetIfCond = targetOp.getIfExpr())
4394 ifCond = moduleTranslation.lookupValue(targetIfCond);
4396 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4397 moduleTranslation.getOpenMPBuilder()->createTarget(
4398 ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
4399 defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
4400 argAccessorCB, dds, targetOp.getNowait());
4402 if (failed(handleError(afterIP, opInst)))
4403 return failure();
4405 builder.restoreIP(*afterIP);
4407 // Remap access operations to declare target reference pointers for the
4408 // device, essentially generating extra loadop's as necessary
4409 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
4410 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
4411 llvmOutlinedFn);
4413 return success();
4416 static LogicalResult
4417 convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
4418 LLVM::ModuleTranslation &moduleTranslation) {
4419 // Amend omp.declare_target by deleting the IR of the outlined functions
4420 // created for target regions. They cannot be filtered out from MLIR earlier
4421 // because the omp.target operation inside must be translated to LLVM, but
4422 // the wrapper functions themselves must not remain at the end of the
4423 // process. We know that functions where omp.declare_target does not match
4424 // omp.is_target_device at this stage can only be wrapper functions because
4425 // those that aren't are removed earlier as an MLIR transformation pass.
4426 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
4427 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
4428 op->getParentOfType<ModuleOp>().getOperation())) {
4429 if (!offloadMod.getIsTargetDevice())
4430 return success();
4432 omp::DeclareTargetDeviceType declareType =
4433 attribute.getDeviceType().getValue();
4435 if (declareType == omp::DeclareTargetDeviceType::host) {
4436 llvm::Function *llvmFunc =
4437 moduleTranslation.lookupFunction(funcOp.getName());
4438 llvmFunc->dropAllReferences();
4439 llvmFunc->eraseFromParent();
4442 return success();
4445 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
4446 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
4447 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
4448 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4449 bool isDeclaration = gOp.isDeclaration();
4450 bool isExternallyVisible =
4451 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
4452 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
4453 llvm::StringRef mangledName = gOp.getSymName();
4454 auto captureClause =
4455 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
4456 auto deviceClause =
4457 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
4458 // unused for MLIR at the moment, required in Clang for book
4459 // keeping
4460 std::vector<llvm::GlobalVariable *> generatedRefs;
4462 std::vector<llvm::Triple> targetTriple;
4463 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
4464 op->getParentOfType<mlir::ModuleOp>()->getAttr(
4465 LLVM::LLVMDialect::getTargetTripleAttrName()));
4466 if (targetTripleAttr)
4467 targetTriple.emplace_back(targetTripleAttr.data());
4469 auto fileInfoCallBack = [&loc]() {
4470 std::string filename = "";
4471 std::uint64_t lineNo = 0;
4473 if (loc) {
4474 filename = loc.getFilename().str();
4475 lineNo = loc.getLine();
4478 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
4479 lineNo);
4482 ompBuilder->registerTargetGlobalVariable(
4483 captureClause, deviceClause, isDeclaration, isExternallyVisible,
4484 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
4485 generatedRefs, /*OpenMPSimd*/ false, targetTriple,
4486 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
4487 gVal->getType(), gVal);
4489 if (ompBuilder->Config.isTargetDevice() &&
4490 (attribute.getCaptureClause().getValue() !=
4491 mlir::omp::DeclareTargetCaptureClause::to ||
4492 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
4493 ompBuilder->getAddrOfDeclareTargetVar(
4494 captureClause, deviceClause, isDeclaration, isExternallyVisible,
4495 ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack), mangledName,
4496 generatedRefs, /*OpenMPSimd*/ false, targetTriple, gVal->getType(),
4497 /*GlobalInitializer*/ nullptr,
4498 /*VariableLinkage*/ nullptr);
4503 return success();
4506 // Returns true if the operation is inside a TargetOp or
4507 // is part of a declare target function.
4508 static bool isTargetDeviceOp(Operation *op) {
4509 // Assumes no reverse offloading
4510 if (op->getParentOfType<omp::TargetOp>())
4511 return true;
4513 // Certain operations return results, and whether utilised in host or
4514 // target there is a chance an LLVM Dialect operation depends on it
4515 // by taking it in as an operand, so we must always lower these in
4516 // some manner or result in an ICE (whether they end up in a no-op
4517 // or otherwise).
4518 if (mlir::isa<omp::ThreadprivateOp>(op))
4519 return true;
4521 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
4522 if (auto declareTargetIface =
4523 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
4524 parentFn.getOperation()))
4525 if (declareTargetIface.isDeclareTarget() &&
4526 declareTargetIface.getDeclareTargetDeviceType() !=
4527 mlir::omp::DeclareTargetDeviceType::host)
4528 return true;
4530 return false;
4533 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
4534 /// (including OpenMP runtime calls).
4535 static LogicalResult
4536 convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
4537 LLVM::ModuleTranslation &moduleTranslation) {
4539 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4541 return llvm::TypeSwitch<Operation *, LogicalResult>(op)
4542 .Case([&](omp::BarrierOp op) -> LogicalResult {
4543 if (failed(checkImplementationStatus(*op)))
4544 return failure();
4546 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4547 ompBuilder->createBarrier(builder.saveIP(),
4548 llvm::omp::OMPD_barrier);
4549 return handleError(afterIP, *op);
4551 .Case([&](omp::TaskyieldOp op) {
4552 if (failed(checkImplementationStatus(*op)))
4553 return failure();
4555 ompBuilder->createTaskyield(builder.saveIP());
4556 return success();
4558 .Case([&](omp::FlushOp op) {
4559 if (failed(checkImplementationStatus(*op)))
4560 return failure();
4562 // No support in Openmp runtime function (__kmpc_flush) to accept
4563 // the argument list.
4564 // OpenMP standard states the following:
4565 // "An implementation may implement a flush with a list by ignoring
4566 // the list, and treating it the same as a flush without a list."
4568 // The argument list is discarded so that, flush with a list is treated
4569 // same as a flush without a list.
4570 ompBuilder->createFlush(builder.saveIP());
4571 return success();
4573 .Case([&](omp::ParallelOp op) {
4574 return convertOmpParallel(op, builder, moduleTranslation);
4576 .Case([&](omp::MaskedOp) {
4577 return convertOmpMasked(*op, builder, moduleTranslation);
4579 .Case([&](omp::MasterOp) {
4580 return convertOmpMaster(*op, builder, moduleTranslation);
4582 .Case([&](omp::CriticalOp) {
4583 return convertOmpCritical(*op, builder, moduleTranslation);
4585 .Case([&](omp::OrderedRegionOp) {
4586 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
4588 .Case([&](omp::OrderedOp) {
4589 return convertOmpOrdered(*op, builder, moduleTranslation);
4591 .Case([&](omp::WsloopOp) {
4592 return convertOmpWsloop(*op, builder, moduleTranslation);
4594 .Case([&](omp::SimdOp) {
4595 return convertOmpSimd(*op, builder, moduleTranslation);
4597 .Case([&](omp::AtomicReadOp) {
4598 return convertOmpAtomicRead(*op, builder, moduleTranslation);
4600 .Case([&](omp::AtomicWriteOp) {
4601 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
4603 .Case([&](omp::AtomicUpdateOp op) {
4604 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
4606 .Case([&](omp::AtomicCaptureOp op) {
4607 return convertOmpAtomicCapture(op, builder, moduleTranslation);
4609 .Case([&](omp::SectionsOp) {
4610 return convertOmpSections(*op, builder, moduleTranslation);
4612 .Case([&](omp::SingleOp op) {
4613 return convertOmpSingle(op, builder, moduleTranslation);
4615 .Case([&](omp::TeamsOp op) {
4616 return convertOmpTeams(op, builder, moduleTranslation);
4618 .Case([&](omp::TaskOp op) {
4619 return convertOmpTaskOp(op, builder, moduleTranslation);
4621 .Case([&](omp::TaskgroupOp op) {
4622 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
4624 .Case([&](omp::TaskwaitOp op) {
4625 return convertOmpTaskwaitOp(op, builder, moduleTranslation);
4627 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareReductionOp,
4628 omp::CriticalDeclareOp>([](auto op) {
4629 // `yield` and `terminator` can be just omitted. The block structure
4630 // was created in the region that handles their parent operation.
4631 // `declare_reduction` will be used by reductions and is not
4632 // converted directly, skip it.
4633 // `critical.declare` is only used to declare names of critical
4634 // sections which will be used by `critical` ops and hence can be
4635 // ignored for lowering. The OpenMP IRBuilder will create unique
4636 // name for critical section names.
4637 return success();
4639 .Case([&](omp::ThreadprivateOp) {
4640 return convertOmpThreadprivate(*op, builder, moduleTranslation);
4642 .Case<omp::TargetDataOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
4643 omp::TargetUpdateOp>([&](auto op) {
4644 return convertOmpTargetData(op, builder, moduleTranslation);
4646 .Case([&](omp::TargetOp) {
4647 return convertOmpTarget(*op, builder, moduleTranslation);
4649 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
4650 [&](auto op) {
4651 // No-op, should be handled by relevant owning operations e.g.
4652 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp etc.
4653 // and then discarded
4654 return success();
4656 .Default([&](Operation *inst) {
4657 return inst->emitError() << "not yet implemented: " << inst->getName();
4661 static LogicalResult
4662 convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
4663 LLVM::ModuleTranslation &moduleTranslation) {
4664 return convertHostOrTargetOperation(op, builder, moduleTranslation);
4667 static LogicalResult
4668 convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
4669 LLVM::ModuleTranslation &moduleTranslation) {
4670 if (isa<omp::TargetOp>(op))
4671 return convertOmpTarget(*op, builder, moduleTranslation);
4672 if (isa<omp::TargetDataOp>(op))
4673 return convertOmpTargetData(op, builder, moduleTranslation);
4674 bool interrupted =
4675 op->walk<WalkOrder::PreOrder>([&](Operation *oper) {
4676 if (isa<omp::TargetOp>(oper)) {
4677 if (failed(convertOmpTarget(*oper, builder, moduleTranslation)))
4678 return WalkResult::interrupt();
4679 return WalkResult::skip();
4681 if (isa<omp::TargetDataOp>(oper)) {
4682 if (failed(convertOmpTargetData(oper, builder, moduleTranslation)))
4683 return WalkResult::interrupt();
4684 return WalkResult::skip();
4686 return WalkResult::advance();
4687 }).wasInterrupted();
4688 return failure(interrupted);
4691 namespace {
4693 /// Implementation of the dialect interface that converts operations belonging
4694 /// to the OpenMP dialect to LLVM IR.
4695 class OpenMPDialectLLVMIRTranslationInterface
4696 : public LLVMTranslationDialectInterface {
4697 public:
4698 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
4700 /// Translates the given operation to LLVM IR using the provided IR builder
4701 /// and saving the state in `moduleTranslation`.
4702 LogicalResult
4703 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
4704 LLVM::ModuleTranslation &moduleTranslation) const final;
4706 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
4707 /// runtime calls, or operation amendments
4708 LogicalResult
4709 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
4710 NamedAttribute attribute,
4711 LLVM::ModuleTranslation &moduleTranslation) const final;
4714 } // namespace
4716 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
4717 Operation *op, ArrayRef<llvm::Instruction *> instructions,
4718 NamedAttribute attribute,
4719 LLVM::ModuleTranslation &moduleTranslation) const {
4720 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
4721 attribute.getName())
4722 .Case("omp.is_target_device",
4723 [&](Attribute attr) {
4724 if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
4725 llvm::OpenMPIRBuilderConfig &config =
4726 moduleTranslation.getOpenMPBuilder()->Config;
4727 config.setIsTargetDevice(deviceAttr.getValue());
4728 return success();
4730 return failure();
4732 .Case("omp.is_gpu",
4733 [&](Attribute attr) {
4734 if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
4735 llvm::OpenMPIRBuilderConfig &config =
4736 moduleTranslation.getOpenMPBuilder()->Config;
4737 config.setIsGPU(gpuAttr.getValue());
4738 return success();
4740 return failure();
4742 .Case("omp.host_ir_filepath",
4743 [&](Attribute attr) {
4744 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
4745 llvm::OpenMPIRBuilder *ompBuilder =
4746 moduleTranslation.getOpenMPBuilder();
4747 ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
4748 return success();
4750 return failure();
4752 .Case("omp.flags",
4753 [&](Attribute attr) {
4754 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
4755 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
4756 return failure();
4758 .Case("omp.version",
4759 [&](Attribute attr) {
4760 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
4761 llvm::OpenMPIRBuilder *ompBuilder =
4762 moduleTranslation.getOpenMPBuilder();
4763 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
4764 versionAttr.getVersion());
4765 return success();
4767 return failure();
4769 .Case("omp.declare_target",
4770 [&](Attribute attr) {
4771 if (auto declareTargetAttr =
4772 dyn_cast<omp::DeclareTargetAttr>(attr))
4773 return convertDeclareTargetAttr(op, declareTargetAttr,
4774 moduleTranslation);
4775 return failure();
4777 .Case("omp.requires",
4778 [&](Attribute attr) {
4779 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
4780 using Requires = omp::ClauseRequires;
4781 Requires flags = requiresAttr.getValue();
4782 llvm::OpenMPIRBuilderConfig &config =
4783 moduleTranslation.getOpenMPBuilder()->Config;
4784 config.setHasRequiresReverseOffload(
4785 bitEnumContainsAll(flags, Requires::reverse_offload));
4786 config.setHasRequiresUnifiedAddress(
4787 bitEnumContainsAll(flags, Requires::unified_address));
4788 config.setHasRequiresUnifiedSharedMemory(
4789 bitEnumContainsAll(flags, Requires::unified_shared_memory));
4790 config.setHasRequiresDynamicAllocators(
4791 bitEnumContainsAll(flags, Requires::dynamic_allocators));
4792 return success();
4794 return failure();
4796 .Case("omp.target_triples",
4797 [&](Attribute attr) {
4798 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
4799 llvm::OpenMPIRBuilderConfig &config =
4800 moduleTranslation.getOpenMPBuilder()->Config;
4801 config.TargetTriples.clear();
4802 config.TargetTriples.reserve(triplesAttr.size());
4803 for (Attribute tripleAttr : triplesAttr) {
4804 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
4805 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
4806 else
4807 return failure();
4809 return success();
4811 return failure();
4813 .Default([](Attribute) {
4814 // Fall through for omp attributes that do not require lowering.
4815 return success();
4816 })(attribute.getValue());
4818 return failure();
4821 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
4822 /// (including OpenMP runtime calls).
4823 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
4824 Operation *op, llvm::IRBuilderBase &builder,
4825 LLVM::ModuleTranslation &moduleTranslation) const {
4827 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4828 if (ompBuilder->Config.isTargetDevice()) {
4829 if (isTargetDeviceOp(op)) {
4830 return convertTargetDeviceOp(op, builder, moduleTranslation);
4831 } else {
4832 return convertTargetOpsInNest(op, builder, moduleTranslation);
4835 return convertHostOrTargetOperation(op, builder, moduleTranslation);
4838 void mlir::registerOpenMPDialectTranslation(DialectRegistry &registry) {
4839 registry.insert<omp::OpenMPDialect>();
4840 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
4841 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
4845 void mlir::registerOpenMPDialectTranslation(MLIRContext &context) {
4846 DialectRegistry registry;
4847 registerOpenMPDialectTranslation(registry);
4848 context.appendDialectRegistry(registry);