[CodeGen][Hexagon] Replace PointerType::getUnqual(Type) with opaque version (NFC...
[llvm-project.git] / flang / lib / Lower / OpenACC.cpp
blobac1a1c00eb145f8c2be9b73b918e44eb120f5a22
1 //===-- OpenACC.cpp -- OpenACC directive lowering -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11 //===----------------------------------------------------------------------===//
13 #include "flang/Lower/OpenACC.h"
15 #include "flang/Common/idioms.h"
16 #include "flang/Lower/Bridge.h"
17 #include "flang/Lower/ConvertType.h"
18 #include "flang/Lower/DirectivesCommon.h"
19 #include "flang/Lower/Mangler.h"
20 #include "flang/Lower/PFTBuilder.h"
21 #include "flang/Lower/StatementContext.h"
22 #include "flang/Lower/Support/Utils.h"
23 #include "flang/Optimizer/Builder/BoxValue.h"
24 #include "flang/Optimizer/Builder/Complex.h"
25 #include "flang/Optimizer/Builder/FIRBuilder.h"
26 #include "flang/Optimizer/Builder/HLFIRTools.h"
27 #include "flang/Optimizer/Builder/IntrinsicCall.h"
28 #include "flang/Optimizer/Builder/Todo.h"
29 #include "flang/Optimizer/Dialect/FIRType.h"
30 #include "flang/Parser/parse-tree-visitor.h"
31 #include "flang/Parser/parse-tree.h"
32 #include "flang/Semantics/expression.h"
33 #include "flang/Semantics/scope.h"
34 #include "flang/Semantics/tools.h"
35 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
36 #include "mlir/Support/LLVM.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/Frontend/OpenACC/ACC.h.inc"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Support/Debug.h"
42 #define DEBUG_TYPE "flang-lower-openacc"
44 static llvm::cl::opt<bool> unwrapFirBox(
45 "openacc-unwrap-fir-box",
46 llvm::cl::desc(
47 "Whether to use the address from fix.box in data clause operations."),
48 llvm::cl::init(false));
50 static llvm::cl::opt<bool> generateDefaultBounds(
51 "openacc-generate-default-bounds",
52 llvm::cl::desc("Whether to generate default bounds for arrays."),
53 llvm::cl::init(false));
55 // Special value for * passed in device_type or gang clauses.
56 static constexpr std::int64_t starCst = -1;
58 static unsigned routineCounter = 0;
59 static constexpr llvm::StringRef accRoutinePrefix = "acc_routine_";
60 static constexpr llvm::StringRef accPrivateInitName = "acc.private.init";
61 static constexpr llvm::StringRef accReductionInitName = "acc.reduction.init";
62 static constexpr llvm::StringRef accFirDescriptorPostfix = "_desc";
64 static mlir::Location
65 genOperandLocation(Fortran::lower::AbstractConverter &converter,
66 const Fortran::parser::AccObject &accObject) {
67 mlir::Location loc = converter.genUnknownLocation();
68 Fortran::common::visit(
69 Fortran::common::visitors{
70 [&](const Fortran::parser::Designator &designator) {
71 loc = converter.genLocation(designator.source);
73 [&](const Fortran::parser::Name &name) {
74 loc = converter.genLocation(name.source);
75 }},
76 accObject.u);
77 return loc;
80 static void addOperands(llvm::SmallVectorImpl<mlir::Value> &operands,
81 llvm::SmallVectorImpl<int32_t> &operandSegments,
82 llvm::ArrayRef<mlir::Value> clauseOperands) {
83 operands.append(clauseOperands.begin(), clauseOperands.end());
84 operandSegments.push_back(clauseOperands.size());
87 static void addOperand(llvm::SmallVectorImpl<mlir::Value> &operands,
88 llvm::SmallVectorImpl<int32_t> &operandSegments,
89 const mlir::Value &clauseOperand) {
90 if (clauseOperand) {
91 operands.push_back(clauseOperand);
92 operandSegments.push_back(1);
93 } else {
94 operandSegments.push_back(0);
98 template <typename Op>
99 static Op
100 createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
101 mlir::Value baseAddr, std::stringstream &name,
102 mlir::SmallVector<mlir::Value> bounds, bool structured,
103 bool implicit, mlir::acc::DataClause dataClause,
104 mlir::Type retTy, llvm::ArrayRef<mlir::Value> async,
105 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
106 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
107 bool unwrapBoxAddr = false, mlir::Value isPresent = {}) {
108 mlir::Value varPtrPtr;
109 // The data clause may apply to either the box reference itself or the
110 // pointer to the data it holds. So use `unwrapBoxAddr` to decide.
111 // When we have a box value - assume it refers to the data inside box.
112 if (unwrapFirBox &&
113 ((fir::isBoxAddress(baseAddr.getType()) && unwrapBoxAddr) ||
114 fir::isa_box_type(baseAddr.getType()))) {
115 if (isPresent) {
116 mlir::Type ifRetTy =
117 mlir::cast<fir::BaseBoxType>(fir::unwrapRefType(baseAddr.getType()))
118 .getEleTy();
119 if (!fir::isa_ref_type(ifRetTy))
120 ifRetTy = fir::ReferenceType::get(ifRetTy);
121 baseAddr =
122 builder
123 .genIfOp(loc, {ifRetTy}, isPresent,
124 /*withElseRegion=*/true)
125 .genThen([&]() {
126 if (fir::isBoxAddress(baseAddr.getType()))
127 baseAddr = builder.create<fir::LoadOp>(loc, baseAddr);
128 mlir::Value boxAddr =
129 builder.create<fir::BoxAddrOp>(loc, baseAddr);
130 builder.create<fir::ResultOp>(loc, mlir::ValueRange{boxAddr});
132 .genElse([&] {
133 mlir::Value absent =
134 builder.create<fir::AbsentOp>(loc, ifRetTy);
135 builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
137 .getResults()[0];
138 } else {
139 if (fir::isBoxAddress(baseAddr.getType()))
140 baseAddr = builder.create<fir::LoadOp>(loc, baseAddr);
141 baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
143 retTy = baseAddr.getType();
146 llvm::SmallVector<mlir::Value, 8> operands;
147 llvm::SmallVector<int32_t, 8> operandSegments;
149 addOperand(operands, operandSegments, baseAddr);
150 addOperand(operands, operandSegments, varPtrPtr);
151 addOperands(operands, operandSegments, bounds);
152 addOperands(operands, operandSegments, async);
154 Op op = builder.create<Op>(loc, retTy, operands);
155 op.setNameAttr(builder.getStringAttr(name.str()));
156 op.setStructured(structured);
157 op.setImplicit(implicit);
158 op.setDataClause(dataClause);
159 if (auto mappableTy =
160 mlir::dyn_cast<mlir::acc::MappableType>(baseAddr.getType())) {
161 op.setVarType(baseAddr.getType());
162 } else {
163 assert(mlir::isa<mlir::acc::PointerLikeType>(baseAddr.getType()) &&
164 "expected pointer-like");
165 op.setVarType(mlir::cast<mlir::acc::PointerLikeType>(baseAddr.getType())
166 .getElementType());
169 op->setAttr(Op::getOperandSegmentSizeAttr(),
170 builder.getDenseI32ArrayAttr(operandSegments));
171 if (!asyncDeviceTypes.empty())
172 op.setAsyncOperandsDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
173 if (!asyncOnlyDeviceTypes.empty())
174 op.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
175 return op;
178 static void addDeclareAttr(fir::FirOpBuilder &builder, mlir::Operation *op,
179 mlir::acc::DataClause clause) {
180 if (!op)
181 return;
182 op->setAttr(mlir::acc::getDeclareAttrName(),
183 mlir::acc::DeclareAttr::get(builder.getContext(),
184 mlir::acc::DataClauseAttr::get(
185 builder.getContext(), clause)));
188 static mlir::func::FuncOp
189 createDeclareFunc(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder,
190 mlir::Location loc, llvm::StringRef funcName,
191 llvm::SmallVector<mlir::Type> argsTy = {},
192 llvm::SmallVector<mlir::Location> locs = {}) {
193 auto funcTy = mlir::FunctionType::get(modBuilder.getContext(), argsTy, {});
194 auto funcOp = modBuilder.create<mlir::func::FuncOp>(loc, funcName, funcTy);
195 funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
196 builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
197 locs);
198 builder.setInsertionPointToEnd(&funcOp.getRegion().back());
199 builder.create<mlir::func::ReturnOp>(loc);
200 builder.setInsertionPointToStart(&funcOp.getRegion().back());
201 return funcOp;
204 template <typename Op>
205 static Op
206 createSimpleOp(fir::FirOpBuilder &builder, mlir::Location loc,
207 const llvm::SmallVectorImpl<mlir::Value> &operands,
208 const llvm::SmallVectorImpl<int32_t> &operandSegments) {
209 llvm::ArrayRef<mlir::Type> argTy;
210 Op op = builder.create<Op>(loc, argTy, operands);
211 op->setAttr(Op::getOperandSegmentSizeAttr(),
212 builder.getDenseI32ArrayAttr(operandSegments));
213 return op;
216 template <typename EntryOp>
217 static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
218 fir::FirOpBuilder &builder,
219 mlir::Location loc, mlir::Type descTy,
220 llvm::StringRef funcNamePrefix,
221 std::stringstream &asFortran,
222 mlir::acc::DataClause clause) {
223 auto crtInsPt = builder.saveInsertionPoint();
224 std::stringstream registerFuncName;
225 registerFuncName << funcNamePrefix.str()
226 << Fortran::lower::declarePostAllocSuffix.str();
228 if (!mlir::isa<fir::ReferenceType>(descTy))
229 descTy = fir::ReferenceType::get(descTy);
230 auto registerFuncOp = createDeclareFunc(
231 modBuilder, builder, loc, registerFuncName.str(), {descTy}, {loc});
233 llvm::SmallVector<mlir::Value> bounds;
234 std::stringstream asFortranDesc;
235 asFortranDesc << asFortran.str();
236 if (unwrapFirBox)
237 asFortranDesc << accFirDescriptorPostfix.str();
239 // Updating descriptor must occur before the mapping of the data so that
240 // attached data pointer is not overwritten.
241 mlir::acc::UpdateDeviceOp updateDeviceOp =
242 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
243 builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
244 /*structured=*/false, /*implicit=*/true,
245 mlir::acc::DataClause::acc_update_device, descTy,
246 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
247 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
248 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
249 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
251 if (unwrapFirBox) {
252 mlir::Value desc =
253 builder.create<fir::LoadOp>(loc, registerFuncOp.getArgument(0));
254 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, desc);
255 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
256 EntryOp entryOp = createDataEntryOp<EntryOp>(
257 builder, loc, boxAddrOp.getResult(), asFortran, bounds,
258 /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType(),
259 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
260 builder.create<mlir::acc::DeclareEnterOp>(
261 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
262 mlir::ValueRange(entryOp.getAccVar()));
265 modBuilder.setInsertionPointAfter(registerFuncOp);
266 builder.restoreInsertionPoint(crtInsPt);
269 template <typename ExitOp>
270 static void createDeclareDeallocFuncWithArg(
271 mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, mlir::Location loc,
272 mlir::Type descTy, llvm::StringRef funcNamePrefix,
273 std::stringstream &asFortran, mlir::acc::DataClause clause) {
274 auto crtInsPt = builder.saveInsertionPoint();
275 // Generate the pre dealloc function.
276 std::stringstream preDeallocFuncName;
277 preDeallocFuncName << funcNamePrefix.str()
278 << Fortran::lower::declarePreDeallocSuffix.str();
279 if (!mlir::isa<fir::ReferenceType>(descTy))
280 descTy = fir::ReferenceType::get(descTy);
281 auto preDeallocOp = createDeclareFunc(
282 modBuilder, builder, loc, preDeallocFuncName.str(), {descTy}, {loc});
284 mlir::Value var = preDeallocOp.getArgument(0);
285 if (unwrapFirBox) {
286 mlir::Value loadOp =
287 builder.create<fir::LoadOp>(loc, preDeallocOp.getArgument(0));
288 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
289 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
290 var = boxAddrOp.getResult();
293 llvm::SmallVector<mlir::Value> bounds;
294 mlir::acc::GetDevicePtrOp entryOp =
295 createDataEntryOp<mlir::acc::GetDevicePtrOp>(
296 builder, loc, var, asFortran, bounds,
297 /*structured=*/false, /*implicit=*/false, clause, var.getType(),
298 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
299 builder.create<mlir::acc::DeclareExitOp>(
300 loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccVar()));
302 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
303 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
304 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccVar(),
305 entryOp.getVar(), entryOp.getVarType(),
306 entryOp.getBounds(), entryOp.getAsyncOperands(),
307 entryOp.getAsyncOperandsDeviceTypeAttr(),
308 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
309 /*structured=*/false, /*implicit=*/false,
310 builder.getStringAttr(*entryOp.getName()));
311 else
312 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccVar(),
313 entryOp.getBounds(), entryOp.getAsyncOperands(),
314 entryOp.getAsyncOperandsDeviceTypeAttr(),
315 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
316 /*structured=*/false, /*implicit=*/false,
317 builder.getStringAttr(*entryOp.getName()));
319 // Generate the post dealloc function.
320 modBuilder.setInsertionPointAfter(preDeallocOp);
321 std::stringstream postDeallocFuncName;
322 postDeallocFuncName << funcNamePrefix.str()
323 << Fortran::lower::declarePostDeallocSuffix.str();
324 auto postDeallocOp = createDeclareFunc(
325 modBuilder, builder, loc, postDeallocFuncName.str(), {descTy}, {loc});
327 var = postDeallocOp.getArgument(0);
328 if (unwrapFirBox) {
329 var = builder.create<fir::LoadOp>(loc, postDeallocOp.getArgument(0));
330 asFortran << accFirDescriptorPostfix.str();
333 mlir::acc::UpdateDeviceOp updateDeviceOp =
334 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
335 builder, loc, var, asFortran, bounds,
336 /*structured=*/false, /*implicit=*/true,
337 mlir::acc::DataClause::acc_update_device, var.getType(),
338 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
339 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
340 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
341 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
342 modBuilder.setInsertionPointAfter(postDeallocOp);
343 builder.restoreInsertionPoint(crtInsPt);
346 Fortran::semantics::Symbol &
347 getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
348 if (const auto *designator =
349 std::get_if<Fortran::parser::Designator>(&accObject.u)) {
350 if (const auto *name =
351 Fortran::semantics::getDesignatorNameIfDataRef(*designator))
352 return *name->symbol;
353 if (const auto *arrayElement =
354 Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
355 *designator)) {
356 const Fortran::parser::Name &name =
357 Fortran::parser::GetLastName(arrayElement->base);
358 return *name.symbol;
360 if (const auto *component =
361 Fortran::parser::Unwrap<Fortran::parser::StructureComponent>(
362 *designator)) {
363 return *component->component.symbol;
365 } else if (const auto *name =
366 std::get_if<Fortran::parser::Name>(&accObject.u)) {
367 return *name->symbol;
369 llvm::report_fatal_error("Could not find symbol");
372 template <typename Op>
373 static void
374 genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
375 Fortran::lower::AbstractConverter &converter,
376 Fortran::semantics::SemanticsContext &semanticsContext,
377 Fortran::lower::StatementContext &stmtCtx,
378 llvm::SmallVectorImpl<mlir::Value> &dataOperands,
379 mlir::acc::DataClause dataClause, bool structured,
380 bool implicit, llvm::ArrayRef<mlir::Value> async,
381 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
382 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
383 bool setDeclareAttr = false) {
384 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
385 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
386 for (const auto &accObject : objectList.v) {
387 llvm::SmallVector<mlir::Value> bounds;
388 std::stringstream asFortran;
389 mlir::Location operandLocation = genOperandLocation(converter, accObject);
390 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
391 Fortran::semantics::MaybeExpr designator = Fortran::common::visit(
392 [&](auto &&s) { return ea.Analyze(s); }, accObject.u);
393 fir::factory::AddrAndBoundsInfo info =
394 Fortran::lower::gatherDataOperandAddrAndBounds<
395 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
396 converter, builder, semanticsContext, stmtCtx, symbol, designator,
397 operandLocation, asFortran, bounds,
398 /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
399 /*genDefaultBounds=*/generateDefaultBounds);
400 LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
402 // If the input value is optional and is not a descriptor, we use the
403 // rawInput directly.
404 mlir::Value baseAddr = ((fir::unwrapRefType(info.addr.getType()) !=
405 fir::unwrapRefType(info.rawInput.getType())) &&
406 info.isPresent)
407 ? info.rawInput
408 : info.addr;
409 Op op = createDataEntryOp<Op>(
410 builder, operandLocation, baseAddr, asFortran, bounds, structured,
411 implicit, dataClause, baseAddr.getType(), async, asyncDeviceTypes,
412 asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true, info.isPresent);
413 dataOperands.push_back(op.getAccVar());
417 template <typename EntryOp, typename ExitOp>
418 static void genDeclareDataOperandOperations(
419 const Fortran::parser::AccObjectList &objectList,
420 Fortran::lower::AbstractConverter &converter,
421 Fortran::semantics::SemanticsContext &semanticsContext,
422 Fortran::lower::StatementContext &stmtCtx,
423 llvm::SmallVectorImpl<mlir::Value> &dataOperands,
424 mlir::acc::DataClause dataClause, bool structured, bool implicit) {
425 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
426 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
427 for (const auto &accObject : objectList.v) {
428 llvm::SmallVector<mlir::Value> bounds;
429 std::stringstream asFortran;
430 mlir::Location operandLocation = genOperandLocation(converter, accObject);
431 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
432 Fortran::semantics::MaybeExpr designator = Fortran::common::visit(
433 [&](auto &&s) { return ea.Analyze(s); }, accObject.u);
434 fir::factory::AddrAndBoundsInfo info =
435 Fortran::lower::gatherDataOperandAddrAndBounds<
436 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
437 converter, builder, semanticsContext, stmtCtx, symbol, designator,
438 operandLocation, asFortran, bounds,
439 /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
440 /*genDefaultBounds=*/generateDefaultBounds);
441 LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
442 EntryOp op = createDataEntryOp<EntryOp>(
443 builder, operandLocation, info.addr, asFortran, bounds, structured,
444 implicit, dataClause, info.addr.getType(),
445 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
446 dataOperands.push_back(op.getAccVar());
447 addDeclareAttr(builder, op.getVar().getDefiningOp(), dataClause);
448 if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) {
449 mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
450 modBuilder.setInsertionPointAfter(builder.getFunction());
451 std::string prefix = converter.mangleName(symbol);
452 createDeclareAllocFuncWithArg<EntryOp>(
453 modBuilder, builder, operandLocation, info.addr.getType(), prefix,
454 asFortran, dataClause);
455 if constexpr (!std::is_same_v<EntryOp, ExitOp>)
456 createDeclareDeallocFuncWithArg<ExitOp>(
457 modBuilder, builder, operandLocation, info.addr.getType(), prefix,
458 asFortran, dataClause);
463 template <typename EntryOp, typename ExitOp, typename Clause>
464 static void genDeclareDataOperandOperationsWithModifier(
465 const Clause *x, Fortran::lower::AbstractConverter &converter,
466 Fortran::semantics::SemanticsContext &semanticsContext,
467 Fortran::lower::StatementContext &stmtCtx,
468 Fortran::parser::AccDataModifier::Modifier mod,
469 llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
470 const mlir::acc::DataClause clause,
471 const mlir::acc::DataClause clauseWithModifier) {
472 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
473 const auto &accObjectList =
474 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
475 const auto &modifier =
476 std::get<std::optional<Fortran::parser::AccDataModifier>>(
477 listWithModifier.t);
478 mlir::acc::DataClause dataClause =
479 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
480 genDeclareDataOperandOperations<EntryOp, ExitOp>(
481 accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands,
482 dataClause,
483 /*structured=*/true, /*implicit=*/false);
486 template <typename EntryOp, typename ExitOp>
487 static void genDataExitOperations(fir::FirOpBuilder &builder,
488 llvm::SmallVector<mlir::Value> operands,
489 bool structured) {
490 for (mlir::Value operand : operands) {
491 auto entryOp = mlir::dyn_cast_or_null<EntryOp>(operand.getDefiningOp());
492 assert(entryOp && "data entry op expected");
493 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
494 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
495 builder.create<ExitOp>(
496 entryOp.getLoc(), entryOp.getAccVar(), entryOp.getVar(),
497 entryOp.getVarType(), entryOp.getBounds(), entryOp.getAsyncOperands(),
498 entryOp.getAsyncOperandsDeviceTypeAttr(), entryOp.getAsyncOnlyAttr(),
499 entryOp.getDataClause(), structured, entryOp.getImplicit(),
500 builder.getStringAttr(*entryOp.getName()));
501 else
502 builder.create<ExitOp>(
503 entryOp.getLoc(), entryOp.getAccVar(), entryOp.getBounds(),
504 entryOp.getAsyncOperands(), entryOp.getAsyncOperandsDeviceTypeAttr(),
505 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), structured,
506 entryOp.getImplicit(), builder.getStringAttr(*entryOp.getName()));
510 fir::ShapeOp genShapeOp(mlir::OpBuilder &builder, fir::SequenceType seqTy,
511 mlir::Location loc) {
512 llvm::SmallVector<mlir::Value> extents;
513 mlir::Type idxTy = builder.getIndexType();
514 for (auto extent : seqTy.getShape())
515 extents.push_back(builder.create<mlir::arith::ConstantOp>(
516 loc, idxTy, builder.getIntegerAttr(idxTy, extent)));
517 return builder.create<fir::ShapeOp>(loc, extents);
520 template <typename RecipeOp>
521 static void genPrivateLikeInitRegion(mlir::OpBuilder &builder, RecipeOp recipe,
522 mlir::Type ty, mlir::Location loc) {
523 mlir::Value retVal = recipe.getInitRegion().front().getArgument(0);
524 ty = fir::unwrapRefType(ty);
525 if (fir::isa_trivial(ty)) {
526 auto alloca = builder.create<fir::AllocaOp>(loc, ty);
527 auto declareOp = builder.create<hlfir::DeclareOp>(
528 loc, alloca, accPrivateInitName, /*shape=*/nullptr,
529 llvm::ArrayRef<mlir::Value>{}, /*dummy_scope=*/nullptr,
530 fir::FortranVariableFlagsAttr{});
531 retVal = declareOp.getBase();
532 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) {
533 if (fir::isa_trivial(seqTy.getEleTy())) {
534 mlir::Value shape;
535 llvm::SmallVector<mlir::Value> extents;
536 if (seqTy.hasDynamicExtents()) {
537 // Extents are passed as block arguments. First argument is the
538 // original value.
539 for (unsigned i = 1; i < recipe.getInitRegion().getArguments().size();
540 ++i)
541 extents.push_back(recipe.getInitRegion().getArgument(i));
542 shape = builder.create<fir::ShapeOp>(loc, extents);
543 } else {
544 shape = genShapeOp(builder, seqTy, loc);
546 auto alloca = builder.create<fir::AllocaOp>(
547 loc, seqTy, /*typeparams=*/mlir::ValueRange{}, extents);
548 auto declareOp = builder.create<hlfir::DeclareOp>(
549 loc, alloca, accPrivateInitName, shape, llvm::ArrayRef<mlir::Value>{},
550 /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
551 retVal = declareOp.getBase();
553 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
554 mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy());
555 if (!fir::isa_trivial(innerTy) && !mlir::isa<fir::SequenceType>(innerTy))
556 TODO(loc, "Unsupported boxed type in OpenACC privatization");
557 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()};
558 hlfir::Entity source = hlfir::Entity{retVal};
559 auto [temp, cleanup] = hlfir::createTempFromMold(loc, firBuilder, source);
560 retVal = temp;
562 builder.create<mlir::acc::YieldOp>(loc, retVal);
565 mlir::acc::PrivateRecipeOp
566 Fortran::lower::createOrGetPrivateRecipe(mlir::OpBuilder &builder,
567 llvm::StringRef recipeName,
568 mlir::Location loc, mlir::Type ty) {
569 mlir::ModuleOp mod =
570 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
571 if (auto recipe = mod.lookupSymbol<mlir::acc::PrivateRecipeOp>(recipeName))
572 return recipe;
574 auto crtPos = builder.saveInsertionPoint();
575 mlir::OpBuilder modBuilder(mod.getBodyRegion());
576 auto recipe =
577 modBuilder.create<mlir::acc::PrivateRecipeOp>(loc, recipeName, ty);
578 llvm::SmallVector<mlir::Type> argsTy{ty};
579 llvm::SmallVector<mlir::Location> argsLoc{loc};
580 if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) {
581 if (auto seqTy =
582 mlir::dyn_cast_or_null<fir::SequenceType>(refTy.getEleTy())) {
583 if (seqTy.hasDynamicExtents()) {
584 mlir::Type idxTy = builder.getIndexType();
585 for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
586 argsTy.push_back(idxTy);
587 argsLoc.push_back(loc);
592 builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(),
593 argsTy, argsLoc);
594 builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
595 genPrivateLikeInitRegion<mlir::acc::PrivateRecipeOp>(builder, recipe, ty,
596 loc);
597 builder.restoreInsertionPoint(crtPos);
598 return recipe;
601 /// Check if the DataBoundsOp is a constant bound (lb and ub are constants or
602 /// extent is a constant).
603 bool isConstantBound(mlir::acc::DataBoundsOp &op) {
604 if (op.getLowerbound() && fir::getIntIfConstant(op.getLowerbound()) &&
605 op.getUpperbound() && fir::getIntIfConstant(op.getUpperbound()))
606 return true;
607 if (op.getExtent() && fir::getIntIfConstant(op.getExtent()))
608 return true;
609 return false;
612 /// Return true iff all the bounds are expressed with constant values.
613 bool areAllBoundConstant(const llvm::SmallVector<mlir::Value> &bounds) {
614 for (auto bound : bounds) {
615 auto dataBound =
616 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
617 assert(dataBound && "Must be DataBoundOp operation");
618 if (!isConstantBound(dataBound))
619 return false;
621 return true;
624 static llvm::SmallVector<mlir::Value>
625 genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc,
626 mlir::acc::DataBoundsOp &dataBound) {
627 mlir::Type idxTy = builder.getIndexType();
628 mlir::Value lb, ub, step;
629 if (dataBound.getLowerbound() &&
630 fir::getIntIfConstant(dataBound.getLowerbound()) &&
631 dataBound.getUpperbound() &&
632 fir::getIntIfConstant(dataBound.getUpperbound())) {
633 lb = builder.createIntegerConstant(
634 loc, idxTy, *fir::getIntIfConstant(dataBound.getLowerbound()));
635 ub = builder.createIntegerConstant(
636 loc, idxTy, *fir::getIntIfConstant(dataBound.getUpperbound()));
637 step = builder.createIntegerConstant(loc, idxTy, 1);
638 } else if (dataBound.getExtent()) {
639 lb = builder.createIntegerConstant(loc, idxTy, 0);
640 ub = builder.createIntegerConstant(
641 loc, idxTy, *fir::getIntIfConstant(dataBound.getExtent()) - 1);
642 step = builder.createIntegerConstant(loc, idxTy, 1);
643 } else {
644 llvm::report_fatal_error("Expect constant lb/ub or extent");
646 return {lb, ub, step};
649 static mlir::Value genShapeFromBoundsOrArgs(
650 mlir::Location loc, fir::FirOpBuilder &builder, fir::SequenceType seqTy,
651 const llvm::SmallVector<mlir::Value> &bounds, mlir::ValueRange arguments) {
652 llvm::SmallVector<mlir::Value> args;
653 if (bounds.empty() && seqTy) {
654 if (seqTy.hasDynamicExtents()) {
655 assert(!arguments.empty() && "arguments must hold the entity");
656 auto entity = hlfir::Entity{arguments[0]};
657 return hlfir::genShape(loc, builder, entity);
659 return genShapeOp(builder, seqTy, loc).getResult();
660 } else if (areAllBoundConstant(bounds)) {
661 for (auto bound : llvm::reverse(bounds)) {
662 auto dataBound =
663 mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
664 args.append(genConstantBounds(builder, loc, dataBound));
666 } else {
667 assert(((arguments.size() - 2) / 3 == seqTy.getDimension()) &&
668 "Expect 3 block arguments per dimension");
669 for (auto arg : arguments.drop_front(2))
670 args.push_back(arg);
673 assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3");
674 llvm::SmallVector<mlir::Value> extents;
675 mlir::Type idxTy = builder.getIndexType();
676 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
677 mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
678 for (unsigned i = 0; i < args.size(); i += 3) {
679 mlir::Value s1 =
680 builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]);
681 mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one);
682 mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]);
683 mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
684 loc, mlir::arith::CmpIPredicate::sgt, s3, zero);
685 mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero);
686 extents.push_back(ext);
688 return builder.create<fir::ShapeOp>(loc, extents);
691 static hlfir::DesignateOp::Subscripts
692 getSubscriptsFromArgs(mlir::ValueRange args) {
693 hlfir::DesignateOp::Subscripts triplets;
694 for (unsigned i = 2; i < args.size(); i += 3)
695 triplets.emplace_back(
696 hlfir::DesignateOp::Triplet{args[i], args[i + 1], args[i + 2]});
697 return triplets;
700 static hlfir::Entity genDesignateWithTriplets(
701 fir::FirOpBuilder &builder, mlir::Location loc, hlfir::Entity &entity,
702 hlfir::DesignateOp::Subscripts &triplets, mlir::Value shape) {
703 llvm::SmallVector<mlir::Value> lenParams;
704 hlfir::genLengthParameters(loc, builder, entity, lenParams);
705 auto designate = builder.create<hlfir::DesignateOp>(
706 loc, entity.getBase().getType(), entity, /*component=*/"",
707 /*componentShape=*/mlir::Value{}, triplets,
708 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, shape,
709 lenParams);
710 return hlfir::Entity{designate.getResult()};
713 mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe(
714 mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
715 mlir::Type ty, llvm::SmallVector<mlir::Value> &bounds) {
716 mlir::ModuleOp mod =
717 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
718 if (auto recipe =
719 mod.lookupSymbol<mlir::acc::FirstprivateRecipeOp>(recipeName))
720 return recipe;
722 auto crtPos = builder.saveInsertionPoint();
723 mlir::OpBuilder modBuilder(mod.getBodyRegion());
724 auto recipe =
725 modBuilder.create<mlir::acc::FirstprivateRecipeOp>(loc, recipeName, ty);
726 llvm::SmallVector<mlir::Type> initArgsTy{ty};
727 llvm::SmallVector<mlir::Location> initArgsLoc{loc};
728 auto refTy = fir::unwrapRefType(ty);
729 if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(refTy)) {
730 if (seqTy.hasDynamicExtents()) {
731 mlir::Type idxTy = builder.getIndexType();
732 for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
733 initArgsTy.push_back(idxTy);
734 initArgsLoc.push_back(loc);
738 builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(),
739 initArgsTy, initArgsLoc);
740 builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
741 genPrivateLikeInitRegion<mlir::acc::FirstprivateRecipeOp>(builder, recipe, ty,
742 loc);
744 bool allConstantBound = areAllBoundConstant(bounds);
745 llvm::SmallVector<mlir::Type> argsTy{ty, ty};
746 llvm::SmallVector<mlir::Location> argsLoc{loc, loc};
747 if (!allConstantBound) {
748 for (mlir::Value bound : llvm::reverse(bounds)) {
749 auto dataBound =
750 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
751 argsTy.push_back(dataBound.getLowerbound().getType());
752 argsLoc.push_back(dataBound.getLowerbound().getLoc());
753 argsTy.push_back(dataBound.getUpperbound().getType());
754 argsLoc.push_back(dataBound.getUpperbound().getLoc());
755 argsTy.push_back(dataBound.getStartIdx().getType());
756 argsLoc.push_back(dataBound.getStartIdx().getLoc());
759 builder.createBlock(&recipe.getCopyRegion(), recipe.getCopyRegion().end(),
760 argsTy, argsLoc);
762 builder.setInsertionPointToEnd(&recipe.getCopyRegion().back());
763 ty = fir::unwrapRefType(ty);
764 if (fir::isa_trivial(ty)) {
765 mlir::Value initValue = builder.create<fir::LoadOp>(
766 loc, recipe.getCopyRegion().front().getArgument(0));
767 builder.create<fir::StoreOp>(loc, initValue,
768 recipe.getCopyRegion().front().getArgument(1));
769 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) {
770 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()};
771 auto shape = genShapeFromBoundsOrArgs(
772 loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments());
774 auto leftDeclOp = builder.create<hlfir::DeclareOp>(
775 loc, recipe.getCopyRegion().getArgument(0), llvm::StringRef{}, shape,
776 llvm::ArrayRef<mlir::Value>{}, /*dummy_scope=*/nullptr,
777 fir::FortranVariableFlagsAttr{});
778 auto rightDeclOp = builder.create<hlfir::DeclareOp>(
779 loc, recipe.getCopyRegion().getArgument(1), llvm::StringRef{}, shape,
780 llvm::ArrayRef<mlir::Value>{}, /*dummy_scope=*/nullptr,
781 fir::FortranVariableFlagsAttr{});
783 hlfir::DesignateOp::Subscripts triplets =
784 getSubscriptsFromArgs(recipe.getCopyRegion().getArguments());
785 auto leftEntity = hlfir::Entity{leftDeclOp.getBase()};
786 auto left =
787 genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape);
788 auto rightEntity = hlfir::Entity{rightDeclOp.getBase()};
789 auto right =
790 genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape);
792 firBuilder.create<hlfir::AssignOp>(loc, left, right);
794 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
795 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()};
796 llvm::SmallVector<mlir::Value> tripletArgs;
797 mlir::Type innerTy = fir::extractSequenceType(boxTy);
798 fir::SequenceType seqTy =
799 mlir::dyn_cast_or_null<fir::SequenceType>(innerTy);
800 if (!seqTy)
801 TODO(loc, "Unsupported boxed type in OpenACC firstprivate");
803 auto shape = genShapeFromBoundsOrArgs(
804 loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments());
805 hlfir::DesignateOp::Subscripts triplets =
806 getSubscriptsFromArgs(recipe.getCopyRegion().getArguments());
807 auto leftEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(0)};
808 auto left =
809 genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape);
810 auto rightEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(1)};
811 auto right =
812 genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape);
813 firBuilder.create<hlfir::AssignOp>(loc, left, right);
816 builder.create<mlir::acc::TerminatorOp>(loc);
817 builder.restoreInsertionPoint(crtPos);
818 return recipe;
821 /// Get a string representation of the bounds.
822 std::string getBoundsString(llvm::SmallVector<mlir::Value> &bounds) {
823 std::stringstream boundStr;
824 if (!bounds.empty())
825 boundStr << "_section_";
826 llvm::interleave(
827 bounds,
828 [&](mlir::Value bound) {
829 auto boundsOp =
830 mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
831 if (boundsOp.getLowerbound() &&
832 fir::getIntIfConstant(boundsOp.getLowerbound()) &&
833 boundsOp.getUpperbound() &&
834 fir::getIntIfConstant(boundsOp.getUpperbound())) {
835 boundStr << "lb" << *fir::getIntIfConstant(boundsOp.getLowerbound())
836 << ".ub" << *fir::getIntIfConstant(boundsOp.getUpperbound());
837 } else if (boundsOp.getExtent() &&
838 fir::getIntIfConstant(boundsOp.getExtent())) {
839 boundStr << "ext" << *fir::getIntIfConstant(boundsOp.getExtent());
840 } else {
841 boundStr << "?";
844 [&] { boundStr << "x"; });
845 return boundStr.str();
848 /// Rebuild the array type from the acc.bounds operation with constant
849 /// lowerbound/upperbound or extent.
850 mlir::Type getTypeFromBounds(llvm::SmallVector<mlir::Value> &bounds,
851 mlir::Type ty) {
852 auto seqTy =
853 mlir::dyn_cast_or_null<fir::SequenceType>(fir::unwrapRefType(ty));
854 if (!bounds.empty() && seqTy) {
855 llvm::SmallVector<int64_t> shape;
856 for (auto b : bounds) {
857 auto boundsOp =
858 mlir::dyn_cast<mlir::acc::DataBoundsOp>(b.getDefiningOp());
859 if (boundsOp.getLowerbound() &&
860 fir::getIntIfConstant(boundsOp.getLowerbound()) &&
861 boundsOp.getUpperbound() &&
862 fir::getIntIfConstant(boundsOp.getUpperbound())) {
863 int64_t ext = *fir::getIntIfConstant(boundsOp.getUpperbound()) -
864 *fir::getIntIfConstant(boundsOp.getLowerbound()) + 1;
865 shape.push_back(ext);
866 } else if (boundsOp.getExtent() &&
867 fir::getIntIfConstant(boundsOp.getExtent())) {
868 shape.push_back(*fir::getIntIfConstant(boundsOp.getExtent()));
869 } else {
870 return ty; // TODO: handle dynamic shaped array slice.
873 if (shape.empty() || shape.size() != bounds.size())
874 return ty;
875 auto newSeqTy = fir::SequenceType::get(shape, seqTy.getEleTy());
876 if (mlir::isa<fir::ReferenceType, fir::PointerType>(ty))
877 return fir::ReferenceType::get(newSeqTy);
878 return newSeqTy;
880 return ty;
883 template <typename RecipeOp>
884 static void
885 genPrivatizations(const Fortran::parser::AccObjectList &objectList,
886 Fortran::lower::AbstractConverter &converter,
887 Fortran::semantics::SemanticsContext &semanticsContext,
888 Fortran::lower::StatementContext &stmtCtx,
889 llvm::SmallVectorImpl<mlir::Value> &dataOperands,
890 llvm::SmallVector<mlir::Attribute> &privatizations,
891 llvm::ArrayRef<mlir::Value> async,
892 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
893 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes) {
894 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
895 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
896 for (const auto &accObject : objectList.v) {
897 llvm::SmallVector<mlir::Value> bounds;
898 std::stringstream asFortran;
899 mlir::Location operandLocation = genOperandLocation(converter, accObject);
900 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
901 Fortran::semantics::MaybeExpr designator = Fortran::common::visit(
902 [&](auto &&s) { return ea.Analyze(s); }, accObject.u);
903 fir::factory::AddrAndBoundsInfo info =
904 Fortran::lower::gatherDataOperandAddrAndBounds<
905 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
906 converter, builder, semanticsContext, stmtCtx, symbol, designator,
907 operandLocation, asFortran, bounds,
908 /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
909 /*genDefaultBounds=*/generateDefaultBounds);
910 LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
912 RecipeOp recipe;
913 mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
914 if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
915 std::string recipeName =
916 fir::getTypeAsString(retTy, converter.getKindMap(),
917 Fortran::lower::privatizationRecipePrefix);
918 recipe = Fortran::lower::createOrGetPrivateRecipe(builder, recipeName,
919 operandLocation, retTy);
920 auto op = createDataEntryOp<mlir::acc::PrivateOp>(
921 builder, operandLocation, info.addr, asFortran, bounds, true,
922 /*implicit=*/false, mlir::acc::DataClause::acc_private, retTy, async,
923 asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true);
924 dataOperands.push_back(op.getAccVar());
925 } else {
926 std::string suffix =
927 areAllBoundConstant(bounds) ? getBoundsString(bounds) : "";
928 std::string recipeName = fir::getTypeAsString(
929 retTy, converter.getKindMap(), "firstprivatization" + suffix);
930 recipe = Fortran::lower::createOrGetFirstprivateRecipe(
931 builder, recipeName, operandLocation, retTy, bounds);
932 auto op = createDataEntryOp<mlir::acc::FirstprivateOp>(
933 builder, operandLocation, info.addr, asFortran, bounds, true,
934 /*implicit=*/false, mlir::acc::DataClause::acc_firstprivate, retTy,
935 async, asyncDeviceTypes, asyncOnlyDeviceTypes,
936 /*unwrapBoxAddr=*/true);
937 dataOperands.push_back(op.getAccVar());
939 privatizations.push_back(mlir::SymbolRefAttr::get(
940 builder.getContext(), recipe.getSymName().str()));
944 /// Return the corresponding enum value for the mlir::acc::ReductionOperator
945 /// from the parser representation.
946 static mlir::acc::ReductionOperator
947 getReductionOperator(const Fortran::parser::ReductionOperator &op) {
948 switch (op.v) {
949 case Fortran::parser::ReductionOperator::Operator::Plus:
950 return mlir::acc::ReductionOperator::AccAdd;
951 case Fortran::parser::ReductionOperator::Operator::Multiply:
952 return mlir::acc::ReductionOperator::AccMul;
953 case Fortran::parser::ReductionOperator::Operator::Max:
954 return mlir::acc::ReductionOperator::AccMax;
955 case Fortran::parser::ReductionOperator::Operator::Min:
956 return mlir::acc::ReductionOperator::AccMin;
957 case Fortran::parser::ReductionOperator::Operator::Iand:
958 return mlir::acc::ReductionOperator::AccIand;
959 case Fortran::parser::ReductionOperator::Operator::Ior:
960 return mlir::acc::ReductionOperator::AccIor;
961 case Fortran::parser::ReductionOperator::Operator::Ieor:
962 return mlir::acc::ReductionOperator::AccXor;
963 case Fortran::parser::ReductionOperator::Operator::And:
964 return mlir::acc::ReductionOperator::AccLand;
965 case Fortran::parser::ReductionOperator::Operator::Or:
966 return mlir::acc::ReductionOperator::AccLor;
967 case Fortran::parser::ReductionOperator::Operator::Eqv:
968 return mlir::acc::ReductionOperator::AccEqv;
969 case Fortran::parser::ReductionOperator::Operator::Neqv:
970 return mlir::acc::ReductionOperator::AccNeqv;
972 llvm_unreachable("unexpected reduction operator");
975 /// Get the initial value for reduction operator.
976 template <typename R>
977 static R getReductionInitValue(mlir::acc::ReductionOperator op, mlir::Type ty) {
978 if (op == mlir::acc::ReductionOperator::AccMin) {
979 // min init value -> largest
980 if constexpr (std::is_same_v<R, llvm::APInt>) {
981 assert(ty.isIntOrIndex() && "expect integer or index type");
982 return llvm::APInt::getSignedMaxValue(ty.getIntOrFloatBitWidth());
984 if constexpr (std::is_same_v<R, llvm::APFloat>) {
985 auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty);
986 assert(floatTy && "expect float type");
987 return llvm::APFloat::getLargest(floatTy.getFloatSemantics(),
988 /*negative=*/false);
990 } else if (op == mlir::acc::ReductionOperator::AccMax) {
991 // max init value -> smallest
992 if constexpr (std::is_same_v<R, llvm::APInt>) {
993 assert(ty.isIntOrIndex() && "expect integer or index type");
994 return llvm::APInt::getSignedMinValue(ty.getIntOrFloatBitWidth());
996 if constexpr (std::is_same_v<R, llvm::APFloat>) {
997 auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty);
998 assert(floatTy && "expect float type");
999 return llvm::APFloat::getSmallest(floatTy.getFloatSemantics(),
1000 /*negative=*/true);
1002 } else if (op == mlir::acc::ReductionOperator::AccIand) {
1003 if constexpr (std::is_same_v<R, llvm::APInt>) {
1004 assert(ty.isIntOrIndex() && "expect integer type");
1005 unsigned bits = ty.getIntOrFloatBitWidth();
1006 return llvm::APInt::getAllOnes(bits);
1008 } else {
1009 // +, ior, ieor init value -> 0
1010 // * init value -> 1
1011 int64_t value = (op == mlir::acc::ReductionOperator::AccMul) ? 1 : 0;
1012 if constexpr (std::is_same_v<R, llvm::APInt>) {
1013 assert(ty.isIntOrIndex() && "expect integer or index type");
1014 return llvm::APInt(ty.getIntOrFloatBitWidth(), value, true);
1017 if constexpr (std::is_same_v<R, llvm::APFloat>) {
1018 assert(mlir::isa<mlir::FloatType>(ty) && "expect float type");
1019 auto floatTy = mlir::dyn_cast<mlir::FloatType>(ty);
1020 return llvm::APFloat(floatTy.getFloatSemantics(), value);
1023 if constexpr (std::is_same_v<R, int64_t>)
1024 return value;
1026 llvm_unreachable("OpenACC reduction unsupported type");
1029 /// Return a constant with the initial value for the reduction operator and
1030 /// type combination.
1031 static mlir::Value getReductionInitValue(fir::FirOpBuilder &builder,
1032 mlir::Location loc, mlir::Type ty,
1033 mlir::acc::ReductionOperator op) {
1034 if (op == mlir::acc::ReductionOperator::AccLand ||
1035 op == mlir::acc::ReductionOperator::AccLor ||
1036 op == mlir::acc::ReductionOperator::AccEqv ||
1037 op == mlir::acc::ReductionOperator::AccNeqv) {
1038 assert(mlir::isa<fir::LogicalType>(ty) && "expect fir.logical type");
1039 bool value = true; // .true. for .and. and .eqv.
1040 if (op == mlir::acc::ReductionOperator::AccLor ||
1041 op == mlir::acc::ReductionOperator::AccNeqv)
1042 value = false; // .false. for .or. and .neqv.
1043 return builder.createBool(loc, value);
1045 if (ty.isIntOrIndex())
1046 return builder.create<mlir::arith::ConstantOp>(
1047 loc, ty,
1048 builder.getIntegerAttr(ty, getReductionInitValue<llvm::APInt>(op, ty)));
1049 if (op == mlir::acc::ReductionOperator::AccMin ||
1050 op == mlir::acc::ReductionOperator::AccMax) {
1051 if (mlir::isa<mlir::ComplexType>(ty))
1052 llvm::report_fatal_error(
1053 "min/max reduction not supported for complex type");
1054 if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty))
1055 return builder.create<mlir::arith::ConstantOp>(
1056 loc, ty,
1057 builder.getFloatAttr(ty,
1058 getReductionInitValue<llvm::APFloat>(op, ty)));
1059 } else if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty)) {
1060 return builder.create<mlir::arith::ConstantOp>(
1061 loc, ty,
1062 builder.getFloatAttr(ty, getReductionInitValue<int64_t>(op, ty)));
1063 } else if (auto cmplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(ty)) {
1064 mlir::Type floatTy = cmplxTy.getElementType();
1065 mlir::Value realInit = builder.createRealConstant(
1066 loc, floatTy, getReductionInitValue<int64_t>(op, cmplxTy));
1067 mlir::Value imagInit = builder.createRealConstant(loc, floatTy, 0.0);
1068 return fir::factory::Complex{builder, loc}.createComplex(cmplxTy, realInit,
1069 imagInit);
1072 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
1073 return getReductionInitValue(builder, loc, seqTy.getEleTy(), op);
1075 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
1076 return getReductionInitValue(builder, loc, boxTy.getEleTy(), op);
1078 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
1079 return getReductionInitValue(builder, loc, heapTy.getEleTy(), op);
1081 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
1082 return getReductionInitValue(builder, loc, ptrTy.getEleTy(), op);
1084 llvm::report_fatal_error("Unsupported OpenACC reduction type");
1087 static mlir::Value genReductionInitRegion(fir::FirOpBuilder &builder,
1088 mlir::Location loc, mlir::Type ty,
1089 mlir::acc::ReductionOperator op) {
1090 ty = fir::unwrapRefType(ty);
1091 mlir::Value initValue = getReductionInitValue(builder, loc, ty, op);
1092 if (fir::isa_trivial(ty)) {
1093 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
1094 auto declareOp = builder.create<hlfir::DeclareOp>(
1095 loc, alloca, accReductionInitName, /*shape=*/nullptr,
1096 llvm::ArrayRef<mlir::Value>{}, /*dummy_scope=*/nullptr,
1097 fir::FortranVariableFlagsAttr{});
1098 builder.create<fir::StoreOp>(loc, builder.createConvert(loc, ty, initValue),
1099 declareOp.getBase());
1100 return declareOp.getBase();
1101 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) {
1102 if (fir::isa_trivial(seqTy.getEleTy())) {
1103 mlir::Value shape;
1104 auto extents = builder.getBlock()->getArguments().drop_front(1);
1105 if (seqTy.hasDynamicExtents())
1106 shape = builder.create<fir::ShapeOp>(loc, extents);
1107 else
1108 shape = genShapeOp(builder, seqTy, loc);
1109 mlir::Value alloca = builder.create<fir::AllocaOp>(
1110 loc, seqTy, /*typeparams=*/mlir::ValueRange{}, extents);
1111 auto declareOp = builder.create<hlfir::DeclareOp>(
1112 loc, alloca, accReductionInitName, shape,
1113 llvm::ArrayRef<mlir::Value>{}, /*dummy_scope=*/nullptr,
1114 fir::FortranVariableFlagsAttr{});
1115 mlir::Type idxTy = builder.getIndexType();
1116 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
1117 llvm::SmallVector<fir::DoLoopOp> loops;
1118 llvm::SmallVector<mlir::Value> ivs;
1120 if (seqTy.hasDynamicExtents()) {
1121 builder.create<hlfir::AssignOp>(loc, initValue, declareOp.getBase());
1122 return declareOp.getBase();
1124 for (auto ext : seqTy.getShape()) {
1125 auto lb = builder.createIntegerConstant(loc, idxTy, 0);
1126 auto ub = builder.createIntegerConstant(loc, idxTy, ext - 1);
1127 auto step = builder.createIntegerConstant(loc, idxTy, 1);
1128 auto loop = builder.create<fir::DoLoopOp>(loc, lb, ub, step,
1129 /*unordered=*/false);
1130 builder.setInsertionPointToStart(loop.getBody());
1131 loops.push_back(loop);
1132 ivs.push_back(loop.getInductionVar());
1134 auto coord = builder.create<fir::CoordinateOp>(loc, refTy,
1135 declareOp.getBase(), ivs);
1136 builder.create<fir::StoreOp>(loc, initValue, coord);
1137 builder.setInsertionPointAfter(loops[0]);
1138 return declareOp.getBase();
1140 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
1141 mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy());
1142 if (!fir::isa_trivial(innerTy) && !mlir::isa<fir::SequenceType>(innerTy))
1143 TODO(loc, "Unsupported boxed type for reduction");
1144 // Create the private copy from the initial fir.box.
1145 hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)};
1146 auto [temp, cleanup] = hlfir::createTempFromMold(loc, builder, source);
1147 mlir::Value newBox = temp;
1148 if (!mlir::isa<fir::BaseBoxType>(temp.getType())) {
1149 newBox = builder.create<fir::EmboxOp>(loc, boxTy, temp);
1151 builder.create<hlfir::AssignOp>(loc, initValue, newBox);
1152 return newBox;
1154 llvm::report_fatal_error("Unsupported OpenACC reduction type");
1157 template <typename Op>
1158 static mlir::Value genLogicalCombiner(fir::FirOpBuilder &builder,
1159 mlir::Location loc, mlir::Value value1,
1160 mlir::Value value2) {
1161 mlir::Type i1 = builder.getI1Type();
1162 mlir::Value v1 = builder.create<fir::ConvertOp>(loc, i1, value1);
1163 mlir::Value v2 = builder.create<fir::ConvertOp>(loc, i1, value2);
1164 mlir::Value combined = builder.create<Op>(loc, v1, v2);
1165 return builder.create<fir::ConvertOp>(loc, value1.getType(), combined);
1168 static mlir::Value genComparisonCombiner(fir::FirOpBuilder &builder,
1169 mlir::Location loc,
1170 mlir::arith::CmpIPredicate pred,
1171 mlir::Value value1,
1172 mlir::Value value2) {
1173 mlir::Type i1 = builder.getI1Type();
1174 mlir::Value v1 = builder.create<fir::ConvertOp>(loc, i1, value1);
1175 mlir::Value v2 = builder.create<fir::ConvertOp>(loc, i1, value2);
1176 mlir::Value add = builder.create<mlir::arith::CmpIOp>(loc, pred, v1, v2);
1177 return builder.create<fir::ConvertOp>(loc, value1.getType(), add);
1180 static mlir::Value genScalarCombiner(fir::FirOpBuilder &builder,
1181 mlir::Location loc,
1182 mlir::acc::ReductionOperator op,
1183 mlir::Type ty, mlir::Value value1,
1184 mlir::Value value2) {
1185 value1 = builder.loadIfRef(loc, value1);
1186 value2 = builder.loadIfRef(loc, value2);
1187 if (op == mlir::acc::ReductionOperator::AccAdd) {
1188 if (ty.isIntOrIndex())
1189 return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
1190 if (mlir::isa<mlir::FloatType>(ty))
1191 return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
1192 if (auto cmplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(ty))
1193 return builder.create<fir::AddcOp>(loc, value1, value2);
1194 TODO(loc, "reduction add type");
1197 if (op == mlir::acc::ReductionOperator::AccMul) {
1198 if (ty.isIntOrIndex())
1199 return builder.create<mlir::arith::MulIOp>(loc, value1, value2);
1200 if (mlir::isa<mlir::FloatType>(ty))
1201 return builder.create<mlir::arith::MulFOp>(loc, value1, value2);
1202 if (mlir::isa<mlir::ComplexType>(ty))
1203 return builder.create<fir::MulcOp>(loc, value1, value2);
1204 TODO(loc, "reduction mul type");
1207 if (op == mlir::acc::ReductionOperator::AccMin)
1208 return fir::genMin(builder, loc, {value1, value2});
1210 if (op == mlir::acc::ReductionOperator::AccMax)
1211 return fir::genMax(builder, loc, {value1, value2});
1213 if (op == mlir::acc::ReductionOperator::AccIand)
1214 return builder.create<mlir::arith::AndIOp>(loc, value1, value2);
1216 if (op == mlir::acc::ReductionOperator::AccIor)
1217 return builder.create<mlir::arith::OrIOp>(loc, value1, value2);
1219 if (op == mlir::acc::ReductionOperator::AccXor)
1220 return builder.create<mlir::arith::XOrIOp>(loc, value1, value2);
1222 if (op == mlir::acc::ReductionOperator::AccLand)
1223 return genLogicalCombiner<mlir::arith::AndIOp>(builder, loc, value1,
1224 value2);
1226 if (op == mlir::acc::ReductionOperator::AccLor)
1227 return genLogicalCombiner<mlir::arith::OrIOp>(builder, loc, value1, value2);
1229 if (op == mlir::acc::ReductionOperator::AccEqv)
1230 return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::eq,
1231 value1, value2);
1233 if (op == mlir::acc::ReductionOperator::AccNeqv)
1234 return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::ne,
1235 value1, value2);
1237 TODO(loc, "reduction operator");
1240 static hlfir::DesignateOp::Subscripts
1241 getTripletsFromArgs(mlir::acc::ReductionRecipeOp recipe) {
1242 hlfir::DesignateOp::Subscripts triplets;
1243 for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
1244 i += 3)
1245 triplets.emplace_back(hlfir::DesignateOp::Triplet{
1246 recipe.getCombinerRegion().getArgument(i),
1247 recipe.getCombinerRegion().getArgument(i + 1),
1248 recipe.getCombinerRegion().getArgument(i + 2)});
1249 return triplets;
1252 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
1253 mlir::acc::ReductionOperator op, mlir::Type ty,
1254 mlir::Value value1, mlir::Value value2,
1255 mlir::acc::ReductionRecipeOp &recipe,
1256 llvm::SmallVector<mlir::Value> &bounds,
1257 bool allConstantBound) {
1258 ty = fir::unwrapRefType(ty);
1260 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
1261 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
1262 llvm::SmallVector<fir::DoLoopOp> loops;
1263 llvm::SmallVector<mlir::Value> ivs;
1264 if (seqTy.hasDynamicExtents()) {
1265 auto shape =
1266 genShapeFromBoundsOrArgs(loc, builder, seqTy, bounds,
1267 recipe.getCombinerRegion().getArguments());
1268 auto v1DeclareOp = builder.create<hlfir::DeclareOp>(
1269 loc, value1, llvm::StringRef{}, shape, llvm::ArrayRef<mlir::Value>{},
1270 /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
1271 auto v2DeclareOp = builder.create<hlfir::DeclareOp>(
1272 loc, value2, llvm::StringRef{}, shape, llvm::ArrayRef<mlir::Value>{},
1273 /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
1274 hlfir::DesignateOp::Subscripts triplets = getTripletsFromArgs(recipe);
1276 llvm::SmallVector<mlir::Value> lenParamsLeft;
1277 auto leftEntity = hlfir::Entity{v1DeclareOp.getBase()};
1278 hlfir::genLengthParameters(loc, builder, leftEntity, lenParamsLeft);
1279 auto leftDesignate = builder.create<hlfir::DesignateOp>(
1280 loc, v1DeclareOp.getBase().getType(), v1DeclareOp.getBase(),
1281 /*component=*/"",
1282 /*componentShape=*/mlir::Value{}, triplets,
1283 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
1284 shape, lenParamsLeft);
1285 auto left = hlfir::Entity{leftDesignate.getResult()};
1287 llvm::SmallVector<mlir::Value> lenParamsRight;
1288 auto rightEntity = hlfir::Entity{v2DeclareOp.getBase()};
1289 hlfir::genLengthParameters(loc, builder, rightEntity, lenParamsLeft);
1290 auto rightDesignate = builder.create<hlfir::DesignateOp>(
1291 loc, v2DeclareOp.getBase().getType(), v2DeclareOp.getBase(),
1292 /*component=*/"",
1293 /*componentShape=*/mlir::Value{}, triplets,
1294 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
1295 shape, lenParamsRight);
1296 auto right = hlfir::Entity{rightDesignate.getResult()};
1298 llvm::SmallVector<mlir::Value, 1> typeParams;
1299 auto genKernel = [&builder, &loc, op, seqTy, &left, &right](
1300 mlir::Location l, fir::FirOpBuilder &b,
1301 mlir::ValueRange oneBasedIndices) -> hlfir::Entity {
1302 auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices);
1303 auto rightElement = hlfir::getElementAt(l, b, right, oneBasedIndices);
1304 auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement);
1305 auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement);
1306 return hlfir::Entity{genScalarCombiner(
1307 builder, loc, op, seqTy.getEleTy(), leftVal, rightVal)};
1309 mlir::Value elemental = hlfir::genElementalOp(
1310 loc, builder, seqTy.getEleTy(), shape, typeParams, genKernel,
1311 /*isUnordered=*/true);
1312 builder.create<hlfir::AssignOp>(loc, elemental, v1DeclareOp.getBase());
1313 return;
1315 if (bounds.empty()) {
1316 llvm::SmallVector<mlir::Value> extents;
1317 mlir::Type idxTy = builder.getIndexType();
1318 for (auto extent : seqTy.getShape()) {
1319 mlir::Value lb = builder.create<mlir::arith::ConstantOp>(
1320 loc, idxTy, builder.getIntegerAttr(idxTy, 0));
1321 mlir::Value ub = builder.create<mlir::arith::ConstantOp>(
1322 loc, idxTy, builder.getIntegerAttr(idxTy, extent - 1));
1323 mlir::Value step = builder.create<mlir::arith::ConstantOp>(
1324 loc, idxTy, builder.getIntegerAttr(idxTy, 1));
1325 auto loop = builder.create<fir::DoLoopOp>(loc, lb, ub, step,
1326 /*unordered=*/false);
1327 builder.setInsertionPointToStart(loop.getBody());
1328 loops.push_back(loop);
1329 ivs.push_back(loop.getInductionVar());
1331 } else if (allConstantBound) {
1332 // Use the constant bound directly in the combiner region so they do not
1333 // need to be passed as block argument.
1334 assert(!bounds.empty() &&
1335 "seq type with constant bounds cannot have empty bounds");
1336 for (auto bound : llvm::reverse(bounds)) {
1337 auto dataBound =
1338 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
1339 llvm::SmallVector<mlir::Value> values =
1340 genConstantBounds(builder, loc, dataBound);
1341 auto loop =
1342 builder.create<fir::DoLoopOp>(loc, values[0], values[1], values[2],
1343 /*unordered=*/false);
1344 builder.setInsertionPointToStart(loop.getBody());
1345 loops.push_back(loop);
1346 ivs.push_back(loop.getInductionVar());
1348 } else {
1349 // Lowerbound, upperbound and step are passed as block arguments.
1350 [[maybe_unused]] unsigned nbRangeArgs =
1351 recipe.getCombinerRegion().getArguments().size() - 2;
1352 assert((nbRangeArgs / 3 == seqTy.getDimension()) &&
1353 "Expect 3 block arguments per dimension");
1354 for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
1355 i += 3) {
1356 mlir::Value lb = recipe.getCombinerRegion().getArgument(i);
1357 mlir::Value ub = recipe.getCombinerRegion().getArgument(i + 1);
1358 mlir::Value step = recipe.getCombinerRegion().getArgument(i + 2);
1359 auto loop = builder.create<fir::DoLoopOp>(loc, lb, ub, step,
1360 /*unordered=*/false);
1361 builder.setInsertionPointToStart(loop.getBody());
1362 loops.push_back(loop);
1363 ivs.push_back(loop.getInductionVar());
1366 auto addr1 = builder.create<fir::CoordinateOp>(loc, refTy, value1, ivs);
1367 auto addr2 = builder.create<fir::CoordinateOp>(loc, refTy, value2, ivs);
1368 auto load1 = builder.create<fir::LoadOp>(loc, addr1);
1369 auto load2 = builder.create<fir::LoadOp>(loc, addr2);
1370 mlir::Value res =
1371 genScalarCombiner(builder, loc, op, seqTy.getEleTy(), load1, load2);
1372 builder.create<fir::StoreOp>(loc, res, addr1);
1373 builder.setInsertionPointAfter(loops[0]);
1374 } else if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
1375 mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy());
1376 if (fir::isa_trivial(innerTy)) {
1377 mlir::Value boxAddr1 = value1, boxAddr2 = value2;
1378 if (fir::isBoxAddress(boxAddr1.getType()))
1379 boxAddr1 = builder.create<fir::LoadOp>(loc, boxAddr1);
1380 if (fir::isBoxAddress(boxAddr2.getType()))
1381 boxAddr2 = builder.create<fir::LoadOp>(loc, boxAddr2);
1382 boxAddr1 = builder.create<fir::BoxAddrOp>(loc, boxAddr1);
1383 boxAddr2 = builder.create<fir::BoxAddrOp>(loc, boxAddr2);
1384 auto leftEntity = hlfir::Entity{boxAddr1};
1385 auto rightEntity = hlfir::Entity{boxAddr2};
1387 auto leftVal = hlfir::loadTrivialScalar(loc, builder, leftEntity);
1388 auto rightVal = hlfir::loadTrivialScalar(loc, builder, rightEntity);
1389 mlir::Value res =
1390 genScalarCombiner(builder, loc, op, innerTy, leftVal, rightVal);
1391 builder.create<hlfir::AssignOp>(loc, res, boxAddr1);
1392 } else {
1393 mlir::Type innerTy = fir::extractSequenceType(boxTy);
1394 fir::SequenceType seqTy =
1395 mlir::dyn_cast_or_null<fir::SequenceType>(innerTy);
1396 if (!seqTy)
1397 TODO(loc, "Unsupported boxed type in OpenACC reduction combiner");
1399 auto shape =
1400 genShapeFromBoundsOrArgs(loc, builder, seqTy, bounds,
1401 recipe.getCombinerRegion().getArguments());
1402 hlfir::DesignateOp::Subscripts triplets =
1403 getSubscriptsFromArgs(recipe.getCombinerRegion().getArguments());
1404 auto leftEntity = hlfir::Entity{value1};
1405 if (fir::isBoxAddress(value1.getType()))
1406 leftEntity =
1407 hlfir::Entity{builder.create<fir::LoadOp>(loc, value1).getResult()};
1408 auto left =
1409 genDesignateWithTriplets(builder, loc, leftEntity, triplets, shape);
1410 auto rightEntity = hlfir::Entity{value2};
1411 if (fir::isBoxAddress(value2.getType()))
1412 rightEntity =
1413 hlfir::Entity{builder.create<fir::LoadOp>(loc, value2).getResult()};
1414 auto right =
1415 genDesignateWithTriplets(builder, loc, rightEntity, triplets, shape);
1417 llvm::SmallVector<mlir::Value, 1> typeParams;
1418 auto genKernel = [&builder, &loc, op, seqTy, &left, &right](
1419 mlir::Location l, fir::FirOpBuilder &b,
1420 mlir::ValueRange oneBasedIndices) -> hlfir::Entity {
1421 auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices);
1422 auto rightElement = hlfir::getElementAt(l, b, right, oneBasedIndices);
1423 auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement);
1424 auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement);
1425 return hlfir::Entity{genScalarCombiner(
1426 builder, loc, op, seqTy.getEleTy(), leftVal, rightVal)};
1428 mlir::Value elemental = hlfir::genElementalOp(
1429 loc, builder, seqTy.getEleTy(), shape, typeParams, genKernel,
1430 /*isUnordered=*/true);
1431 builder.create<hlfir::AssignOp>(loc, elemental, value1);
1433 } else {
1434 mlir::Value res = genScalarCombiner(builder, loc, op, ty, value1, value2);
1435 builder.create<fir::StoreOp>(loc, res, value1);
1439 mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe(
1440 fir::FirOpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
1441 mlir::Type ty, mlir::acc::ReductionOperator op,
1442 llvm::SmallVector<mlir::Value> &bounds) {
1443 mlir::ModuleOp mod =
1444 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
1445 if (auto recipe = mod.lookupSymbol<mlir::acc::ReductionRecipeOp>(recipeName))
1446 return recipe;
1448 auto crtPos = builder.saveInsertionPoint();
1449 mlir::OpBuilder modBuilder(mod.getBodyRegion());
1450 auto recipe =
1451 modBuilder.create<mlir::acc::ReductionRecipeOp>(loc, recipeName, ty, op);
1452 llvm::SmallVector<mlir::Type> initArgsTy{ty};
1453 llvm::SmallVector<mlir::Location> initArgsLoc{loc};
1454 mlir::Type refTy = fir::unwrapRefType(ty);
1455 if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(refTy)) {
1456 if (seqTy.hasDynamicExtents()) {
1457 mlir::Type idxTy = builder.getIndexType();
1458 for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
1459 initArgsTy.push_back(idxTy);
1460 initArgsLoc.push_back(loc);
1464 builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(),
1465 initArgsTy, initArgsLoc);
1466 builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
1467 mlir::Value initValue = genReductionInitRegion(builder, loc, ty, op);
1468 builder.create<mlir::acc::YieldOp>(loc, initValue);
1470 // The two first block arguments are the two values to be combined.
1471 // The next arguments are the iteration ranges (lb, ub, step) to be used
1472 // for the combiner if needed.
1473 llvm::SmallVector<mlir::Type> argsTy{ty, ty};
1474 llvm::SmallVector<mlir::Location> argsLoc{loc, loc};
1475 bool allConstantBound = areAllBoundConstant(bounds);
1476 if (!allConstantBound) {
1477 for (mlir::Value bound : llvm::reverse(bounds)) {
1478 auto dataBound =
1479 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
1480 argsTy.push_back(dataBound.getLowerbound().getType());
1481 argsLoc.push_back(dataBound.getLowerbound().getLoc());
1482 argsTy.push_back(dataBound.getUpperbound().getType());
1483 argsLoc.push_back(dataBound.getUpperbound().getLoc());
1484 argsTy.push_back(dataBound.getStartIdx().getType());
1485 argsLoc.push_back(dataBound.getStartIdx().getLoc());
1488 builder.createBlock(&recipe.getCombinerRegion(),
1489 recipe.getCombinerRegion().end(), argsTy, argsLoc);
1490 builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back());
1491 mlir::Value v1 = recipe.getCombinerRegion().front().getArgument(0);
1492 mlir::Value v2 = recipe.getCombinerRegion().front().getArgument(1);
1493 genCombiner(builder, loc, op, ty, v1, v2, recipe, bounds, allConstantBound);
1494 builder.create<mlir::acc::YieldOp>(loc, v1);
1495 builder.restoreInsertionPoint(crtPos);
1496 return recipe;
1499 static bool isSupportedReductionType(mlir::Type ty) {
1500 ty = fir::unwrapRefType(ty);
1501 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
1502 return isSupportedReductionType(boxTy.getEleTy());
1503 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
1504 return isSupportedReductionType(seqTy.getEleTy());
1505 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
1506 return isSupportedReductionType(heapTy.getEleTy());
1507 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
1508 return isSupportedReductionType(ptrTy.getEleTy());
1509 return fir::isa_trivial(ty);
1512 static void
1513 genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
1514 Fortran::lower::AbstractConverter &converter,
1515 Fortran::semantics::SemanticsContext &semanticsContext,
1516 Fortran::lower::StatementContext &stmtCtx,
1517 llvm::SmallVectorImpl<mlir::Value> &reductionOperands,
1518 llvm::SmallVector<mlir::Attribute> &reductionRecipes,
1519 llvm::ArrayRef<mlir::Value> async,
1520 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
1521 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes) {
1522 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
1523 const auto &objects = std::get<Fortran::parser::AccObjectList>(objectList.t);
1524 const auto &op = std::get<Fortran::parser::ReductionOperator>(objectList.t);
1525 mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
1526 Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
1527 for (const auto &accObject : objects.v) {
1528 llvm::SmallVector<mlir::Value> bounds;
1529 std::stringstream asFortran;
1530 mlir::Location operandLocation = genOperandLocation(converter, accObject);
1531 Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
1532 Fortran::semantics::MaybeExpr designator = Fortran::common::visit(
1533 [&](auto &&s) { return ea.Analyze(s); }, accObject.u);
1534 fir::factory::AddrAndBoundsInfo info =
1535 Fortran::lower::gatherDataOperandAddrAndBounds<
1536 mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
1537 converter, builder, semanticsContext, stmtCtx, symbol, designator,
1538 operandLocation, asFortran, bounds,
1539 /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
1540 /*genDefaultBounds=*/generateDefaultBounds);
1541 LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
1543 mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType());
1544 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
1545 reductionTy = seqTy.getEleTy();
1547 if (!isSupportedReductionType(reductionTy))
1548 TODO(operandLocation, "reduction with unsupported type");
1550 auto op = createDataEntryOp<mlir::acc::ReductionOp>(
1551 builder, operandLocation, info.addr, asFortran, bounds,
1552 /*structured=*/true, /*implicit=*/false,
1553 mlir::acc::DataClause::acc_reduction, info.addr.getType(), async,
1554 asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true);
1555 mlir::Type ty = op.getAccVar().getType();
1556 if (!areAllBoundConstant(bounds) ||
1557 fir::isAssumedShape(info.addr.getType()) ||
1558 fir::isAllocatableOrPointerArray(info.addr.getType()))
1559 ty = info.addr.getType();
1560 std::string suffix =
1561 areAllBoundConstant(bounds) ? getBoundsString(bounds) : "";
1562 std::string recipeName = fir::getTypeAsString(
1563 ty, converter.getKindMap(),
1564 ("reduction_" + stringifyReductionOperator(mlirOp)).str() + suffix);
1566 mlir::acc::ReductionRecipeOp recipe =
1567 Fortran::lower::createOrGetReductionRecipe(
1568 builder, recipeName, operandLocation, ty, mlirOp, bounds);
1569 reductionRecipes.push_back(mlir::SymbolRefAttr::get(
1570 builder.getContext(), recipe.getSymName().str()));
1571 reductionOperands.push_back(op.getAccVar());
1575 template <typename Op, typename Terminator>
1576 static Op
1577 createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
1578 mlir::Location returnLoc, Fortran::lower::pft::Evaluation &eval,
1579 const llvm::SmallVectorImpl<mlir::Value> &operands,
1580 const llvm::SmallVectorImpl<int32_t> &operandSegments,
1581 bool outerCombined = false,
1582 llvm::SmallVector<mlir::Type> retTy = {},
1583 mlir::Value yieldValue = {}, mlir::TypeRange argsTy = {},
1584 llvm::SmallVector<mlir::Location> locs = {}) {
1585 Op op = builder.create<Op>(loc, retTy, operands);
1586 builder.createBlock(&op.getRegion(), op.getRegion().end(), argsTy, locs);
1587 mlir::Block &block = op.getRegion().back();
1588 builder.setInsertionPointToStart(&block);
1590 op->setAttr(Op::getOperandSegmentSizeAttr(),
1591 builder.getDenseI32ArrayAttr(operandSegments));
1593 // Place the insertion point to the start of the first block.
1594 builder.setInsertionPointToStart(&block);
1596 // If it is an unstructured region and is not the outer region of a combined
1597 // construct, create empty blocks for all evaluations.
1598 if (eval.lowerAsUnstructured() && !outerCombined)
1599 Fortran::lower::createEmptyRegionBlocks<mlir::acc::TerminatorOp,
1600 mlir::acc::YieldOp>(
1601 builder, eval.getNestedEvaluations());
1603 if (yieldValue) {
1604 if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) {
1605 Terminator yieldOp = builder.create<Terminator>(returnLoc, yieldValue);
1606 yieldValue.getDefiningOp()->moveBefore(yieldOp);
1607 } else {
1608 builder.create<Terminator>(returnLoc);
1610 } else {
1611 builder.create<Terminator>(returnLoc);
1613 builder.setInsertionPointToStart(&block);
1614 return op;
1617 static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
1618 const Fortran::parser::AccClause::Async *asyncClause,
1619 mlir::Value &async, bool &addAsyncAttr,
1620 Fortran::lower::StatementContext &stmtCtx) {
1621 const auto &asyncClauseValue = asyncClause->v;
1622 if (asyncClauseValue) { // async has a value.
1623 async = fir::getBase(converter.genExprValue(
1624 *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
1625 } else {
1626 addAsyncAttr = true;
1630 static void
1631 genAsyncClause(Fortran::lower::AbstractConverter &converter,
1632 const Fortran::parser::AccClause::Async *asyncClause,
1633 llvm::SmallVector<mlir::Value> &async,
1634 llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes,
1635 llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes,
1636 llvm::SmallVector<mlir::Attribute> &deviceTypeAttrs,
1637 Fortran::lower::StatementContext &stmtCtx) {
1638 const auto &asyncClauseValue = asyncClause->v;
1639 if (asyncClauseValue) { // async has a value.
1640 mlir::Value asyncValue = fir::getBase(converter.genExprValue(
1641 *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
1642 for (auto deviceTypeAttr : deviceTypeAttrs) {
1643 async.push_back(asyncValue);
1644 asyncDeviceTypes.push_back(deviceTypeAttr);
1646 } else {
1647 for (auto deviceTypeAttr : deviceTypeAttrs)
1648 asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
1652 static mlir::acc::DeviceType
1653 getDeviceType(Fortran::common::OpenACCDeviceType device) {
1654 switch (device) {
1655 case Fortran::common::OpenACCDeviceType::Star:
1656 return mlir::acc::DeviceType::Star;
1657 case Fortran::common::OpenACCDeviceType::Default:
1658 return mlir::acc::DeviceType::Default;
1659 case Fortran::common::OpenACCDeviceType::Nvidia:
1660 return mlir::acc::DeviceType::Nvidia;
1661 case Fortran::common::OpenACCDeviceType::Radeon:
1662 return mlir::acc::DeviceType::Radeon;
1663 case Fortran::common::OpenACCDeviceType::Host:
1664 return mlir::acc::DeviceType::Host;
1665 case Fortran::common::OpenACCDeviceType::Multicore:
1666 return mlir::acc::DeviceType::Multicore;
1667 case Fortran::common::OpenACCDeviceType::None:
1668 return mlir::acc::DeviceType::None;
1670 return mlir::acc::DeviceType::None;
1673 static void gatherDeviceTypeAttrs(
1674 fir::FirOpBuilder &builder,
1675 const Fortran::parser::AccClause::DeviceType *deviceTypeClause,
1676 llvm::SmallVector<mlir::Attribute> &deviceTypes) {
1677 const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
1678 deviceTypeClause->v;
1679 for (const auto &deviceTypeExpr : deviceTypeExprList.v)
1680 deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
1681 builder.getContext(), getDeviceType(deviceTypeExpr.v)));
1684 static void genIfClause(Fortran::lower::AbstractConverter &converter,
1685 mlir::Location clauseLocation,
1686 const Fortran::parser::AccClause::If *ifClause,
1687 mlir::Value &ifCond,
1688 Fortran::lower::StatementContext &stmtCtx) {
1689 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1690 mlir::Value cond = fir::getBase(converter.genExprValue(
1691 *Fortran::semantics::GetExpr(ifClause->v), stmtCtx, &clauseLocation));
1692 ifCond = firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(),
1693 cond);
1696 static void genWaitClause(Fortran::lower::AbstractConverter &converter,
1697 const Fortran::parser::AccClause::Wait *waitClause,
1698 llvm::SmallVectorImpl<mlir::Value> &operands,
1699 mlir::Value &waitDevnum, bool &addWaitAttr,
1700 Fortran::lower::StatementContext &stmtCtx) {
1701 const auto &waitClauseValue = waitClause->v;
1702 if (waitClauseValue) { // wait has a value.
1703 const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1704 const auto &waitList =
1705 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1706 for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1707 mlir::Value v = fir::getBase(
1708 converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx));
1709 operands.push_back(v);
1712 const auto &waitDevnumValue =
1713 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1714 if (waitDevnumValue)
1715 waitDevnum = fir::getBase(converter.genExprValue(
1716 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
1717 } else {
1718 addWaitAttr = true;
1722 static void genWaitClauseWithDeviceType(
1723 Fortran::lower::AbstractConverter &converter,
1724 const Fortran::parser::AccClause::Wait *waitClause,
1725 llvm::SmallVector<mlir::Value> &waitOperands,
1726 llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
1727 llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
1728 llvm::SmallVector<bool> &hasDevnums,
1729 llvm::SmallVector<int32_t> &waitOperandsSegments,
1730 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
1731 Fortran::lower::StatementContext &stmtCtx) {
1732 const auto &waitClauseValue = waitClause->v;
1733 if (waitClauseValue) { // wait has a value.
1734 llvm::SmallVector<mlir::Value> waitValues;
1736 const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1737 const auto &waitDevnumValue =
1738 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1739 bool hasDevnum = false;
1740 if (waitDevnumValue) {
1741 waitValues.push_back(fir::getBase(converter.genExprValue(
1742 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)));
1743 hasDevnum = true;
1746 const auto &waitList =
1747 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1748 for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1749 waitValues.push_back(fir::getBase(converter.genExprValue(
1750 *Fortran::semantics::GetExpr(value), stmtCtx)));
1753 for (auto deviceTypeAttr : deviceTypeAttrs) {
1754 for (auto value : waitValues)
1755 waitOperands.push_back(value);
1756 waitOperandsDeviceTypes.push_back(deviceTypeAttr);
1757 waitOperandsSegments.push_back(waitValues.size());
1758 hasDevnums.push_back(hasDevnum);
1760 } else {
1761 for (auto deviceTypeAttr : deviceTypeAttrs)
1762 waitOnlyDeviceTypes.push_back(deviceTypeAttr);
1766 mlir::Type getTypeFromIvTypeSize(fir::FirOpBuilder &builder,
1767 const Fortran::semantics::Symbol &ivSym) {
1768 std::size_t ivTypeSize = ivSym.size();
1769 if (ivTypeSize == 0)
1770 llvm::report_fatal_error("unexpected induction variable size");
1771 // ivTypeSize is in bytes and IntegerType needs to be in bits.
1772 return builder.getIntegerType(ivTypeSize * 8);
1775 static void privatizeIv(Fortran::lower::AbstractConverter &converter,
1776 const Fortran::semantics::Symbol &sym,
1777 mlir::Location loc,
1778 llvm::SmallVector<mlir::Type> &ivTypes,
1779 llvm::SmallVector<mlir::Location> &ivLocs,
1780 llvm::SmallVector<mlir::Value> &privateOperands,
1781 llvm::SmallVector<mlir::Value> &ivPrivate,
1782 llvm::SmallVector<mlir::Attribute> &privatizations,
1783 bool isDoConcurrent = false) {
1784 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
1786 mlir::Type ivTy = getTypeFromIvTypeSize(builder, sym);
1787 ivTypes.push_back(ivTy);
1788 ivLocs.push_back(loc);
1789 mlir::Value ivValue = converter.getSymbolAddress(sym);
1790 if (!ivValue && isDoConcurrent) {
1791 // DO CONCURRENT induction variables are not mapped yet since they are local
1792 // to the DO CONCURRENT scope.
1793 mlir::OpBuilder::InsertPoint insPt = builder.saveInsertionPoint();
1794 builder.setInsertionPointToStart(builder.getAllocaBlock());
1795 ivValue = builder.createTemporaryAlloc(loc, ivTy, toStringRef(sym.name()));
1796 builder.restoreInsertionPoint(insPt);
1799 std::string recipeName =
1800 fir::getTypeAsString(ivValue.getType(), converter.getKindMap(),
1801 Fortran::lower::privatizationRecipePrefix);
1802 auto recipe = Fortran::lower::createOrGetPrivateRecipe(
1803 builder, recipeName, loc, ivValue.getType());
1805 std::stringstream asFortran;
1806 auto op = createDataEntryOp<mlir::acc::PrivateOp>(
1807 builder, loc, ivValue, asFortran, {}, true, /*implicit=*/true,
1808 mlir::acc::DataClause::acc_private, ivValue.getType(),
1809 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
1811 privateOperands.push_back(op.getAccVar());
1812 privatizations.push_back(mlir::SymbolRefAttr::get(builder.getContext(),
1813 recipe.getSymName().str()));
1815 // Map the new private iv to its symbol for the scope of the loop. bindSymbol
1816 // might create a hlfir.declare op, if so, we map its result in order to
1817 // use the sym value in the scope.
1818 converter.bindSymbol(sym, op.getAccVar());
1819 auto privateValue = converter.getSymbolAddress(sym);
1820 if (auto declareOp =
1821 mlir::dyn_cast<hlfir::DeclareOp>(privateValue.getDefiningOp()))
1822 privateValue = declareOp.getResults()[0];
1823 ivPrivate.push_back(privateValue);
1826 static mlir::acc::LoopOp createLoopOp(
1827 Fortran::lower::AbstractConverter &converter,
1828 mlir::Location currentLocation,
1829 Fortran::semantics::SemanticsContext &semanticsContext,
1830 Fortran::lower::StatementContext &stmtCtx,
1831 const Fortran::parser::DoConstruct &outerDoConstruct,
1832 Fortran::lower::pft::Evaluation &eval,
1833 const Fortran::parser::AccClauseList &accClauseList,
1834 std::optional<mlir::acc::CombinedConstructsType> combinedConstructs =
1835 std::nullopt,
1836 bool needEarlyReturnHandling = false) {
1837 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
1838 llvm::SmallVector<mlir::Value> tileOperands, privateOperands, ivPrivate,
1839 reductionOperands, cacheOperands, vectorOperands, workerNumOperands,
1840 gangOperands, lowerbounds, upperbounds, steps;
1841 llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
1842 llvm::SmallVector<int32_t> tileOperandsSegments, gangOperandsSegments;
1843 llvm::SmallVector<int64_t> collapseValues;
1845 llvm::SmallVector<mlir::Attribute> gangArgTypes;
1846 llvm::SmallVector<mlir::Attribute> seqDeviceTypes, independentDeviceTypes,
1847 autoDeviceTypes, vectorOperandsDeviceTypes, workerNumOperandsDeviceTypes,
1848 vectorDeviceTypes, workerNumDeviceTypes, tileOperandsDeviceTypes,
1849 collapseDeviceTypes, gangDeviceTypes, gangOperandsDeviceTypes;
1851 // device_type attribute is set to `none` until a device_type clause is
1852 // encountered.
1853 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
1854 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
1855 builder.getContext(), mlir::acc::DeviceType::None));
1857 llvm::SmallVector<mlir::Type> ivTypes;
1858 llvm::SmallVector<mlir::Location> ivLocs;
1859 llvm::SmallVector<bool> inclusiveBounds;
1861 llvm::SmallVector<mlir::Location> locs;
1862 locs.push_back(currentLocation); // Location of the directive
1863 Fortran::lower::pft::Evaluation *crtEval = &eval.getFirstNestedEvaluation();
1864 bool isDoConcurrent = outerDoConstruct.IsDoConcurrent();
1865 if (isDoConcurrent) {
1866 locs.push_back(converter.genLocation(
1867 Fortran::parser::FindSourceLocation(outerDoConstruct)));
1868 const Fortran::parser::LoopControl *loopControl =
1869 &*outerDoConstruct.GetLoopControl();
1870 const auto &concurrent =
1871 std::get<Fortran::parser::LoopControl::Concurrent>(loopControl->u);
1872 if (!std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent.t)
1873 .empty())
1874 TODO(currentLocation, "DO CONCURRENT with locality spec");
1876 const auto &concurrentHeader =
1877 std::get<Fortran::parser::ConcurrentHeader>(concurrent.t);
1878 const auto &controls =
1879 std::get<std::list<Fortran::parser::ConcurrentControl>>(
1880 concurrentHeader.t);
1881 for (const auto &control : controls) {
1882 lowerbounds.push_back(fir::getBase(converter.genExprValue(
1883 *Fortran::semantics::GetExpr(std::get<1>(control.t)), stmtCtx)));
1884 upperbounds.push_back(fir::getBase(converter.genExprValue(
1885 *Fortran::semantics::GetExpr(std::get<2>(control.t)), stmtCtx)));
1886 if (const auto &expr =
1887 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
1888 control.t))
1889 steps.push_back(fir::getBase(converter.genExprValue(
1890 *Fortran::semantics::GetExpr(*expr), stmtCtx)));
1891 else // If `step` is not present, assume it is `1`.
1892 steps.push_back(builder.createIntegerConstant(
1893 currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1));
1895 const auto &name = std::get<Fortran::parser::Name>(control.t);
1896 privatizeIv(converter, *name.symbol, currentLocation, ivTypes, ivLocs,
1897 privateOperands, ivPrivate, privatizations, isDoConcurrent);
1899 inclusiveBounds.push_back(true);
1901 } else {
1902 int64_t collapseValue = Fortran::lower::getCollapseValue(accClauseList);
1903 for (unsigned i = 0; i < collapseValue; ++i) {
1904 const Fortran::parser::LoopControl *loopControl;
1905 if (i == 0) {
1906 loopControl = &*outerDoConstruct.GetLoopControl();
1907 locs.push_back(converter.genLocation(
1908 Fortran::parser::FindSourceLocation(outerDoConstruct)));
1909 } else {
1910 auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
1911 assert(doCons && "expect do construct");
1912 loopControl = &*doCons->GetLoopControl();
1913 locs.push_back(converter.genLocation(
1914 Fortran::parser::FindSourceLocation(*doCons)));
1917 const Fortran::parser::LoopControl::Bounds *bounds =
1918 std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
1919 assert(bounds && "Expected bounds on the loop construct");
1920 lowerbounds.push_back(fir::getBase(converter.genExprValue(
1921 *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
1922 upperbounds.push_back(fir::getBase(converter.genExprValue(
1923 *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
1924 if (bounds->step)
1925 steps.push_back(fir::getBase(converter.genExprValue(
1926 *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
1927 else // If `step` is not present, assume it is `1`.
1928 steps.push_back(builder.createIntegerConstant(
1929 currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1));
1931 Fortran::semantics::Symbol &ivSym =
1932 bounds->name.thing.symbol->GetUltimate();
1933 privatizeIv(converter, ivSym, currentLocation, ivTypes, ivLocs,
1934 privateOperands, ivPrivate, privatizations);
1936 inclusiveBounds.push_back(true);
1938 if (i < collapseValue - 1)
1939 crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
1943 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
1944 mlir::Location clauseLocation = converter.genLocation(clause.source);
1945 if (const auto *gangClause =
1946 std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
1947 if (gangClause->v) {
1948 const Fortran::parser::AccGangArgList &x = *gangClause->v;
1949 mlir::SmallVector<mlir::Value> gangValues;
1950 mlir::SmallVector<mlir::Attribute> gangArgs;
1951 for (const Fortran::parser::AccGangArg &gangArg : x.v) {
1952 if (const auto *num =
1953 std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u)) {
1954 gangValues.push_back(fir::getBase(converter.genExprValue(
1955 *Fortran::semantics::GetExpr(num->v), stmtCtx)));
1956 gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
1957 builder.getContext(), mlir::acc::GangArgType::Num));
1958 } else if (const auto *staticArg =
1959 std::get_if<Fortran::parser::AccGangArg::Static>(
1960 &gangArg.u)) {
1961 const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v;
1962 if (sizeExpr.v) {
1963 gangValues.push_back(fir::getBase(converter.genExprValue(
1964 *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx)));
1965 } else {
1966 // * was passed as value and will be represented as a special
1967 // constant.
1968 gangValues.push_back(builder.createIntegerConstant(
1969 clauseLocation, builder.getIndexType(), starCst));
1971 gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
1972 builder.getContext(), mlir::acc::GangArgType::Static));
1973 } else if (const auto *dim =
1974 std::get_if<Fortran::parser::AccGangArg::Dim>(
1975 &gangArg.u)) {
1976 gangValues.push_back(fir::getBase(converter.genExprValue(
1977 *Fortran::semantics::GetExpr(dim->v), stmtCtx)));
1978 gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
1979 builder.getContext(), mlir::acc::GangArgType::Dim));
1982 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
1983 for (const auto &pair : llvm::zip(gangValues, gangArgs)) {
1984 gangOperands.push_back(std::get<0>(pair));
1985 gangArgTypes.push_back(std::get<1>(pair));
1987 gangOperandsSegments.push_back(gangValues.size());
1988 gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1990 } else {
1991 for (auto crtDeviceTypeAttr : crtDeviceTypes)
1992 gangDeviceTypes.push_back(crtDeviceTypeAttr);
1994 } else if (const auto *workerClause =
1995 std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
1996 if (workerClause->v) {
1997 mlir::Value workerNumValue = fir::getBase(converter.genExprValue(
1998 *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx));
1999 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2000 workerNumOperands.push_back(workerNumValue);
2001 workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
2003 } else {
2004 for (auto crtDeviceTypeAttr : crtDeviceTypes)
2005 workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
2007 } else if (const auto *vectorClause =
2008 std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
2009 if (vectorClause->v) {
2010 mlir::Value vectorValue = fir::getBase(converter.genExprValue(
2011 *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx));
2012 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2013 vectorOperands.push_back(vectorValue);
2014 vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
2016 } else {
2017 for (auto crtDeviceTypeAttr : crtDeviceTypes)
2018 vectorDeviceTypes.push_back(crtDeviceTypeAttr);
2020 } else if (const auto *tileClause =
2021 std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
2022 const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v;
2023 llvm::SmallVector<mlir::Value> tileValues;
2024 for (const auto &accTileExpr : accTileExprList.v) {
2025 const auto &expr =
2026 std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
2027 accTileExpr.t);
2028 if (expr) {
2029 tileValues.push_back(fir::getBase(converter.genExprValue(
2030 *Fortran::semantics::GetExpr(*expr), stmtCtx)));
2031 } else {
2032 // * was passed as value and will be represented as a special
2033 // constant.
2034 mlir::Value tileStar = builder.createIntegerConstant(
2035 clauseLocation, builder.getIntegerType(32), starCst);
2036 tileValues.push_back(tileStar);
2039 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2040 for (auto value : tileValues)
2041 tileOperands.push_back(value);
2042 tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
2043 tileOperandsSegments.push_back(tileValues.size());
2045 } else if (const auto *privateClause =
2046 std::get_if<Fortran::parser::AccClause::Private>(
2047 &clause.u)) {
2048 genPrivatizations<mlir::acc::PrivateRecipeOp>(
2049 privateClause->v, converter, semanticsContext, stmtCtx,
2050 privateOperands, privatizations, /*async=*/{},
2051 /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
2052 } else if (const auto *reductionClause =
2053 std::get_if<Fortran::parser::AccClause::Reduction>(
2054 &clause.u)) {
2055 genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
2056 reductionOperands, reductionRecipes, /*async=*/{},
2057 /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
2058 } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
2059 for (auto crtDeviceTypeAttr : crtDeviceTypes)
2060 seqDeviceTypes.push_back(crtDeviceTypeAttr);
2061 } else if (std::get_if<Fortran::parser::AccClause::Independent>(
2062 &clause.u)) {
2063 for (auto crtDeviceTypeAttr : crtDeviceTypes)
2064 independentDeviceTypes.push_back(crtDeviceTypeAttr);
2065 } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
2066 for (auto crtDeviceTypeAttr : crtDeviceTypes)
2067 autoDeviceTypes.push_back(crtDeviceTypeAttr);
2068 } else if (const auto *deviceTypeClause =
2069 std::get_if<Fortran::parser::AccClause::DeviceType>(
2070 &clause.u)) {
2071 crtDeviceTypes.clear();
2072 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
2073 } else if (const auto *collapseClause =
2074 std::get_if<Fortran::parser::AccClause::Collapse>(
2075 &clause.u)) {
2076 const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
2077 const auto &force = std::get<bool>(arg.t);
2078 if (force)
2079 TODO(clauseLocation, "OpenACC collapse force modifier");
2081 const auto &intExpr =
2082 std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
2083 const auto *expr = Fortran::semantics::GetExpr(intExpr);
2084 const std::optional<int64_t> collapseValue =
2085 Fortran::evaluate::ToInt64(*expr);
2086 assert(collapseValue && "expect integer value for the collapse clause");
2088 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2089 collapseValues.push_back(*collapseValue);
2090 collapseDeviceTypes.push_back(crtDeviceTypeAttr);
2095 // Prepare the operand segment size attribute and the operands value range.
2096 llvm::SmallVector<mlir::Value> operands;
2097 llvm::SmallVector<int32_t> operandSegments;
2098 addOperands(operands, operandSegments, lowerbounds);
2099 addOperands(operands, operandSegments, upperbounds);
2100 addOperands(operands, operandSegments, steps);
2101 addOperands(operands, operandSegments, gangOperands);
2102 addOperands(operands, operandSegments, workerNumOperands);
2103 addOperands(operands, operandSegments, vectorOperands);
2104 addOperands(operands, operandSegments, tileOperands);
2105 addOperands(operands, operandSegments, cacheOperands);
2106 addOperands(operands, operandSegments, privateOperands);
2107 addOperands(operands, operandSegments, reductionOperands);
2109 llvm::SmallVector<mlir::Type> retTy;
2110 mlir::Value yieldValue;
2111 if (needEarlyReturnHandling) {
2112 mlir::Type i1Ty = builder.getI1Type();
2113 yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0);
2114 retTy.push_back(i1Ty);
2117 auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
2118 builder, builder.getFusedLoc(locs), currentLocation, eval, operands,
2119 operandSegments, /*outerCombined=*/false, retTy, yieldValue, ivTypes,
2120 ivLocs);
2122 for (auto [arg, value] : llvm::zip(
2123 loopOp.getLoopRegions().front()->front().getArguments(), ivPrivate))
2124 builder.create<fir::StoreOp>(currentLocation, arg, value);
2126 loopOp.setInclusiveUpperbound(inclusiveBounds);
2128 if (!gangDeviceTypes.empty())
2129 loopOp.setGangAttr(builder.getArrayAttr(gangDeviceTypes));
2130 if (!gangArgTypes.empty())
2131 loopOp.setGangOperandsArgTypeAttr(builder.getArrayAttr(gangArgTypes));
2132 if (!gangOperandsSegments.empty())
2133 loopOp.setGangOperandsSegmentsAttr(
2134 builder.getDenseI32ArrayAttr(gangOperandsSegments));
2135 if (!gangOperandsDeviceTypes.empty())
2136 loopOp.setGangOperandsDeviceTypeAttr(
2137 builder.getArrayAttr(gangOperandsDeviceTypes));
2139 if (!workerNumDeviceTypes.empty())
2140 loopOp.setWorkerAttr(builder.getArrayAttr(workerNumDeviceTypes));
2141 if (!workerNumOperandsDeviceTypes.empty())
2142 loopOp.setWorkerNumOperandsDeviceTypeAttr(
2143 builder.getArrayAttr(workerNumOperandsDeviceTypes));
2145 if (!vectorDeviceTypes.empty())
2146 loopOp.setVectorAttr(builder.getArrayAttr(vectorDeviceTypes));
2147 if (!vectorOperandsDeviceTypes.empty())
2148 loopOp.setVectorOperandsDeviceTypeAttr(
2149 builder.getArrayAttr(vectorOperandsDeviceTypes));
2151 if (!tileOperandsDeviceTypes.empty())
2152 loopOp.setTileOperandsDeviceTypeAttr(
2153 builder.getArrayAttr(tileOperandsDeviceTypes));
2154 if (!tileOperandsSegments.empty())
2155 loopOp.setTileOperandsSegmentsAttr(
2156 builder.getDenseI32ArrayAttr(tileOperandsSegments));
2158 if (!seqDeviceTypes.empty())
2159 loopOp.setSeqAttr(builder.getArrayAttr(seqDeviceTypes));
2160 if (!independentDeviceTypes.empty())
2161 loopOp.setIndependentAttr(builder.getArrayAttr(independentDeviceTypes));
2162 if (!autoDeviceTypes.empty())
2163 loopOp.setAuto_Attr(builder.getArrayAttr(autoDeviceTypes));
2165 if (!privatizations.empty())
2166 loopOp.setPrivatizationsAttr(
2167 mlir::ArrayAttr::get(builder.getContext(), privatizations));
2169 if (!reductionRecipes.empty())
2170 loopOp.setReductionRecipesAttr(
2171 mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
2173 if (!collapseValues.empty())
2174 loopOp.setCollapseAttr(builder.getI64ArrayAttr(collapseValues));
2175 if (!collapseDeviceTypes.empty())
2176 loopOp.setCollapseDeviceTypeAttr(builder.getArrayAttr(collapseDeviceTypes));
2178 if (combinedConstructs)
2179 loopOp.setCombinedAttr(mlir::acc::CombinedConstructsTypeAttr::get(
2180 builder.getContext(), *combinedConstructs));
2182 // TODO: retrieve directives from NonLabelDoStmt pft::Evaluation, and add them
2183 // as attribute to the acc.loop as an extra attribute. It is not quite clear
2184 // how useful these $dir are in acc contexts, but they could still provide
2185 // more information about the loop acc codegen. They can be obtained by
2186 // looking for the first lexicalSuccessor of eval that is a NonLabelDoStmt,
2187 // and using the related `dirs` member.
2189 return loopOp;
2192 static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
2193 bool hasReturnStmt = false;
2194 for (auto &e : eval.getNestedEvaluations()) {
2195 e.visit(Fortran::common::visitors{
2196 [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
2197 [&](const auto &s) {},
2199 if (e.hasNestedEvaluations())
2200 hasReturnStmt = hasEarlyReturn(e);
2202 return hasReturnStmt;
2205 static mlir::Value
2206 genACC(Fortran::lower::AbstractConverter &converter,
2207 Fortran::semantics::SemanticsContext &semanticsContext,
2208 Fortran::lower::pft::Evaluation &eval,
2209 const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
2211 const auto &beginLoopDirective =
2212 std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t);
2213 const auto &loopDirective =
2214 std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
2216 bool needEarlyExitHandling = false;
2217 if (eval.lowerAsUnstructured())
2218 needEarlyExitHandling = hasEarlyReturn(eval);
2220 mlir::Location currentLocation =
2221 converter.genLocation(beginLoopDirective.source);
2222 Fortran::lower::StatementContext stmtCtx;
2224 assert(loopDirective.v == llvm::acc::ACCD_loop &&
2225 "Unsupported OpenACC loop construct");
2226 (void)loopDirective;
2228 const auto &accClauseList =
2229 std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
2230 const auto &outerDoConstruct =
2231 std::get<std::optional<Fortran::parser::DoConstruct>>(loopConstruct.t);
2232 auto loopOp = createLoopOp(converter, currentLocation, semanticsContext,
2233 stmtCtx, *outerDoConstruct, eval, accClauseList,
2234 /*combinedConstructs=*/{}, needEarlyExitHandling);
2235 if (needEarlyExitHandling)
2236 return loopOp.getResult(0);
2238 return mlir::Value{};
2241 template <typename Op, typename Clause>
2242 static void genDataOperandOperationsWithModifier(
2243 const Clause *x, Fortran::lower::AbstractConverter &converter,
2244 Fortran::semantics::SemanticsContext &semanticsContext,
2245 Fortran::lower::StatementContext &stmtCtx,
2246 Fortran::parser::AccDataModifier::Modifier mod,
2247 llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
2248 const mlir::acc::DataClause clause,
2249 const mlir::acc::DataClause clauseWithModifier,
2250 llvm::ArrayRef<mlir::Value> async,
2251 llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
2252 llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
2253 bool setDeclareAttr = false) {
2254 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
2255 const auto &accObjectList =
2256 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
2257 const auto &modifier =
2258 std::get<std::optional<Fortran::parser::AccDataModifier>>(
2259 listWithModifier.t);
2260 mlir::acc::DataClause dataClause =
2261 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
2262 genDataOperandOperations<Op>(accObjectList, converter, semanticsContext,
2263 stmtCtx, dataClauseOperands, dataClause,
2264 /*structured=*/true, /*implicit=*/false, async,
2265 asyncDeviceTypes, asyncOnlyDeviceTypes,
2266 setDeclareAttr);
2269 template <typename Op>
2270 static Op createComputeOp(
2271 Fortran::lower::AbstractConverter &converter,
2272 mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
2273 Fortran::semantics::SemanticsContext &semanticsContext,
2274 Fortran::lower::StatementContext &stmtCtx,
2275 const Fortran::parser::AccClauseList &accClauseList,
2276 std::optional<mlir::acc::CombinedConstructsType> combinedConstructs =
2277 std::nullopt) {
2279 // Parallel operation operands
2280 mlir::Value ifCond;
2281 mlir::Value selfCond;
2282 llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
2283 copyEntryOperands, copyinEntryOperands, copyoutEntryOperands,
2284 createEntryOperands, dataClauseOperands, numGangs, numWorkers,
2285 vectorLength, async;
2286 llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
2287 vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
2288 waitOperandsDeviceTypes, waitOnlyDeviceTypes;
2289 llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
2290 llvm::SmallVector<bool> hasWaitDevnums;
2292 llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
2293 firstprivateOperands;
2294 llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
2295 reductionRecipes;
2297 // Self clause has optional values but can be present with
2298 // no value as well. When there is no value, the op has an attribute to
2299 // represent the clause.
2300 bool addSelfAttr = false;
2302 bool hasDefaultNone = false;
2303 bool hasDefaultPresent = false;
2305 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2307 // device_type attribute is set to `none` until a device_type clause is
2308 // encountered.
2309 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
2310 auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
2311 builder.getContext(), mlir::acc::DeviceType::None);
2312 crtDeviceTypes.push_back(crtDeviceTypeAttr);
2314 // Lower clauses values mapped to operands and array attributes.
2315 // Keep track of each group of operands separately as clauses can appear
2316 // more than once.
2318 // Process the clauses that may have a specified device_type first.
2319 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2320 if (const auto *asyncClause =
2321 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2322 genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
2323 asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
2324 } else if (const auto *waitClause =
2325 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
2326 genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
2327 waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2328 hasWaitDevnums, waitOperandsSegments,
2329 crtDeviceTypes, stmtCtx);
2330 } else if (const auto *numGangsClause =
2331 std::get_if<Fortran::parser::AccClause::NumGangs>(
2332 &clause.u)) {
2333 llvm::SmallVector<mlir::Value> numGangValues;
2334 for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
2335 numGangValues.push_back(fir::getBase(converter.genExprValue(
2336 *Fortran::semantics::GetExpr(expr), stmtCtx)));
2337 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2338 for (auto value : numGangValues)
2339 numGangs.push_back(value);
2340 numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
2341 numGangsSegments.push_back(numGangValues.size());
2343 } else if (const auto *numWorkersClause =
2344 std::get_if<Fortran::parser::AccClause::NumWorkers>(
2345 &clause.u)) {
2346 mlir::Value numWorkerValue = fir::getBase(converter.genExprValue(
2347 *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
2348 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2349 numWorkers.push_back(numWorkerValue);
2350 numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
2352 } else if (const auto *vectorLengthClause =
2353 std::get_if<Fortran::parser::AccClause::VectorLength>(
2354 &clause.u)) {
2355 mlir::Value vectorLengthValue = fir::getBase(converter.genExprValue(
2356 *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
2357 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
2358 vectorLength.push_back(vectorLengthValue);
2359 vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
2361 } else if (const auto *deviceTypeClause =
2362 std::get_if<Fortran::parser::AccClause::DeviceType>(
2363 &clause.u)) {
2364 crtDeviceTypes.clear();
2365 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
2369 // Process the clauses independent of device_type.
2370 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2371 mlir::Location clauseLocation = converter.genLocation(clause.source);
2372 if (const auto *ifClause =
2373 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2374 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2375 } else if (const auto *selfClause =
2376 std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
2377 const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
2378 selfClause->v;
2379 if (accSelfClause) {
2380 if (const auto *optCondition =
2381 std::get_if<std::optional<Fortran::parser::ScalarLogicalExpr>>(
2382 &(*accSelfClause).u)) {
2383 if (*optCondition) {
2384 mlir::Value cond = fir::getBase(converter.genExprValue(
2385 *Fortran::semantics::GetExpr(*optCondition), stmtCtx));
2386 selfCond = builder.createConvert(clauseLocation,
2387 builder.getI1Type(), cond);
2389 } else if (const auto *accClauseList =
2390 std::get_if<Fortran::parser::AccObjectList>(
2391 &(*accSelfClause).u)) {
2392 // TODO This would be nicer to be done in canonicalization step.
2393 if (accClauseList->v.size() == 1) {
2394 const auto &accObject = accClauseList->v.front();
2395 if (const auto *designator =
2396 std::get_if<Fortran::parser::Designator>(&accObject.u)) {
2397 if (const auto *name =
2398 Fortran::semantics::getDesignatorNameIfDataRef(
2399 *designator)) {
2400 auto cond = converter.getSymbolAddress(*name->symbol);
2401 selfCond = builder.createConvert(clauseLocation,
2402 builder.getI1Type(), cond);
2407 } else {
2408 addSelfAttr = true;
2410 } else if (const auto *copyClause =
2411 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
2412 auto crtDataStart = dataClauseOperands.size();
2413 genDataOperandOperations<mlir::acc::CopyinOp>(
2414 copyClause->v, converter, semanticsContext, stmtCtx,
2415 dataClauseOperands, mlir::acc::DataClause::acc_copy,
2416 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2417 asyncOnlyDeviceTypes);
2418 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2419 dataClauseOperands.end());
2420 } else if (const auto *copyinClause =
2421 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
2422 auto crtDataStart = dataClauseOperands.size();
2423 genDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
2424 Fortran::parser::AccClause::Copyin>(
2425 copyinClause, converter, semanticsContext, stmtCtx,
2426 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
2427 dataClauseOperands, mlir::acc::DataClause::acc_copyin,
2428 mlir::acc::DataClause::acc_copyin_readonly, async, asyncDeviceTypes,
2429 asyncOnlyDeviceTypes);
2430 copyinEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2431 dataClauseOperands.end());
2432 } else if (const auto *copyoutClause =
2433 std::get_if<Fortran::parser::AccClause::Copyout>(
2434 &clause.u)) {
2435 auto crtDataStart = dataClauseOperands.size();
2436 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2437 Fortran::parser::AccClause::Copyout>(
2438 copyoutClause, converter, semanticsContext, stmtCtx,
2439 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
2440 dataClauseOperands, mlir::acc::DataClause::acc_copyout,
2441 mlir::acc::DataClause::acc_copyout_zero, async, asyncDeviceTypes,
2442 asyncOnlyDeviceTypes);
2443 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2444 dataClauseOperands.end());
2445 } else if (const auto *createClause =
2446 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
2447 auto crtDataStart = dataClauseOperands.size();
2448 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2449 Fortran::parser::AccClause::Create>(
2450 createClause, converter, semanticsContext, stmtCtx,
2451 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands,
2452 mlir::acc::DataClause::acc_create,
2453 mlir::acc::DataClause::acc_create_zero, async, asyncDeviceTypes,
2454 asyncOnlyDeviceTypes);
2455 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2456 dataClauseOperands.end());
2457 } else if (const auto *noCreateClause =
2458 std::get_if<Fortran::parser::AccClause::NoCreate>(
2459 &clause.u)) {
2460 genDataOperandOperations<mlir::acc::NoCreateOp>(
2461 noCreateClause->v, converter, semanticsContext, stmtCtx,
2462 dataClauseOperands, mlir::acc::DataClause::acc_no_create,
2463 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2464 asyncOnlyDeviceTypes);
2465 } else if (const auto *presentClause =
2466 std::get_if<Fortran::parser::AccClause::Present>(
2467 &clause.u)) {
2468 genDataOperandOperations<mlir::acc::PresentOp>(
2469 presentClause->v, converter, semanticsContext, stmtCtx,
2470 dataClauseOperands, mlir::acc::DataClause::acc_present,
2471 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2472 asyncOnlyDeviceTypes);
2473 } else if (const auto *devicePtrClause =
2474 std::get_if<Fortran::parser::AccClause::Deviceptr>(
2475 &clause.u)) {
2476 genDataOperandOperations<mlir::acc::DevicePtrOp>(
2477 devicePtrClause->v, converter, semanticsContext, stmtCtx,
2478 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
2479 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2480 asyncOnlyDeviceTypes);
2481 } else if (const auto *attachClause =
2482 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
2483 auto crtDataStart = dataClauseOperands.size();
2484 genDataOperandOperations<mlir::acc::AttachOp>(
2485 attachClause->v, converter, semanticsContext, stmtCtx,
2486 dataClauseOperands, mlir::acc::DataClause::acc_attach,
2487 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2488 asyncOnlyDeviceTypes);
2489 attachEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2490 dataClauseOperands.end());
2491 } else if (const auto *privateClause =
2492 std::get_if<Fortran::parser::AccClause::Private>(
2493 &clause.u)) {
2494 if (!combinedConstructs)
2495 genPrivatizations<mlir::acc::PrivateRecipeOp>(
2496 privateClause->v, converter, semanticsContext, stmtCtx,
2497 privateOperands, privatizations, async, asyncDeviceTypes,
2498 asyncOnlyDeviceTypes);
2499 } else if (const auto *firstprivateClause =
2500 std::get_if<Fortran::parser::AccClause::Firstprivate>(
2501 &clause.u)) {
2502 genPrivatizations<mlir::acc::FirstprivateRecipeOp>(
2503 firstprivateClause->v, converter, semanticsContext, stmtCtx,
2504 firstprivateOperands, firstPrivatizations, async, asyncDeviceTypes,
2505 asyncOnlyDeviceTypes);
2506 } else if (const auto *reductionClause =
2507 std::get_if<Fortran::parser::AccClause::Reduction>(
2508 &clause.u)) {
2509 // A reduction clause on a combined construct is treated as if it appeared
2510 // on the loop construct. So don't generate a reduction clause when it is
2511 // combined - delay it to the loop. However, a reduction clause on a
2512 // combined construct implies a copy clause so issue an implicit copy
2513 // instead.
2514 if (!combinedConstructs) {
2515 genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
2516 reductionOperands, reductionRecipes, async,
2517 asyncDeviceTypes, asyncOnlyDeviceTypes);
2518 } else {
2519 auto crtDataStart = dataClauseOperands.size();
2520 genDataOperandOperations<mlir::acc::CopyinOp>(
2521 std::get<Fortran::parser::AccObjectList>(reductionClause->v.t),
2522 converter, semanticsContext, stmtCtx, dataClauseOperands,
2523 mlir::acc::DataClause::acc_reduction,
2524 /*structured=*/true, /*implicit=*/true, async, asyncDeviceTypes,
2525 asyncOnlyDeviceTypes);
2526 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2527 dataClauseOperands.end());
2529 } else if (const auto *defaultClause =
2530 std::get_if<Fortran::parser::AccClause::Default>(
2531 &clause.u)) {
2532 if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
2533 hasDefaultNone = true;
2534 else if ((defaultClause->v).v ==
2535 llvm::acc::DefaultValue::ACC_Default_present)
2536 hasDefaultPresent = true;
2540 // Prepare the operand segment size attribute and the operands value range.
2541 llvm::SmallVector<mlir::Value, 8> operands;
2542 llvm::SmallVector<int32_t, 8> operandSegments;
2543 addOperands(operands, operandSegments, async);
2544 addOperands(operands, operandSegments, waitOperands);
2545 if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2546 addOperands(operands, operandSegments, numGangs);
2547 addOperands(operands, operandSegments, numWorkers);
2548 addOperands(operands, operandSegments, vectorLength);
2550 addOperand(operands, operandSegments, ifCond);
2551 addOperand(operands, operandSegments, selfCond);
2552 if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
2553 addOperands(operands, operandSegments, reductionOperands);
2554 addOperands(operands, operandSegments, privateOperands);
2555 addOperands(operands, operandSegments, firstprivateOperands);
2557 addOperands(operands, operandSegments, dataClauseOperands);
2559 Op computeOp;
2560 if constexpr (std::is_same_v<Op, mlir::acc::KernelsOp>)
2561 computeOp = createRegionOp<Op, mlir::acc::TerminatorOp>(
2562 builder, currentLocation, currentLocation, eval, operands,
2563 operandSegments, /*outerCombined=*/combinedConstructs.has_value());
2564 else
2565 computeOp = createRegionOp<Op, mlir::acc::YieldOp>(
2566 builder, currentLocation, currentLocation, eval, operands,
2567 operandSegments, /*outerCombined=*/combinedConstructs.has_value());
2569 if (addSelfAttr)
2570 computeOp.setSelfAttrAttr(builder.getUnitAttr());
2572 if (hasDefaultNone)
2573 computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);
2574 if (hasDefaultPresent)
2575 computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
2577 if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2578 if (!numWorkersDeviceTypes.empty())
2579 computeOp.setNumWorkersDeviceTypeAttr(
2580 mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
2581 if (!vectorLengthDeviceTypes.empty())
2582 computeOp.setVectorLengthDeviceTypeAttr(
2583 mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
2584 if (!numGangsDeviceTypes.empty())
2585 computeOp.setNumGangsDeviceTypeAttr(
2586 mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
2587 if (!numGangsSegments.empty())
2588 computeOp.setNumGangsSegmentsAttr(
2589 builder.getDenseI32ArrayAttr(numGangsSegments));
2591 if (!asyncDeviceTypes.empty())
2592 computeOp.setAsyncOperandsDeviceTypeAttr(
2593 builder.getArrayAttr(asyncDeviceTypes));
2594 if (!asyncOnlyDeviceTypes.empty())
2595 computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
2597 if (!waitOperandsDeviceTypes.empty())
2598 computeOp.setWaitOperandsDeviceTypeAttr(
2599 builder.getArrayAttr(waitOperandsDeviceTypes));
2600 if (!waitOperandsSegments.empty())
2601 computeOp.setWaitOperandsSegmentsAttr(
2602 builder.getDenseI32ArrayAttr(waitOperandsSegments));
2603 if (!hasWaitDevnums.empty())
2604 computeOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
2605 if (!waitOnlyDeviceTypes.empty())
2606 computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
2608 if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
2609 if (!privatizations.empty())
2610 computeOp.setPrivatizationsAttr(
2611 mlir::ArrayAttr::get(builder.getContext(), privatizations));
2612 if (!reductionRecipes.empty())
2613 computeOp.setReductionRecipesAttr(
2614 mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
2615 if (!firstPrivatizations.empty())
2616 computeOp.setFirstprivatizationsAttr(
2617 mlir::ArrayAttr::get(builder.getContext(), firstPrivatizations));
2620 if (combinedConstructs)
2621 computeOp.setCombinedAttr(builder.getUnitAttr());
2623 auto insPt = builder.saveInsertionPoint();
2624 builder.setInsertionPointAfter(computeOp);
2626 // Create the exit operations after the region.
2627 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
2628 builder, copyEntryOperands, /*structured=*/true);
2629 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::DeleteOp>(
2630 builder, copyinEntryOperands, /*structured=*/true);
2631 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>(
2632 builder, copyoutEntryOperands, /*structured=*/true);
2633 genDataExitOperations<mlir::acc::AttachOp, mlir::acc::DetachOp>(
2634 builder, attachEntryOperands, /*structured=*/true);
2635 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
2636 builder, createEntryOperands, /*structured=*/true);
2638 builder.restoreInsertionPoint(insPt);
2639 return computeOp;
2642 static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
2643 mlir::Location currentLocation,
2644 Fortran::lower::pft::Evaluation &eval,
2645 Fortran::semantics::SemanticsContext &semanticsContext,
2646 Fortran::lower::StatementContext &stmtCtx,
2647 const Fortran::parser::AccClauseList &accClauseList) {
2648 mlir::Value ifCond;
2649 llvm::SmallVector<mlir::Value> attachEntryOperands, createEntryOperands,
2650 copyEntryOperands, copyinEntryOperands, copyoutEntryOperands,
2651 dataClauseOperands, waitOperands, async;
2652 llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
2653 waitOperandsDeviceTypes, waitOnlyDeviceTypes;
2654 llvm::SmallVector<int32_t> waitOperandsSegments;
2655 llvm::SmallVector<bool> hasWaitDevnums;
2657 bool hasDefaultNone = false;
2658 bool hasDefaultPresent = false;
2660 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2662 // device_type attribute is set to `none` until a device_type clause is
2663 // encountered.
2664 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
2665 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
2666 builder.getContext(), mlir::acc::DeviceType::None));
2668 // Lower clauses values mapped to operands and array attributes.
2669 // Keep track of each group of operands separately as clauses can appear
2670 // more than once.
2672 // Process the clauses that may have a specified device_type first.
2673 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2674 if (const auto *asyncClause =
2675 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2676 genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
2677 asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
2678 } else if (const auto *waitClause =
2679 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
2680 genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
2681 waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2682 hasWaitDevnums, waitOperandsSegments,
2683 crtDeviceTypes, stmtCtx);
2684 } else if (const auto *deviceTypeClause =
2685 std::get_if<Fortran::parser::AccClause::DeviceType>(
2686 &clause.u)) {
2687 crtDeviceTypes.clear();
2688 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
2692 // Process the clauses independent of device_type.
2693 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2694 mlir::Location clauseLocation = converter.genLocation(clause.source);
2695 if (const auto *ifClause =
2696 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2697 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2698 } else if (const auto *copyClause =
2699 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
2700 auto crtDataStart = dataClauseOperands.size();
2701 genDataOperandOperations<mlir::acc::CopyinOp>(
2702 copyClause->v, converter, semanticsContext, stmtCtx,
2703 dataClauseOperands, mlir::acc::DataClause::acc_copy,
2704 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2705 asyncOnlyDeviceTypes);
2706 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2707 dataClauseOperands.end());
2708 } else if (const auto *copyinClause =
2709 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
2710 auto crtDataStart = dataClauseOperands.size();
2711 genDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
2712 Fortran::parser::AccClause::Copyin>(
2713 copyinClause, converter, semanticsContext, stmtCtx,
2714 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
2715 dataClauseOperands, mlir::acc::DataClause::acc_copyin,
2716 mlir::acc::DataClause::acc_copyin_readonly, async, asyncDeviceTypes,
2717 asyncOnlyDeviceTypes);
2718 copyinEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2719 dataClauseOperands.end());
2720 } else if (const auto *copyoutClause =
2721 std::get_if<Fortran::parser::AccClause::Copyout>(
2722 &clause.u)) {
2723 auto crtDataStart = dataClauseOperands.size();
2724 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2725 Fortran::parser::AccClause::Copyout>(
2726 copyoutClause, converter, semanticsContext, stmtCtx,
2727 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands,
2728 mlir::acc::DataClause::acc_copyout,
2729 mlir::acc::DataClause::acc_copyout_zero, async, asyncDeviceTypes,
2730 asyncOnlyDeviceTypes);
2731 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2732 dataClauseOperands.end());
2733 } else if (const auto *createClause =
2734 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
2735 auto crtDataStart = dataClauseOperands.size();
2736 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2737 Fortran::parser::AccClause::Create>(
2738 createClause, converter, semanticsContext, stmtCtx,
2739 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands,
2740 mlir::acc::DataClause::acc_create,
2741 mlir::acc::DataClause::acc_create_zero, async, asyncDeviceTypes,
2742 asyncOnlyDeviceTypes);
2743 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2744 dataClauseOperands.end());
2745 } else if (const auto *noCreateClause =
2746 std::get_if<Fortran::parser::AccClause::NoCreate>(
2747 &clause.u)) {
2748 genDataOperandOperations<mlir::acc::NoCreateOp>(
2749 noCreateClause->v, converter, semanticsContext, stmtCtx,
2750 dataClauseOperands, mlir::acc::DataClause::acc_no_create,
2751 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2752 asyncOnlyDeviceTypes);
2753 } else if (const auto *presentClause =
2754 std::get_if<Fortran::parser::AccClause::Present>(
2755 &clause.u)) {
2756 genDataOperandOperations<mlir::acc::PresentOp>(
2757 presentClause->v, converter, semanticsContext, stmtCtx,
2758 dataClauseOperands, mlir::acc::DataClause::acc_present,
2759 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2760 asyncOnlyDeviceTypes);
2761 } else if (const auto *deviceptrClause =
2762 std::get_if<Fortran::parser::AccClause::Deviceptr>(
2763 &clause.u)) {
2764 genDataOperandOperations<mlir::acc::DevicePtrOp>(
2765 deviceptrClause->v, converter, semanticsContext, stmtCtx,
2766 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
2767 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2768 asyncOnlyDeviceTypes);
2769 } else if (const auto *attachClause =
2770 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
2771 auto crtDataStart = dataClauseOperands.size();
2772 genDataOperandOperations<mlir::acc::AttachOp>(
2773 attachClause->v, converter, semanticsContext, stmtCtx,
2774 dataClauseOperands, mlir::acc::DataClause::acc_attach,
2775 /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
2776 asyncOnlyDeviceTypes);
2777 attachEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2778 dataClauseOperands.end());
2779 } else if (const auto *defaultClause =
2780 std::get_if<Fortran::parser::AccClause::Default>(
2781 &clause.u)) {
2782 if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
2783 hasDefaultNone = true;
2784 else if ((defaultClause->v).v ==
2785 llvm::acc::DefaultValue::ACC_Default_present)
2786 hasDefaultPresent = true;
2790 // Prepare the operand segment size attribute and the operands value range.
2791 llvm::SmallVector<mlir::Value> operands;
2792 llvm::SmallVector<int32_t> operandSegments;
2793 addOperand(operands, operandSegments, ifCond);
2794 addOperands(operands, operandSegments, async);
2795 addOperands(operands, operandSegments, waitOperands);
2796 addOperands(operands, operandSegments, dataClauseOperands);
2798 if (dataClauseOperands.empty() && !hasDefaultNone && !hasDefaultPresent)
2799 return;
2801 auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
2802 builder, currentLocation, currentLocation, eval, operands,
2803 operandSegments);
2805 if (!asyncDeviceTypes.empty())
2806 dataOp.setAsyncOperandsDeviceTypeAttr(
2807 builder.getArrayAttr(asyncDeviceTypes));
2808 if (!asyncOnlyDeviceTypes.empty())
2809 dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
2810 if (!waitOperandsDeviceTypes.empty())
2811 dataOp.setWaitOperandsDeviceTypeAttr(
2812 builder.getArrayAttr(waitOperandsDeviceTypes));
2813 if (!waitOperandsSegments.empty())
2814 dataOp.setWaitOperandsSegmentsAttr(
2815 builder.getDenseI32ArrayAttr(waitOperandsSegments));
2816 if (!hasWaitDevnums.empty())
2817 dataOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
2818 if (!waitOnlyDeviceTypes.empty())
2819 dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
2821 if (hasDefaultNone)
2822 dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);
2823 if (hasDefaultPresent)
2824 dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
2826 auto insPt = builder.saveInsertionPoint();
2827 builder.setInsertionPointAfter(dataOp);
2829 // Create the exit operations after the region.
2830 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
2831 builder, copyEntryOperands, /*structured=*/true);
2832 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::DeleteOp>(
2833 builder, copyinEntryOperands, /*structured=*/true);
2834 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>(
2835 builder, copyoutEntryOperands, /*structured=*/true);
2836 genDataExitOperations<mlir::acc::AttachOp, mlir::acc::DetachOp>(
2837 builder, attachEntryOperands, /*structured=*/true);
2838 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
2839 builder, createEntryOperands, /*structured=*/true);
2841 builder.restoreInsertionPoint(insPt);
2844 static void
2845 genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
2846 mlir::Location currentLocation,
2847 Fortran::lower::pft::Evaluation &eval,
2848 Fortran::semantics::SemanticsContext &semanticsContext,
2849 Fortran::lower::StatementContext &stmtCtx,
2850 const Fortran::parser::AccClauseList &accClauseList) {
2851 mlir::Value ifCond;
2852 llvm::SmallVector<mlir::Value> dataOperands;
2853 bool addIfPresentAttr = false;
2855 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2857 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2858 mlir::Location clauseLocation = converter.genLocation(clause.source);
2859 if (const auto *ifClause =
2860 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2861 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2862 } else if (const auto *useDevice =
2863 std::get_if<Fortran::parser::AccClause::UseDevice>(
2864 &clause.u)) {
2865 genDataOperandOperations<mlir::acc::UseDeviceOp>(
2866 useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
2867 mlir::acc::DataClause::acc_use_device,
2868 /*structured=*/true, /*implicit=*/false, /*async=*/{},
2869 /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
2870 } else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
2871 addIfPresentAttr = true;
2875 if (ifCond) {
2876 if (auto cst =
2877 mlir::dyn_cast<mlir::arith::ConstantOp>(ifCond.getDefiningOp()))
2878 if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(cst.getValue())) {
2879 if (boolAttr.getValue()) {
2880 // get rid of the if condition if it is always true.
2881 ifCond = mlir::Value();
2882 } else {
2883 // Do not generate the acc.host_data op if the if condition is always
2884 // false.
2885 return;
2890 // Prepare the operand segment size attribute and the operands value range.
2891 llvm::SmallVector<mlir::Value> operands;
2892 llvm::SmallVector<int32_t> operandSegments;
2893 addOperand(operands, operandSegments, ifCond);
2894 addOperands(operands, operandSegments, dataOperands);
2896 auto hostDataOp =
2897 createRegionOp<mlir::acc::HostDataOp, mlir::acc::TerminatorOp>(
2898 builder, currentLocation, currentLocation, eval, operands,
2899 operandSegments);
2901 if (addIfPresentAttr)
2902 hostDataOp.setIfPresentAttr(builder.getUnitAttr());
2905 static void
2906 genACC(Fortran::lower::AbstractConverter &converter,
2907 Fortran::semantics::SemanticsContext &semanticsContext,
2908 Fortran::lower::pft::Evaluation &eval,
2909 const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
2910 const auto &beginBlockDirective =
2911 std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
2912 const auto &blockDirective =
2913 std::get<Fortran::parser::AccBlockDirective>(beginBlockDirective.t);
2914 const auto &accClauseList =
2915 std::get<Fortran::parser::AccClauseList>(beginBlockDirective.t);
2917 mlir::Location currentLocation = converter.genLocation(blockDirective.source);
2918 Fortran::lower::StatementContext stmtCtx;
2920 if (blockDirective.v == llvm::acc::ACCD_parallel) {
2921 createComputeOp<mlir::acc::ParallelOp>(converter, currentLocation, eval,
2922 semanticsContext, stmtCtx,
2923 accClauseList);
2924 } else if (blockDirective.v == llvm::acc::ACCD_data) {
2925 genACCDataOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
2926 accClauseList);
2927 } else if (blockDirective.v == llvm::acc::ACCD_serial) {
2928 createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval,
2929 semanticsContext, stmtCtx,
2930 accClauseList);
2931 } else if (blockDirective.v == llvm::acc::ACCD_kernels) {
2932 createComputeOp<mlir::acc::KernelsOp>(converter, currentLocation, eval,
2933 semanticsContext, stmtCtx,
2934 accClauseList);
2935 } else if (blockDirective.v == llvm::acc::ACCD_host_data) {
2936 genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
2937 stmtCtx, accClauseList);
2941 static void
2942 genACC(Fortran::lower::AbstractConverter &converter,
2943 Fortran::semantics::SemanticsContext &semanticsContext,
2944 Fortran::lower::pft::Evaluation &eval,
2945 const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) {
2946 const auto &beginCombinedDirective =
2947 std::get<Fortran::parser::AccBeginCombinedDirective>(combinedConstruct.t);
2948 const auto &combinedDirective =
2949 std::get<Fortran::parser::AccCombinedDirective>(beginCombinedDirective.t);
2950 const auto &accClauseList =
2951 std::get<Fortran::parser::AccClauseList>(beginCombinedDirective.t);
2952 const auto &outerDoConstruct =
2953 std::get<std::optional<Fortran::parser::DoConstruct>>(
2954 combinedConstruct.t);
2956 mlir::Location currentLocation =
2957 converter.genLocation(beginCombinedDirective.source);
2958 Fortran::lower::StatementContext stmtCtx;
2960 if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
2961 createComputeOp<mlir::acc::KernelsOp>(
2962 converter, currentLocation, eval, semanticsContext, stmtCtx,
2963 accClauseList, mlir::acc::CombinedConstructsType::KernelsLoop);
2964 createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
2965 *outerDoConstruct, eval, accClauseList,
2966 mlir::acc::CombinedConstructsType::KernelsLoop);
2967 } else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
2968 createComputeOp<mlir::acc::ParallelOp>(
2969 converter, currentLocation, eval, semanticsContext, stmtCtx,
2970 accClauseList, mlir::acc::CombinedConstructsType::ParallelLoop);
2971 createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
2972 *outerDoConstruct, eval, accClauseList,
2973 mlir::acc::CombinedConstructsType::ParallelLoop);
2974 } else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
2975 createComputeOp<mlir::acc::SerialOp>(
2976 converter, currentLocation, eval, semanticsContext, stmtCtx,
2977 accClauseList, mlir::acc::CombinedConstructsType::SerialLoop);
2978 createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
2979 *outerDoConstruct, eval, accClauseList,
2980 mlir::acc::CombinedConstructsType::SerialLoop);
2981 } else {
2982 llvm::report_fatal_error("Unknown combined construct encountered");
2986 static void
2987 genACCEnterDataOp(Fortran::lower::AbstractConverter &converter,
2988 mlir::Location currentLocation,
2989 Fortran::semantics::SemanticsContext &semanticsContext,
2990 Fortran::lower::StatementContext &stmtCtx,
2991 const Fortran::parser::AccClauseList &accClauseList) {
2992 mlir::Value ifCond, async, waitDevnum;
2993 llvm::SmallVector<mlir::Value> waitOperands, dataClauseOperands;
2995 // Async, wait and self clause have optional values but can be present with
2996 // no value as well. When there is no value, the op has an attribute to
2997 // represent the clause.
2998 bool addAsyncAttr = false;
2999 bool addWaitAttr = false;
3001 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
3003 // Lower clauses values mapped to operands.
3004 // Keep track of each group of operands separately as clauses can appear
3005 // more than once.
3007 // Process the async clause first.
3008 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3009 if (const auto *asyncClause =
3010 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
3011 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
3015 // The async clause of 'enter data' applies to all device types,
3016 // so propagate the async clause to copyin/create/attach ops
3017 // as if it is an async clause without preceding device_type clause.
3018 llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes;
3019 llvm::SmallVector<mlir::Value> asyncValues;
3020 auto noneDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
3021 firOpBuilder.getContext(), mlir::acc::DeviceType::None);
3022 if (addAsyncAttr) {
3023 asyncOnlyDeviceTypes.push_back(noneDeviceTypeAttr);
3024 } else if (async) {
3025 asyncValues.push_back(async);
3026 asyncDeviceTypes.push_back(noneDeviceTypeAttr);
3029 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3030 mlir::Location clauseLocation = converter.genLocation(clause.source);
3031 if (const auto *ifClause =
3032 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3033 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3034 } else if (const auto *waitClause =
3035 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
3036 genWaitClause(converter, waitClause, waitOperands, waitDevnum,
3037 addWaitAttr, stmtCtx);
3038 } else if (const auto *copyinClause =
3039 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
3040 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3041 copyinClause->v;
3042 const auto &accObjectList =
3043 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3044 genDataOperandOperations<mlir::acc::CopyinOp>(
3045 accObjectList, converter, semanticsContext, stmtCtx,
3046 dataClauseOperands, mlir::acc::DataClause::acc_copyin, false,
3047 /*implicit=*/false, asyncValues, asyncDeviceTypes,
3048 asyncOnlyDeviceTypes);
3049 } else if (const auto *createClause =
3050 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
3051 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3052 createClause->v;
3053 const auto &accObjectList =
3054 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3055 const auto &modifier =
3056 std::get<std::optional<Fortran::parser::AccDataModifier>>(
3057 listWithModifier.t);
3058 mlir::acc::DataClause clause = mlir::acc::DataClause::acc_create;
3059 if (modifier &&
3060 (*modifier).v == Fortran::parser::AccDataModifier::Modifier::Zero)
3061 clause = mlir::acc::DataClause::acc_create_zero;
3062 genDataOperandOperations<mlir::acc::CreateOp>(
3063 accObjectList, converter, semanticsContext, stmtCtx,
3064 dataClauseOperands, clause, false, /*implicit=*/false, asyncValues,
3065 asyncDeviceTypes, asyncOnlyDeviceTypes);
3066 } else if (const auto *attachClause =
3067 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
3068 genDataOperandOperations<mlir::acc::AttachOp>(
3069 attachClause->v, converter, semanticsContext, stmtCtx,
3070 dataClauseOperands, mlir::acc::DataClause::acc_attach, false,
3071 /*implicit=*/false, asyncValues, asyncDeviceTypes,
3072 asyncOnlyDeviceTypes);
3073 } else if (!std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
3074 llvm::report_fatal_error(
3075 "Unknown clause in ENTER DATA directive lowering");
3079 // Prepare the operand segment size attribute and the operands value range.
3080 llvm::SmallVector<mlir::Value, 16> operands;
3081 llvm::SmallVector<int32_t, 8> operandSegments;
3082 addOperand(operands, operandSegments, ifCond);
3083 addOperand(operands, operandSegments, async);
3084 addOperand(operands, operandSegments, waitDevnum);
3085 addOperands(operands, operandSegments, waitOperands);
3086 addOperands(operands, operandSegments, dataClauseOperands);
3088 mlir::acc::EnterDataOp enterDataOp = createSimpleOp<mlir::acc::EnterDataOp>(
3089 firOpBuilder, currentLocation, operands, operandSegments);
3091 if (addAsyncAttr)
3092 enterDataOp.setAsyncAttr(firOpBuilder.getUnitAttr());
3093 if (addWaitAttr)
3094 enterDataOp.setWaitAttr(firOpBuilder.getUnitAttr());
3097 static void
3098 genACCExitDataOp(Fortran::lower::AbstractConverter &converter,
3099 mlir::Location currentLocation,
3100 Fortran::semantics::SemanticsContext &semanticsContext,
3101 Fortran::lower::StatementContext &stmtCtx,
3102 const Fortran::parser::AccClauseList &accClauseList) {
3103 mlir::Value ifCond, async, waitDevnum;
3104 llvm::SmallVector<mlir::Value> waitOperands, dataClauseOperands,
3105 copyoutOperands, deleteOperands, detachOperands;
3107 // Async and wait clause have optional values but can be present with
3108 // no value as well. When there is no value, the op has an attribute to
3109 // represent the clause.
3110 bool addAsyncAttr = false;
3111 bool addWaitAttr = false;
3112 bool addFinalizeAttr = false;
3114 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3116 // Lower clauses values mapped to operands.
3117 // Keep track of each group of operands separately as clauses can appear
3118 // more than once.
3120 // Process the async clause first.
3121 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3122 if (const auto *asyncClause =
3123 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
3124 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
3128 // The async clause of 'exit data' applies to all device types,
3129 // so propagate the async clause to copyin/create/attach ops
3130 // as if it is an async clause without preceding device_type clause.
3131 llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes;
3132 llvm::SmallVector<mlir::Value> asyncValues;
3133 auto noneDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
3134 builder.getContext(), mlir::acc::DeviceType::None);
3135 if (addAsyncAttr) {
3136 asyncOnlyDeviceTypes.push_back(noneDeviceTypeAttr);
3137 } else if (async) {
3138 asyncValues.push_back(async);
3139 asyncDeviceTypes.push_back(noneDeviceTypeAttr);
3142 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3143 mlir::Location clauseLocation = converter.genLocation(clause.source);
3144 if (const auto *ifClause =
3145 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3146 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3147 } else if (const auto *waitClause =
3148 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
3149 genWaitClause(converter, waitClause, waitOperands, waitDevnum,
3150 addWaitAttr, stmtCtx);
3151 } else if (const auto *copyoutClause =
3152 std::get_if<Fortran::parser::AccClause::Copyout>(
3153 &clause.u)) {
3154 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3155 copyoutClause->v;
3156 const auto &accObjectList =
3157 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3158 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
3159 accObjectList, converter, semanticsContext, stmtCtx, copyoutOperands,
3160 mlir::acc::DataClause::acc_copyout, false, /*implicit=*/false,
3161 asyncValues, asyncDeviceTypes, asyncOnlyDeviceTypes);
3162 } else if (const auto *deleteClause =
3163 std::get_if<Fortran::parser::AccClause::Delete>(&clause.u)) {
3164 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
3165 deleteClause->v, converter, semanticsContext, stmtCtx, deleteOperands,
3166 mlir::acc::DataClause::acc_delete, false, /*implicit=*/false,
3167 asyncValues, asyncDeviceTypes, asyncOnlyDeviceTypes);
3168 } else if (const auto *detachClause =
3169 std::get_if<Fortran::parser::AccClause::Detach>(&clause.u)) {
3170 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
3171 detachClause->v, converter, semanticsContext, stmtCtx, detachOperands,
3172 mlir::acc::DataClause::acc_detach, false, /*implicit=*/false,
3173 asyncValues, asyncDeviceTypes, asyncOnlyDeviceTypes);
3174 } else if (std::get_if<Fortran::parser::AccClause::Finalize>(&clause.u)) {
3175 addFinalizeAttr = true;
3179 dataClauseOperands.append(copyoutOperands);
3180 dataClauseOperands.append(deleteOperands);
3181 dataClauseOperands.append(detachOperands);
3183 // Prepare the operand segment size attribute and the operands value range.
3184 llvm::SmallVector<mlir::Value, 14> operands;
3185 llvm::SmallVector<int32_t, 7> operandSegments;
3186 addOperand(operands, operandSegments, ifCond);
3187 addOperand(operands, operandSegments, async);
3188 addOperand(operands, operandSegments, waitDevnum);
3189 addOperands(operands, operandSegments, waitOperands);
3190 addOperands(operands, operandSegments, dataClauseOperands);
3192 mlir::acc::ExitDataOp exitDataOp = createSimpleOp<mlir::acc::ExitDataOp>(
3193 builder, currentLocation, operands, operandSegments);
3195 if (addAsyncAttr)
3196 exitDataOp.setAsyncAttr(builder.getUnitAttr());
3197 if (addWaitAttr)
3198 exitDataOp.setWaitAttr(builder.getUnitAttr());
3199 if (addFinalizeAttr)
3200 exitDataOp.setFinalizeAttr(builder.getUnitAttr());
3202 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::CopyoutOp>(
3203 builder, copyoutOperands, /*structured=*/false);
3204 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::DeleteOp>(
3205 builder, deleteOperands, /*structured=*/false);
3206 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::DetachOp>(
3207 builder, detachOperands, /*structured=*/false);
3210 template <typename Op>
3211 static void
3212 genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
3213 mlir::Location currentLocation,
3214 const Fortran::parser::AccClauseList &accClauseList) {
3215 mlir::Value ifCond, deviceNum;
3217 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3218 Fortran::lower::StatementContext stmtCtx;
3219 llvm::SmallVector<mlir::Attribute> deviceTypes;
3221 // Lower clauses values mapped to operands.
3222 // Keep track of each group of operands separately as clauses can appear
3223 // more than once.
3224 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3225 mlir::Location clauseLocation = converter.genLocation(clause.source);
3226 if (const auto *ifClause =
3227 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3228 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3229 } else if (const auto *deviceNumClause =
3230 std::get_if<Fortran::parser::AccClause::DeviceNum>(
3231 &clause.u)) {
3232 deviceNum = fir::getBase(converter.genExprValue(
3233 *Fortran::semantics::GetExpr(deviceNumClause->v), stmtCtx));
3234 } else if (const auto *deviceTypeClause =
3235 std::get_if<Fortran::parser::AccClause::DeviceType>(
3236 &clause.u)) {
3237 gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
3241 // Prepare the operand segment size attribute and the operands value range.
3242 llvm::SmallVector<mlir::Value, 6> operands;
3243 llvm::SmallVector<int32_t, 2> operandSegments;
3245 addOperand(operands, operandSegments, deviceNum);
3246 addOperand(operands, operandSegments, ifCond);
3248 Op op =
3249 createSimpleOp<Op>(builder, currentLocation, operands, operandSegments);
3250 if (!deviceTypes.empty())
3251 op.setDeviceTypesAttr(
3252 mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
3255 void genACCSetOp(Fortran::lower::AbstractConverter &converter,
3256 mlir::Location currentLocation,
3257 const Fortran::parser::AccClauseList &accClauseList) {
3258 mlir::Value ifCond, deviceNum, defaultAsync;
3259 llvm::SmallVector<mlir::Value> deviceTypeOperands;
3261 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3262 Fortran::lower::StatementContext stmtCtx;
3263 llvm::SmallVector<mlir::Attribute> deviceTypes;
3265 // Lower clauses values mapped to operands.
3266 // Keep track of each group of operands separately as clauses can appear
3267 // more than once.
3268 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3269 mlir::Location clauseLocation = converter.genLocation(clause.source);
3270 if (const auto *ifClause =
3271 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3272 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3273 } else if (const auto *defaultAsyncClause =
3274 std::get_if<Fortran::parser::AccClause::DefaultAsync>(
3275 &clause.u)) {
3276 defaultAsync = fir::getBase(converter.genExprValue(
3277 *Fortran::semantics::GetExpr(defaultAsyncClause->v), stmtCtx));
3278 } else if (const auto *deviceNumClause =
3279 std::get_if<Fortran::parser::AccClause::DeviceNum>(
3280 &clause.u)) {
3281 deviceNum = fir::getBase(converter.genExprValue(
3282 *Fortran::semantics::GetExpr(deviceNumClause->v), stmtCtx));
3283 } else if (const auto *deviceTypeClause =
3284 std::get_if<Fortran::parser::AccClause::DeviceType>(
3285 &clause.u)) {
3286 gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
3290 // Prepare the operand segment size attribute and the operands value range.
3291 llvm::SmallVector<mlir::Value> operands;
3292 llvm::SmallVector<int32_t, 3> operandSegments;
3293 addOperand(operands, operandSegments, defaultAsync);
3294 addOperand(operands, operandSegments, deviceNum);
3295 addOperand(operands, operandSegments, ifCond);
3297 auto op = createSimpleOp<mlir::acc::SetOp>(builder, currentLocation, operands,
3298 operandSegments);
3299 if (!deviceTypes.empty()) {
3300 assert(deviceTypes.size() == 1 && "expect only one value for acc.set");
3301 op.setDeviceTypeAttr(mlir::cast<mlir::acc::DeviceTypeAttr>(deviceTypes[0]));
3305 static inline mlir::ArrayAttr
3306 getArrayAttr(fir::FirOpBuilder &b,
3307 llvm::SmallVector<mlir::Attribute> &attributes) {
3308 return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
3311 static inline mlir::ArrayAttr
3312 getBoolArrayAttr(fir::FirOpBuilder &b, llvm::SmallVector<bool> &values) {
3313 return values.empty() ? nullptr : b.getBoolArrayAttr(values);
3316 static inline mlir::DenseI32ArrayAttr
3317 getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
3318 llvm::SmallVector<int32_t> &values) {
3319 return values.empty() ? nullptr : builder.getDenseI32ArrayAttr(values);
3322 static void
3323 genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
3324 mlir::Location currentLocation,
3325 Fortran::semantics::SemanticsContext &semanticsContext,
3326 Fortran::lower::StatementContext &stmtCtx,
3327 const Fortran::parser::AccClauseList &accClauseList) {
3328 mlir::Value ifCond;
3329 llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
3330 waitOperands, deviceTypeOperands, asyncOperands;
3331 llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
3332 asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
3333 llvm::SmallVector<bool> hasWaitDevnums;
3334 llvm::SmallVector<int32_t> waitOperandsSegments;
3336 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3338 // device_type attribute is set to `none` until a device_type clause is
3339 // encountered.
3340 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
3341 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
3342 builder.getContext(), mlir::acc::DeviceType::None));
3344 bool ifPresent = false;
3346 // Lower clauses values mapped to operands and array attributes.
3347 // Keep track of each group of operands separately as clauses can appear
3348 // more than once.
3350 // Process the clauses that may have a specified device_type first.
3351 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3352 if (const auto *asyncClause =
3353 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
3354 genAsyncClause(converter, asyncClause, asyncOperands,
3355 asyncOperandsDeviceTypes, asyncOnlyDeviceTypes,
3356 crtDeviceTypes, stmtCtx);
3357 } else if (const auto *waitClause =
3358 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
3359 genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
3360 waitOperandsDeviceTypes, waitOnlyDeviceTypes,
3361 hasWaitDevnums, waitOperandsSegments,
3362 crtDeviceTypes, stmtCtx);
3363 } else if (const auto *deviceTypeClause =
3364 std::get_if<Fortran::parser::AccClause::DeviceType>(
3365 &clause.u)) {
3366 crtDeviceTypes.clear();
3367 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
3371 // Process the clauses independent of device_type.
3372 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3373 mlir::Location clauseLocation = converter.genLocation(clause.source);
3374 if (const auto *ifClause =
3375 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3376 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3377 } else if (const auto *hostClause =
3378 std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
3379 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
3380 hostClause->v, converter, semanticsContext, stmtCtx,
3381 updateHostOperands, mlir::acc::DataClause::acc_update_host, false,
3382 /*implicit=*/false, asyncOperands, asyncOperandsDeviceTypes,
3383 asyncOnlyDeviceTypes);
3384 } else if (const auto *deviceClause =
3385 std::get_if<Fortran::parser::AccClause::Device>(&clause.u)) {
3386 genDataOperandOperations<mlir::acc::UpdateDeviceOp>(
3387 deviceClause->v, converter, semanticsContext, stmtCtx,
3388 dataClauseOperands, mlir::acc::DataClause::acc_update_device, false,
3389 /*implicit=*/false, asyncOperands, asyncOperandsDeviceTypes,
3390 asyncOnlyDeviceTypes);
3391 } else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
3392 ifPresent = true;
3393 } else if (const auto *selfClause =
3394 std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
3395 const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
3396 selfClause->v;
3397 const auto *accObjectList =
3398 std::get_if<Fortran::parser::AccObjectList>(&(*accSelfClause).u);
3399 assert(accObjectList && "expect AccObjectList");
3400 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
3401 *accObjectList, converter, semanticsContext, stmtCtx,
3402 updateHostOperands, mlir::acc::DataClause::acc_update_self, false,
3403 /*implicit=*/false, asyncOperands, asyncOperandsDeviceTypes,
3404 asyncOnlyDeviceTypes);
3408 dataClauseOperands.append(updateHostOperands);
3410 builder.create<mlir::acc::UpdateOp>(
3411 currentLocation, ifCond, asyncOperands,
3412 getArrayAttr(builder, asyncOperandsDeviceTypes),
3413 getArrayAttr(builder, asyncOnlyDeviceTypes), waitOperands,
3414 getDenseI32ArrayAttr(builder, waitOperandsSegments),
3415 getArrayAttr(builder, waitOperandsDeviceTypes),
3416 getBoolArrayAttr(builder, hasWaitDevnums),
3417 getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
3418 ifPresent);
3420 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
3421 builder, updateHostOperands, /*structured=*/false);
3424 static void
3425 genACC(Fortran::lower::AbstractConverter &converter,
3426 Fortran::semantics::SemanticsContext &semanticsContext,
3427 const Fortran::parser::OpenACCStandaloneConstruct &standaloneConstruct) {
3428 const auto &standaloneDirective =
3429 std::get<Fortran::parser::AccStandaloneDirective>(standaloneConstruct.t);
3430 const auto &accClauseList =
3431 std::get<Fortran::parser::AccClauseList>(standaloneConstruct.t);
3433 mlir::Location currentLocation =
3434 converter.genLocation(standaloneDirective.source);
3435 Fortran::lower::StatementContext stmtCtx;
3437 if (standaloneDirective.v == llvm::acc::Directive::ACCD_enter_data) {
3438 genACCEnterDataOp(converter, currentLocation, semanticsContext, stmtCtx,
3439 accClauseList);
3440 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_exit_data) {
3441 genACCExitDataOp(converter, currentLocation, semanticsContext, stmtCtx,
3442 accClauseList);
3443 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_init) {
3444 genACCInitShutdownOp<mlir::acc::InitOp>(converter, currentLocation,
3445 accClauseList);
3446 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_shutdown) {
3447 genACCInitShutdownOp<mlir::acc::ShutdownOp>(converter, currentLocation,
3448 accClauseList);
3449 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_set) {
3450 genACCSetOp(converter, currentLocation, accClauseList);
3451 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_update) {
3452 genACCUpdateOp(converter, currentLocation, semanticsContext, stmtCtx,
3453 accClauseList);
3457 static void genACC(Fortran::lower::AbstractConverter &converter,
3458 const Fortran::parser::OpenACCWaitConstruct &waitConstruct) {
3460 const auto &waitArgument =
3461 std::get<std::optional<Fortran::parser::AccWaitArgument>>(
3462 waitConstruct.t);
3463 const auto &accClauseList =
3464 std::get<Fortran::parser::AccClauseList>(waitConstruct.t);
3466 mlir::Value ifCond, waitDevnum, async;
3467 llvm::SmallVector<mlir::Value> waitOperands;
3469 // Async clause have optional values but can be present with
3470 // no value as well. When there is no value, the op has an attribute to
3471 // represent the clause.
3472 bool addAsyncAttr = false;
3474 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
3475 mlir::Location currentLocation = converter.genLocation(waitConstruct.source);
3476 Fortran::lower::StatementContext stmtCtx;
3478 if (waitArgument) { // wait has a value.
3479 const Fortran::parser::AccWaitArgument &waitArg = *waitArgument;
3480 const auto &waitList =
3481 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
3482 for (const Fortran::parser::ScalarIntExpr &value : waitList) {
3483 mlir::Value v = fir::getBase(
3484 converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx));
3485 waitOperands.push_back(v);
3488 const auto &waitDevnumValue =
3489 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
3490 if (waitDevnumValue)
3491 waitDevnum = fir::getBase(converter.genExprValue(
3492 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
3495 // Lower clauses values mapped to operands.
3496 // Keep track of each group of operands separately as clauses can appear
3497 // more than once.
3498 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3499 mlir::Location clauseLocation = converter.genLocation(clause.source);
3500 if (const auto *ifClause =
3501 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
3502 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
3503 } else if (const auto *asyncClause =
3504 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
3505 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
3509 // Prepare the operand segment size attribute and the operands value range.
3510 llvm::SmallVector<mlir::Value> operands;
3511 llvm::SmallVector<int32_t> operandSegments;
3512 addOperands(operands, operandSegments, waitOperands);
3513 addOperand(operands, operandSegments, async);
3514 addOperand(operands, operandSegments, waitDevnum);
3515 addOperand(operands, operandSegments, ifCond);
3517 mlir::acc::WaitOp waitOp = createSimpleOp<mlir::acc::WaitOp>(
3518 firOpBuilder, currentLocation, operands, operandSegments);
3520 if (addAsyncAttr)
3521 waitOp.setAsyncAttr(firOpBuilder.getUnitAttr());
3524 template <typename GlobalOp, typename EntryOp, typename DeclareOp,
3525 typename ExitOp>
3526 static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder,
3527 fir::FirOpBuilder &builder,
3528 mlir::Location loc, fir::GlobalOp globalOp,
3529 mlir::acc::DataClause clause,
3530 const std::string &declareGlobalName,
3531 bool implicit, std::stringstream &asFortran) {
3532 GlobalOp declareGlobalOp =
3533 modBuilder.create<GlobalOp>(loc, declareGlobalName);
3534 builder.createBlock(&declareGlobalOp.getRegion(),
3535 declareGlobalOp.getRegion().end(), {}, {});
3536 builder.setInsertionPointToEnd(&declareGlobalOp.getRegion().back());
3538 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
3539 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3540 addDeclareAttr(builder, addrOp, clause);
3542 llvm::SmallVector<mlir::Value> bounds;
3543 EntryOp entryOp = createDataEntryOp<EntryOp>(
3544 builder, loc, addrOp.getResTy(), asFortran, bounds,
3545 /*structured=*/false, implicit, clause, addrOp.getResTy().getType(),
3546 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
3547 if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>)
3548 builder.create<DeclareOp>(
3549 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
3550 mlir::ValueRange(entryOp.getAccVar()));
3551 else
3552 builder.create<DeclareOp>(loc, mlir::Value{},
3553 mlir::ValueRange(entryOp.getAccVar()));
3554 if constexpr (std::is_same_v<GlobalOp, mlir::acc::GlobalDestructorOp>) {
3555 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccVar(),
3556 entryOp.getBounds(), entryOp.getAsyncOperands(),
3557 entryOp.getAsyncOperandsDeviceTypeAttr(),
3558 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
3559 /*structured=*/false, /*implicit=*/false,
3560 builder.getStringAttr(*entryOp.getName()));
3562 builder.create<mlir::acc::TerminatorOp>(loc);
3563 modBuilder.setInsertionPointAfter(declareGlobalOp);
3566 template <typename EntryOp>
3567 static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
3568 fir::FirOpBuilder &builder,
3569 mlir::Location loc, fir::GlobalOp &globalOp,
3570 mlir::acc::DataClause clause) {
3571 std::stringstream registerFuncName;
3572 registerFuncName << globalOp.getSymName().str()
3573 << Fortran::lower::declarePostAllocSuffix.str();
3574 auto registerFuncOp =
3575 createDeclareFunc(modBuilder, builder, loc, registerFuncName.str());
3577 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
3578 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3580 std::stringstream asFortran;
3581 asFortran << Fortran::lower::mangle::demangleName(globalOp.getSymName());
3582 std::stringstream asFortranDesc;
3583 asFortranDesc << asFortran.str();
3584 if (unwrapFirBox)
3585 asFortranDesc << accFirDescriptorPostfix.str();
3586 llvm::SmallVector<mlir::Value> bounds;
3588 // Updating descriptor must occur before the mapping of the data so that
3589 // attached data pointer is not overwritten.
3590 mlir::acc::UpdateDeviceOp updateDeviceOp =
3591 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
3592 builder, loc, addrOp, asFortranDesc, bounds,
3593 /*structured=*/false, /*implicit=*/true,
3594 mlir::acc::DataClause::acc_update_device, addrOp.getType(),
3595 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
3596 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
3597 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
3598 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
3600 if (unwrapFirBox) {
3601 auto loadOp = builder.create<fir::LoadOp>(loc, addrOp.getResult());
3602 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
3603 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
3604 EntryOp entryOp = createDataEntryOp<EntryOp>(
3605 builder, loc, boxAddrOp.getResult(), asFortran, bounds,
3606 /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType(),
3607 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
3608 builder.create<mlir::acc::DeclareEnterOp>(
3609 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
3610 mlir::ValueRange(entryOp.getAccVar()));
3613 modBuilder.setInsertionPointAfter(registerFuncOp);
3616 /// Action to be performed on deallocation are split in two distinct functions.
3617 /// - Pre deallocation function includes all the action to be performed before
3618 /// the actual deallocation is done on the host side.
3619 /// - Post deallocation function includes update to the descriptor.
3620 template <typename ExitOp>
3621 static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
3622 fir::FirOpBuilder &builder,
3623 mlir::Location loc,
3624 fir::GlobalOp &globalOp,
3625 mlir::acc::DataClause clause) {
3626 std::stringstream asFortran;
3627 asFortran << Fortran::lower::mangle::demangleName(globalOp.getSymName());
3629 // If FIR box semantics are being unwrapped, then a pre-dealloc function
3630 // needs generated to ensure to delete the device data pointed to by the
3631 // descriptor before this information is lost.
3632 if (unwrapFirBox) {
3633 // Generate the pre dealloc function.
3634 std::stringstream preDeallocFuncName;
3635 preDeallocFuncName << globalOp.getSymName().str()
3636 << Fortran::lower::declarePreDeallocSuffix.str();
3637 auto preDeallocOp =
3638 createDeclareFunc(modBuilder, builder, loc, preDeallocFuncName.str());
3640 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
3641 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3642 auto loadOp = builder.create<fir::LoadOp>(loc, addrOp.getResult());
3643 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
3644 mlir::Value var = boxAddrOp.getResult();
3645 addDeclareAttr(builder, var.getDefiningOp(), clause);
3647 llvm::SmallVector<mlir::Value> bounds;
3648 mlir::acc::GetDevicePtrOp entryOp =
3649 createDataEntryOp<mlir::acc::GetDevicePtrOp>(
3650 builder, loc, var, asFortran, bounds,
3651 /*structured=*/false, /*implicit=*/false, clause, var.getType(),
3652 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
3654 builder.create<mlir::acc::DeclareExitOp>(
3655 loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccVar()));
3657 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
3658 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
3659 builder.create<ExitOp>(
3660 entryOp.getLoc(), entryOp.getAccVar(), entryOp.getVar(),
3661 entryOp.getBounds(), entryOp.getAsyncOperands(),
3662 entryOp.getAsyncOperandsDeviceTypeAttr(), entryOp.getAsyncOnlyAttr(),
3663 entryOp.getDataClause(),
3664 /*structured=*/false, /*implicit=*/false,
3665 builder.getStringAttr(*entryOp.getName()));
3666 else
3667 builder.create<ExitOp>(
3668 entryOp.getLoc(), entryOp.getAccVar(), entryOp.getBounds(),
3669 entryOp.getAsyncOperands(), entryOp.getAsyncOperandsDeviceTypeAttr(),
3670 entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
3671 /*structured=*/false, /*implicit=*/false,
3672 builder.getStringAttr(*entryOp.getName()));
3674 // Generate the post dealloc function.
3675 modBuilder.setInsertionPointAfter(preDeallocOp);
3678 std::stringstream postDeallocFuncName;
3679 postDeallocFuncName << globalOp.getSymName().str()
3680 << Fortran::lower::declarePostDeallocSuffix.str();
3681 auto postDeallocOp =
3682 createDeclareFunc(modBuilder, builder, loc, postDeallocFuncName.str());
3684 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
3685 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3686 if (unwrapFirBox)
3687 asFortran << accFirDescriptorPostfix.str();
3688 llvm::SmallVector<mlir::Value> bounds;
3689 mlir::acc::UpdateDeviceOp updateDeviceOp =
3690 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
3691 builder, loc, addrOp, asFortran, bounds,
3692 /*structured=*/false, /*implicit=*/true,
3693 mlir::acc::DataClause::acc_update_device, addrOp.getType(),
3694 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
3695 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
3696 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
3697 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
3698 modBuilder.setInsertionPointAfter(postDeallocOp);
3701 template <typename EntryOp, typename ExitOp>
3702 static void genGlobalCtors(Fortran::lower::AbstractConverter &converter,
3703 mlir::OpBuilder &modBuilder,
3704 const Fortran::parser::AccObjectList &accObjectList,
3705 mlir::acc::DataClause clause) {
3706 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3707 for (const auto &accObject : accObjectList.v) {
3708 mlir::Location operandLocation = genOperandLocation(converter, accObject);
3709 Fortran::common::visit(
3710 Fortran::common::visitors{
3711 [&](const Fortran::parser::Designator &designator) {
3712 if (const auto *name =
3713 Fortran::semantics::getDesignatorNameIfDataRef(
3714 designator)) {
3715 std::string globalName = converter.mangleName(*name->symbol);
3716 fir::GlobalOp globalOp = builder.getNamedGlobal(globalName);
3717 std::stringstream declareGlobalCtorName;
3718 declareGlobalCtorName << globalName << "_acc_ctor";
3719 std::stringstream declareGlobalDtorName;
3720 declareGlobalDtorName << globalName << "_acc_dtor";
3721 std::stringstream asFortran;
3722 asFortran << name->symbol->name().ToString();
3724 if (builder.getModule()
3725 .lookupSymbol<mlir::acc::GlobalConstructorOp>(
3726 declareGlobalCtorName.str()))
3727 return;
3729 if (!globalOp) {
3730 if (Fortran::semantics::FindEquivalenceSet(*name->symbol)) {
3731 for (Fortran::semantics::EquivalenceObject eqObj :
3732 *Fortran::semantics::FindEquivalenceSet(
3733 *name->symbol)) {
3734 std::string eqName = converter.mangleName(eqObj.symbol);
3735 globalOp = builder.getNamedGlobal(eqName);
3736 if (globalOp)
3737 break;
3740 if (!globalOp)
3741 llvm::report_fatal_error(
3742 "could not retrieve global symbol");
3743 } else {
3744 llvm::report_fatal_error(
3745 "could not retrieve global symbol");
3749 addDeclareAttr(builder, globalOp.getOperation(), clause);
3750 auto crtPos = builder.saveInsertionPoint();
3751 modBuilder.setInsertionPointAfter(globalOp);
3752 if (mlir::isa<fir::BaseBoxType>(
3753 fir::unwrapRefType(globalOp.getType()))) {
3754 createDeclareGlobalOp<mlir::acc::GlobalConstructorOp,
3755 mlir::acc::CopyinOp,
3756 mlir::acc::DeclareEnterOp, ExitOp>(
3757 modBuilder, builder, operandLocation, globalOp, clause,
3758 declareGlobalCtorName.str(), /*implicit=*/true,
3759 asFortran);
3760 createDeclareAllocFunc<EntryOp>(
3761 modBuilder, builder, operandLocation, globalOp, clause);
3762 if constexpr (!std::is_same_v<EntryOp, ExitOp>)
3763 createDeclareDeallocFunc<ExitOp>(
3764 modBuilder, builder, operandLocation, globalOp, clause);
3765 } else {
3766 createDeclareGlobalOp<mlir::acc::GlobalConstructorOp, EntryOp,
3767 mlir::acc::DeclareEnterOp, ExitOp>(
3768 modBuilder, builder, operandLocation, globalOp, clause,
3769 declareGlobalCtorName.str(), /*implicit=*/false,
3770 asFortran);
3772 if constexpr (!std::is_same_v<EntryOp, ExitOp>) {
3773 createDeclareGlobalOp<mlir::acc::GlobalDestructorOp,
3774 mlir::acc::GetDevicePtrOp,
3775 mlir::acc::DeclareExitOp, ExitOp>(
3776 modBuilder, builder, operandLocation, globalOp, clause,
3777 declareGlobalDtorName.str(), /*implicit=*/false,
3778 asFortran);
3780 builder.restoreInsertionPoint(crtPos);
3783 [&](const Fortran::parser::Name &name) {
3784 TODO(operandLocation, "OpenACC Global Ctor from parser::Name");
3786 accObject.u);
3790 template <typename Clause, typename EntryOp, typename ExitOp>
3791 static void
3792 genGlobalCtorsWithModifier(Fortran::lower::AbstractConverter &converter,
3793 mlir::OpBuilder &modBuilder, const Clause *x,
3794 Fortran::parser::AccDataModifier::Modifier mod,
3795 const mlir::acc::DataClause clause,
3796 const mlir::acc::DataClause clauseWithModifier) {
3797 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
3798 const auto &accObjectList =
3799 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3800 const auto &modifier =
3801 std::get<std::optional<Fortran::parser::AccDataModifier>>(
3802 listWithModifier.t);
3803 mlir::acc::DataClause dataClause =
3804 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
3805 genGlobalCtors<EntryOp, ExitOp>(converter, modBuilder, accObjectList,
3806 dataClause);
3809 static void
3810 genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
3811 Fortran::semantics::SemanticsContext &semanticsContext,
3812 Fortran::lower::StatementContext &openAccCtx,
3813 mlir::Location loc,
3814 const Fortran::parser::AccClauseList &accClauseList) {
3815 llvm::SmallVector<mlir::Value> dataClauseOperands, copyEntryOperands,
3816 copyinEntryOperands, createEntryOperands, copyoutEntryOperands,
3817 deviceResidentEntryOperands;
3818 Fortran::lower::StatementContext stmtCtx;
3819 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3821 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3822 if (const auto *copyClause =
3823 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
3824 auto crtDataStart = dataClauseOperands.size();
3825 genDeclareDataOperandOperations<mlir::acc::CopyinOp,
3826 mlir::acc::CopyoutOp>(
3827 copyClause->v, converter, semanticsContext, stmtCtx,
3828 dataClauseOperands, mlir::acc::DataClause::acc_copy,
3829 /*structured=*/true, /*implicit=*/false);
3830 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3831 dataClauseOperands.end());
3832 } else if (const auto *createClause =
3833 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
3834 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3835 createClause->v;
3836 const auto &accObjectList =
3837 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3838 auto crtDataStart = dataClauseOperands.size();
3839 genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
3840 accObjectList, converter, semanticsContext, stmtCtx,
3841 dataClauseOperands, mlir::acc::DataClause::acc_create,
3842 /*structured=*/true, /*implicit=*/false);
3843 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3844 dataClauseOperands.end());
3845 } else if (const auto *presentClause =
3846 std::get_if<Fortran::parser::AccClause::Present>(
3847 &clause.u)) {
3848 genDeclareDataOperandOperations<mlir::acc::PresentOp,
3849 mlir::acc::PresentOp>(
3850 presentClause->v, converter, semanticsContext, stmtCtx,
3851 dataClauseOperands, mlir::acc::DataClause::acc_present,
3852 /*structured=*/true, /*implicit=*/false);
3853 } else if (const auto *copyinClause =
3854 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
3855 auto crtDataStart = dataClauseOperands.size();
3856 genDeclareDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
3857 mlir::acc::DeleteOp>(
3858 copyinClause, converter, semanticsContext, stmtCtx,
3859 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
3860 dataClauseOperands, mlir::acc::DataClause::acc_copyin,
3861 mlir::acc::DataClause::acc_copyin_readonly);
3862 copyinEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3863 dataClauseOperands.end());
3864 } else if (const auto *copyoutClause =
3865 std::get_if<Fortran::parser::AccClause::Copyout>(
3866 &clause.u)) {
3867 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3868 copyoutClause->v;
3869 const auto &accObjectList =
3870 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3871 auto crtDataStart = dataClauseOperands.size();
3872 genDeclareDataOperandOperations<mlir::acc::CreateOp,
3873 mlir::acc::CopyoutOp>(
3874 accObjectList, converter, semanticsContext, stmtCtx,
3875 dataClauseOperands, mlir::acc::DataClause::acc_copyout,
3876 /*structured=*/true, /*implicit=*/false);
3877 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3878 dataClauseOperands.end());
3879 } else if (const auto *devicePtrClause =
3880 std::get_if<Fortran::parser::AccClause::Deviceptr>(
3881 &clause.u)) {
3882 genDeclareDataOperandOperations<mlir::acc::DevicePtrOp,
3883 mlir::acc::DevicePtrOp>(
3884 devicePtrClause->v, converter, semanticsContext, stmtCtx,
3885 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
3886 /*structured=*/true, /*implicit=*/false);
3887 } else if (const auto *linkClause =
3888 std::get_if<Fortran::parser::AccClause::Link>(&clause.u)) {
3889 genDeclareDataOperandOperations<mlir::acc::DeclareLinkOp,
3890 mlir::acc::DeclareLinkOp>(
3891 linkClause->v, converter, semanticsContext, stmtCtx,
3892 dataClauseOperands, mlir::acc::DataClause::acc_declare_link,
3893 /*structured=*/true, /*implicit=*/false);
3894 } else if (const auto *deviceResidentClause =
3895 std::get_if<Fortran::parser::AccClause::DeviceResident>(
3896 &clause.u)) {
3897 auto crtDataStart = dataClauseOperands.size();
3898 genDeclareDataOperandOperations<mlir::acc::DeclareDeviceResidentOp,
3899 mlir::acc::DeleteOp>(
3900 deviceResidentClause->v, converter, semanticsContext, stmtCtx,
3901 dataClauseOperands,
3902 mlir::acc::DataClause::acc_declare_device_resident,
3903 /*structured=*/true, /*implicit=*/false);
3904 deviceResidentEntryOperands.append(
3905 dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end());
3906 } else {
3907 mlir::Location clauseLocation = converter.genLocation(clause.source);
3908 TODO(clauseLocation, "clause on declare directive");
3912 mlir::func::FuncOp funcOp = builder.getFunction();
3913 auto ops = funcOp.getOps<mlir::acc::DeclareEnterOp>();
3914 mlir::Value declareToken;
3915 if (ops.empty()) {
3916 declareToken = builder.create<mlir::acc::DeclareEnterOp>(
3917 loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
3918 dataClauseOperands);
3919 } else {
3920 auto declareOp = *ops.begin();
3921 auto newDeclareOp = builder.create<mlir::acc::DeclareEnterOp>(
3922 loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
3923 declareOp.getDataClauseOperands());
3924 newDeclareOp.getDataClauseOperandsMutable().append(dataClauseOperands);
3925 declareToken = newDeclareOp.getToken();
3926 declareOp.erase();
3929 openAccCtx.attachCleanup([&builder, loc, createEntryOperands,
3930 copyEntryOperands, copyinEntryOperands,
3931 copyoutEntryOperands, deviceResidentEntryOperands,
3932 declareToken]() {
3933 llvm::SmallVector<mlir::Value> operands;
3934 operands.append(createEntryOperands);
3935 operands.append(deviceResidentEntryOperands);
3936 operands.append(copyEntryOperands);
3937 operands.append(copyinEntryOperands);
3938 operands.append(copyoutEntryOperands);
3940 mlir::func::FuncOp funcOp = builder.getFunction();
3941 auto ops = funcOp.getOps<mlir::acc::DeclareExitOp>();
3942 if (ops.empty()) {
3943 builder.create<mlir::acc::DeclareExitOp>(loc, declareToken, operands);
3944 } else {
3945 auto declareOp = *ops.begin();
3946 declareOp.getDataClauseOperandsMutable().append(operands);
3949 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
3950 builder, createEntryOperands, /*structured=*/true);
3951 genDataExitOperations<mlir::acc::DeclareDeviceResidentOp,
3952 mlir::acc::DeleteOp>(
3953 builder, deviceResidentEntryOperands, /*structured=*/true);
3954 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
3955 builder, copyEntryOperands, /*structured=*/true);
3956 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::DeleteOp>(
3957 builder, copyinEntryOperands, /*structured=*/true);
3958 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>(
3959 builder, copyoutEntryOperands, /*structured=*/true);
3963 static void
3964 genDeclareInModule(Fortran::lower::AbstractConverter &converter,
3965 mlir::ModuleOp moduleOp,
3966 const Fortran::parser::AccClauseList &accClauseList) {
3967 mlir::OpBuilder modBuilder(moduleOp.getBodyRegion());
3968 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3969 if (const auto *createClause =
3970 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
3971 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3972 createClause->v;
3973 const auto &accObjectList =
3974 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3975 genGlobalCtors<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
3976 converter, modBuilder, accObjectList,
3977 mlir::acc::DataClause::acc_create);
3978 } else if (const auto *copyinClause =
3979 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
3980 genGlobalCtorsWithModifier<Fortran::parser::AccClause::Copyin,
3981 mlir::acc::CopyinOp, mlir::acc::DeleteOp>(
3982 converter, modBuilder, copyinClause,
3983 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
3984 mlir::acc::DataClause::acc_copyin,
3985 mlir::acc::DataClause::acc_copyin_readonly);
3986 } else if (const auto *deviceResidentClause =
3987 std::get_if<Fortran::parser::AccClause::DeviceResident>(
3988 &clause.u)) {
3989 genGlobalCtors<mlir::acc::DeclareDeviceResidentOp, mlir::acc::DeleteOp>(
3990 converter, modBuilder, deviceResidentClause->v,
3991 mlir::acc::DataClause::acc_declare_device_resident);
3992 } else if (const auto *linkClause =
3993 std::get_if<Fortran::parser::AccClause::Link>(&clause.u)) {
3994 genGlobalCtors<mlir::acc::DeclareLinkOp, mlir::acc::DeclareLinkOp>(
3995 converter, modBuilder, linkClause->v,
3996 mlir::acc::DataClause::acc_declare_link);
3997 } else {
3998 llvm::report_fatal_error("unsupported clause on DECLARE directive");
4003 static void genACC(Fortran::lower::AbstractConverter &converter,
4004 Fortran::semantics::SemanticsContext &semanticsContext,
4005 Fortran::lower::StatementContext &openAccCtx,
4006 const Fortran::parser::OpenACCStandaloneDeclarativeConstruct
4007 &declareConstruct) {
4009 const auto &declarativeDir =
4010 std::get<Fortran::parser::AccDeclarativeDirective>(declareConstruct.t);
4011 mlir::Location directiveLocation =
4012 converter.genLocation(declarativeDir.source);
4013 const auto &accClauseList =
4014 std::get<Fortran::parser::AccClauseList>(declareConstruct.t);
4016 if (declarativeDir.v == llvm::acc::Directive::ACCD_declare) {
4017 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
4018 auto moduleOp =
4019 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
4020 auto funcOp =
4021 builder.getBlock()->getParent()->getParentOfType<mlir::func::FuncOp>();
4022 if (funcOp)
4023 genDeclareInFunction(converter, semanticsContext, openAccCtx,
4024 directiveLocation, accClauseList);
4025 else if (moduleOp)
4026 genDeclareInModule(converter, moduleOp, accClauseList);
4027 return;
4029 llvm_unreachable("unsupported declarative directive");
4032 static bool hasDeviceType(llvm::SmallVector<mlir::Attribute> &arrayAttr,
4033 mlir::acc::DeviceType deviceType) {
4034 for (auto attr : arrayAttr) {
4035 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4036 if (deviceTypeAttr.getValue() == deviceType)
4037 return true;
4039 return false;
4042 template <typename RetTy, typename AttrTy>
4043 static std::optional<RetTy>
4044 getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
4045 llvm::SmallVector<mlir::Attribute> &deviceTypes,
4046 mlir::acc::DeviceType deviceType) {
4047 assert(attributes.size() == deviceTypes.size() &&
4048 "expect same number of attributes");
4049 for (auto it : llvm::enumerate(deviceTypes)) {
4050 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(it.value());
4051 if (deviceTypeAttr.getValue() == deviceType) {
4052 if constexpr (std::is_same_v<mlir::StringAttr, AttrTy>) {
4053 auto strAttr = mlir::dyn_cast<AttrTy>(attributes[it.index()]);
4054 return strAttr.getValue();
4055 } else if constexpr (std::is_same_v<mlir::IntegerAttr, AttrTy>) {
4056 auto intAttr =
4057 mlir::dyn_cast<mlir::IntegerAttr>(attributes[it.index()]);
4058 return intAttr.getInt();
4062 return std::nullopt;
4065 static bool compareDeviceTypeInfo(
4066 mlir::acc::RoutineOp op,
4067 llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
4068 llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
4069 llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
4070 llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
4071 llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
4072 llvm::SmallVector<mlir::Attribute> &seqArrayAttr,
4073 llvm::SmallVector<mlir::Attribute> &workerArrayAttr,
4074 llvm::SmallVector<mlir::Attribute> &vectorArrayAttr) {
4075 for (uint32_t dtypeInt = 0;
4076 dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
4077 auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
4078 if (op.getBindNameValue(dtype) !=
4079 getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4080 bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
4081 return false;
4082 if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
4083 return false;
4084 if (op.getGangDimValue(dtype) !=
4085 getAttributeValueByDeviceType<int64_t, mlir::IntegerAttr>(
4086 gangDimArrayAttr, gangDimDeviceTypeArrayAttr, dtype))
4087 return false;
4088 if (op.hasSeq(dtype) != hasDeviceType(seqArrayAttr, dtype))
4089 return false;
4090 if (op.hasWorker(dtype) != hasDeviceType(workerArrayAttr, dtype))
4091 return false;
4092 if (op.hasVector(dtype) != hasDeviceType(vectorArrayAttr, dtype))
4093 return false;
4095 return true;
4098 static void attachRoutineInfo(mlir::func::FuncOp func,
4099 mlir::SymbolRefAttr routineAttr) {
4100 llvm::SmallVector<mlir::SymbolRefAttr> routines;
4101 if (func.getOperation()->hasAttr(mlir::acc::getRoutineInfoAttrName())) {
4102 auto routineInfo =
4103 func.getOperation()->getAttrOfType<mlir::acc::RoutineInfoAttr>(
4104 mlir::acc::getRoutineInfoAttrName());
4105 routines.append(routineInfo.getAccRoutines().begin(),
4106 routineInfo.getAccRoutines().end());
4108 routines.push_back(routineAttr);
4109 func.getOperation()->setAttr(
4110 mlir::acc::getRoutineInfoAttrName(),
4111 mlir::acc::RoutineInfoAttr::get(func.getContext(), routines));
4114 void Fortran::lower::genOpenACCRoutineConstruct(
4115 Fortran::lower::AbstractConverter &converter,
4116 Fortran::semantics::SemanticsContext &semanticsContext, mlir::ModuleOp mod,
4117 const Fortran::parser::OpenACCRoutineConstruct &routineConstruct,
4118 Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
4119 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
4120 mlir::Location loc = converter.genLocation(routineConstruct.source);
4121 std::optional<Fortran::parser::Name> name =
4122 std::get<std::optional<Fortran::parser::Name>>(routineConstruct.t);
4123 const auto &clauses =
4124 std::get<Fortran::parser::AccClauseList>(routineConstruct.t);
4125 mlir::func::FuncOp funcOp;
4126 std::string funcName;
4127 if (name) {
4128 funcName = converter.mangleName(*name->symbol);
4129 funcOp =
4130 builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName);
4131 } else {
4132 Fortran::semantics::Scope &scope =
4133 semanticsContext.FindScope(routineConstruct.source);
4134 const Fortran::semantics::Scope &progUnit{GetProgramUnitContaining(scope)};
4135 const auto *subpDetails{
4136 progUnit.symbol()
4137 ? progUnit.symbol()
4138 ->detailsIf<Fortran::semantics::SubprogramDetails>()
4139 : nullptr};
4140 if (subpDetails && subpDetails->isInterface()) {
4141 funcName = converter.mangleName(*progUnit.symbol());
4142 funcOp =
4143 builder.getNamedFunction(mod, builder.getMLIRSymbolTable(), funcName);
4144 } else {
4145 funcOp = builder.getFunction();
4146 funcName = funcOp.getName();
4149 bool hasNohost = false;
4151 llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
4152 workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4153 gangDimDeviceTypes, gangDimValues;
4155 // device_type attribute is set to `none` until a device_type clause is
4156 // encountered.
4157 llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
4158 crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
4159 builder.getContext(), mlir::acc::DeviceType::None));
4161 for (const Fortran::parser::AccClause &clause : clauses.v) {
4162 if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
4163 for (auto crtDeviceTypeAttr : crtDeviceTypes)
4164 seqDeviceTypes.push_back(crtDeviceTypeAttr);
4165 } else if (const auto *gangClause =
4166 std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
4167 if (gangClause->v) {
4168 const Fortran::parser::AccGangArgList &x = *gangClause->v;
4169 for (const Fortran::parser::AccGangArg &gangArg : x.v) {
4170 if (const auto *dim =
4171 std::get_if<Fortran::parser::AccGangArg::Dim>(&gangArg.u)) {
4172 const std::optional<int64_t> dimValue = Fortran::evaluate::ToInt64(
4173 *Fortran::semantics::GetExpr(dim->v));
4174 if (!dimValue)
4175 mlir::emitError(loc,
4176 "dim value must be a constant positive integer");
4177 mlir::Attribute gangDimAttr =
4178 builder.getIntegerAttr(builder.getI64Type(), *dimValue);
4179 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
4180 gangDimValues.push_back(gangDimAttr);
4181 gangDimDeviceTypes.push_back(crtDeviceTypeAttr);
4185 } else {
4186 for (auto crtDeviceTypeAttr : crtDeviceTypes)
4187 gangDeviceTypes.push_back(crtDeviceTypeAttr);
4189 } else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
4190 for (auto crtDeviceTypeAttr : crtDeviceTypes)
4191 vectorDeviceTypes.push_back(crtDeviceTypeAttr);
4192 } else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
4193 for (auto crtDeviceTypeAttr : crtDeviceTypes)
4194 workerDeviceTypes.push_back(crtDeviceTypeAttr);
4195 } else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
4196 hasNohost = true;
4197 } else if (const auto *bindClause =
4198 std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
4199 if (const auto *name =
4200 std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
4201 mlir::Attribute bindNameAttr =
4202 builder.getStringAttr(converter.mangleName(*name->symbol));
4203 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
4204 bindNames.push_back(bindNameAttr);
4205 bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
4207 } else if (const auto charExpr =
4208 std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
4209 &bindClause->v.u)) {
4210 const std::optional<std::string> name =
4211 Fortran::semantics::GetConstExpr<std::string>(semanticsContext,
4212 *charExpr);
4213 if (!name)
4214 mlir::emitError(loc, "Could not retrieve the bind name");
4216 mlir::Attribute bindNameAttr = builder.getStringAttr(*name);
4217 for (auto crtDeviceTypeAttr : crtDeviceTypes) {
4218 bindNames.push_back(bindNameAttr);
4219 bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
4222 } else if (const auto *deviceTypeClause =
4223 std::get_if<Fortran::parser::AccClause::DeviceType>(
4224 &clause.u)) {
4225 crtDeviceTypes.clear();
4226 gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
4230 mlir::OpBuilder modBuilder(mod.getBodyRegion());
4231 std::stringstream routineOpName;
4232 routineOpName << accRoutinePrefix.str() << routineCounter++;
4234 for (auto routineOp : mod.getOps<mlir::acc::RoutineOp>()) {
4235 if (routineOp.getFuncName().str().compare(funcName) == 0) {
4236 // If the routine is already specified with the same clauses, just skip
4237 // the operation creation.
4238 if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
4239 gangDeviceTypes, gangDimValues,
4240 gangDimDeviceTypes, seqDeviceTypes,
4241 workerDeviceTypes, vectorDeviceTypes) &&
4242 routineOp.getNohost() == hasNohost)
4243 return;
4244 mlir::emitError(loc, "Routine already specified with different clauses");
4248 modBuilder.create<mlir::acc::RoutineOp>(
4249 loc, routineOpName.str(), funcName,
4250 bindNames.empty() ? nullptr : builder.getArrayAttr(bindNames),
4251 bindNameDeviceTypes.empty() ? nullptr
4252 : builder.getArrayAttr(bindNameDeviceTypes),
4253 workerDeviceTypes.empty() ? nullptr
4254 : builder.getArrayAttr(workerDeviceTypes),
4255 vectorDeviceTypes.empty() ? nullptr
4256 : builder.getArrayAttr(vectorDeviceTypes),
4257 seqDeviceTypes.empty() ? nullptr : builder.getArrayAttr(seqDeviceTypes),
4258 hasNohost, /*implicit=*/false,
4259 gangDeviceTypes.empty() ? nullptr : builder.getArrayAttr(gangDeviceTypes),
4260 gangDimValues.empty() ? nullptr : builder.getArrayAttr(gangDimValues),
4261 gangDimDeviceTypes.empty() ? nullptr
4262 : builder.getArrayAttr(gangDimDeviceTypes));
4264 if (funcOp)
4265 attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpName.str()));
4266 else
4267 // FuncOp is not lowered yet. Keep the information so the routine info
4268 // can be attached later to the funcOp.
4269 accRoutineInfos.push_back(std::make_pair(
4270 funcName, builder.getSymbolRefAttr(routineOpName.str())));
4273 void Fortran::lower::finalizeOpenACCRoutineAttachment(
4274 mlir::ModuleOp mod,
4275 Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
4276 for (auto &mapping : accRoutineInfos) {
4277 mlir::func::FuncOp funcOp =
4278 mod.lookupSymbol<mlir::func::FuncOp>(mapping.first);
4279 if (!funcOp)
4280 mlir::emitWarning(mod.getLoc(),
4281 llvm::Twine("function '") + llvm::Twine(mapping.first) +
4282 llvm::Twine("' in acc routine directive is not "
4283 "found in this translation unit."));
4284 else
4285 attachRoutineInfo(funcOp, mapping.second);
4287 accRoutineInfos.clear();
4290 static void
4291 genACC(Fortran::lower::AbstractConverter &converter,
4292 Fortran::lower::pft::Evaluation &eval,
4293 const Fortran::parser::OpenACCAtomicConstruct &atomicConstruct) {
4295 mlir::Location loc = converter.genLocation(atomicConstruct.source);
4296 Fortran::common::visit(
4297 Fortran::common::visitors{
4298 [&](const Fortran::parser::AccAtomicRead &atomicRead) {
4299 Fortran::lower::genOmpAccAtomicRead<Fortran::parser::AccAtomicRead,
4300 void>(converter, atomicRead,
4301 loc);
4303 [&](const Fortran::parser::AccAtomicWrite &atomicWrite) {
4304 Fortran::lower::genOmpAccAtomicWrite<
4305 Fortran::parser::AccAtomicWrite, void>(converter, atomicWrite,
4306 loc);
4308 [&](const Fortran::parser::AccAtomicUpdate &atomicUpdate) {
4309 Fortran::lower::genOmpAccAtomicUpdate<
4310 Fortran::parser::AccAtomicUpdate, void>(converter, atomicUpdate,
4311 loc);
4313 [&](const Fortran::parser::AccAtomicCapture &atomicCapture) {
4314 Fortran::lower::genOmpAccAtomicCapture<
4315 Fortran::parser::AccAtomicCapture, void>(converter,
4316 atomicCapture, loc);
4319 atomicConstruct.u);
4322 static void
4323 genACC(Fortran::lower::AbstractConverter &converter,
4324 Fortran::semantics::SemanticsContext &semanticsContext,
4325 const Fortran::parser::OpenACCCacheConstruct &cacheConstruct) {
4326 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
4327 auto loopOp = builder.getRegion().getParentOfType<mlir::acc::LoopOp>();
4328 auto crtPos = builder.saveInsertionPoint();
4329 if (loopOp) {
4330 builder.setInsertionPoint(loopOp);
4331 Fortran::lower::StatementContext stmtCtx;
4332 llvm::SmallVector<mlir::Value> cacheOperands;
4333 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
4334 std::get<Fortran::parser::AccObjectListWithModifier>(cacheConstruct.t);
4335 const auto &accObjectList =
4336 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
4337 const auto &modifier =
4338 std::get<std::optional<Fortran::parser::AccDataModifier>>(
4339 listWithModifier.t);
4341 mlir::acc::DataClause dataClause = mlir::acc::DataClause::acc_cache;
4342 if (modifier &&
4343 (*modifier).v == Fortran::parser::AccDataModifier::Modifier::ReadOnly)
4344 dataClause = mlir::acc::DataClause::acc_cache_readonly;
4345 genDataOperandOperations<mlir::acc::CacheOp>(
4346 accObjectList, converter, semanticsContext, stmtCtx, cacheOperands,
4347 dataClause,
4348 /*structured=*/true, /*implicit=*/false,
4349 /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{},
4350 /*setDeclareAttr*/ false);
4351 loopOp.getCacheOperandsMutable().append(cacheOperands);
4352 } else {
4353 llvm::report_fatal_error(
4354 "could not find loop to attach OpenACC cache information.");
4356 builder.restoreInsertionPoint(crtPos);
4359 mlir::Value Fortran::lower::genOpenACCConstruct(
4360 Fortran::lower::AbstractConverter &converter,
4361 Fortran::semantics::SemanticsContext &semanticsContext,
4362 Fortran::lower::pft::Evaluation &eval,
4363 const Fortran::parser::OpenACCConstruct &accConstruct) {
4365 mlir::Value exitCond;
4366 Fortran::common::visit(
4367 common::visitors{
4368 [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
4369 genACC(converter, semanticsContext, eval, blockConstruct);
4371 [&](const Fortran::parser::OpenACCCombinedConstruct
4372 &combinedConstruct) {
4373 genACC(converter, semanticsContext, eval, combinedConstruct);
4375 [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
4376 exitCond = genACC(converter, semanticsContext, eval, loopConstruct);
4378 [&](const Fortran::parser::OpenACCStandaloneConstruct
4379 &standaloneConstruct) {
4380 genACC(converter, semanticsContext, standaloneConstruct);
4382 [&](const Fortran::parser::OpenACCCacheConstruct &cacheConstruct) {
4383 genACC(converter, semanticsContext, cacheConstruct);
4385 [&](const Fortran::parser::OpenACCWaitConstruct &waitConstruct) {
4386 genACC(converter, waitConstruct);
4388 [&](const Fortran::parser::OpenACCAtomicConstruct &atomicConstruct) {
4389 genACC(converter, eval, atomicConstruct);
4391 [&](const Fortran::parser::OpenACCEndConstruct &) {
4392 // No op
4395 accConstruct.u);
4396 return exitCond;
4399 void Fortran::lower::genOpenACCDeclarativeConstruct(
4400 Fortran::lower::AbstractConverter &converter,
4401 Fortran::semantics::SemanticsContext &semanticsContext,
4402 Fortran::lower::StatementContext &openAccCtx,
4403 const Fortran::parser::OpenACCDeclarativeConstruct &accDeclConstruct,
4404 Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
4406 Fortran::common::visit(
4407 common::visitors{
4408 [&](const Fortran::parser::OpenACCStandaloneDeclarativeConstruct
4409 &standaloneDeclarativeConstruct) {
4410 genACC(converter, semanticsContext, openAccCtx,
4411 standaloneDeclarativeConstruct);
4413 [&](const Fortran::parser::OpenACCRoutineConstruct
4414 &routineConstruct) {
4415 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
4416 mlir::ModuleOp mod = builder.getModule();
4417 Fortran::lower::genOpenACCRoutineConstruct(
4418 converter, semanticsContext, mod, routineConstruct,
4419 accRoutineInfos);
4422 accDeclConstruct.u);
4425 void Fortran::lower::attachDeclarePostAllocAction(
4426 AbstractConverter &converter, fir::FirOpBuilder &builder,
4427 const Fortran::semantics::Symbol &sym) {
4428 std::stringstream fctName;
4429 fctName << converter.mangleName(sym) << declarePostAllocSuffix.str();
4430 mlir::Operation *op = &builder.getInsertionBlock()->back();
4432 if (auto resOp = mlir::dyn_cast<fir::ResultOp>(*op)) {
4433 assert(resOp.getOperands().size() == 0 &&
4434 "expect only fir.result op with no operand");
4435 op = op->getPrevNode();
4437 assert(op && "expect operation to attach the post allocation action");
4439 if (op->hasAttr(mlir::acc::getDeclareActionAttrName())) {
4440 auto attr = op->getAttrOfType<mlir::acc::DeclareActionAttr>(
4441 mlir::acc::getDeclareActionAttrName());
4442 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4443 mlir::acc::DeclareActionAttr::get(
4444 builder.getContext(), attr.getPreAlloc(),
4445 /*postAlloc=*/builder.getSymbolRefAttr(fctName.str()),
4446 attr.getPreDealloc(), attr.getPostDealloc()));
4447 } else {
4448 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4449 mlir::acc::DeclareActionAttr::get(
4450 builder.getContext(),
4451 /*preAlloc=*/{},
4452 /*postAlloc=*/builder.getSymbolRefAttr(fctName.str()),
4453 /*preDealloc=*/{}, /*postDealloc=*/{}));
4457 void Fortran::lower::attachDeclarePreDeallocAction(
4458 AbstractConverter &converter, fir::FirOpBuilder &builder,
4459 mlir::Value beginOpValue, const Fortran::semantics::Symbol &sym) {
4460 if (!sym.test(Fortran::semantics::Symbol::Flag::AccCreate) &&
4461 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyIn) &&
4462 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyInReadOnly) &&
4463 !sym.test(Fortran::semantics::Symbol::Flag::AccCopy) &&
4464 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyOut) &&
4465 !sym.test(Fortran::semantics::Symbol::Flag::AccDeviceResident))
4466 return;
4468 std::stringstream fctName;
4469 fctName << converter.mangleName(sym) << declarePreDeallocSuffix.str();
4471 auto *op = beginOpValue.getDefiningOp();
4472 if (op->hasAttr(mlir::acc::getDeclareActionAttrName())) {
4473 auto attr = op->getAttrOfType<mlir::acc::DeclareActionAttr>(
4474 mlir::acc::getDeclareActionAttrName());
4475 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4476 mlir::acc::DeclareActionAttr::get(
4477 builder.getContext(), attr.getPreAlloc(),
4478 attr.getPostAlloc(),
4479 /*preDealloc=*/builder.getSymbolRefAttr(fctName.str()),
4480 attr.getPostDealloc()));
4481 } else {
4482 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4483 mlir::acc::DeclareActionAttr::get(
4484 builder.getContext(),
4485 /*preAlloc=*/{}, /*postAlloc=*/{},
4486 /*preDealloc=*/builder.getSymbolRefAttr(fctName.str()),
4487 /*postDealloc=*/{}));
4491 void Fortran::lower::attachDeclarePostDeallocAction(
4492 AbstractConverter &converter, fir::FirOpBuilder &builder,
4493 const Fortran::semantics::Symbol &sym) {
4494 if (!sym.test(Fortran::semantics::Symbol::Flag::AccCreate) &&
4495 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyIn) &&
4496 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyInReadOnly) &&
4497 !sym.test(Fortran::semantics::Symbol::Flag::AccCopy) &&
4498 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyOut) &&
4499 !sym.test(Fortran::semantics::Symbol::Flag::AccDeviceResident))
4500 return;
4502 std::stringstream fctName;
4503 fctName << converter.mangleName(sym) << declarePostDeallocSuffix.str();
4504 mlir::Operation *op = &builder.getInsertionBlock()->back();
4505 if (auto resOp = mlir::dyn_cast<fir::ResultOp>(*op)) {
4506 assert(resOp.getOperands().size() == 0 &&
4507 "expect only fir.result op with no operand");
4508 op = op->getPrevNode();
4510 assert(op && "expect operation to attach the post deallocation action");
4511 if (op->hasAttr(mlir::acc::getDeclareActionAttrName())) {
4512 auto attr = op->getAttrOfType<mlir::acc::DeclareActionAttr>(
4513 mlir::acc::getDeclareActionAttrName());
4514 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4515 mlir::acc::DeclareActionAttr::get(
4516 builder.getContext(), attr.getPreAlloc(),
4517 attr.getPostAlloc(), attr.getPreDealloc(),
4518 /*postDealloc=*/builder.getSymbolRefAttr(fctName.str())));
4519 } else {
4520 op->setAttr(mlir::acc::getDeclareActionAttrName(),
4521 mlir::acc::DeclareActionAttr::get(
4522 builder.getContext(),
4523 /*preAlloc=*/{}, /*postAlloc=*/{}, /*preDealloc=*/{},
4524 /*postDealloc=*/builder.getSymbolRefAttr(fctName.str())));
4528 void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
4529 mlir::Operation *op,
4530 mlir::Location loc) {
4531 if (mlir::isa<mlir::acc::ParallelOp, mlir::acc::LoopOp>(op))
4532 builder.create<mlir::acc::YieldOp>(loc);
4533 else
4534 builder.create<mlir::acc::TerminatorOp>(loc);
4537 bool Fortran::lower::isInOpenACCLoop(fir::FirOpBuilder &builder) {
4538 if (builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
4539 return true;
4540 return false;
4543 void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(
4544 fir::FirOpBuilder &builder) {
4545 if (auto loopOp =
4546 builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
4547 builder.setInsertionPointAfter(loopOp);
4550 void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
4551 mlir::Location loc) {
4552 mlir::Value yieldValue =
4553 builder.createIntegerConstant(loc, builder.getI1Type(), 1);
4554 builder.create<mlir::acc::YieldOp>(loc, yieldValue);
4557 int64_t Fortran::lower::getCollapseValue(
4558 const Fortran::parser::AccClauseList &clauseList) {
4559 for (const Fortran::parser::AccClause &clause : clauseList.v) {
4560 if (const auto *collapseClause =
4561 std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
4562 const parser::AccCollapseArg &arg = collapseClause->v;
4563 const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)};
4564 return *Fortran::semantics::GetIntValue(collapseValue);
4567 return 1;