LAA: improve code in getStrideFromPointer (NFC) (#124780)
[llvm-project.git] / flang / lib / Lower / HlfirIntrinsics.cpp
blob8b96b209ddb00e0428d69fcefbe12ecdc7bd4da1
1 //===-- HlfirIntrinsics.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11 //===----------------------------------------------------------------------===//
13 #include "flang/Lower/HlfirIntrinsics.h"
15 #include "flang/Optimizer/Builder/BoxValue.h"
16 #include "flang/Optimizer/Builder/FIRBuilder.h"
17 #include "flang/Optimizer/Builder/HLFIRTools.h"
18 #include "flang/Optimizer/Builder/IntrinsicCall.h"
19 #include "flang/Optimizer/Builder/MutableBox.h"
20 #include "flang/Optimizer/Builder/Todo.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
23 #include "flang/Optimizer/HLFIR/HLFIROps.h"
24 #include "mlir/IR/Value.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include <mlir/IR/ValueRange.h>
28 namespace {
30 class HlfirTransformationalIntrinsic {
31 public:
32 explicit HlfirTransformationalIntrinsic(fir::FirOpBuilder &builder,
33 mlir::Location loc)
34 : builder(builder), loc(loc) {}
36 virtual ~HlfirTransformationalIntrinsic() = default;
38 hlfir::EntityWithAttributes
39 lower(const Fortran::lower::PreparedActualArguments &loweredActuals,
40 const fir::IntrinsicArgumentLoweringRules *argLowering,
41 mlir::Type stmtResultType) {
42 mlir::Value res = lowerImpl(loweredActuals, argLowering, stmtResultType);
43 for (const hlfir::CleanupFunction &fn : cleanupFns)
44 fn();
45 return {hlfir::EntityWithAttributes{res}};
48 protected:
49 fir::FirOpBuilder &builder;
50 mlir::Location loc;
51 llvm::SmallVector<hlfir::CleanupFunction, 3> cleanupFns;
53 virtual mlir::Value
54 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
55 const fir::IntrinsicArgumentLoweringRules *argLowering,
56 mlir::Type stmtResultType) = 0;
58 llvm::SmallVector<mlir::Value> getOperandVector(
59 const Fortran::lower::PreparedActualArguments &loweredActuals,
60 const fir::IntrinsicArgumentLoweringRules *argLowering);
62 mlir::Type computeResultType(mlir::Value argArray, mlir::Type stmtResultType);
64 template <typename OP, typename... BUILD_ARGS>
65 inline OP createOp(BUILD_ARGS... args) {
66 return builder.create<OP>(loc, args...);
69 mlir::Value loadBoxAddress(
70 const std::optional<Fortran::lower::PreparedActualArgument> &arg);
72 void addCleanup(std::optional<hlfir::CleanupFunction> cleanup) {
73 if (cleanup)
74 cleanupFns.emplace_back(std::move(*cleanup));
78 template <typename OP, bool HAS_MASK>
79 class HlfirReductionIntrinsic : public HlfirTransformationalIntrinsic {
80 public:
81 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
83 protected:
84 mlir::Value
85 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
86 const fir::IntrinsicArgumentLoweringRules *argLowering,
87 mlir::Type stmtResultType) override;
89 using HlfirSumLowering = HlfirReductionIntrinsic<hlfir::SumOp, true>;
90 using HlfirProductLowering = HlfirReductionIntrinsic<hlfir::ProductOp, true>;
91 using HlfirMaxvalLowering = HlfirReductionIntrinsic<hlfir::MaxvalOp, true>;
92 using HlfirMinvalLowering = HlfirReductionIntrinsic<hlfir::MinvalOp, true>;
93 using HlfirAnyLowering = HlfirReductionIntrinsic<hlfir::AnyOp, false>;
94 using HlfirAllLowering = HlfirReductionIntrinsic<hlfir::AllOp, false>;
96 template <typename OP>
97 class HlfirMinMaxLocIntrinsic : public HlfirTransformationalIntrinsic {
98 public:
99 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
101 protected:
102 mlir::Value
103 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
104 const fir::IntrinsicArgumentLoweringRules *argLowering,
105 mlir::Type stmtResultType) override;
107 using HlfirMinlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MinlocOp>;
108 using HlfirMaxlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MaxlocOp>;
110 template <typename OP>
111 class HlfirProductIntrinsic : public HlfirTransformationalIntrinsic {
112 public:
113 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
115 protected:
116 mlir::Value
117 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
118 const fir::IntrinsicArgumentLoweringRules *argLowering,
119 mlir::Type stmtResultType) override;
121 using HlfirMatmulLowering = HlfirProductIntrinsic<hlfir::MatmulOp>;
122 using HlfirDotProductLowering = HlfirProductIntrinsic<hlfir::DotProductOp>;
124 class HlfirTransposeLowering : public HlfirTransformationalIntrinsic {
125 public:
126 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
128 protected:
129 mlir::Value
130 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
131 const fir::IntrinsicArgumentLoweringRules *argLowering,
132 mlir::Type stmtResultType) override;
135 class HlfirCountLowering : public HlfirTransformationalIntrinsic {
136 public:
137 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
139 protected:
140 mlir::Value
141 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
142 const fir::IntrinsicArgumentLoweringRules *argLowering,
143 mlir::Type stmtResultType) override;
146 class HlfirCharExtremumLowering : public HlfirTransformationalIntrinsic {
147 public:
148 HlfirCharExtremumLowering(fir::FirOpBuilder &builder, mlir::Location loc,
149 hlfir::CharExtremumPredicate pred)
150 : HlfirTransformationalIntrinsic(builder, loc), pred{pred} {}
152 protected:
153 mlir::Value
154 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
155 const fir::IntrinsicArgumentLoweringRules *argLowering,
156 mlir::Type stmtResultType) override;
158 protected:
159 hlfir::CharExtremumPredicate pred;
162 class HlfirCShiftLowering : public HlfirTransformationalIntrinsic {
163 public:
164 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
166 protected:
167 mlir::Value
168 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
169 const fir::IntrinsicArgumentLoweringRules *argLowering,
170 mlir::Type stmtResultType) override;
173 class HlfirReshapeLowering : public HlfirTransformationalIntrinsic {
174 public:
175 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
177 protected:
178 mlir::Value
179 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
180 const fir::IntrinsicArgumentLoweringRules *argLowering,
181 mlir::Type stmtResultType) override;
184 } // namespace
186 mlir::Value HlfirTransformationalIntrinsic::loadBoxAddress(
187 const std::optional<Fortran::lower::PreparedActualArgument> &arg) {
188 if (!arg)
189 return mlir::Value{};
191 hlfir::Entity actual = arg->getActual(loc, builder);
193 if (!arg->handleDynamicOptional()) {
194 if (actual.isMutableBox()) {
195 // this is a box address type but is not dynamically optional. Just load
196 // the box, assuming it is well formed (!fir.ref<!fir.box<...>> ->
197 // !fir.box<...>)
198 return builder.create<fir::LoadOp>(loc, actual.getBase());
200 return actual;
203 auto [exv, cleanup] = hlfir::translateToExtendedValue(loc, builder, actual);
204 addCleanup(cleanup);
206 mlir::Value isPresent = arg->getIsPresent();
207 // createBox will not do create any invalid memory dereferences if exv is
208 // absent. The created fir.box will not be usable, but the SelectOp below
209 // ensures it won't be.
210 mlir::Value box = builder.createBox(loc, exv);
211 mlir::Type boxType = box.getType();
212 auto absent = builder.create<fir::AbsentOp>(loc, boxType);
213 auto boxOrAbsent = builder.create<mlir::arith::SelectOp>(
214 loc, boxType, isPresent, box, absent);
216 return boxOrAbsent;
219 static mlir::Value loadOptionalValue(
220 mlir::Location loc, fir::FirOpBuilder &builder,
221 const std::optional<Fortran::lower::PreparedActualArgument> &arg,
222 hlfir::Entity actual) {
223 if (!arg->handleDynamicOptional())
224 return hlfir::loadTrivialScalar(loc, builder, actual);
226 mlir::Value isPresent = arg->getIsPresent();
227 mlir::Type eleType = hlfir::getFortranElementType(actual.getType());
228 return builder
229 .genIfOp(loc, {eleType}, isPresent,
230 /*withElseRegion=*/true)
231 .genThen([&]() {
232 assert(actual.isScalar() && fir::isa_trivial(eleType) &&
233 "must be a numerical or logical scalar");
234 hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, actual);
235 builder.create<fir::ResultOp>(loc, val);
237 .genElse([&]() {
238 mlir::Value zero = fir::factory::createZeroValue(builder, loc, eleType);
239 builder.create<fir::ResultOp>(loc, zero);
241 .getResults()[0];
244 llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector(
245 const Fortran::lower::PreparedActualArguments &loweredActuals,
246 const fir::IntrinsicArgumentLoweringRules *argLowering) {
247 llvm::SmallVector<mlir::Value> operands;
248 operands.reserve(loweredActuals.size());
250 for (size_t i = 0; i < loweredActuals.size(); ++i) {
251 std::optional<Fortran::lower::PreparedActualArgument> arg =
252 loweredActuals[i];
253 if (!arg) {
254 operands.emplace_back();
255 continue;
257 hlfir::Entity actual = arg->getActual(loc, builder);
258 mlir::Value valArg;
260 if (!argLowering) {
261 valArg = hlfir::loadTrivialScalar(loc, builder, actual);
262 } else {
263 fir::ArgLoweringRule argRules =
264 fir::lowerIntrinsicArgumentAs(*argLowering, i);
265 if (argRules.lowerAs == fir::LowerIntrinsicArgAs::Box)
266 valArg = loadBoxAddress(arg);
267 else if (!argRules.handleDynamicOptional &&
268 argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired)
269 valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual);
270 else if (argRules.handleDynamicOptional &&
271 argRules.lowerAs == fir::LowerIntrinsicArgAs::Value)
272 valArg = loadOptionalValue(loc, builder, arg, actual);
273 else if (argRules.handleDynamicOptional)
274 TODO(loc, "hlfir transformational intrinsic dynamically optional "
275 "argument without box lowering");
276 else
277 valArg = actual.getBase();
280 operands.emplace_back(valArg);
282 return operands;
285 mlir::Type
286 HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray,
287 mlir::Type stmtResultType) {
288 mlir::Type normalisedResult =
289 hlfir::getFortranElementOrSequenceType(stmtResultType);
290 if (auto array = mlir::dyn_cast<fir::SequenceType>(normalisedResult)) {
291 hlfir::ExprType::Shape resultShape =
292 hlfir::ExprType::Shape{array.getShape()};
293 mlir::Type elementType = array.getEleTy();
294 return hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
295 fir::isPolymorphicType(stmtResultType));
296 } else if (auto resCharType =
297 mlir::dyn_cast<fir::CharacterType>(stmtResultType)) {
298 normalisedResult = hlfir::ExprType::get(
299 builder.getContext(), hlfir::ExprType::Shape{}, resCharType,
300 /*polymorphic=*/false);
302 return normalisedResult;
305 template <typename OP, bool HAS_MASK>
306 mlir::Value HlfirReductionIntrinsic<OP, HAS_MASK>::lowerImpl(
307 const Fortran::lower::PreparedActualArguments &loweredActuals,
308 const fir::IntrinsicArgumentLoweringRules *argLowering,
309 mlir::Type stmtResultType) {
310 auto operands = getOperandVector(loweredActuals, argLowering);
311 mlir::Value array = operands[0];
312 mlir::Value dim = operands[1];
313 // dim, mask can be NULL if these arguments are not given
314 if (dim)
315 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
317 mlir::Type resultTy = computeResultType(array, stmtResultType);
319 OP op;
320 if constexpr (HAS_MASK)
321 op = createOp<OP>(resultTy, array, dim,
322 /*mask=*/operands[2]);
323 else
324 op = createOp<OP>(resultTy, array, dim);
325 return op;
328 template <typename OP>
329 mlir::Value HlfirMinMaxLocIntrinsic<OP>::lowerImpl(
330 const Fortran::lower::PreparedActualArguments &loweredActuals,
331 const fir::IntrinsicArgumentLoweringRules *argLowering,
332 mlir::Type stmtResultType) {
333 auto operands = getOperandVector(loweredActuals, argLowering);
334 mlir::Value array = operands[0];
335 mlir::Value dim = operands[1];
336 mlir::Value mask = operands[2];
337 mlir::Value back = operands[4];
338 // dim, mask and back can be NULL if these arguments are not given.
339 if (dim)
340 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
341 if (back)
342 back = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{back});
344 mlir::Type resultTy = computeResultType(array, stmtResultType);
346 return createOp<OP>(resultTy, array, dim, mask, back);
349 template <typename OP>
350 mlir::Value HlfirProductIntrinsic<OP>::lowerImpl(
351 const Fortran::lower::PreparedActualArguments &loweredActuals,
352 const fir::IntrinsicArgumentLoweringRules *argLowering,
353 mlir::Type stmtResultType) {
354 auto operands = getOperandVector(loweredActuals, argLowering);
355 mlir::Type resultType = computeResultType(operands[0], stmtResultType);
356 return createOp<OP>(resultType, operands[0], operands[1]);
359 mlir::Value HlfirTransposeLowering::lowerImpl(
360 const Fortran::lower::PreparedActualArguments &loweredActuals,
361 const fir::IntrinsicArgumentLoweringRules *argLowering,
362 mlir::Type stmtResultType) {
363 auto operands = getOperandVector(loweredActuals, argLowering);
364 hlfir::ExprType::Shape resultShape;
365 mlir::Type normalisedResult =
366 hlfir::getFortranElementOrSequenceType(stmtResultType);
367 auto array = mlir::cast<fir::SequenceType>(normalisedResult);
368 llvm::ArrayRef<int64_t> arrayShape = array.getShape();
369 assert(arrayShape.size() == 2 && "arguments to transpose have a rank of 2");
370 mlir::Type elementType = array.getEleTy();
371 resultShape.push_back(arrayShape[0]);
372 resultShape.push_back(arrayShape[1]);
373 if (auto resCharType = mlir::dyn_cast<fir::CharacterType>(elementType))
374 if (!resCharType.hasConstantLen()) {
375 // The FunctionRef expression might have imprecise character
376 // type at this point, and we can improve it by propagating
377 // the constant length from the argument.
378 auto argCharType = mlir::dyn_cast<fir::CharacterType>(
379 hlfir::getFortranElementType(operands[0].getType()));
380 if (argCharType && argCharType.hasConstantLen())
381 elementType = fir::CharacterType::get(
382 builder.getContext(), resCharType.getFKind(), argCharType.getLen());
385 mlir::Type resultTy =
386 hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
387 fir::isPolymorphicType(stmtResultType));
388 return createOp<hlfir::TransposeOp>(resultTy, operands[0]);
391 mlir::Value HlfirCountLowering::lowerImpl(
392 const Fortran::lower::PreparedActualArguments &loweredActuals,
393 const fir::IntrinsicArgumentLoweringRules *argLowering,
394 mlir::Type stmtResultType) {
395 auto operands = getOperandVector(loweredActuals, argLowering);
396 mlir::Value array = operands[0];
397 mlir::Value dim = operands[1];
398 if (dim)
399 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
400 mlir::Type resultType = computeResultType(array, stmtResultType);
401 return createOp<hlfir::CountOp>(resultType, array, dim);
404 mlir::Value HlfirCharExtremumLowering::lowerImpl(
405 const Fortran::lower::PreparedActualArguments &loweredActuals,
406 const fir::IntrinsicArgumentLoweringRules *argLowering,
407 mlir::Type stmtResultType) {
408 auto operands = getOperandVector(loweredActuals, argLowering);
409 assert(operands.size() >= 2);
410 return createOp<hlfir::CharExtremumOp>(pred, mlir::ValueRange{operands});
413 mlir::Value HlfirCShiftLowering::lowerImpl(
414 const Fortran::lower::PreparedActualArguments &loweredActuals,
415 const fir::IntrinsicArgumentLoweringRules *argLowering,
416 mlir::Type stmtResultType) {
417 auto operands = getOperandVector(loweredActuals, argLowering);
418 assert(operands.size() == 3);
419 mlir::Value dim = operands[2];
420 if (!dim) {
421 // If DIM is not present, drop the last element which is a null Value.
422 operands.truncate(2);
423 } else {
424 // If DIM is present, then dereference it if it is a ref.
425 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
426 operands[2] = dim;
429 mlir::Type resultType = computeResultType(operands[0], stmtResultType);
430 return createOp<hlfir::CShiftOp>(resultType, operands);
433 mlir::Value HlfirReshapeLowering::lowerImpl(
434 const Fortran::lower::PreparedActualArguments &loweredActuals,
435 const fir::IntrinsicArgumentLoweringRules *argLowering,
436 mlir::Type stmtResultType) {
437 auto operands = getOperandVector(loweredActuals, argLowering);
438 assert(operands.size() == 4);
439 mlir::Type resultType = computeResultType(operands[0], stmtResultType);
440 return createOp<hlfir::ReshapeOp>(resultType, operands[0], operands[1],
441 operands[2], operands[3]);
444 std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
445 fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name,
446 const Fortran::lower::PreparedActualArguments &loweredActuals,
447 const fir::IntrinsicArgumentLoweringRules *argLowering,
448 mlir::Type stmtResultType) {
449 // If the result is of a derived type that may need finalization,
450 // we have to use DestroyOp with 'finalize' attribute for the result
451 // of the intrinsic operation.
452 if (name == "sum")
453 return HlfirSumLowering{builder, loc}.lower(loweredActuals, argLowering,
454 stmtResultType);
455 if (name == "product")
456 return HlfirProductLowering{builder, loc}.lower(loweredActuals, argLowering,
457 stmtResultType);
458 if (name == "any")
459 return HlfirAnyLowering{builder, loc}.lower(loweredActuals, argLowering,
460 stmtResultType);
461 if (name == "all")
462 return HlfirAllLowering{builder, loc}.lower(loweredActuals, argLowering,
463 stmtResultType);
464 if (name == "matmul")
465 return HlfirMatmulLowering{builder, loc}.lower(loweredActuals, argLowering,
466 stmtResultType);
467 if (name == "dot_product")
468 return HlfirDotProductLowering{builder, loc}.lower(
469 loweredActuals, argLowering, stmtResultType);
470 // FIXME: the result may need finalization.
471 if (name == "transpose")
472 return HlfirTransposeLowering{builder, loc}.lower(
473 loweredActuals, argLowering, stmtResultType);
474 if (name == "count")
475 return HlfirCountLowering{builder, loc}.lower(loweredActuals, argLowering,
476 stmtResultType);
477 if (name == "maxval")
478 return HlfirMaxvalLowering{builder, loc}.lower(loweredActuals, argLowering,
479 stmtResultType);
480 if (name == "minval")
481 return HlfirMinvalLowering{builder, loc}.lower(loweredActuals, argLowering,
482 stmtResultType);
483 if (name == "minloc")
484 return HlfirMinlocLowering{builder, loc}.lower(loweredActuals, argLowering,
485 stmtResultType);
486 if (name == "maxloc")
487 return HlfirMaxlocLowering{builder, loc}.lower(loweredActuals, argLowering,
488 stmtResultType);
489 if (name == "cshift")
490 return HlfirCShiftLowering{builder, loc}.lower(loweredActuals, argLowering,
491 stmtResultType);
492 if (name == "reshape")
493 return HlfirReshapeLowering{builder, loc}.lower(loweredActuals, argLowering,
494 stmtResultType);
495 if (mlir::isa<fir::CharacterType>(stmtResultType)) {
496 if (name == "min")
497 return HlfirCharExtremumLowering{builder, loc,
498 hlfir::CharExtremumPredicate::min}
499 .lower(loweredActuals, argLowering, stmtResultType);
500 if (name == "max")
501 return HlfirCharExtremumLowering{builder, loc,
502 hlfir::CharExtremumPredicate::max}
503 .lower(loweredActuals, argLowering, stmtResultType);
505 return std::nullopt;