[Flang][RISCV] Set vscale_range based off zvl*b (#77277)
[llvm-project.git] / flang / lib / Lower / OpenACC.cpp
blobdb9ed72bc8725703e42b1d3467e8aee28a4ad0fa
1 //===-- OpenACC.cpp -- OpenACC directive lowering -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11 //===----------------------------------------------------------------------===//
13 #include "flang/Lower/OpenACC.h"
14 #include "DirectivesCommon.h"
15 #include "flang/Common/idioms.h"
16 #include "flang/Lower/Bridge.h"
17 #include "flang/Lower/ConvertType.h"
18 #include "flang/Lower/Mangler.h"
19 #include "flang/Lower/PFTBuilder.h"
20 #include "flang/Lower/StatementContext.h"
21 #include "flang/Lower/Support/Utils.h"
22 #include "flang/Optimizer/Builder/BoxValue.h"
23 #include "flang/Optimizer/Builder/Complex.h"
24 #include "flang/Optimizer/Builder/FIRBuilder.h"
25 #include "flang/Optimizer/Builder/HLFIRTools.h"
26 #include "flang/Optimizer/Builder/IntrinsicCall.h"
27 #include "flang/Optimizer/Builder/Todo.h"
28 #include "flang/Parser/parse-tree-visitor.h"
29 #include "flang/Parser/parse-tree.h"
30 #include "flang/Semantics/expression.h"
31 #include "flang/Semantics/scope.h"
32 #include "flang/Semantics/tools.h"
33 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
34 #include "llvm/Frontend/OpenACC/ACC.h.inc"
36 // Special value for * passed in device_type or gang clauses.
37 static constexpr std::int64_t starCst = -1;
39 static unsigned routineCounter = 0;
40 static constexpr llvm::StringRef accRoutinePrefix = "acc_routine_";
41 static constexpr llvm::StringRef accPrivateInitName = "acc.private.init";
42 static constexpr llvm::StringRef accReductionInitName = "acc.reduction.init";
43 static constexpr llvm::StringRef accFirDescriptorPostfix = "_desc";
45 static mlir::Location
46 genOperandLocation(Fortran::lower::AbstractConverter &converter,
47 const Fortran::parser::AccObject &accObject) {
48 mlir::Location loc = converter.genUnknownLocation();
49 std::visit(Fortran::common::visitors{
50 [&](const Fortran::parser::Designator &designator) {
51 loc = converter.genLocation(designator.source);
53 [&](const Fortran::parser::Name &name) {
54 loc = converter.genLocation(name.source);
55 }},
56 accObject.u);
57 return loc;
60 template <typename Op>
61 static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
62 mlir::Value baseAddr, std::stringstream &name,
63 mlir::SmallVector<mlir::Value> bounds,
64 bool structured, bool implicit,
65 mlir::acc::DataClause dataClause, mlir::Type retTy,
66 mlir::Value isPresent = {}) {
67 mlir::Value varPtrPtr;
68 if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
69 if (isPresent) {
70 baseAddr =
71 builder
72 .genIfOp(loc, {boxTy.getEleTy()}, isPresent,
73 /*withElseRegion=*/true)
74 .genThen([&]() {
75 mlir::Value boxAddr =
76 builder.create<fir::BoxAddrOp>(loc, baseAddr);
77 builder.create<fir::ResultOp>(loc, mlir::ValueRange{boxAddr});
79 .genElse([&] {
80 mlir::Value absent =
81 builder.create<fir::AbsentOp>(loc, boxTy.getEleTy());
82 builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
84 .getResults()[0];
85 } else {
86 baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
88 retTy = baseAddr.getType();
91 Op op = builder.create<Op>(loc, retTy, baseAddr);
92 op.setNameAttr(builder.getStringAttr(name.str()));
93 op.setStructured(structured);
94 op.setImplicit(implicit);
95 op.setDataClause(dataClause);
97 unsigned insPos = 1;
98 if (varPtrPtr)
99 op->insertOperands(insPos++, varPtrPtr);
100 if (bounds.size() > 0)
101 op->insertOperands(insPos, bounds);
102 op->setAttr(Op::getOperandSegmentSizeAttr(),
103 builder.getDenseI32ArrayAttr(
104 {1, varPtrPtr ? 1 : 0, static_cast<int32_t>(bounds.size())}));
105 return op;
108 static void addDeclareAttr(fir::FirOpBuilder &builder, mlir::Operation *op,
109 mlir::acc::DataClause clause) {
110 if (!op)
111 return;
112 op->setAttr(mlir::acc::getDeclareAttrName(),
113 mlir::acc::DeclareAttr::get(builder.getContext(),
114 mlir::acc::DataClauseAttr::get(
115 builder.getContext(), clause)));
118 static mlir::func::FuncOp
119 createDeclareFunc(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder,
120 mlir::Location loc, llvm::StringRef funcName,
121 llvm::SmallVector<mlir::Type> argsTy = {},
122 llvm::SmallVector<mlir::Location> locs = {}) {
123 auto funcTy = mlir::FunctionType::get(modBuilder.getContext(), argsTy, {});
124 auto funcOp = modBuilder.create<mlir::func::FuncOp>(loc, funcName, funcTy);
125 funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
126 builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
127 locs);
128 builder.setInsertionPointToEnd(&funcOp.getRegion().back());
129 builder.create<mlir::func::ReturnOp>(loc);
130 builder.setInsertionPointToStart(&funcOp.getRegion().back());
131 return funcOp;
134 template <typename Op>
135 static Op
136 createSimpleOp(fir::FirOpBuilder &builder, mlir::Location loc,
137 const llvm::SmallVectorImpl<mlir::Value> &operands,
138 const llvm::SmallVectorImpl<int32_t> &operandSegments) {
139 llvm::ArrayRef<mlir::Type> argTy;
140 Op op = builder.create<Op>(loc, argTy, operands);
141 op->setAttr(Op::getOperandSegmentSizeAttr(),
142 builder.getDenseI32ArrayAttr(operandSegments));
143 return op;
146 template <typename EntryOp>
147 static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
148 fir::FirOpBuilder &builder,
149 mlir::Location loc, mlir::Type descTy,
150 llvm::StringRef funcNamePrefix,
151 std::stringstream &asFortran,
152 mlir::acc::DataClause clause) {
153 auto crtInsPt = builder.saveInsertionPoint();
154 std::stringstream registerFuncName;
155 registerFuncName << funcNamePrefix.str()
156 << Fortran::lower::declarePostAllocSuffix.str();
158 if (!mlir::isa<fir::ReferenceType>(descTy))
159 descTy = fir::ReferenceType::get(descTy);
160 auto registerFuncOp = createDeclareFunc(
161 modBuilder, builder, loc, registerFuncName.str(), {descTy}, {loc});
163 llvm::SmallVector<mlir::Value> bounds;
164 std::stringstream asFortranDesc;
165 asFortranDesc << asFortran.str() << accFirDescriptorPostfix.str();
167 // Updating descriptor must occur before the mapping of the data so that
168 // attached data pointer is not overwritten.
169 mlir::acc::UpdateDeviceOp updateDeviceOp =
170 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
171 builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
172 /*structured=*/false, /*implicit=*/true,
173 mlir::acc::DataClause::acc_update_device, descTy);
174 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
175 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
176 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
178 mlir::Value desc =
179 builder.create<fir::LoadOp>(loc, registerFuncOp.getArgument(0));
180 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, desc);
181 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
182 EntryOp entryOp = createDataEntryOp<EntryOp>(
183 builder, loc, boxAddrOp.getResult(), asFortran, bounds,
184 /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
185 builder.create<mlir::acc::DeclareEnterOp>(
186 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
187 mlir::ValueRange(entryOp.getAccPtr()));
189 modBuilder.setInsertionPointAfter(registerFuncOp);
190 builder.restoreInsertionPoint(crtInsPt);
193 template <typename ExitOp>
194 static void createDeclareDeallocFuncWithArg(
195 mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, mlir::Location loc,
196 mlir::Type descTy, llvm::StringRef funcNamePrefix,
197 std::stringstream &asFortran, mlir::acc::DataClause clause) {
198 auto crtInsPt = builder.saveInsertionPoint();
199 // Generate the pre dealloc function.
200 std::stringstream preDeallocFuncName;
201 preDeallocFuncName << funcNamePrefix.str()
202 << Fortran::lower::declarePreDeallocSuffix.str();
203 if (!mlir::isa<fir::ReferenceType>(descTy))
204 descTy = fir::ReferenceType::get(descTy);
205 auto preDeallocOp = createDeclareFunc(
206 modBuilder, builder, loc, preDeallocFuncName.str(), {descTy}, {loc});
207 mlir::Value loadOp =
208 builder.create<fir::LoadOp>(loc, preDeallocOp.getArgument(0));
209 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
210 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
212 llvm::SmallVector<mlir::Value> bounds;
213 mlir::acc::GetDevicePtrOp entryOp =
214 createDataEntryOp<mlir::acc::GetDevicePtrOp>(
215 builder, loc, boxAddrOp.getResult(), asFortran, bounds,
216 /*structured=*/false, /*implicit=*/false, clause,
217 boxAddrOp.getType());
218 builder.create<mlir::acc::DeclareExitOp>(
219 loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));
221 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
222 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
223 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
224 entryOp.getVarPtr(), entryOp.getBounds(),
225 entryOp.getDataClause(),
226 /*structured=*/false, /*implicit=*/false,
227 builder.getStringAttr(*entryOp.getName()));
228 else
229 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
230 entryOp.getBounds(), entryOp.getDataClause(),
231 /*structured=*/false, /*implicit=*/false,
232 builder.getStringAttr(*entryOp.getName()));
234 // Generate the post dealloc function.
235 modBuilder.setInsertionPointAfter(preDeallocOp);
236 std::stringstream postDeallocFuncName;
237 postDeallocFuncName << funcNamePrefix.str()
238 << Fortran::lower::declarePostDeallocSuffix.str();
239 auto postDeallocOp = createDeclareFunc(
240 modBuilder, builder, loc, postDeallocFuncName.str(), {descTy}, {loc});
241 loadOp = builder.create<fir::LoadOp>(loc, postDeallocOp.getArgument(0));
242 asFortran << accFirDescriptorPostfix.str();
243 mlir::acc::UpdateDeviceOp updateDeviceOp =
244 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
245 builder, loc, loadOp, asFortran, bounds,
246 /*structured=*/false, /*implicit=*/true,
247 mlir::acc::DataClause::acc_update_device, loadOp.getType());
248 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
249 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
250 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
251 modBuilder.setInsertionPointAfter(postDeallocOp);
252 builder.restoreInsertionPoint(crtInsPt);
255 Fortran::semantics::Symbol &
256 getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
257 if (const auto *designator =
258 std::get_if<Fortran::parser::Designator>(&accObject.u)) {
259 if (const auto *name =
260 Fortran::semantics::getDesignatorNameIfDataRef(*designator))
261 return *name->symbol;
262 if (const auto *arrayElement =
263 Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
264 *designator)) {
265 const Fortran::parser::Name &name =
266 Fortran::parser::GetLastName(arrayElement->base);
267 return *name.symbol;
269 } else if (const auto *name =
270 std::get_if<Fortran::parser::Name>(&accObject.u)) {
271 return *name->symbol;
273 llvm::report_fatal_error("Could not find symbol");
276 template <typename Op>
277 static void
278 genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
279 Fortran::lower::AbstractConverter &converter,
280 Fortran::semantics::SemanticsContext &semanticsContext,
281 Fortran::lower::StatementContext &stmtCtx,
282 llvm::SmallVectorImpl<mlir::Value> &dataOperands,
283 mlir::acc::DataClause dataClause, bool structured,
284 bool implicit, bool setDeclareAttr = false) {
285 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
286 for (const auto &accObject : objectList.v) {
287 llvm::SmallVector<mlir::Value> bounds;
288 std::stringstream asFortran;
289 mlir::Location operandLocation = genOperandLocation(converter, accObject);
290 Fortran::lower::AddrAndBoundsInfo info =
291 Fortran::lower::gatherDataOperandAddrAndBounds<
292 Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
293 mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
294 stmtCtx, accObject, operandLocation,
295 asFortran, bounds,
296 /*treatIndexAsSection=*/true);
298 Op op = createDataEntryOp<Op>(
299 builder, operandLocation, info.addr, asFortran, bounds, structured,
300 implicit, dataClause, info.addr.getType(), info.isPresent);
301 dataOperands.push_back(op.getAccPtr());
305 template <typename EntryOp, typename ExitOp>
306 static void genDeclareDataOperandOperations(
307 const Fortran::parser::AccObjectList &objectList,
308 Fortran::lower::AbstractConverter &converter,
309 Fortran::semantics::SemanticsContext &semanticsContext,
310 Fortran::lower::StatementContext &stmtCtx,
311 llvm::SmallVectorImpl<mlir::Value> &dataOperands,
312 mlir::acc::DataClause dataClause, bool structured, bool implicit) {
313 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
314 for (const auto &accObject : objectList.v) {
315 llvm::SmallVector<mlir::Value> bounds;
316 std::stringstream asFortran;
317 mlir::Location operandLocation = genOperandLocation(converter, accObject);
318 Fortran::lower::AddrAndBoundsInfo info =
319 Fortran::lower::gatherDataOperandAddrAndBounds<
320 Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
321 mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
322 stmtCtx, accObject, operandLocation,
323 asFortran, bounds);
324 EntryOp op = createDataEntryOp<EntryOp>(
325 builder, operandLocation, info.addr, asFortran, bounds, structured,
326 implicit, dataClause, info.addr.getType());
327 dataOperands.push_back(op.getAccPtr());
328 addDeclareAttr(builder, op.getVarPtr().getDefiningOp(), dataClause);
329 if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) {
330 mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
331 modBuilder.setInsertionPointAfter(builder.getFunction());
332 std::string prefix =
333 converter.mangleName(getSymbolFromAccObject(accObject));
334 createDeclareAllocFuncWithArg<EntryOp>(
335 modBuilder, builder, operandLocation, info.addr.getType(), prefix,
336 asFortran, dataClause);
337 if constexpr (!std::is_same_v<EntryOp, ExitOp>)
338 createDeclareDeallocFuncWithArg<ExitOp>(
339 modBuilder, builder, operandLocation, info.addr.getType(), prefix,
340 asFortran, dataClause);
345 template <typename EntryOp, typename ExitOp, typename Clause>
346 static void genDeclareDataOperandOperationsWithModifier(
347 const Clause *x, Fortran::lower::AbstractConverter &converter,
348 Fortran::semantics::SemanticsContext &semanticsContext,
349 Fortran::lower::StatementContext &stmtCtx,
350 Fortran::parser::AccDataModifier::Modifier mod,
351 llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
352 const mlir::acc::DataClause clause,
353 const mlir::acc::DataClause clauseWithModifier) {
354 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
355 const auto &accObjectList =
356 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
357 const auto &modifier =
358 std::get<std::optional<Fortran::parser::AccDataModifier>>(
359 listWithModifier.t);
360 mlir::acc::DataClause dataClause =
361 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
362 genDeclareDataOperandOperations<EntryOp, ExitOp>(
363 accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands,
364 dataClause,
365 /*structured=*/true, /*implicit=*/false);
368 template <typename EntryOp, typename ExitOp>
369 static void genDataExitOperations(fir::FirOpBuilder &builder,
370 llvm::SmallVector<mlir::Value> operands,
371 bool structured) {
372 for (mlir::Value operand : operands) {
373 auto entryOp = mlir::dyn_cast_or_null<EntryOp>(operand.getDefiningOp());
374 assert(entryOp && "data entry op expected");
375 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
376 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
377 builder.create<ExitOp>(
378 entryOp.getLoc(), entryOp.getAccPtr(), entryOp.getVarPtr(),
379 entryOp.getBounds(), entryOp.getDataClause(), structured,
380 entryOp.getImplicit(), builder.getStringAttr(*entryOp.getName()));
381 else
382 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
383 entryOp.getBounds(), entryOp.getDataClause(),
384 structured, entryOp.getImplicit(),
385 builder.getStringAttr(*entryOp.getName()));
389 fir::ShapeOp genShapeOp(mlir::OpBuilder &builder, fir::SequenceType seqTy,
390 mlir::Location loc) {
391 llvm::SmallVector<mlir::Value> extents;
392 mlir::Type idxTy = builder.getIndexType();
393 for (auto extent : seqTy.getShape())
394 extents.push_back(builder.create<mlir::arith::ConstantOp>(
395 loc, idxTy, builder.getIntegerAttr(idxTy, extent)));
396 return builder.create<fir::ShapeOp>(loc, extents);
399 /// Return the nested sequence type if any.
400 static mlir::Type extractSequenceType(mlir::Type ty) {
401 if (mlir::isa<fir::SequenceType>(ty))
402 return ty;
403 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
404 return extractSequenceType(boxTy.getEleTy());
405 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
406 return extractSequenceType(heapTy.getEleTy());
407 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
408 return extractSequenceType(ptrTy.getEleTy());
409 return mlir::Type{};
412 template <typename RecipeOp>
413 static void genPrivateLikeInitRegion(mlir::OpBuilder &builder, RecipeOp recipe,
414 mlir::Type ty, mlir::Location loc) {
415 mlir::Value retVal = recipe.getInitRegion().front().getArgument(0);
416 if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) {
417 if (fir::isa_trivial(refTy.getEleTy())) {
418 auto alloca = builder.create<fir::AllocaOp>(loc, refTy.getEleTy());
419 auto declareOp = builder.create<hlfir::DeclareOp>(
420 loc, alloca, accPrivateInitName, /*shape=*/nullptr,
421 llvm::ArrayRef<mlir::Value>{}, fir::FortranVariableFlagsAttr{});
422 retVal = declareOp.getBase();
423 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
424 refTy.getEleTy())) {
425 if (fir::isa_trivial(seqTy.getEleTy())) {
426 mlir::Value shape;
427 llvm::SmallVector<mlir::Value> extents;
428 if (seqTy.hasDynamicExtents()) {
429 // Extents are passed as block arguments. First argument is the
430 // original value.
431 for (unsigned i = 1; i < recipe.getInitRegion().getArguments().size();
432 ++i)
433 extents.push_back(recipe.getInitRegion().getArgument(i));
434 shape = builder.create<fir::ShapeOp>(loc, extents);
435 } else {
436 shape = genShapeOp(builder, seqTy, loc);
438 auto alloca = builder.create<fir::AllocaOp>(
439 loc, seqTy, /*typeparams=*/mlir::ValueRange{}, extents);
440 auto declareOp = builder.create<hlfir::DeclareOp>(
441 loc, alloca, accPrivateInitName, shape,
442 llvm::ArrayRef<mlir::Value>{}, fir::FortranVariableFlagsAttr{});
443 retVal = declareOp.getBase();
446 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
447 mlir::Type innerTy = extractSequenceType(boxTy);
448 if (!innerTy)
449 TODO(loc, "Unsupported boxed type in OpenACC privatization");
450 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()};
451 hlfir::Entity source = hlfir::Entity{retVal};
452 auto [temp, cleanup] = hlfir::createTempFromMold(loc, firBuilder, source);
453 retVal = temp;
455 builder.create<mlir::acc::YieldOp>(loc, retVal);
458 mlir::acc::PrivateRecipeOp
459 Fortran::lower::createOrGetPrivateRecipe(mlir::OpBuilder &builder,
460 llvm::StringRef recipeName,
461 mlir::Location loc, mlir::Type ty) {
462 mlir::ModuleOp mod =
463 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
464 if (auto recipe = mod.lookupSymbol<mlir::acc::PrivateRecipeOp>(recipeName))
465 return recipe;
467 auto crtPos = builder.saveInsertionPoint();
468 mlir::OpBuilder modBuilder(mod.getBodyRegion());
469 auto recipe =
470 modBuilder.create<mlir::acc::PrivateRecipeOp>(loc, recipeName, ty);
471 llvm::SmallVector<mlir::Type> argsTy{ty};
472 llvm::SmallVector<mlir::Location> argsLoc{loc};
473 if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) {
474 if (auto seqTy =
475 mlir::dyn_cast_or_null<fir::SequenceType>(refTy.getEleTy())) {
476 if (seqTy.hasDynamicExtents()) {
477 mlir::Type idxTy = builder.getIndexType();
478 for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
479 argsTy.push_back(idxTy);
480 argsLoc.push_back(loc);
485 builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(),
486 argsTy, argsLoc);
487 builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
488 genPrivateLikeInitRegion<mlir::acc::PrivateRecipeOp>(builder, recipe, ty,
489 loc);
490 builder.restoreInsertionPoint(crtPos);
491 return recipe;
494 /// Check if the DataBoundsOp is a constant bound (lb and ub are constants or
495 /// extent is a constant).
496 bool isConstantBound(mlir::acc::DataBoundsOp &op) {
497 if (op.getLowerbound() && fir::getIntIfConstant(op.getLowerbound()) &&
498 op.getUpperbound() && fir::getIntIfConstant(op.getUpperbound()))
499 return true;
500 if (op.getExtent() && fir::getIntIfConstant(op.getExtent()))
501 return true;
502 return false;
505 /// Return true iff all the bounds are expressed with constant values.
506 bool areAllBoundConstant(const llvm::SmallVector<mlir::Value> &bounds) {
507 for (auto bound : bounds) {
508 auto dataBound =
509 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
510 assert(dataBound && "Must be DataBoundOp operation");
511 if (!isConstantBound(dataBound))
512 return false;
514 return true;
517 static llvm::SmallVector<mlir::Value>
518 genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc,
519 mlir::acc::DataBoundsOp &dataBound) {
520 mlir::Type idxTy = builder.getIndexType();
521 mlir::Value lb, ub, step;
522 if (dataBound.getLowerbound() &&
523 fir::getIntIfConstant(dataBound.getLowerbound()) &&
524 dataBound.getUpperbound() &&
525 fir::getIntIfConstant(dataBound.getUpperbound())) {
526 lb = builder.createIntegerConstant(
527 loc, idxTy, *fir::getIntIfConstant(dataBound.getLowerbound()));
528 ub = builder.createIntegerConstant(
529 loc, idxTy, *fir::getIntIfConstant(dataBound.getUpperbound()));
530 step = builder.createIntegerConstant(loc, idxTy, 1);
531 } else if (dataBound.getExtent()) {
532 lb = builder.createIntegerConstant(loc, idxTy, 0);
533 ub = builder.createIntegerConstant(
534 loc, idxTy, *fir::getIntIfConstant(dataBound.getExtent()) - 1);
535 step = builder.createIntegerConstant(loc, idxTy, 1);
536 } else {
537 llvm::report_fatal_error("Expect constant lb/ub or extent");
539 return {lb, ub, step};
542 static fir::ShapeOp genShapeFromBoundsOrArgs(
543 mlir::Location loc, fir::FirOpBuilder &builder, fir::SequenceType seqTy,
544 const llvm::SmallVector<mlir::Value> &bounds, mlir::ValueRange arguments) {
545 llvm::SmallVector<mlir::Value> args;
546 if (areAllBoundConstant(bounds)) {
547 for (auto bound : llvm::reverse(bounds)) {
548 auto dataBound =
549 mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
550 args.append(genConstantBounds(builder, loc, dataBound));
552 } else {
553 assert(((arguments.size() - 2) / 3 == seqTy.getDimension()) &&
554 "Expect 3 block arguments per dimension");
555 for (auto arg : arguments.drop_front(2))
556 args.push_back(arg);
559 assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3");
560 llvm::SmallVector<mlir::Value> extents;
561 mlir::Type idxTy = builder.getIndexType();
562 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
563 mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
564 for (unsigned i = 0; i < args.size(); i += 3) {
565 mlir::Value s1 =
566 builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]);
567 mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one);
568 mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]);
569 mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
570 loc, mlir::arith::CmpIPredicate::sgt, s3, zero);
571 mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero);
572 extents.push_back(ext);
574 return builder.create<fir::ShapeOp>(loc, extents);
577 static hlfir::DesignateOp::Subscripts
578 getSubscriptsFromArgs(mlir::ValueRange args) {
579 hlfir::DesignateOp::Subscripts triplets;
580 for (unsigned i = 2; i < args.size(); i += 3)
581 triplets.emplace_back(
582 hlfir::DesignateOp::Triplet{args[i], args[i + 1], args[i + 2]});
583 return triplets;
586 static hlfir::Entity genDesignateWithTriplets(
587 fir::FirOpBuilder &builder, mlir::Location loc, hlfir::Entity &entity,
588 hlfir::DesignateOp::Subscripts &triplets, mlir::Value shape) {
589 llvm::SmallVector<mlir::Value> lenParams;
590 hlfir::genLengthParameters(loc, builder, entity, lenParams);
591 auto designate = builder.create<hlfir::DesignateOp>(
592 loc, entity.getBase().getType(), entity, /*component=*/"",
593 /*componentShape=*/mlir::Value{}, triplets,
594 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, shape,
595 lenParams);
596 return hlfir::Entity{designate.getResult()};
599 mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe(
600 mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
601 mlir::Type ty, llvm::SmallVector<mlir::Value> &bounds) {
602 mlir::ModuleOp mod =
603 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
604 if (auto recipe =
605 mod.lookupSymbol<mlir::acc::FirstprivateRecipeOp>(recipeName))
606 return recipe;
608 auto crtPos = builder.saveInsertionPoint();
609 mlir::OpBuilder modBuilder(mod.getBodyRegion());
610 auto recipe =
611 modBuilder.create<mlir::acc::FirstprivateRecipeOp>(loc, recipeName, ty);
612 llvm::SmallVector<mlir::Type> initArgsTy{ty};
613 llvm::SmallVector<mlir::Location> initArgsLoc{loc};
614 auto refTy = fir::unwrapRefType(ty);
615 if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(refTy)) {
616 if (seqTy.hasDynamicExtents()) {
617 mlir::Type idxTy = builder.getIndexType();
618 for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
619 initArgsTy.push_back(idxTy);
620 initArgsLoc.push_back(loc);
624 builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(),
625 initArgsTy, initArgsLoc);
626 builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
627 genPrivateLikeInitRegion<mlir::acc::FirstprivateRecipeOp>(builder, recipe, ty,
628 loc);
630 bool allConstantBound = areAllBoundConstant(bounds);
631 llvm::SmallVector<mlir::Type> argsTy{ty, ty};
632 llvm::SmallVector<mlir::Location> argsLoc{loc, loc};
633 if (!allConstantBound) {
634 for (mlir::Value bound : llvm::reverse(bounds)) {
635 auto dataBound =
636 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
637 argsTy.push_back(dataBound.getLowerbound().getType());
638 argsLoc.push_back(dataBound.getLowerbound().getLoc());
639 argsTy.push_back(dataBound.getUpperbound().getType());
640 argsLoc.push_back(dataBound.getUpperbound().getLoc());
641 argsTy.push_back(dataBound.getStartIdx().getType());
642 argsLoc.push_back(dataBound.getStartIdx().getLoc());
645 builder.createBlock(&recipe.getCopyRegion(), recipe.getCopyRegion().end(),
646 argsTy, argsLoc);
648 builder.setInsertionPointToEnd(&recipe.getCopyRegion().back());
649 ty = fir::unwrapRefType(ty);
650 if (fir::isa_trivial(ty)) {
651 mlir::Value initValue = builder.create<fir::LoadOp>(
652 loc, recipe.getCopyRegion().front().getArgument(0));
653 builder.create<fir::StoreOp>(loc, initValue,
654 recipe.getCopyRegion().front().getArgument(1));
655 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) {
656 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()};
657 auto shape = genShapeFromBoundsOrArgs(
658 loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments());
660 auto leftDeclOp = builder.create<hlfir::DeclareOp>(
661 loc, recipe.getCopyRegion().getArgument(0), llvm::StringRef{}, shape,
662 llvm::ArrayRef<mlir::Value>{}, fir::FortranVariableFlagsAttr{});
663 auto rightDeclOp = builder.create<hlfir::DeclareOp>(
664 loc, recipe.getCopyRegion().getArgument(1), llvm::StringRef{}, shape,
665 llvm::ArrayRef<mlir::Value>{}, fir::FortranVariableFlagsAttr{});
667 hlfir::DesignateOp::Subscripts triplets =
668 getSubscriptsFromArgs(recipe.getCopyRegion().getArguments());
669 auto leftEntity = hlfir::Entity{leftDeclOp.getBase()};
670 auto left =
671 genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape);
672 auto rightEntity = hlfir::Entity{rightDeclOp.getBase()};
673 auto right =
674 genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape);
676 firBuilder.create<hlfir::AssignOp>(loc, left, right);
678 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
679 fir::FirOpBuilder firBuilder{builder, recipe.getOperation()};
680 llvm::SmallVector<mlir::Value> tripletArgs;
681 mlir::Type innerTy = extractSequenceType(boxTy);
682 fir::SequenceType seqTy =
683 mlir::dyn_cast_or_null<fir::SequenceType>(innerTy);
684 if (!seqTy)
685 TODO(loc, "Unsupported boxed type in OpenACC firstprivate");
687 auto shape = genShapeFromBoundsOrArgs(
688 loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments());
689 hlfir::DesignateOp::Subscripts triplets =
690 getSubscriptsFromArgs(recipe.getCopyRegion().getArguments());
691 auto leftEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(0)};
692 auto left =
693 genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape);
694 auto rightEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(1)};
695 auto right =
696 genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape);
697 firBuilder.create<hlfir::AssignOp>(loc, left, right);
700 builder.create<mlir::acc::TerminatorOp>(loc);
701 builder.restoreInsertionPoint(crtPos);
702 return recipe;
705 /// Get a string representation of the bounds.
706 std::string getBoundsString(llvm::SmallVector<mlir::Value> &bounds) {
707 std::stringstream boundStr;
708 if (!bounds.empty())
709 boundStr << "_section_";
710 llvm::interleave(
711 bounds,
712 [&](mlir::Value bound) {
713 auto boundsOp =
714 mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
715 if (boundsOp.getLowerbound() &&
716 fir::getIntIfConstant(boundsOp.getLowerbound()) &&
717 boundsOp.getUpperbound() &&
718 fir::getIntIfConstant(boundsOp.getUpperbound())) {
719 boundStr << "lb" << *fir::getIntIfConstant(boundsOp.getLowerbound())
720 << ".ub" << *fir::getIntIfConstant(boundsOp.getUpperbound());
721 } else if (boundsOp.getExtent() &&
722 fir::getIntIfConstant(boundsOp.getExtent())) {
723 boundStr << "ext" << *fir::getIntIfConstant(boundsOp.getExtent());
724 } else {
725 boundStr << "?";
728 [&] { boundStr << "x"; });
729 return boundStr.str();
732 /// Rebuild the array type from the acc.bounds operation with constant
733 /// lowerbound/upperbound or extent.
734 mlir::Type getTypeFromBounds(llvm::SmallVector<mlir::Value> &bounds,
735 mlir::Type ty) {
736 auto seqTy =
737 mlir::dyn_cast_or_null<fir::SequenceType>(fir::unwrapRefType(ty));
738 if (!bounds.empty() && seqTy) {
739 llvm::SmallVector<int64_t> shape;
740 for (auto b : bounds) {
741 auto boundsOp =
742 mlir::dyn_cast<mlir::acc::DataBoundsOp>(b.getDefiningOp());
743 if (boundsOp.getLowerbound() &&
744 fir::getIntIfConstant(boundsOp.getLowerbound()) &&
745 boundsOp.getUpperbound() &&
746 fir::getIntIfConstant(boundsOp.getUpperbound())) {
747 int64_t ext = *fir::getIntIfConstant(boundsOp.getUpperbound()) -
748 *fir::getIntIfConstant(boundsOp.getLowerbound()) + 1;
749 shape.push_back(ext);
750 } else if (boundsOp.getExtent() &&
751 fir::getIntIfConstant(boundsOp.getExtent())) {
752 shape.push_back(*fir::getIntIfConstant(boundsOp.getExtent()));
753 } else {
754 return ty; // TODO: handle dynamic shaped array slice.
757 if (shape.empty() || shape.size() != bounds.size())
758 return ty;
759 auto newSeqTy = fir::SequenceType::get(shape, seqTy.getEleTy());
760 if (mlir::isa<fir::ReferenceType, fir::PointerType>(ty))
761 return fir::ReferenceType::get(newSeqTy);
762 return newSeqTy;
764 return ty;
767 template <typename RecipeOp>
768 static void
769 genPrivatizations(const Fortran::parser::AccObjectList &objectList,
770 Fortran::lower::AbstractConverter &converter,
771 Fortran::semantics::SemanticsContext &semanticsContext,
772 Fortran::lower::StatementContext &stmtCtx,
773 llvm::SmallVectorImpl<mlir::Value> &dataOperands,
774 llvm::SmallVector<mlir::Attribute> &privatizations) {
775 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
776 for (const auto &accObject : objectList.v) {
777 llvm::SmallVector<mlir::Value> bounds;
778 std::stringstream asFortran;
779 mlir::Location operandLocation = genOperandLocation(converter, accObject);
780 Fortran::lower::AddrAndBoundsInfo info =
781 Fortran::lower::gatherDataOperandAddrAndBounds<
782 Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
783 mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
784 stmtCtx, accObject, operandLocation,
785 asFortran, bounds);
786 RecipeOp recipe;
787 mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
788 if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
789 std::string recipeName =
790 fir::getTypeAsString(retTy, converter.getKindMap(), "privatization");
791 recipe = Fortran::lower::createOrGetPrivateRecipe(builder, recipeName,
792 operandLocation, retTy);
793 auto op = createDataEntryOp<mlir::acc::PrivateOp>(
794 builder, operandLocation, info.addr, asFortran, bounds, true,
795 /*implicit=*/false, mlir::acc::DataClause::acc_private, retTy);
796 dataOperands.push_back(op.getAccPtr());
797 } else {
798 std::string suffix =
799 areAllBoundConstant(bounds) ? getBoundsString(bounds) : "";
800 std::string recipeName = fir::getTypeAsString(
801 retTy, converter.getKindMap(), "firstprivatization" + suffix);
802 recipe = Fortran::lower::createOrGetFirstprivateRecipe(
803 builder, recipeName, operandLocation, retTy, bounds);
804 auto op = createDataEntryOp<mlir::acc::FirstprivateOp>(
805 builder, operandLocation, info.addr, asFortran, bounds, true,
806 /*implicit=*/false, mlir::acc::DataClause::acc_firstprivate, retTy);
807 dataOperands.push_back(op.getAccPtr());
809 privatizations.push_back(mlir::SymbolRefAttr::get(
810 builder.getContext(), recipe.getSymName().str()));
814 /// Return the corresponding enum value for the mlir::acc::ReductionOperator
815 /// from the parser representation.
816 static mlir::acc::ReductionOperator
817 getReductionOperator(const Fortran::parser::AccReductionOperator &op) {
818 switch (op.v) {
819 case Fortran::parser::AccReductionOperator::Operator::Plus:
820 return mlir::acc::ReductionOperator::AccAdd;
821 case Fortran::parser::AccReductionOperator::Operator::Multiply:
822 return mlir::acc::ReductionOperator::AccMul;
823 case Fortran::parser::AccReductionOperator::Operator::Max:
824 return mlir::acc::ReductionOperator::AccMax;
825 case Fortran::parser::AccReductionOperator::Operator::Min:
826 return mlir::acc::ReductionOperator::AccMin;
827 case Fortran::parser::AccReductionOperator::Operator::Iand:
828 return mlir::acc::ReductionOperator::AccIand;
829 case Fortran::parser::AccReductionOperator::Operator::Ior:
830 return mlir::acc::ReductionOperator::AccIor;
831 case Fortran::parser::AccReductionOperator::Operator::Ieor:
832 return mlir::acc::ReductionOperator::AccXor;
833 case Fortran::parser::AccReductionOperator::Operator::And:
834 return mlir::acc::ReductionOperator::AccLand;
835 case Fortran::parser::AccReductionOperator::Operator::Or:
836 return mlir::acc::ReductionOperator::AccLor;
837 case Fortran::parser::AccReductionOperator::Operator::Eqv:
838 return mlir::acc::ReductionOperator::AccEqv;
839 case Fortran::parser::AccReductionOperator::Operator::Neqv:
840 return mlir::acc::ReductionOperator::AccNeqv;
842 llvm_unreachable("unexpected reduction operator");
845 /// Get the initial value for reduction operator.
846 template <typename R>
847 static R getReductionInitValue(mlir::acc::ReductionOperator op, mlir::Type ty) {
848 if (op == mlir::acc::ReductionOperator::AccMin) {
849 // min init value -> largest
850 if constexpr (std::is_same_v<R, llvm::APInt>) {
851 assert(ty.isIntOrIndex() && "expect integer or index type");
852 return llvm::APInt::getSignedMaxValue(ty.getIntOrFloatBitWidth());
854 if constexpr (std::is_same_v<R, llvm::APFloat>) {
855 auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty);
856 assert(floatTy && "expect float type");
857 return llvm::APFloat::getLargest(floatTy.getFloatSemantics(),
858 /*negative=*/false);
860 } else if (op == mlir::acc::ReductionOperator::AccMax) {
861 // max init value -> smallest
862 if constexpr (std::is_same_v<R, llvm::APInt>) {
863 assert(ty.isIntOrIndex() && "expect integer or index type");
864 return llvm::APInt::getSignedMinValue(ty.getIntOrFloatBitWidth());
866 if constexpr (std::is_same_v<R, llvm::APFloat>) {
867 auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty);
868 assert(floatTy && "expect float type");
869 return llvm::APFloat::getSmallest(floatTy.getFloatSemantics(),
870 /*negative=*/true);
872 } else if (op == mlir::acc::ReductionOperator::AccIand) {
873 if constexpr (std::is_same_v<R, llvm::APInt>) {
874 assert(ty.isIntOrIndex() && "expect integer type");
875 unsigned bits = ty.getIntOrFloatBitWidth();
876 return llvm::APInt::getAllOnes(bits);
878 } else {
879 // +, ior, ieor init value -> 0
880 // * init value -> 1
881 int64_t value = (op == mlir::acc::ReductionOperator::AccMul) ? 1 : 0;
882 if constexpr (std::is_same_v<R, llvm::APInt>) {
883 assert(ty.isIntOrIndex() && "expect integer or index type");
884 return llvm::APInt(ty.getIntOrFloatBitWidth(), value, true);
887 if constexpr (std::is_same_v<R, llvm::APFloat>) {
888 assert(mlir::isa<mlir::FloatType>(ty) && "expect float type");
889 auto floatTy = mlir::dyn_cast<mlir::FloatType>(ty);
890 return llvm::APFloat(floatTy.getFloatSemantics(), value);
893 if constexpr (std::is_same_v<R, int64_t>)
894 return value;
896 llvm_unreachable("OpenACC reduction unsupported type");
899 /// Return a constant with the initial value for the reduction operator and
900 /// type combination.
901 static mlir::Value getReductionInitValue(fir::FirOpBuilder &builder,
902 mlir::Location loc, mlir::Type ty,
903 mlir::acc::ReductionOperator op) {
904 if (op == mlir::acc::ReductionOperator::AccLand ||
905 op == mlir::acc::ReductionOperator::AccLor ||
906 op == mlir::acc::ReductionOperator::AccEqv ||
907 op == mlir::acc::ReductionOperator::AccNeqv) {
908 assert(mlir::isa<fir::LogicalType>(ty) && "expect fir.logical type");
909 bool value = true; // .true. for .and. and .eqv.
910 if (op == mlir::acc::ReductionOperator::AccLor ||
911 op == mlir::acc::ReductionOperator::AccNeqv)
912 value = false; // .false. for .or. and .neqv.
913 return builder.createBool(loc, value);
915 if (ty.isIntOrIndex())
916 return builder.create<mlir::arith::ConstantOp>(
917 loc, ty,
918 builder.getIntegerAttr(ty, getReductionInitValue<llvm::APInt>(op, ty)));
919 if (op == mlir::acc::ReductionOperator::AccMin ||
920 op == mlir::acc::ReductionOperator::AccMax) {
921 if (mlir::isa<fir::ComplexType>(ty))
922 llvm::report_fatal_error(
923 "min/max reduction not supported for complex type");
924 if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty))
925 return builder.create<mlir::arith::ConstantOp>(
926 loc, ty,
927 builder.getFloatAttr(ty,
928 getReductionInitValue<llvm::APFloat>(op, ty)));
929 } else if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty)) {
930 return builder.create<mlir::arith::ConstantOp>(
931 loc, ty,
932 builder.getFloatAttr(ty, getReductionInitValue<int64_t>(op, ty)));
933 } else if (auto cmplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(ty)) {
934 mlir::Type floatTy =
935 Fortran::lower::convertReal(builder.getContext(), cmplxTy.getFKind());
936 mlir::Value realInit = builder.createRealConstant(
937 loc, floatTy, getReductionInitValue<int64_t>(op, cmplxTy));
938 mlir::Value imagInit = builder.createRealConstant(loc, floatTy, 0.0);
939 return fir::factory::Complex{builder, loc}.createComplex(
940 cmplxTy.getFKind(), realInit, imagInit);
943 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
944 return getReductionInitValue(builder, loc, seqTy.getEleTy(), op);
946 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
947 return getReductionInitValue(builder, loc, boxTy.getEleTy(), op);
949 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
950 return getReductionInitValue(builder, loc, heapTy.getEleTy(), op);
952 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
953 return getReductionInitValue(builder, loc, ptrTy.getEleTy(), op);
955 llvm::report_fatal_error("Unsupported OpenACC reduction type");
958 static mlir::Value genReductionInitRegion(fir::FirOpBuilder &builder,
959 mlir::Location loc, mlir::Type ty,
960 mlir::acc::ReductionOperator op) {
961 ty = fir::unwrapRefType(ty);
962 mlir::Value initValue = getReductionInitValue(builder, loc, ty, op);
963 if (fir::isa_trivial(ty)) {
964 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
965 auto declareOp = builder.create<hlfir::DeclareOp>(
966 loc, alloca, accReductionInitName, /*shape=*/nullptr,
967 llvm::ArrayRef<mlir::Value>{}, fir::FortranVariableFlagsAttr{});
968 builder.create<fir::StoreOp>(loc, builder.createConvert(loc, ty, initValue),
969 declareOp.getBase());
970 return declareOp.getBase();
971 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) {
972 if (fir::isa_trivial(seqTy.getEleTy())) {
973 mlir::Value shape;
974 auto extents = builder.getBlock()->getArguments().drop_front(1);
975 if (seqTy.hasDynamicExtents())
976 shape = builder.create<fir::ShapeOp>(loc, extents);
977 else
978 shape = genShapeOp(builder, seqTy, loc);
979 mlir::Value alloca = builder.create<fir::AllocaOp>(
980 loc, seqTy, /*typeparams=*/mlir::ValueRange{}, extents);
981 auto declareOp = builder.create<hlfir::DeclareOp>(
982 loc, alloca, accReductionInitName, shape,
983 llvm::ArrayRef<mlir::Value>{}, fir::FortranVariableFlagsAttr{});
984 mlir::Type idxTy = builder.getIndexType();
985 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
986 llvm::SmallVector<fir::DoLoopOp> loops;
987 llvm::SmallVector<mlir::Value> ivs;
989 if (seqTy.hasDynamicExtents()) {
990 builder.create<hlfir::AssignOp>(loc, initValue, declareOp.getBase());
991 return declareOp.getBase();
993 for (auto ext : llvm::reverse(seqTy.getShape())) {
994 auto lb = builder.createIntegerConstant(loc, idxTy, 0);
995 auto ub = builder.createIntegerConstant(loc, idxTy, ext - 1);
996 auto step = builder.createIntegerConstant(loc, idxTy, 1);
997 auto loop = builder.create<fir::DoLoopOp>(loc, lb, ub, step,
998 /*unordered=*/false);
999 builder.setInsertionPointToStart(loop.getBody());
1000 loops.push_back(loop);
1001 ivs.push_back(loop.getInductionVar());
1003 auto coord = builder.create<fir::CoordinateOp>(loc, refTy,
1004 declareOp.getBase(), ivs);
1005 builder.create<fir::StoreOp>(loc, initValue, coord);
1006 builder.setInsertionPointAfter(loops[0]);
1007 return declareOp.getBase();
1009 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
1010 mlir::Type innerTy = extractSequenceType(boxTy);
1011 if (!mlir::isa<fir::SequenceType>(innerTy))
1012 TODO(loc, "Unsupported boxed type for reduction");
1013 // Create the private copy from the initial fir.box.
1014 hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)};
1015 auto [temp, cleanup] = hlfir::createTempFromMold(loc, builder, source);
1016 builder.create<hlfir::AssignOp>(loc, initValue, temp);
1017 return temp;
1019 llvm::report_fatal_error("Unsupported OpenACC reduction type");
1022 template <typename Op>
1023 static mlir::Value genLogicalCombiner(fir::FirOpBuilder &builder,
1024 mlir::Location loc, mlir::Value value1,
1025 mlir::Value value2) {
1026 mlir::Type i1 = builder.getI1Type();
1027 mlir::Value v1 = builder.create<fir::ConvertOp>(loc, i1, value1);
1028 mlir::Value v2 = builder.create<fir::ConvertOp>(loc, i1, value2);
1029 mlir::Value combined = builder.create<Op>(loc, v1, v2);
1030 return builder.create<fir::ConvertOp>(loc, value1.getType(), combined);
1033 static mlir::Value loadIfRef(fir::FirOpBuilder &builder, mlir::Location loc,
1034 mlir::Value value) {
1035 if (mlir::isa<fir::ReferenceType, fir::PointerType, fir::HeapType>(
1036 value.getType()))
1037 return builder.create<fir::LoadOp>(loc, value);
1038 return value;
1041 static mlir::Value genComparisonCombiner(fir::FirOpBuilder &builder,
1042 mlir::Location loc,
1043 mlir::arith::CmpIPredicate pred,
1044 mlir::Value value1,
1045 mlir::Value value2) {
1046 mlir::Type i1 = builder.getI1Type();
1047 mlir::Value v1 = builder.create<fir::ConvertOp>(loc, i1, value1);
1048 mlir::Value v2 = builder.create<fir::ConvertOp>(loc, i1, value2);
1049 mlir::Value add = builder.create<mlir::arith::CmpIOp>(loc, pred, v1, v2);
1050 return builder.create<fir::ConvertOp>(loc, value1.getType(), add);
1053 static mlir::Value genScalarCombiner(fir::FirOpBuilder &builder,
1054 mlir::Location loc,
1055 mlir::acc::ReductionOperator op,
1056 mlir::Type ty, mlir::Value value1,
1057 mlir::Value value2) {
1058 value1 = loadIfRef(builder, loc, value1);
1059 value2 = loadIfRef(builder, loc, value2);
1060 if (op == mlir::acc::ReductionOperator::AccAdd) {
1061 if (ty.isIntOrIndex())
1062 return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
1063 if (mlir::isa<mlir::FloatType>(ty))
1064 return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
1065 if (auto cmplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(ty))
1066 return builder.create<fir::AddcOp>(loc, value1, value2);
1067 TODO(loc, "reduction add type");
1070 if (op == mlir::acc::ReductionOperator::AccMul) {
1071 if (ty.isIntOrIndex())
1072 return builder.create<mlir::arith::MulIOp>(loc, value1, value2);
1073 if (mlir::isa<mlir::FloatType>(ty))
1074 return builder.create<mlir::arith::MulFOp>(loc, value1, value2);
1075 if (mlir::isa<fir::ComplexType>(ty))
1076 return builder.create<fir::MulcOp>(loc, value1, value2);
1077 TODO(loc, "reduction mul type");
1080 if (op == mlir::acc::ReductionOperator::AccMin)
1081 return fir::genMin(builder, loc, {value1, value2});
1083 if (op == mlir::acc::ReductionOperator::AccMax)
1084 return fir::genMax(builder, loc, {value1, value2});
1086 if (op == mlir::acc::ReductionOperator::AccIand)
1087 return builder.create<mlir::arith::AndIOp>(loc, value1, value2);
1089 if (op == mlir::acc::ReductionOperator::AccIor)
1090 return builder.create<mlir::arith::OrIOp>(loc, value1, value2);
1092 if (op == mlir::acc::ReductionOperator::AccXor)
1093 return builder.create<mlir::arith::XOrIOp>(loc, value1, value2);
1095 if (op == mlir::acc::ReductionOperator::AccLand)
1096 return genLogicalCombiner<mlir::arith::AndIOp>(builder, loc, value1,
1097 value2);
1099 if (op == mlir::acc::ReductionOperator::AccLor)
1100 return genLogicalCombiner<mlir::arith::OrIOp>(builder, loc, value1, value2);
1102 if (op == mlir::acc::ReductionOperator::AccEqv)
1103 return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::eq,
1104 value1, value2);
1106 if (op == mlir::acc::ReductionOperator::AccNeqv)
1107 return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::ne,
1108 value1, value2);
1110 TODO(loc, "reduction operator");
1113 static hlfir::DesignateOp::Subscripts
1114 getTripletsFromArgs(mlir::acc::ReductionRecipeOp recipe) {
1115 hlfir::DesignateOp::Subscripts triplets;
1116 for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
1117 i += 3)
1118 triplets.emplace_back(hlfir::DesignateOp::Triplet{
1119 recipe.getCombinerRegion().getArgument(i),
1120 recipe.getCombinerRegion().getArgument(i + 1),
1121 recipe.getCombinerRegion().getArgument(i + 2)});
1122 return triplets;
1125 static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
1126 mlir::acc::ReductionOperator op, mlir::Type ty,
1127 mlir::Value value1, mlir::Value value2,
1128 mlir::acc::ReductionRecipeOp &recipe,
1129 llvm::SmallVector<mlir::Value> &bounds,
1130 bool allConstantBound) {
1131 ty = fir::unwrapRefType(ty);
1133 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
1134 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
1135 llvm::SmallVector<fir::DoLoopOp> loops;
1136 llvm::SmallVector<mlir::Value> ivs;
1137 if (seqTy.hasDynamicExtents()) {
1138 auto shape =
1139 genShapeFromBoundsOrArgs(loc, builder, seqTy, bounds,
1140 recipe.getCombinerRegion().getArguments());
1141 auto v1DeclareOp = builder.create<hlfir::DeclareOp>(
1142 loc, value1, llvm::StringRef{}, shape, llvm::ArrayRef<mlir::Value>{},
1143 fir::FortranVariableFlagsAttr{});
1144 auto v2DeclareOp = builder.create<hlfir::DeclareOp>(
1145 loc, value2, llvm::StringRef{}, shape, llvm::ArrayRef<mlir::Value>{},
1146 fir::FortranVariableFlagsAttr{});
1147 hlfir::DesignateOp::Subscripts triplets = getTripletsFromArgs(recipe);
1149 llvm::SmallVector<mlir::Value> lenParamsLeft;
1150 auto leftEntity = hlfir::Entity{v1DeclareOp.getBase()};
1151 hlfir::genLengthParameters(loc, builder, leftEntity, lenParamsLeft);
1152 auto leftDesignate = builder.create<hlfir::DesignateOp>(
1153 loc, v1DeclareOp.getBase().getType(), v1DeclareOp.getBase(),
1154 /*component=*/"",
1155 /*componentShape=*/mlir::Value{}, triplets,
1156 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
1157 shape, lenParamsLeft);
1158 auto left = hlfir::Entity{leftDesignate.getResult()};
1160 llvm::SmallVector<mlir::Value> lenParamsRight;
1161 auto rightEntity = hlfir::Entity{v2DeclareOp.getBase()};
1162 hlfir::genLengthParameters(loc, builder, rightEntity, lenParamsLeft);
1163 auto rightDesignate = builder.create<hlfir::DesignateOp>(
1164 loc, v2DeclareOp.getBase().getType(), v2DeclareOp.getBase(),
1165 /*component=*/"",
1166 /*componentShape=*/mlir::Value{}, triplets,
1167 /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
1168 shape, lenParamsRight);
1169 auto right = hlfir::Entity{rightDesignate.getResult()};
1171 llvm::SmallVector<mlir::Value, 1> typeParams;
1172 auto genKernel = [&builder, &loc, op, seqTy, &left, &right](
1173 mlir::Location l, fir::FirOpBuilder &b,
1174 mlir::ValueRange oneBasedIndices) -> hlfir::Entity {
1175 auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices);
1176 auto rightElement = hlfir::getElementAt(l, b, right, oneBasedIndices);
1177 auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement);
1178 auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement);
1179 return hlfir::Entity{genScalarCombiner(
1180 builder, loc, op, seqTy.getEleTy(), leftVal, rightVal)};
1182 mlir::Value elemental = hlfir::genElementalOp(
1183 loc, builder, seqTy.getEleTy(), shape, typeParams, genKernel,
1184 /*isUnordered=*/true);
1185 builder.create<hlfir::AssignOp>(loc, elemental, v1DeclareOp.getBase());
1186 return;
1188 if (allConstantBound) {
1189 // Use the constant bound directly in the combiner region so they do not
1190 // need to be passed as block argument.
1191 for (auto bound : llvm::reverse(bounds)) {
1192 auto dataBound =
1193 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
1194 llvm::SmallVector<mlir::Value> values =
1195 genConstantBounds(builder, loc, dataBound);
1196 auto loop =
1197 builder.create<fir::DoLoopOp>(loc, values[0], values[1], values[2],
1198 /*unordered=*/false);
1199 builder.setInsertionPointToStart(loop.getBody());
1200 loops.push_back(loop);
1201 ivs.push_back(loop.getInductionVar());
1203 } else {
1204 // Lowerbound, upperbound and step are passed as block arguments.
1205 [[maybe_unused]] unsigned nbRangeArgs =
1206 recipe.getCombinerRegion().getArguments().size() - 2;
1207 assert((nbRangeArgs / 3 == seqTy.getDimension()) &&
1208 "Expect 3 block arguments per dimension");
1209 for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
1210 i += 3) {
1211 mlir::Value lb = recipe.getCombinerRegion().getArgument(i);
1212 mlir::Value ub = recipe.getCombinerRegion().getArgument(i + 1);
1213 mlir::Value step = recipe.getCombinerRegion().getArgument(i + 2);
1214 auto loop = builder.create<fir::DoLoopOp>(loc, lb, ub, step,
1215 /*unordered=*/false);
1216 builder.setInsertionPointToStart(loop.getBody());
1217 loops.push_back(loop);
1218 ivs.push_back(loop.getInductionVar());
1221 auto addr1 = builder.create<fir::CoordinateOp>(loc, refTy, value1, ivs);
1222 auto addr2 = builder.create<fir::CoordinateOp>(loc, refTy, value2, ivs);
1223 auto load1 = builder.create<fir::LoadOp>(loc, addr1);
1224 auto load2 = builder.create<fir::LoadOp>(loc, addr2);
1225 mlir::Value res =
1226 genScalarCombiner(builder, loc, op, seqTy.getEleTy(), load1, load2);
1227 builder.create<fir::StoreOp>(loc, res, addr1);
1228 builder.setInsertionPointAfter(loops[0]);
1229 } else if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
1230 mlir::Type innerTy = extractSequenceType(boxTy);
1231 fir::SequenceType seqTy =
1232 mlir::dyn_cast_or_null<fir::SequenceType>(innerTy);
1233 if (!seqTy)
1234 TODO(loc, "Unsupported boxed type in OpenACC reduction");
1236 auto shape = genShapeFromBoundsOrArgs(
1237 loc, builder, seqTy, bounds, recipe.getCombinerRegion().getArguments());
1238 hlfir::DesignateOp::Subscripts triplets =
1239 getSubscriptsFromArgs(recipe.getCombinerRegion().getArguments());
1240 auto leftEntity = hlfir::Entity{value1};
1241 auto left =
1242 genDesignateWithTriplets(builder, loc, leftEntity, triplets, shape);
1243 auto rightEntity = hlfir::Entity{value2};
1244 auto right =
1245 genDesignateWithTriplets(builder, loc, rightEntity, triplets, shape);
1247 llvm::SmallVector<mlir::Value, 1> typeParams;
1248 auto genKernel = [&builder, &loc, op, seqTy, &left, &right](
1249 mlir::Location l, fir::FirOpBuilder &b,
1250 mlir::ValueRange oneBasedIndices) -> hlfir::Entity {
1251 auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices);
1252 auto rightElement = hlfir::getElementAt(l, b, right, oneBasedIndices);
1253 auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement);
1254 auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement);
1255 return hlfir::Entity{genScalarCombiner(builder, loc, op, seqTy.getEleTy(),
1256 leftVal, rightVal)};
1258 mlir::Value elemental = hlfir::genElementalOp(
1259 loc, builder, seqTy.getEleTy(), shape, typeParams, genKernel,
1260 /*isUnordered=*/true);
1261 builder.create<hlfir::AssignOp>(loc, elemental, value1);
1262 } else {
1263 mlir::Value res = genScalarCombiner(builder, loc, op, ty, value1, value2);
1264 builder.create<fir::StoreOp>(loc, res, value1);
1268 mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe(
1269 fir::FirOpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
1270 mlir::Type ty, mlir::acc::ReductionOperator op,
1271 llvm::SmallVector<mlir::Value> &bounds) {
1272 mlir::ModuleOp mod =
1273 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
1274 if (auto recipe = mod.lookupSymbol<mlir::acc::ReductionRecipeOp>(recipeName))
1275 return recipe;
1277 auto crtPos = builder.saveInsertionPoint();
1278 mlir::OpBuilder modBuilder(mod.getBodyRegion());
1279 auto recipe =
1280 modBuilder.create<mlir::acc::ReductionRecipeOp>(loc, recipeName, ty, op);
1281 llvm::SmallVector<mlir::Type> initArgsTy{ty};
1282 llvm::SmallVector<mlir::Location> initArgsLoc{loc};
1283 mlir::Type refTy = fir::unwrapRefType(ty);
1284 if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(refTy)) {
1285 if (seqTy.hasDynamicExtents()) {
1286 mlir::Type idxTy = builder.getIndexType();
1287 for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
1288 initArgsTy.push_back(idxTy);
1289 initArgsLoc.push_back(loc);
1293 builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(),
1294 initArgsTy, initArgsLoc);
1295 builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
1296 mlir::Value initValue = genReductionInitRegion(builder, loc, ty, op);
1297 builder.create<mlir::acc::YieldOp>(loc, initValue);
1299 // The two first block arguments are the two values to be combined.
1300 // The next arguments are the iteration ranges (lb, ub, step) to be used
1301 // for the combiner if needed.
1302 llvm::SmallVector<mlir::Type> argsTy{ty, ty};
1303 llvm::SmallVector<mlir::Location> argsLoc{loc, loc};
1304 bool allConstantBound = areAllBoundConstant(bounds);
1305 if (!allConstantBound) {
1306 for (mlir::Value bound : llvm::reverse(bounds)) {
1307 auto dataBound =
1308 mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
1309 argsTy.push_back(dataBound.getLowerbound().getType());
1310 argsLoc.push_back(dataBound.getLowerbound().getLoc());
1311 argsTy.push_back(dataBound.getUpperbound().getType());
1312 argsLoc.push_back(dataBound.getUpperbound().getLoc());
1313 argsTy.push_back(dataBound.getStartIdx().getType());
1314 argsLoc.push_back(dataBound.getStartIdx().getLoc());
1317 builder.createBlock(&recipe.getCombinerRegion(),
1318 recipe.getCombinerRegion().end(), argsTy, argsLoc);
1319 builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back());
1320 mlir::Value v1 = recipe.getCombinerRegion().front().getArgument(0);
1321 mlir::Value v2 = recipe.getCombinerRegion().front().getArgument(1);
1322 genCombiner(builder, loc, op, ty, v1, v2, recipe, bounds, allConstantBound);
1323 builder.create<mlir::acc::YieldOp>(loc, v1);
1324 builder.restoreInsertionPoint(crtPos);
1325 return recipe;
1328 static bool isSupportedReductionType(mlir::Type ty) {
1329 ty = fir::unwrapRefType(ty);
1330 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
1331 return isSupportedReductionType(boxTy.getEleTy());
1332 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
1333 return isSupportedReductionType(seqTy.getEleTy());
1334 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
1335 return isSupportedReductionType(heapTy.getEleTy());
1336 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
1337 return isSupportedReductionType(ptrTy.getEleTy());
1338 return fir::isa_trivial(ty);
1341 static void
1342 genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
1343 Fortran::lower::AbstractConverter &converter,
1344 Fortran::semantics::SemanticsContext &semanticsContext,
1345 Fortran::lower::StatementContext &stmtCtx,
1346 llvm::SmallVectorImpl<mlir::Value> &reductionOperands,
1347 llvm::SmallVector<mlir::Attribute> &reductionRecipes) {
1348 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
1349 const auto &objects = std::get<Fortran::parser::AccObjectList>(objectList.t);
1350 const auto &op =
1351 std::get<Fortran::parser::AccReductionOperator>(objectList.t);
1352 mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
1353 for (const auto &accObject : objects.v) {
1354 llvm::SmallVector<mlir::Value> bounds;
1355 std::stringstream asFortran;
1356 mlir::Location operandLocation = genOperandLocation(converter, accObject);
1357 Fortran::lower::AddrAndBoundsInfo info =
1358 Fortran::lower::gatherDataOperandAddrAndBounds<
1359 Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
1360 mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
1361 stmtCtx, accObject, operandLocation,
1362 asFortran, bounds);
1364 mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType());
1365 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
1366 reductionTy = seqTy.getEleTy();
1368 if (!isSupportedReductionType(reductionTy))
1369 TODO(operandLocation, "reduction with unsupported type");
1371 auto op = createDataEntryOp<mlir::acc::ReductionOp>(
1372 builder, operandLocation, info.addr, asFortran, bounds,
1373 /*structured=*/true, /*implicit=*/false,
1374 mlir::acc::DataClause::acc_reduction, info.addr.getType());
1375 mlir::Type ty = op.getAccPtr().getType();
1376 if (!areAllBoundConstant(bounds) ||
1377 fir::isAssumedShape(info.addr.getType()) ||
1378 fir::isAllocatableOrPointerArray(info.addr.getType()))
1379 ty = info.addr.getType();
1380 std::string suffix =
1381 areAllBoundConstant(bounds) ? getBoundsString(bounds) : "";
1382 std::string recipeName = fir::getTypeAsString(
1383 ty, converter.getKindMap(),
1384 ("reduction_" + stringifyReductionOperator(mlirOp)).str() + suffix);
1386 mlir::acc::ReductionRecipeOp recipe =
1387 Fortran::lower::createOrGetReductionRecipe(
1388 builder, recipeName, operandLocation, ty, mlirOp, bounds);
1389 reductionRecipes.push_back(mlir::SymbolRefAttr::get(
1390 builder.getContext(), recipe.getSymName().str()));
1391 reductionOperands.push_back(op.getAccPtr());
1395 static void
1396 addOperands(llvm::SmallVectorImpl<mlir::Value> &operands,
1397 llvm::SmallVectorImpl<int32_t> &operandSegments,
1398 const llvm::SmallVectorImpl<mlir::Value> &clauseOperands) {
1399 operands.append(clauseOperands.begin(), clauseOperands.end());
1400 operandSegments.push_back(clauseOperands.size());
1403 static void addOperand(llvm::SmallVectorImpl<mlir::Value> &operands,
1404 llvm::SmallVectorImpl<int32_t> &operandSegments,
1405 const mlir::Value &clauseOperand) {
1406 if (clauseOperand) {
1407 operands.push_back(clauseOperand);
1408 operandSegments.push_back(1);
1409 } else {
1410 operandSegments.push_back(0);
1414 template <typename Op, typename Terminator>
1415 static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
1416 Fortran::lower::pft::Evaluation &eval,
1417 const llvm::SmallVectorImpl<mlir::Value> &operands,
1418 const llvm::SmallVectorImpl<int32_t> &operandSegments,
1419 bool outerCombined = false,
1420 llvm::SmallVector<mlir::Type> retTy = {},
1421 mlir::Value yieldValue = {}) {
1422 Op op = builder.create<Op>(loc, retTy, operands);
1423 builder.createBlock(&op.getRegion());
1424 mlir::Block &block = op.getRegion().back();
1425 builder.setInsertionPointToStart(&block);
1427 op->setAttr(Op::getOperandSegmentSizeAttr(),
1428 builder.getDenseI32ArrayAttr(operandSegments));
1430 // Place the insertion point to the start of the first block.
1431 builder.setInsertionPointToStart(&block);
1433 // If it is an unstructured region and is not the outer region of a combined
1434 // construct, create empty blocks for all evaluations.
1435 if (eval.lowerAsUnstructured() && !outerCombined)
1436 Fortran::lower::createEmptyRegionBlocks<mlir::acc::TerminatorOp,
1437 mlir::acc::YieldOp>(
1438 builder, eval.getNestedEvaluations());
1440 if (yieldValue) {
1441 if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) {
1442 Terminator yieldOp = builder.create<Terminator>(loc, yieldValue);
1443 yieldValue.getDefiningOp()->moveBefore(yieldOp);
1444 } else {
1445 builder.create<Terminator>(loc);
1447 } else {
1448 builder.create<Terminator>(loc);
1450 builder.setInsertionPointToStart(&block);
1451 return op;
1454 static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
1455 const Fortran::parser::AccClause::Async *asyncClause,
1456 mlir::Value &async, bool &addAsyncAttr,
1457 Fortran::lower::StatementContext &stmtCtx) {
1458 const auto &asyncClauseValue = asyncClause->v;
1459 if (asyncClauseValue) { // async has a value.
1460 async = fir::getBase(converter.genExprValue(
1461 *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
1462 } else {
1463 addAsyncAttr = true;
1467 static void
1468 genAsyncClause(Fortran::lower::AbstractConverter &converter,
1469 const Fortran::parser::AccClause::Async *asyncClause,
1470 llvm::SmallVector<mlir::Value> &async,
1471 llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes,
1472 llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes,
1473 mlir::acc::DeviceTypeAttr deviceTypeAttr,
1474 Fortran::lower::StatementContext &stmtCtx) {
1475 const auto &asyncClauseValue = asyncClause->v;
1476 if (asyncClauseValue) { // async has a value.
1477 async.push_back(fir::getBase(converter.genExprValue(
1478 *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
1479 asyncDeviceTypes.push_back(deviceTypeAttr);
1480 } else {
1481 asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
1485 static mlir::acc::DeviceType
1486 getDeviceType(Fortran::common::OpenACCDeviceType device) {
1487 switch (device) {
1488 case Fortran::common::OpenACCDeviceType::Star:
1489 return mlir::acc::DeviceType::Star;
1490 case Fortran::common::OpenACCDeviceType::Default:
1491 return mlir::acc::DeviceType::Default;
1492 case Fortran::common::OpenACCDeviceType::Nvidia:
1493 return mlir::acc::DeviceType::Nvidia;
1494 case Fortran::common::OpenACCDeviceType::Radeon:
1495 return mlir::acc::DeviceType::Radeon;
1496 case Fortran::common::OpenACCDeviceType::Host:
1497 return mlir::acc::DeviceType::Host;
1498 case Fortran::common::OpenACCDeviceType::Multicore:
1499 return mlir::acc::DeviceType::Multicore;
1500 case Fortran::common::OpenACCDeviceType::None:
1501 return mlir::acc::DeviceType::None;
1503 return mlir::acc::DeviceType::None;
1506 static void gatherDeviceTypeAttrs(
1507 fir::FirOpBuilder &builder, mlir::Location clauseLocation,
1508 const Fortran::parser::AccClause::DeviceType *deviceTypeClause,
1509 llvm::SmallVector<mlir::Attribute> &deviceTypes,
1510 Fortran::lower::StatementContext &stmtCtx) {
1511 const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
1512 deviceTypeClause->v;
1513 for (const auto &deviceTypeExpr : deviceTypeExprList.v)
1514 deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
1515 builder.getContext(), getDeviceType(deviceTypeExpr.v)));
1518 static void genIfClause(Fortran::lower::AbstractConverter &converter,
1519 mlir::Location clauseLocation,
1520 const Fortran::parser::AccClause::If *ifClause,
1521 mlir::Value &ifCond,
1522 Fortran::lower::StatementContext &stmtCtx) {
1523 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1524 mlir::Value cond = fir::getBase(converter.genExprValue(
1525 *Fortran::semantics::GetExpr(ifClause->v), stmtCtx, &clauseLocation));
1526 ifCond = firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(),
1527 cond);
1530 static void genWaitClause(Fortran::lower::AbstractConverter &converter,
1531 const Fortran::parser::AccClause::Wait *waitClause,
1532 llvm::SmallVectorImpl<mlir::Value> &operands,
1533 mlir::Value &waitDevnum, bool &addWaitAttr,
1534 Fortran::lower::StatementContext &stmtCtx) {
1535 const auto &waitClauseValue = waitClause->v;
1536 if (waitClauseValue) { // wait has a value.
1537 const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1538 const auto &waitList =
1539 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1540 for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1541 mlir::Value v = fir::getBase(
1542 converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx));
1543 operands.push_back(v);
1546 const auto &waitDevnumValue =
1547 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1548 if (waitDevnumValue)
1549 waitDevnum = fir::getBase(converter.genExprValue(
1550 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
1551 } else {
1552 addWaitAttr = true;
1556 static void
1557 genWaitClause(Fortran::lower::AbstractConverter &converter,
1558 const Fortran::parser::AccClause::Wait *waitClause,
1559 llvm::SmallVector<mlir::Value> &waitOperands,
1560 llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
1561 llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
1562 llvm::SmallVector<int32_t> &waitOperandsSegments,
1563 mlir::Value &waitDevnum, mlir::acc::DeviceTypeAttr deviceTypeAttr,
1564 Fortran::lower::StatementContext &stmtCtx) {
1565 const auto &waitClauseValue = waitClause->v;
1566 if (waitClauseValue) { // wait has a value.
1567 const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1568 const auto &waitList =
1569 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1570 auto crtWaitOperands = waitOperands.size();
1571 for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1572 waitOperands.push_back(fir::getBase(converter.genExprValue(
1573 *Fortran::semantics::GetExpr(value), stmtCtx)));
1575 waitOperandsDeviceTypes.push_back(deviceTypeAttr);
1576 waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
1578 // TODO: move to device_type model.
1579 const auto &waitDevnumValue =
1580 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1581 if (waitDevnumValue)
1582 waitDevnum = fir::getBase(converter.genExprValue(
1583 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
1584 } else {
1585 waitOnlyDeviceTypes.push_back(deviceTypeAttr);
1589 static mlir::acc::LoopOp
1590 createLoopOp(Fortran::lower::AbstractConverter &converter,
1591 mlir::Location currentLocation,
1592 Fortran::lower::pft::Evaluation &eval,
1593 Fortran::semantics::SemanticsContext &semanticsContext,
1594 Fortran::lower::StatementContext &stmtCtx,
1595 const Fortran::parser::AccClauseList &accClauseList,
1596 bool needEarlyReturnHandling = false) {
1597 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
1598 llvm::SmallVector<mlir::Value> tileOperands, privateOperands,
1599 reductionOperands, cacheOperands, vectorOperands, workerNumOperands,
1600 gangOperands;
1601 llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
1602 llvm::SmallVector<int32_t> tileOperandsSegments, gangOperandsSegments;
1603 llvm::SmallVector<int64_t> collapseValues;
1605 llvm::SmallVector<mlir::Attribute> gangArgTypes;
1606 llvm::SmallVector<mlir::Attribute> seqDeviceTypes, independentDeviceTypes,
1607 autoDeviceTypes, vectorOperandsDeviceTypes, workerNumOperandsDeviceTypes,
1608 vectorDeviceTypes, workerNumDeviceTypes, tileOperandsDeviceTypes,
1609 collapseDeviceTypes, gangDeviceTypes, gangOperandsDeviceTypes;
1611 // device_type attribute is set to `none` until a device_type clause is
1612 // encountered.
1613 auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
1614 builder.getContext(), mlir::acc::DeviceType::None);
1616 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
1617 mlir::Location clauseLocation = converter.genLocation(clause.source);
1618 if (const auto *gangClause =
1619 std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
1620 if (gangClause->v) {
1621 auto crtGangOperands = gangOperands.size();
1622 const Fortran::parser::AccGangArgList &x = *gangClause->v;
1623 for (const Fortran::parser::AccGangArg &gangArg : x.v) {
1624 if (const auto *num =
1625 std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u)) {
1626 gangOperands.push_back(fir::getBase(converter.genExprValue(
1627 *Fortran::semantics::GetExpr(num->v), stmtCtx)));
1628 gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
1629 builder.getContext(), mlir::acc::GangArgType::Num));
1630 } else if (const auto *staticArg =
1631 std::get_if<Fortran::parser::AccGangArg::Static>(
1632 &gangArg.u)) {
1633 const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v;
1634 if (sizeExpr.v) {
1635 gangOperands.push_back(fir::getBase(converter.genExprValue(
1636 *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx)));
1637 } else {
1638 // * was passed as value and will be represented as a special
1639 // constant.
1640 gangOperands.push_back(builder.createIntegerConstant(
1641 clauseLocation, builder.getIndexType(), starCst));
1643 gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
1644 builder.getContext(), mlir::acc::GangArgType::Static));
1645 } else if (const auto *dim =
1646 std::get_if<Fortran::parser::AccGangArg::Dim>(
1647 &gangArg.u)) {
1648 gangOperands.push_back(fir::getBase(converter.genExprValue(
1649 *Fortran::semantics::GetExpr(dim->v), stmtCtx)));
1650 gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
1651 builder.getContext(), mlir::acc::GangArgType::Dim));
1654 gangOperandsSegments.push_back(gangOperands.size() - crtGangOperands);
1655 gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1656 } else {
1657 gangDeviceTypes.push_back(crtDeviceTypeAttr);
1659 } else if (const auto *workerClause =
1660 std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
1661 if (workerClause->v) {
1662 workerNumOperands.push_back(fir::getBase(converter.genExprValue(
1663 *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx)));
1664 workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1665 } else {
1666 workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
1668 } else if (const auto *vectorClause =
1669 std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
1670 if (vectorClause->v) {
1671 vectorOperands.push_back(fir::getBase(converter.genExprValue(
1672 *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx)));
1673 vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1674 } else {
1675 vectorDeviceTypes.push_back(crtDeviceTypeAttr);
1677 } else if (const auto *tileClause =
1678 std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
1679 const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v;
1680 auto crtTileOperands = tileOperands.size();
1681 for (const auto &accTileExpr : accTileExprList.v) {
1682 const auto &expr =
1683 std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
1684 accTileExpr.t);
1685 if (expr) {
1686 tileOperands.push_back(fir::getBase(converter.genExprValue(
1687 *Fortran::semantics::GetExpr(*expr), stmtCtx)));
1688 } else {
1689 // * was passed as value and will be represented as a special
1690 // constant.
1691 mlir::Value tileStar = builder.createIntegerConstant(
1692 clauseLocation, builder.getIntegerType(32), starCst);
1693 tileOperands.push_back(tileStar);
1696 tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1697 tileOperandsSegments.push_back(tileOperands.size() - crtTileOperands);
1698 } else if (const auto *privateClause =
1699 std::get_if<Fortran::parser::AccClause::Private>(
1700 &clause.u)) {
1701 genPrivatizations<mlir::acc::PrivateRecipeOp>(
1702 privateClause->v, converter, semanticsContext, stmtCtx,
1703 privateOperands, privatizations);
1704 } else if (const auto *reductionClause =
1705 std::get_if<Fortran::parser::AccClause::Reduction>(
1706 &clause.u)) {
1707 genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
1708 reductionOperands, reductionRecipes);
1709 } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
1710 seqDeviceTypes.push_back(crtDeviceTypeAttr);
1711 } else if (std::get_if<Fortran::parser::AccClause::Independent>(
1712 &clause.u)) {
1713 independentDeviceTypes.push_back(crtDeviceTypeAttr);
1714 } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
1715 autoDeviceTypes.push_back(crtDeviceTypeAttr);
1716 } else if (const auto *deviceTypeClause =
1717 std::get_if<Fortran::parser::AccClause::DeviceType>(
1718 &clause.u)) {
1719 const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
1720 deviceTypeClause->v;
1721 assert(deviceTypeExprList.v.size() == 1 &&
1722 "expect only one device_type expr");
1723 crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
1724 builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
1725 } else if (const auto *collapseClause =
1726 std::get_if<Fortran::parser::AccClause::Collapse>(
1727 &clause.u)) {
1728 const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
1729 const auto &force = std::get<bool>(arg.t);
1730 if (force)
1731 TODO(clauseLocation, "OpenACC collapse force modifier");
1732 const auto &intExpr =
1733 std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
1734 const auto *expr = Fortran::semantics::GetExpr(intExpr);
1735 const std::optional<int64_t> collapseValue =
1736 Fortran::evaluate::ToInt64(*expr);
1737 assert(collapseValue && "expect integer value for the collapse clause");
1738 collapseValues.push_back(*collapseValue);
1739 collapseDeviceTypes.push_back(crtDeviceTypeAttr);
1743 // Prepare the operand segment size attribute and the operands value range.
1744 llvm::SmallVector<mlir::Value> operands;
1745 llvm::SmallVector<int32_t> operandSegments;
1746 addOperands(operands, operandSegments, gangOperands);
1747 addOperands(operands, operandSegments, workerNumOperands);
1748 addOperands(operands, operandSegments, vectorOperands);
1749 addOperands(operands, operandSegments, tileOperands);
1750 addOperands(operands, operandSegments, cacheOperands);
1751 addOperands(operands, operandSegments, privateOperands);
1752 addOperands(operands, operandSegments, reductionOperands);
1754 llvm::SmallVector<mlir::Type> retTy;
1755 mlir::Value yieldValue;
1756 if (needEarlyReturnHandling) {
1757 mlir::Type i1Ty = builder.getI1Type();
1758 yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0);
1759 retTy.push_back(i1Ty);
1762 auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
1763 builder, currentLocation, eval, operands, operandSegments,
1764 /*outerCombined=*/false, retTy, yieldValue);
1766 if (!gangDeviceTypes.empty())
1767 loopOp.setGangAttr(builder.getArrayAttr(gangDeviceTypes));
1768 if (!gangArgTypes.empty())
1769 loopOp.setGangOperandsArgTypeAttr(builder.getArrayAttr(gangArgTypes));
1770 if (!gangOperandsSegments.empty())
1771 loopOp.setGangOperandsSegmentsAttr(
1772 builder.getDenseI32ArrayAttr(gangOperandsSegments));
1773 if (!gangOperandsDeviceTypes.empty())
1774 loopOp.setGangOperandsDeviceTypeAttr(
1775 builder.getArrayAttr(gangOperandsDeviceTypes));
1777 if (!workerNumDeviceTypes.empty())
1778 loopOp.setWorkerAttr(builder.getArrayAttr(workerNumDeviceTypes));
1779 if (!workerNumOperandsDeviceTypes.empty())
1780 loopOp.setWorkerNumOperandsDeviceTypeAttr(
1781 builder.getArrayAttr(workerNumOperandsDeviceTypes));
1783 if (!vectorDeviceTypes.empty())
1784 loopOp.setVectorAttr(builder.getArrayAttr(vectorDeviceTypes));
1785 if (!vectorOperandsDeviceTypes.empty())
1786 loopOp.setVectorOperandsDeviceTypeAttr(
1787 builder.getArrayAttr(vectorOperandsDeviceTypes));
1789 if (!tileOperandsDeviceTypes.empty())
1790 loopOp.setTileOperandsDeviceTypeAttr(
1791 builder.getArrayAttr(tileOperandsDeviceTypes));
1792 if (!tileOperandsSegments.empty())
1793 loopOp.setTileOperandsSegmentsAttr(
1794 builder.getDenseI32ArrayAttr(tileOperandsSegments));
1796 if (!seqDeviceTypes.empty())
1797 loopOp.setSeqAttr(builder.getArrayAttr(seqDeviceTypes));
1798 if (!independentDeviceTypes.empty())
1799 loopOp.setIndependentAttr(builder.getArrayAttr(independentDeviceTypes));
1800 if (!autoDeviceTypes.empty())
1801 loopOp.setAuto_Attr(builder.getArrayAttr(autoDeviceTypes));
1803 if (!privatizations.empty())
1804 loopOp.setPrivatizationsAttr(
1805 mlir::ArrayAttr::get(builder.getContext(), privatizations));
1807 if (!reductionRecipes.empty())
1808 loopOp.setReductionRecipesAttr(
1809 mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
1811 if (!collapseValues.empty())
1812 loopOp.setCollapseAttr(builder.getI64ArrayAttr(collapseValues));
1813 if (!collapseDeviceTypes.empty())
1814 loopOp.setCollapseDeviceTypeAttr(builder.getArrayAttr(collapseDeviceTypes));
1816 return loopOp;
1819 static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
1820 bool hasReturnStmt = false;
1821 for (auto &e : eval.getNestedEvaluations()) {
1822 e.visit(Fortran::common::visitors{
1823 [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
1824 [&](const auto &s) {},
1826 if (e.hasNestedEvaluations())
1827 hasReturnStmt = hasEarlyReturn(e);
1829 return hasReturnStmt;
1832 static mlir::Value
1833 genACC(Fortran::lower::AbstractConverter &converter,
1834 Fortran::semantics::SemanticsContext &semanticsContext,
1835 Fortran::lower::pft::Evaluation &eval,
1836 const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
1838 const auto &beginLoopDirective =
1839 std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t);
1840 const auto &loopDirective =
1841 std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
1843 bool needEarlyExitHandling = false;
1844 if (eval.lowerAsUnstructured())
1845 needEarlyExitHandling = hasEarlyReturn(eval);
1847 mlir::Location currentLocation =
1848 converter.genLocation(beginLoopDirective.source);
1849 Fortran::lower::StatementContext stmtCtx;
1851 if (loopDirective.v == llvm::acc::ACCD_loop) {
1852 const auto &accClauseList =
1853 std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
1854 auto loopOp =
1855 createLoopOp(converter, currentLocation, eval, semanticsContext,
1856 stmtCtx, accClauseList, needEarlyExitHandling);
1857 if (needEarlyExitHandling)
1858 return loopOp.getResult(0);
1860 return mlir::Value{};
1863 template <typename Op, typename Clause>
1864 static void genDataOperandOperationsWithModifier(
1865 const Clause *x, Fortran::lower::AbstractConverter &converter,
1866 Fortran::semantics::SemanticsContext &semanticsContext,
1867 Fortran::lower::StatementContext &stmtCtx,
1868 Fortran::parser::AccDataModifier::Modifier mod,
1869 llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
1870 const mlir::acc::DataClause clause,
1871 const mlir::acc::DataClause clauseWithModifier,
1872 bool setDeclareAttr = false) {
1873 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
1874 const auto &accObjectList =
1875 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
1876 const auto &modifier =
1877 std::get<std::optional<Fortran::parser::AccDataModifier>>(
1878 listWithModifier.t);
1879 mlir::acc::DataClause dataClause =
1880 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
1881 genDataOperandOperations<Op>(accObjectList, converter, semanticsContext,
1882 stmtCtx, dataClauseOperands, dataClause,
1883 /*structured=*/true, /*implicit=*/false,
1884 setDeclareAttr);
1887 template <typename Op>
1888 static Op
1889 createComputeOp(Fortran::lower::AbstractConverter &converter,
1890 mlir::Location currentLocation,
1891 Fortran::lower::pft::Evaluation &eval,
1892 Fortran::semantics::SemanticsContext &semanticsContext,
1893 Fortran::lower::StatementContext &stmtCtx,
1894 const Fortran::parser::AccClauseList &accClauseList,
1895 bool outerCombined = false) {
1897 // Parallel operation operands
1898 mlir::Value ifCond;
1899 mlir::Value selfCond;
1900 llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
1901 copyEntryOperands, copyoutEntryOperands, createEntryOperands,
1902 dataClauseOperands, numGangs, numWorkers, vectorLength, async;
1903 llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
1904 vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
1905 waitOperandsDeviceTypes, waitOnlyDeviceTypes;
1906 llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
1908 llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
1909 firstprivateOperands;
1910 llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
1911 reductionRecipes;
1912 mlir::Value waitDevnum; // TODO not yet implemented on compute op.
1914 // Self clause has optional values but can be present with
1915 // no value as well. When there is no value, the op has an attribute to
1916 // represent the clause.
1917 bool addSelfAttr = false;
1919 bool hasDefaultNone = false;
1920 bool hasDefaultPresent = false;
1922 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
1924 // device_type attribute is set to `none` until a device_type clause is
1925 // encountered.
1926 auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
1927 builder.getContext(), mlir::acc::DeviceType::None);
1929 // Lower clauses values mapped to operands.
1930 // Keep track of each group of operands separatly as clauses can appear
1931 // more than once.
1932 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
1933 mlir::Location clauseLocation = converter.genLocation(clause.source);
1934 if (const auto *asyncClause =
1935 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
1936 genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
1937 asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
1938 } else if (const auto *waitClause =
1939 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
1940 genWaitClause(converter, waitClause, waitOperands,
1941 waitOperandsDeviceTypes, waitOnlyDeviceTypes,
1942 waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
1943 stmtCtx);
1944 } else if (const auto *numGangsClause =
1945 std::get_if<Fortran::parser::AccClause::NumGangs>(
1946 &clause.u)) {
1947 auto crtNumGangs = numGangs.size();
1948 for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
1949 numGangs.push_back(fir::getBase(converter.genExprValue(
1950 *Fortran::semantics::GetExpr(expr), stmtCtx)));
1951 numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
1952 numGangsSegments.push_back(numGangs.size() - crtNumGangs);
1953 } else if (const auto *numWorkersClause =
1954 std::get_if<Fortran::parser::AccClause::NumWorkers>(
1955 &clause.u)) {
1956 numWorkers.push_back(fir::getBase(converter.genExprValue(
1957 *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
1958 numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
1959 } else if (const auto *vectorLengthClause =
1960 std::get_if<Fortran::parser::AccClause::VectorLength>(
1961 &clause.u)) {
1962 vectorLength.push_back(fir::getBase(converter.genExprValue(
1963 *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
1964 vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
1965 } else if (const auto *ifClause =
1966 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
1967 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
1968 } else if (const auto *selfClause =
1969 std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
1970 const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
1971 selfClause->v;
1972 if (accSelfClause) {
1973 if (const auto *optCondition =
1974 std::get_if<std::optional<Fortran::parser::ScalarLogicalExpr>>(
1975 &(*accSelfClause).u)) {
1976 if (*optCondition) {
1977 mlir::Value cond = fir::getBase(converter.genExprValue(
1978 *Fortran::semantics::GetExpr(*optCondition), stmtCtx));
1979 selfCond = builder.createConvert(clauseLocation,
1980 builder.getI1Type(), cond);
1982 } else if (const auto *accClauseList =
1983 std::get_if<Fortran::parser::AccObjectList>(
1984 &(*accSelfClause).u)) {
1985 // TODO This would be nicer to be done in canonicalization step.
1986 if (accClauseList->v.size() == 1) {
1987 const auto &accObject = accClauseList->v.front();
1988 if (const auto *designator =
1989 std::get_if<Fortran::parser::Designator>(&accObject.u)) {
1990 if (const auto *name =
1991 Fortran::semantics::getDesignatorNameIfDataRef(
1992 *designator)) {
1993 auto cond = converter.getSymbolAddress(*name->symbol);
1994 selfCond = builder.createConvert(clauseLocation,
1995 builder.getI1Type(), cond);
2000 } else {
2001 addSelfAttr = true;
2003 } else if (const auto *copyClause =
2004 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
2005 auto crtDataStart = dataClauseOperands.size();
2006 genDataOperandOperations<mlir::acc::CopyinOp>(
2007 copyClause->v, converter, semanticsContext, stmtCtx,
2008 dataClauseOperands, mlir::acc::DataClause::acc_copy,
2009 /*structured=*/true, /*implicit=*/false);
2010 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2011 dataClauseOperands.end());
2012 } else if (const auto *copyinClause =
2013 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
2014 genDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
2015 Fortran::parser::AccClause::Copyin>(
2016 copyinClause, converter, semanticsContext, stmtCtx,
2017 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
2018 dataClauseOperands, mlir::acc::DataClause::acc_copyin,
2019 mlir::acc::DataClause::acc_copyin_readonly);
2020 } else if (const auto *copyoutClause =
2021 std::get_if<Fortran::parser::AccClause::Copyout>(
2022 &clause.u)) {
2023 auto crtDataStart = dataClauseOperands.size();
2024 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2025 Fortran::parser::AccClause::Copyout>(
2026 copyoutClause, converter, semanticsContext, stmtCtx,
2027 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
2028 dataClauseOperands, mlir::acc::DataClause::acc_copyout,
2029 mlir::acc::DataClause::acc_copyout_zero);
2030 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2031 dataClauseOperands.end());
2032 } else if (const auto *createClause =
2033 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
2034 auto crtDataStart = dataClauseOperands.size();
2035 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2036 Fortran::parser::AccClause::Create>(
2037 createClause, converter, semanticsContext, stmtCtx,
2038 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands,
2039 mlir::acc::DataClause::acc_create,
2040 mlir::acc::DataClause::acc_create_zero);
2041 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2042 dataClauseOperands.end());
2043 } else if (const auto *noCreateClause =
2044 std::get_if<Fortran::parser::AccClause::NoCreate>(
2045 &clause.u)) {
2046 genDataOperandOperations<mlir::acc::NoCreateOp>(
2047 noCreateClause->v, converter, semanticsContext, stmtCtx,
2048 dataClauseOperands, mlir::acc::DataClause::acc_no_create,
2049 /*structured=*/true, /*implicit=*/false);
2050 } else if (const auto *presentClause =
2051 std::get_if<Fortran::parser::AccClause::Present>(
2052 &clause.u)) {
2053 genDataOperandOperations<mlir::acc::PresentOp>(
2054 presentClause->v, converter, semanticsContext, stmtCtx,
2055 dataClauseOperands, mlir::acc::DataClause::acc_present,
2056 /*structured=*/true, /*implicit=*/false);
2057 } else if (const auto *devicePtrClause =
2058 std::get_if<Fortran::parser::AccClause::Deviceptr>(
2059 &clause.u)) {
2060 genDataOperandOperations<mlir::acc::DevicePtrOp>(
2061 devicePtrClause->v, converter, semanticsContext, stmtCtx,
2062 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
2063 /*structured=*/true, /*implicit=*/false);
2064 } else if (const auto *attachClause =
2065 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
2066 auto crtDataStart = dataClauseOperands.size();
2067 genDataOperandOperations<mlir::acc::AttachOp>(
2068 attachClause->v, converter, semanticsContext, stmtCtx,
2069 dataClauseOperands, mlir::acc::DataClause::acc_attach,
2070 /*structured=*/true, /*implicit=*/false);
2071 attachEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2072 dataClauseOperands.end());
2073 } else if (const auto *privateClause =
2074 std::get_if<Fortran::parser::AccClause::Private>(
2075 &clause.u)) {
2076 if (!outerCombined)
2077 genPrivatizations<mlir::acc::PrivateRecipeOp>(
2078 privateClause->v, converter, semanticsContext, stmtCtx,
2079 privateOperands, privatizations);
2080 } else if (const auto *firstprivateClause =
2081 std::get_if<Fortran::parser::AccClause::Firstprivate>(
2082 &clause.u)) {
2083 genPrivatizations<mlir::acc::FirstprivateRecipeOp>(
2084 firstprivateClause->v, converter, semanticsContext, stmtCtx,
2085 firstprivateOperands, firstPrivatizations);
2086 } else if (const auto *reductionClause =
2087 std::get_if<Fortran::parser::AccClause::Reduction>(
2088 &clause.u)) {
2089 // A reduction clause on a combined construct is treated as if it appeared
2090 // on the loop construct. So don't generate a reduction clause when it is
2091 // combined - delay it to the loop. However, a reduction clause on a
2092 // combined construct implies a copy clause so issue an implicit copy
2093 // instead.
2094 if (!outerCombined) {
2095 genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
2096 reductionOperands, reductionRecipes);
2097 } else {
2098 auto crtDataStart = dataClauseOperands.size();
2099 genDataOperandOperations<mlir::acc::CopyinOp>(
2100 std::get<Fortran::parser::AccObjectList>(reductionClause->v.t),
2101 converter, semanticsContext, stmtCtx, dataClauseOperands,
2102 mlir::acc::DataClause::acc_reduction,
2103 /*structured=*/true, /*implicit=*/true);
2104 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2105 dataClauseOperands.end());
2107 } else if (const auto *defaultClause =
2108 std::get_if<Fortran::parser::AccClause::Default>(
2109 &clause.u)) {
2110 if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
2111 hasDefaultNone = true;
2112 else if ((defaultClause->v).v ==
2113 llvm::acc::DefaultValue::ACC_Default_present)
2114 hasDefaultPresent = true;
2115 } else if (const auto *deviceTypeClause =
2116 std::get_if<Fortran::parser::AccClause::DeviceType>(
2117 &clause.u)) {
2118 const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
2119 deviceTypeClause->v;
2120 assert(deviceTypeExprList.v.size() == 1 &&
2121 "expect only one device_type expr");
2122 crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
2123 builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
2127 // Prepare the operand segment size attribute and the operands value range.
2128 llvm::SmallVector<mlir::Value, 8> operands;
2129 llvm::SmallVector<int32_t, 8> operandSegments;
2130 addOperands(operands, operandSegments, async);
2131 addOperands(operands, operandSegments, waitOperands);
2132 if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2133 addOperands(operands, operandSegments, numGangs);
2134 addOperands(operands, operandSegments, numWorkers);
2135 addOperands(operands, operandSegments, vectorLength);
2137 addOperand(operands, operandSegments, ifCond);
2138 addOperand(operands, operandSegments, selfCond);
2139 if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
2140 addOperands(operands, operandSegments, reductionOperands);
2141 addOperands(operands, operandSegments, privateOperands);
2142 addOperands(operands, operandSegments, firstprivateOperands);
2144 addOperands(operands, operandSegments, dataClauseOperands);
2146 Op computeOp;
2147 if constexpr (std::is_same_v<Op, mlir::acc::KernelsOp>)
2148 computeOp = createRegionOp<Op, mlir::acc::TerminatorOp>(
2149 builder, currentLocation, eval, operands, operandSegments,
2150 outerCombined);
2151 else
2152 computeOp = createRegionOp<Op, mlir::acc::YieldOp>(
2153 builder, currentLocation, eval, operands, operandSegments,
2154 outerCombined);
2156 if (addSelfAttr)
2157 computeOp.setSelfAttrAttr(builder.getUnitAttr());
2159 if (hasDefaultNone)
2160 computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);
2161 if (hasDefaultPresent)
2162 computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
2164 if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2165 if (!numWorkersDeviceTypes.empty())
2166 computeOp.setNumWorkersDeviceTypeAttr(
2167 mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
2168 if (!vectorLengthDeviceTypes.empty())
2169 computeOp.setVectorLengthDeviceTypeAttr(
2170 mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
2171 if (!numGangsDeviceTypes.empty())
2172 computeOp.setNumGangsDeviceTypeAttr(
2173 mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
2174 if (!numGangsSegments.empty())
2175 computeOp.setNumGangsSegmentsAttr(
2176 builder.getDenseI32ArrayAttr(numGangsSegments));
2178 if (!asyncDeviceTypes.empty())
2179 computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
2180 if (!asyncOnlyDeviceTypes.empty())
2181 computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
2183 if (!waitOperandsDeviceTypes.empty())
2184 computeOp.setWaitOperandsDeviceTypeAttr(
2185 builder.getArrayAttr(waitOperandsDeviceTypes));
2186 if (!waitOperandsSegments.empty())
2187 computeOp.setWaitOperandsSegmentsAttr(
2188 builder.getDenseI32ArrayAttr(waitOperandsSegments));
2189 if (!waitOnlyDeviceTypes.empty())
2190 computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
2192 if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
2193 if (!privatizations.empty())
2194 computeOp.setPrivatizationsAttr(
2195 mlir::ArrayAttr::get(builder.getContext(), privatizations));
2196 if (!reductionRecipes.empty())
2197 computeOp.setReductionRecipesAttr(
2198 mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
2199 if (!firstPrivatizations.empty())
2200 computeOp.setFirstprivatizationsAttr(
2201 mlir::ArrayAttr::get(builder.getContext(), firstPrivatizations));
2204 auto insPt = builder.saveInsertionPoint();
2205 builder.setInsertionPointAfter(computeOp);
2207 // Create the exit operations after the region.
2208 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
2209 builder, copyEntryOperands, /*structured=*/true);
2210 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>(
2211 builder, copyoutEntryOperands, /*structured=*/true);
2212 genDataExitOperations<mlir::acc::AttachOp, mlir::acc::DetachOp>(
2213 builder, attachEntryOperands, /*structured=*/true);
2214 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
2215 builder, createEntryOperands, /*structured=*/true);
2217 builder.restoreInsertionPoint(insPt);
2218 return computeOp;
2221 static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
2222 mlir::Location currentLocation,
2223 Fortran::lower::pft::Evaluation &eval,
2224 Fortran::semantics::SemanticsContext &semanticsContext,
2225 Fortran::lower::StatementContext &stmtCtx,
2226 const Fortran::parser::AccClauseList &accClauseList) {
2227 mlir::Value ifCond, waitDevnum;
2228 llvm::SmallVector<mlir::Value> attachEntryOperands, createEntryOperands,
2229 copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands,
2230 async;
2231 llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
2232 waitOperandsDeviceTypes, waitOnlyDeviceTypes;
2233 llvm::SmallVector<int32_t> waitOperandsSegments;
2235 bool hasDefaultNone = false;
2236 bool hasDefaultPresent = false;
2238 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2240 // device_type attribute is set to `none` until a device_type clause is
2241 // encountered.
2242 auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
2243 builder.getContext(), mlir::acc::DeviceType::None);
2245 // Lower clauses values mapped to operands.
2246 // Keep track of each group of operands separately as clauses can appear
2247 // more than once.
2248 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2249 mlir::Location clauseLocation = converter.genLocation(clause.source);
2250 if (const auto *ifClause =
2251 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2252 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2253 } else if (const auto *copyClause =
2254 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
2255 auto crtDataStart = dataClauseOperands.size();
2256 genDataOperandOperations<mlir::acc::CopyinOp>(
2257 copyClause->v, converter, semanticsContext, stmtCtx,
2258 dataClauseOperands, mlir::acc::DataClause::acc_copy,
2259 /*structured=*/true, /*implicit=*/false);
2260 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2261 dataClauseOperands.end());
2262 } else if (const auto *copyinClause =
2263 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
2264 genDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
2265 Fortran::parser::AccClause::Copyin>(
2266 copyinClause, converter, semanticsContext, stmtCtx,
2267 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
2268 dataClauseOperands, mlir::acc::DataClause::acc_copyin,
2269 mlir::acc::DataClause::acc_copyin_readonly);
2270 } else if (const auto *copyoutClause =
2271 std::get_if<Fortran::parser::AccClause::Copyout>(
2272 &clause.u)) {
2273 auto crtDataStart = dataClauseOperands.size();
2274 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2275 Fortran::parser::AccClause::Copyout>(
2276 copyoutClause, converter, semanticsContext, stmtCtx,
2277 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands,
2278 mlir::acc::DataClause::acc_copyout,
2279 mlir::acc::DataClause::acc_copyout_zero);
2280 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2281 dataClauseOperands.end());
2282 } else if (const auto *createClause =
2283 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
2284 auto crtDataStart = dataClauseOperands.size();
2285 genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
2286 Fortran::parser::AccClause::Create>(
2287 createClause, converter, semanticsContext, stmtCtx,
2288 Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands,
2289 mlir::acc::DataClause::acc_create,
2290 mlir::acc::DataClause::acc_create_zero);
2291 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2292 dataClauseOperands.end());
2293 } else if (const auto *noCreateClause =
2294 std::get_if<Fortran::parser::AccClause::NoCreate>(
2295 &clause.u)) {
2296 genDataOperandOperations<mlir::acc::NoCreateOp>(
2297 noCreateClause->v, converter, semanticsContext, stmtCtx,
2298 dataClauseOperands, mlir::acc::DataClause::acc_no_create,
2299 /*structured=*/true, /*implicit=*/false);
2300 } else if (const auto *presentClause =
2301 std::get_if<Fortran::parser::AccClause::Present>(
2302 &clause.u)) {
2303 genDataOperandOperations<mlir::acc::PresentOp>(
2304 presentClause->v, converter, semanticsContext, stmtCtx,
2305 dataClauseOperands, mlir::acc::DataClause::acc_present,
2306 /*structured=*/true, /*implicit=*/false);
2307 } else if (const auto *deviceptrClause =
2308 std::get_if<Fortran::parser::AccClause::Deviceptr>(
2309 &clause.u)) {
2310 genDataOperandOperations<mlir::acc::DevicePtrOp>(
2311 deviceptrClause->v, converter, semanticsContext, stmtCtx,
2312 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
2313 /*structured=*/true, /*implicit=*/false);
2314 } else if (const auto *attachClause =
2315 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
2316 auto crtDataStart = dataClauseOperands.size();
2317 genDataOperandOperations<mlir::acc::AttachOp>(
2318 attachClause->v, converter, semanticsContext, stmtCtx,
2319 dataClauseOperands, mlir::acc::DataClause::acc_attach,
2320 /*structured=*/true, /*implicit=*/false);
2321 attachEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
2322 dataClauseOperands.end());
2323 } else if (const auto *asyncClause =
2324 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2325 genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
2326 asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
2327 } else if (const auto *waitClause =
2328 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
2329 genWaitClause(converter, waitClause, waitOperands,
2330 waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2331 waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
2332 stmtCtx);
2333 } else if(const auto *defaultClause =
2334 std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
2335 if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
2336 hasDefaultNone = true;
2337 else if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_present)
2338 hasDefaultPresent = true;
2342 // Prepare the operand segment size attribute and the operands value range.
2343 llvm::SmallVector<mlir::Value> operands;
2344 llvm::SmallVector<int32_t> operandSegments;
2345 addOperand(operands, operandSegments, ifCond);
2346 addOperands(operands, operandSegments, async);
2347 addOperand(operands, operandSegments, waitDevnum);
2348 addOperands(operands, operandSegments, waitOperands);
2349 addOperands(operands, operandSegments, dataClauseOperands);
2351 if (dataClauseOperands.empty() && !hasDefaultNone && !hasDefaultPresent)
2352 return;
2354 auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
2355 builder, currentLocation, eval, operands, operandSegments);
2357 if (!asyncDeviceTypes.empty())
2358 dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
2359 if (!asyncOnlyDeviceTypes.empty())
2360 dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
2361 if (!waitOperandsDeviceTypes.empty())
2362 dataOp.setWaitOperandsDeviceTypeAttr(
2363 builder.getArrayAttr(waitOperandsDeviceTypes));
2364 if (!waitOperandsSegments.empty())
2365 dataOp.setWaitOperandsSegmentsAttr(
2366 builder.getDenseI32ArrayAttr(waitOperandsSegments));
2367 if (!waitOnlyDeviceTypes.empty())
2368 dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
2370 if (hasDefaultNone)
2371 dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);
2372 if (hasDefaultPresent)
2373 dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
2375 auto insPt = builder.saveInsertionPoint();
2376 builder.setInsertionPointAfter(dataOp);
2378 // Create the exit operations after the region.
2379 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
2380 builder, copyEntryOperands, /*structured=*/true);
2381 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>(
2382 builder, copyoutEntryOperands, /*structured=*/true);
2383 genDataExitOperations<mlir::acc::AttachOp, mlir::acc::DetachOp>(
2384 builder, attachEntryOperands, /*structured=*/true);
2385 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
2386 builder, createEntryOperands, /*structured=*/true);
2388 builder.restoreInsertionPoint(insPt);
2391 static void
2392 genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
2393 mlir::Location currentLocation,
2394 Fortran::lower::pft::Evaluation &eval,
2395 Fortran::semantics::SemanticsContext &semanticsContext,
2396 Fortran::lower::StatementContext &stmtCtx,
2397 const Fortran::parser::AccClauseList &accClauseList) {
2398 mlir::Value ifCond;
2399 llvm::SmallVector<mlir::Value> dataOperands;
2400 bool addIfPresentAttr = false;
2402 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2404 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2405 mlir::Location clauseLocation = converter.genLocation(clause.source);
2406 if (const auto *ifClause =
2407 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2408 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2409 } else if (const auto *useDevice =
2410 std::get_if<Fortran::parser::AccClause::UseDevice>(
2411 &clause.u)) {
2412 genDataOperandOperations<mlir::acc::UseDeviceOp>(
2413 useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
2414 mlir::acc::DataClause::acc_use_device,
2415 /*structured=*/true, /*implicit=*/false);
2416 } else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
2417 addIfPresentAttr = true;
2421 if (ifCond) {
2422 if (auto cst =
2423 mlir::dyn_cast<mlir::arith::ConstantOp>(ifCond.getDefiningOp()))
2424 if (auto boolAttr = cst.getValue().dyn_cast<mlir::BoolAttr>()) {
2425 if (boolAttr.getValue()) {
2426 // get rid of the if condition if it is always true.
2427 ifCond = mlir::Value();
2428 } else {
2429 // Do not generate the acc.host_data op if the if condition is always
2430 // false.
2431 return;
2436 // Prepare the operand segment size attribute and the operands value range.
2437 llvm::SmallVector<mlir::Value> operands;
2438 llvm::SmallVector<int32_t> operandSegments;
2439 addOperand(operands, operandSegments, ifCond);
2440 addOperands(operands, operandSegments, dataOperands);
2442 auto hostDataOp =
2443 createRegionOp<mlir::acc::HostDataOp, mlir::acc::TerminatorOp>(
2444 builder, currentLocation, eval, operands, operandSegments);
2446 if (addIfPresentAttr)
2447 hostDataOp.setIfPresentAttr(builder.getUnitAttr());
2450 static void
2451 genACC(Fortran::lower::AbstractConverter &converter,
2452 Fortran::semantics::SemanticsContext &semanticsContext,
2453 Fortran::lower::pft::Evaluation &eval,
2454 const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
2455 const auto &beginBlockDirective =
2456 std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
2457 const auto &blockDirective =
2458 std::get<Fortran::parser::AccBlockDirective>(beginBlockDirective.t);
2459 const auto &accClauseList =
2460 std::get<Fortran::parser::AccClauseList>(beginBlockDirective.t);
2462 mlir::Location currentLocation = converter.genLocation(blockDirective.source);
2463 Fortran::lower::StatementContext stmtCtx;
2465 if (blockDirective.v == llvm::acc::ACCD_parallel) {
2466 createComputeOp<mlir::acc::ParallelOp>(converter, currentLocation, eval,
2467 semanticsContext, stmtCtx,
2468 accClauseList);
2469 } else if (blockDirective.v == llvm::acc::ACCD_data) {
2470 genACCDataOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
2471 accClauseList);
2472 } else if (blockDirective.v == llvm::acc::ACCD_serial) {
2473 createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval,
2474 semanticsContext, stmtCtx,
2475 accClauseList);
2476 } else if (blockDirective.v == llvm::acc::ACCD_kernels) {
2477 createComputeOp<mlir::acc::KernelsOp>(converter, currentLocation, eval,
2478 semanticsContext, stmtCtx,
2479 accClauseList);
2480 } else if (blockDirective.v == llvm::acc::ACCD_host_data) {
2481 genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
2482 stmtCtx, accClauseList);
2486 static void
2487 genACC(Fortran::lower::AbstractConverter &converter,
2488 Fortran::semantics::SemanticsContext &semanticsContext,
2489 Fortran::lower::pft::Evaluation &eval,
2490 const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) {
2491 const auto &beginCombinedDirective =
2492 std::get<Fortran::parser::AccBeginCombinedDirective>(combinedConstruct.t);
2493 const auto &combinedDirective =
2494 std::get<Fortran::parser::AccCombinedDirective>(beginCombinedDirective.t);
2495 const auto &accClauseList =
2496 std::get<Fortran::parser::AccClauseList>(beginCombinedDirective.t);
2498 mlir::Location currentLocation =
2499 converter.genLocation(beginCombinedDirective.source);
2500 Fortran::lower::StatementContext stmtCtx;
2502 if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
2503 createComputeOp<mlir::acc::KernelsOp>(
2504 converter, currentLocation, eval, semanticsContext, stmtCtx,
2505 accClauseList, /*outerCombined=*/true);
2506 createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
2507 accClauseList);
2508 } else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
2509 createComputeOp<mlir::acc::ParallelOp>(
2510 converter, currentLocation, eval, semanticsContext, stmtCtx,
2511 accClauseList, /*outerCombined=*/true);
2512 createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
2513 accClauseList);
2514 } else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
2515 createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval,
2516 semanticsContext, stmtCtx,
2517 accClauseList, /*outerCombined=*/true);
2518 createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
2519 accClauseList);
2520 } else {
2521 llvm::report_fatal_error("Unknown combined construct encountered");
2525 static void
2526 genACCEnterDataOp(Fortran::lower::AbstractConverter &converter,
2527 mlir::Location currentLocation,
2528 Fortran::semantics::SemanticsContext &semanticsContext,
2529 Fortran::lower::StatementContext &stmtCtx,
2530 const Fortran::parser::AccClauseList &accClauseList) {
2531 mlir::Value ifCond, async, waitDevnum;
2532 llvm::SmallVector<mlir::Value> waitOperands, dataClauseOperands;
2534 // Async, wait and self clause have optional values but can be present with
2535 // no value as well. When there is no value, the op has an attribute to
2536 // represent the clause.
2537 bool addAsyncAttr = false;
2538 bool addWaitAttr = false;
2540 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2542 // Lower clauses values mapped to operands.
2543 // Keep track of each group of operands separately as clauses can appear
2544 // more than once.
2545 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2546 mlir::Location clauseLocation = converter.genLocation(clause.source);
2547 if (const auto *ifClause =
2548 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2549 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2550 } else if (const auto *asyncClause =
2551 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2552 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
2553 } else if (const auto *waitClause =
2554 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
2555 genWaitClause(converter, waitClause, waitOperands, waitDevnum,
2556 addWaitAttr, stmtCtx);
2557 } else if (const auto *copyinClause =
2558 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
2559 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
2560 copyinClause->v;
2561 const auto &accObjectList =
2562 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
2563 genDataOperandOperations<mlir::acc::CopyinOp>(
2564 accObjectList, converter, semanticsContext, stmtCtx,
2565 dataClauseOperands, mlir::acc::DataClause::acc_copyin, false,
2566 /*implicit=*/false);
2567 } else if (const auto *createClause =
2568 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
2569 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
2570 createClause->v;
2571 const auto &accObjectList =
2572 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
2573 const auto &modifier =
2574 std::get<std::optional<Fortran::parser::AccDataModifier>>(
2575 listWithModifier.t);
2576 mlir::acc::DataClause clause = mlir::acc::DataClause::acc_create;
2577 if (modifier &&
2578 (*modifier).v == Fortran::parser::AccDataModifier::Modifier::Zero)
2579 clause = mlir::acc::DataClause::acc_create_zero;
2580 genDataOperandOperations<mlir::acc::CreateOp>(
2581 accObjectList, converter, semanticsContext, stmtCtx,
2582 dataClauseOperands, clause, false, /*implicit=*/false);
2583 } else if (const auto *attachClause =
2584 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
2585 genDataOperandOperations<mlir::acc::AttachOp>(
2586 attachClause->v, converter, semanticsContext, stmtCtx,
2587 dataClauseOperands, mlir::acc::DataClause::acc_attach, false,
2588 /*implicit=*/false);
2589 } else {
2590 llvm::report_fatal_error(
2591 "Unknown clause in ENTER DATA directive lowering");
2595 // Prepare the operand segment size attribute and the operands value range.
2596 llvm::SmallVector<mlir::Value, 16> operands;
2597 llvm::SmallVector<int32_t, 8> operandSegments;
2598 addOperand(operands, operandSegments, ifCond);
2599 addOperand(operands, operandSegments, async);
2600 addOperand(operands, operandSegments, waitDevnum);
2601 addOperands(operands, operandSegments, waitOperands);
2602 addOperands(operands, operandSegments, dataClauseOperands);
2604 mlir::acc::EnterDataOp enterDataOp = createSimpleOp<mlir::acc::EnterDataOp>(
2605 firOpBuilder, currentLocation, operands, operandSegments);
2607 if (addAsyncAttr)
2608 enterDataOp.setAsyncAttr(firOpBuilder.getUnitAttr());
2609 if (addWaitAttr)
2610 enterDataOp.setWaitAttr(firOpBuilder.getUnitAttr());
2613 static void
2614 genACCExitDataOp(Fortran::lower::AbstractConverter &converter,
2615 mlir::Location currentLocation,
2616 Fortran::semantics::SemanticsContext &semanticsContext,
2617 Fortran::lower::StatementContext &stmtCtx,
2618 const Fortran::parser::AccClauseList &accClauseList) {
2619 mlir::Value ifCond, async, waitDevnum;
2620 llvm::SmallVector<mlir::Value> waitOperands, dataClauseOperands,
2621 copyoutOperands, deleteOperands, detachOperands;
2623 // Async and wait clause have optional values but can be present with
2624 // no value as well. When there is no value, the op has an attribute to
2625 // represent the clause.
2626 bool addAsyncAttr = false;
2627 bool addWaitAttr = false;
2628 bool addFinalizeAttr = false;
2630 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2632 // Lower clauses values mapped to operands.
2633 // Keep track of each group of operands separately as clauses can appear
2634 // more than once.
2635 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2636 mlir::Location clauseLocation = converter.genLocation(clause.source);
2637 if (const auto *ifClause =
2638 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2639 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2640 } else if (const auto *asyncClause =
2641 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2642 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
2643 } else if (const auto *waitClause =
2644 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
2645 genWaitClause(converter, waitClause, waitOperands, waitDevnum,
2646 addWaitAttr, stmtCtx);
2647 } else if (const auto *copyoutClause =
2648 std::get_if<Fortran::parser::AccClause::Copyout>(
2649 &clause.u)) {
2650 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
2651 copyoutClause->v;
2652 const auto &accObjectList =
2653 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
2654 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
2655 accObjectList, converter, semanticsContext, stmtCtx, copyoutOperands,
2656 mlir::acc::DataClause::acc_copyout, false, /*implicit=*/false);
2657 } else if (const auto *deleteClause =
2658 std::get_if<Fortran::parser::AccClause::Delete>(&clause.u)) {
2659 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
2660 deleteClause->v, converter, semanticsContext, stmtCtx, deleteOperands,
2661 mlir::acc::DataClause::acc_delete, false, /*implicit=*/false);
2662 } else if (const auto *detachClause =
2663 std::get_if<Fortran::parser::AccClause::Detach>(&clause.u)) {
2664 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
2665 detachClause->v, converter, semanticsContext, stmtCtx, detachOperands,
2666 mlir::acc::DataClause::acc_detach, false, /*implicit=*/false);
2667 } else if (std::get_if<Fortran::parser::AccClause::Finalize>(&clause.u)) {
2668 addFinalizeAttr = true;
2672 dataClauseOperands.append(copyoutOperands);
2673 dataClauseOperands.append(deleteOperands);
2674 dataClauseOperands.append(detachOperands);
2676 // Prepare the operand segment size attribute and the operands value range.
2677 llvm::SmallVector<mlir::Value, 14> operands;
2678 llvm::SmallVector<int32_t, 7> operandSegments;
2679 addOperand(operands, operandSegments, ifCond);
2680 addOperand(operands, operandSegments, async);
2681 addOperand(operands, operandSegments, waitDevnum);
2682 addOperands(operands, operandSegments, waitOperands);
2683 addOperands(operands, operandSegments, dataClauseOperands);
2685 mlir::acc::ExitDataOp exitDataOp = createSimpleOp<mlir::acc::ExitDataOp>(
2686 builder, currentLocation, operands, operandSegments);
2688 if (addAsyncAttr)
2689 exitDataOp.setAsyncAttr(builder.getUnitAttr());
2690 if (addWaitAttr)
2691 exitDataOp.setWaitAttr(builder.getUnitAttr());
2692 if (addFinalizeAttr)
2693 exitDataOp.setFinalizeAttr(builder.getUnitAttr());
2695 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::CopyoutOp>(
2696 builder, copyoutOperands, /*structured=*/false);
2697 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::DeleteOp>(
2698 builder, deleteOperands, /*structured=*/false);
2699 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::DetachOp>(
2700 builder, detachOperands, /*structured=*/false);
2703 template <typename Op>
2704 static void
2705 genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
2706 mlir::Location currentLocation,
2707 const Fortran::parser::AccClauseList &accClauseList) {
2708 mlir::Value ifCond, deviceNum;
2710 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2711 Fortran::lower::StatementContext stmtCtx;
2712 llvm::SmallVector<mlir::Attribute> deviceTypes;
2714 // Lower clauses values mapped to operands.
2715 // Keep track of each group of operands separately as clauses can appear
2716 // more than once.
2717 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2718 mlir::Location clauseLocation = converter.genLocation(clause.source);
2719 if (const auto *ifClause =
2720 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2721 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2722 } else if (const auto *deviceNumClause =
2723 std::get_if<Fortran::parser::AccClause::DeviceNum>(
2724 &clause.u)) {
2725 deviceNum = fir::getBase(converter.genExprValue(
2726 *Fortran::semantics::GetExpr(deviceNumClause->v), stmtCtx));
2727 } else if (const auto *deviceTypeClause =
2728 std::get_if<Fortran::parser::AccClause::DeviceType>(
2729 &clause.u)) {
2730 gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
2731 deviceTypes, stmtCtx);
2735 // Prepare the operand segment size attribute and the operands value range.
2736 llvm::SmallVector<mlir::Value, 6> operands;
2737 llvm::SmallVector<int32_t, 2> operandSegments;
2739 addOperand(operands, operandSegments, deviceNum);
2740 addOperand(operands, operandSegments, ifCond);
2742 Op op =
2743 createSimpleOp<Op>(builder, currentLocation, operands, operandSegments);
2744 if (!deviceTypes.empty())
2745 op.setDeviceTypesAttr(
2746 mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
2749 void genACCSetOp(Fortran::lower::AbstractConverter &converter,
2750 mlir::Location currentLocation,
2751 const Fortran::parser::AccClauseList &accClauseList) {
2752 mlir::Value ifCond, deviceNum, defaultAsync;
2753 llvm::SmallVector<mlir::Value> deviceTypeOperands;
2755 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2756 Fortran::lower::StatementContext stmtCtx;
2757 llvm::SmallVector<mlir::Attribute> deviceTypes;
2759 // Lower clauses values mapped to operands.
2760 // Keep track of each group of operands separately as clauses can appear
2761 // more than once.
2762 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2763 mlir::Location clauseLocation = converter.genLocation(clause.source);
2764 if (const auto *ifClause =
2765 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2766 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2767 } else if (const auto *defaultAsyncClause =
2768 std::get_if<Fortran::parser::AccClause::DefaultAsync>(
2769 &clause.u)) {
2770 defaultAsync = fir::getBase(converter.genExprValue(
2771 *Fortran::semantics::GetExpr(defaultAsyncClause->v), stmtCtx));
2772 } else if (const auto *deviceNumClause =
2773 std::get_if<Fortran::parser::AccClause::DeviceNum>(
2774 &clause.u)) {
2775 deviceNum = fir::getBase(converter.genExprValue(
2776 *Fortran::semantics::GetExpr(deviceNumClause->v), stmtCtx));
2777 } else if (const auto *deviceTypeClause =
2778 std::get_if<Fortran::parser::AccClause::DeviceType>(
2779 &clause.u)) {
2780 gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
2781 deviceTypes, stmtCtx);
2785 // Prepare the operand segment size attribute and the operands value range.
2786 llvm::SmallVector<mlir::Value> operands;
2787 llvm::SmallVector<int32_t, 3> operandSegments;
2788 addOperand(operands, operandSegments, defaultAsync);
2789 addOperand(operands, operandSegments, deviceNum);
2790 addOperand(operands, operandSegments, ifCond);
2792 auto op = createSimpleOp<mlir::acc::SetOp>(builder, currentLocation, operands,
2793 operandSegments);
2794 if (!deviceTypes.empty()) {
2795 assert(deviceTypes.size() == 1 && "expect only one value for acc.set");
2796 op.setDeviceTypeAttr(mlir::cast<mlir::acc::DeviceTypeAttr>(deviceTypes[0]));
2800 static void
2801 genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
2802 mlir::Location currentLocation,
2803 Fortran::semantics::SemanticsContext &semanticsContext,
2804 Fortran::lower::StatementContext &stmtCtx,
2805 const Fortran::parser::AccClauseList &accClauseList) {
2806 mlir::Value ifCond, async, waitDevnum;
2807 llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
2808 waitOperands, deviceTypeOperands;
2809 llvm::SmallVector<mlir::Attribute> deviceTypes;
2811 // Async and wait clause have optional values but can be present with
2812 // no value as well. When there is no value, the op has an attribute to
2813 // represent the clause.
2814 bool addAsyncAttr = false;
2815 bool addWaitAttr = false;
2816 bool addIfPresentAttr = false;
2818 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2820 // Lower clauses values mapped to operands.
2821 // Keep track of each group of operands separately as clauses can appear
2822 // more than once.
2823 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2824 mlir::Location clauseLocation = converter.genLocation(clause.source);
2825 if (const auto *ifClause =
2826 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2827 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2828 } else if (const auto *asyncClause =
2829 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2830 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
2831 } else if (const auto *waitClause =
2832 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
2833 genWaitClause(converter, waitClause, waitOperands, waitDevnum,
2834 addWaitAttr, stmtCtx);
2835 } else if (const auto *deviceTypeClause =
2836 std::get_if<Fortran::parser::AccClause::DeviceType>(
2837 &clause.u)) {
2838 gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
2839 deviceTypes, stmtCtx);
2840 } else if (const auto *hostClause =
2841 std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
2842 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
2843 hostClause->v, converter, semanticsContext, stmtCtx,
2844 updateHostOperands, mlir::acc::DataClause::acc_update_host, false,
2845 /*implicit=*/false);
2846 } else if (const auto *deviceClause =
2847 std::get_if<Fortran::parser::AccClause::Device>(&clause.u)) {
2848 genDataOperandOperations<mlir::acc::UpdateDeviceOp>(
2849 deviceClause->v, converter, semanticsContext, stmtCtx,
2850 dataClauseOperands, mlir::acc::DataClause::acc_update_device, false,
2851 /*implicit=*/false);
2852 } else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
2853 addIfPresentAttr = true;
2854 } else if (const auto *selfClause =
2855 std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
2856 const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
2857 selfClause->v;
2858 const auto *accObjectList =
2859 std::get_if<Fortran::parser::AccObjectList>(&(*accSelfClause).u);
2860 assert(accObjectList && "expect AccObjectList");
2861 genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
2862 *accObjectList, converter, semanticsContext, stmtCtx,
2863 updateHostOperands, mlir::acc::DataClause::acc_update_self, false,
2864 /*implicit=*/false);
2868 dataClauseOperands.append(updateHostOperands);
2870 // Prepare the operand segment size attribute and the operands value range.
2871 llvm::SmallVector<mlir::Value> operands;
2872 llvm::SmallVector<int32_t> operandSegments;
2873 addOperand(operands, operandSegments, ifCond);
2874 addOperand(operands, operandSegments, async);
2875 addOperand(operands, operandSegments, waitDevnum);
2876 addOperands(operands, operandSegments, waitOperands);
2877 addOperands(operands, operandSegments, dataClauseOperands);
2879 mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(
2880 builder, currentLocation, operands, operandSegments);
2881 if (!deviceTypes.empty())
2882 updateOp.setDeviceTypesAttr(
2883 mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
2885 genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
2886 builder, updateHostOperands, /*structured=*/false);
2888 if (addAsyncAttr)
2889 updateOp.setAsyncAttr(builder.getUnitAttr());
2890 if (addWaitAttr)
2891 updateOp.setWaitAttr(builder.getUnitAttr());
2892 if (addIfPresentAttr)
2893 updateOp.setIfPresentAttr(builder.getUnitAttr());
2896 static void
2897 genACC(Fortran::lower::AbstractConverter &converter,
2898 Fortran::semantics::SemanticsContext &semanticsContext,
2899 const Fortran::parser::OpenACCStandaloneConstruct &standaloneConstruct) {
2900 const auto &standaloneDirective =
2901 std::get<Fortran::parser::AccStandaloneDirective>(standaloneConstruct.t);
2902 const auto &accClauseList =
2903 std::get<Fortran::parser::AccClauseList>(standaloneConstruct.t);
2905 mlir::Location currentLocation =
2906 converter.genLocation(standaloneDirective.source);
2907 Fortran::lower::StatementContext stmtCtx;
2909 if (standaloneDirective.v == llvm::acc::Directive::ACCD_enter_data) {
2910 genACCEnterDataOp(converter, currentLocation, semanticsContext, stmtCtx,
2911 accClauseList);
2912 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_exit_data) {
2913 genACCExitDataOp(converter, currentLocation, semanticsContext, stmtCtx,
2914 accClauseList);
2915 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_init) {
2916 genACCInitShutdownOp<mlir::acc::InitOp>(converter, currentLocation,
2917 accClauseList);
2918 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_shutdown) {
2919 genACCInitShutdownOp<mlir::acc::ShutdownOp>(converter, currentLocation,
2920 accClauseList);
2921 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_set) {
2922 genACCSetOp(converter, currentLocation, accClauseList);
2923 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_update) {
2924 genACCUpdateOp(converter, currentLocation, semanticsContext, stmtCtx,
2925 accClauseList);
2929 static void genACC(Fortran::lower::AbstractConverter &converter,
2930 const Fortran::parser::OpenACCWaitConstruct &waitConstruct) {
2932 const auto &waitArgument =
2933 std::get<std::optional<Fortran::parser::AccWaitArgument>>(
2934 waitConstruct.t);
2935 const auto &accClauseList =
2936 std::get<Fortran::parser::AccClauseList>(waitConstruct.t);
2938 mlir::Value ifCond, waitDevnum, async;
2939 llvm::SmallVector<mlir::Value> waitOperands;
2941 // Async clause have optional values but can be present with
2942 // no value as well. When there is no value, the op has an attribute to
2943 // represent the clause.
2944 bool addAsyncAttr = false;
2946 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2947 mlir::Location currentLocation = converter.genLocation(waitConstruct.source);
2948 Fortran::lower::StatementContext stmtCtx;
2950 if (waitArgument) { // wait has a value.
2951 const Fortran::parser::AccWaitArgument &waitArg = *waitArgument;
2952 const auto &waitList =
2953 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
2954 for (const Fortran::parser::ScalarIntExpr &value : waitList) {
2955 mlir::Value v = fir::getBase(
2956 converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx));
2957 waitOperands.push_back(v);
2960 const auto &waitDevnumValue =
2961 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
2962 if (waitDevnumValue)
2963 waitDevnum = fir::getBase(converter.genExprValue(
2964 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
2967 // Lower clauses values mapped to operands.
2968 // Keep track of each group of operands separately as clauses can appear
2969 // more than once.
2970 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
2971 mlir::Location clauseLocation = converter.genLocation(clause.source);
2972 if (const auto *ifClause =
2973 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
2974 genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
2975 } else if (const auto *asyncClause =
2976 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2977 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
2981 // Prepare the operand segment size attribute and the operands value range.
2982 llvm::SmallVector<mlir::Value> operands;
2983 llvm::SmallVector<int32_t> operandSegments;
2984 addOperands(operands, operandSegments, waitOperands);
2985 addOperand(operands, operandSegments, async);
2986 addOperand(operands, operandSegments, waitDevnum);
2987 addOperand(operands, operandSegments, ifCond);
2989 mlir::acc::WaitOp waitOp = createSimpleOp<mlir::acc::WaitOp>(
2990 firOpBuilder, currentLocation, operands, operandSegments);
2992 if (addAsyncAttr)
2993 waitOp.setAsyncAttr(firOpBuilder.getUnitAttr());
2996 template <typename GlobalOp, typename EntryOp, typename DeclareOp,
2997 typename ExitOp>
2998 static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder,
2999 fir::FirOpBuilder &builder,
3000 mlir::Location loc, fir::GlobalOp globalOp,
3001 mlir::acc::DataClause clause,
3002 const std::string declareGlobalName,
3003 bool implicit, std::stringstream &asFortran) {
3004 GlobalOp declareGlobalOp =
3005 modBuilder.create<GlobalOp>(loc, declareGlobalName);
3006 builder.createBlock(&declareGlobalOp.getRegion(),
3007 declareGlobalOp.getRegion().end(), {}, {});
3008 builder.setInsertionPointToEnd(&declareGlobalOp.getRegion().back());
3010 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
3011 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3012 addDeclareAttr(builder, addrOp, clause);
3014 llvm::SmallVector<mlir::Value> bounds;
3015 EntryOp entryOp = createDataEntryOp<EntryOp>(
3016 builder, loc, addrOp.getResTy(), asFortran, bounds,
3017 /*structured=*/false, implicit, clause, addrOp.getResTy().getType());
3018 if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>)
3019 builder.create<DeclareOp>(
3020 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
3021 mlir::ValueRange(entryOp.getAccPtr()));
3022 else
3023 builder.create<DeclareOp>(loc, mlir::Value{},
3024 mlir::ValueRange(entryOp.getAccPtr()));
3025 if constexpr (std::is_same_v<GlobalOp, mlir::acc::GlobalDestructorOp>) {
3026 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
3027 entryOp.getBounds(), entryOp.getDataClause(),
3028 /*structured=*/false, /*implicit=*/false,
3029 builder.getStringAttr(*entryOp.getName()));
3031 builder.create<mlir::acc::TerminatorOp>(loc);
3032 modBuilder.setInsertionPointAfter(declareGlobalOp);
3035 template <typename EntryOp>
3036 static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
3037 fir::FirOpBuilder &builder,
3038 mlir::Location loc, fir::GlobalOp &globalOp,
3039 mlir::acc::DataClause clause) {
3040 std::stringstream registerFuncName;
3041 registerFuncName << globalOp.getSymName().str()
3042 << Fortran::lower::declarePostAllocSuffix.str();
3043 auto registerFuncOp =
3044 createDeclareFunc(modBuilder, builder, loc, registerFuncName.str());
3046 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
3047 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3049 std::stringstream asFortran;
3050 asFortran << Fortran::lower::mangle::demangleName(globalOp.getSymName());
3051 std::stringstream asFortranDesc;
3052 asFortranDesc << asFortran.str() << accFirDescriptorPostfix.str();
3053 llvm::SmallVector<mlir::Value> bounds;
3055 // Updating descriptor must occur before the mapping of the data so that
3056 // attached data pointer is not overwritten.
3057 mlir::acc::UpdateDeviceOp updateDeviceOp =
3058 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
3059 builder, loc, addrOp, asFortranDesc, bounds,
3060 /*structured=*/false, /*implicit=*/true,
3061 mlir::acc::DataClause::acc_update_device, addrOp.getType());
3062 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
3063 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
3064 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
3066 auto loadOp = builder.create<fir::LoadOp>(loc, addrOp.getResult());
3067 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
3068 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
3069 EntryOp entryOp = createDataEntryOp<EntryOp>(
3070 builder, loc, boxAddrOp.getResult(), asFortran, bounds,
3071 /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
3072 builder.create<mlir::acc::DeclareEnterOp>(
3073 loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
3074 mlir::ValueRange(entryOp.getAccPtr()));
3076 modBuilder.setInsertionPointAfter(registerFuncOp);
3079 /// Action to be performed on deallocation are split in two distinct functions.
3080 /// - Pre deallocation function includes all the action to be performed before
3081 /// the actual deallocation is done on the host side.
3082 /// - Post deallocation function includes update to the descriptor.
3083 template <typename ExitOp>
3084 static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
3085 fir::FirOpBuilder &builder,
3086 mlir::Location loc,
3087 fir::GlobalOp &globalOp,
3088 mlir::acc::DataClause clause) {
3090 // Generate the pre dealloc function.
3091 std::stringstream preDeallocFuncName;
3092 preDeallocFuncName << globalOp.getSymName().str()
3093 << Fortran::lower::declarePreDeallocSuffix.str();
3094 auto preDeallocOp =
3095 createDeclareFunc(modBuilder, builder, loc, preDeallocFuncName.str());
3096 fir::AddrOfOp addrOp = builder.create<fir::AddrOfOp>(
3097 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3098 auto loadOp = builder.create<fir::LoadOp>(loc, addrOp.getResult());
3099 fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
3100 addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
3102 std::stringstream asFortran;
3103 asFortran << Fortran::lower::mangle::demangleName(globalOp.getSymName());
3104 llvm::SmallVector<mlir::Value> bounds;
3105 mlir::acc::GetDevicePtrOp entryOp =
3106 createDataEntryOp<mlir::acc::GetDevicePtrOp>(
3107 builder, loc, boxAddrOp.getResult(), asFortran, bounds,
3108 /*structured=*/false, /*implicit=*/false, clause,
3109 boxAddrOp.getType());
3111 builder.create<mlir::acc::DeclareExitOp>(
3112 loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));
3114 if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
3115 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
3116 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
3117 entryOp.getVarPtr(), entryOp.getBounds(),
3118 entryOp.getDataClause(),
3119 /*structured=*/false, /*implicit=*/false,
3120 builder.getStringAttr(*entryOp.getName()));
3121 else
3122 builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
3123 entryOp.getBounds(), entryOp.getDataClause(),
3124 /*structured=*/false, /*implicit=*/false,
3125 builder.getStringAttr(*entryOp.getName()));
3127 // Generate the post dealloc function.
3128 modBuilder.setInsertionPointAfter(preDeallocOp);
3129 std::stringstream postDeallocFuncName;
3130 postDeallocFuncName << globalOp.getSymName().str()
3131 << Fortran::lower::declarePostDeallocSuffix.str();
3132 auto postDeallocOp =
3133 createDeclareFunc(modBuilder, builder, loc, postDeallocFuncName.str());
3135 addrOp = builder.create<fir::AddrOfOp>(
3136 loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol());
3137 asFortran << accFirDescriptorPostfix.str();
3138 mlir::acc::UpdateDeviceOp updateDeviceOp =
3139 createDataEntryOp<mlir::acc::UpdateDeviceOp>(
3140 builder, loc, addrOp, asFortran, bounds,
3141 /*structured=*/false, /*implicit=*/true,
3142 mlir::acc::DataClause::acc_update_device, addrOp.getType());
3143 llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
3144 llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
3145 createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
3146 modBuilder.setInsertionPointAfter(postDeallocOp);
3149 template <typename EntryOp, typename ExitOp>
3150 static void genGlobalCtors(Fortran::lower::AbstractConverter &converter,
3151 mlir::OpBuilder &modBuilder,
3152 const Fortran::parser::AccObjectList &accObjectList,
3153 mlir::acc::DataClause clause) {
3154 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3155 for (const auto &accObject : accObjectList.v) {
3156 mlir::Location operandLocation = genOperandLocation(converter, accObject);
3157 std::visit(
3158 Fortran::common::visitors{
3159 [&](const Fortran::parser::Designator &designator) {
3160 if (const auto *name =
3161 Fortran::semantics::getDesignatorNameIfDataRef(
3162 designator)) {
3163 std::string globalName = converter.mangleName(*name->symbol);
3164 fir::GlobalOp globalOp = builder.getNamedGlobal(globalName);
3165 std::stringstream declareGlobalCtorName;
3166 declareGlobalCtorName << globalName << "_acc_ctor";
3167 std::stringstream declareGlobalDtorName;
3168 declareGlobalDtorName << globalName << "_acc_dtor";
3169 std::stringstream asFortran;
3170 asFortran << name->symbol->name().ToString();
3172 if (builder.getModule()
3173 .lookupSymbol<mlir::acc::GlobalConstructorOp>(
3174 declareGlobalCtorName.str()))
3175 return;
3177 if (!globalOp) {
3178 if (Fortran::semantics::FindEquivalenceSet(*name->symbol)) {
3179 for (Fortran::semantics::EquivalenceObject eqObj :
3180 *Fortran::semantics::FindEquivalenceSet(
3181 *name->symbol)) {
3182 std::string eqName = converter.mangleName(eqObj.symbol);
3183 globalOp = builder.getNamedGlobal(eqName);
3184 if (globalOp)
3185 break;
3188 if (!globalOp)
3189 llvm::report_fatal_error(
3190 "could not retrieve global symbol");
3191 } else {
3192 llvm::report_fatal_error(
3193 "could not retrieve global symbol");
3197 addDeclareAttr(builder, globalOp.getOperation(), clause);
3198 auto crtPos = builder.saveInsertionPoint();
3199 modBuilder.setInsertionPointAfter(globalOp);
3200 if (mlir::isa<fir::BaseBoxType>(
3201 fir::unwrapRefType(globalOp.getType()))) {
3202 createDeclareGlobalOp<mlir::acc::GlobalConstructorOp,
3203 mlir::acc::CopyinOp,
3204 mlir::acc::DeclareEnterOp, ExitOp>(
3205 modBuilder, builder, operandLocation, globalOp, clause,
3206 declareGlobalCtorName.str(), /*implicit=*/true,
3207 asFortran);
3208 createDeclareAllocFunc<EntryOp>(
3209 modBuilder, builder, operandLocation, globalOp, clause);
3210 if constexpr (!std::is_same_v<EntryOp, ExitOp>)
3211 createDeclareDeallocFunc<ExitOp>(
3212 modBuilder, builder, operandLocation, globalOp, clause);
3213 } else {
3214 createDeclareGlobalOp<mlir::acc::GlobalConstructorOp, EntryOp,
3215 mlir::acc::DeclareEnterOp, ExitOp>(
3216 modBuilder, builder, operandLocation, globalOp, clause,
3217 declareGlobalCtorName.str(), /*implicit=*/false,
3218 asFortran);
3220 if constexpr (!std::is_same_v<EntryOp, ExitOp>) {
3221 createDeclareGlobalOp<mlir::acc::GlobalDestructorOp,
3222 mlir::acc::GetDevicePtrOp,
3223 mlir::acc::DeclareExitOp, ExitOp>(
3224 modBuilder, builder, operandLocation, globalOp, clause,
3225 declareGlobalDtorName.str(), /*implicit=*/false,
3226 asFortran);
3228 builder.restoreInsertionPoint(crtPos);
3231 [&](const Fortran::parser::Name &name) {
3232 TODO(operandLocation, "OpenACC Global Ctor from parser::Name");
3234 accObject.u);
3238 template <typename Clause, typename EntryOp, typename ExitOp>
3239 static void
3240 genGlobalCtorsWithModifier(Fortran::lower::AbstractConverter &converter,
3241 mlir::OpBuilder &modBuilder, const Clause *x,
3242 Fortran::parser::AccDataModifier::Modifier mod,
3243 const mlir::acc::DataClause clause,
3244 const mlir::acc::DataClause clauseWithModifier) {
3245 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
3246 const auto &accObjectList =
3247 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3248 const auto &modifier =
3249 std::get<std::optional<Fortran::parser::AccDataModifier>>(
3250 listWithModifier.t);
3251 mlir::acc::DataClause dataClause =
3252 (modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
3253 genGlobalCtors<EntryOp, ExitOp>(converter, modBuilder, accObjectList,
3254 dataClause);
3257 static void
3258 genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
3259 Fortran::semantics::SemanticsContext &semanticsContext,
3260 Fortran::lower::StatementContext &openAccCtx,
3261 mlir::Location loc,
3262 const Fortran::parser::AccClauseList &accClauseList) {
3263 llvm::SmallVector<mlir::Value> dataClauseOperands, copyEntryOperands,
3264 createEntryOperands, copyoutEntryOperands, deviceResidentEntryOperands;
3265 Fortran::lower::StatementContext stmtCtx;
3266 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3268 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3269 if (const auto *copyClause =
3270 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
3271 auto crtDataStart = dataClauseOperands.size();
3272 genDeclareDataOperandOperations<mlir::acc::CopyinOp,
3273 mlir::acc::CopyoutOp>(
3274 copyClause->v, converter, semanticsContext, stmtCtx,
3275 dataClauseOperands, mlir::acc::DataClause::acc_copy,
3276 /*structured=*/true, /*implicit=*/false);
3277 copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3278 dataClauseOperands.end());
3279 } else if (const auto *createClause =
3280 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
3281 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3282 createClause->v;
3283 const auto &accObjectList =
3284 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3285 auto crtDataStart = dataClauseOperands.size();
3286 genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
3287 accObjectList, converter, semanticsContext, stmtCtx,
3288 dataClauseOperands, mlir::acc::DataClause::acc_create,
3289 /*structured=*/true, /*implicit=*/false);
3290 createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3291 dataClauseOperands.end());
3292 } else if (const auto *presentClause =
3293 std::get_if<Fortran::parser::AccClause::Present>(
3294 &clause.u)) {
3295 genDeclareDataOperandOperations<mlir::acc::PresentOp,
3296 mlir::acc::PresentOp>(
3297 presentClause->v, converter, semanticsContext, stmtCtx,
3298 dataClauseOperands, mlir::acc::DataClause::acc_present,
3299 /*structured=*/true, /*implicit=*/false);
3300 } else if (const auto *copyinClause =
3301 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
3302 genDeclareDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
3303 mlir::acc::DeleteOp>(
3304 copyinClause, converter, semanticsContext, stmtCtx,
3305 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
3306 dataClauseOperands, mlir::acc::DataClause::acc_copyin,
3307 mlir::acc::DataClause::acc_copyin_readonly);
3308 } else if (const auto *copyoutClause =
3309 std::get_if<Fortran::parser::AccClause::Copyout>(
3310 &clause.u)) {
3311 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3312 copyoutClause->v;
3313 const auto &accObjectList =
3314 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3315 auto crtDataStart = dataClauseOperands.size();
3316 genDeclareDataOperandOperations<mlir::acc::CreateOp,
3317 mlir::acc::CopyoutOp>(
3318 accObjectList, converter, semanticsContext, stmtCtx,
3319 dataClauseOperands, mlir::acc::DataClause::acc_copyout,
3320 /*structured=*/true, /*implicit=*/false);
3321 copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
3322 dataClauseOperands.end());
3323 } else if (const auto *devicePtrClause =
3324 std::get_if<Fortran::parser::AccClause::Deviceptr>(
3325 &clause.u)) {
3326 genDeclareDataOperandOperations<mlir::acc::DevicePtrOp,
3327 mlir::acc::DevicePtrOp>(
3328 devicePtrClause->v, converter, semanticsContext, stmtCtx,
3329 dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
3330 /*structured=*/true, /*implicit=*/false);
3331 } else if (const auto *linkClause =
3332 std::get_if<Fortran::parser::AccClause::Link>(&clause.u)) {
3333 genDeclareDataOperandOperations<mlir::acc::DeclareLinkOp,
3334 mlir::acc::DeclareLinkOp>(
3335 linkClause->v, converter, semanticsContext, stmtCtx,
3336 dataClauseOperands, mlir::acc::DataClause::acc_declare_link,
3337 /*structured=*/true, /*implicit=*/false);
3338 } else if (const auto *deviceResidentClause =
3339 std::get_if<Fortran::parser::AccClause::DeviceResident>(
3340 &clause.u)) {
3341 auto crtDataStart = dataClauseOperands.size();
3342 genDeclareDataOperandOperations<mlir::acc::DeclareDeviceResidentOp,
3343 mlir::acc::DeleteOp>(
3344 deviceResidentClause->v, converter, semanticsContext, stmtCtx,
3345 dataClauseOperands,
3346 mlir::acc::DataClause::acc_declare_device_resident,
3347 /*structured=*/true, /*implicit=*/false);
3348 deviceResidentEntryOperands.append(
3349 dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end());
3350 } else {
3351 mlir::Location clauseLocation = converter.genLocation(clause.source);
3352 TODO(clauseLocation, "clause on declare directive");
3356 mlir::func::FuncOp funcOp = builder.getFunction();
3357 auto ops = funcOp.getOps<mlir::acc::DeclareEnterOp>();
3358 mlir::Value declareToken;
3359 if (ops.empty()) {
3360 declareToken = builder.create<mlir::acc::DeclareEnterOp>(
3361 loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
3362 dataClauseOperands);
3363 } else {
3364 auto declareOp = *ops.begin();
3365 auto newDeclareOp = builder.create<mlir::acc::DeclareEnterOp>(
3366 loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
3367 declareOp.getDataClauseOperands());
3368 newDeclareOp.getDataClauseOperandsMutable().append(dataClauseOperands);
3369 declareToken = newDeclareOp.getToken();
3370 declareOp.erase();
3373 openAccCtx.attachCleanup([&builder, loc, createEntryOperands,
3374 copyEntryOperands, copyoutEntryOperands,
3375 deviceResidentEntryOperands, declareToken]() {
3376 llvm::SmallVector<mlir::Value> operands;
3377 operands.append(createEntryOperands);
3378 operands.append(deviceResidentEntryOperands);
3379 operands.append(copyEntryOperands);
3380 operands.append(copyoutEntryOperands);
3382 mlir::func::FuncOp funcOp = builder.getFunction();
3383 auto ops = funcOp.getOps<mlir::acc::DeclareExitOp>();
3384 if (ops.empty()) {
3385 builder.create<mlir::acc::DeclareExitOp>(loc, declareToken, operands);
3386 } else {
3387 auto declareOp = *ops.begin();
3388 declareOp.getDataClauseOperandsMutable().append(operands);
3391 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
3392 builder, createEntryOperands, /*structured=*/true);
3393 genDataExitOperations<mlir::acc::DeclareDeviceResidentOp,
3394 mlir::acc::DeleteOp>(
3395 builder, deviceResidentEntryOperands, /*structured=*/true);
3396 genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
3397 builder, copyEntryOperands, /*structured=*/true);
3398 genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>(
3399 builder, copyoutEntryOperands, /*structured=*/true);
3403 static void
3404 genDeclareInModule(Fortran::lower::AbstractConverter &converter,
3405 mlir::ModuleOp &moduleOp,
3406 const Fortran::parser::AccClauseList &accClauseList) {
3407 mlir::OpBuilder modBuilder(moduleOp.getBodyRegion());
3408 for (const Fortran::parser::AccClause &clause : accClauseList.v) {
3409 if (const auto *createClause =
3410 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
3411 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3412 createClause->v;
3413 const auto &accObjectList =
3414 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3415 genGlobalCtors<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
3416 converter, modBuilder, accObjectList,
3417 mlir::acc::DataClause::acc_create);
3418 } else if (const auto *copyinClause =
3419 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
3420 genGlobalCtorsWithModifier<Fortran::parser::AccClause::Copyin,
3421 mlir::acc::CopyinOp, mlir::acc::CopyinOp>(
3422 converter, modBuilder, copyinClause,
3423 Fortran::parser::AccDataModifier::Modifier::ReadOnly,
3424 mlir::acc::DataClause::acc_copyin,
3425 mlir::acc::DataClause::acc_copyin_readonly);
3426 } else if (const auto *deviceResidentClause =
3427 std::get_if<Fortran::parser::AccClause::DeviceResident>(
3428 &clause.u)) {
3429 genGlobalCtors<mlir::acc::DeclareDeviceResidentOp, mlir::acc::DeleteOp>(
3430 converter, modBuilder, deviceResidentClause->v,
3431 mlir::acc::DataClause::acc_declare_device_resident);
3432 } else if (const auto *linkClause =
3433 std::get_if<Fortran::parser::AccClause::Link>(&clause.u)) {
3434 genGlobalCtors<mlir::acc::DeclareLinkOp, mlir::acc::DeclareLinkOp>(
3435 converter, modBuilder, linkClause->v,
3436 mlir::acc::DataClause::acc_declare_link);
3437 } else {
3438 llvm::report_fatal_error("unsupported clause on DECLARE directive");
3443 static void genACC(Fortran::lower::AbstractConverter &converter,
3444 Fortran::semantics::SemanticsContext &semanticsContext,
3445 Fortran::lower::StatementContext &openAccCtx,
3446 const Fortran::parser::OpenACCStandaloneDeclarativeConstruct
3447 &declareConstruct) {
3449 const auto &declarativeDir =
3450 std::get<Fortran::parser::AccDeclarativeDirective>(declareConstruct.t);
3451 mlir::Location directiveLocation =
3452 converter.genLocation(declarativeDir.source);
3453 const auto &accClauseList =
3454 std::get<Fortran::parser::AccClauseList>(declareConstruct.t);
3456 if (declarativeDir.v == llvm::acc::Directive::ACCD_declare) {
3457 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3458 auto moduleOp =
3459 builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
3460 auto funcOp =
3461 builder.getBlock()->getParent()->getParentOfType<mlir::func::FuncOp>();
3462 if (funcOp)
3463 genDeclareInFunction(converter, semanticsContext, openAccCtx,
3464 directiveLocation, accClauseList);
3465 else if (moduleOp)
3466 genDeclareInModule(converter, moduleOp, accClauseList);
3467 return;
3469 llvm_unreachable("unsupported declarative directive");
3472 static void attachRoutineInfo(mlir::func::FuncOp func,
3473 mlir::SymbolRefAttr routineAttr) {
3474 llvm::SmallVector<mlir::SymbolRefAttr> routines;
3475 if (func.getOperation()->hasAttr(mlir::acc::getRoutineInfoAttrName())) {
3476 auto routineInfo =
3477 func.getOperation()->getAttrOfType<mlir::acc::RoutineInfoAttr>(
3478 mlir::acc::getRoutineInfoAttrName());
3479 routines.append(routineInfo.getAccRoutines().begin(),
3480 routineInfo.getAccRoutines().end());
3482 routines.push_back(routineAttr);
3483 func.getOperation()->setAttr(
3484 mlir::acc::getRoutineInfoAttrName(),
3485 mlir::acc::RoutineInfoAttr::get(func.getContext(), routines));
3488 void Fortran::lower::genOpenACCRoutineConstruct(
3489 Fortran::lower::AbstractConverter &converter,
3490 Fortran::semantics::SemanticsContext &semanticsContext, mlir::ModuleOp &mod,
3491 const Fortran::parser::OpenACCRoutineConstruct &routineConstruct,
3492 Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
3493 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3494 mlir::Location loc = converter.genLocation(routineConstruct.source);
3495 std::optional<Fortran::parser::Name> name =
3496 std::get<std::optional<Fortran::parser::Name>>(routineConstruct.t);
3497 const auto &clauses =
3498 std::get<Fortran::parser::AccClauseList>(routineConstruct.t);
3499 mlir::func::FuncOp funcOp;
3500 std::string funcName;
3501 if (name) {
3502 funcName = converter.mangleName(*name->symbol);
3503 funcOp = builder.getNamedFunction(mod, funcName);
3504 } else {
3505 Fortran::semantics::Scope &scope =
3506 semanticsContext.FindScope(routineConstruct.source);
3507 const Fortran::semantics::Scope &progUnit{GetProgramUnitContaining(scope)};
3508 const auto *subpDetails{
3509 progUnit.symbol()
3510 ? progUnit.symbol()
3511 ->detailsIf<Fortran::semantics::SubprogramDetails>()
3512 : nullptr};
3513 if (subpDetails && subpDetails->isInterface()) {
3514 funcName = converter.mangleName(*progUnit.symbol());
3515 funcOp = builder.getNamedFunction(mod, funcName);
3516 } else {
3517 funcOp = builder.getFunction();
3518 funcName = funcOp.getName();
3521 bool hasSeq = false, hasGang = false, hasWorker = false, hasVector = false,
3522 hasNohost = false;
3523 std::optional<std::string> bindName = std::nullopt;
3524 std::optional<int64_t> gangDim = std::nullopt;
3526 for (const Fortran::parser::AccClause &clause : clauses.v) {
3527 if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
3528 hasSeq = true;
3529 } else if (const auto *gangClause =
3530 std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
3531 hasGang = true;
3532 if (gangClause->v) {
3533 const Fortran::parser::AccGangArgList &x = *gangClause->v;
3534 for (const Fortran::parser::AccGangArg &gangArg : x.v) {
3535 if (const auto *dim =
3536 std::get_if<Fortran::parser::AccGangArg::Dim>(&gangArg.u)) {
3537 const std::optional<int64_t> dimValue = Fortran::evaluate::ToInt64(
3538 *Fortran::semantics::GetExpr(dim->v));
3539 if (!dimValue)
3540 mlir::emitError(loc,
3541 "dim value must be a constant positive integer");
3542 gangDim = *dimValue;
3546 } else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
3547 hasVector = true;
3548 } else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
3549 hasWorker = true;
3550 } else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
3551 hasNohost = true;
3552 } else if (const auto *bindClause =
3553 std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
3554 if (const auto *name =
3555 std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
3556 bindName = converter.mangleName(*name->symbol);
3557 } else if (const auto charExpr =
3558 std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
3559 &bindClause->v.u)) {
3560 const std::optional<std::string> name =
3561 Fortran::semantics::GetConstExpr<std::string>(semanticsContext,
3562 *charExpr);
3563 if (!name)
3564 mlir::emitError(loc, "Could not retrieve the bind name");
3565 bindName = *name;
3570 mlir::OpBuilder modBuilder(mod.getBodyRegion());
3571 std::stringstream routineOpName;
3572 routineOpName << accRoutinePrefix.str() << routineCounter++;
3574 for (auto routineOp : mod.getOps<mlir::acc::RoutineOp>()) {
3575 if (routineOp.getFuncName().str().compare(funcName) == 0) {
3576 // If the routine is already specified with the same clauses, just skip
3577 // the operation creation.
3578 if (routineOp.getBindName() == bindName &&
3579 routineOp.getGang() == hasGang &&
3580 routineOp.getWorker() == hasWorker &&
3581 routineOp.getVector() == hasVector && routineOp.getSeq() == hasSeq &&
3582 routineOp.getNohost() == hasNohost &&
3583 routineOp.getGangDim() == gangDim)
3584 return;
3585 mlir::emitError(loc, "Routine already specified with different clauses");
3589 modBuilder.create<mlir::acc::RoutineOp>(
3590 loc, routineOpName.str(), funcName,
3591 bindName ? builder.getStringAttr(*bindName) : mlir::StringAttr{}, hasGang,
3592 hasWorker, hasVector, hasSeq, hasNohost, /*implicit=*/false,
3593 gangDim ? builder.getIntegerAttr(builder.getIntegerType(32), *gangDim)
3594 : mlir::IntegerAttr{});
3596 if (funcOp)
3597 attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpName.str()));
3598 else
3599 // FuncOp is not lowered yet. Keep the information so the routine info
3600 // can be attached later to the funcOp.
3601 accRoutineInfos.push_back(std::make_pair(
3602 funcName, builder.getSymbolRefAttr(routineOpName.str())));
3605 void Fortran::lower::finalizeOpenACCRoutineAttachment(
3606 mlir::ModuleOp &mod,
3607 Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
3608 for (auto &mapping : accRoutineInfos) {
3609 mlir::func::FuncOp funcOp =
3610 mod.lookupSymbol<mlir::func::FuncOp>(mapping.first);
3611 if (!funcOp)
3612 mlir::emitWarning(mod.getLoc(),
3613 llvm::Twine("function '") + llvm::Twine(mapping.first) +
3614 llvm::Twine("' in acc routine directive is not "
3615 "found in this translation unit."));
3616 else
3617 attachRoutineInfo(funcOp, mapping.second);
3619 accRoutineInfos.clear();
3622 static void
3623 genACC(Fortran::lower::AbstractConverter &converter,
3624 Fortran::lower::pft::Evaluation &eval,
3625 const Fortran::parser::OpenACCAtomicConstruct &atomicConstruct) {
3627 mlir::Location loc = converter.genLocation(atomicConstruct.source);
3628 std::visit(
3629 Fortran::common::visitors{
3630 [&](const Fortran::parser::AccAtomicRead &atomicRead) {
3631 Fortran::lower::genOmpAccAtomicRead<Fortran::parser::AccAtomicRead,
3632 void>(converter, atomicRead,
3633 loc);
3635 [&](const Fortran::parser::AccAtomicWrite &atomicWrite) {
3636 Fortran::lower::genOmpAccAtomicWrite<
3637 Fortran::parser::AccAtomicWrite, void>(converter, atomicWrite,
3638 loc);
3640 [&](const Fortran::parser::AccAtomicUpdate &atomicUpdate) {
3641 Fortran::lower::genOmpAccAtomicUpdate<
3642 Fortran::parser::AccAtomicUpdate, void>(converter, atomicUpdate,
3643 loc);
3645 [&](const Fortran::parser::AccAtomicCapture &atomicCapture) {
3646 Fortran::lower::genOmpAccAtomicCapture<
3647 Fortran::parser::AccAtomicCapture, void>(converter,
3648 atomicCapture, loc);
3651 atomicConstruct.u);
3654 static void
3655 genACC(Fortran::lower::AbstractConverter &converter,
3656 Fortran::semantics::SemanticsContext &semanticsContext,
3657 const Fortran::parser::OpenACCCacheConstruct &cacheConstruct) {
3658 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3659 auto loopOp = builder.getRegion().getParentOfType<mlir::acc::LoopOp>();
3660 auto crtPos = builder.saveInsertionPoint();
3661 if (loopOp) {
3662 builder.setInsertionPoint(loopOp);
3663 Fortran::lower::StatementContext stmtCtx;
3664 llvm::SmallVector<mlir::Value> cacheOperands;
3665 const Fortran::parser::AccObjectListWithModifier &listWithModifier =
3666 std::get<Fortran::parser::AccObjectListWithModifier>(cacheConstruct.t);
3667 const auto &accObjectList =
3668 std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
3669 const auto &modifier =
3670 std::get<std::optional<Fortran::parser::AccDataModifier>>(
3671 listWithModifier.t);
3673 mlir::acc::DataClause dataClause = mlir::acc::DataClause::acc_cache;
3674 if (modifier &&
3675 (*modifier).v == Fortran::parser::AccDataModifier::Modifier::ReadOnly)
3676 dataClause = mlir::acc::DataClause::acc_cache_readonly;
3677 genDataOperandOperations<mlir::acc::CacheOp>(
3678 accObjectList, converter, semanticsContext, stmtCtx, cacheOperands,
3679 dataClause,
3680 /*structured=*/true, /*implicit=*/false, /*setDeclareAttr*/ false);
3681 loopOp.getCacheOperandsMutable().append(cacheOperands);
3682 } else {
3683 llvm::report_fatal_error(
3684 "could not find loop to attach OpenACC cache information.");
3686 builder.restoreInsertionPoint(crtPos);
3689 mlir::Value Fortran::lower::genOpenACCConstruct(
3690 Fortran::lower::AbstractConverter &converter,
3691 Fortran::semantics::SemanticsContext &semanticsContext,
3692 Fortran::lower::pft::Evaluation &eval,
3693 const Fortran::parser::OpenACCConstruct &accConstruct) {
3695 mlir::Value exitCond;
3696 std::visit(
3697 common::visitors{
3698 [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
3699 genACC(converter, semanticsContext, eval, blockConstruct);
3701 [&](const Fortran::parser::OpenACCCombinedConstruct
3702 &combinedConstruct) {
3703 genACC(converter, semanticsContext, eval, combinedConstruct);
3705 [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
3706 exitCond = genACC(converter, semanticsContext, eval, loopConstruct);
3708 [&](const Fortran::parser::OpenACCStandaloneConstruct
3709 &standaloneConstruct) {
3710 genACC(converter, semanticsContext, standaloneConstruct);
3712 [&](const Fortran::parser::OpenACCCacheConstruct &cacheConstruct) {
3713 genACC(converter, semanticsContext, cacheConstruct);
3715 [&](const Fortran::parser::OpenACCWaitConstruct &waitConstruct) {
3716 genACC(converter, waitConstruct);
3718 [&](const Fortran::parser::OpenACCAtomicConstruct &atomicConstruct) {
3719 genACC(converter, eval, atomicConstruct);
3721 [&](const Fortran::parser::OpenACCEndConstruct &) {
3722 // No op
3725 accConstruct.u);
3726 return exitCond;
3729 void Fortran::lower::genOpenACCDeclarativeConstruct(
3730 Fortran::lower::AbstractConverter &converter,
3731 Fortran::semantics::SemanticsContext &semanticsContext,
3732 Fortran::lower::StatementContext &openAccCtx,
3733 const Fortran::parser::OpenACCDeclarativeConstruct &accDeclConstruct,
3734 Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
3736 std::visit(
3737 common::visitors{
3738 [&](const Fortran::parser::OpenACCStandaloneDeclarativeConstruct
3739 &standaloneDeclarativeConstruct) {
3740 genACC(converter, semanticsContext, openAccCtx,
3741 standaloneDeclarativeConstruct);
3743 [&](const Fortran::parser::OpenACCRoutineConstruct
3744 &routineConstruct) {
3745 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3746 mlir::ModuleOp mod = builder.getModule();
3747 Fortran::lower::genOpenACCRoutineConstruct(
3748 converter, semanticsContext, mod, routineConstruct,
3749 accRoutineInfos);
3752 accDeclConstruct.u);
3755 void Fortran::lower::attachDeclarePostAllocAction(
3756 AbstractConverter &converter, fir::FirOpBuilder &builder,
3757 const Fortran::semantics::Symbol &sym) {
3758 std::stringstream fctName;
3759 fctName << converter.mangleName(sym) << declarePostAllocSuffix.str();
3760 mlir::Operation &op = builder.getInsertionBlock()->back();
3761 op.setAttr(mlir::acc::getDeclareActionAttrName(),
3762 mlir::acc::DeclareActionAttr::get(
3763 builder.getContext(),
3764 /*preAlloc=*/{},
3765 /*postAlloc=*/builder.getSymbolRefAttr(fctName.str()),
3766 /*preDealloc=*/{}, /*postDealloc=*/{}));
3769 void Fortran::lower::attachDeclarePreDeallocAction(
3770 AbstractConverter &converter, fir::FirOpBuilder &builder,
3771 mlir::Value beginOpValue, const Fortran::semantics::Symbol &sym) {
3772 if (!sym.test(Fortran::semantics::Symbol::Flag::AccCreate) &&
3773 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyIn) &&
3774 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyInReadOnly) &&
3775 !sym.test(Fortran::semantics::Symbol::Flag::AccCopy) &&
3776 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyOut) &&
3777 !sym.test(Fortran::semantics::Symbol::Flag::AccDeviceResident))
3778 return;
3780 std::stringstream fctName;
3781 fctName << converter.mangleName(sym) << declarePreDeallocSuffix.str();
3782 beginOpValue.getDefiningOp()->setAttr(
3783 mlir::acc::getDeclareActionAttrName(),
3784 mlir::acc::DeclareActionAttr::get(
3785 builder.getContext(),
3786 /*preAlloc=*/{}, /*postAlloc=*/{},
3787 /*preDealloc=*/builder.getSymbolRefAttr(fctName.str()),
3788 /*postDealloc=*/{}));
3791 void Fortran::lower::attachDeclarePostDeallocAction(
3792 AbstractConverter &converter, fir::FirOpBuilder &builder,
3793 const Fortran::semantics::Symbol &sym) {
3794 if (!sym.test(Fortran::semantics::Symbol::Flag::AccCreate) &&
3795 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyIn) &&
3796 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyInReadOnly) &&
3797 !sym.test(Fortran::semantics::Symbol::Flag::AccCopy) &&
3798 !sym.test(Fortran::semantics::Symbol::Flag::AccCopyOut) &&
3799 !sym.test(Fortran::semantics::Symbol::Flag::AccDeviceResident))
3800 return;
3802 std::stringstream fctName;
3803 fctName << converter.mangleName(sym) << declarePostDeallocSuffix.str();
3804 mlir::Operation &op = builder.getInsertionBlock()->back();
3805 op.setAttr(mlir::acc::getDeclareActionAttrName(),
3806 mlir::acc::DeclareActionAttr::get(
3807 builder.getContext(),
3808 /*preAlloc=*/{}, /*postAlloc=*/{}, /*preDealloc=*/{},
3809 /*postDealloc=*/builder.getSymbolRefAttr(fctName.str())));
3812 void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
3813 mlir::Operation *op,
3814 mlir::Location loc) {
3815 if (mlir::isa<mlir::acc::ParallelOp, mlir::acc::LoopOp>(op))
3816 builder.create<mlir::acc::YieldOp>(loc);
3817 else
3818 builder.create<mlir::acc::TerminatorOp>(loc);
3821 bool Fortran::lower::isInOpenACCLoop(fir::FirOpBuilder &builder) {
3822 if (builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
3823 return true;
3824 return false;
3827 void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(
3828 fir::FirOpBuilder &builder) {
3829 if (auto loopOp =
3830 builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
3831 builder.setInsertionPointAfter(loopOp);
3834 void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
3835 mlir::Location loc) {
3836 mlir::Value yieldValue =
3837 builder.createIntegerConstant(loc, builder.getI1Type(), 1);
3838 builder.create<mlir::acc::YieldOp>(loc, yieldValue);