[flang] Move `genCommonBlockMember` from OpenMP to ConvertVariable, NFC (#74488)
[llvm-project.git] / flang / lib / Lower / OpenMP.cpp
blob0fa1ac76d57edbf655547b3d5d03ed57e7e12878
1 //===-- OpenMP.cpp -- Open MP directive lowering --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11 //===----------------------------------------------------------------------===//
13 #include "flang/Lower/OpenMP.h"
14 #include "DirectivesCommon.h"
15 #include "flang/Common/idioms.h"
16 #include "flang/Lower/Bridge.h"
17 #include "flang/Lower/ConvertExpr.h"
18 #include "flang/Lower/ConvertVariable.h"
19 #include "flang/Lower/PFTBuilder.h"
20 #include "flang/Lower/StatementContext.h"
21 #include "flang/Optimizer/Builder/BoxValue.h"
22 #include "flang/Optimizer/Builder/FIRBuilder.h"
23 #include "flang/Optimizer/Builder/Todo.h"
24 #include "flang/Optimizer/HLFIR/HLFIROps.h"
25 #include "flang/Parser/dump-parse-tree.h"
26 #include "flang/Parser/parse-tree.h"
27 #include "flang/Semantics/openmp-directive-sets.h"
28 #include "flang/Semantics/tools.h"
29 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
30 #include "mlir/Dialect/SCF/IR/SCF.h"
31 #include "mlir/Transforms/RegionUtils.h"
32 #include "llvm/Frontend/OpenMP/OMPConstants.h"
33 #include "llvm/Support/CommandLine.h"
35 static llvm::cl::opt<bool> treatIndexAsSection(
36 "openmp-treat-index-as-section",
37 llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
38 llvm::cl::init(true));
40 using DeclareTargetCapturePair =
41 std::pair<mlir::omp::DeclareTargetCaptureClause,
42 Fortran::semantics::Symbol>;
44 //===----------------------------------------------------------------------===//
45 // Common helper functions
46 //===----------------------------------------------------------------------===//
48 static Fortran::semantics::Symbol *
49 getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
50 Fortran::semantics::Symbol *sym = nullptr;
51 std::visit(
52 Fortran::common::visitors{
53 [&](const Fortran::parser::Designator &designator) {
54 if (auto *arrayEle =
55 Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
56 designator)) {
57 sym = GetFirstName(arrayEle->base).symbol;
58 } else if (const Fortran::parser::Name *name =
59 Fortran::semantics::getDesignatorNameIfDataRef(
60 designator)) {
61 sym = name->symbol;
64 [&](const Fortran::parser::Name &name) { sym = name.symbol; }},
65 ompObject.u);
66 return sym;
69 static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
70 Fortran::lower::AbstractConverter &converter,
71 llvm::SmallVectorImpl<mlir::Value> &operands) {
72 auto addOperands = [&](Fortran::lower::SymbolRef sym) {
73 const mlir::Value variable = converter.getSymbolAddress(sym);
74 if (variable) {
75 operands.push_back(variable);
76 } else {
77 if (const auto *details =
78 sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
79 operands.push_back(converter.getSymbolAddress(details->symbol()));
80 converter.copySymbolBinding(details->symbol(), sym);
84 for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
85 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
86 addOperands(*sym);
90 static void gatherFuncAndVarSyms(
91 const Fortran::parser::OmpObjectList &objList,
92 mlir::omp::DeclareTargetCaptureClause clause,
93 llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
94 for (const Fortran::parser::OmpObject &ompObject : objList.v) {
95 Fortran::common::visit(
96 Fortran::common::visitors{
97 [&](const Fortran::parser::Designator &designator) {
98 if (const Fortran::parser::Name *name =
99 Fortran::semantics::getDesignatorNameIfDataRef(
100 designator)) {
101 symbolAndClause.emplace_back(clause, *name->symbol);
104 [&](const Fortran::parser::Name &name) {
105 symbolAndClause.emplace_back(clause, *name.symbol);
107 ompObject.u);
111 //===----------------------------------------------------------------------===//
112 // DataSharingProcessor
113 //===----------------------------------------------------------------------===//
115 class DataSharingProcessor {
116 bool hasLastPrivateOp;
117 mlir::OpBuilder::InsertPoint lastPrivIP;
118 mlir::OpBuilder::InsertPoint insPt;
119 mlir::Value loopIV;
120 // Symbols in private, firstprivate, and/or lastprivate clauses.
121 llvm::SetVector<const Fortran::semantics::Symbol *> privatizedSymbols;
122 llvm::SetVector<const Fortran::semantics::Symbol *> defaultSymbols;
123 llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInNestedRegions;
124 llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInParentRegions;
125 Fortran::lower::AbstractConverter &converter;
126 fir::FirOpBuilder &firOpBuilder;
127 const Fortran::parser::OmpClauseList &opClauseList;
128 Fortran::lower::pft::Evaluation &eval;
130 bool needBarrier();
131 void collectSymbols(Fortran::semantics::Symbol::Flag flag);
132 void collectOmpObjectListSymbol(
133 const Fortran::parser::OmpObjectList &ompObjectList,
134 llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet);
135 void collectSymbolsForPrivatization();
136 void insertBarrier();
137 void collectDefaultSymbols();
138 void privatize();
139 void defaultPrivatize();
140 void copyLastPrivatize(mlir::Operation *op);
141 void insertLastPrivateCompare(mlir::Operation *op);
142 void cloneSymbol(const Fortran::semantics::Symbol *sym);
143 void copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym);
144 void copyLastPrivateSymbol(const Fortran::semantics::Symbol *sym,
145 mlir::OpBuilder::InsertPoint *lastPrivIP);
146 void insertDeallocs();
148 public:
149 DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
150 const Fortran::parser::OmpClauseList &opClauseList,
151 Fortran::lower::pft::Evaluation &eval)
152 : hasLastPrivateOp(false), converter(converter),
153 firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList),
154 eval(eval) {}
155 // Privatisation is split into two steps.
156 // Step1 performs cloning of all privatisation clauses and copying for
157 // firstprivates. Step1 is performed at the place where process/processStep1
158 // is called. This is usually inside the Operation corresponding to the OpenMP
159 // construct, for looping constructs this is just before the Operation. The
160 // split into two steps was performed basically to be able to call
161 // privatisation for looping constructs before the operation is created since
162 // the bounds of the MLIR OpenMP operation can be privatised.
163 // Step2 performs the copying for lastprivates and requires knowledge of the
164 // MLIR operation to insert the last private update. Step2 adds
165 // dealocation code as well.
166 void processStep1();
167 void processStep2(mlir::Operation *op, bool isLoop);
169 void setLoopIV(mlir::Value iv) {
170 assert(!loopIV && "Loop iteration variable already set");
171 loopIV = iv;
175 void DataSharingProcessor::processStep1() {
176 collectSymbolsForPrivatization();
177 collectDefaultSymbols();
178 privatize();
179 defaultPrivatize();
180 insertBarrier();
183 void DataSharingProcessor::processStep2(mlir::Operation *op, bool isLoop) {
184 insPt = firOpBuilder.saveInsertionPoint();
185 copyLastPrivatize(op);
186 firOpBuilder.restoreInsertionPoint(insPt);
188 if (isLoop) {
189 // push deallocs out of the loop
190 firOpBuilder.setInsertionPointAfter(op);
191 insertDeallocs();
192 } else {
193 // insert dummy instruction to mark the insertion position
194 mlir::Value undefMarker = firOpBuilder.create<fir::UndefOp>(
195 op->getLoc(), firOpBuilder.getIndexType());
196 insertDeallocs();
197 firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
201 void DataSharingProcessor::insertDeallocs() {
202 for (const Fortran::semantics::Symbol *sym : privatizedSymbols)
203 if (Fortran::semantics::IsAllocatable(sym->GetUltimate())) {
204 converter.createHostAssociateVarCloneDealloc(*sym);
208 void DataSharingProcessor::cloneSymbol(const Fortran::semantics::Symbol *sym) {
209 // Privatization for symbols which are pre-determined (like loop index
210 // variables) happen separately, for everything else privatize here.
211 if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined))
212 return;
213 bool success = converter.createHostAssociateVarClone(*sym);
214 (void)success;
215 assert(success && "Privatization failed due to existing binding");
218 void DataSharingProcessor::copyFirstPrivateSymbol(
219 const Fortran::semantics::Symbol *sym) {
220 if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate))
221 converter.copyHostAssociateVar(*sym);
224 void DataSharingProcessor::copyLastPrivateSymbol(
225 const Fortran::semantics::Symbol *sym,
226 [[maybe_unused]] mlir::OpBuilder::InsertPoint *lastPrivIP) {
227 if (sym->test(Fortran::semantics::Symbol::Flag::OmpLastPrivate))
228 converter.copyHostAssociateVar(*sym, lastPrivIP);
231 void DataSharingProcessor::collectOmpObjectListSymbol(
232 const Fortran::parser::OmpObjectList &ompObjectList,
233 llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet) {
234 for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
235 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
236 symbolSet.insert(sym);
240 void DataSharingProcessor::collectSymbolsForPrivatization() {
241 bool hasCollapse = false;
242 for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
243 if (const auto &privateClause =
244 std::get_if<Fortran::parser::OmpClause::Private>(&clause.u)) {
245 collectOmpObjectListSymbol(privateClause->v, privatizedSymbols);
246 } else if (const auto &firstPrivateClause =
247 std::get_if<Fortran::parser::OmpClause::Firstprivate>(
248 &clause.u)) {
249 collectOmpObjectListSymbol(firstPrivateClause->v, privatizedSymbols);
250 } else if (const auto &lastPrivateClause =
251 std::get_if<Fortran::parser::OmpClause::Lastprivate>(
252 &clause.u)) {
253 collectOmpObjectListSymbol(lastPrivateClause->v, privatizedSymbols);
254 hasLastPrivateOp = true;
255 } else if (std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) {
256 hasCollapse = true;
260 if (hasCollapse && hasLastPrivateOp)
261 TODO(converter.getCurrentLocation(), "Collapse clause with lastprivate");
264 bool DataSharingProcessor ::needBarrier() {
265 for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {
266 if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate) &&
267 sym->test(Fortran::semantics::Symbol::Flag::OmpLastPrivate))
268 return true;
270 return false;
273 void DataSharingProcessor ::insertBarrier() {
274 // Emit implicit barrier to synchronize threads and avoid data races on
275 // initialization of firstprivate variables and post-update of lastprivate
276 // variables.
277 // FIXME: Emit barrier for lastprivate clause when 'sections' directive has
278 // 'nowait' clause. Otherwise, emit barrier when 'sections' directive has
279 // both firstprivate and lastprivate clause.
280 // Emit implicit barrier for linear clause. Maybe on somewhere else.
281 if (needBarrier())
282 firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation());
285 void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
286 bool cmpCreated = false;
287 mlir::OpBuilder::InsertPoint localInsPt = firOpBuilder.saveInsertionPoint();
288 for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
289 if (std::get_if<Fortran::parser::OmpClause::Lastprivate>(&clause.u)) {
290 // TODO: Add lastprivate support for simd construct
291 if (mlir::isa<mlir::omp::SectionOp>(op)) {
292 if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) {
293 // For `omp.sections`, lastprivatized variables occur in
294 // lexically final `omp.section` operation. The following FIR
295 // shall be generated for the same:
297 // omp.sections lastprivate(...) {
298 // omp.section {...}
299 // omp.section {...}
300 // omp.section {
301 // fir.allocate for `private`/`firstprivate`
302 // <More operations here>
303 // fir.if %true {
304 // ^%lpv_update_blk
305 // }
306 // }
307 // }
309 // To keep code consistency while handling privatization
310 // through this control flow, add a `fir.if` operation
311 // that always evaluates to true, in order to create
312 // a dedicated sub-region in `omp.section` where
313 // lastprivate FIR can reside. Later canonicalizations
314 // will optimize away this operation.
315 if (!eval.lowerAsUnstructured()) {
316 auto ifOp = firOpBuilder.create<fir::IfOp>(
317 op->getLoc(),
318 firOpBuilder.createIntegerConstant(
319 op->getLoc(), firOpBuilder.getIntegerType(1), 0x1),
320 /*else*/ false);
321 firOpBuilder.setInsertionPointToStart(
322 &ifOp.getThenRegion().front());
324 const Fortran::parser::OpenMPConstruct *parentOmpConstruct =
325 eval.parentConstruct->getIf<Fortran::parser::OpenMPConstruct>();
326 assert(parentOmpConstruct &&
327 "Expected a valid enclosing OpenMP construct");
328 const Fortran::parser::OpenMPSectionsConstruct *sectionsConstruct =
329 std::get_if<Fortran::parser::OpenMPSectionsConstruct>(
330 &parentOmpConstruct->u);
331 assert(sectionsConstruct &&
332 "Expected an enclosing omp.sections construct");
333 const Fortran::parser::OmpClauseList &sectionsEndClauseList =
334 std::get<Fortran::parser::OmpClauseList>(
335 std::get<Fortran::parser::OmpEndSectionsDirective>(
336 sectionsConstruct->t)
337 .t);
338 for (const Fortran::parser::OmpClause &otherClause :
339 sectionsEndClauseList.v)
340 if (std::get_if<Fortran::parser::OmpClause::Nowait>(
341 &otherClause.u))
342 // Emit implicit barrier to synchronize threads and avoid data
343 // races on post-update of lastprivate variables when `nowait`
344 // clause is present.
345 firOpBuilder.create<mlir::omp::BarrierOp>(
346 converter.getCurrentLocation());
347 firOpBuilder.setInsertionPointToStart(
348 &ifOp.getThenRegion().front());
349 lastPrivIP = firOpBuilder.saveInsertionPoint();
350 firOpBuilder.setInsertionPoint(ifOp);
351 insPt = firOpBuilder.saveInsertionPoint();
352 } else {
353 // Lastprivate operation is inserted at the end
354 // of the lexically last section in the sections
355 // construct
356 mlir::OpBuilder::InsertPoint unstructuredSectionsIP =
357 firOpBuilder.saveInsertionPoint();
358 firOpBuilder.setInsertionPointToStart(&op->getRegion(0).back());
359 lastPrivIP = firOpBuilder.saveInsertionPoint();
360 firOpBuilder.restoreInsertionPoint(unstructuredSectionsIP);
363 } else if (mlir::isa<mlir::omp::WsLoopOp>(op)) {
364 // Update the original variable just before exiting the worksharing
365 // loop. Conversion as follows:
367 // omp.wsloop {
368 // omp.wsloop { ...
369 // ... store
370 // store ===> %v = arith.addi %iv, %step
371 // omp.yield %cmp = %step < 0 ? %v < %ub : %v > %ub
372 // } fir.if %cmp {
373 // fir.store %v to %loopIV
374 // ^%lpv_update_blk:
375 // }
376 // omp.yield
377 // }
380 // Only generate the compare once in presence of multiple LastPrivate
381 // clauses.
382 if (cmpCreated)
383 continue;
384 cmpCreated = true;
386 mlir::Location loc = op->getLoc();
387 mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
388 firOpBuilder.setInsertionPoint(lastOper);
390 mlir::Value iv = op->getRegion(0).front().getArguments()[0];
391 mlir::Value ub =
392 mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getUpperBound()[0];
393 mlir::Value step = mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getStep()[0];
395 // v = iv + step
396 // cmp = step < 0 ? v < ub : v > ub
397 mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);
398 mlir::Value zero =
399 firOpBuilder.createIntegerConstant(loc, step.getType(), 0);
400 mlir::Value negativeStep = firOpBuilder.create<mlir::arith::CmpIOp>(
401 loc, mlir::arith::CmpIPredicate::slt, step, zero);
402 mlir::Value vLT = firOpBuilder.create<mlir::arith::CmpIOp>(
403 loc, mlir::arith::CmpIPredicate::slt, v, ub);
404 mlir::Value vGT = firOpBuilder.create<mlir::arith::CmpIOp>(
405 loc, mlir::arith::CmpIPredicate::sgt, v, ub);
406 mlir::Value cmpOp = firOpBuilder.create<mlir::arith::SelectOp>(
407 loc, negativeStep, vLT, vGT);
409 auto ifOp = firOpBuilder.create<fir::IfOp>(loc, cmpOp, /*else*/ false);
410 firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
411 assert(loopIV && "loopIV was not set");
412 firOpBuilder.create<fir::StoreOp>(op->getLoc(), v, loopIV);
413 lastPrivIP = firOpBuilder.saveInsertionPoint();
414 } else {
415 TODO(converter.getCurrentLocation(),
416 "lastprivate clause in constructs other than "
417 "simd/worksharing-loop");
421 firOpBuilder.restoreInsertionPoint(localInsPt);
424 void DataSharingProcessor::collectSymbols(
425 Fortran::semantics::Symbol::Flag flag) {
426 converter.collectSymbolSet(eval, defaultSymbols, flag,
427 /*collectSymbols=*/true,
428 /*collectHostAssociatedSymbols=*/true);
429 for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) {
430 if (e.hasNestedEvaluations())
431 converter.collectSymbolSet(e, symbolsInNestedRegions, flag,
432 /*collectSymbols=*/true,
433 /*collectHostAssociatedSymbols=*/false);
434 else
435 converter.collectSymbolSet(e, symbolsInParentRegions, flag,
436 /*collectSymbols=*/false,
437 /*collectHostAssociatedSymbols=*/true);
441 void DataSharingProcessor::collectDefaultSymbols() {
442 for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
443 if (const auto &defaultClause =
444 std::get_if<Fortran::parser::OmpClause::Default>(&clause.u)) {
445 if (defaultClause->v.v ==
446 Fortran::parser::OmpDefaultClause::Type::Private)
447 collectSymbols(Fortran::semantics::Symbol::Flag::OmpPrivate);
448 else if (defaultClause->v.v ==
449 Fortran::parser::OmpDefaultClause::Type::Firstprivate)
450 collectSymbols(Fortran::semantics::Symbol::Flag::OmpFirstPrivate);
455 void DataSharingProcessor::privatize() {
456 for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {
457 if (const auto *commonDet =
458 sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
459 for (const auto &mem : commonDet->objects()) {
460 cloneSymbol(&*mem);
461 copyFirstPrivateSymbol(&*mem);
463 } else {
464 cloneSymbol(sym);
465 copyFirstPrivateSymbol(sym);
470 void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) {
471 insertLastPrivateCompare(op);
472 for (const Fortran::semantics::Symbol *sym : privatizedSymbols)
473 if (const auto *commonDet =
474 sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
475 for (const auto &mem : commonDet->objects()) {
476 copyLastPrivateSymbol(&*mem, &lastPrivIP);
478 } else {
479 copyLastPrivateSymbol(sym, &lastPrivIP);
483 void DataSharingProcessor::defaultPrivatize() {
484 for (const Fortran::semantics::Symbol *sym : defaultSymbols) {
485 if (!Fortran::semantics::IsProcedure(*sym) &&
486 !sym->GetUltimate().has<Fortran::semantics::DerivedTypeDetails>() &&
487 !sym->GetUltimate().has<Fortran::semantics::NamelistDetails>() &&
488 !symbolsInNestedRegions.contains(sym) &&
489 !symbolsInParentRegions.contains(sym) &&
490 !privatizedSymbols.contains(sym)) {
491 cloneSymbol(sym);
492 copyFirstPrivateSymbol(sym);
497 //===----------------------------------------------------------------------===//
498 // ClauseProcessor
499 //===----------------------------------------------------------------------===//
501 /// Class that handles the processing of OpenMP clauses.
503 /// Its `process<ClauseName>()` methods perform MLIR code generation for their
504 /// corresponding clause if it is present in the clause list. Otherwise, they
505 /// will return `false` to signal that the clause was not found.
507 /// The intended use is of this class is to move clause processing outside of
508 /// construct processing, since the same clauses can appear attached to
509 /// different constructs and constructs can be combined, so that code
510 /// duplication is minimized.
512 /// Each construct-lowering function only calls the `process<ClauseName>()`
513 /// methods that relate to clauses that can impact the lowering of that
514 /// construct.
515 class ClauseProcessor {
516 using ClauseTy = Fortran::parser::OmpClause;
518 public:
519 ClauseProcessor(Fortran::lower::AbstractConverter &converter,
520 const Fortran::parser::OmpClauseList &clauses)
521 : converter(converter), clauses(clauses) {}
523 // 'Unique' clauses: They can appear at most once in the clause list.
524 bool
525 processCollapse(mlir::Location currentLocation,
526 Fortran::lower::pft::Evaluation &eval,
527 llvm::SmallVectorImpl<mlir::Value> &lowerBound,
528 llvm::SmallVectorImpl<mlir::Value> &upperBound,
529 llvm::SmallVectorImpl<mlir::Value> &step,
530 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
531 std::size_t &loopVarTypeSize) const;
532 bool processDefault() const;
533 bool processDevice(Fortran::lower::StatementContext &stmtCtx,
534 mlir::Value &result) const;
535 bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const;
536 bool processFinal(Fortran::lower::StatementContext &stmtCtx,
537 mlir::Value &result) const;
538 bool processHint(mlir::IntegerAttr &result) const;
539 bool processMergeable(mlir::UnitAttr &result) const;
540 bool processNowait(mlir::UnitAttr &result) const;
541 bool processNumTeams(Fortran::lower::StatementContext &stmtCtx,
542 mlir::Value &result) const;
543 bool processNumThreads(Fortran::lower::StatementContext &stmtCtx,
544 mlir::Value &result) const;
545 bool processOrdered(mlir::IntegerAttr &result) const;
546 bool processPriority(Fortran::lower::StatementContext &stmtCtx,
547 mlir::Value &result) const;
548 bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const;
549 bool processSafelen(mlir::IntegerAttr &result) const;
550 bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr,
551 mlir::omp::ScheduleModifierAttr &modifierAttr,
552 mlir::UnitAttr &simdModifierAttr) const;
553 bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx,
554 mlir::Value &result) const;
555 bool processSimdlen(mlir::IntegerAttr &result) const;
556 bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx,
557 mlir::Value &result) const;
558 bool processUntied(mlir::UnitAttr &result) const;
560 // 'Repeatable' clauses: They can appear multiple times in the clause list.
561 bool
562 processAllocate(llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
563 llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const;
564 bool processCopyin() const;
565 bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
566 llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
567 bool
568 processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
569 bool
570 processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
571 mlir::Value &result) const;
572 bool
573 processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
575 // This method is used to process a map clause.
576 // The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to
577 // store the original type, location and Fortran symbol for the map operands.
578 // They may be used later on to create the block_arguments for some of the
579 // target directives that require it.
580 bool processMap(mlir::Location currentLocation,
581 const llvm::omp::Directive &directive,
582 Fortran::semantics::SemanticsContext &semanticsContext,
583 Fortran::lower::StatementContext &stmtCtx,
584 llvm::SmallVectorImpl<mlir::Value> &mapOperands,
585 llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
586 llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
587 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
588 *mapSymbols = nullptr) const;
589 bool processReduction(
590 mlir::Location currentLocation,
591 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
592 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const;
593 bool processSectionsReduction(mlir::Location currentLocation) const;
594 bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
595 bool
596 processUseDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
597 llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
598 llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
599 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
600 &useDeviceSymbols) const;
601 bool
602 processUseDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
603 llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
604 llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
605 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
606 &useDeviceSymbols) const;
608 // Call this method for these clauses that should be supported but are not
609 // implemented yet. It triggers a compilation error if any of the given
610 // clauses is found.
611 template <typename... Ts>
612 void processTODO(mlir::Location currentLocation,
613 llvm::omp::Directive directive) const;
615 private:
616 using ClauseIterator = std::list<ClauseTy>::const_iterator;
618 /// Utility to find a clause within a range in the clause list.
619 template <typename T>
620 static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end) {
621 for (ClauseIterator it = begin; it != end; ++it) {
622 if (std::get_if<T>(&it->u))
623 return it;
626 return end;
629 /// Return the first instance of the given clause found in the clause list or
630 /// `nullptr` if not present. If more than one instance is expected, use
631 /// `findRepeatableClause` instead.
632 template <typename T>
633 const T *
634 findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const {
635 ClauseIterator it = findClause<T>(clauses.v.begin(), clauses.v.end());
636 if (it != clauses.v.end()) {
637 if (source)
638 *source = &it->source;
639 return &std::get<T>(it->u);
641 return nullptr;
644 /// Call `callbackFn` for each occurrence of the given clause. Return `true`
645 /// if at least one instance was found.
646 template <typename T>
647 bool findRepeatableClause(
648 std::function<void(const T *, const Fortran::parser::CharBlock &source)>
649 callbackFn) const {
650 bool found = false;
651 ClauseIterator nextIt, endIt = clauses.v.end();
652 for (ClauseIterator it = clauses.v.begin(); it != endIt; it = nextIt) {
653 nextIt = findClause<T>(it, endIt);
655 if (nextIt != endIt) {
656 callbackFn(&std::get<T>(nextIt->u), nextIt->source);
657 found = true;
658 ++nextIt;
661 return found;
664 /// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
665 template <typename T>
666 bool markClauseOccurrence(mlir::UnitAttr &result) const {
667 if (findUniqueClause<T>()) {
668 result = converter.getFirOpBuilder().getUnitAttr();
669 return true;
671 return false;
674 Fortran::lower::AbstractConverter &converter;
675 const Fortran::parser::OmpClauseList &clauses;
678 //===----------------------------------------------------------------------===//
679 // ClauseProcessor helper functions
680 //===----------------------------------------------------------------------===//
682 /// Check for unsupported map operand types.
683 static void checkMapType(mlir::Location location, mlir::Type type) {
684 if (auto refType = type.dyn_cast<fir::ReferenceType>())
685 type = refType.getElementType();
686 if (auto boxType = type.dyn_cast_or_null<fir::BoxType>())
687 if (!boxType.getElementType().isa<fir::PointerType>())
688 TODO(location, "OMPD_target_data MapOperand BoxType");
691 static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
692 return (llvm::Twine(name) +
693 (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
694 llvm::Twine(ty.getIntOrFloatBitWidth()))
695 .str();
698 static std::string getReductionName(
699 Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
700 mlir::Type ty) {
701 std::string reductionName;
703 switch (intrinsicOp) {
704 case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
705 reductionName = "add_reduction";
706 break;
707 case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
708 reductionName = "multiply_reduction";
709 break;
710 case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
711 return "and_reduction";
712 case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
713 return "eqv_reduction";
714 case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
715 return "or_reduction";
716 case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
717 return "neqv_reduction";
718 default:
719 reductionName = "other_reduction";
720 break;
723 return getReductionName(reductionName, ty);
726 /// This function returns the identity value of the operator \p reductionOpName.
727 /// For example:
728 /// 0 + x = x,
729 /// 1 * x = x
730 static int getOperationIdentity(llvm::StringRef reductionOpName,
731 mlir::Location loc) {
732 if (reductionOpName.contains("add") || reductionOpName.contains("or") ||
733 reductionOpName.contains("neqv"))
734 return 0;
735 if (reductionOpName.contains("multiply") || reductionOpName.contains("and") ||
736 reductionOpName.contains("eqv"))
737 return 1;
738 TODO(loc, "Reduction of some intrinsic operators is not supported");
741 static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
742 llvm::StringRef reductionOpName,
743 fir::FirOpBuilder &builder) {
744 assert((fir::isa_integer(type) || fir::isa_real(type) ||
745 type.isa<fir::LogicalType>()) &&
746 "only integer, logical and real types are currently supported");
747 if (reductionOpName.contains("max")) {
748 if (auto ty = type.dyn_cast<mlir::FloatType>()) {
749 const llvm::fltSemantics &sem = ty.getFloatSemantics();
750 return builder.createRealConstant(
751 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
753 unsigned bits = type.getIntOrFloatBitWidth();
754 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
755 return builder.createIntegerConstant(loc, type, minInt);
756 } else if (reductionOpName.contains("min")) {
757 if (auto ty = type.dyn_cast<mlir::FloatType>()) {
758 const llvm::fltSemantics &sem = ty.getFloatSemantics();
759 return builder.createRealConstant(
760 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
762 unsigned bits = type.getIntOrFloatBitWidth();
763 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
764 return builder.createIntegerConstant(loc, type, maxInt);
765 } else if (reductionOpName.contains("ior")) {
766 unsigned bits = type.getIntOrFloatBitWidth();
767 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
768 return builder.createIntegerConstant(loc, type, zeroInt);
769 } else if (reductionOpName.contains("ieor")) {
770 unsigned bits = type.getIntOrFloatBitWidth();
771 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
772 return builder.createIntegerConstant(loc, type, zeroInt);
773 } else if (reductionOpName.contains("iand")) {
774 unsigned bits = type.getIntOrFloatBitWidth();
775 int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
776 return builder.createIntegerConstant(loc, type, allOnInt);
777 } else {
778 if (type.isa<mlir::FloatType>())
779 return builder.create<mlir::arith::ConstantOp>(
780 loc, type,
781 builder.getFloatAttr(
782 type, (double)getOperationIdentity(reductionOpName, loc)));
784 if (type.isa<fir::LogicalType>()) {
785 mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
786 loc, builder.getI1Type(),
787 builder.getIntegerAttr(builder.getI1Type(),
788 getOperationIdentity(reductionOpName, loc)));
789 return builder.createConvert(loc, type, intConst);
792 return builder.create<mlir::arith::ConstantOp>(
793 loc, type,
794 builder.getIntegerAttr(type,
795 getOperationIdentity(reductionOpName, loc)));
799 template <typename FloatOp, typename IntegerOp>
800 static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
801 mlir::Type type, mlir::Location loc,
802 mlir::Value op1, mlir::Value op2) {
803 assert(type.isIntOrIndexOrFloat() &&
804 "only integer and float types are currently supported");
805 if (type.isIntOrIndex())
806 return builder.create<IntegerOp>(loc, op1, op2);
807 return builder.create<FloatOp>(loc, op1, op2);
810 static mlir::omp::ReductionDeclareOp
811 createMinimalReductionDecl(fir::FirOpBuilder &builder,
812 llvm::StringRef reductionOpName, mlir::Type type,
813 mlir::Location loc) {
814 mlir::ModuleOp module = builder.getModule();
815 mlir::OpBuilder modBuilder(module.getBodyRegion());
817 mlir::omp::ReductionDeclareOp decl =
818 modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
819 type);
820 builder.createBlock(&decl.getInitializerRegion(),
821 decl.getInitializerRegion().end(), {type}, {loc});
822 builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
823 mlir::Value init = getReductionInitValue(loc, type, reductionOpName, builder);
824 builder.create<mlir::omp::YieldOp>(loc, init);
826 builder.createBlock(&decl.getReductionRegion(),
827 decl.getReductionRegion().end(), {type, type},
828 {loc, loc});
830 return decl;
833 /// Creates an OpenMP reduction declaration and inserts it into the provided
834 /// symbol table. The declaration has a constant initializer with the neutral
835 /// value `initValue`, and the reduction combiner carried over from `reduce`.
836 /// TODO: Generalize this for non-integer types, add atomic region.
837 static mlir::omp::ReductionDeclareOp
838 createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
839 const Fortran::parser::ProcedureDesignator &procDesignator,
840 mlir::Type type, mlir::Location loc) {
841 mlir::OpBuilder::InsertionGuard guard(builder);
842 mlir::ModuleOp module = builder.getModule();
844 auto decl =
845 module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
846 if (decl)
847 return decl;
849 decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
850 builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
851 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
852 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
854 mlir::Value reductionOp;
855 if (const auto *name{
856 Fortran::parser::Unwrap<Fortran::parser::Name>(procDesignator)}) {
857 if (name->source == "max") {
858 reductionOp =
859 getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
860 builder, type, loc, op1, op2);
861 } else if (name->source == "min") {
862 reductionOp =
863 getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
864 builder, type, loc, op1, op2);
865 } else if (name->source == "ior") {
866 assert((type.isIntOrIndex()) && "only integer is expected");
867 reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
868 } else if (name->source == "ieor") {
869 assert((type.isIntOrIndex()) && "only integer is expected");
870 reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
871 } else if (name->source == "iand") {
872 assert((type.isIntOrIndex()) && "only integer is expected");
873 reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
874 } else {
875 TODO(loc, "Reduction of some intrinsic operators is not supported");
879 builder.create<mlir::omp::YieldOp>(loc, reductionOp);
880 return decl;
883 /// Creates an OpenMP reduction declaration and inserts it into the provided
884 /// symbol table. The declaration has a constant initializer with the neutral
885 /// value `initValue`, and the reduction combiner carried over from `reduce`.
886 /// TODO: Generalize this for non-integer types, add atomic region.
887 static mlir::omp::ReductionDeclareOp createReductionDecl(
888 fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
889 Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
890 mlir::Type type, mlir::Location loc) {
891 mlir::OpBuilder::InsertionGuard guard(builder);
892 mlir::ModuleOp module = builder.getModule();
894 auto decl =
895 module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
896 if (decl)
897 return decl;
899 decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
900 builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
901 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
902 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
904 mlir::Value reductionOp;
905 switch (intrinsicOp) {
906 case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
907 reductionOp =
908 getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
909 builder, type, loc, op1, op2);
910 break;
911 case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
912 reductionOp =
913 getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
914 builder, type, loc, op1, op2);
915 break;
916 case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
917 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
918 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
920 mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
922 reductionOp = builder.createConvert(loc, type, andiOp);
923 break;
925 case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
926 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
927 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
929 mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
931 reductionOp = builder.createConvert(loc, type, oriOp);
932 break;
934 case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
935 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
936 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
938 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
939 loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
941 reductionOp = builder.createConvert(loc, type, cmpiOp);
942 break;
944 case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
945 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
946 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
948 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
949 loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
951 reductionOp = builder.createConvert(loc, type, cmpiOp);
952 break;
954 default:
955 TODO(loc, "Reduction of some intrinsic operators is not supported");
958 builder.create<mlir::omp::YieldOp>(loc, reductionOp);
959 return decl;
962 static mlir::omp::ScheduleModifier
963 translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) {
964 switch (m.v) {
965 case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic:
966 return mlir::omp::ScheduleModifier::monotonic;
967 case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic:
968 return mlir::omp::ScheduleModifier::nonmonotonic;
969 case Fortran::parser::OmpScheduleModifierType::ModType::Simd:
970 return mlir::omp::ScheduleModifier::simd;
972 return mlir::omp::ScheduleModifier::none;
975 static mlir::omp::ScheduleModifier
976 getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) {
977 const auto &modifier =
978 std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t);
979 // The input may have the modifier any order, so we look for one that isn't
980 // SIMD. If modifier is not set at all, fall down to the bottom and return
981 // "none".
982 if (modifier) {
983 const auto &modType1 =
984 std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t);
985 if (modType1.v.v ==
986 Fortran::parser::OmpScheduleModifierType::ModType::Simd) {
987 const auto &modType2 = std::get<
988 std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>(
989 modifier->t);
990 if (modType2 &&
991 modType2->v.v !=
992 Fortran::parser::OmpScheduleModifierType::ModType::Simd)
993 return translateScheduleModifier(modType2->v);
995 return mlir::omp::ScheduleModifier::none;
998 return translateScheduleModifier(modType1.v);
1000 return mlir::omp::ScheduleModifier::none;
1003 static mlir::omp::ScheduleModifier
1004 getSimdModifier(const Fortran::parser::OmpScheduleClause &x) {
1005 const auto &modifier =
1006 std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t);
1007 // Either of the two possible modifiers in the input can be the SIMD modifier,
1008 // so look in either one, and return simd if we find one. Not found = return
1009 // "none".
1010 if (modifier) {
1011 const auto &modType1 =
1012 std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t);
1013 if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd)
1014 return mlir::omp::ScheduleModifier::simd;
1016 const auto &modType2 = std::get<
1017 std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>(
1018 modifier->t);
1019 if (modType2 && modType2->v.v ==
1020 Fortran::parser::OmpScheduleModifierType::ModType::Simd)
1021 return mlir::omp::ScheduleModifier::simd;
1023 return mlir::omp::ScheduleModifier::none;
1026 static void
1027 genAllocateClause(Fortran::lower::AbstractConverter &converter,
1028 const Fortran::parser::OmpAllocateClause &ompAllocateClause,
1029 llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
1030 llvm::SmallVectorImpl<mlir::Value> &allocateOperands) {
1031 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1032 mlir::Location currentLocation = converter.getCurrentLocation();
1033 Fortran::lower::StatementContext stmtCtx;
1035 mlir::Value allocatorOperand;
1036 const Fortran::parser::OmpObjectList &ompObjectList =
1037 std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t);
1038 const auto &allocateModifier = std::get<
1039 std::optional<Fortran::parser::OmpAllocateClause::AllocateModifier>>(
1040 ompAllocateClause.t);
1042 // If the allocate modifier is present, check if we only use the allocator
1043 // submodifier. ALIGN in this context is unimplemented
1044 const bool onlyAllocator =
1045 allocateModifier &&
1046 std::holds_alternative<
1047 Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
1048 allocateModifier->u);
1050 if (allocateModifier && !onlyAllocator) {
1051 TODO(currentLocation, "OmpAllocateClause ALIGN modifier");
1054 // Check if allocate clause has allocator specified. If so, add it
1055 // to list of allocators, otherwise, add default allocator to
1056 // list of allocators.
1057 if (onlyAllocator) {
1058 const auto &allocatorValue = std::get<
1059 Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
1060 allocateModifier->u);
1061 allocatorOperand = fir::getBase(converter.genExprValue(
1062 *Fortran::semantics::GetExpr(allocatorValue.v), stmtCtx));
1063 allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
1064 allocatorOperand);
1065 } else {
1066 allocatorOperand = firOpBuilder.createIntegerConstant(
1067 currentLocation, firOpBuilder.getI32Type(), 1);
1068 allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
1069 allocatorOperand);
1071 genObjectList(ompObjectList, converter, allocateOperands);
1074 static mlir::omp::ClauseProcBindKindAttr genProcBindKindAttr(
1075 fir::FirOpBuilder &firOpBuilder,
1076 const Fortran::parser::OmpClause::ProcBind *procBindClause) {
1077 mlir::omp::ClauseProcBindKind procBindKind;
1078 switch (procBindClause->v.v) {
1079 case Fortran::parser::OmpProcBindClause::Type::Master:
1080 procBindKind = mlir::omp::ClauseProcBindKind::Master;
1081 break;
1082 case Fortran::parser::OmpProcBindClause::Type::Close:
1083 procBindKind = mlir::omp::ClauseProcBindKind::Close;
1084 break;
1085 case Fortran::parser::OmpProcBindClause::Type::Spread:
1086 procBindKind = mlir::omp::ClauseProcBindKind::Spread;
1087 break;
1088 case Fortran::parser::OmpProcBindClause::Type::Primary:
1089 procBindKind = mlir::omp::ClauseProcBindKind::Primary;
1090 break;
1092 return mlir::omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(),
1093 procBindKind);
1096 static mlir::omp::ClauseTaskDependAttr
1097 genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
1098 const Fortran::parser::OmpClause::Depend *dependClause) {
1099 mlir::omp::ClauseTaskDepend pbKind;
1100 switch (
1101 std::get<Fortran::parser::OmpDependenceType>(
1102 std::get<Fortran::parser::OmpDependClause::InOut>(dependClause->v.u)
1104 .v) {
1105 case Fortran::parser::OmpDependenceType::Type::In:
1106 pbKind = mlir::omp::ClauseTaskDepend::taskdependin;
1107 break;
1108 case Fortran::parser::OmpDependenceType::Type::Out:
1109 pbKind = mlir::omp::ClauseTaskDepend::taskdependout;
1110 break;
1111 case Fortran::parser::OmpDependenceType::Type::Inout:
1112 pbKind = mlir::omp::ClauseTaskDepend::taskdependinout;
1113 break;
1114 default:
1115 llvm_unreachable("unknown parser task dependence type");
1116 break;
1118 return mlir::omp::ClauseTaskDependAttr::get(firOpBuilder.getContext(),
1119 pbKind);
1122 static mlir::Value getIfClauseOperand(
1123 Fortran::lower::AbstractConverter &converter,
1124 const Fortran::parser::OmpClause::If *ifClause,
1125 Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
1126 mlir::Location clauseLocation) {
1127 // Only consider the clause if it's intended for the given directive.
1128 auto &directive = std::get<
1129 std::optional<Fortran::parser::OmpIfClause::DirectiveNameModifier>>(
1130 ifClause->v.t);
1131 if (directive && directive.value() != directiveName)
1132 return nullptr;
1134 Fortran::lower::StatementContext stmtCtx;
1135 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1136 auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
1137 mlir::Value ifVal = fir::getBase(
1138 converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
1139 return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(),
1140 ifVal);
1143 /// Creates a reduction declaration and associates it with an OpenMP block
1144 /// directive.
1145 static void
1146 addReductionDecl(mlir::Location currentLocation,
1147 Fortran::lower::AbstractConverter &converter,
1148 const Fortran::parser::OmpReductionClause &reduction,
1149 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1150 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
1151 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1152 mlir::omp::ReductionDeclareOp decl;
1153 const auto &redOperator{
1154 std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
1155 const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
1156 if (const auto &redDefinedOp =
1157 std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
1158 const auto &intrinsicOp{
1159 std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
1160 redDefinedOp->u)};
1161 switch (intrinsicOp) {
1162 case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
1163 case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
1164 case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
1165 case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
1166 case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
1167 case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
1168 break;
1170 default:
1171 TODO(currentLocation,
1172 "Reduction of some intrinsic operators is not supported");
1173 break;
1175 for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
1176 if (const auto *name{
1177 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1178 if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
1179 mlir::Value symVal = converter.getSymbolAddress(*symbol);
1180 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
1181 symVal = declOp.getBase();
1182 mlir::Type redType =
1183 symVal.getType().cast<fir::ReferenceType>().getEleTy();
1184 reductionVars.push_back(symVal);
1185 if (redType.isa<fir::LogicalType>())
1186 decl = createReductionDecl(
1187 firOpBuilder,
1188 getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
1189 intrinsicOp, redType, currentLocation);
1190 else if (redType.isIntOrIndexOrFloat()) {
1191 decl = createReductionDecl(firOpBuilder,
1192 getReductionName(intrinsicOp, redType),
1193 intrinsicOp, redType, currentLocation);
1194 } else {
1195 TODO(currentLocation, "Reduction of some types is not supported");
1197 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
1198 firOpBuilder.getContext(), decl.getSymName()));
1202 } else if (const auto *reductionIntrinsic =
1203 std::get_if<Fortran::parser::ProcedureDesignator>(
1204 &redOperator.u)) {
1205 if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
1206 reductionIntrinsic)}) {
1207 if ((name->source != "max") && (name->source != "min") &&
1208 (name->source != "ior") && (name->source != "ieor") &&
1209 (name->source != "iand")) {
1210 TODO(currentLocation,
1211 "Reduction of intrinsic procedures is not supported");
1213 std::string intrinsicOp = name->ToString();
1214 for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
1215 if (const auto *name{
1216 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1217 if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
1218 mlir::Value symVal = converter.getSymbolAddress(*symbol);
1219 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
1220 symVal = declOp.getBase();
1221 mlir::Type redType =
1222 symVal.getType().cast<fir::ReferenceType>().getEleTy();
1223 reductionVars.push_back(symVal);
1224 assert(redType.isIntOrIndexOrFloat() &&
1225 "Unsupported reduction type");
1226 decl = createReductionDecl(
1227 firOpBuilder, getReductionName(intrinsicOp, redType),
1228 *reductionIntrinsic, redType, currentLocation);
1229 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
1230 firOpBuilder.getContext(), decl.getSymName()));
1238 static void
1239 addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
1240 const Fortran::parser::OmpObjectList &useDeviceClause,
1241 llvm::SmallVectorImpl<mlir::Value> &operands,
1242 llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
1243 llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
1244 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
1245 &useDeviceSymbols) {
1246 genObjectList(useDeviceClause, converter, operands);
1247 for (mlir::Value &operand : operands) {
1248 checkMapType(operand.getLoc(), operand.getType());
1249 useDeviceTypes.push_back(operand.getType());
1250 useDeviceLocs.push_back(operand.getLoc());
1252 for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) {
1253 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
1254 useDeviceSymbols.push_back(sym);
1258 //===----------------------------------------------------------------------===//
1259 // ClauseProcessor unique clauses
1260 //===----------------------------------------------------------------------===//
1262 bool ClauseProcessor::processCollapse(
1263 mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
1264 llvm::SmallVectorImpl<mlir::Value> &lowerBound,
1265 llvm::SmallVectorImpl<mlir::Value> &upperBound,
1266 llvm::SmallVectorImpl<mlir::Value> &step,
1267 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
1268 std::size_t &loopVarTypeSize) const {
1269 bool found = false;
1270 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1272 // Collect the loops to collapse.
1273 Fortran::lower::pft::Evaluation *doConstructEval =
1274 &eval.getFirstNestedEvaluation();
1275 if (doConstructEval->getIf<Fortran::parser::DoConstruct>()
1276 ->IsDoConcurrent()) {
1277 TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
1280 std::int64_t collapseValue = 1l;
1281 if (auto *collapseClause = findUniqueClause<ClauseTy::Collapse>()) {
1282 const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
1283 collapseValue = Fortran::evaluate::ToInt64(*expr).value();
1284 found = true;
1287 loopVarTypeSize = 0;
1288 do {
1289 Fortran::lower::pft::Evaluation *doLoop =
1290 &doConstructEval->getFirstNestedEvaluation();
1291 auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
1292 assert(doStmt && "Expected do loop to be in the nested evaluation");
1293 const auto &loopControl =
1294 std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
1295 const Fortran::parser::LoopControl::Bounds *bounds =
1296 std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
1297 assert(bounds && "Expected bounds for worksharing do loop");
1298 Fortran::lower::StatementContext stmtCtx;
1299 lowerBound.push_back(fir::getBase(converter.genExprValue(
1300 *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
1301 upperBound.push_back(fir::getBase(converter.genExprValue(
1302 *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
1303 if (bounds->step) {
1304 step.push_back(fir::getBase(converter.genExprValue(
1305 *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
1306 } else { // If `step` is not present, assume it as `1`.
1307 step.push_back(firOpBuilder.createIntegerConstant(
1308 currentLocation, firOpBuilder.getIntegerType(32), 1));
1310 iv.push_back(bounds->name.thing.symbol);
1311 loopVarTypeSize = std::max(loopVarTypeSize,
1312 bounds->name.thing.symbol->GetUltimate().size());
1313 collapseValue--;
1314 doConstructEval =
1315 &*std::next(doConstructEval->getNestedEvaluations().begin());
1316 } while (collapseValue > 0);
1318 return found;
1321 bool ClauseProcessor::processDefault() const {
1322 if (auto *defaultClause = findUniqueClause<ClauseTy::Default>()) {
1323 // Private, Firstprivate, Shared, None
1324 switch (defaultClause->v.v) {
1325 case Fortran::parser::OmpDefaultClause::Type::Shared:
1326 case Fortran::parser::OmpDefaultClause::Type::None:
1327 // Default clause with shared or none do not require any handling since
1328 // Shared is the default behavior in the IR and None is only required
1329 // for semantic checks.
1330 break;
1331 case Fortran::parser::OmpDefaultClause::Type::Private:
1332 // TODO Support default(private)
1333 break;
1334 case Fortran::parser::OmpDefaultClause::Type::Firstprivate:
1335 // TODO Support default(firstprivate)
1336 break;
1338 return true;
1340 return false;
1343 bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
1344 mlir::Value &result) const {
1345 const Fortran::parser::CharBlock *source = nullptr;
1346 if (auto *deviceClause = findUniqueClause<ClauseTy::Device>(&source)) {
1347 mlir::Location clauseLocation = converter.genLocation(*source);
1348 if (auto deviceModifier = std::get<
1349 std::optional<Fortran::parser::OmpDeviceClause::DeviceModifier>>(
1350 deviceClause->v.t)) {
1351 if (deviceModifier ==
1352 Fortran::parser::OmpDeviceClause::DeviceModifier::Ancestor) {
1353 TODO(clauseLocation, "OMPD_target Device Modifier Ancestor");
1356 if (const auto *deviceExpr = Fortran::semantics::GetExpr(
1357 std::get<Fortran::parser::ScalarIntExpr>(deviceClause->v.t))) {
1358 result = fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx));
1360 return true;
1362 return false;
1365 bool ClauseProcessor::processDeviceType(
1366 mlir::omp::DeclareTargetDeviceType &result) const {
1367 if (auto *deviceTypeClause = findUniqueClause<ClauseTy::DeviceType>()) {
1368 // Case: declare target ... device_type(any | host | nohost)
1369 switch (deviceTypeClause->v.v) {
1370 case Fortran::parser::OmpDeviceTypeClause::Type::Nohost:
1371 result = mlir::omp::DeclareTargetDeviceType::nohost;
1372 break;
1373 case Fortran::parser::OmpDeviceTypeClause::Type::Host:
1374 result = mlir::omp::DeclareTargetDeviceType::host;
1375 break;
1376 case Fortran::parser::OmpDeviceTypeClause::Type::Any:
1377 result = mlir::omp::DeclareTargetDeviceType::any;
1378 break;
1380 return true;
1382 return false;
1385 bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
1386 mlir::Value &result) const {
1387 const Fortran::parser::CharBlock *source = nullptr;
1388 if (auto *finalClause = findUniqueClause<ClauseTy::Final>(&source)) {
1389 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1390 mlir::Location clauseLocation = converter.genLocation(*source);
1392 mlir::Value finalVal = fir::getBase(converter.genExprValue(
1393 *Fortran::semantics::GetExpr(finalClause->v), stmtCtx));
1394 result = firOpBuilder.createConvert(clauseLocation,
1395 firOpBuilder.getI1Type(), finalVal);
1396 return true;
1398 return false;
1401 bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const {
1402 if (auto *hintClause = findUniqueClause<ClauseTy::Hint>()) {
1403 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1404 const auto *expr = Fortran::semantics::GetExpr(hintClause->v);
1405 int64_t hintValue = *Fortran::evaluate::ToInt64(*expr);
1406 result = firOpBuilder.getI64IntegerAttr(hintValue);
1407 return true;
1409 return false;
1412 bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const {
1413 return markClauseOccurrence<ClauseTy::Mergeable>(result);
1416 bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const {
1417 return markClauseOccurrence<ClauseTy::Nowait>(result);
1420 bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx,
1421 mlir::Value &result) const {
1422 // TODO Get lower and upper bounds for num_teams when parser is updated to
1423 // accept both.
1424 if (auto *numTeamsClause = findUniqueClause<ClauseTy::NumTeams>()) {
1425 result = fir::getBase(converter.genExprValue(
1426 *Fortran::semantics::GetExpr(numTeamsClause->v), stmtCtx));
1427 return true;
1429 return false;
1432 bool ClauseProcessor::processNumThreads(
1433 Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
1434 if (auto *numThreadsClause = findUniqueClause<ClauseTy::NumThreads>()) {
1435 // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
1436 result = fir::getBase(converter.genExprValue(
1437 *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
1438 return true;
1440 return false;
1443 bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const {
1444 if (auto *orderedClause = findUniqueClause<ClauseTy::Ordered>()) {
1445 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1446 int64_t orderedClauseValue = 0l;
1447 if (orderedClause->v.has_value()) {
1448 const auto *expr = Fortran::semantics::GetExpr(orderedClause->v);
1449 orderedClauseValue = *Fortran::evaluate::ToInt64(*expr);
1451 result = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
1452 return true;
1454 return false;
1457 bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx,
1458 mlir::Value &result) const {
1459 if (auto *priorityClause = findUniqueClause<ClauseTy::Priority>()) {
1460 result = fir::getBase(converter.genExprValue(
1461 *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx));
1462 return true;
1464 return false;
1467 bool ClauseProcessor::processProcBind(
1468 mlir::omp::ClauseProcBindKindAttr &result) const {
1469 if (auto *procBindClause = findUniqueClause<ClauseTy::ProcBind>()) {
1470 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1471 result = genProcBindKindAttr(firOpBuilder, procBindClause);
1472 return true;
1474 return false;
1477 bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const {
1478 if (auto *safelenClause = findUniqueClause<ClauseTy::Safelen>()) {
1479 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1480 const auto *expr = Fortran::semantics::GetExpr(safelenClause->v);
1481 const std::optional<std::int64_t> safelenVal =
1482 Fortran::evaluate::ToInt64(*expr);
1483 result = firOpBuilder.getI64IntegerAttr(*safelenVal);
1484 return true;
1486 return false;
1489 bool ClauseProcessor::processSchedule(
1490 mlir::omp::ClauseScheduleKindAttr &valAttr,
1491 mlir::omp::ScheduleModifierAttr &modifierAttr,
1492 mlir::UnitAttr &simdModifierAttr) const {
1493 if (auto *scheduleClause = findUniqueClause<ClauseTy::Schedule>()) {
1494 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1495 mlir::MLIRContext *context = firOpBuilder.getContext();
1496 const Fortran::parser::OmpScheduleClause &scheduleType = scheduleClause->v;
1497 const auto &scheduleClauseKind =
1498 std::get<Fortran::parser::OmpScheduleClause::ScheduleType>(
1499 scheduleType.t);
1501 mlir::omp::ClauseScheduleKind scheduleKind;
1502 switch (scheduleClauseKind) {
1503 case Fortran::parser::OmpScheduleClause::ScheduleType::Static:
1504 scheduleKind = mlir::omp::ClauseScheduleKind::Static;
1505 break;
1506 case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic:
1507 scheduleKind = mlir::omp::ClauseScheduleKind::Dynamic;
1508 break;
1509 case Fortran::parser::OmpScheduleClause::ScheduleType::Guided:
1510 scheduleKind = mlir::omp::ClauseScheduleKind::Guided;
1511 break;
1512 case Fortran::parser::OmpScheduleClause::ScheduleType::Auto:
1513 scheduleKind = mlir::omp::ClauseScheduleKind::Auto;
1514 break;
1515 case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime:
1516 scheduleKind = mlir::omp::ClauseScheduleKind::Runtime;
1517 break;
1520 mlir::omp::ScheduleModifier scheduleModifier =
1521 getScheduleModifier(scheduleClause->v);
1523 if (scheduleModifier != mlir::omp::ScheduleModifier::none)
1524 modifierAttr =
1525 mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier);
1527 if (getSimdModifier(scheduleClause->v) != mlir::omp::ScheduleModifier::none)
1528 simdModifierAttr = firOpBuilder.getUnitAttr();
1530 valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
1531 return true;
1533 return false;
1536 bool ClauseProcessor::processScheduleChunk(
1537 Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
1538 if (auto *scheduleClause = findUniqueClause<ClauseTy::Schedule>()) {
1539 if (const auto &chunkExpr =
1540 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
1541 scheduleClause->v.t)) {
1542 if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) {
1543 result = fir::getBase(converter.genExprValue(*expr, stmtCtx));
1546 return true;
1548 return false;
1551 bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const {
1552 if (auto *simdlenClause = findUniqueClause<ClauseTy::Simdlen>()) {
1553 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1554 const auto *expr = Fortran::semantics::GetExpr(simdlenClause->v);
1555 const std::optional<std::int64_t> simdlenVal =
1556 Fortran::evaluate::ToInt64(*expr);
1557 result = firOpBuilder.getI64IntegerAttr(*simdlenVal);
1558 return true;
1560 return false;
1563 bool ClauseProcessor::processThreadLimit(
1564 Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
1565 if (auto *threadLmtClause = findUniqueClause<ClauseTy::ThreadLimit>()) {
1566 result = fir::getBase(converter.genExprValue(
1567 *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx));
1568 return true;
1570 return false;
1573 bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
1574 return markClauseOccurrence<ClauseTy::Untied>(result);
1577 //===----------------------------------------------------------------------===//
1578 // ClauseProcessor repeatable clauses
1579 //===----------------------------------------------------------------------===//
1581 bool ClauseProcessor::processAllocate(
1582 llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
1583 llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const {
1584 return findRepeatableClause<ClauseTy::Allocate>(
1585 [&](const ClauseTy::Allocate *allocateClause,
1586 const Fortran::parser::CharBlock &) {
1587 genAllocateClause(converter, allocateClause->v, allocatorOperands,
1588 allocateOperands);
1592 bool ClauseProcessor::processCopyin() const {
1593 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1594 mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint();
1595 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
1596 auto checkAndCopyHostAssociateVar =
1597 [&](Fortran::semantics::Symbol *sym,
1598 mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) {
1599 assert(sym->has<Fortran::semantics::HostAssocDetails>() &&
1600 "No host-association found");
1601 if (converter.isPresentShallowLookup(*sym))
1602 converter.copyHostAssociateVar(*sym, copyAssignIP);
1604 bool hasCopyin = findRepeatableClause<ClauseTy::Copyin>(
1605 [&](const ClauseTy::Copyin *copyinClause,
1606 const Fortran::parser::CharBlock &) {
1607 const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v;
1608 for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
1609 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
1610 if (const auto *commonDetails =
1611 sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
1612 for (const auto &mem : commonDetails->objects())
1613 checkAndCopyHostAssociateVar(&*mem, &insPt);
1614 break;
1616 if (Fortran::semantics::IsAllocatableOrObjectPointer(
1617 &sym->GetUltimate()))
1618 TODO(converter.getCurrentLocation(),
1619 "pointer or allocatable variables in Copyin clause");
1620 assert(sym->has<Fortran::semantics::HostAssocDetails>() &&
1621 "No host-association found");
1622 checkAndCopyHostAssociateVar(sym);
1626 // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to
1627 // the execution of the associated structured block. Emit implicit barrier to
1628 // synchronize threads and avoid data races on propagation master's thread
1629 // values of threadprivate variables to local instances of that variables of
1630 // all other implicit threads.
1631 if (hasCopyin)
1632 firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation());
1633 firOpBuilder.restoreInsertionPoint(insPt);
1634 return hasCopyin;
1637 bool ClauseProcessor::processDepend(
1638 llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
1639 llvm::SmallVectorImpl<mlir::Value> &dependOperands) const {
1640 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1642 return findRepeatableClause<ClauseTy::Depend>(
1643 [&](const ClauseTy::Depend *dependClause,
1644 const Fortran::parser::CharBlock &) {
1645 const std::list<Fortran::parser::Designator> &depVal =
1646 std::get<std::list<Fortran::parser::Designator>>(
1647 std::get<Fortran::parser::OmpDependClause::InOut>(
1648 dependClause->v.u)
1649 .t);
1650 mlir::omp::ClauseTaskDependAttr dependTypeOperand =
1651 genDependKindAttr(firOpBuilder, dependClause);
1652 dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
1653 dependTypeOperand);
1654 for (const Fortran::parser::Designator &ompObject : depVal) {
1655 Fortran::semantics::Symbol *sym = nullptr;
1656 std::visit(
1657 Fortran::common::visitors{
1658 [&](const Fortran::parser::DataRef &designator) {
1659 if (const Fortran::parser::Name *name =
1660 std::get_if<Fortran::parser::Name>(&designator.u)) {
1661 sym = name->symbol;
1662 } else if (std::get_if<Fortran::common::Indirection<
1663 Fortran::parser::ArrayElement>>(
1664 &designator.u)) {
1665 TODO(converter.getCurrentLocation(),
1666 "array sections not supported for task depend");
1669 [&](const Fortran::parser::Substring &designator) {
1670 TODO(converter.getCurrentLocation(),
1671 "substring not supported for task depend");
1673 (ompObject).u);
1674 const mlir::Value variable = converter.getSymbolAddress(*sym);
1675 dependOperands.push_back(variable);
1680 bool ClauseProcessor::processIf(
1681 Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
1682 mlir::Value &result) const {
1683 bool found = false;
1684 findRepeatableClause<ClauseTy::If>(
1685 [&](const ClauseTy::If *ifClause,
1686 const Fortran::parser::CharBlock &source) {
1687 mlir::Location clauseLocation = converter.genLocation(source);
1688 mlir::Value operand = getIfClauseOperand(converter, ifClause,
1689 directiveName, clauseLocation);
1690 // Assume that, at most, a single 'if' clause will be applicable to the
1691 // given directive.
1692 if (operand) {
1693 result = operand;
1694 found = true;
1697 return found;
1700 bool ClauseProcessor::processLink(
1701 llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
1702 return findRepeatableClause<ClauseTy::Link>(
1703 [&](const ClauseTy::Link *linkClause,
1704 const Fortran::parser::CharBlock &) {
1705 // Case: declare target link(var1, var2)...
1706 gatherFuncAndVarSyms(
1707 linkClause->v, mlir::omp::DeclareTargetCaptureClause::link, result);
1711 static mlir::omp::MapInfoOp
1712 createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
1713 mlir::Value baseAddr, std::stringstream &name,
1714 mlir::SmallVector<mlir::Value> bounds, uint64_t mapType,
1715 mlir::omp::VariableCaptureKind mapCaptureType,
1716 mlir::Type retTy) {
1717 mlir::Value varPtr, varPtrPtr;
1718 mlir::TypeAttr varType;
1720 if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
1721 baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
1722 retTy = baseAddr.getType();
1725 varPtr = baseAddr;
1726 varType = mlir::TypeAttr::get(
1727 llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
1729 mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
1730 loc, retTy, varPtr, varType, varPtrPtr, bounds,
1731 builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
1732 builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
1733 builder.getStringAttr(name.str()));
1734 return op;
1737 bool ClauseProcessor::processMap(
1738 mlir::Location currentLocation, const llvm::omp::Directive &directive,
1739 Fortran::semantics::SemanticsContext &semanticsContext,
1740 Fortran::lower::StatementContext &stmtCtx,
1741 llvm::SmallVectorImpl<mlir::Value> &mapOperands,
1742 llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
1743 llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
1744 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
1745 const {
1746 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1747 return findRepeatableClause<ClauseTy::Map>(
1748 [&](const ClauseTy::Map *mapClause,
1749 const Fortran::parser::CharBlock &source) {
1750 mlir::Location clauseLocation = converter.genLocation(source);
1751 const auto &oMapType =
1752 std::get<std::optional<Fortran::parser::OmpMapType>>(
1753 mapClause->v.t);
1754 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1755 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1756 // If the map type is specified, then process it else Tofrom is the
1757 // default.
1758 if (oMapType) {
1759 const Fortran::parser::OmpMapType::Type &mapType =
1760 std::get<Fortran::parser::OmpMapType::Type>(oMapType->t);
1761 switch (mapType) {
1762 case Fortran::parser::OmpMapType::Type::To:
1763 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1764 break;
1765 case Fortran::parser::OmpMapType::Type::From:
1766 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1767 break;
1768 case Fortran::parser::OmpMapType::Type::Tofrom:
1769 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1770 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1771 break;
1772 case Fortran::parser::OmpMapType::Type::Alloc:
1773 case Fortran::parser::OmpMapType::Type::Release:
1774 // alloc and release is the default map_type for the Target Data
1775 // Ops, i.e. if no bits for map_type is supplied then alloc/release
1776 // is implicitly assumed based on the target directive. Default
1777 // value for Target Data and Enter Data is alloc and for Exit Data
1778 // it is release.
1779 break;
1780 case Fortran::parser::OmpMapType::Type::Delete:
1781 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1784 if (std::get<std::optional<Fortran::parser::OmpMapType::Always>>(
1785 oMapType->t))
1786 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1787 } else {
1788 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1789 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1792 for (const Fortran::parser::OmpObject &ompObject :
1793 std::get<Fortran::parser::OmpObjectList>(mapClause->v.t).v) {
1794 llvm::SmallVector<mlir::Value> bounds;
1795 std::stringstream asFortran;
1796 mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
1797 Fortran::parser::OmpObject, mlir::omp::DataBoundsType,
1798 mlir::omp::DataBoundsOp>(
1799 converter, firOpBuilder, semanticsContext, stmtCtx, ompObject,
1800 clauseLocation, asFortran, bounds, treatIndexAsSection);
1802 // Explicit map captures are captured ByRef by default,
1803 // optimisation passes may alter this to ByCopy or other capture
1804 // types to optimise
1805 mlir::Value mapOp = createMapInfoOp(
1806 firOpBuilder, clauseLocation, baseAddr, asFortran, bounds,
1807 static_cast<
1808 std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1809 mapTypeBits),
1810 mlir::omp::VariableCaptureKind::ByRef, baseAddr.getType());
1812 mapOperands.push_back(mapOp);
1813 if (mapSymTypes)
1814 mapSymTypes->push_back(baseAddr.getType());
1815 if (mapSymLocs)
1816 mapSymLocs->push_back(baseAddr.getLoc());
1817 if (mapSymbols)
1818 mapSymbols->push_back(getOmpObjectSymbol(ompObject));
1823 bool ClauseProcessor::processReduction(
1824 mlir::Location currentLocation,
1825 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1826 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const {
1827 return findRepeatableClause<ClauseTy::Reduction>(
1828 [&](const ClauseTy::Reduction *reductionClause,
1829 const Fortran::parser::CharBlock &) {
1830 addReductionDecl(currentLocation, converter, reductionClause->v,
1831 reductionVars, reductionDeclSymbols);
1835 bool ClauseProcessor::processSectionsReduction(
1836 mlir::Location currentLocation) const {
1837 return findRepeatableClause<ClauseTy::Reduction>(
1838 [&](const ClauseTy::Reduction *, const Fortran::parser::CharBlock &) {
1839 TODO(currentLocation, "OMPC_Reduction");
1843 bool ClauseProcessor::processTo(
1844 llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
1845 return findRepeatableClause<ClauseTy::To>(
1846 [&](const ClauseTy::To *toClause, const Fortran::parser::CharBlock &) {
1847 // Case: declare target to(func, var1, var2)...
1848 gatherFuncAndVarSyms(toClause->v,
1849 mlir::omp::DeclareTargetCaptureClause::to, result);
1853 bool ClauseProcessor::processEnter(
1854 llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
1855 return findRepeatableClause<ClauseTy::Enter>(
1856 [&](const ClauseTy::Enter *enterClause,
1857 const Fortran::parser::CharBlock &) {
1858 // Case: declare target enter(func, var1, var2)...
1859 gatherFuncAndVarSyms(enterClause->v,
1860 mlir::omp::DeclareTargetCaptureClause::enter,
1861 result);
1865 bool ClauseProcessor::processUseDeviceAddr(
1866 llvm::SmallVectorImpl<mlir::Value> &operands,
1867 llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
1868 llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
1869 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
1870 const {
1871 return findRepeatableClause<ClauseTy::UseDeviceAddr>(
1872 [&](const ClauseTy::UseDeviceAddr *devAddrClause,
1873 const Fortran::parser::CharBlock &) {
1874 addUseDeviceClause(converter, devAddrClause->v, operands,
1875 useDeviceTypes, useDeviceLocs, useDeviceSymbols);
1879 bool ClauseProcessor::processUseDevicePtr(
1880 llvm::SmallVectorImpl<mlir::Value> &operands,
1881 llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
1882 llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
1883 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
1884 const {
1885 return findRepeatableClause<ClauseTy::UseDevicePtr>(
1886 [&](const ClauseTy::UseDevicePtr *devPtrClause,
1887 const Fortran::parser::CharBlock &) {
1888 addUseDeviceClause(converter, devPtrClause->v, operands, useDeviceTypes,
1889 useDeviceLocs, useDeviceSymbols);
1893 template <typename... Ts>
1894 void ClauseProcessor::processTODO(mlir::Location currentLocation,
1895 llvm::omp::Directive directive) const {
1896 auto checkUnhandledClause = [&](const auto *x) {
1897 if (!x)
1898 return;
1899 TODO(currentLocation,
1900 "Unhandled clause " +
1901 llvm::StringRef(Fortran::parser::ParseTreeDumper::GetNodeName(*x))
1902 .upper() +
1903 " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() +
1904 " construct");
1907 for (ClauseIterator it = clauses.v.begin(); it != clauses.v.end(); ++it)
1908 (checkUnhandledClause(std::get_if<Ts>(&it->u)), ...);
1911 //===----------------------------------------------------------------------===//
1912 // Code generation helper functions
1913 //===----------------------------------------------------------------------===//
1915 static fir::GlobalOp globalInitialization(
1916 Fortran::lower::AbstractConverter &converter,
1917 fir::FirOpBuilder &firOpBuilder, const Fortran::semantics::Symbol &sym,
1918 const Fortran::lower::pft::Variable &var, mlir::Location currentLocation) {
1919 mlir::Type ty = converter.genType(sym);
1920 std::string globalName = converter.mangleName(sym);
1921 mlir::StringAttr linkage = firOpBuilder.createInternalLinkage();
1922 fir::GlobalOp global =
1923 firOpBuilder.createGlobal(currentLocation, ty, globalName, linkage);
1925 // Create default initialization for non-character scalar.
1926 if (Fortran::semantics::IsAllocatableOrObjectPointer(&sym)) {
1927 mlir::Type baseAddrType = ty.dyn_cast<fir::BoxType>().getEleTy();
1928 Fortran::lower::createGlobalInitialization(
1929 firOpBuilder, global, [&](fir::FirOpBuilder &b) {
1930 mlir::Value nullAddr =
1931 b.createNullConstant(currentLocation, baseAddrType);
1932 mlir::Value box =
1933 b.create<fir::EmboxOp>(currentLocation, ty, nullAddr);
1934 b.create<fir::HasValueOp>(currentLocation, box);
1936 } else {
1937 Fortran::lower::createGlobalInitialization(
1938 firOpBuilder, global, [&](fir::FirOpBuilder &b) {
1939 mlir::Value undef = b.create<fir::UndefOp>(currentLocation, ty);
1940 b.create<fir::HasValueOp>(currentLocation, undef);
1944 return global;
1947 static mlir::Operation *getCompareFromReductionOp(mlir::Operation *reductionOp,
1948 mlir::Value loadVal) {
1949 for (mlir::Value reductionOperand : reductionOp->getOperands()) {
1950 if (mlir::Operation *compareOp = reductionOperand.getDefiningOp()) {
1951 if (compareOp->getOperand(0) == loadVal ||
1952 compareOp->getOperand(1) == loadVal)
1953 assert((mlir::isa<mlir::arith::CmpIOp>(compareOp) ||
1954 mlir::isa<mlir::arith::CmpFOp>(compareOp)) &&
1955 "Expected comparison not found in reduction intrinsic");
1956 return compareOp;
1959 return nullptr;
1962 // Get the extended value for \p val by extracting additional variable
1963 // information from \p base.
1964 static fir::ExtendedValue getExtendedValue(fir::ExtendedValue base,
1965 mlir::Value val) {
1966 return base.match(
1967 [&](const fir::MutableBoxValue &box) -> fir::ExtendedValue {
1968 return fir::MutableBoxValue(val, box.nonDeferredLenParams(), {});
1970 [&](const auto &) -> fir::ExtendedValue {
1971 return fir::substBase(base, val);
1975 static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter,
1976 Fortran::lower::pft::Evaluation &eval) {
1977 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1978 mlir::Location currentLocation = converter.getCurrentLocation();
1979 mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint();
1980 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
1982 // Get the original ThreadprivateOp corresponding to the symbol and use the
1983 // symbol value from that operation to create one ThreadprivateOp copy
1984 // operation inside the parallel region.
1985 auto genThreadprivateOp = [&](Fortran::lower::SymbolRef sym) -> mlir::Value {
1986 mlir::Value symOriThreadprivateValue = converter.getSymbolAddress(sym);
1987 mlir::Operation *op = symOriThreadprivateValue.getDefiningOp();
1988 if (auto declOp = mlir::dyn_cast<hlfir::DeclareOp>(op))
1989 op = declOp.getMemref().getDefiningOp();
1990 assert(mlir::isa<mlir::omp::ThreadprivateOp>(op) &&
1991 "Threadprivate operation not created");
1992 mlir::Value symValue =
1993 mlir::dyn_cast<mlir::omp::ThreadprivateOp>(op).getSymAddr();
1994 return firOpBuilder.create<mlir::omp::ThreadprivateOp>(
1995 currentLocation, symValue.getType(), symValue);
1998 llvm::SetVector<const Fortran::semantics::Symbol *> threadprivateSyms;
1999 converter.collectSymbolSet(
2000 eval, threadprivateSyms,
2001 Fortran::semantics::Symbol::Flag::OmpThreadprivate);
2002 std::set<Fortran::semantics::SourceName> threadprivateSymNames;
2004 // For a COMMON block, the ThreadprivateOp is generated for itself instead of
2005 // its members, so only bind the value of the new copied ThreadprivateOp
2006 // inside the parallel region to the common block symbol only once for
2007 // multiple members in one COMMON block.
2008 llvm::SetVector<const Fortran::semantics::Symbol *> commonSyms;
2009 for (std::size_t i = 0; i < threadprivateSyms.size(); i++) {
2010 const Fortran::semantics::Symbol *sym = threadprivateSyms[i];
2011 mlir::Value symThreadprivateValue;
2012 // The variable may be used more than once, and each reference has one
2013 // symbol with the same name. Only do once for references of one variable.
2014 if (threadprivateSymNames.find(sym->name()) != threadprivateSymNames.end())
2015 continue;
2016 threadprivateSymNames.insert(sym->name());
2017 if (const Fortran::semantics::Symbol *common =
2018 Fortran::semantics::FindCommonBlockContaining(sym->GetUltimate())) {
2019 mlir::Value commonThreadprivateValue;
2020 if (commonSyms.contains(common)) {
2021 commonThreadprivateValue = converter.getSymbolAddress(*common);
2022 } else {
2023 commonThreadprivateValue = genThreadprivateOp(*common);
2024 converter.bindSymbol(*common, commonThreadprivateValue);
2025 commonSyms.insert(common);
2027 symThreadprivateValue = Fortran::lower::genCommonBlockMember(
2028 converter, currentLocation, *sym, commonThreadprivateValue);
2029 } else {
2030 symThreadprivateValue = genThreadprivateOp(*sym);
2033 fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*sym);
2034 fir::ExtendedValue symThreadprivateExv =
2035 getExtendedValue(sexv, symThreadprivateValue);
2036 converter.bindSymbol(*sym, symThreadprivateExv);
2039 firOpBuilder.restoreInsertionPoint(insPt);
2042 static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
2043 std::size_t loopVarTypeSize) {
2044 // OpenMP runtime requires 32-bit or 64-bit loop variables.
2045 loopVarTypeSize = loopVarTypeSize * 8;
2046 if (loopVarTypeSize < 32) {
2047 loopVarTypeSize = 32;
2048 } else if (loopVarTypeSize > 64) {
2049 loopVarTypeSize = 64;
2050 mlir::emitWarning(converter.getCurrentLocation(),
2051 "OpenMP loop iteration variable cannot have more than 64 "
2052 "bits size and will be narrowed into 64 bits.");
2054 assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) &&
2055 "OpenMP loop iteration variable size must be transformed into 32-bit "
2056 "or 64-bit");
2057 return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
2060 static void resetBeforeTerminator(fir::FirOpBuilder &firOpBuilder,
2061 mlir::Operation *storeOp,
2062 mlir::Block &block) {
2063 if (storeOp)
2064 firOpBuilder.setInsertionPointAfter(storeOp);
2065 else
2066 firOpBuilder.setInsertionPointToStart(&block);
2069 static mlir::Operation *
2070 createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
2071 mlir::Location loc, mlir::Value indexVal,
2072 const Fortran::semantics::Symbol *sym) {
2073 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2074 mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint();
2075 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
2077 mlir::Type tempTy = converter.genType(*sym);
2078 mlir::Value temp = firOpBuilder.create<fir::AllocaOp>(
2079 loc, tempTy, /*pinned=*/true, /*lengthParams=*/mlir::ValueRange{},
2080 /*shapeParams*/ mlir::ValueRange{},
2081 llvm::ArrayRef<mlir::NamedAttribute>{
2082 fir::getAdaptToByRefAttr(firOpBuilder)});
2083 converter.bindSymbol(*sym, temp);
2084 firOpBuilder.restoreInsertionPoint(insPt);
2085 mlir::Value cvtVal = firOpBuilder.createConvert(loc, tempTy, indexVal);
2086 mlir::Operation *storeOp = firOpBuilder.create<fir::StoreOp>(
2087 loc, cvtVal, converter.getSymbolAddress(*sym));
2088 return storeOp;
2091 /// Create the body (block) for an OpenMP Operation.
2093 /// \param [in] op - the operation the body belongs to.
2094 /// \param [inout] converter - converter to use for the clauses.
2095 /// \param [in] loc - location in source code.
2096 /// \param [in] eval - current PFT node/evaluation.
2097 /// \oaran [in] clauses - list of clauses to process.
2098 /// \param [in] args - block arguments (induction variable[s]) for the
2099 //// region.
2100 /// \param [in] outerCombined - is this an outer operation - prevents
2101 /// privatization.
2102 template <typename Op>
2103 static void createBodyOfOp(
2104 Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
2105 Fortran::lower::pft::Evaluation &eval,
2106 const Fortran::parser::OmpClauseList *clauses = nullptr,
2107 const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {},
2108 bool outerCombined = false, DataSharingProcessor *dsp = nullptr) {
2109 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2110 // If an argument for the region is provided then create the block with that
2111 // argument. Also update the symbol's address with the mlir argument value.
2112 // e.g. For loops the argument is the induction variable. And all further
2113 // uses of the induction variable should use this mlir value.
2114 mlir::Operation *storeOp = nullptr;
2115 if (args.size()) {
2116 std::size_t loopVarTypeSize = 0;
2117 for (const Fortran::semantics::Symbol *arg : args)
2118 loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
2119 mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
2120 llvm::SmallVector<mlir::Type> tiv;
2121 llvm::SmallVector<mlir::Location> locs;
2122 for (int i = 0; i < (int)args.size(); i++) {
2123 tiv.push_back(loopVarType);
2124 locs.push_back(loc);
2126 firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
2127 int argIndex = 0;
2128 // The argument is not currently in memory, so make a temporary for the
2129 // argument, and store it there, then bind that location to the argument.
2130 for (const Fortran::semantics::Symbol *arg : args) {
2131 mlir::Value indexVal =
2132 fir::getBase(op.getRegion().front().getArgument(argIndex));
2133 storeOp = createAndSetPrivatizedLoopVar(converter, loc, indexVal, arg);
2134 argIndex++;
2136 } else {
2137 firOpBuilder.createBlock(&op.getRegion());
2139 // Set the insert for the terminator operation to go at the end of the
2140 // block - this is either empty or the block with the stores above,
2141 // the end of the block works for both.
2142 mlir::Block &block = op.getRegion().back();
2143 firOpBuilder.setInsertionPointToEnd(&block);
2145 // If it is an unstructured region and is not the outer region of a combined
2146 // construct, create empty blocks for all evaluations.
2147 if (eval.lowerAsUnstructured() && !outerCombined)
2148 Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
2149 mlir::omp::YieldOp>(
2150 firOpBuilder, eval.getNestedEvaluations());
2152 // Insert the terminator.
2153 if constexpr (std::is_same_v<Op, mlir::omp::WsLoopOp> ||
2154 std::is_same_v<Op, mlir::omp::SimdLoopOp>) {
2155 mlir::ValueRange results;
2156 firOpBuilder.create<mlir::omp::YieldOp>(loc, results);
2157 } else {
2158 firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
2160 // Reset the insert point to before the terminator.
2161 resetBeforeTerminator(firOpBuilder, storeOp, block);
2163 // Handle privatization. Do not privatize if this is the outer operation.
2164 if (clauses && !outerCombined) {
2165 constexpr bool isLoop = std::is_same_v<Op, mlir::omp::WsLoopOp> ||
2166 std::is_same_v<Op, mlir::omp::SimdLoopOp>;
2167 if (!dsp) {
2168 DataSharingProcessor proc(converter, *clauses, eval);
2169 proc.processStep1();
2170 proc.processStep2(op, isLoop);
2171 } else {
2172 if (isLoop && args.size() > 0)
2173 dsp->setLoopIV(converter.getSymbolAddress(*args[0]));
2174 dsp->processStep2(op, isLoop);
2177 if (storeOp)
2178 firOpBuilder.setInsertionPointAfter(storeOp);
2181 if constexpr (std::is_same_v<Op, mlir::omp::ParallelOp>) {
2182 threadPrivatizeVars(converter, eval);
2183 if (clauses)
2184 ClauseProcessor(converter, *clauses).processCopyin();
2188 static void genBodyOfTargetDataOp(
2189 Fortran::lower::AbstractConverter &converter,
2190 Fortran::lower::pft::Evaluation &eval, mlir::omp::DataOp &dataOp,
2191 const llvm::SmallVector<mlir::Type> &useDeviceTypes,
2192 const llvm::SmallVector<mlir::Location> &useDeviceLocs,
2193 const llvm::SmallVector<const Fortran::semantics::Symbol *>
2194 &useDeviceSymbols,
2195 const mlir::Location &currentLocation) {
2196 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2197 mlir::Region &region = dataOp.getRegion();
2199 firOpBuilder.createBlock(&region, {}, useDeviceTypes, useDeviceLocs);
2201 unsigned argIndex = 0;
2202 for (const Fortran::semantics::Symbol *sym : useDeviceSymbols) {
2203 const mlir::BlockArgument &arg = region.front().getArgument(argIndex);
2204 fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
2205 if (auto refType = arg.getType().dyn_cast<fir::ReferenceType>()) {
2206 if (fir::isa_builtin_cptr_type(refType.getElementType())) {
2207 converter.bindSymbol(*sym, arg);
2208 } else {
2209 extVal.match(
2210 [&](const fir::MutableBoxValue &mbv) {
2211 converter.bindSymbol(
2212 *sym,
2213 fir::MutableBoxValue(
2214 arg, fir::factory::getNonDeferredLenParams(extVal), {}));
2216 [&](const auto &) {
2217 TODO(converter.getCurrentLocation(),
2218 "use_device clause operand unsupported type");
2221 } else {
2222 TODO(converter.getCurrentLocation(),
2223 "use_device clause operand unsupported type");
2225 argIndex++;
2228 // Insert dummy instruction to remember the insertion position. The
2229 // marker will be deleted by clean up passes since there are no uses.
2230 // Remembering the position for further insertion is important since
2231 // there are hlfir.declares inserted above while setting block arguments
2232 // and new code from the body should be inserted after that.
2233 mlir::Value undefMarker = firOpBuilder.create<fir::UndefOp>(
2234 dataOp.getOperation()->getLoc(), firOpBuilder.getIndexType());
2236 // Create blocks for unstructured regions. This has to be done since
2237 // blocks are initially allocated with the function as the parent region.
2238 if (eval.lowerAsUnstructured()) {
2239 Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
2240 mlir::omp::YieldOp>(
2241 firOpBuilder, eval.getNestedEvaluations());
2244 firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
2246 // Set the insertion point after the marker.
2247 firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
2250 template <typename OpTy, typename... Args>
2251 static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
2252 Fortran::lower::pft::Evaluation &eval,
2253 mlir::Location currentLocation, bool outerCombined,
2254 const Fortran::parser::OmpClauseList *clauseList,
2255 Args &&...args) {
2256 auto op = converter.getFirOpBuilder().create<OpTy>(
2257 currentLocation, std::forward<Args>(args)...);
2258 createBodyOfOp<OpTy>(op, converter, currentLocation, eval, clauseList,
2259 /*args=*/{}, outerCombined);
2260 return op;
2263 static mlir::omp::MasterOp
2264 genMasterOp(Fortran::lower::AbstractConverter &converter,
2265 Fortran::lower::pft::Evaluation &eval,
2266 mlir::Location currentLocation) {
2267 return genOpWithBody<mlir::omp::MasterOp>(converter, eval, currentLocation,
2268 /*outerCombined=*/false,
2269 /*clauseList=*/nullptr,
2270 /*resultTypes=*/mlir::TypeRange());
2273 static mlir::omp::OrderedRegionOp
2274 genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
2275 Fortran::lower::pft::Evaluation &eval,
2276 mlir::Location currentLocation) {
2277 return genOpWithBody<mlir::omp::OrderedRegionOp>(
2278 converter, eval, currentLocation, /*outerCombined=*/false,
2279 /*clauseList=*/nullptr, /*simd=*/false);
2282 static mlir::omp::ParallelOp
2283 genParallelOp(Fortran::lower::AbstractConverter &converter,
2284 Fortran::lower::pft::Evaluation &eval,
2285 mlir::Location currentLocation,
2286 const Fortran::parser::OmpClauseList &clauseList,
2287 bool outerCombined = false) {
2288 Fortran::lower::StatementContext stmtCtx;
2289 mlir::Value ifClauseOperand, numThreadsClauseOperand;
2290 mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
2291 llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
2292 reductionVars;
2293 llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
2295 ClauseProcessor cp(converter, clauseList);
2296 cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
2297 ifClauseOperand);
2298 cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
2299 cp.processProcBind(procBindKindAttr);
2300 cp.processDefault();
2301 cp.processAllocate(allocatorOperands, allocateOperands);
2302 if (!outerCombined)
2303 cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
2305 return genOpWithBody<mlir::omp::ParallelOp>(
2306 converter, eval, currentLocation, outerCombined, &clauseList,
2307 /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
2308 numThreadsClauseOperand, allocateOperands, allocatorOperands,
2309 reductionVars,
2310 reductionDeclSymbols.empty()
2311 ? nullptr
2312 : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
2313 reductionDeclSymbols),
2314 procBindKindAttr);
2317 static mlir::omp::SingleOp
2318 genSingleOp(Fortran::lower::AbstractConverter &converter,
2319 Fortran::lower::pft::Evaluation &eval,
2320 mlir::Location currentLocation,
2321 const Fortran::parser::OmpClauseList &beginClauseList,
2322 const Fortran::parser::OmpClauseList &endClauseList) {
2323 llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
2324 mlir::UnitAttr nowaitAttr;
2326 ClauseProcessor cp(converter, beginClauseList);
2327 cp.processAllocate(allocatorOperands, allocateOperands);
2328 cp.processTODO<Fortran::parser::OmpClause::Copyprivate>(
2329 currentLocation, llvm::omp::Directive::OMPD_single);
2331 ClauseProcessor(converter, endClauseList).processNowait(nowaitAttr);
2333 return genOpWithBody<mlir::omp::SingleOp>(
2334 converter, eval, currentLocation, /*outerCombined=*/false,
2335 &beginClauseList, allocateOperands, allocatorOperands, nowaitAttr);
2338 static mlir::omp::TaskOp
2339 genTaskOp(Fortran::lower::AbstractConverter &converter,
2340 Fortran::lower::pft::Evaluation &eval, mlir::Location currentLocation,
2341 const Fortran::parser::OmpClauseList &clauseList) {
2342 Fortran::lower::StatementContext stmtCtx;
2343 mlir::Value ifClauseOperand, finalClauseOperand, priorityClauseOperand;
2344 mlir::UnitAttr untiedAttr, mergeableAttr;
2345 llvm::SmallVector<mlir::Attribute> dependTypeOperands;
2346 llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
2347 dependOperands;
2349 ClauseProcessor cp(converter, clauseList);
2350 cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
2351 ifClauseOperand);
2352 cp.processAllocate(allocatorOperands, allocateOperands);
2353 cp.processDefault();
2354 cp.processFinal(stmtCtx, finalClauseOperand);
2355 cp.processUntied(untiedAttr);
2356 cp.processMergeable(mergeableAttr);
2357 cp.processPriority(stmtCtx, priorityClauseOperand);
2358 cp.processDepend(dependTypeOperands, dependOperands);
2359 cp.processTODO<Fortran::parser::OmpClause::InReduction,
2360 Fortran::parser::OmpClause::Detach,
2361 Fortran::parser::OmpClause::Affinity>(
2362 currentLocation, llvm::omp::Directive::OMPD_task);
2364 return genOpWithBody<mlir::omp::TaskOp>(
2365 converter, eval, currentLocation, /*outerCombined=*/false, &clauseList,
2366 ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr,
2367 /*in_reduction_vars=*/mlir::ValueRange(),
2368 /*in_reductions=*/nullptr, priorityClauseOperand,
2369 dependTypeOperands.empty()
2370 ? nullptr
2371 : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
2372 dependTypeOperands),
2373 dependOperands, allocateOperands, allocatorOperands);
2376 static mlir::omp::TaskGroupOp
2377 genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
2378 Fortran::lower::pft::Evaluation &eval,
2379 mlir::Location currentLocation,
2380 const Fortran::parser::OmpClauseList &clauseList) {
2381 llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
2382 ClauseProcessor cp(converter, clauseList);
2383 cp.processAllocate(allocatorOperands, allocateOperands);
2384 cp.processTODO<Fortran::parser::OmpClause::TaskReduction>(
2385 currentLocation, llvm::omp::Directive::OMPD_taskgroup);
2386 return genOpWithBody<mlir::omp::TaskGroupOp>(
2387 converter, eval, currentLocation, /*outerCombined=*/false, &clauseList,
2388 /*task_reduction_vars=*/mlir::ValueRange(),
2389 /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
2392 static mlir::omp::DataOp
2393 genDataOp(Fortran::lower::AbstractConverter &converter,
2394 Fortran::lower::pft::Evaluation &eval,
2395 Fortran::semantics::SemanticsContext &semanticsContext,
2396 mlir::Location currentLocation,
2397 const Fortran::parser::OmpClauseList &clauseList) {
2398 Fortran::lower::StatementContext stmtCtx;
2399 mlir::Value ifClauseOperand, deviceOperand;
2400 llvm::SmallVector<mlir::Value> mapOperands, devicePtrOperands,
2401 deviceAddrOperands;
2402 llvm::SmallVector<mlir::Type> useDeviceTypes;
2403 llvm::SmallVector<mlir::Location> useDeviceLocs;
2404 llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
2406 ClauseProcessor cp(converter, clauseList);
2407 cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
2408 ifClauseOperand);
2409 cp.processDevice(stmtCtx, deviceOperand);
2410 cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
2411 useDeviceSymbols);
2412 cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
2413 useDeviceSymbols);
2414 cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data,
2415 semanticsContext, stmtCtx, mapOperands);
2417 auto dataOp = converter.getFirOpBuilder().create<mlir::omp::DataOp>(
2418 currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
2419 deviceAddrOperands, mapOperands);
2420 genBodyOfTargetDataOp(converter, eval, dataOp, useDeviceTypes, useDeviceLocs,
2421 useDeviceSymbols, currentLocation);
2422 return dataOp;
2425 template <typename OpTy>
2426 static OpTy
2427 genEnterExitDataOp(Fortran::lower::AbstractConverter &converter,
2428 Fortran::semantics::SemanticsContext &semanticsContext,
2429 mlir::Location currentLocation,
2430 const Fortran::parser::OmpClauseList &clauseList) {
2431 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2432 Fortran::lower::StatementContext stmtCtx;
2433 mlir::Value ifClauseOperand, deviceOperand;
2434 mlir::UnitAttr nowaitAttr;
2435 llvm::SmallVector<mlir::Value> mapOperands;
2437 Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
2438 llvm::omp::Directive directive;
2439 if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
2440 directiveName =
2441 Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
2442 directive = llvm::omp::Directive::OMPD_target_enter_data;
2443 } else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
2444 directiveName =
2445 Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
2446 directive = llvm::omp::Directive::OMPD_target_exit_data;
2447 } else {
2448 return nullptr;
2451 ClauseProcessor cp(converter, clauseList);
2452 cp.processIf(directiveName, ifClauseOperand);
2453 cp.processDevice(stmtCtx, deviceOperand);
2454 cp.processNowait(nowaitAttr);
2455 cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
2456 mapOperands);
2457 cp.processTODO<Fortran::parser::OmpClause::Depend>(currentLocation,
2458 directive);
2460 return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
2461 deviceOperand, nowaitAttr, mapOperands);
2464 // This functions creates a block for the body of the targetOp's region. It adds
2465 // all the symbols present in mapSymbols as block arguments to this block.
2466 static void genBodyOfTargetOp(
2467 Fortran::lower::AbstractConverter &converter,
2468 Fortran::lower::pft::Evaluation &eval, mlir::omp::TargetOp &targetOp,
2469 const llvm::SmallVector<mlir::Type> &mapSymTypes,
2470 const llvm::SmallVector<mlir::Location> &mapSymLocs,
2471 const llvm::SmallVector<const Fortran::semantics::Symbol *> &mapSymbols,
2472 const mlir::Location &currentLocation) {
2473 assert(mapSymTypes.size() == mapSymLocs.size());
2475 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2476 mlir::Region &region = targetOp.getRegion();
2478 auto *regionBlock =
2479 firOpBuilder.createBlock(&region, {}, mapSymTypes, mapSymLocs);
2481 unsigned argIndex = 0;
2483 // Clones the `bounds` placing them inside the target region and returns them.
2484 auto cloneBound = [&](mlir::Value bound) {
2485 if (mlir::isMemoryEffectFree(bound.getDefiningOp())) {
2486 mlir::Operation *clonedOp = bound.getDefiningOp()->clone();
2487 regionBlock->push_back(clonedOp);
2488 return clonedOp->getResult(0);
2490 TODO(converter.getCurrentLocation(),
2491 "target map clause operand unsupported bound type");
2494 auto cloneBounds = [cloneBound](llvm::ArrayRef<mlir::Value> bounds) {
2495 llvm::SmallVector<mlir::Value> clonedBounds;
2496 for (mlir::Value bound : bounds)
2497 clonedBounds.emplace_back(cloneBound(bound));
2498 return clonedBounds;
2501 // Bind the symbols to their corresponding block arguments.
2502 for (const Fortran::semantics::Symbol *sym : mapSymbols) {
2503 const mlir::BlockArgument &arg = region.getArgument(argIndex);
2504 fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
2505 extVal.match(
2506 [&](const fir::BoxValue &v) {
2507 converter.bindSymbol(*sym,
2508 fir::BoxValue(arg, cloneBounds(v.getLBounds()),
2509 v.getExplicitParameters(),
2510 v.getExplicitExtents()));
2512 [&](const fir::MutableBoxValue &v) {
2513 converter.bindSymbol(
2514 *sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()),
2515 v.getMutableProperties()));
2517 [&](const fir::ArrayBoxValue &v) {
2518 converter.bindSymbol(
2519 *sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()),
2520 cloneBounds(v.getLBounds()),
2521 v.getSourceBox()));
2523 [&](const fir::CharArrayBoxValue &v) {
2524 converter.bindSymbol(
2525 *sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()),
2526 cloneBounds(v.getExtents()),
2527 cloneBounds(v.getLBounds())));
2529 [&](const fir::CharBoxValue &v) {
2530 converter.bindSymbol(*sym,
2531 fir::CharBoxValue(arg, cloneBound(v.getLen())));
2533 [&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, arg); },
2534 [&](const auto &) {
2535 TODO(converter.getCurrentLocation(),
2536 "target map clause operand unsupported type");
2538 argIndex++;
2541 // Check if cloning the bounds introduced any dependency on the outer region.
2542 // If so, then either clone them as well if they are MemoryEffectFree, or else
2543 // copy them to a new temporary and add them to the map and block_argument
2544 // lists and replace their uses with the new temporary.
2545 llvm::SetVector<mlir::Value> valuesDefinedAbove;
2546 mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
2547 while (!valuesDefinedAbove.empty()) {
2548 for (mlir::Value val : valuesDefinedAbove) {
2549 mlir::Operation *valOp = val.getDefiningOp();
2550 if (mlir::isMemoryEffectFree(valOp)) {
2551 mlir::Operation *clonedOp = valOp->clone();
2552 regionBlock->push_front(clonedOp);
2553 val.replaceUsesWithIf(
2554 clonedOp->getResult(0), [regionBlock](mlir::OpOperand &use) {
2555 return use.getOwner()->getBlock() == regionBlock;
2557 } else {
2558 auto savedIP = firOpBuilder.getInsertionPoint();
2559 firOpBuilder.setInsertionPointAfter(valOp);
2560 auto copyVal =
2561 firOpBuilder.createTemporary(val.getLoc(), val.getType());
2562 firOpBuilder.createStoreWithConvert(copyVal.getLoc(), val, copyVal);
2564 llvm::SmallVector<mlir::Value> bounds;
2565 std::stringstream name;
2566 firOpBuilder.setInsertionPoint(targetOp);
2567 mlir::Value mapOp = createMapInfoOp(
2568 firOpBuilder, copyVal.getLoc(), copyVal, name, bounds,
2569 static_cast<
2570 std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
2571 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
2572 mlir::omp::VariableCaptureKind::ByCopy, copyVal.getType());
2573 targetOp.getMapOperandsMutable().append(mapOp);
2574 mlir::Value clonedValArg =
2575 region.addArgument(copyVal.getType(), copyVal.getLoc());
2576 firOpBuilder.setInsertionPointToStart(regionBlock);
2577 auto loadOp = firOpBuilder.create<fir::LoadOp>(clonedValArg.getLoc(),
2578 clonedValArg);
2579 val.replaceUsesWithIf(
2580 loadOp->getResult(0), [regionBlock](mlir::OpOperand &use) {
2581 return use.getOwner()->getBlock() == regionBlock;
2583 firOpBuilder.setInsertionPoint(regionBlock, savedIP);
2586 valuesDefinedAbove.clear();
2587 mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
2590 // Insert dummy instruction to remember the insertion position. The
2591 // marker will be deleted since there are not uses.
2592 // In the HLFIR flow there are hlfir.declares inserted above while
2593 // setting block arguments.
2594 mlir::Value undefMarker = firOpBuilder.create<fir::UndefOp>(
2595 targetOp.getOperation()->getLoc(), firOpBuilder.getIndexType());
2597 // Create blocks for unstructured regions. This has to be done since
2598 // blocks are initially allocated with the function as the parent region.
2599 // the parent region of blocks.
2600 if (eval.lowerAsUnstructured()) {
2601 Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
2602 mlir::omp::YieldOp>(
2603 firOpBuilder, eval.getNestedEvaluations());
2606 firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
2608 // Create the insertion point after the marker.
2609 firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
2612 static mlir::omp::TargetOp
2613 genTargetOp(Fortran::lower::AbstractConverter &converter,
2614 Fortran::lower::pft::Evaluation &eval,
2615 Fortran::semantics::SemanticsContext &semanticsContext,
2616 mlir::Location currentLocation,
2617 const Fortran::parser::OmpClauseList &clauseList,
2618 llvm::omp::Directive directive, bool outerCombined = false) {
2619 Fortran::lower::StatementContext stmtCtx;
2620 mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
2621 mlir::UnitAttr nowaitAttr;
2622 llvm::SmallVector<mlir::Value> mapOperands;
2623 llvm::SmallVector<mlir::Type> mapSymTypes;
2624 llvm::SmallVector<mlir::Location> mapSymLocs;
2625 llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
2627 ClauseProcessor cp(converter, clauseList);
2628 cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
2629 ifClauseOperand);
2630 cp.processDevice(stmtCtx, deviceOperand);
2631 cp.processThreadLimit(stmtCtx, threadLimitOperand);
2632 cp.processNowait(nowaitAttr);
2633 cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
2634 mapOperands, &mapSymTypes, &mapSymLocs, &mapSymbols);
2635 cp.processTODO<Fortran::parser::OmpClause::Private,
2636 Fortran::parser::OmpClause::Depend,
2637 Fortran::parser::OmpClause::Firstprivate,
2638 Fortran::parser::OmpClause::IsDevicePtr,
2639 Fortran::parser::OmpClause::HasDeviceAddr,
2640 Fortran::parser::OmpClause::Reduction,
2641 Fortran::parser::OmpClause::InReduction,
2642 Fortran::parser::OmpClause::Allocate,
2643 Fortran::parser::OmpClause::UsesAllocators,
2644 Fortran::parser::OmpClause::Defaultmap>(
2645 currentLocation, llvm::omp::Directive::OMPD_target);
2647 // 5.8.1 Implicit Data-Mapping Attribute Rules
2648 // The following code follows the implicit data-mapping rules to map all the
2649 // symbols used inside the region that have not been explicitly mapped using
2650 // the map clause.
2651 auto captureImplicitMap = [&](const Fortran::semantics::Symbol &sym) {
2652 if (llvm::find(mapSymbols, &sym) == mapSymbols.end()) {
2653 mlir::Value baseOp = converter.getSymbolAddress(sym);
2654 if (!baseOp)
2655 if (const auto *details = sym.template detailsIf<
2656 Fortran::semantics::HostAssocDetails>()) {
2657 baseOp = converter.getSymbolAddress(details->symbol());
2658 converter.copySymbolBinding(details->symbol(), sym);
2661 if (baseOp) {
2662 llvm::SmallVector<mlir::Value> bounds;
2663 std::stringstream name;
2664 fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
2665 name << sym.name().ToString();
2667 mlir::Value baseAddr =
2668 getDataOperandBaseAddr(converter, converter.getFirOpBuilder(), sym,
2669 converter.getCurrentLocation());
2670 if (fir::unwrapRefType(baseAddr.getType()).isa<fir::BaseBoxType>())
2671 bounds =
2672 Fortran::lower::genBoundsOpsFromBox<mlir::omp::DataBoundsOp,
2673 mlir::omp::DataBoundsType>(
2674 converter.getFirOpBuilder(), converter.getCurrentLocation(),
2675 converter, dataExv, baseAddr);
2676 if (fir::unwrapRefType(baseAddr.getType()).isa<fir::SequenceType>())
2677 bounds = Fortran::lower::genBaseBoundsOps<mlir::omp::DataBoundsOp,
2678 mlir::omp::DataBoundsType>(
2679 converter.getFirOpBuilder(), converter.getCurrentLocation(),
2680 converter, dataExv, baseAddr);
2682 llvm::omp::OpenMPOffloadMappingFlags mapFlag =
2683 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
2684 mlir::omp::VariableCaptureKind captureKind =
2685 mlir::omp::VariableCaptureKind::ByRef;
2686 if (auto refType = baseOp.getType().dyn_cast<fir::ReferenceType>()) {
2687 auto eleType = refType.getElementType();
2688 if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) {
2689 captureKind = mlir::omp::VariableCaptureKind::ByCopy;
2690 } else if (!fir::isa_builtin_cptr_type(eleType)) {
2691 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
2692 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
2696 mlir::Value mapOp = createMapInfoOp(
2697 converter.getFirOpBuilder(), baseOp.getLoc(), baseOp, name, bounds,
2698 static_cast<
2699 std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
2700 mapFlag),
2701 captureKind, baseOp.getType());
2703 mapOperands.push_back(mapOp);
2704 mapSymTypes.push_back(baseOp.getType());
2705 mapSymLocs.push_back(baseOp.getLoc());
2706 mapSymbols.push_back(&sym);
2710 Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
2712 auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
2713 currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
2714 nowaitAttr, mapOperands);
2716 genBodyOfTargetOp(converter, eval, targetOp, mapSymTypes, mapSymLocs,
2717 mapSymbols, currentLocation);
2719 return targetOp;
2722 static mlir::omp::TeamsOp
2723 genTeamsOp(Fortran::lower::AbstractConverter &converter,
2724 Fortran::lower::pft::Evaluation &eval,
2725 mlir::Location currentLocation,
2726 const Fortran::parser::OmpClauseList &clauseList,
2727 bool outerCombined = false) {
2728 Fortran::lower::StatementContext stmtCtx;
2729 mlir::Value numTeamsClauseOperand, ifClauseOperand, threadLimitClauseOperand;
2730 llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
2731 reductionVars;
2732 llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
2734 ClauseProcessor cp(converter, clauseList);
2735 cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
2736 ifClauseOperand);
2737 cp.processAllocate(allocatorOperands, allocateOperands);
2738 cp.processDefault();
2739 cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
2740 cp.processThreadLimit(stmtCtx, threadLimitClauseOperand);
2741 cp.processTODO<Fortran::parser::OmpClause::Reduction>(
2742 currentLocation, llvm::omp::Directive::OMPD_teams);
2744 return genOpWithBody<mlir::omp::TeamsOp>(
2745 converter, eval, currentLocation, outerCombined, &clauseList,
2746 /*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
2747 threadLimitClauseOperand, allocateOperands, allocatorOperands,
2748 reductionVars,
2749 reductionDeclSymbols.empty()
2750 ? nullptr
2751 : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
2752 reductionDeclSymbols));
2755 /// Extract the list of function and variable symbols affected by the given
2756 /// 'declare target' directive and return the intended device type for them.
2757 static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
2758 Fortran::lower::AbstractConverter &converter,
2759 Fortran::lower::pft::Evaluation &eval,
2760 const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
2761 llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
2763 // The default capture type
2764 mlir::omp::DeclareTargetDeviceType deviceType =
2765 mlir::omp::DeclareTargetDeviceType::any;
2766 const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
2767 declareTargetConstruct.t);
2769 if (const auto *objectList{
2770 Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
2771 // Case: declare target(func, var1, var2)
2772 gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
2773 symbolAndClause);
2774 } else if (const auto *clauseList{
2775 Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
2776 spec.u)}) {
2777 if (clauseList->v.empty()) {
2778 // Case: declare target, implicit capture of function
2779 symbolAndClause.emplace_back(
2780 mlir::omp::DeclareTargetCaptureClause::to,
2781 eval.getOwningProcedure()->getSubprogramSymbol());
2784 ClauseProcessor cp(converter, *clauseList);
2785 cp.processTo(symbolAndClause);
2786 cp.processEnter(symbolAndClause);
2787 cp.processLink(symbolAndClause);
2788 cp.processDeviceType(deviceType);
2789 cp.processTODO<Fortran::parser::OmpClause::Indirect>(
2790 converter.getCurrentLocation(),
2791 llvm::omp::Directive::OMPD_declare_target);
2794 return deviceType;
2797 static std::optional<mlir::omp::DeclareTargetDeviceType>
2798 getDeclareTargetFunctionDevice(
2799 Fortran::lower::AbstractConverter &converter,
2800 Fortran::lower::pft::Evaluation &eval,
2801 const Fortran::parser::OpenMPDeclareTargetConstruct
2802 &declareTargetConstruct) {
2803 llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause;
2804 mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
2805 converter, eval, declareTargetConstruct, symbolAndClause);
2807 // Return the device type only if at least one of the targets for the
2808 // directive is a function or subroutine
2809 mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
2810 for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
2811 mlir::Operation *op = mod.lookupSymbol(
2812 converter.mangleName(std::get<Fortran::semantics::Symbol>(symClause)));
2814 if (mlir::isa<mlir::func::FuncOp>(op))
2815 return deviceType;
2818 return std::nullopt;
2821 //===----------------------------------------------------------------------===//
2822 // genOMP() Code generation helper functions
2823 //===----------------------------------------------------------------------===//
2825 static void
2826 genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
2827 Fortran::lower::pft::Evaluation &eval,
2828 Fortran::semantics::SemanticsContext &semanticsContext,
2829 const Fortran::parser::OpenMPSimpleStandaloneConstruct
2830 &simpleStandaloneConstruct) {
2831 const auto &directive =
2832 std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
2833 simpleStandaloneConstruct.t);
2834 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2835 const auto &opClauseList =
2836 std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t);
2837 mlir::Location currentLocation = converter.genLocation(directive.source);
2839 switch (directive.v) {
2840 default:
2841 break;
2842 case llvm::omp::Directive::OMPD_barrier:
2843 firOpBuilder.create<mlir::omp::BarrierOp>(currentLocation);
2844 break;
2845 case llvm::omp::Directive::OMPD_taskwait:
2846 ClauseProcessor(converter, opClauseList)
2847 .processTODO<Fortran::parser::OmpClause::Depend,
2848 Fortran::parser::OmpClause::Nowait>(
2849 currentLocation, llvm::omp::Directive::OMPD_taskwait);
2850 firOpBuilder.create<mlir::omp::TaskwaitOp>(currentLocation);
2851 break;
2852 case llvm::omp::Directive::OMPD_taskyield:
2853 firOpBuilder.create<mlir::omp::TaskyieldOp>(currentLocation);
2854 break;
2855 case llvm::omp::Directive::OMPD_target_data:
2856 genDataOp(converter, eval, semanticsContext, currentLocation, opClauseList);
2857 break;
2858 case llvm::omp::Directive::OMPD_target_enter_data:
2859 genEnterExitDataOp<mlir::omp::EnterDataOp>(converter, semanticsContext,
2860 currentLocation, opClauseList);
2861 break;
2862 case llvm::omp::Directive::OMPD_target_exit_data:
2863 genEnterExitDataOp<mlir::omp::ExitDataOp>(converter, semanticsContext,
2864 currentLocation, opClauseList);
2865 break;
2866 case llvm::omp::Directive::OMPD_target_update:
2867 TODO(currentLocation, "OMPD_target_update");
2868 case llvm::omp::Directive::OMPD_ordered:
2869 TODO(currentLocation, "OMPD_ordered");
2873 static void
2874 genOmpFlush(Fortran::lower::AbstractConverter &converter,
2875 Fortran::lower::pft::Evaluation &eval,
2876 const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
2877 llvm::SmallVector<mlir::Value, 4> operandRange;
2878 if (const auto &ompObjectList =
2879 std::get<std::optional<Fortran::parser::OmpObjectList>>(
2880 flushConstruct.t))
2881 genObjectList(*ompObjectList, converter, operandRange);
2882 const auto &memOrderClause =
2883 std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
2884 flushConstruct.t);
2885 if (memOrderClause && memOrderClause->size() > 0)
2886 TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause");
2887 converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
2888 converter.getCurrentLocation(), operandRange);
2891 static void
2892 genOMP(Fortran::lower::AbstractConverter &converter,
2893 Fortran::lower::pft::Evaluation &eval,
2894 Fortran::semantics::SemanticsContext &semanticsContext,
2895 const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
2896 std::visit(
2897 Fortran::common::visitors{
2898 [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct
2899 &simpleStandaloneConstruct) {
2900 genOmpSimpleStandalone(converter, eval, semanticsContext,
2901 simpleStandaloneConstruct);
2903 [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
2904 genOmpFlush(converter, eval, flushConstruct);
2906 [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) {
2907 TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
2909 [&](const Fortran::parser::OpenMPCancellationPointConstruct
2910 &cancellationPointConstruct) {
2911 TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
2914 standaloneConstruct.u);
2917 static void genOMP(Fortran::lower::AbstractConverter &converter,
2918 Fortran::lower::pft::Evaluation &eval,
2919 Fortran::semantics::SemanticsContext &semanticsContext,
2920 const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
2921 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2922 llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, linearVars,
2923 linearStepVars, reductionVars;
2924 mlir::Value scheduleChunkClauseOperand;
2925 mlir::IntegerAttr orderedClauseOperand;
2926 mlir::omp::ClauseOrderKindAttr orderClauseOperand;
2927 mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
2928 mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
2929 mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
2930 llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
2931 Fortran::lower::StatementContext stmtCtx;
2932 std::size_t loopVarTypeSize;
2933 llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
2935 const auto &beginLoopDirective =
2936 std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t);
2937 const auto &loopOpClauseList =
2938 std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t);
2939 mlir::Location currentLocation =
2940 converter.genLocation(beginLoopDirective.source);
2941 const auto ompDirective =
2942 std::get<Fortran::parser::OmpLoopDirective>(beginLoopDirective.t).v;
2944 bool validDirective = false;
2945 if (llvm::omp::topTaskloopSet.test(ompDirective)) {
2946 validDirective = true;
2947 TODO(currentLocation, "Taskloop construct");
2948 } else {
2949 // Create omp.{target, teams, distribute, parallel} nested operations
2950 if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet)
2951 .test(ompDirective)) {
2952 validDirective = true;
2953 genTargetOp(converter, eval, semanticsContext, currentLocation,
2954 loopOpClauseList, ompDirective, /*outerCombined=*/true);
2956 if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet)
2957 .test(ompDirective)) {
2958 validDirective = true;
2959 genTeamsOp(converter, eval, currentLocation, loopOpClauseList,
2960 /*outerCombined=*/true);
2962 if (llvm::omp::allDistributeSet.test(ompDirective)) {
2963 validDirective = true;
2964 TODO(currentLocation, "Distribute construct");
2966 if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
2967 .test(ompDirective)) {
2968 validDirective = true;
2969 genParallelOp(converter, eval, currentLocation, loopOpClauseList,
2970 /*outerCombined=*/true);
2973 if ((llvm::omp::allDoSet | llvm::omp::allSimdSet).test(ompDirective))
2974 validDirective = true;
2976 if (!validDirective) {
2977 TODO(currentLocation, "Unhandled loop directive (" +
2978 llvm::omp::getOpenMPDirectiveName(ompDirective) +
2979 ")");
2982 DataSharingProcessor dsp(converter, loopOpClauseList, eval);
2983 dsp.processStep1();
2985 ClauseProcessor cp(converter, loopOpClauseList);
2986 cp.processCollapse(currentLocation, eval, lowerBound, upperBound, step, iv,
2987 loopVarTypeSize);
2988 cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
2989 cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
2990 cp.processTODO<Fortran::parser::OmpClause::Linear,
2991 Fortran::parser::OmpClause::Order>(currentLocation,
2992 ompDirective);
2994 // The types of lower bound, upper bound, and step are converted into the
2995 // type of the loop variable if necessary.
2996 mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
2997 for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
2998 lowerBound[it] = firOpBuilder.createConvert(currentLocation, loopVarType,
2999 lowerBound[it]);
3000 upperBound[it] = firOpBuilder.createConvert(currentLocation, loopVarType,
3001 upperBound[it]);
3002 step[it] =
3003 firOpBuilder.createConvert(currentLocation, loopVarType, step[it]);
3006 // 2.9.3.1 SIMD construct
3007 if (llvm::omp::allSimdSet.test(ompDirective)) {
3008 llvm::SmallVector<mlir::Value> alignedVars, nontemporalVars;
3009 mlir::Value ifClauseOperand;
3010 mlir::IntegerAttr simdlenClauseOperand, safelenClauseOperand;
3011 cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
3012 ifClauseOperand);
3013 cp.processSimdlen(simdlenClauseOperand);
3014 cp.processSafelen(safelenClauseOperand);
3015 cp.processTODO<Fortran::parser::OmpClause::Aligned,
3016 Fortran::parser::OmpClause::Allocate,
3017 Fortran::parser::OmpClause::Nontemporal>(currentLocation,
3018 ompDirective);
3020 mlir::TypeRange resultType;
3021 auto simdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>(
3022 currentLocation, resultType, lowerBound, upperBound, step, alignedVars,
3023 /*alignment_values=*/nullptr, ifClauseOperand, nontemporalVars,
3024 orderClauseOperand, simdlenClauseOperand, safelenClauseOperand,
3025 /*inclusive=*/firOpBuilder.getUnitAttr());
3026 createBodyOfOp<mlir::omp::SimdLoopOp>(
3027 simdLoopOp, converter, currentLocation, eval, &loopOpClauseList, iv,
3028 /*outer=*/false, &dsp);
3029 return;
3032 auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
3033 currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars,
3034 reductionVars,
3035 reductionDeclSymbols.empty()
3036 ? nullptr
3037 : mlir::ArrayAttr::get(firOpBuilder.getContext(),
3038 reductionDeclSymbols),
3039 scheduleValClauseOperand, scheduleChunkClauseOperand,
3040 /*schedule_modifiers=*/nullptr,
3041 /*simd_modifier=*/nullptr, nowaitClauseOperand, orderedClauseOperand,
3042 orderClauseOperand,
3043 /*inclusive=*/firOpBuilder.getUnitAttr());
3045 // Handle attribute based clauses.
3046 if (cp.processOrdered(orderedClauseOperand))
3047 wsLoopOp.setOrderedValAttr(orderedClauseOperand);
3049 if (cp.processSchedule(scheduleValClauseOperand, scheduleModClauseOperand,
3050 scheduleSimdClauseOperand)) {
3051 wsLoopOp.setScheduleValAttr(scheduleValClauseOperand);
3052 wsLoopOp.setScheduleModifierAttr(scheduleModClauseOperand);
3053 wsLoopOp.setSimdModifierAttr(scheduleSimdClauseOperand);
3055 // In FORTRAN `nowait` clause occur at the end of `omp do` directive.
3056 // i.e
3057 // !$omp do
3058 // <...>
3059 // !$omp end do nowait
3060 if (const auto &endClauseList =
3061 std::get<std::optional<Fortran::parser::OmpEndLoopDirective>>(
3062 loopConstruct.t)) {
3063 const auto &clauseList =
3064 std::get<Fortran::parser::OmpClauseList>((*endClauseList).t);
3065 if (ClauseProcessor(converter, clauseList)
3066 .processNowait(nowaitClauseOperand))
3067 wsLoopOp.setNowaitAttr(nowaitClauseOperand);
3070 createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, currentLocation,
3071 eval, &loopOpClauseList, iv,
3072 /*outer=*/false, &dsp);
3075 static void
3076 genOMP(Fortran::lower::AbstractConverter &converter,
3077 Fortran::lower::pft::Evaluation &eval,
3078 Fortran::semantics::SemanticsContext &semanticsContext,
3079 const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
3080 const auto &beginBlockDirective =
3081 std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
3082 const auto &endBlockDirective =
3083 std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
3084 const auto &directive =
3085 std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
3086 const auto &beginClauseList =
3087 std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
3088 const auto &endClauseList =
3089 std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t);
3091 for (const Fortran::parser::OmpClause &clause : beginClauseList.v) {
3092 mlir::Location clauseLocation = converter.genLocation(clause.source);
3093 if (!std::get_if<Fortran::parser::OmpClause::If>(&clause.u) &&
3094 !std::get_if<Fortran::parser::OmpClause::NumThreads>(&clause.u) &&
3095 !std::get_if<Fortran::parser::OmpClause::ProcBind>(&clause.u) &&
3096 !std::get_if<Fortran::parser::OmpClause::Allocate>(&clause.u) &&
3097 !std::get_if<Fortran::parser::OmpClause::Default>(&clause.u) &&
3098 !std::get_if<Fortran::parser::OmpClause::Final>(&clause.u) &&
3099 !std::get_if<Fortran::parser::OmpClause::Priority>(&clause.u) &&
3100 !std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u) &&
3101 !std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u) &&
3102 !std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) &&
3103 !std::get_if<Fortran::parser::OmpClause::Firstprivate>(&clause.u) &&
3104 !std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u) &&
3105 !std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u) &&
3106 !std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u) &&
3107 !std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
3108 !std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
3109 !std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
3110 !std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
3111 !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
3112 TODO(clauseLocation, "OpenMP Block construct clause");
3116 for (const auto &clause : endClauseList.v) {
3117 mlir::Location clauseLocation = converter.genLocation(clause.source);
3118 if (!std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
3119 TODO(clauseLocation, "OpenMP Block construct clause");
3122 mlir::Location currentLocation = converter.genLocation(directive.source);
3123 switch (directive.v) {
3124 case llvm::omp::Directive::OMPD_master:
3125 genMasterOp(converter, eval, currentLocation);
3126 break;
3127 case llvm::omp::Directive::OMPD_ordered:
3128 genOrderedRegionOp(converter, eval, currentLocation);
3129 break;
3130 case llvm::omp::Directive::OMPD_parallel:
3131 genParallelOp(converter, eval, currentLocation, beginClauseList);
3132 break;
3133 case llvm::omp::Directive::OMPD_single:
3134 genSingleOp(converter, eval, currentLocation, beginClauseList,
3135 endClauseList);
3136 break;
3137 case llvm::omp::Directive::OMPD_target:
3138 genTargetOp(converter, eval, semanticsContext, currentLocation,
3139 beginClauseList, directive.v);
3140 break;
3141 case llvm::omp::Directive::OMPD_target_data:
3142 genDataOp(converter, eval, semanticsContext, currentLocation,
3143 beginClauseList);
3144 break;
3145 case llvm::omp::Directive::OMPD_task:
3146 genTaskOp(converter, eval, currentLocation, beginClauseList);
3147 break;
3148 case llvm::omp::Directive::OMPD_taskgroup:
3149 genTaskGroupOp(converter, eval, currentLocation, beginClauseList);
3150 break;
3151 case llvm::omp::Directive::OMPD_teams:
3152 genTeamsOp(converter, eval, currentLocation, beginClauseList,
3153 /*outerCombined=*/false);
3154 break;
3155 case llvm::omp::Directive::OMPD_workshare:
3156 TODO(currentLocation, "Workshare construct");
3157 break;
3158 default: {
3159 // Codegen for combined directives
3160 bool combinedDirective = false;
3161 if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet)
3162 .test(directive.v)) {
3163 genTargetOp(converter, eval, semanticsContext, currentLocation,
3164 beginClauseList, directive.v, /*outerCombined=*/true);
3165 combinedDirective = true;
3167 if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet)
3168 .test(directive.v)) {
3169 genTeamsOp(converter, eval, currentLocation, beginClauseList);
3170 combinedDirective = true;
3172 if ((llvm::omp::allParallelSet & llvm::omp::blockConstructSet)
3173 .test(directive.v)) {
3174 bool outerCombined =
3175 directive.v != llvm::omp::Directive::OMPD_target_parallel;
3176 genParallelOp(converter, eval, currentLocation, beginClauseList,
3177 outerCombined);
3178 combinedDirective = true;
3180 if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet)
3181 .test(directive.v)) {
3182 TODO(currentLocation, "Workshare construct");
3183 combinedDirective = true;
3185 if (!combinedDirective)
3186 TODO(currentLocation, "Unhandled block directive (" +
3187 llvm::omp::getOpenMPDirectiveName(directive.v) +
3188 ")");
3189 break;
3194 static void
3195 genOMP(Fortran::lower::AbstractConverter &converter,
3196 Fortran::lower::pft::Evaluation &eval,
3197 const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
3198 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
3199 mlir::Location currentLocation = converter.getCurrentLocation();
3200 mlir::IntegerAttr hintClauseOp;
3201 std::string name;
3202 const Fortran::parser::OmpCriticalDirective &cd =
3203 std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
3204 if (std::get<std::optional<Fortran::parser::Name>>(cd.t).has_value()) {
3205 name =
3206 std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString();
3209 const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
3210 ClauseProcessor(converter, clauseList).processHint(hintClauseOp);
3212 mlir::omp::CriticalOp criticalOp = [&]() {
3213 if (name.empty()) {
3214 return firOpBuilder.create<mlir::omp::CriticalOp>(
3215 currentLocation, mlir::FlatSymbolRefAttr());
3217 mlir::ModuleOp module = firOpBuilder.getModule();
3218 mlir::OpBuilder modBuilder(module.getBodyRegion());
3219 auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name);
3220 if (!global)
3221 global = modBuilder.create<mlir::omp::CriticalDeclareOp>(
3222 currentLocation,
3223 mlir::StringAttr::get(firOpBuilder.getContext(), name), hintClauseOp);
3224 return firOpBuilder.create<mlir::omp::CriticalOp>(
3225 currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(),
3226 global.getSymName()));
3227 }();
3228 createBodyOfOp<mlir::omp::CriticalOp>(criticalOp, converter, currentLocation,
3229 eval);
3232 static void
3233 genOMP(Fortran::lower::AbstractConverter &converter,
3234 Fortran::lower::pft::Evaluation &eval,
3235 const Fortran::parser::OpenMPSectionConstruct &sectionConstruct) {
3236 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
3237 mlir::Location currentLocation = converter.getCurrentLocation();
3238 const Fortran::parser::OpenMPConstruct *parentOmpConstruct =
3239 eval.parentConstruct->getIf<Fortran::parser::OpenMPConstruct>();
3240 assert(parentOmpConstruct &&
3241 "No enclosing parent OpenMPConstruct on SECTION construct");
3242 const Fortran::parser::OpenMPSectionsConstruct *sectionsConstruct =
3243 std::get_if<Fortran::parser::OpenMPSectionsConstruct>(
3244 &parentOmpConstruct->u);
3245 assert(sectionsConstruct && "SECTION construct must have parent"
3246 "SECTIONS construct");
3247 const Fortran::parser::OmpClauseList &sectionsClauseList =
3248 std::get<Fortran::parser::OmpClauseList>(
3249 std::get<Fortran::parser::OmpBeginSectionsDirective>(
3250 sectionsConstruct->t)
3251 .t);
3252 // Currently only private/firstprivate clause is handled, and
3253 // all privatization is done within `omp.section` operations.
3254 mlir::omp::SectionOp sectionOp =
3255 firOpBuilder.create<mlir::omp::SectionOp>(currentLocation);
3256 createBodyOfOp<mlir::omp::SectionOp>(sectionOp, converter, currentLocation,
3257 eval, &sectionsClauseList);
3260 static void
3261 genOMP(Fortran::lower::AbstractConverter &converter,
3262 Fortran::lower::pft::Evaluation &eval,
3263 const Fortran::parser::OpenMPSectionsConstruct &sectionsConstruct) {
3264 mlir::Location currentLocation = converter.getCurrentLocation();
3265 llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
3266 mlir::UnitAttr nowaitClauseOperand;
3267 const auto &beginSectionsDirective =
3268 std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t);
3269 const auto &sectionsClauseList =
3270 std::get<Fortran::parser::OmpClauseList>(beginSectionsDirective.t);
3272 // Process clauses before optional omp.parallel, so that new variables are
3273 // allocated outside of the parallel region
3274 ClauseProcessor cp(converter, sectionsClauseList);
3275 cp.processSectionsReduction(currentLocation);
3276 cp.processAllocate(allocatorOperands, allocateOperands);
3278 llvm::omp::Directive dir =
3279 std::get<Fortran::parser::OmpSectionsDirective>(beginSectionsDirective.t)
3282 // Parallel wrapper of PARALLEL SECTIONS construct
3283 if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
3284 genParallelOp(converter, eval, currentLocation, sectionsClauseList,
3285 /*outerCombined=*/true);
3286 } else {
3287 const auto &endSectionsDirective =
3288 std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t);
3289 const auto &endSectionsClauseList =
3290 std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t);
3291 ClauseProcessor(converter, endSectionsClauseList)
3292 .processNowait(nowaitClauseOperand);
3295 // SECTIONS construct
3296 genOpWithBody<mlir::omp::SectionsOp>(converter, eval, currentLocation,
3297 /*outerCombined=*/false,
3298 /*clauseList=*/nullptr,
3299 /*reduction_vars=*/mlir::ValueRange(),
3300 /*reductions=*/nullptr, allocateOperands,
3301 allocatorOperands, nowaitClauseOperand);
3304 static void
3305 genOMP(Fortran::lower::AbstractConverter &converter,
3306 Fortran::lower::pft::Evaluation &eval,
3307 const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
3308 std::visit(
3309 Fortran::common::visitors{
3310 [&](const Fortran::parser::OmpAtomicRead &atomicRead) {
3311 mlir::Location loc = converter.genLocation(atomicRead.source);
3312 Fortran::lower::genOmpAccAtomicRead<
3313 Fortran::parser::OmpAtomicRead,
3314 Fortran::parser::OmpAtomicClauseList>(converter, atomicRead,
3315 loc);
3317 [&](const Fortran::parser::OmpAtomicWrite &atomicWrite) {
3318 mlir::Location loc = converter.genLocation(atomicWrite.source);
3319 Fortran::lower::genOmpAccAtomicWrite<
3320 Fortran::parser::OmpAtomicWrite,
3321 Fortran::parser::OmpAtomicClauseList>(converter, atomicWrite,
3322 loc);
3324 [&](const Fortran::parser::OmpAtomic &atomicConstruct) {
3325 mlir::Location loc = converter.genLocation(atomicConstruct.source);
3326 Fortran::lower::genOmpAtomic<Fortran::parser::OmpAtomic,
3327 Fortran::parser::OmpAtomicClauseList>(
3328 converter, atomicConstruct, loc);
3330 [&](const Fortran::parser::OmpAtomicUpdate &atomicUpdate) {
3331 mlir::Location loc = converter.genLocation(atomicUpdate.source);
3332 Fortran::lower::genOmpAccAtomicUpdate<
3333 Fortran::parser::OmpAtomicUpdate,
3334 Fortran::parser::OmpAtomicClauseList>(converter, atomicUpdate,
3335 loc);
3337 [&](const Fortran::parser::OmpAtomicCapture &atomicCapture) {
3338 mlir::Location loc = converter.genLocation(atomicCapture.source);
3339 Fortran::lower::genOmpAccAtomicCapture<
3340 Fortran::parser::OmpAtomicCapture,
3341 Fortran::parser::OmpAtomicClauseList>(converter, atomicCapture,
3342 loc);
3345 atomicConstruct.u);
3348 static void genOMP(Fortran::lower::AbstractConverter &converter,
3349 Fortran::lower::pft::Evaluation &eval,
3350 const Fortran::parser::OpenMPDeclareTargetConstruct
3351 &declareTargetConstruct) {
3352 llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause;
3353 mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
3354 mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
3355 converter, eval, declareTargetConstruct, symbolAndClause);
3357 for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
3358 mlir::Operation *op = mod.lookupSymbol(
3359 converter.mangleName(std::get<Fortran::semantics::Symbol>(symClause)));
3360 // There's several cases this can currently be triggered and it could be
3361 // one of the following:
3362 // 1) Invalid argument passed to a declare target that currently isn't
3363 // captured by a frontend semantic check
3364 // 2) The symbol of a valid argument is not correctly updated by one of
3365 // the prior passes, resulting in missing symbol information
3366 // 3) It's a variable internal to a module or program, that is legal by
3367 // Fortran OpenMP standards, but is currently unhandled as they do not
3368 // appear in the symbol table as they are represented as allocas
3369 if (!op)
3370 TODO(converter.getCurrentLocation(),
3371 "Missing symbol, possible case of currently unsupported use of "
3372 "a program local variable in declare target or erroneous symbol "
3373 "information ");
3375 auto declareTargetOp =
3376 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op);
3377 if (!declareTargetOp)
3378 fir::emitFatalError(
3379 converter.getCurrentLocation(),
3380 "Attempt to apply declare target on unsupported operation");
3382 // The function or global already has a declare target applied to it, very
3383 // likely through implicit capture (usage in another declare target
3384 // function/subroutine). It should be marked as any if it has been assigned
3385 // both host and nohost, else we skip, as there is no change
3386 if (declareTargetOp.isDeclareTarget()) {
3387 if (declareTargetOp.getDeclareTargetDeviceType() != deviceType)
3388 declareTargetOp.setDeclareTarget(
3389 mlir::omp::DeclareTargetDeviceType::any,
3390 std::get<mlir::omp::DeclareTargetCaptureClause>(symClause));
3391 continue;
3394 declareTargetOp.setDeclareTarget(
3395 deviceType, std::get<mlir::omp::DeclareTargetCaptureClause>(symClause));
3399 //===----------------------------------------------------------------------===//
3400 // Public functions
3401 //===----------------------------------------------------------------------===//
3403 void Fortran::lower::genOpenMPTerminator(fir::FirOpBuilder &builder,
3404 mlir::Operation *op,
3405 mlir::Location loc) {
3406 if (mlir::isa<mlir::omp::WsLoopOp, mlir::omp::ReductionDeclareOp,
3407 mlir::omp::AtomicUpdateOp, mlir::omp::SimdLoopOp>(op))
3408 builder.create<mlir::omp::YieldOp>(loc);
3409 else
3410 builder.create<mlir::omp::TerminatorOp>(loc);
3413 void Fortran::lower::genOpenMPConstruct(
3414 Fortran::lower::AbstractConverter &converter,
3415 Fortran::semantics::SemanticsContext &semanticsContext,
3416 Fortran::lower::pft::Evaluation &eval,
3417 const Fortran::parser::OpenMPConstruct &ompConstruct) {
3418 std::visit(
3419 common::visitors{
3420 [&](const Fortran::parser::OpenMPStandaloneConstruct
3421 &standaloneConstruct) {
3422 genOMP(converter, eval, semanticsContext, standaloneConstruct);
3424 [&](const Fortran::parser::OpenMPSectionsConstruct
3425 &sectionsConstruct) {
3426 genOMP(converter, eval, sectionsConstruct);
3428 [&](const Fortran::parser::OpenMPSectionConstruct &sectionConstruct) {
3429 genOMP(converter, eval, sectionConstruct);
3431 [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
3432 genOMP(converter, eval, semanticsContext, loopConstruct);
3434 [&](const Fortran::parser::OpenMPDeclarativeAllocate
3435 &execAllocConstruct) {
3436 TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate");
3438 [&](const Fortran::parser::OpenMPExecutableAllocate
3439 &execAllocConstruct) {
3440 TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate");
3442 [&](const Fortran::parser::OpenMPAllocatorsConstruct
3443 &allocsConstruct) {
3444 TODO(converter.getCurrentLocation(), "OpenMPAllocatorsConstruct");
3446 [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
3447 genOMP(converter, eval, semanticsContext, blockConstruct);
3449 [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
3450 genOMP(converter, eval, atomicConstruct);
3452 [&](const Fortran::parser::OpenMPCriticalConstruct
3453 &criticalConstruct) {
3454 genOMP(converter, eval, criticalConstruct);
3457 ompConstruct.u);
3460 void Fortran::lower::genOpenMPDeclarativeConstruct(
3461 Fortran::lower::AbstractConverter &converter,
3462 Fortran::lower::pft::Evaluation &eval,
3463 const Fortran::parser::OpenMPDeclarativeConstruct &ompDeclConstruct) {
3464 std::visit(
3465 common::visitors{
3466 [&](const Fortran::parser::OpenMPDeclarativeAllocate
3467 &declarativeAllocate) {
3468 TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate");
3470 [&](const Fortran::parser::OpenMPDeclareReductionConstruct
3471 &declareReductionConstruct) {
3472 TODO(converter.getCurrentLocation(),
3473 "OpenMPDeclareReductionConstruct");
3475 [&](const Fortran::parser::OpenMPDeclareSimdConstruct
3476 &declareSimdConstruct) {
3477 TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct");
3479 [&](const Fortran::parser::OpenMPDeclareTargetConstruct
3480 &declareTargetConstruct) {
3481 genOMP(converter, eval, declareTargetConstruct);
3483 [&](const Fortran::parser::OpenMPRequiresConstruct
3484 &requiresConstruct) {
3485 // Requires directives are gathered and processed in semantics and
3486 // then combined in the lowering bridge before triggering codegen
3487 // just once. Hence, there is no need to lower each individual
3488 // occurrence here.
3490 [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) {
3491 // The directive is lowered when instantiating the variable to
3492 // support the case of threadprivate variable declared in module.
3495 ompDeclConstruct.u);
3498 int64_t Fortran::lower::getCollapseValue(
3499 const Fortran::parser::OmpClauseList &clauseList) {
3500 for (const Fortran::parser::OmpClause &clause : clauseList.v) {
3501 if (const auto &collapseClause =
3502 std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) {
3503 const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
3504 return Fortran::evaluate::ToInt64(*expr).value();
3507 return 1;
3510 void Fortran::lower::genThreadprivateOp(
3511 Fortran::lower::AbstractConverter &converter,
3512 const Fortran::lower::pft::Variable &var) {
3513 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
3514 mlir::Location currentLocation = converter.getCurrentLocation();
3516 const Fortran::semantics::Symbol &sym = var.getSymbol();
3517 mlir::Value symThreadprivateValue;
3518 if (const Fortran::semantics::Symbol *common =
3519 Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate())) {
3520 mlir::Value commonValue = converter.getSymbolAddress(*common);
3521 if (mlir::isa<mlir::omp::ThreadprivateOp>(commonValue.getDefiningOp())) {
3522 // Generate ThreadprivateOp for a common block instead of its members and
3523 // only do it once for a common block.
3524 return;
3526 // Generate ThreadprivateOp and rebind the common block.
3527 mlir::Value commonThreadprivateValue =
3528 firOpBuilder.create<mlir::omp::ThreadprivateOp>(
3529 currentLocation, commonValue.getType(), commonValue);
3530 converter.bindSymbol(*common, commonThreadprivateValue);
3531 // Generate the threadprivate value for the common block member.
3532 symThreadprivateValue = genCommonBlockMember(converter, currentLocation,
3533 sym, commonThreadprivateValue);
3534 } else if (!var.isGlobal()) {
3535 // Non-global variable which can be in threadprivate directive must be one
3536 // variable in main program, and it has implicit SAVE attribute. Take it as
3537 // with SAVE attribute, so to create GlobalOp for it to simplify the
3538 // translation to LLVM IR.
3539 fir::GlobalOp global = globalInitialization(converter, firOpBuilder, sym,
3540 var, currentLocation);
3542 mlir::Value symValue = firOpBuilder.create<fir::AddrOfOp>(
3543 currentLocation, global.resultType(), global.getSymbol());
3544 symThreadprivateValue = firOpBuilder.create<mlir::omp::ThreadprivateOp>(
3545 currentLocation, symValue.getType(), symValue);
3546 } else {
3547 mlir::Value symValue = converter.getSymbolAddress(sym);
3549 // The symbol may be use-associated multiple times, and nothing needs to be
3550 // done after the original symbol is mapped to the threadprivatized value
3551 // for the first time. Use the threadprivatized value directly.
3552 mlir::Operation *op;
3553 if (auto declOp = symValue.getDefiningOp<hlfir::DeclareOp>())
3554 op = declOp.getMemref().getDefiningOp();
3555 else
3556 op = symValue.getDefiningOp();
3557 if (mlir::isa<mlir::omp::ThreadprivateOp>(op))
3558 return;
3560 symThreadprivateValue = firOpBuilder.create<mlir::omp::ThreadprivateOp>(
3561 currentLocation, symValue.getType(), symValue);
3564 fir::ExtendedValue sexv = converter.getSymbolExtendedValue(sym);
3565 fir::ExtendedValue symThreadprivateExv =
3566 getExtendedValue(sexv, symThreadprivateValue);
3567 converter.bindSymbol(sym, symThreadprivateExv);
3570 // This function replicates threadprivate's behaviour of generating
3571 // an internal fir.GlobalOp for non-global variables in the main program
3572 // that have the implicit SAVE attribute, to simplifiy LLVM-IR and MLIR
3573 // generation.
3574 void Fortran::lower::genDeclareTargetIntGlobal(
3575 Fortran::lower::AbstractConverter &converter,
3576 const Fortran::lower::pft::Variable &var) {
3577 if (!var.isGlobal()) {
3578 // A non-global variable which can be in a declare target directive must
3579 // be a variable in the main program, and it has the implicit SAVE
3580 // attribute. We create a GlobalOp for it to simplify the translation to
3581 // LLVM IR.
3582 globalInitialization(converter, converter.getFirOpBuilder(),
3583 var.getSymbol(), var, converter.getCurrentLocation());
3587 // Generate an OpenMP reduction operation.
3588 // TODO: Currently assumes it is either an integer addition/multiplication
3589 // reduction, or a logical and reduction. Generalize this for various reduction
3590 // operation types.
3591 // TODO: Generate the reduction operation during lowering instead of creating
3592 // and removing operations since this is not a robust approach. Also, removing
3593 // ops in the builder (instead of a rewriter) is probably not the best approach.
3594 void Fortran::lower::genOpenMPReduction(
3595 Fortran::lower::AbstractConverter &converter,
3596 const Fortran::parser::OmpClauseList &clauseList) {
3597 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
3599 for (const Fortran::parser::OmpClause &clause : clauseList.v) {
3600 if (const auto &reductionClause =
3601 std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
3602 const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
3603 reductionClause->v.t)};
3604 const auto &objectList{
3605 std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
3606 if (const auto *reductionOp =
3607 std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
3608 const auto &intrinsicOp{
3609 std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
3610 reductionOp->u)};
3612 switch (intrinsicOp) {
3613 case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
3614 case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
3615 case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
3616 case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
3617 case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
3618 case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
3619 break;
3620 default:
3621 continue;
3623 for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
3624 if (const auto *name{
3625 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
3626 if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
3627 mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
3628 if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
3629 reductionVal = declOp.getBase();
3630 mlir::Type reductionType =
3631 reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
3632 if (!reductionType.isa<fir::LogicalType>()) {
3633 if (!reductionType.isIntOrIndexOrFloat())
3634 continue;
3636 for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
3637 if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
3638 reductionValUse.getOwner())) {
3639 mlir::Value loadVal = loadOp.getRes();
3640 if (reductionType.isa<fir::LogicalType>()) {
3641 mlir::Operation *reductionOp = findReductionChain(loadVal);
3642 fir::ConvertOp convertOp =
3643 getConvertFromReductionOp(reductionOp, loadVal);
3644 updateReduction(reductionOp, firOpBuilder, loadVal,
3645 reductionVal, &convertOp);
3646 removeStoreOp(reductionOp, reductionVal);
3647 } else if (mlir::Operation *reductionOp =
3648 findReductionChain(loadVal, &reductionVal)) {
3649 updateReduction(reductionOp, firOpBuilder, loadVal,
3650 reductionVal);
3657 } else if (const auto *reductionIntrinsic =
3658 std::get_if<Fortran::parser::ProcedureDesignator>(
3659 &redOperator.u)) {
3660 if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
3661 reductionIntrinsic)}) {
3662 std::string redName = name->ToString();
3663 if ((name->source != "max") && (name->source != "min") &&
3664 (name->source != "ior") && (name->source != "ieor") &&
3665 (name->source != "iand")) {
3666 continue;
3668 for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
3669 if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
3670 ompObject)}) {
3671 if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
3672 mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
3673 if (auto declOp =
3674 reductionVal.getDefiningOp<hlfir::DeclareOp>())
3675 reductionVal = declOp.getBase();
3676 for (const mlir::OpOperand &reductionValUse :
3677 reductionVal.getUses()) {
3678 if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
3679 reductionValUse.getOwner())) {
3680 mlir::Value loadVal = loadOp.getRes();
3681 // Max is lowered as a compare -> select.
3682 // Match the pattern here.
3683 mlir::Operation *reductionOp =
3684 findReductionChain(loadVal, &reductionVal);
3685 if (reductionOp == nullptr)
3686 continue;
3688 if (redName == "max" || redName == "min") {
3689 assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
3690 "Selection Op not found in reduction intrinsic");
3691 mlir::Operation *compareOp =
3692 getCompareFromReductionOp(reductionOp, loadVal);
3693 updateReduction(compareOp, firOpBuilder, loadVal,
3694 reductionVal);
3696 if (redName == "ior" || redName == "ieor" ||
3697 redName == "iand") {
3699 updateReduction(reductionOp, firOpBuilder, loadVal,
3700 reductionVal);
3713 mlir::Operation *Fortran::lower::findReductionChain(mlir::Value loadVal,
3714 mlir::Value *reductionVal) {
3715 for (mlir::OpOperand &loadOperand : loadVal.getUses()) {
3716 if (mlir::Operation *reductionOp = loadOperand.getOwner()) {
3717 if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(reductionOp)) {
3718 for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) {
3719 if (mlir::Operation *reductionOp = convertOperand.getOwner())
3720 return reductionOp;
3723 for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
3724 if (auto store =
3725 mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
3726 if (store.getMemref() == *reductionVal) {
3727 store.erase();
3728 return reductionOp;
3731 if (auto assign =
3732 mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner())) {
3733 if (assign.getLhs() == *reductionVal) {
3734 assign.erase();
3735 return reductionOp;
3741 return nullptr;
3744 // for a logical operator 'op' reduction X = X op Y
3745 // This function returns the operation responsible for converting Y from
3746 // fir.logical<4> to i1
3747 fir::ConvertOp
3748 Fortran::lower::getConvertFromReductionOp(mlir::Operation *reductionOp,
3749 mlir::Value loadVal) {
3750 for (mlir::Value reductionOperand : reductionOp->getOperands()) {
3751 if (auto convertOp =
3752 mlir::dyn_cast<fir::ConvertOp>(reductionOperand.getDefiningOp())) {
3753 if (convertOp.getOperand() == loadVal)
3754 continue;
3755 return convertOp;
3758 return nullptr;
3761 void Fortran::lower::updateReduction(mlir::Operation *op,
3762 fir::FirOpBuilder &firOpBuilder,
3763 mlir::Value loadVal,
3764 mlir::Value reductionVal,
3765 fir::ConvertOp *convertOp) {
3766 mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
3767 firOpBuilder.setInsertionPoint(op);
3769 mlir::Value reductionOp;
3770 if (convertOp)
3771 reductionOp = convertOp->getOperand();
3772 else if (op->getOperand(0) == loadVal)
3773 reductionOp = op->getOperand(1);
3774 else
3775 reductionOp = op->getOperand(0);
3777 firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), reductionOp,
3778 reductionVal);
3779 firOpBuilder.restoreInsertionPoint(insertPtDel);
3782 void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp,
3783 mlir::Value symVal) {
3784 for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) {
3785 if (auto convertReduction =
3786 mlir::dyn_cast<fir::ConvertOp>(reductionOpUse)) {
3787 for (mlir::Operation *convertReductionUse :
3788 convertReduction.getRes().getUsers()) {
3789 if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(convertReductionUse)) {
3790 if (storeOp.getMemref() == symVal)
3791 storeOp.erase();
3793 if (auto assignOp =
3794 mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) {
3795 if (assignOp.getLhs() == symVal)
3796 assignOp.erase();
3803 bool Fortran::lower::isOpenMPTargetConstruct(
3804 const Fortran::parser::OpenMPConstruct &omp) {
3805 llvm::omp::Directive dir = llvm::omp::Directive::OMPD_unknown;
3806 if (const auto *block =
3807 std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u)) {
3808 const auto &begin =
3809 std::get<Fortran::parser::OmpBeginBlockDirective>(block->t);
3810 dir = std::get<Fortran::parser::OmpBlockDirective>(begin.t).v;
3811 } else if (const auto *loop =
3812 std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u)) {
3813 const auto &begin =
3814 std::get<Fortran::parser::OmpBeginLoopDirective>(loop->t);
3815 dir = std::get<Fortran::parser::OmpLoopDirective>(begin.t).v;
3817 return llvm::omp::allTargetSet.test(dir);
3820 bool Fortran::lower::isOpenMPDeviceDeclareTarget(
3821 Fortran::lower::AbstractConverter &converter,
3822 Fortran::lower::pft::Evaluation &eval,
3823 const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
3824 return std::visit(
3825 Fortran::common::visitors{
3826 [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) {
3827 mlir::omp::DeclareTargetDeviceType targetType =
3828 getDeclareTargetFunctionDevice(converter, eval, ompReq)
3829 .value_or(mlir::omp::DeclareTargetDeviceType::host);
3830 return targetType != mlir::omp::DeclareTargetDeviceType::host;
3832 [&](const auto &) { return false; },
3834 ompDecl.u);
3837 void Fortran::lower::genOpenMPRequires(
3838 mlir::Operation *mod, const Fortran::semantics::Symbol *symbol) {
3839 using MlirRequires = mlir::omp::ClauseRequires;
3840 using SemaRequires = Fortran::semantics::WithOmpDeclarative::RequiresFlag;
3842 if (auto offloadMod =
3843 llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) {
3844 Fortran::semantics::WithOmpDeclarative::RequiresFlags semaFlags;
3845 if (symbol) {
3846 Fortran::common::visit(
3847 [&](const auto &details) {
3848 if constexpr (std::is_base_of_v<
3849 Fortran::semantics::WithOmpDeclarative,
3850 std::decay_t<decltype(details)>>) {
3851 if (details.has_ompRequires())
3852 semaFlags = *details.ompRequires();
3855 symbol->details());
3858 MlirRequires mlirFlags = MlirRequires::none;
3859 if (semaFlags.test(SemaRequires::ReverseOffload))
3860 mlirFlags = mlirFlags | MlirRequires::reverse_offload;
3861 if (semaFlags.test(SemaRequires::UnifiedAddress))
3862 mlirFlags = mlirFlags | MlirRequires::unified_address;
3863 if (semaFlags.test(SemaRequires::UnifiedSharedMemory))
3864 mlirFlags = mlirFlags | MlirRequires::unified_shared_memory;
3865 if (semaFlags.test(SemaRequires::DynamicAllocators))
3866 mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;
3868 offloadMod.setRequires(mlirFlags);