1 //===-- HlfirIntrinsics.cpp -----------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11 //===----------------------------------------------------------------------===//
13 #include "flang/Lower/HlfirIntrinsics.h"
15 #include "flang/Optimizer/Builder/BoxValue.h"
16 #include "flang/Optimizer/Builder/FIRBuilder.h"
17 #include "flang/Optimizer/Builder/HLFIRTools.h"
18 #include "flang/Optimizer/Builder/IntrinsicCall.h"
19 #include "flang/Optimizer/Builder/MutableBox.h"
20 #include "flang/Optimizer/Builder/Todo.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
23 #include "flang/Optimizer/HLFIR/HLFIROps.h"
24 #include "mlir/IR/Value.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include <mlir/IR/ValueRange.h>
30 class HlfirTransformationalIntrinsic
{
32 explicit HlfirTransformationalIntrinsic(fir::FirOpBuilder
&builder
,
34 : builder(builder
), loc(loc
) {}
36 virtual ~HlfirTransformationalIntrinsic() = default;
38 hlfir::EntityWithAttributes
39 lower(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
40 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
41 mlir::Type stmtResultType
) {
42 mlir::Value res
= lowerImpl(loweredActuals
, argLowering
, stmtResultType
);
43 for (const hlfir::CleanupFunction
&fn
: cleanupFns
)
45 return {hlfir::EntityWithAttributes
{res
}};
49 fir::FirOpBuilder
&builder
;
51 llvm::SmallVector
<hlfir::CleanupFunction
, 3> cleanupFns
;
54 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
55 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
56 mlir::Type stmtResultType
) = 0;
58 llvm::SmallVector
<mlir::Value
> getOperandVector(
59 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
60 const fir::IntrinsicArgumentLoweringRules
*argLowering
);
62 mlir::Type
computeResultType(mlir::Value argArray
, mlir::Type stmtResultType
);
64 template <typename OP
, typename
... BUILD_ARGS
>
65 inline OP
createOp(BUILD_ARGS
... args
) {
66 return builder
.create
<OP
>(loc
, args
...);
69 mlir::Value
loadBoxAddress(
70 const std::optional
<Fortran::lower::PreparedActualArgument
> &arg
);
72 void addCleanup(std::optional
<hlfir::CleanupFunction
> cleanup
) {
74 cleanupFns
.emplace_back(std::move(*cleanup
));
78 template <typename OP
, bool HAS_MASK
>
79 class HlfirReductionIntrinsic
: public HlfirTransformationalIntrinsic
{
81 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic
;
85 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
86 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
87 mlir::Type stmtResultType
) override
;
89 using HlfirSumLowering
= HlfirReductionIntrinsic
<hlfir::SumOp
, true>;
90 using HlfirProductLowering
= HlfirReductionIntrinsic
<hlfir::ProductOp
, true>;
91 using HlfirMaxvalLowering
= HlfirReductionIntrinsic
<hlfir::MaxvalOp
, true>;
92 using HlfirMinvalLowering
= HlfirReductionIntrinsic
<hlfir::MinvalOp
, true>;
93 using HlfirAnyLowering
= HlfirReductionIntrinsic
<hlfir::AnyOp
, false>;
94 using HlfirAllLowering
= HlfirReductionIntrinsic
<hlfir::AllOp
, false>;
96 template <typename OP
>
97 class HlfirMinMaxLocIntrinsic
: public HlfirTransformationalIntrinsic
{
99 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic
;
103 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
104 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
105 mlir::Type stmtResultType
) override
;
107 using HlfirMinlocLowering
= HlfirMinMaxLocIntrinsic
<hlfir::MinlocOp
>;
108 using HlfirMaxlocLowering
= HlfirMinMaxLocIntrinsic
<hlfir::MaxlocOp
>;
110 template <typename OP
>
111 class HlfirProductIntrinsic
: public HlfirTransformationalIntrinsic
{
113 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic
;
117 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
118 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
119 mlir::Type stmtResultType
) override
;
121 using HlfirMatmulLowering
= HlfirProductIntrinsic
<hlfir::MatmulOp
>;
122 using HlfirDotProductLowering
= HlfirProductIntrinsic
<hlfir::DotProductOp
>;
124 class HlfirTransposeLowering
: public HlfirTransformationalIntrinsic
{
126 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic
;
130 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
131 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
132 mlir::Type stmtResultType
) override
;
135 class HlfirCountLowering
: public HlfirTransformationalIntrinsic
{
137 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic
;
141 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
142 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
143 mlir::Type stmtResultType
) override
;
146 class HlfirCharExtremumLowering
: public HlfirTransformationalIntrinsic
{
148 HlfirCharExtremumLowering(fir::FirOpBuilder
&builder
, mlir::Location loc
,
149 hlfir::CharExtremumPredicate pred
)
150 : HlfirTransformationalIntrinsic(builder
, loc
), pred
{pred
} {}
154 lowerImpl(const Fortran::lower::PreparedActualArguments
&loweredActuals
,
155 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
156 mlir::Type stmtResultType
) override
;
159 hlfir::CharExtremumPredicate pred
;
164 mlir::Value
HlfirTransformationalIntrinsic::loadBoxAddress(
165 const std::optional
<Fortran::lower::PreparedActualArgument
> &arg
) {
167 return mlir::Value
{};
169 hlfir::Entity actual
= arg
->getActual(loc
, builder
);
171 if (!arg
->handleDynamicOptional()) {
172 if (actual
.isMutableBox()) {
173 // this is a box address type but is not dynamically optional. Just load
174 // the box, assuming it is well formed (!fir.ref<!fir.box<...>> ->
176 return builder
.create
<fir::LoadOp
>(loc
, actual
.getBase());
181 auto [exv
, cleanup
] = hlfir::translateToExtendedValue(loc
, builder
, actual
);
184 mlir::Value isPresent
= arg
->getIsPresent();
185 // createBox will not do create any invalid memory dereferences if exv is
186 // absent. The created fir.box will not be usable, but the SelectOp below
187 // ensures it won't be.
188 mlir::Value box
= builder
.createBox(loc
, exv
);
189 mlir::Type boxType
= box
.getType();
190 auto absent
= builder
.create
<fir::AbsentOp
>(loc
, boxType
);
191 auto boxOrAbsent
= builder
.create
<mlir::arith::SelectOp
>(
192 loc
, boxType
, isPresent
, box
, absent
);
197 static mlir::Value
loadOptionalValue(
198 mlir::Location loc
, fir::FirOpBuilder
&builder
,
199 const std::optional
<Fortran::lower::PreparedActualArgument
> &arg
,
200 hlfir::Entity actual
) {
201 if (!arg
->handleDynamicOptional())
202 return hlfir::loadTrivialScalar(loc
, builder
, actual
);
204 mlir::Value isPresent
= arg
->getIsPresent();
205 mlir::Type eleType
= hlfir::getFortranElementType(actual
.getType());
207 .genIfOp(loc
, {eleType
}, isPresent
,
208 /*withElseRegion=*/true)
210 assert(actual
.isScalar() && fir::isa_trivial(eleType
) &&
211 "must be a numerical or logical scalar");
212 hlfir::Entity val
= hlfir::loadTrivialScalar(loc
, builder
, actual
);
213 builder
.create
<fir::ResultOp
>(loc
, val
);
216 mlir::Value zero
= fir::factory::createZeroValue(builder
, loc
, eleType
);
217 builder
.create
<fir::ResultOp
>(loc
, zero
);
222 llvm::SmallVector
<mlir::Value
> HlfirTransformationalIntrinsic::getOperandVector(
223 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
224 const fir::IntrinsicArgumentLoweringRules
*argLowering
) {
225 llvm::SmallVector
<mlir::Value
> operands
;
226 operands
.reserve(loweredActuals
.size());
228 for (size_t i
= 0; i
< loweredActuals
.size(); ++i
) {
229 std::optional
<Fortran::lower::PreparedActualArgument
> arg
=
232 operands
.emplace_back();
235 hlfir::Entity actual
= arg
->getActual(loc
, builder
);
239 valArg
= hlfir::loadTrivialScalar(loc
, builder
, actual
);
241 fir::ArgLoweringRule argRules
=
242 fir::lowerIntrinsicArgumentAs(*argLowering
, i
);
243 if (argRules
.lowerAs
== fir::LowerIntrinsicArgAs::Box
)
244 valArg
= loadBoxAddress(arg
);
245 else if (!argRules
.handleDynamicOptional
&&
246 argRules
.lowerAs
!= fir::LowerIntrinsicArgAs::Inquired
)
247 valArg
= hlfir::derefPointersAndAllocatables(loc
, builder
, actual
);
248 else if (argRules
.handleDynamicOptional
&&
249 argRules
.lowerAs
== fir::LowerIntrinsicArgAs::Value
)
250 valArg
= loadOptionalValue(loc
, builder
, arg
, actual
);
251 else if (argRules
.handleDynamicOptional
)
252 TODO(loc
, "hlfir transformational intrinsic dynamically optional "
253 "argument without box lowering");
255 valArg
= actual
.getBase();
258 operands
.emplace_back(valArg
);
264 HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray
,
265 mlir::Type stmtResultType
) {
266 mlir::Type normalisedResult
=
267 hlfir::getFortranElementOrSequenceType(stmtResultType
);
268 if (auto array
= mlir::dyn_cast
<fir::SequenceType
>(normalisedResult
)) {
269 hlfir::ExprType::Shape resultShape
=
270 hlfir::ExprType::Shape
{array
.getShape()};
271 mlir::Type elementType
= array
.getEleTy();
272 return hlfir::ExprType::get(builder
.getContext(), resultShape
, elementType
,
273 /*polymorphic=*/false);
274 } else if (auto resCharType
=
275 mlir::dyn_cast
<fir::CharacterType
>(stmtResultType
)) {
276 normalisedResult
= hlfir::ExprType::get(
277 builder
.getContext(), hlfir::ExprType::Shape
{}, resCharType
, false);
279 return normalisedResult
;
282 template <typename OP
, bool HAS_MASK
>
283 mlir::Value HlfirReductionIntrinsic
<OP
, HAS_MASK
>::lowerImpl(
284 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
285 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
286 mlir::Type stmtResultType
) {
287 auto operands
= getOperandVector(loweredActuals
, argLowering
);
288 mlir::Value array
= operands
[0];
289 mlir::Value dim
= operands
[1];
290 // dim, mask can be NULL if these arguments are not given
292 dim
= hlfir::loadTrivialScalar(loc
, builder
, hlfir::Entity
{dim
});
294 mlir::Type resultTy
= computeResultType(array
, stmtResultType
);
297 if constexpr (HAS_MASK
)
298 op
= createOp
<OP
>(resultTy
, array
, dim
,
299 /*mask=*/operands
[2]);
301 op
= createOp
<OP
>(resultTy
, array
, dim
);
305 template <typename OP
>
306 mlir::Value HlfirMinMaxLocIntrinsic
<OP
>::lowerImpl(
307 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
308 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
309 mlir::Type stmtResultType
) {
310 auto operands
= getOperandVector(loweredActuals
, argLowering
);
311 mlir::Value array
= operands
[0];
312 mlir::Value dim
= operands
[1];
313 mlir::Value mask
= operands
[2];
314 mlir::Value back
= operands
[4];
315 // dim, mask and back can be NULL if these arguments are not given.
317 dim
= hlfir::loadTrivialScalar(loc
, builder
, hlfir::Entity
{dim
});
319 back
= hlfir::loadTrivialScalar(loc
, builder
, hlfir::Entity
{back
});
321 mlir::Type resultTy
= computeResultType(array
, stmtResultType
);
323 return createOp
<OP
>(resultTy
, array
, dim
, mask
, back
);
326 template <typename OP
>
327 mlir::Value HlfirProductIntrinsic
<OP
>::lowerImpl(
328 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
329 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
330 mlir::Type stmtResultType
) {
331 auto operands
= getOperandVector(loweredActuals
, argLowering
);
332 mlir::Type resultType
= computeResultType(operands
[0], stmtResultType
);
333 return createOp
<OP
>(resultType
, operands
[0], operands
[1]);
336 mlir::Value
HlfirTransposeLowering::lowerImpl(
337 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
338 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
339 mlir::Type stmtResultType
) {
340 auto operands
= getOperandVector(loweredActuals
, argLowering
);
341 hlfir::ExprType::Shape resultShape
;
342 mlir::Type normalisedResult
=
343 hlfir::getFortranElementOrSequenceType(stmtResultType
);
344 auto array
= mlir::cast
<fir::SequenceType
>(normalisedResult
);
345 llvm::ArrayRef
<int64_t> arrayShape
= array
.getShape();
346 assert(arrayShape
.size() == 2 && "arguments to transpose have a rank of 2");
347 mlir::Type elementType
= array
.getEleTy();
348 resultShape
.push_back(arrayShape
[0]);
349 resultShape
.push_back(arrayShape
[1]);
350 if (auto resCharType
= mlir::dyn_cast
<fir::CharacterType
>(elementType
))
351 if (!resCharType
.hasConstantLen()) {
352 // The FunctionRef expression might have imprecise character
353 // type at this point, and we can improve it by propagating
354 // the constant length from the argument.
355 auto argCharType
= mlir::dyn_cast
<fir::CharacterType
>(
356 hlfir::getFortranElementType(operands
[0].getType()));
357 if (argCharType
&& argCharType
.hasConstantLen())
358 elementType
= fir::CharacterType::get(
359 builder
.getContext(), resCharType
.getFKind(), argCharType
.getLen());
362 mlir::Type resultTy
=
363 hlfir::ExprType::get(builder
.getContext(), resultShape
, elementType
,
364 fir::isPolymorphicType(stmtResultType
));
365 return createOp
<hlfir::TransposeOp
>(resultTy
, operands
[0]);
368 mlir::Value
HlfirCountLowering::lowerImpl(
369 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
370 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
371 mlir::Type stmtResultType
) {
372 auto operands
= getOperandVector(loweredActuals
, argLowering
);
373 mlir::Value array
= operands
[0];
374 mlir::Value dim
= operands
[1];
376 dim
= hlfir::loadTrivialScalar(loc
, builder
, hlfir::Entity
{dim
});
377 mlir::Type resultType
= computeResultType(array
, stmtResultType
);
378 return createOp
<hlfir::CountOp
>(resultType
, array
, dim
);
381 mlir::Value
HlfirCharExtremumLowering::lowerImpl(
382 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
383 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
384 mlir::Type stmtResultType
) {
385 auto operands
= getOperandVector(loweredActuals
, argLowering
);
386 assert(operands
.size() >= 2);
387 return createOp
<hlfir::CharExtremumOp
>(pred
, mlir::ValueRange
{operands
});
390 std::optional
<hlfir::EntityWithAttributes
> Fortran::lower::lowerHlfirIntrinsic(
391 fir::FirOpBuilder
&builder
, mlir::Location loc
, const std::string
&name
,
392 const Fortran::lower::PreparedActualArguments
&loweredActuals
,
393 const fir::IntrinsicArgumentLoweringRules
*argLowering
,
394 mlir::Type stmtResultType
) {
395 // If the result is of a derived type that may need finalization,
396 // we have to use DestroyOp with 'finalize' attribute for the result
397 // of the intrinsic operation.
399 return HlfirSumLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
401 if (name
== "product")
402 return HlfirProductLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
405 return HlfirAnyLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
408 return HlfirAllLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
410 if (name
== "matmul")
411 return HlfirMatmulLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
413 if (name
== "dot_product")
414 return HlfirDotProductLowering
{builder
, loc
}.lower(
415 loweredActuals
, argLowering
, stmtResultType
);
416 // FIXME: the result may need finalization.
417 if (name
== "transpose")
418 return HlfirTransposeLowering
{builder
, loc
}.lower(
419 loweredActuals
, argLowering
, stmtResultType
);
421 return HlfirCountLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
423 if (name
== "maxval")
424 return HlfirMaxvalLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
426 if (name
== "minval")
427 return HlfirMinvalLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
429 if (name
== "minloc")
430 return HlfirMinlocLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
432 if (name
== "maxloc")
433 return HlfirMaxlocLowering
{builder
, loc
}.lower(loweredActuals
, argLowering
,
435 if (mlir::isa
<fir::CharacterType
>(stmtResultType
)) {
437 return HlfirCharExtremumLowering
{builder
, loc
,
438 hlfir::CharExtremumPredicate::min
}
439 .lower(loweredActuals
, argLowering
, stmtResultType
);
441 return HlfirCharExtremumLowering
{builder
, loc
,
442 hlfir::CharExtremumPredicate::max
}
443 .lower(loweredActuals
, argLowering
, stmtResultType
);