1 //===-- HlfirIntrinsics.cpp -----------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11 //===----------------------------------------------------------------------===//
13 #include "flang/Lower/HlfirIntrinsics.h"
15 #include "flang/Optimizer/Builder/BoxValue.h"
16 #include "flang/Optimizer/Builder/FIRBuilder.h"
17 #include "flang/Optimizer/Builder/HLFIRTools.h"
18 #include "flang/Optimizer/Builder/IntrinsicCall.h"
19 #include "flang/Optimizer/Builder/MutableBox.h"
20 #include "flang/Optimizer/Builder/Todo.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
23 #include "flang/Optimizer/HLFIR/HLFIROps.h"
24 #include "mlir/IR/Value.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include <mlir/IR/ValueRange.h>
30 class HlfirTransformationalIntrinsic
{
32 explicit HlfirTransformationalIntrinsic(fir::FirOpBuilder
&builder
,
34 : builder(builder
), loc(loc
) {}
36 virtual ~HlfirTransformationalIntrinsic() = default;
38 hlfir::EntityWithAttributes
39 lower(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
40 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
41 mlir::Type stmtResultType
) {
42 mlir::Value res
= lowerImpl(loweredActuals
, argLowering
, stmtResultType
);
43 for (const hlfir::CleanupFunction
&fn
: cleanupFns
)
45 return {hlfir::EntityWithAttributes
{res
}};
49 fir::FirOpBuilder
&builder
;
51 llvm::SmallVector
<hlfir::CleanupFunction
, 3> cleanupFns
;
54 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
55 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
56 mlir::Type stmtResultType
) = 0;
58 llvm::SmallVector
<mlir::Value
> getOperandVector(
59 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
60 const fir::IntrinsicArgumentLoweringRules
*argLowering
);
62 mlir::Type
computeResultType(mlir::Value argArray
, mlir::Type stmtResultType
);
64 template <typename OP
, typename
... BUILD_ARGS
>
65 inline OP
createOp(BUILD_ARGS
... args
) {
66 return builder
.create
<OP
>(loc
, args
...);
69 mlir::Value
loadBoxAddress(
70 const std::optional
<Fortran::lower::PreparedActualArgument
> &arg
);
72 void addCleanup(std::optional
<hlfir::CleanupFunction
> cleanup
) {
74 cleanupFns
.emplace_back(std::move(*cleanup
));
78 template <typename OP
, bool HAS_MASK
>
79 class HlfirReductionIntrinsic
: public HlfirTransformationalIntrinsic
{
81 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic
;
85 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
86 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
87 mlir::Type stmtResultType
) override
;
89 using HlfirSumLowering
= HlfirReductionIntrinsic
<hlfir::SumOp
, true>;
90 using HlfirProductLowering
= HlfirReductionIntrinsic
<hlfir::ProductOp
, true>;
91 using HlfirMaxvalLowering
= HlfirReductionIntrinsic
<hlfir::MaxvalOp
, true>;
92 using HlfirMinvalLowering
= HlfirReductionIntrinsic
<hlfir::MinvalOp
, true>;
93 using HlfirAnyLowering
= HlfirReductionIntrinsic
<hlfir::AnyOp
, false>;
94 using HlfirAllLowering
= HlfirReductionIntrinsic
<hlfir::AllOp
, false>;
96 template <typename OP
>
97 class HlfirMinMaxLocIntrinsic
: public HlfirTransformationalIntrinsic
{
99 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic
;
103 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
104 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
105 mlir::Type stmtResultType
) override
;
107 using HlfirMinlocLowering
= HlfirMinMaxLocIntrinsic
<hlfir::MinlocOp
>;
108 using HlfirMaxlocLowering
= HlfirMinMaxLocIntrinsic
<hlfir::MaxlocOp
>;
110 template <typename OP
>
111 class HlfirProductIntrinsic
: public HlfirTransformationalIntrinsic
{
113 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic
;
117 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
118 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
119 mlir::Type stmtResultType
) override
;
121 using HlfirMatmulLowering
= HlfirProductIntrinsic
<hlfir::MatmulOp
>;
122 using HlfirDotProductLowering
= HlfirProductIntrinsic
<hlfir::DotProductOp
>;
124 class HlfirTransposeLowering
: public HlfirTransformationalIntrinsic
{
126 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic
;
130 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
131 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
132 mlir::Type stmtResultType
) override
;
135 class HlfirCountLowering
: public HlfirTransformationalIntrinsic
{
137 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic
;
141 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
142 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
143 mlir::Type stmtResultType
) override
;
146 class HlfirCharExtremumLowering
: public HlfirTransformationalIntrinsic
{
148 HlfirCharExtremumLowering(fir::FirOpBuilder
&builder
, mlir::Location loc
,
149 hlfir::CharExtremumPredicate pred
)
150 : HlfirTransformationalIntrinsic(builder
, loc
), pred
{pred
} {}
154 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
155 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
156 mlir::Type stmtResultType
) override
;
159 hlfir::CharExtremumPredicate pred
;
162 class HlfirCShiftLowering
: public HlfirTransformationalIntrinsic
{
164 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic
;
168 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
169 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
170 mlir::Type stmtResultType
) override
;
173 class HlfirReshapeLowering
: public HlfirTransformationalIntrinsic
{
175 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic
;
179 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
180 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
181 mlir::Type stmtResultType
) override
;
186 mlir::Value
HlfirTransformationalIntrinsic::loadBoxAddress(
187 const std::optional
<Fortran::lower::PreparedActualArgument
> &arg
) {
189 return mlir::Value
{};
191 hlfir::Entity actual
= arg
->getActual(loc
, builder
);
193 if (!arg
->handleDynamicOptional()) {
194 if (actual
.isMutableBox()) {
195 // this is a box address type but is not dynamically optional. Just load
196 // the box, assuming it is well formed (!fir.ref<!fir.box<...>> ->
198 return builder
.create
<fir::LoadOp
>(loc
, actual
.getBase());
203 auto [exv
, cleanup
] = hlfir::translateToExtendedValue(loc
, builder
, actual
);
206 mlir::Value isPresent
= arg
->getIsPresent();
207 // createBox will not do create any invalid memory dereferences if exv is
208 // absent. The created fir.box will not be usable, but the SelectOp below
209 // ensures it won't be.
210 mlir::Value box
= builder
.createBox(loc
, exv
);
211 mlir::Type boxType
= box
.getType();
212 auto absent
= builder
.create
<fir::AbsentOp
>(loc
, boxType
);
213 auto boxOrAbsent
= builder
.create
<mlir::arith::SelectOp
>(
214 loc
, boxType
, isPresent
, box
, absent
);
219 static mlir::Value
loadOptionalValue(
220 mlir::Location loc
, fir::FirOpBuilder
&builder
,
221 const std::optional
<Fortran::lower::PreparedActualArgument
> &arg
,
222 hlfir::Entity actual
) {
223 if (!arg
->handleDynamicOptional())
224 return hlfir::loadTrivialScalar(loc
, builder
, actual
);
226 mlir::Value isPresent
= arg
->getIsPresent();
227 mlir::Type eleType
= hlfir::getFortranElementType(actual
.getType());
229 .genIfOp(loc
, {eleType
}, isPresent
,
230 /*withElseRegion=*/true)
232 assert(actual
.isScalar() && fir::isa_trivial(eleType
) &&
233 "must be a numerical or logical scalar");
234 hlfir::Entity val
= hlfir::loadTrivialScalar(loc
, builder
, actual
);
235 builder
.create
<fir::ResultOp
>(loc
, val
);
238 mlir::Value zero
= fir::factory::createZeroValue(builder
, loc
, eleType
);
239 builder
.create
<fir::ResultOp
>(loc
, zero
);
244 llvm::SmallVector
<mlir::Value
> HlfirTransformationalIntrinsic::getOperandVector(
245 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
246 const fir::IntrinsicArgumentLoweringRules
*argLowering
) {
247 llvm::SmallVector
<mlir::Value
> operands
;
248 operands
.reserve(loweredActuals
.size());
250 for (size_t i
= 0; i
< loweredActuals
.size(); ++i
) {
251 std::optional
<Fortran::lower::PreparedActualArgument
> arg
=
254 operands
.emplace_back();
257 hlfir::Entity actual
= arg
->getActual(loc
, builder
);
261 valArg
= hlfir::loadTrivialScalar(loc
, builder
, actual
);
263 fir::ArgLoweringRule argRules
=
264 fir::lowerIntrinsicArgumentAs(*argLowering
, i
);
265 if (argRules
.lowerAs
== fir::LowerIntrinsicArgAs::Box
)
266 valArg
= loadBoxAddress(arg
);
267 else if (!argRules
.handleDynamicOptional
&&
268 argRules
.lowerAs
!= fir::LowerIntrinsicArgAs::Inquired
)
269 valArg
= hlfir::derefPointersAndAllocatables(loc
, builder
, actual
);
270 else if (argRules
.handleDynamicOptional
&&
271 argRules
.lowerAs
== fir::LowerIntrinsicArgAs::Value
)
272 valArg
= loadOptionalValue(loc
, builder
, arg
, actual
);
273 else if (argRules
.handleDynamicOptional
)
274 TODO(loc
, "hlfir transformational intrinsic dynamically optional "
275 "argument without box lowering");
277 valArg
= actual
.getBase();
280 operands
.emplace_back(valArg
);
286 HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray
,
287 mlir::Type stmtResultType
) {
288 mlir::Type normalisedResult
=
289 hlfir::getFortranElementOrSequenceType(stmtResultType
);
290 if (auto array
= mlir::dyn_cast
<fir::SequenceType
>(normalisedResult
)) {
291 hlfir::ExprType::Shape resultShape
=
292 hlfir::ExprType::Shape
{array
.getShape()};
293 mlir::Type elementType
= array
.getEleTy();
294 return hlfir::ExprType::get(builder
.getContext(), resultShape
, elementType
,
295 fir::isPolymorphicType(stmtResultType
));
296 } else if (auto resCharType
=
297 mlir::dyn_cast
<fir::CharacterType
>(stmtResultType
)) {
298 normalisedResult
= hlfir::ExprType::get(
299 builder
.getContext(), hlfir::ExprType::Shape
{}, resCharType
,
300 /*polymorphic=*/false);
302 return normalisedResult
;
305 template <typename OP
, bool HAS_MASK
>
306 mlir::Value HlfirReductionIntrinsic
<OP
, HAS_MASK
>::lowerImpl(
307 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
308 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
309 mlir::Type stmtResultType
) {
310 auto operands
= getOperandVector(loweredActuals
, argLowering
);
311 mlir::Value array
= operands
[0];
312 mlir::Value dim
= operands
[1];
313 // dim, mask can be NULL if these arguments are not given
315 dim
= hlfir::loadTrivialScalar(loc
, builder
, hlfir::Entity
{dim
});
317 mlir::Type resultTy
= computeResultType(array
, stmtResultType
);
320 if constexpr (HAS_MASK
)
321 op
= createOp
<OP
>(resultTy
, array
, dim
,
322 /*mask=*/operands
[2]);
324 op
= createOp
<OP
>(resultTy
, array
, dim
);
328 template <typename OP
>
329 mlir::Value HlfirMinMaxLocIntrinsic
<OP
>::lowerImpl(
330 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
331 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
332 mlir::Type stmtResultType
) {
333 auto operands
= getOperandVector(loweredActuals
, argLowering
);
334 mlir::Value array
= operands
[0];
335 mlir::Value dim
= operands
[1];
336 mlir::Value mask
= operands
[2];
337 mlir::Value back
= operands
[4];
338 // dim, mask and back can be NULL if these arguments are not given.
340 dim
= hlfir::loadTrivialScalar(loc
, builder
, hlfir::Entity
{dim
});
342 back
= hlfir::loadTrivialScalar(loc
, builder
, hlfir::Entity
{back
});
344 mlir::Type resultTy
= computeResultType(array
, stmtResultType
);
346 return createOp
<OP
>(resultTy
, array
, dim
, mask
, back
);
349 template <typename OP
>
350 mlir::Value HlfirProductIntrinsic
<OP
>::lowerImpl(
351 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
352 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
353 mlir::Type stmtResultType
) {
354 auto operands
= getOperandVector(loweredActuals
, argLowering
);
355 mlir::Type resultType
= computeResultType(operands
[0], stmtResultType
);
356 return createOp
<OP
>(resultType
, operands
[0], operands
[1]);
359 mlir::Value
HlfirTransposeLowering::lowerImpl(
360 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
361 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
362 mlir::Type stmtResultType
) {
363 auto operands
= getOperandVector(loweredActuals
, argLowering
);
364 hlfir::ExprType::Shape resultShape
;
365 mlir::Type normalisedResult
=
366 hlfir::getFortranElementOrSequenceType(stmtResultType
);
367 auto array
= mlir::cast
<fir::SequenceType
>(normalisedResult
);
368 llvm::ArrayRef
<int64_t> arrayShape
= array
.getShape();
369 assert(arrayShape
.size() == 2 && "arguments to transpose have a rank of 2");
370 mlir::Type elementType
= array
.getEleTy();
371 resultShape
.push_back(arrayShape
[0]);
372 resultShape
.push_back(arrayShape
[1]);
373 if (auto resCharType
= mlir::dyn_cast
<fir::CharacterType
>(elementType
))
374 if (!resCharType
.hasConstantLen()) {
375 // The FunctionRef expression might have imprecise character
376 // type at this point, and we can improve it by propagating
377 // the constant length from the argument.
378 auto argCharType
= mlir::dyn_cast
<fir::CharacterType
>(
379 hlfir::getFortranElementType(operands
[0].getType()));
380 if (argCharType
&& argCharType
.hasConstantLen())
381 elementType
= fir::CharacterType::get(
382 builder
.getContext(), resCharType
.getFKind(), argCharType
.getLen());
385 mlir::Type resultTy
=
386 hlfir::ExprType::get(builder
.getContext(), resultShape
, elementType
,
387 fir::isPolymorphicType(stmtResultType
));
388 return createOp
<hlfir::TransposeOp
>(resultTy
, operands
[0]);
391 mlir::Value
HlfirCountLowering::lowerImpl(
392 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
393 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
394 mlir::Type stmtResultType
) {
395 auto operands
= getOperandVector(loweredActuals
, argLowering
);
396 mlir::Value array
= operands
[0];
397 mlir::Value dim
= operands
[1];
399 dim
= hlfir::loadTrivialScalar(loc
, builder
, hlfir::Entity
{dim
});
400 mlir::Type resultType
= computeResultType(array
, stmtResultType
);
401 return createOp
<hlfir::CountOp
>(resultType
, array
, dim
);
404 mlir::Value
HlfirCharExtremumLowering::lowerImpl(
405 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
406 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
407 mlir::Type stmtResultType
) {
408 auto operands
= getOperandVector(loweredActuals
, argLowering
);
409 assert(operands
.size() >= 2);
410 return createOp
<hlfir::CharExtremumOp
>(pred
, mlir::ValueRange
{operands
});
413 mlir::Value
HlfirCShiftLowering::lowerImpl(
414 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
415 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
416 mlir::Type stmtResultType
) {
417 auto operands
= getOperandVector(loweredActuals
, argLowering
);
418 assert(operands
.size() == 3);
419 mlir::Value dim
= operands
[2];
421 // If DIM is not present, drop the last element which is a null Value.
422 operands
.truncate(2);
424 // If DIM is present, then dereference it if it is a ref.
425 dim
= hlfir::loadTrivialScalar(loc
, builder
, hlfir::Entity
{dim
});
429 mlir::Type resultType
= computeResultType(operands
[0], stmtResultType
);
430 return createOp
<hlfir::CShiftOp
>(resultType
, operands
);
433 mlir::Value
HlfirReshapeLowering::lowerImpl(
434 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
435 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
436 mlir::Type stmtResultType
) {
437 auto operands
= getOperandVector(loweredActuals
, argLowering
);
438 assert(operands
.size() == 4);
439 mlir::Type resultType
= computeResultType(operands
[0], stmtResultType
);
440 return createOp
<hlfir::ReshapeOp
>(resultType
, operands
[0], operands
[1],
441 operands
[2], operands
[3]);
444 std::optional
<hlfir::EntityWithAttributes
> Fortran::lower::lowerHlfirIntrinsic(
445 fir::FirOpBuilder
&builder
, mlir::Location loc
, const std::string
&name
,
446 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
447 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
448 mlir::Type stmtResultType
) {
449 // If the result is of a derived type that may need finalization,
450 // we have to use DestroyOp with 'finalize' attribute for the result
451 // of the intrinsic operation.
453 return HlfirSumLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
455 if (name
== "product")
456 return HlfirProductLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
459 return HlfirAnyLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
462 return HlfirAllLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
464 if (name
== "matmul")
465 return HlfirMatmulLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
467 if (name
== "dot_product")
468 return HlfirDotProductLowering
{builder
, loc
}.lower(
469 loweredActuals
, argLowering
, stmtResultType
);
470 // FIXME: the result may need finalization.
471 if (name
== "transpose")
472 return HlfirTransposeLowering
{builder
, loc
}.lower(
473 loweredActuals
, argLowering
, stmtResultType
);
475 return HlfirCountLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
477 if (name
== "maxval")
478 return HlfirMaxvalLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
480 if (name
== "minval")
481 return HlfirMinvalLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
483 if (name
== "minloc")
484 return HlfirMinlocLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
486 if (name
== "maxloc")
487 return HlfirMaxlocLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
489 if (name
== "cshift")
490 return HlfirCShiftLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
492 if (name
== "reshape")
493 return HlfirReshapeLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
495 if (mlir::isa
<fir::CharacterType
>(stmtResultType
)) {
497 return HlfirCharExtremumLowering
{builder
, loc
,
498 hlfir::CharExtremumPredicate::min
}
499 .lower(loweredActuals
, argLowering
, stmtResultType
);
501 return HlfirCharExtremumLowering
{builder
, loc
,
502 hlfir::CharExtremumPredicate::max
}
503 .lower(loweredActuals
, argLowering
, stmtResultType
);