[TargetVersion] Only enable on RISC-V and AArch64 (#115991)
[llvm-project.git] / flang / lib / Lower / HlfirIntrinsics.cpp
blob310b62697f710d5d3e7bf9507e61efd5c9660f61
1 //===-- HlfirIntrinsics.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11 //===----------------------------------------------------------------------===//
13 #include "flang/Lower/HlfirIntrinsics.h"
15 #include "flang/Optimizer/Builder/BoxValue.h"
16 #include "flang/Optimizer/Builder/FIRBuilder.h"
17 #include "flang/Optimizer/Builder/HLFIRTools.h"
18 #include "flang/Optimizer/Builder/IntrinsicCall.h"
19 #include "flang/Optimizer/Builder/MutableBox.h"
20 #include "flang/Optimizer/Builder/Todo.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
23 #include "flang/Optimizer/HLFIR/HLFIROps.h"
24 #include "mlir/IR/Value.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include <mlir/IR/ValueRange.h>
28 namespace {
30 class HlfirTransformationalIntrinsic {
31 public:
32 explicit HlfirTransformationalIntrinsic(fir::FirOpBuilder &builder,
33 mlir::Location loc)
34 : builder(builder), loc(loc) {}
36 virtual ~HlfirTransformationalIntrinsic() = default;
38 hlfir::EntityWithAttributes
39 lower(const Fortran::lower::PreparedActualArguments &loweredActuals,
40 const fir::IntrinsicArgumentLoweringRules *argLowering,
41 mlir::Type stmtResultType) {
42 mlir::Value res = lowerImpl(loweredActuals, argLowering, stmtResultType);
43 for (const hlfir::CleanupFunction &fn : cleanupFns)
44 fn();
45 return {hlfir::EntityWithAttributes{res}};
48 protected:
49 fir::FirOpBuilder &builder;
50 mlir::Location loc;
51 llvm::SmallVector<hlfir::CleanupFunction, 3> cleanupFns;
53 virtual mlir::Value
54 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
55 const fir::IntrinsicArgumentLoweringRules *argLowering,
56 mlir::Type stmtResultType) = 0;
58 llvm::SmallVector<mlir::Value> getOperandVector(
59 const Fortran::lower::PreparedActualArguments &loweredActuals,
60 const fir::IntrinsicArgumentLoweringRules *argLowering);
62 mlir::Type computeResultType(mlir::Value argArray, mlir::Type stmtResultType);
64 template <typename OP, typename... BUILD_ARGS>
65 inline OP createOp(BUILD_ARGS... args) {
66 return builder.create<OP>(loc, args...);
69 mlir::Value loadBoxAddress(
70 const std::optional<Fortran::lower::PreparedActualArgument> &arg);
72 void addCleanup(std::optional<hlfir::CleanupFunction> cleanup) {
73 if (cleanup)
74 cleanupFns.emplace_back(std::move(*cleanup));
78 template <typename OP, bool HAS_MASK>
79 class HlfirReductionIntrinsic : public HlfirTransformationalIntrinsic {
80 public:
81 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
83 protected:
84 mlir::Value
85 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
86 const fir::IntrinsicArgumentLoweringRules *argLowering,
87 mlir::Type stmtResultType) override;
89 using HlfirSumLowering = HlfirReductionIntrinsic<hlfir::SumOp, true>;
90 using HlfirProductLowering = HlfirReductionIntrinsic<hlfir::ProductOp, true>;
91 using HlfirMaxvalLowering = HlfirReductionIntrinsic<hlfir::MaxvalOp, true>;
92 using HlfirMinvalLowering = HlfirReductionIntrinsic<hlfir::MinvalOp, true>;
93 using HlfirAnyLowering = HlfirReductionIntrinsic<hlfir::AnyOp, false>;
94 using HlfirAllLowering = HlfirReductionIntrinsic<hlfir::AllOp, false>;
96 template <typename OP>
97 class HlfirMinMaxLocIntrinsic : public HlfirTransformationalIntrinsic {
98 public:
99 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
101 protected:
102 mlir::Value
103 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
104 const fir::IntrinsicArgumentLoweringRules *argLowering,
105 mlir::Type stmtResultType) override;
107 using HlfirMinlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MinlocOp>;
108 using HlfirMaxlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MaxlocOp>;
110 template <typename OP>
111 class HlfirProductIntrinsic : public HlfirTransformationalIntrinsic {
112 public:
113 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
115 protected:
116 mlir::Value
117 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
118 const fir::IntrinsicArgumentLoweringRules *argLowering,
119 mlir::Type stmtResultType) override;
121 using HlfirMatmulLowering = HlfirProductIntrinsic<hlfir::MatmulOp>;
122 using HlfirDotProductLowering = HlfirProductIntrinsic<hlfir::DotProductOp>;
124 class HlfirTransposeLowering : public HlfirTransformationalIntrinsic {
125 public:
126 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
128 protected:
129 mlir::Value
130 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
131 const fir::IntrinsicArgumentLoweringRules *argLowering,
132 mlir::Type stmtResultType) override;
135 class HlfirCountLowering : public HlfirTransformationalIntrinsic {
136 public:
137 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
139 protected:
140 mlir::Value
141 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
142 const fir::IntrinsicArgumentLoweringRules *argLowering,
143 mlir::Type stmtResultType) override;
146 class HlfirCharExtremumLowering : public HlfirTransformationalIntrinsic {
147 public:
148 HlfirCharExtremumLowering(fir::FirOpBuilder &builder, mlir::Location loc,
149 hlfir::CharExtremumPredicate pred)
150 : HlfirTransformationalIntrinsic(builder, loc), pred{pred} {}
152 protected:
153 mlir::Value
154 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
155 const fir::IntrinsicArgumentLoweringRules *argLowering,
156 mlir::Type stmtResultType) override;
158 protected:
159 hlfir::CharExtremumPredicate pred;
162 } // namespace
164 mlir::Value HlfirTransformationalIntrinsic::loadBoxAddress(
165 const std::optional<Fortran::lower::PreparedActualArgument> &arg) {
166 if (!arg)
167 return mlir::Value{};
169 hlfir::Entity actual = arg->getActual(loc, builder);
171 if (!arg->handleDynamicOptional()) {
172 if (actual.isMutableBox()) {
173 // this is a box address type but is not dynamically optional. Just load
174 // the box, assuming it is well formed (!fir.ref<!fir.box<...>> ->
175 // !fir.box<...>)
176 return builder.create<fir::LoadOp>(loc, actual.getBase());
178 return actual;
181 auto [exv, cleanup] = hlfir::translateToExtendedValue(loc, builder, actual);
182 addCleanup(cleanup);
184 mlir::Value isPresent = arg->getIsPresent();
185 // createBox will not do create any invalid memory dereferences if exv is
186 // absent. The created fir.box will not be usable, but the SelectOp below
187 // ensures it won't be.
188 mlir::Value box = builder.createBox(loc, exv);
189 mlir::Type boxType = box.getType();
190 auto absent = builder.create<fir::AbsentOp>(loc, boxType);
191 auto boxOrAbsent = builder.create<mlir::arith::SelectOp>(
192 loc, boxType, isPresent, box, absent);
194 return boxOrAbsent;
197 static mlir::Value loadOptionalValue(
198 mlir::Location loc, fir::FirOpBuilder &builder,
199 const std::optional<Fortran::lower::PreparedActualArgument> &arg,
200 hlfir::Entity actual) {
201 if (!arg->handleDynamicOptional())
202 return hlfir::loadTrivialScalar(loc, builder, actual);
204 mlir::Value isPresent = arg->getIsPresent();
205 mlir::Type eleType = hlfir::getFortranElementType(actual.getType());
206 return builder
207 .genIfOp(loc, {eleType}, isPresent,
208 /*withElseRegion=*/true)
209 .genThen([&]() {
210 assert(actual.isScalar() && fir::isa_trivial(eleType) &&
211 "must be a numerical or logical scalar");
212 hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, actual);
213 builder.create<fir::ResultOp>(loc, val);
215 .genElse([&]() {
216 mlir::Value zero = fir::factory::createZeroValue(builder, loc, eleType);
217 builder.create<fir::ResultOp>(loc, zero);
219 .getResults()[0];
222 llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector(
223 const Fortran::lower::PreparedActualArguments &loweredActuals,
224 const fir::IntrinsicArgumentLoweringRules *argLowering) {
225 llvm::SmallVector<mlir::Value> operands;
226 operands.reserve(loweredActuals.size());
228 for (size_t i = 0; i < loweredActuals.size(); ++i) {
229 std::optional<Fortran::lower::PreparedActualArgument> arg =
230 loweredActuals[i];
231 if (!arg) {
232 operands.emplace_back();
233 continue;
235 hlfir::Entity actual = arg->getActual(loc, builder);
236 mlir::Value valArg;
238 if (!argLowering) {
239 valArg = hlfir::loadTrivialScalar(loc, builder, actual);
240 } else {
241 fir::ArgLoweringRule argRules =
242 fir::lowerIntrinsicArgumentAs(*argLowering, i);
243 if (argRules.lowerAs == fir::LowerIntrinsicArgAs::Box)
244 valArg = loadBoxAddress(arg);
245 else if (!argRules.handleDynamicOptional &&
246 argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired)
247 valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual);
248 else if (argRules.handleDynamicOptional &&
249 argRules.lowerAs == fir::LowerIntrinsicArgAs::Value)
250 valArg = loadOptionalValue(loc, builder, arg, actual);
251 else if (argRules.handleDynamicOptional)
252 TODO(loc, "hlfir transformational intrinsic dynamically optional "
253 "argument without box lowering");
254 else
255 valArg = actual.getBase();
258 operands.emplace_back(valArg);
260 return operands;
263 mlir::Type
264 HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray,
265 mlir::Type stmtResultType) {
266 mlir::Type normalisedResult =
267 hlfir::getFortranElementOrSequenceType(stmtResultType);
268 if (auto array = mlir::dyn_cast<fir::SequenceType>(normalisedResult)) {
269 hlfir::ExprType::Shape resultShape =
270 hlfir::ExprType::Shape{array.getShape()};
271 mlir::Type elementType = array.getEleTy();
272 return hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
273 /*polymorphic=*/false);
274 } else if (auto resCharType =
275 mlir::dyn_cast<fir::CharacterType>(stmtResultType)) {
276 normalisedResult = hlfir::ExprType::get(
277 builder.getContext(), hlfir::ExprType::Shape{}, resCharType, false);
279 return normalisedResult;
282 template <typename OP, bool HAS_MASK>
283 mlir::Value HlfirReductionIntrinsic<OP, HAS_MASK>::lowerImpl(
284 const Fortran::lower::PreparedActualArguments &loweredActuals,
285 const fir::IntrinsicArgumentLoweringRules *argLowering,
286 mlir::Type stmtResultType) {
287 auto operands = getOperandVector(loweredActuals, argLowering);
288 mlir::Value array = operands[0];
289 mlir::Value dim = operands[1];
290 // dim, mask can be NULL if these arguments are not given
291 if (dim)
292 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
294 mlir::Type resultTy = computeResultType(array, stmtResultType);
296 OP op;
297 if constexpr (HAS_MASK)
298 op = createOp<OP>(resultTy, array, dim,
299 /*mask=*/operands[2]);
300 else
301 op = createOp<OP>(resultTy, array, dim);
302 return op;
305 template <typename OP>
306 mlir::Value HlfirMinMaxLocIntrinsic<OP>::lowerImpl(
307 const Fortran::lower::PreparedActualArguments &loweredActuals,
308 const fir::IntrinsicArgumentLoweringRules *argLowering,
309 mlir::Type stmtResultType) {
310 auto operands = getOperandVector(loweredActuals, argLowering);
311 mlir::Value array = operands[0];
312 mlir::Value dim = operands[1];
313 mlir::Value mask = operands[2];
314 mlir::Value back = operands[4];
315 // dim, mask and back can be NULL if these arguments are not given.
316 if (dim)
317 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
318 if (back)
319 back = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{back});
321 mlir::Type resultTy = computeResultType(array, stmtResultType);
323 return createOp<OP>(resultTy, array, dim, mask, back);
326 template <typename OP>
327 mlir::Value HlfirProductIntrinsic<OP>::lowerImpl(
328 const Fortran::lower::PreparedActualArguments &loweredActuals,
329 const fir::IntrinsicArgumentLoweringRules *argLowering,
330 mlir::Type stmtResultType) {
331 auto operands = getOperandVector(loweredActuals, argLowering);
332 mlir::Type resultType = computeResultType(operands[0], stmtResultType);
333 return createOp<OP>(resultType, operands[0], operands[1]);
336 mlir::Value HlfirTransposeLowering::lowerImpl(
337 const Fortran::lower::PreparedActualArguments &loweredActuals,
338 const fir::IntrinsicArgumentLoweringRules *argLowering,
339 mlir::Type stmtResultType) {
340 auto operands = getOperandVector(loweredActuals, argLowering);
341 hlfir::ExprType::Shape resultShape;
342 mlir::Type normalisedResult =
343 hlfir::getFortranElementOrSequenceType(stmtResultType);
344 auto array = mlir::cast<fir::SequenceType>(normalisedResult);
345 llvm::ArrayRef<int64_t> arrayShape = array.getShape();
346 assert(arrayShape.size() == 2 && "arguments to transpose have a rank of 2");
347 mlir::Type elementType = array.getEleTy();
348 resultShape.push_back(arrayShape[0]);
349 resultShape.push_back(arrayShape[1]);
350 if (auto resCharType = mlir::dyn_cast<fir::CharacterType>(elementType))
351 if (!resCharType.hasConstantLen()) {
352 // The FunctionRef expression might have imprecise character
353 // type at this point, and we can improve it by propagating
354 // the constant length from the argument.
355 auto argCharType = mlir::dyn_cast<fir::CharacterType>(
356 hlfir::getFortranElementType(operands[0].getType()));
357 if (argCharType && argCharType.hasConstantLen())
358 elementType = fir::CharacterType::get(
359 builder.getContext(), resCharType.getFKind(), argCharType.getLen());
362 mlir::Type resultTy =
363 hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
364 fir::isPolymorphicType(stmtResultType));
365 return createOp<hlfir::TransposeOp>(resultTy, operands[0]);
368 mlir::Value HlfirCountLowering::lowerImpl(
369 const Fortran::lower::PreparedActualArguments &loweredActuals,
370 const fir::IntrinsicArgumentLoweringRules *argLowering,
371 mlir::Type stmtResultType) {
372 auto operands = getOperandVector(loweredActuals, argLowering);
373 mlir::Value array = operands[0];
374 mlir::Value dim = operands[1];
375 if (dim)
376 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
377 mlir::Type resultType = computeResultType(array, stmtResultType);
378 return createOp<hlfir::CountOp>(resultType, array, dim);
381 mlir::Value HlfirCharExtremumLowering::lowerImpl(
382 const Fortran::lower::PreparedActualArguments &loweredActuals,
383 const fir::IntrinsicArgumentLoweringRules *argLowering,
384 mlir::Type stmtResultType) {
385 auto operands = getOperandVector(loweredActuals, argLowering);
386 assert(operands.size() >= 2);
387 return createOp<hlfir::CharExtremumOp>(pred, mlir::ValueRange{operands});
390 std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
391 fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name,
392 const Fortran::lower::PreparedActualArguments &loweredActuals,
393 const fir::IntrinsicArgumentLoweringRules *argLowering,
394 mlir::Type stmtResultType) {
395 // If the result is of a derived type that may need finalization,
396 // we have to use DestroyOp with 'finalize' attribute for the result
397 // of the intrinsic operation.
398 if (name == "sum")
399 return HlfirSumLowering{builder, loc}.lower(loweredActuals, argLowering,
400 stmtResultType);
401 if (name == "product")
402 return HlfirProductLowering{builder, loc}.lower(loweredActuals, argLowering,
403 stmtResultType);
404 if (name == "any")
405 return HlfirAnyLowering{builder, loc}.lower(loweredActuals, argLowering,
406 stmtResultType);
407 if (name == "all")
408 return HlfirAllLowering{builder, loc}.lower(loweredActuals, argLowering,
409 stmtResultType);
410 if (name == "matmul")
411 return HlfirMatmulLowering{builder, loc}.lower(loweredActuals, argLowering,
412 stmtResultType);
413 if (name == "dot_product")
414 return HlfirDotProductLowering{builder, loc}.lower(
415 loweredActuals, argLowering, stmtResultType);
416 // FIXME: the result may need finalization.
417 if (name == "transpose")
418 return HlfirTransposeLowering{builder, loc}.lower(
419 loweredActuals, argLowering, stmtResultType);
420 if (name == "count")
421 return HlfirCountLowering{builder, loc}.lower(loweredActuals, argLowering,
422 stmtResultType);
423 if (name == "maxval")
424 return HlfirMaxvalLowering{builder, loc}.lower(loweredActuals, argLowering,
425 stmtResultType);
426 if (name == "minval")
427 return HlfirMinvalLowering{builder, loc}.lower(loweredActuals, argLowering,
428 stmtResultType);
429 if (name == "minloc")
430 return HlfirMinlocLowering{builder, loc}.lower(loweredActuals, argLowering,
431 stmtResultType);
432 if (name == "maxloc")
433 return HlfirMaxlocLowering{builder, loc}.lower(loweredActuals, argLowering,
434 stmtResultType);
435 if (mlir::isa<fir::CharacterType>(stmtResultType)) {
436 if (name == "min")
437 return HlfirCharExtremumLowering{builder, loc,
438 hlfir::CharExtremumPredicate::min}
439 .lower(loweredActuals, argLowering, stmtResultType);
440 if (name == "max")
441 return HlfirCharExtremumLowering{builder, loc,
442 hlfir::CharExtremumPredicate::max}
443 .lower(loweredActuals, argLowering, stmtResultType);
445 return std::nullopt;