1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 //===----------------------------------------------------------------------===//
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 // for higher dimension arrays, of course)
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
25 #include "flang/Common/Fortran.h"
26 #include "flang/Optimizer/Builder/BoxValue.h"
27 #include "flang/Optimizer/Builder/FIRBuilder.h"
28 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
29 #include "flang/Optimizer/Builder/Todo.h"
30 #include "flang/Optimizer/Dialect/FIROps.h"
31 #include "flang/Optimizer/Dialect/FIRType.h"
32 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
33 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
34 #include "flang/Optimizer/Transforms/CUFCommon.h"
35 #include "flang/Optimizer/Transforms/Passes.h"
36 #include "flang/Optimizer/Transforms/Utils.h"
37 #include "flang/Runtime/entry-names.h"
38 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
39 #include "mlir/IR/Matchers.h"
40 #include "mlir/IR/Operation.h"
41 #include "mlir/Pass/Pass.h"
42 #include "mlir/Transforms/DialectConversion.h"
43 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
44 #include "mlir/Transforms/RegionUtils.h"
45 #include "llvm/Support/Debug.h"
46 #include "llvm/Support/raw_ostream.h"
47 #include <llvm/Support/ErrorHandling.h>
48 #include <mlir/Dialect/Arith/IR/Arith.h>
49 #include <mlir/IR/BuiltinTypes.h>
50 #include <mlir/IR/Location.h>
51 #include <mlir/IR/MLIRContext.h>
52 #include <mlir/IR/Value.h>
53 #include <mlir/Support/LLVM.h>
57 #define GEN_PASS_DEF_SIMPLIFYINTRINSICS
58 #include "flang/Optimizer/Transforms/Passes.h.inc"
61 #define DEBUG_TYPE "flang-simplify-intrinsics"
65 class SimplifyIntrinsicsPass
66 : public fir::impl::SimplifyIntrinsicsBase
<SimplifyIntrinsicsPass
> {
67 using FunctionTypeGeneratorTy
=
68 llvm::function_ref
<mlir::FunctionType(fir::FirOpBuilder
&)>;
69 using FunctionBodyGeneratorTy
=
70 llvm::function_ref
<void(fir::FirOpBuilder
&, mlir::func::FuncOp
&)>;
71 using GenReductionBodyTy
= llvm::function_ref
<void(
72 fir::FirOpBuilder
&builder
, mlir::func::FuncOp
&funcOp
, unsigned rank
,
73 mlir::Type elementType
)>;
76 using fir::impl::SimplifyIntrinsicsBase
<
77 SimplifyIntrinsicsPass
>::SimplifyIntrinsicsBase
;
79 /// Generate a new function implementing a simplified version
80 /// of a Fortran runtime function defined by \p basename name.
81 /// \p typeGenerator is a callback that generates the new function's type.
82 /// \p bodyGenerator is a callback that generates the new function's body.
83 /// The new function is created in the \p builder's Module.
84 mlir::func::FuncOp
getOrCreateFunction(fir::FirOpBuilder
&builder
,
85 const mlir::StringRef
&basename
,
86 FunctionTypeGeneratorTy typeGenerator
,
87 FunctionBodyGeneratorTy bodyGenerator
);
88 void runOnOperation() override
;
89 void getDependentDialects(mlir::DialectRegistry
®istry
) const override
;
92 /// Helper functions to replace a reduction type of call with its
93 /// simplified form. The actual function is generated using a callback
95 /// \p call is the call to be replaced
96 /// \p kindMap is used to create FIROpBuilder
97 /// \p genBodyFunc is the callback that builds the replacement function
98 void simplifyIntOrFloatReduction(fir::CallOp call
,
99 const fir::KindMapping
&kindMap
,
100 GenReductionBodyTy genBodyFunc
);
101 void simplifyLogicalDim0Reduction(fir::CallOp call
,
102 const fir::KindMapping
&kindMap
,
103 GenReductionBodyTy genBodyFunc
);
104 void simplifyLogicalDim1Reduction(fir::CallOp call
,
105 const fir::KindMapping
&kindMap
,
106 GenReductionBodyTy genBodyFunc
);
107 void simplifyMinMaxlocReduction(fir::CallOp call
,
108 const fir::KindMapping
&kindMap
, bool isMax
);
109 void simplifyReductionBody(fir::CallOp call
, const fir::KindMapping
&kindMap
,
110 GenReductionBodyTy genBodyFunc
,
111 fir::FirOpBuilder
&builder
,
112 const mlir::StringRef
&basename
,
113 mlir::Type elementType
);
118 /// Create FirOpBuilder with the provided \p op insertion point
119 /// and \p kindMap additionally inheriting FastMathFlags from \p op.
120 static fir::FirOpBuilder
121 getSimplificationBuilder(mlir::Operation
*op
, const fir::KindMapping
&kindMap
) {
122 fir::FirOpBuilder builder
{op
, kindMap
};
123 auto fmi
= mlir::dyn_cast
<mlir::arith::ArithFastMathInterface
>(*op
);
127 // Regardless of what default FastMathFlags are used by FirOpBuilder,
128 // override them with FastMathFlags attached to the operation.
129 builder
.setFastMathFlags(fmi
.getFastMathFlagsAttr().getValue());
133 /// Generate function type for the simplified version of RTNAME(Sum) and
134 /// similar functions with a fir.box<none> type returning \p elementType.
135 static mlir::FunctionType
genNoneBoxType(fir::FirOpBuilder
&builder
,
136 const mlir::Type
&elementType
) {
137 mlir::Type boxType
= fir::BoxType::get(builder
.getNoneType());
138 return mlir::FunctionType::get(builder
.getContext(), {boxType
},
142 template <typename Op
>
143 Op
expectOp(mlir::Value val
) {
144 if (Op op
= mlir::dyn_cast_or_null
<Op
>(val
.getDefiningOp()))
146 LLVM_DEBUG(llvm::dbgs() << "Didn't find expected " << Op::getOperationName()
151 template <typename Op
>
152 static mlir::Value
findDefSingle(fir::ConvertOp op
) {
153 if (auto defOp
= expectOp
<Op
>(op
->getOperand(0))) {
154 return defOp
.getResult();
159 template <typename
... Ops
>
160 static mlir::Value
findDef(fir::ConvertOp op
) {
162 // Loop over the operation types given to see if any match, exiting once
163 // a match is found. Cast to void is needed to avoid compiler complaining
164 // that the result of expression is unused
165 (void)((defOp
= findDefSingle
<Ops
>(op
), (defOp
)) || ...);
169 static bool isOperandAbsent(mlir::Value val
) {
170 if (auto op
= expectOp
<fir::ConvertOp
>(val
)) {
171 assert(op
->getOperands().size() != 0);
172 return mlir::isa_and_nonnull
<fir::AbsentOp
>(
173 op
->getOperand(0).getDefiningOp());
178 static bool isTrueOrNotConstant(mlir::Value val
) {
179 if (auto op
= expectOp
<mlir::arith::ConstantOp
>(val
)) {
180 return !mlir::matchPattern(val
, mlir::m_Zero());
185 static bool isZero(mlir::Value val
) {
186 if (auto op
= expectOp
<fir::ConvertOp
>(val
)) {
187 assert(op
->getOperands().size() != 0);
188 if (mlir::Operation
*defOp
= op
->getOperand(0).getDefiningOp())
189 return mlir::matchPattern(defOp
, mlir::m_Zero());
194 static mlir::Value
findBoxDef(mlir::Value val
) {
195 if (auto op
= expectOp
<fir::ConvertOp
>(val
)) {
196 assert(op
->getOperands().size() != 0);
197 return findDef
<fir::EmboxOp
, fir::ReboxOp
>(op
);
202 static mlir::Value
findMaskDef(mlir::Value val
) {
203 if (auto op
= expectOp
<fir::ConvertOp
>(val
)) {
204 assert(op
->getOperands().size() != 0);
205 return findDef
<fir::EmboxOp
, fir::ReboxOp
, fir::AbsentOp
>(op
);
210 static unsigned getDimCount(mlir::Value val
) {
211 // In order to find the dimensions count, we look for EmboxOp/ReboxOp
212 // and take the count from its *result* type. Note that in case
213 // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
214 // have different types.
215 // Actually, we can take the box type from the operand of
216 // the first ConvertOp that has non-opaque box type that we meet
217 // going through the ConvertOp chain.
218 if (mlir::Value emboxVal
= findBoxDef(val
))
219 if (auto boxTy
= mlir::dyn_cast
<fir::BoxType
>(emboxVal
.getType()))
220 if (auto seqTy
= mlir::dyn_cast
<fir::SequenceType
>(boxTy
.getEleTy()))
221 return seqTy
.getDimension();
225 /// Given the call operation's box argument \p val, discover
226 /// the element type of the underlying array object.
227 /// \returns the element type or std::nullopt if the type cannot
228 /// be reliably found.
229 /// We expect that the argument is a result of fir.convert
230 /// with the destination type of !fir.box<none>.
231 static std::optional
<mlir::Type
> getArgElementType(mlir::Value val
) {
232 mlir::Operation
*defOp
;
234 defOp
= val
.getDefiningOp();
235 // Analyze only sequences of convert operations.
236 if (!mlir::isa
<fir::ConvertOp
>(defOp
))
238 val
= defOp
->getOperand(0);
239 // The convert operation is expected to convert from one
240 // box type to another box type.
241 auto boxType
= mlir::cast
<fir::BoxType
>(val
.getType());
242 auto elementType
= fir::unwrapSeqOrBoxedSeqType(boxType
);
243 if (!mlir::isa
<mlir::NoneType
>(elementType
))
248 using BodyOpGeneratorTy
= llvm::function_ref
<mlir::Value(
249 fir::FirOpBuilder
&, mlir::Location
, const mlir::Type
&, mlir::Value
,
251 using ContinueLoopGenTy
= llvm::function_ref
<llvm::SmallVector
<mlir::Value
>(
252 fir::FirOpBuilder
&, mlir::Location
, mlir::Value
)>;
254 /// Generate the reduction loop into \p funcOp.
256 /// \p initVal is a function, called to get the initial value for
257 /// the reduction value
258 /// \p genBody is called to fill in the actual reduciton operation
259 /// for example add for SUM, MAX for MAXVAL, etc.
260 /// \p rank is the rank of the input argument.
261 /// \p elementType is the type of the elements in the input array,
262 /// which may be different to the return type.
263 /// \p loopCond is called to generate the condition to continue or
264 /// not for IterWhile loops
265 /// \p unorderedOrInitalLoopCond contains either a boolean or bool
266 /// mlir constant, and controls the inital value for while loops
267 /// or if DoLoop is ordered/unordered.
269 template <typename OP
, typename T
, int resultIndex
>
271 genReductionLoop(fir::FirOpBuilder
&builder
, mlir::func::FuncOp
&funcOp
,
272 fir::InitValGeneratorTy initVal
, ContinueLoopGenTy loopCond
,
273 T unorderedOrInitialLoopCond
, BodyOpGeneratorTy genBody
,
274 unsigned rank
, mlir::Type elementType
, mlir::Location loc
) {
276 mlir::IndexType idxTy
= builder
.getIndexType();
278 mlir::Block::BlockArgListType args
= funcOp
.front().getArguments();
279 mlir::Value arg
= args
[0];
281 mlir::Value zeroIdx
= builder
.createIntegerConstant(loc
, idxTy
, 0);
283 fir::SequenceType::Shape
flatShape(rank
,
284 fir::SequenceType::getUnknownExtent());
285 mlir::Type arrTy
= fir::SequenceType::get(flatShape
, elementType
);
286 mlir::Type boxArrTy
= fir::BoxType::get(arrTy
);
287 mlir::Value array
= builder
.create
<fir::ConvertOp
>(loc
, boxArrTy
, arg
);
288 mlir::Type resultType
= funcOp
.getResultTypes()[0];
289 mlir::Value init
= initVal(builder
, loc
, resultType
);
291 llvm::SmallVector
<mlir::Value
, Fortran::common::maxRank
> bounds
;
293 assert(rank
> 0 && "rank cannot be zero");
294 mlir::Value one
= builder
.createIntegerConstant(loc
, idxTy
, 1);
296 // Compute all the upper bounds before the loop nest.
297 // It is not strictly necessary for performance, since the loop nest
298 // does not have any store operations and any LICM optimization
299 // should be able to optimize the redundancy.
300 for (unsigned i
= 0; i
< rank
; ++i
) {
301 mlir::Value dimIdx
= builder
.createIntegerConstant(loc
, idxTy
, i
);
303 builder
.create
<fir::BoxDimsOp
>(loc
, idxTy
, idxTy
, idxTy
, array
, dimIdx
);
304 mlir::Value len
= dims
.getResult(1);
305 // We use C indexing here, so len-1 as loopcount
306 mlir::Value loopCount
= builder
.create
<mlir::arith::SubIOp
>(loc
, len
, one
);
307 bounds
.push_back(loopCount
);
309 // Create a loop nest consisting of OP operations.
310 // Collect the loops' induction variables into indices array,
311 // which will be used in the innermost loop to load the input
313 // The loops are generated such that the innermost loop processes
315 llvm::SmallVector
<mlir::Value
, Fortran::common::maxRank
> indices
;
316 for (unsigned i
= rank
; 0 < i
; --i
) {
317 mlir::Value step
= one
;
318 mlir::Value loopCount
= bounds
[i
- 1];
319 auto loop
= builder
.create
<OP
>(loc
, zeroIdx
, loopCount
, step
,
320 unorderedOrInitialLoopCond
,
321 /*finalCountValue=*/false, init
);
322 init
= loop
.getRegionIterArgs()[resultIndex
];
323 indices
.push_back(loop
.getInductionVar());
324 // Set insertion point to the loop body so that the next loop
325 // is inserted inside the current one.
326 builder
.setInsertionPointToStart(loop
.getBody());
329 // Reverse the indices such that they are ordered as:
330 // <dim-0-idx, dim-1-idx, ...>
331 std::reverse(indices
.begin(), indices
.end());
332 // We are in the innermost loop: generate the reduction body.
333 mlir::Type eleRefTy
= builder
.getRefType(elementType
);
335 builder
.create
<fir::CoordinateOp
>(loc
, eleRefTy
, array
, indices
);
336 mlir::Value elem
= builder
.create
<fir::LoadOp
>(loc
, addr
);
337 mlir::Value reductionVal
= genBody(builder
, loc
, elementType
, elem
, init
);
338 // Generate vector with condition to continue while loop at [0] and result
339 // from current loop at [1] for IterWhileOp loops, just result at [0] for
341 llvm::SmallVector
<mlir::Value
> results
= loopCond(builder
, loc
, reductionVal
);
343 // Unwind the loop nest and insert ResultOp on each level
344 // to return the updated value of the reduction to the enclosing
346 for (unsigned i
= 0; i
< rank
; ++i
) {
347 auto result
= builder
.create
<fir::ResultOp
>(loc
, results
);
348 // Proceed to the outer loop.
349 auto loop
= mlir::cast
<OP
>(result
->getParentOp());
350 results
= loop
.getResults();
351 // Set insertion point after the loop operation that we have
353 builder
.setInsertionPointAfter(loop
.getOperation());
355 // End of loop nest. The insertion point is after the outermost loop.
356 // Return the reduction value from the function.
357 builder
.create
<mlir::func::ReturnOp
>(loc
, results
[resultIndex
]);
360 static llvm::SmallVector
<mlir::Value
> nopLoopCond(fir::FirOpBuilder
&builder
,
362 mlir::Value reductionVal
) {
363 return {reductionVal
};
366 /// Generate function body of the simplified version of RTNAME(Sum)
367 /// with signature provided by \p funcOp. The caller is responsible
368 /// for saving/restoring the original insertion point of \p builder.
369 /// \p funcOp is expected to be empty on entry to this function.
370 /// \p rank specifies the rank of the input argument.
371 static void genRuntimeSumBody(fir::FirOpBuilder
&builder
,
372 mlir::func::FuncOp
&funcOp
, unsigned rank
,
373 mlir::Type elementType
) {
374 // function RTNAME(Sum)<T>x<rank>_simplified(arr)
375 // T, dimension(:) :: arr
378 // do iter = 0, extent(arr)
379 // sum = sum + arr[iter]
381 // RTNAME(Sum)<T>x<rank>_simplified = sum
382 // end function RTNAME(Sum)<T>x<rank>_simplified
383 auto zero
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
384 mlir::Type elementType
) {
385 if (auto ty
= mlir::dyn_cast
<mlir::FloatType
>(elementType
)) {
386 const llvm::fltSemantics
&sem
= ty
.getFloatSemantics();
387 return builder
.createRealConstant(loc
, elementType
,
388 llvm::APFloat::getZero(sem
));
390 return builder
.createIntegerConstant(loc
, elementType
, 0);
393 auto genBodyOp
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
394 mlir::Type elementType
, mlir::Value elem1
,
395 mlir::Value elem2
) -> mlir::Value
{
396 if (mlir::isa
<mlir::FloatType
>(elementType
))
397 return builder
.create
<mlir::arith::AddFOp
>(loc
, elem1
, elem2
);
398 if (mlir::isa
<mlir::IntegerType
>(elementType
))
399 return builder
.create
<mlir::arith::AddIOp
>(loc
, elem1
, elem2
);
401 llvm_unreachable("unsupported type");
405 mlir::Location loc
= mlir::UnknownLoc::get(builder
.getContext());
406 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
408 genReductionLoop
<fir::DoLoopOp
, bool, 0>(builder
, funcOp
, zero
, nopLoopCond
,
409 false, genBodyOp
, rank
, elementType
,
413 static void genRuntimeMaxvalBody(fir::FirOpBuilder
&builder
,
414 mlir::func::FuncOp
&funcOp
, unsigned rank
,
415 mlir::Type elementType
) {
416 auto init
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
417 mlir::Type elementType
) {
418 if (auto ty
= mlir::dyn_cast
<mlir::FloatType
>(elementType
)) {
419 const llvm::fltSemantics
&sem
= ty
.getFloatSemantics();
420 return builder
.createRealConstant(
421 loc
, elementType
, llvm::APFloat::getLargest(sem
, /*Negative=*/true));
423 unsigned bits
= elementType
.getIntOrFloatBitWidth();
424 int64_t minInt
= llvm::APInt::getSignedMinValue(bits
).getSExtValue();
425 return builder
.createIntegerConstant(loc
, elementType
, minInt
);
428 auto genBodyOp
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
429 mlir::Type elementType
, mlir::Value elem1
,
430 mlir::Value elem2
) -> mlir::Value
{
431 if (mlir::isa
<mlir::FloatType
>(elementType
)) {
432 // arith.maxf later converted to llvm.intr.maxnum does not work
433 // correctly for NaNs and -0.0 (see maxnum/minnum pattern matching
434 // in LLVM's InstCombine pass). Moreover, llvm.intr.maxnum
435 // for F128 operands is lowered into fmaxl call by LLVM.
436 // This libm function may not work properly for F128 arguments
437 // on targets where long double is not F128. It is an LLVM issue,
438 // but we just use normal select here to resolve all the cases.
439 auto compare
= builder
.create
<mlir::arith::CmpFOp
>(
440 loc
, mlir::arith::CmpFPredicate::OGT
, elem1
, elem2
);
441 return builder
.create
<mlir::arith::SelectOp
>(loc
, compare
, elem1
, elem2
);
443 if (mlir::isa
<mlir::IntegerType
>(elementType
))
444 return builder
.create
<mlir::arith::MaxSIOp
>(loc
, elem1
, elem2
);
446 llvm_unreachable("unsupported type");
450 mlir::Location loc
= mlir::UnknownLoc::get(builder
.getContext());
451 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
453 genReductionLoop
<fir::DoLoopOp
, bool, 0>(builder
, funcOp
, init
, nopLoopCond
,
454 false, genBodyOp
, rank
, elementType
,
458 static void genRuntimeCountBody(fir::FirOpBuilder
&builder
,
459 mlir::func::FuncOp
&funcOp
, unsigned rank
,
460 mlir::Type elementType
) {
461 auto zero
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
462 mlir::Type elementType
) {
463 unsigned bits
= elementType
.getIntOrFloatBitWidth();
464 int64_t zeroInt
= llvm::APInt::getZero(bits
).getSExtValue();
465 return builder
.createIntegerConstant(loc
, elementType
, zeroInt
);
468 auto genBodyOp
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
469 mlir::Type elementType
, mlir::Value elem1
,
470 mlir::Value elem2
) -> mlir::Value
{
471 auto zero32
= builder
.createIntegerConstant(loc
, elementType
, 0);
472 auto zero64
= builder
.createIntegerConstant(loc
, builder
.getI64Type(), 0);
473 auto one64
= builder
.createIntegerConstant(loc
, builder
.getI64Type(), 1);
475 auto compare
= builder
.create
<mlir::arith::CmpIOp
>(
476 loc
, mlir::arith::CmpIPredicate::eq
, elem1
, zero32
);
478 builder
.create
<mlir::arith::SelectOp
>(loc
, compare
, zero64
, one64
);
479 return builder
.create
<mlir::arith::AddIOp
>(loc
, select
, elem2
);
482 // Count always gets I32 for elementType as it converts logical input to
483 // logical<4> before passing to the function.
484 mlir::Location loc
= mlir::UnknownLoc::get(builder
.getContext());
485 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
487 genReductionLoop
<fir::DoLoopOp
, bool, 0>(builder
, funcOp
, zero
, nopLoopCond
,
488 false, genBodyOp
, rank
, elementType
,
492 static void genRuntimeAnyBody(fir::FirOpBuilder
&builder
,
493 mlir::func::FuncOp
&funcOp
, unsigned rank
,
494 mlir::Type elementType
) {
495 auto zero
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
496 mlir::Type elementType
) {
497 return builder
.createIntegerConstant(loc
, elementType
, 0);
500 auto genBodyOp
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
501 mlir::Type elementType
, mlir::Value elem1
,
502 mlir::Value elem2
) -> mlir::Value
{
503 auto zero
= builder
.createIntegerConstant(loc
, elementType
, 0);
504 return builder
.create
<mlir::arith::CmpIOp
>(
505 loc
, mlir::arith::CmpIPredicate::ne
, elem1
, zero
);
508 auto continueCond
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
509 mlir::Value reductionVal
) {
510 auto one1
= builder
.createIntegerConstant(loc
, builder
.getI1Type(), 1);
511 auto eor
= builder
.create
<mlir::arith::XOrIOp
>(loc
, reductionVal
, one1
);
512 llvm::SmallVector
<mlir::Value
> results
= {eor
, reductionVal
};
516 mlir::Location loc
= mlir::UnknownLoc::get(builder
.getContext());
517 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
518 mlir::Value ok
= builder
.createBool(loc
, true);
520 genReductionLoop
<fir::IterWhileOp
, mlir::Value
, 1>(
521 builder
, funcOp
, zero
, continueCond
, ok
, genBodyOp
, rank
, elementType
,
525 static void genRuntimeAllBody(fir::FirOpBuilder
&builder
,
526 mlir::func::FuncOp
&funcOp
, unsigned rank
,
527 mlir::Type elementType
) {
528 auto one
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
529 mlir::Type elementType
) {
530 return builder
.createIntegerConstant(loc
, elementType
, 1);
533 auto genBodyOp
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
534 mlir::Type elementType
, mlir::Value elem1
,
535 mlir::Value elem2
) -> mlir::Value
{
536 auto zero
= builder
.createIntegerConstant(loc
, elementType
, 0);
537 return builder
.create
<mlir::arith::CmpIOp
>(
538 loc
, mlir::arith::CmpIPredicate::ne
, elem1
, zero
);
541 auto continueCond
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
542 mlir::Value reductionVal
) {
543 llvm::SmallVector
<mlir::Value
> results
= {reductionVal
, reductionVal
};
547 mlir::Location loc
= mlir::UnknownLoc::get(builder
.getContext());
548 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
549 mlir::Value ok
= builder
.createBool(loc
, true);
551 genReductionLoop
<fir::IterWhileOp
, mlir::Value
, 1>(
552 builder
, funcOp
, one
, continueCond
, ok
, genBodyOp
, rank
, elementType
,
556 static mlir::FunctionType
genRuntimeMinlocType(fir::FirOpBuilder
&builder
,
558 mlir::Type boxType
= fir::BoxType::get(builder
.getNoneType());
559 mlir::Type boxRefType
= builder
.getRefType(boxType
);
561 return mlir::FunctionType::get(builder
.getContext(),
562 {boxRefType
, boxType
, boxType
}, {});
565 // Produces a loop nest for a Minloc intrinsic.
566 void fir::genMinMaxlocReductionLoop(
567 fir::FirOpBuilder
&builder
, mlir::Value array
,
568 fir::InitValGeneratorTy initVal
, fir::MinlocBodyOpGeneratorTy genBody
,
569 fir::AddrGeneratorTy getAddrFn
, unsigned rank
, mlir::Type elementType
,
570 mlir::Location loc
, mlir::Type maskElemType
, mlir::Value resultArr
,
571 bool maskMayBeLogicalScalar
) {
572 mlir::IndexType idxTy
= builder
.getIndexType();
574 mlir::Value zeroIdx
= builder
.createIntegerConstant(loc
, idxTy
, 0);
576 fir::SequenceType::Shape
flatShape(rank
,
577 fir::SequenceType::getUnknownExtent());
578 mlir::Type arrTy
= fir::SequenceType::get(flatShape
, elementType
);
579 mlir::Type boxArrTy
= fir::BoxType::get(arrTy
);
580 array
= builder
.create
<fir::ConvertOp
>(loc
, boxArrTy
, array
);
582 mlir::Type resultElemType
= hlfir::getFortranElementType(resultArr
.getType());
583 mlir::Value flagSet
= builder
.createIntegerConstant(loc
, resultElemType
, 1);
584 mlir::Value zero
= builder
.createIntegerConstant(loc
, resultElemType
, 0);
585 mlir::Value flagRef
= builder
.createTemporary(loc
, resultElemType
);
586 builder
.create
<fir::StoreOp
>(loc
, zero
, flagRef
);
588 mlir::Value init
= initVal(builder
, loc
, elementType
);
589 llvm::SmallVector
<mlir::Value
, Fortran::common::maxRank
> bounds
;
591 assert(rank
> 0 && "rank cannot be zero");
592 mlir::Value one
= builder
.createIntegerConstant(loc
, idxTy
, 1);
594 // Compute all the upper bounds before the loop nest.
595 // It is not strictly necessary for performance, since the loop nest
596 // does not have any store operations and any LICM optimization
597 // should be able to optimize the redundancy.
598 for (unsigned i
= 0; i
< rank
; ++i
) {
599 mlir::Value dimIdx
= builder
.createIntegerConstant(loc
, idxTy
, i
);
601 builder
.create
<fir::BoxDimsOp
>(loc
, idxTy
, idxTy
, idxTy
, array
, dimIdx
);
602 mlir::Value len
= dims
.getResult(1);
603 // We use C indexing here, so len-1 as loopcount
604 mlir::Value loopCount
= builder
.create
<mlir::arith::SubIOp
>(loc
, len
, one
);
605 bounds
.push_back(loopCount
);
607 // Create a loop nest consisting of OP operations.
608 // Collect the loops' induction variables into indices array,
609 // which will be used in the innermost loop to load the input
611 // The loops are generated such that the innermost loop processes
613 llvm::SmallVector
<mlir::Value
, Fortran::common::maxRank
> indices
;
614 for (unsigned i
= rank
; 0 < i
; --i
) {
615 mlir::Value step
= one
;
616 mlir::Value loopCount
= bounds
[i
- 1];
618 builder
.create
<fir::DoLoopOp
>(loc
, zeroIdx
, loopCount
, step
, false,
619 /*finalCountValue=*/false, init
);
620 init
= loop
.getRegionIterArgs()[0];
621 indices
.push_back(loop
.getInductionVar());
622 // Set insertion point to the loop body so that the next loop
623 // is inserted inside the current one.
624 builder
.setInsertionPointToStart(loop
.getBody());
627 // Reverse the indices such that they are ordered as:
628 // <dim-0-idx, dim-1-idx, ...>
629 std::reverse(indices
.begin(), indices
.end());
630 mlir::Value reductionVal
=
631 genBody(builder
, loc
, elementType
, array
, flagRef
, init
, indices
);
633 // Unwind the loop nest and insert ResultOp on each level
634 // to return the updated value of the reduction to the enclosing
636 for (unsigned i
= 0; i
< rank
; ++i
) {
637 auto result
= builder
.create
<fir::ResultOp
>(loc
, reductionVal
);
638 // Proceed to the outer loop.
639 auto loop
= mlir::cast
<fir::DoLoopOp
>(result
->getParentOp());
640 reductionVal
= loop
.getResult(0);
641 // Set insertion point after the loop operation that we have
643 builder
.setInsertionPointAfter(loop
.getOperation());
645 // End of loop nest. The insertion point is after the outermost loop.
646 if (maskMayBeLogicalScalar
) {
648 mlir::dyn_cast
<fir::IfOp
>(builder
.getBlock()->getParentOp())) {
649 builder
.create
<fir::ResultOp
>(loc
, reductionVal
);
650 builder
.setInsertionPointAfter(ifOp
);
651 // Redefine flagSet to escape scope of ifOp
652 flagSet
= builder
.createIntegerConstant(loc
, resultElemType
, 1);
653 reductionVal
= ifOp
.getResult(0);
658 static void genRuntimeMinMaxlocBody(fir::FirOpBuilder
&builder
,
659 mlir::func::FuncOp
&funcOp
, bool isMax
,
660 unsigned rank
, int maskRank
,
661 mlir::Type elementType
,
662 mlir::Type maskElemType
,
663 mlir::Type resultElemTy
, bool isDim
) {
664 auto init
= [isMax
](fir::FirOpBuilder builder
, mlir::Location loc
,
665 mlir::Type elementType
) {
666 if (auto ty
= mlir::dyn_cast
<mlir::FloatType
>(elementType
)) {
667 const llvm::fltSemantics
&sem
= ty
.getFloatSemantics();
668 llvm::APFloat limit
= llvm::APFloat::getInf(sem
, /*Negative=*/isMax
);
669 return builder
.createRealConstant(loc
, elementType
, limit
);
671 unsigned bits
= elementType
.getIntOrFloatBitWidth();
672 int64_t initValue
= (isMax
? llvm::APInt::getSignedMinValue(bits
)
673 : llvm::APInt::getSignedMaxValue(bits
))
675 return builder
.createIntegerConstant(loc
, elementType
, initValue
);
678 mlir::Location loc
= mlir::UnknownLoc::get(builder
.getContext());
679 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
681 mlir::Value mask
= funcOp
.front().getArgument(2);
683 // Set up result array in case of early exit / 0 length array
684 mlir::IndexType idxTy
= builder
.getIndexType();
685 mlir::Type resultTy
= fir::SequenceType::get(rank
, resultElemTy
);
686 mlir::Type resultHeapTy
= fir::HeapType::get(resultTy
);
687 mlir::Type resultBoxTy
= fir::BoxType::get(resultHeapTy
);
689 mlir::Value returnValue
= builder
.createIntegerConstant(loc
, resultElemTy
, 0);
690 mlir::Value resultArrSize
= builder
.createIntegerConstant(loc
, idxTy
, rank
);
692 mlir::Value resultArrInit
= builder
.create
<fir::AllocMemOp
>(loc
, resultTy
);
693 mlir::Value resultArrShape
= builder
.create
<fir::ShapeOp
>(loc
, resultArrSize
);
694 mlir::Value resultArr
= builder
.create
<fir::EmboxOp
>(
695 loc
, resultBoxTy
, resultArrInit
, resultArrShape
);
697 mlir::Type resultRefTy
= builder
.getRefType(resultElemTy
);
700 fir::SequenceType::Shape
flatShape(rank
,
701 fir::SequenceType::getUnknownExtent());
702 mlir::Type maskTy
= fir::SequenceType::get(flatShape
, maskElemType
);
703 mlir::Type boxMaskTy
= fir::BoxType::get(maskTy
);
704 mask
= builder
.create
<fir::ConvertOp
>(loc
, boxMaskTy
, mask
);
707 for (unsigned int i
= 0; i
< rank
; ++i
) {
708 mlir::Value index
= builder
.createIntegerConstant(loc
, idxTy
, i
);
709 mlir::Value resultElemAddr
=
710 builder
.create
<fir::CoordinateOp
>(loc
, resultRefTy
, resultArr
, index
);
711 builder
.create
<fir::StoreOp
>(loc
, returnValue
, resultElemAddr
);
715 [&rank
, &resultArr
, isMax
, &mask
, &maskElemType
, &maskRank
](
716 fir::FirOpBuilder builder
, mlir::Location loc
, mlir::Type elementType
,
717 mlir::Value array
, mlir::Value flagRef
, mlir::Value reduction
,
718 const llvm::SmallVectorImpl
<mlir::Value
> &indices
) -> mlir::Value
{
719 // We are in the innermost loop: generate the reduction body.
721 mlir::Type logicalRef
= builder
.getRefType(maskElemType
);
722 mlir::Value maskAddr
=
723 builder
.create
<fir::CoordinateOp
>(loc
, logicalRef
, mask
, indices
);
724 mlir::Value maskElem
= builder
.create
<fir::LoadOp
>(loc
, maskAddr
);
726 // fir::IfOp requires argument to be I1 - won't accept logical or any
728 mlir::Type ifCompatType
= builder
.getI1Type();
729 mlir::Value ifCompatElem
=
730 builder
.create
<fir::ConvertOp
>(loc
, ifCompatType
, maskElem
);
732 llvm::SmallVector
<mlir::Type
> resultsTy
= {elementType
, elementType
};
733 fir::IfOp ifOp
= builder
.create
<fir::IfOp
>(loc
, elementType
, ifCompatElem
,
734 /*withElseRegion=*/true);
735 builder
.setInsertionPointToStart(&ifOp
.getThenRegion().front());
738 // Set flag that mask was true at some point
739 mlir::Value flagSet
= builder
.createIntegerConstant(
740 loc
, mlir::cast
<fir::ReferenceType
>(flagRef
.getType()).getEleTy(), 1);
741 mlir::Value isFirst
= builder
.create
<fir::LoadOp
>(loc
, flagRef
);
742 mlir::Type eleRefTy
= builder
.getRefType(elementType
);
744 builder
.create
<fir::CoordinateOp
>(loc
, eleRefTy
, array
, indices
);
745 mlir::Value elem
= builder
.create
<fir::LoadOp
>(loc
, addr
);
748 if (mlir::isa
<mlir::FloatType
>(elementType
)) {
749 // For FP reductions we want the first smallest value to be used, that
750 // is not NaN. A OGL/OLT condition will usually work for this unless all
751 // the values are Nan or Inf. This follows the same logic as
752 // NumericCompare for Minloc/Maxlox in extrema.cpp.
753 cmp
= builder
.create
<mlir::arith::CmpFOp
>(
755 isMax
? mlir::arith::CmpFPredicate::OGT
756 : mlir::arith::CmpFPredicate::OLT
,
759 mlir::Value cmpNan
= builder
.create
<mlir::arith::CmpFOp
>(
760 loc
, mlir::arith::CmpFPredicate::UNE
, reduction
, reduction
);
761 mlir::Value cmpNan2
= builder
.create
<mlir::arith::CmpFOp
>(
762 loc
, mlir::arith::CmpFPredicate::OEQ
, elem
, elem
);
763 cmpNan
= builder
.create
<mlir::arith::AndIOp
>(loc
, cmpNan
, cmpNan2
);
764 cmp
= builder
.create
<mlir::arith::OrIOp
>(loc
, cmp
, cmpNan
);
765 } else if (mlir::isa
<mlir::IntegerType
>(elementType
)) {
766 cmp
= builder
.create
<mlir::arith::CmpIOp
>(
768 isMax
? mlir::arith::CmpIPredicate::sgt
769 : mlir::arith::CmpIPredicate::slt
,
772 llvm_unreachable("unsupported type");
775 // The condition used for the loop is isFirst || <the condition above>.
776 isFirst
= builder
.create
<fir::ConvertOp
>(loc
, cmp
.getType(), isFirst
);
777 isFirst
= builder
.create
<mlir::arith::XOrIOp
>(
778 loc
, isFirst
, builder
.createIntegerConstant(loc
, cmp
.getType(), 1));
779 cmp
= builder
.create
<mlir::arith::OrIOp
>(loc
, cmp
, isFirst
);
780 fir::IfOp ifOp
= builder
.create
<fir::IfOp
>(loc
, elementType
, cmp
,
781 /*withElseRegion*/ true);
783 builder
.setInsertionPointToStart(&ifOp
.getThenRegion().front());
784 builder
.create
<fir::StoreOp
>(loc
, flagSet
, flagRef
);
785 mlir::Type resultElemTy
= hlfir::getFortranElementType(resultArr
.getType());
786 mlir::Type returnRefTy
= builder
.getRefType(resultElemTy
);
787 mlir::IndexType idxTy
= builder
.getIndexType();
789 mlir::Value one
= builder
.createIntegerConstant(loc
, resultElemTy
, 1);
791 for (unsigned int i
= 0; i
< rank
; ++i
) {
792 mlir::Value index
= builder
.createIntegerConstant(loc
, idxTy
, i
);
793 mlir::Value resultElemAddr
=
794 builder
.create
<fir::CoordinateOp
>(loc
, returnRefTy
, resultArr
, index
);
795 mlir::Value convert
=
796 builder
.create
<fir::ConvertOp
>(loc
, resultElemTy
, indices
[i
]);
797 mlir::Value fortranIndex
=
798 builder
.create
<mlir::arith::AddIOp
>(loc
, convert
, one
);
799 builder
.create
<fir::StoreOp
>(loc
, fortranIndex
, resultElemAddr
);
801 builder
.create
<fir::ResultOp
>(loc
, elem
);
802 builder
.setInsertionPointToStart(&ifOp
.getElseRegion().front());
803 builder
.create
<fir::ResultOp
>(loc
, reduction
);
804 builder
.setInsertionPointAfter(ifOp
);
805 mlir::Value reductionVal
= ifOp
.getResult(0);
807 // Close the mask if needed
810 mlir::dyn_cast
<fir::IfOp
>(builder
.getBlock()->getParentOp());
811 builder
.create
<fir::ResultOp
>(loc
, reductionVal
);
812 builder
.setInsertionPointToStart(&ifOp
.getElseRegion().front());
813 builder
.create
<fir::ResultOp
>(loc
, reduction
);
814 reductionVal
= ifOp
.getResult(0);
815 builder
.setInsertionPointAfter(ifOp
);
821 // if mask is a logical scalar, we can check its value before the main loop
822 // and either ignore the fact it is there or exit early.
824 mlir::Type logical
= builder
.getI1Type();
825 mlir::IndexType idxTy
= builder
.getIndexType();
827 fir::SequenceType::Shape
singleElement(1, 1);
828 mlir::Type arrTy
= fir::SequenceType::get(singleElement
, logical
);
829 mlir::Type boxArrTy
= fir::BoxType::get(arrTy
);
830 mlir::Value array
= builder
.create
<fir::ConvertOp
>(loc
, boxArrTy
, mask
);
832 mlir::Value indx
= builder
.createIntegerConstant(loc
, idxTy
, 0);
833 mlir::Type logicalRefTy
= builder
.getRefType(logical
);
834 mlir::Value condAddr
=
835 builder
.create
<fir::CoordinateOp
>(loc
, logicalRefTy
, array
, indx
);
836 mlir::Value cond
= builder
.create
<fir::LoadOp
>(loc
, condAddr
);
838 fir::IfOp ifOp
= builder
.create
<fir::IfOp
>(loc
, elementType
, cond
,
839 /*withElseRegion=*/true);
841 builder
.setInsertionPointToStart(&ifOp
.getElseRegion().front());
842 mlir::Value basicValue
;
843 if (mlir::isa
<mlir::IntegerType
>(elementType
)) {
844 basicValue
= builder
.createIntegerConstant(loc
, elementType
, 0);
846 basicValue
= builder
.createRealConstant(loc
, elementType
, 0);
848 builder
.create
<fir::ResultOp
>(loc
, basicValue
);
850 builder
.setInsertionPointToStart(&ifOp
.getThenRegion().front());
852 auto getAddrFn
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
853 const mlir::Type
&resultElemType
, mlir::Value resultArr
,
855 mlir::Type resultRefTy
= builder
.getRefType(resultElemType
);
856 return builder
.create
<fir::CoordinateOp
>(loc
, resultRefTy
, resultArr
,
860 genMinMaxlocReductionLoop(builder
, funcOp
.front().getArgument(1), init
,
861 genBodyOp
, getAddrFn
, rank
, elementType
, loc
,
862 maskElemType
, resultArr
, maskRank
== 0);
864 // Store newly created output array to the reference passed in
866 mlir::Type resultBoxTy
=
867 fir::BoxType::get(fir::HeapType::get(resultElemTy
));
868 mlir::Value outputArr
= builder
.create
<fir::ConvertOp
>(
869 loc
, builder
.getRefType(resultBoxTy
), funcOp
.front().getArgument(0));
870 mlir::Value resultArrScalar
= builder
.create
<fir::ConvertOp
>(
871 loc
, fir::HeapType::get(resultElemTy
), resultArrInit
);
872 mlir::Value resultBox
=
873 builder
.create
<fir::EmboxOp
>(loc
, resultBoxTy
, resultArrScalar
);
874 builder
.create
<fir::StoreOp
>(loc
, resultBox
, outputArr
);
876 fir::SequenceType::Shape
resultShape(1, rank
);
877 mlir::Type outputArrTy
= fir::SequenceType::get(resultShape
, resultElemTy
);
878 mlir::Type outputHeapTy
= fir::HeapType::get(outputArrTy
);
879 mlir::Type outputBoxTy
= fir::BoxType::get(outputHeapTy
);
880 mlir::Type outputRefTy
= builder
.getRefType(outputBoxTy
);
881 mlir::Value outputArr
= builder
.create
<fir::ConvertOp
>(
882 loc
, outputRefTy
, funcOp
.front().getArgument(0));
883 builder
.create
<fir::StoreOp
>(loc
, resultArr
, outputArr
);
886 builder
.create
<mlir::func::ReturnOp
>(loc
);
889 /// Generate function type for the simplified version of RTNAME(DotProduct)
890 /// operating on the given \p elementType.
891 static mlir::FunctionType
genRuntimeDotType(fir::FirOpBuilder
&builder
,
892 const mlir::Type
&elementType
) {
893 mlir::Type boxType
= fir::BoxType::get(builder
.getNoneType());
894 return mlir::FunctionType::get(builder
.getContext(), {boxType
, boxType
},
898 /// Generate function body of the simplified version of RTNAME(DotProduct)
899 /// with signature provided by \p funcOp. The caller is responsible
900 /// for saving/restoring the original insertion point of \p builder.
901 /// \p funcOp is expected to be empty on entry to this function.
902 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
903 /// of the underlying array objects - they are used to generate proper
904 /// element accesses.
905 static void genRuntimeDotBody(fir::FirOpBuilder
&builder
,
906 mlir::func::FuncOp
&funcOp
,
907 mlir::Type arg1ElementTy
,
908 mlir::Type arg2ElementTy
) {
909 // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
910 // T, dimension(:) :: arr1, arr2
913 // do iter = 0, extent(arr1)
914 // product = product + arr1[iter] * arr2[iter]
916 // RTNAME(ADotProduct)<T>_simplified = product
917 // end function RTNAME(DotProduct)<T>_simplified
918 auto loc
= mlir::UnknownLoc::get(builder
.getContext());
919 mlir::Type resultElementType
= funcOp
.getResultTypes()[0];
920 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
922 mlir::IndexType idxTy
= builder
.getIndexType();
925 mlir::isa
<mlir::FloatType
>(resultElementType
)
926 ? builder
.createRealConstant(loc
, resultElementType
, 0.0)
927 : builder
.createIntegerConstant(loc
, resultElementType
, 0);
929 mlir::Block::BlockArgListType args
= funcOp
.front().getArguments();
930 mlir::Value arg1
= args
[0];
931 mlir::Value arg2
= args
[1];
933 mlir::Value zeroIdx
= builder
.createIntegerConstant(loc
, idxTy
, 0);
935 fir::SequenceType::Shape flatShape
= {fir::SequenceType::getUnknownExtent()};
936 mlir::Type arrTy1
= fir::SequenceType::get(flatShape
, arg1ElementTy
);
937 mlir::Type boxArrTy1
= fir::BoxType::get(arrTy1
);
938 mlir::Value array1
= builder
.create
<fir::ConvertOp
>(loc
, boxArrTy1
, arg1
);
939 mlir::Type arrTy2
= fir::SequenceType::get(flatShape
, arg2ElementTy
);
940 mlir::Type boxArrTy2
= fir::BoxType::get(arrTy2
);
941 mlir::Value array2
= builder
.create
<fir::ConvertOp
>(loc
, boxArrTy2
, arg2
);
942 // This version takes the loop trip count from the first argument.
943 // If the first argument's box has unknown (at compilation time)
944 // extent, then it may be better to take the extent from the second
945 // argument - so that after inlining the loop may be better optimized, e.g.
946 // fully unrolled. This requires generating two versions of the simplified
947 // function and some analysis at the call site to choose which version
948 // is more profitable to call.
949 // Note that we can assume that both arguments have the same extent.
951 builder
.create
<fir::BoxDimsOp
>(loc
, idxTy
, idxTy
, idxTy
, array1
, zeroIdx
);
952 mlir::Value len
= dims
.getResult(1);
953 mlir::Value one
= builder
.createIntegerConstant(loc
, idxTy
, 1);
954 mlir::Value step
= one
;
956 // We use C indexing here, so len-1 as loopcount
957 mlir::Value loopCount
= builder
.create
<mlir::arith::SubIOp
>(loc
, len
, one
);
958 auto loop
= builder
.create
<fir::DoLoopOp
>(loc
, zeroIdx
, loopCount
, step
,
960 /*finalCountValue=*/false, zero
);
961 mlir::Value sumVal
= loop
.getRegionIterArgs()[0];
964 mlir::OpBuilder::InsertPoint loopEndPt
= builder
.saveInsertionPoint();
965 builder
.setInsertionPointToStart(loop
.getBody());
967 mlir::Type eleRef1Ty
= builder
.getRefType(arg1ElementTy
);
968 mlir::Value index
= loop
.getInductionVar();
970 builder
.create
<fir::CoordinateOp
>(loc
, eleRef1Ty
, array1
, index
);
971 mlir::Value elem1
= builder
.create
<fir::LoadOp
>(loc
, addr1
);
972 // Convert to the result type.
973 elem1
= builder
.create
<fir::ConvertOp
>(loc
, resultElementType
, elem1
);
975 mlir::Type eleRef2Ty
= builder
.getRefType(arg2ElementTy
);
977 builder
.create
<fir::CoordinateOp
>(loc
, eleRef2Ty
, array2
, index
);
978 mlir::Value elem2
= builder
.create
<fir::LoadOp
>(loc
, addr2
);
979 // Convert to the result type.
980 elem2
= builder
.create
<fir::ConvertOp
>(loc
, resultElementType
, elem2
);
982 if (mlir::isa
<mlir::FloatType
>(resultElementType
))
983 sumVal
= builder
.create
<mlir::arith::AddFOp
>(
984 loc
, builder
.create
<mlir::arith::MulFOp
>(loc
, elem1
, elem2
), sumVal
);
985 else if (mlir::isa
<mlir::IntegerType
>(resultElementType
))
986 sumVal
= builder
.create
<mlir::arith::AddIOp
>(
987 loc
, builder
.create
<mlir::arith::MulIOp
>(loc
, elem1
, elem2
), sumVal
);
989 llvm_unreachable("unsupported type");
991 builder
.create
<fir::ResultOp
>(loc
, sumVal
);
993 builder
.restoreInsertionPoint(loopEndPt
);
995 mlir::Value resultVal
= loop
.getResult(0);
996 builder
.create
<mlir::func::ReturnOp
>(loc
, resultVal
);
999 mlir::func::FuncOp
SimplifyIntrinsicsPass::getOrCreateFunction(
1000 fir::FirOpBuilder
&builder
, const mlir::StringRef
&baseName
,
1001 FunctionTypeGeneratorTy typeGenerator
,
1002 FunctionBodyGeneratorTy bodyGenerator
) {
1003 // WARNING: if the function generated here changes its signature
1004 // or behavior (the body code), we should probably embed some
1005 // versioning information into its name, otherwise libraries
1006 // statically linked with older versions of Flang may stop
1007 // working with object files created with newer Flang.
1008 // We can also avoid this by using internal linkage, but
1009 // this may increase the size of final executable/shared library.
1010 std::string replacementName
= mlir::Twine
{baseName
, "_simplified"}.str();
1011 // If we already have a function, just return it.
1012 mlir::func::FuncOp newFunc
= builder
.getNamedFunction(replacementName
);
1013 mlir::FunctionType fType
= typeGenerator(builder
);
1015 assert(newFunc
.getFunctionType() == fType
&&
1016 "type mismatch for simplified function");
1020 // Need to build the function!
1021 auto loc
= mlir::UnknownLoc::get(builder
.getContext());
1022 newFunc
= builder
.createFunction(loc
, replacementName
, fType
);
1023 auto inlineLinkage
= mlir::LLVM::linkage::Linkage::LinkonceODR
;
1025 mlir::LLVM::LinkageAttr::get(builder
.getContext(), inlineLinkage
);
1026 newFunc
->setAttr("llvm.linkage", linkage
);
1028 // Save the position of the original call.
1029 mlir::OpBuilder::InsertPoint insertPt
= builder
.saveInsertionPoint();
1031 bodyGenerator(builder
, newFunc
);
1033 // Now back to where we were adding code earlier...
1034 builder
.restoreInsertionPoint(insertPt
);
1039 void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction(
1040 fir::CallOp call
, const fir::KindMapping
&kindMap
,
1041 GenReductionBodyTy genBodyFunc
) {
1042 // args[1] and args[2] are source filename and line number, ignored.
1043 mlir::Operation::operand_range args
= call
.getArgs();
1045 const mlir::Value
&dim
= args
[3];
1046 const mlir::Value
&mask
= args
[4];
1047 // dim is zero when it is absent, which is an implementation
1048 // detail in the runtime library.
1050 bool dimAndMaskAbsent
= isZero(dim
) && isOperandAbsent(mask
);
1051 unsigned rank
= getDimCount(args
[0]);
1053 // Rank is set to 0 for assumed shape arrays, don't simplify
1055 if (!(dimAndMaskAbsent
&& rank
> 0))
1058 mlir::Type resultType
= call
.getResult(0).getType();
1060 if (!mlir::isa
<mlir::FloatType
>(resultType
) &&
1061 !mlir::isa
<mlir::IntegerType
>(resultType
))
1064 auto argType
= getArgElementType(args
[0]);
1067 assert(*argType
== resultType
&&
1068 "Argument/result types mismatch in reduction");
1070 mlir::SymbolRefAttr callee
= call
.getCalleeAttr();
1072 fir::FirOpBuilder builder
{getSimplificationBuilder(call
, kindMap
)};
1073 std::string fmfString
{builder
.getFastMathFlagsString()};
1074 std::string funcName
=
1075 (mlir::Twine
{callee
.getLeafReference().getValue(), "x"} +
1077 // We must mangle the generated function name with FastMathFlags
1079 (fmfString
.empty() ? mlir::Twine
{} : mlir::Twine
{"_", fmfString
}))
1082 simplifyReductionBody(call
, kindMap
, genBodyFunc
, builder
, funcName
,
1086 void SimplifyIntrinsicsPass::simplifyLogicalDim0Reduction(
1087 fir::CallOp call
, const fir::KindMapping
&kindMap
,
1088 GenReductionBodyTy genBodyFunc
) {
1090 mlir::Operation::operand_range args
= call
.getArgs();
1091 const mlir::Value
&dim
= args
[3];
1092 unsigned rank
= getDimCount(args
[0]);
1094 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1096 if (!(isZero(dim
) && rank
> 0))
1099 mlir::Value inputBox
= findBoxDef(args
[0]);
1101 mlir::Type elementType
= hlfir::getFortranElementType(inputBox
.getType());
1102 mlir::SymbolRefAttr callee
= call
.getCalleeAttr();
1104 fir::FirOpBuilder builder
{getSimplificationBuilder(call
, kindMap
)};
1106 // Treating logicals as integers makes things a lot easier
1107 fir::LogicalType logicalType
= {
1108 mlir::dyn_cast
<fir::LogicalType
>(elementType
)};
1109 fir::KindTy kind
= logicalType
.getFKind();
1110 mlir::Type intElementType
= builder
.getIntegerType(kind
* 8);
1112 // Mangle kind into function name as it is not done by default
1113 std::string funcName
=
1114 (mlir::Twine
{callee
.getLeafReference().getValue(), "Logical"} +
1115 mlir::Twine
{kind
} + "x" + mlir::Twine
{rank
})
1118 simplifyReductionBody(call
, kindMap
, genBodyFunc
, builder
, funcName
,
1122 void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction(
1123 fir::CallOp call
, const fir::KindMapping
&kindMap
,
1124 GenReductionBodyTy genBodyFunc
) {
1126 mlir::Operation::operand_range args
= call
.getArgs();
1127 mlir::SymbolRefAttr callee
= call
.getCalleeAttr();
1128 mlir::StringRef funcNameBase
= callee
.getLeafReference().getValue();
1129 unsigned rank
= getDimCount(args
[0]);
1131 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1132 // these cases. We check for Dim at the end as some logical functions (Any,
1133 // All) set dim to 1 instead of 0 when the argument is not present.
1134 if (funcNameBase
.ends_with("Dim") || !(rank
> 0))
1137 mlir::Value inputBox
= findBoxDef(args
[0]);
1138 mlir::Type elementType
= hlfir::getFortranElementType(inputBox
.getType());
1140 fir::FirOpBuilder builder
{getSimplificationBuilder(call
, kindMap
)};
1142 // Treating logicals as integers makes things a lot easier
1143 fir::LogicalType logicalType
= {
1144 mlir::dyn_cast
<fir::LogicalType
>(elementType
)};
1145 fir::KindTy kind
= logicalType
.getFKind();
1146 mlir::Type intElementType
= builder
.getIntegerType(kind
* 8);
1148 // Mangle kind into function name as it is not done by default
1149 std::string funcName
=
1150 (mlir::Twine
{callee
.getLeafReference().getValue(), "Logical"} +
1151 mlir::Twine
{kind
} + "x" + mlir::Twine
{rank
})
1154 simplifyReductionBody(call
, kindMap
, genBodyFunc
, builder
, funcName
,
1158 void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
1159 fir::CallOp call
, const fir::KindMapping
&kindMap
, bool isMax
) {
1161 mlir::Operation::operand_range args
= call
.getArgs();
1163 mlir::SymbolRefAttr callee
= call
.getCalleeAttr();
1164 mlir::StringRef funcNameBase
= callee
.getLeafReference().getValue();
1165 bool isDim
= funcNameBase
.ends_with("Dim");
1166 mlir::Value back
= args
[isDim
? 7 : 6];
1167 if (isTrueOrNotConstant(back
))
1170 mlir::Value mask
= args
[isDim
? 6 : 5];
1171 mlir::Value maskDef
= findMaskDef(mask
);
1173 // maskDef is set to NULL when the defining op is not one we accept.
1174 // This tends to be because it is a selectOp, in which case let the
1175 // runtime deal with it.
1176 if (maskDef
== NULL
)
1179 unsigned rank
= getDimCount(args
[1]);
1180 if ((isDim
&& rank
!= 1) || !(rank
> 0))
1183 fir::FirOpBuilder builder
{getSimplificationBuilder(call
, kindMap
)};
1184 mlir::Location loc
= call
.getLoc();
1185 auto inputBox
= findBoxDef(args
[1]);
1186 mlir::Type inputType
= hlfir::getFortranElementType(inputBox
.getType());
1188 if (mlir::isa
<fir::CharacterType
>(inputType
))
1192 fir::KindTy kind
= 0;
1193 mlir::Type logicalElemType
= builder
.getI1Type();
1194 if (isOperandAbsent(mask
)) {
1197 maskRank
= getDimCount(mask
);
1198 mlir::Type maskElemTy
= hlfir::getFortranElementType(maskDef
.getType());
1199 fir::LogicalType logicalFirType
= {
1200 mlir::dyn_cast
<fir::LogicalType
>(maskElemTy
)};
1201 kind
= logicalFirType
.getFKind();
1202 // Convert fir::LogicalType to mlir::Type
1203 logicalElemType
= logicalFirType
;
1206 mlir::Operation
*outputDef
= args
[0].getDefiningOp();
1207 mlir::Value outputAlloc
= outputDef
->getOperand(0);
1208 mlir::Type outType
= hlfir::getFortranElementType(outputAlloc
.getType());
1210 std::string fmfString
{builder
.getFastMathFlagsString()};
1211 std::string funcName
=
1212 (mlir::Twine
{callee
.getLeafReference().getValue(), "x"} +
1215 ? "_Logical" + mlir::Twine
{kind
} + "x" + mlir::Twine
{maskRank
}
1220 llvm::raw_string_ostream
nameOS(funcName
);
1221 outType
.print(nameOS
);
1223 nameOS
<< '_' << inputType
;
1224 nameOS
<< '_' << fmfString
;
1226 auto typeGenerator
= [rank
](fir::FirOpBuilder
&builder
) {
1227 return genRuntimeMinlocType(builder
, rank
);
1229 auto bodyGenerator
= [rank
, maskRank
, inputType
, logicalElemType
, outType
,
1230 isMax
, isDim
](fir::FirOpBuilder
&builder
,
1231 mlir::func::FuncOp
&funcOp
) {
1232 genRuntimeMinMaxlocBody(builder
, funcOp
, isMax
, rank
, maskRank
, inputType
,
1233 logicalElemType
, outType
, isDim
);
1236 mlir::func::FuncOp newFunc
=
1237 getOrCreateFunction(builder
, funcName
, typeGenerator
, bodyGenerator
);
1238 builder
.create
<fir::CallOp
>(loc
, newFunc
,
1239 mlir::ValueRange
{args
[0], args
[1], mask
});
1240 call
->dropAllReferences();
1244 void SimplifyIntrinsicsPass::simplifyReductionBody(
1245 fir::CallOp call
, const fir::KindMapping
&kindMap
,
1246 GenReductionBodyTy genBodyFunc
, fir::FirOpBuilder
&builder
,
1247 const mlir::StringRef
&funcName
, mlir::Type elementType
) {
1249 mlir::Operation::operand_range args
= call
.getArgs();
1251 mlir::Type resultType
= call
.getResult(0).getType();
1252 unsigned rank
= getDimCount(args
[0]);
1254 mlir::Location loc
= call
.getLoc();
1256 auto typeGenerator
= [&resultType
](fir::FirOpBuilder
&builder
) {
1257 return genNoneBoxType(builder
, resultType
);
1259 auto bodyGenerator
= [&rank
, &genBodyFunc
,
1260 &elementType
](fir::FirOpBuilder
&builder
,
1261 mlir::func::FuncOp
&funcOp
) {
1262 genBodyFunc(builder
, funcOp
, rank
, elementType
);
1264 // Mangle the function name with the rank value as "x<rank>".
1265 mlir::func::FuncOp newFunc
=
1266 getOrCreateFunction(builder
, funcName
, typeGenerator
, bodyGenerator
);
1268 builder
.create
<fir::CallOp
>(loc
, newFunc
, mlir::ValueRange
{args
[0]});
1269 call
->replaceAllUsesWith(newCall
.getResults());
1270 call
->dropAllReferences();
1274 void SimplifyIntrinsicsPass::runOnOperation() {
1275 LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE
" ===\n");
1276 mlir::ModuleOp module
= getOperation();
1277 fir::KindMapping kindMap
= fir::getKindMapping(module
);
1278 module
.walk([&](mlir::Operation
*op
) {
1279 if (auto call
= mlir::dyn_cast
<fir::CallOp
>(op
)) {
1280 if (cuf::isInCUDADeviceContext(op
))
1282 if (mlir::SymbolRefAttr callee
= call
.getCalleeAttr()) {
1283 mlir::StringRef funcName
= callee
.getLeafReference().getValue();
1284 // Replace call to runtime function for SUM when it has single
1285 // argument (no dim or mask argument) for 1D arrays with either
1286 // Integer4 or Real8 types. Other forms are ignored.
1287 // The new function is added to the module.
1289 // Prototype for runtime call (from sum.cpp):
1290 // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
1291 // int dim, const Descriptor *mask)
1293 if (funcName
.starts_with(RTNAME_STRING(Sum
))) {
1294 simplifyIntOrFloatReduction(call
, kindMap
, genRuntimeSumBody
);
1297 if (funcName
.starts_with(RTNAME_STRING(DotProduct
))) {
1298 LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName
<< "\n");
1299 LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op
->dump();
1300 llvm::dbgs() << "\n");
1301 mlir::Operation::operand_range args
= call
.getArgs();
1302 const mlir::Value
&v1
= args
[0];
1303 const mlir::Value
&v2
= args
[1];
1304 mlir::Location loc
= call
.getLoc();
1305 fir::FirOpBuilder builder
{getSimplificationBuilder(op
, kindMap
)};
1306 // Stringize the builder's FastMathFlags flags for mangling
1307 // the generated function name.
1308 std::string fmfString
{builder
.getFastMathFlagsString()};
1310 mlir::Type type
= call
.getResult(0).getType();
1311 if (!mlir::isa
<mlir::FloatType
>(type
) &&
1312 !mlir::isa
<mlir::IntegerType
>(type
))
1315 // Try to find the element types of the boxed arguments.
1316 auto arg1Type
= getArgElementType(v1
);
1317 auto arg2Type
= getArgElementType(v2
);
1319 if (!arg1Type
|| !arg2Type
)
1322 // Support only floating point and integer arguments
1323 // now (e.g. logical is skipped here).
1324 if (!mlir::isa
<mlir::FloatType
, mlir::IntegerType
>(*arg1Type
))
1326 if (!mlir::isa
<mlir::FloatType
, mlir::IntegerType
>(*arg2Type
))
1329 auto typeGenerator
= [&type
](fir::FirOpBuilder
&builder
) {
1330 return genRuntimeDotType(builder
, type
);
1332 auto bodyGenerator
= [&arg1Type
,
1333 &arg2Type
](fir::FirOpBuilder
&builder
,
1334 mlir::func::FuncOp
&funcOp
) {
1335 genRuntimeDotBody(builder
, funcOp
, *arg1Type
, *arg2Type
);
1338 // Suffix the function name with the element types
1339 // of the arguments.
1340 std::string
typedFuncName(funcName
);
1341 llvm::raw_string_ostream
nameOS(typedFuncName
);
1342 // We must mangle the generated function name with FastMathFlags
1344 if (!fmfString
.empty())
1345 nameOS
<< '_' << fmfString
;
1347 arg1Type
->print(nameOS
);
1349 arg2Type
->print(nameOS
);
1351 mlir::func::FuncOp newFunc
= getOrCreateFunction(
1352 builder
, typedFuncName
, typeGenerator
, bodyGenerator
);
1353 auto newCall
= builder
.create
<fir::CallOp
>(loc
, newFunc
,
1354 mlir::ValueRange
{v1
, v2
});
1355 call
->replaceAllUsesWith(newCall
.getResults());
1356 call
->dropAllReferences();
1359 LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall
.dump();
1360 llvm::dbgs() << "\n");
1363 if (funcName
.starts_with(RTNAME_STRING(Maxval
))) {
1364 simplifyIntOrFloatReduction(call
, kindMap
, genRuntimeMaxvalBody
);
1367 if (funcName
.starts_with(RTNAME_STRING(Count
))) {
1368 simplifyLogicalDim0Reduction(call
, kindMap
, genRuntimeCountBody
);
1371 if (funcName
.starts_with(RTNAME_STRING(Any
))) {
1372 simplifyLogicalDim1Reduction(call
, kindMap
, genRuntimeAnyBody
);
1375 if (funcName
.ends_with(RTNAME_STRING(All
))) {
1376 simplifyLogicalDim1Reduction(call
, kindMap
, genRuntimeAllBody
);
1379 if (funcName
.starts_with(RTNAME_STRING(Minloc
))) {
1380 simplifyMinMaxlocReduction(call
, kindMap
, false);
1383 if (funcName
.starts_with(RTNAME_STRING(Maxloc
))) {
1384 simplifyMinMaxlocReduction(call
, kindMap
, true);
1390 LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE
" ===\n");
1393 void SimplifyIntrinsicsPass::getDependentDialects(
1394 mlir::DialectRegistry
®istry
) const {
1395 // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
1396 registry
.insert
<mlir::LLVM::LLVMDialect
>();