1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 //===----------------------------------------------------------------------===//
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 // for higher dimension arrays, of course)
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
25 #include "flang/Common/Fortran.h"
26 #include "flang/Optimizer/Builder/BoxValue.h"
27 #include "flang/Optimizer/Builder/FIRBuilder.h"
28 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
29 #include "flang/Optimizer/Builder/Todo.h"
30 #include "flang/Optimizer/Dialect/FIROps.h"
31 #include "flang/Optimizer/Dialect/FIRType.h"
32 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
33 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
34 #include "flang/Optimizer/Transforms/Passes.h"
35 #include "flang/Runtime/entry-names.h"
36 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
37 #include "mlir/IR/Matchers.h"
38 #include "mlir/IR/Operation.h"
39 #include "mlir/Pass/Pass.h"
40 #include "mlir/Transforms/DialectConversion.h"
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
42 #include "mlir/Transforms/RegionUtils.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/Support/raw_ostream.h"
45 #include <llvm/CodeGen/SelectionDAGNodes.h>
46 #include <llvm/Support/ErrorHandling.h>
47 #include <mlir/Dialect/Arith/IR/Arith.h>
48 #include <mlir/IR/BuiltinTypes.h>
49 #include <mlir/IR/Location.h>
50 #include <mlir/IR/MLIRContext.h>
51 #include <mlir/IR/Value.h>
52 #include <mlir/Support/LLVM.h>
56 #define GEN_PASS_DEF_SIMPLIFYINTRINSICS
57 #include "flang/Optimizer/Transforms/Passes.h.inc"
60 #define DEBUG_TYPE "flang-simplify-intrinsics"
64 class SimplifyIntrinsicsPass
65 : public fir::impl::SimplifyIntrinsicsBase
<SimplifyIntrinsicsPass
> {
66 using FunctionTypeGeneratorTy
=
67 llvm::function_ref
<mlir::FunctionType(fir::FirOpBuilder
&)>;
68 using FunctionBodyGeneratorTy
=
69 llvm::function_ref
<void(fir::FirOpBuilder
&, mlir::func::FuncOp
&)>;
70 using GenReductionBodyTy
= llvm::function_ref
<void(
71 fir::FirOpBuilder
&builder
, mlir::func::FuncOp
&funcOp
, unsigned rank
,
72 mlir::Type elementType
)>;
75 /// Generate a new function implementing a simplified version
76 /// of a Fortran runtime function defined by \p basename name.
77 /// \p typeGenerator is a callback that generates the new function's type.
78 /// \p bodyGenerator is a callback that generates the new function's body.
79 /// The new function is created in the \p builder's Module.
80 mlir::func::FuncOp
getOrCreateFunction(fir::FirOpBuilder
&builder
,
81 const mlir::StringRef
&basename
,
82 FunctionTypeGeneratorTy typeGenerator
,
83 FunctionBodyGeneratorTy bodyGenerator
);
84 void runOnOperation() override
;
85 void getDependentDialects(mlir::DialectRegistry
®istry
) const override
;
88 /// Helper functions to replace a reduction type of call with its
89 /// simplified form. The actual function is generated using a callback
91 /// \p call is the call to be replaced
92 /// \p kindMap is used to create FIROpBuilder
93 /// \p genBodyFunc is the callback that builds the replacement function
94 void simplifyIntOrFloatReduction(fir::CallOp call
,
95 const fir::KindMapping
&kindMap
,
96 GenReductionBodyTy genBodyFunc
);
97 void simplifyLogicalDim0Reduction(fir::CallOp call
,
98 const fir::KindMapping
&kindMap
,
99 GenReductionBodyTy genBodyFunc
);
100 void simplifyLogicalDim1Reduction(fir::CallOp call
,
101 const fir::KindMapping
&kindMap
,
102 GenReductionBodyTy genBodyFunc
);
103 void simplifyMinlocReduction(fir::CallOp call
,
104 const fir::KindMapping
&kindMap
);
105 void simplifyReductionBody(fir::CallOp call
, const fir::KindMapping
&kindMap
,
106 GenReductionBodyTy genBodyFunc
,
107 fir::FirOpBuilder
&builder
,
108 const mlir::StringRef
&basename
,
109 mlir::Type elementType
);
114 /// Create FirOpBuilder with the provided \p op insertion point
115 /// and \p kindMap additionally inheriting FastMathFlags from \p op.
116 static fir::FirOpBuilder
117 getSimplificationBuilder(mlir::Operation
*op
, const fir::KindMapping
&kindMap
) {
118 fir::FirOpBuilder builder
{op
, kindMap
};
119 auto fmi
= mlir::dyn_cast
<mlir::arith::ArithFastMathInterface
>(*op
);
123 // Regardless of what default FastMathFlags are used by FirOpBuilder,
124 // override them with FastMathFlags attached to the operation.
125 builder
.setFastMathFlags(fmi
.getFastMathFlagsAttr().getValue());
129 /// Stringify FastMathFlags set for the given \p builder in a way
130 /// that the string may be used for mangling a function name.
131 /// If FastMathFlags are set to 'none', then the result is an empty
133 static std::string
getFastMathFlagsString(const fir::FirOpBuilder
&builder
) {
134 mlir::arith::FastMathFlags flags
= builder
.getFastMathFlags();
135 if (flags
== mlir::arith::FastMathFlags::none
)
138 std::string fmfString
{mlir::arith::stringifyFastMathFlags(flags
)};
139 std::replace(fmfString
.begin(), fmfString
.end(), ',', '_');
143 /// Generate function type for the simplified version of RTNAME(Sum) and
144 /// similar functions with a fir.box<none> type returning \p elementType.
145 static mlir::FunctionType
genNoneBoxType(fir::FirOpBuilder
&builder
,
146 const mlir::Type
&elementType
) {
147 mlir::Type boxType
= fir::BoxType::get(builder
.getNoneType());
148 return mlir::FunctionType::get(builder
.getContext(), {boxType
},
152 template <typename Op
>
153 Op
expectOp(mlir::Value val
) {
154 if (Op op
= mlir::dyn_cast_or_null
<Op
>(val
.getDefiningOp()))
156 LLVM_DEBUG(llvm::dbgs() << "Didn't find expected " << Op::getOperationName()
161 template <typename Op
>
162 static mlir::Value
findDefSingle(fir::ConvertOp op
) {
163 if (auto defOp
= expectOp
<Op
>(op
->getOperand(0))) {
164 return defOp
.getResult();
169 template <typename
... Ops
>
170 static mlir::Value
findDef(fir::ConvertOp op
) {
172 // Loop over the operation types given to see if any match, exiting once
173 // a match is found. Cast to void is needed to avoid compiler complaining
174 // that the result of expression is unused
175 (void)((defOp
= findDefSingle
<Ops
>(op
), (defOp
)) || ...);
179 static bool isOperandAbsent(mlir::Value val
) {
180 if (auto op
= expectOp
<fir::ConvertOp
>(val
)) {
181 assert(op
->getOperands().size() != 0);
182 return mlir::isa_and_nonnull
<fir::AbsentOp
>(
183 op
->getOperand(0).getDefiningOp());
188 static bool isTrueOrNotConstant(mlir::Value val
) {
189 if (auto op
= expectOp
<mlir::arith::ConstantOp
>(val
)) {
190 return !mlir::matchPattern(val
, mlir::m_Zero());
195 static bool isZero(mlir::Value val
) {
196 if (auto op
= expectOp
<fir::ConvertOp
>(val
)) {
197 assert(op
->getOperands().size() != 0);
198 if (mlir::Operation
*defOp
= op
->getOperand(0).getDefiningOp())
199 return mlir::matchPattern(defOp
, mlir::m_Zero());
204 static mlir::Value
findBoxDef(mlir::Value val
) {
205 if (auto op
= expectOp
<fir::ConvertOp
>(val
)) {
206 assert(op
->getOperands().size() != 0);
207 return findDef
<fir::EmboxOp
, fir::ReboxOp
>(op
);
212 static mlir::Value
findMaskDef(mlir::Value val
) {
213 if (auto op
= expectOp
<fir::ConvertOp
>(val
)) {
214 assert(op
->getOperands().size() != 0);
215 return findDef
<fir::EmboxOp
, fir::ReboxOp
, fir::AbsentOp
>(op
);
220 static unsigned getDimCount(mlir::Value val
) {
221 // In order to find the dimensions count, we look for EmboxOp/ReboxOp
222 // and take the count from its *result* type. Note that in case
223 // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
224 // have different types.
225 // Actually, we can take the box type from the operand of
226 // the first ConvertOp that has non-opaque box type that we meet
227 // going through the ConvertOp chain.
228 if (mlir::Value emboxVal
= findBoxDef(val
))
229 if (auto boxTy
= emboxVal
.getType().dyn_cast
<fir::BoxType
>())
230 if (auto seqTy
= boxTy
.getEleTy().dyn_cast
<fir::SequenceType
>())
231 return seqTy
.getDimension();
235 /// Given the call operation's box argument \p val, discover
236 /// the element type of the underlying array object.
237 /// \returns the element type or std::nullopt if the type cannot
238 /// be reliably found.
239 /// We expect that the argument is a result of fir.convert
240 /// with the destination type of !fir.box<none>.
241 static std::optional
<mlir::Type
> getArgElementType(mlir::Value val
) {
242 mlir::Operation
*defOp
;
244 defOp
= val
.getDefiningOp();
245 // Analyze only sequences of convert operations.
246 if (!mlir::isa
<fir::ConvertOp
>(defOp
))
248 val
= defOp
->getOperand(0);
249 // The convert operation is expected to convert from one
250 // box type to another box type.
251 auto boxType
= val
.getType().cast
<fir::BoxType
>();
252 auto elementType
= fir::unwrapSeqOrBoxedSeqType(boxType
);
253 if (!elementType
.isa
<mlir::NoneType
>())
258 using BodyOpGeneratorTy
= llvm::function_ref
<mlir::Value(
259 fir::FirOpBuilder
&, mlir::Location
, const mlir::Type
&, mlir::Value
,
261 using InitValGeneratorTy
= llvm::function_ref
<mlir::Value(
262 fir::FirOpBuilder
&, mlir::Location
, const mlir::Type
&)>;
263 using ContinueLoopGenTy
= llvm::function_ref
<llvm::SmallVector
<mlir::Value
>(
264 fir::FirOpBuilder
&, mlir::Location
, mlir::Value
)>;
266 /// Generate the reduction loop into \p funcOp.
268 /// \p initVal is a function, called to get the initial value for
269 /// the reduction value
270 /// \p genBody is called to fill in the actual reduciton operation
271 /// for example add for SUM, MAX for MAXVAL, etc.
272 /// \p rank is the rank of the input argument.
273 /// \p elementType is the type of the elements in the input array,
274 /// which may be different to the return type.
275 /// \p loopCond is called to generate the condition to continue or
276 /// not for IterWhile loops
277 /// \p unorderedOrInitalLoopCond contains either a boolean or bool
278 /// mlir constant, and controls the inital value for while loops
279 /// or if DoLoop is ordered/unordered.
281 template <typename OP
, typename T
, int resultIndex
>
283 genReductionLoop(fir::FirOpBuilder
&builder
, mlir::func::FuncOp
&funcOp
,
284 InitValGeneratorTy initVal
, ContinueLoopGenTy loopCond
,
285 T unorderedOrInitialLoopCond
, BodyOpGeneratorTy genBody
,
286 unsigned rank
, mlir::Type elementType
, mlir::Location loc
) {
288 mlir::IndexType idxTy
= builder
.getIndexType();
290 mlir::Block::BlockArgListType args
= funcOp
.front().getArguments();
291 mlir::Value arg
= args
[0];
293 mlir::Value zeroIdx
= builder
.createIntegerConstant(loc
, idxTy
, 0);
295 fir::SequenceType::Shape
flatShape(rank
,
296 fir::SequenceType::getUnknownExtent());
297 mlir::Type arrTy
= fir::SequenceType::get(flatShape
, elementType
);
298 mlir::Type boxArrTy
= fir::BoxType::get(arrTy
);
299 mlir::Value array
= builder
.create
<fir::ConvertOp
>(loc
, boxArrTy
, arg
);
300 mlir::Type resultType
= funcOp
.getResultTypes()[0];
301 mlir::Value init
= initVal(builder
, loc
, resultType
);
303 llvm::SmallVector
<mlir::Value
, Fortran::common::maxRank
> bounds
;
305 assert(rank
> 0 && "rank cannot be zero");
306 mlir::Value one
= builder
.createIntegerConstant(loc
, idxTy
, 1);
308 // Compute all the upper bounds before the loop nest.
309 // It is not strictly necessary for performance, since the loop nest
310 // does not have any store operations and any LICM optimization
311 // should be able to optimize the redundancy.
312 for (unsigned i
= 0; i
< rank
; ++i
) {
313 mlir::Value dimIdx
= builder
.createIntegerConstant(loc
, idxTy
, i
);
315 builder
.create
<fir::BoxDimsOp
>(loc
, idxTy
, idxTy
, idxTy
, array
, dimIdx
);
316 mlir::Value len
= dims
.getResult(1);
317 // We use C indexing here, so len-1 as loopcount
318 mlir::Value loopCount
= builder
.create
<mlir::arith::SubIOp
>(loc
, len
, one
);
319 bounds
.push_back(loopCount
);
321 // Create a loop nest consisting of OP operations.
322 // Collect the loops' induction variables into indices array,
323 // which will be used in the innermost loop to load the input
325 // The loops are generated such that the innermost loop processes
327 llvm::SmallVector
<mlir::Value
, Fortran::common::maxRank
> indices
;
328 for (unsigned i
= rank
; 0 < i
; --i
) {
329 mlir::Value step
= one
;
330 mlir::Value loopCount
= bounds
[i
- 1];
331 auto loop
= builder
.create
<OP
>(loc
, zeroIdx
, loopCount
, step
,
332 unorderedOrInitialLoopCond
,
333 /*finalCountValue=*/false, init
);
334 init
= loop
.getRegionIterArgs()[resultIndex
];
335 indices
.push_back(loop
.getInductionVar());
336 // Set insertion point to the loop body so that the next loop
337 // is inserted inside the current one.
338 builder
.setInsertionPointToStart(loop
.getBody());
341 // Reverse the indices such that they are ordered as:
342 // <dim-0-idx, dim-1-idx, ...>
343 std::reverse(indices
.begin(), indices
.end());
344 // We are in the innermost loop: generate the reduction body.
345 mlir::Type eleRefTy
= builder
.getRefType(elementType
);
347 builder
.create
<fir::CoordinateOp
>(loc
, eleRefTy
, array
, indices
);
348 mlir::Value elem
= builder
.create
<fir::LoadOp
>(loc
, addr
);
349 mlir::Value reductionVal
= genBody(builder
, loc
, elementType
, elem
, init
);
350 // Generate vector with condition to continue while loop at [0] and result
351 // from current loop at [1] for IterWhileOp loops, just result at [0] for
353 llvm::SmallVector
<mlir::Value
> results
= loopCond(builder
, loc
, reductionVal
);
355 // Unwind the loop nest and insert ResultOp on each level
356 // to return the updated value of the reduction to the enclosing
358 for (unsigned i
= 0; i
< rank
; ++i
) {
359 auto result
= builder
.create
<fir::ResultOp
>(loc
, results
);
360 // Proceed to the outer loop.
361 auto loop
= mlir::cast
<OP
>(result
->getParentOp());
362 results
= loop
.getResults();
363 // Set insertion point after the loop operation that we have
365 builder
.setInsertionPointAfter(loop
.getOperation());
367 // End of loop nest. The insertion point is after the outermost loop.
368 // Return the reduction value from the function.
369 builder
.create
<mlir::func::ReturnOp
>(loc
, results
[resultIndex
]);
371 using MinlocBodyOpGeneratorTy
= llvm::function_ref
<mlir::Value(
372 fir::FirOpBuilder
&, mlir::Location
, const mlir::Type
&, mlir::Value
,
373 mlir::Value
, llvm::SmallVector
<mlir::Value
, Fortran::common::maxRank
> &)>;
376 genMinlocReductionLoop(fir::FirOpBuilder
&builder
, mlir::func::FuncOp
&funcOp
,
377 InitValGeneratorTy initVal
,
378 MinlocBodyOpGeneratorTy genBody
, unsigned rank
,
379 mlir::Type elementType
, mlir::Location loc
, bool hasMask
,
380 mlir::Type maskElemType
, mlir::Value resultArr
) {
382 mlir::IndexType idxTy
= builder
.getIndexType();
384 mlir::Block::BlockArgListType args
= funcOp
.front().getArguments();
385 mlir::Value arg
= args
[1];
387 mlir::Value zeroIdx
= builder
.createIntegerConstant(loc
, idxTy
, 0);
389 fir::SequenceType::Shape
flatShape(rank
,
390 fir::SequenceType::getUnknownExtent());
391 mlir::Type arrTy
= fir::SequenceType::get(flatShape
, elementType
);
392 mlir::Type boxArrTy
= fir::BoxType::get(arrTy
);
393 mlir::Value array
= builder
.create
<fir::ConvertOp
>(loc
, boxArrTy
, arg
);
395 mlir::Type resultElemType
= hlfir::getFortranElementType(resultArr
.getType());
396 mlir::Value flagSet
= builder
.createIntegerConstant(loc
, resultElemType
, 1);
397 mlir::Value zero
= builder
.createIntegerConstant(loc
, resultElemType
, 0);
398 mlir::Value flagRef
= builder
.createTemporary(loc
, resultElemType
);
399 builder
.create
<fir::StoreOp
>(loc
, zero
, flagRef
);
403 mlir::Type maskTy
= fir::SequenceType::get(flatShape
, maskElemType
);
404 mlir::Type boxMaskTy
= fir::BoxType::get(maskTy
);
405 mask
= builder
.create
<fir::ConvertOp
>(loc
, boxMaskTy
, args
[2]);
408 mlir::Value init
= initVal(builder
, loc
, elementType
);
409 llvm::SmallVector
<mlir::Value
, Fortran::common::maxRank
> bounds
;
411 assert(rank
> 0 && "rank cannot be zero");
412 mlir::Value one
= builder
.createIntegerConstant(loc
, idxTy
, 1);
414 // Compute all the upper bounds before the loop nest.
415 // It is not strictly necessary for performance, since the loop nest
416 // does not have any store operations and any LICM optimization
417 // should be able to optimize the redundancy.
418 for (unsigned i
= 0; i
< rank
; ++i
) {
419 mlir::Value dimIdx
= builder
.createIntegerConstant(loc
, idxTy
, i
);
421 builder
.create
<fir::BoxDimsOp
>(loc
, idxTy
, idxTy
, idxTy
, array
, dimIdx
);
422 mlir::Value len
= dims
.getResult(1);
423 // We use C indexing here, so len-1 as loopcount
424 mlir::Value loopCount
= builder
.create
<mlir::arith::SubIOp
>(loc
, len
, one
);
425 bounds
.push_back(loopCount
);
427 // Create a loop nest consisting of OP operations.
428 // Collect the loops' induction variables into indices array,
429 // which will be used in the innermost loop to load the input
431 // The loops are generated such that the innermost loop processes
433 llvm::SmallVector
<mlir::Value
, Fortran::common::maxRank
> indices
;
434 for (unsigned i
= rank
; 0 < i
; --i
) {
435 mlir::Value step
= one
;
436 mlir::Value loopCount
= bounds
[i
- 1];
438 builder
.create
<fir::DoLoopOp
>(loc
, zeroIdx
, loopCount
, step
, false,
439 /*finalCountValue=*/false, init
);
440 init
= loop
.getRegionIterArgs()[0];
441 indices
.push_back(loop
.getInductionVar());
442 // Set insertion point to the loop body so that the next loop
443 // is inserted inside the current one.
444 builder
.setInsertionPointToStart(loop
.getBody());
447 // Reverse the indices such that they are ordered as:
448 // <dim-0-idx, dim-1-idx, ...>
449 std::reverse(indices
.begin(), indices
.end());
450 // We are in the innermost loop: generate the reduction body.
452 mlir::Type logicalRef
= builder
.getRefType(maskElemType
);
453 mlir::Value maskAddr
=
454 builder
.create
<fir::CoordinateOp
>(loc
, logicalRef
, mask
, indices
);
455 mlir::Value maskElem
= builder
.create
<fir::LoadOp
>(loc
, maskAddr
);
457 // fir::IfOp requires argument to be I1 - won't accept logical or any other
459 mlir::Type ifCompatType
= builder
.getI1Type();
460 mlir::Value ifCompatElem
=
461 builder
.create
<fir::ConvertOp
>(loc
, ifCompatType
, maskElem
);
463 llvm::SmallVector
<mlir::Type
> resultsTy
= {elementType
, elementType
};
464 fir::IfOp ifOp
= builder
.create
<fir::IfOp
>(loc
, elementType
, ifCompatElem
,
465 /*withElseRegion=*/true);
466 builder
.setInsertionPointToStart(&ifOp
.getThenRegion().front());
469 // Set flag that mask was true at some point
470 builder
.create
<fir::StoreOp
>(loc
, flagSet
, flagRef
);
471 mlir::Type eleRefTy
= builder
.getRefType(elementType
);
473 builder
.create
<fir::CoordinateOp
>(loc
, eleRefTy
, array
, indices
);
474 mlir::Value elem
= builder
.create
<fir::LoadOp
>(loc
, addr
);
476 mlir::Value reductionVal
=
477 genBody(builder
, loc
, elementType
, elem
, init
, indices
);
481 mlir::dyn_cast
<fir::IfOp
>(builder
.getBlock()->getParentOp());
482 builder
.create
<fir::ResultOp
>(loc
, reductionVal
);
483 builder
.setInsertionPointToStart(&ifOp
.getElseRegion().front());
484 builder
.create
<fir::ResultOp
>(loc
, init
);
485 reductionVal
= ifOp
.getResult(0);
486 builder
.setInsertionPointAfter(ifOp
);
489 // Unwind the loop nest and insert ResultOp on each level
490 // to return the updated value of the reduction to the enclosing
492 for (unsigned i
= 0; i
< rank
; ++i
) {
493 auto result
= builder
.create
<fir::ResultOp
>(loc
, reductionVal
);
494 // Proceed to the outer loop.
495 auto loop
= mlir::cast
<fir::DoLoopOp
>(result
->getParentOp());
496 reductionVal
= loop
.getResult(0);
497 // Set insertion point after the loop operation that we have
499 builder
.setInsertionPointAfter(loop
.getOperation());
501 // End of loop nest. The insertion point is after the outermost loop.
503 mlir::dyn_cast
<fir::IfOp
>(builder
.getBlock()->getParentOp())) {
504 builder
.create
<fir::ResultOp
>(loc
, reductionVal
);
505 builder
.setInsertionPointAfter(ifOp
);
506 // Redefine flagSet to escape scope of ifOp
507 flagSet
= builder
.createIntegerConstant(loc
, resultElemType
, 1);
508 reductionVal
= ifOp
.getResult(0);
511 // Check for case where array was full of max values.
512 // flag will be 0 if mask was never true, 1 if mask was true as some point,
513 // this is needed to avoid catching cases where we didn't access any elements
515 mlir::Value flagValue
=
516 builder
.create
<fir::LoadOp
>(loc
, resultElemType
, flagRef
);
517 mlir::Value flagCmp
= builder
.create
<mlir::arith::CmpIOp
>(
518 loc
, mlir::arith::CmpIPredicate::eq
, flagValue
, flagSet
);
519 fir::IfOp ifMaskTrueOp
=
520 builder
.create
<fir::IfOp
>(loc
, flagCmp
, /*withElseRegion=*/false);
521 builder
.setInsertionPointToStart(&ifMaskTrueOp
.getThenRegion().front());
523 mlir::Value testInit
= initVal(builder
, loc
, elementType
);
524 fir::IfOp ifMinSetOp
;
525 if (elementType
.isa
<mlir::FloatType
>()) {
526 mlir::Value cmp
= builder
.create
<mlir::arith::CmpFOp
>(
527 loc
, mlir::arith::CmpFPredicate::OEQ
, testInit
, reductionVal
);
528 ifMinSetOp
= builder
.create
<fir::IfOp
>(loc
, cmp
,
529 /*withElseRegion*/ false);
531 mlir::Value cmp
= builder
.create
<mlir::arith::CmpIOp
>(
532 loc
, mlir::arith::CmpIPredicate::eq
, testInit
, reductionVal
);
533 ifMinSetOp
= builder
.create
<fir::IfOp
>(loc
, cmp
,
534 /*withElseRegion*/ false);
536 builder
.setInsertionPointToStart(&ifMinSetOp
.getThenRegion().front());
538 // Load output array with 1s instead of 0s
539 for (unsigned int i
= 0; i
< rank
; ++i
) {
540 mlir::Type resultRefTy
= builder
.getRefType(resultElemType
);
541 // mlir::Value one = builder.createIntegerConstant(loc, resultElemType, 1);
542 mlir::Value index
= builder
.createIntegerConstant(loc
, idxTy
, i
);
543 mlir::Value resultElemAddr
=
544 builder
.create
<fir::CoordinateOp
>(loc
, resultRefTy
, resultArr
, index
);
545 builder
.create
<fir::StoreOp
>(loc
, flagSet
, resultElemAddr
);
547 builder
.setInsertionPointAfter(ifMaskTrueOp
);
548 // Store newly created output array to the reference passed in
549 fir::SequenceType::Shape
resultShape(1, rank
);
550 mlir::Type outputArrTy
= fir::SequenceType::get(resultShape
, resultElemType
);
551 mlir::Type outputHeapTy
= fir::HeapType::get(outputArrTy
);
552 mlir::Type outputBoxTy
= fir::BoxType::get(outputHeapTy
);
553 mlir::Type outputRefTy
= builder
.getRefType(outputBoxTy
);
555 mlir::Value outputArrNone
= args
[0];
556 mlir::Value outputArr
=
557 builder
.create
<fir::ConvertOp
>(loc
, outputRefTy
, outputArrNone
);
559 // Store nearly created array to output array
560 builder
.create
<fir::StoreOp
>(loc
, resultArr
, outputArr
);
561 builder
.create
<mlir::func::ReturnOp
>(loc
);
564 static llvm::SmallVector
<mlir::Value
> nopLoopCond(fir::FirOpBuilder
&builder
,
566 mlir::Value reductionVal
) {
567 return {reductionVal
};
570 /// Generate function body of the simplified version of RTNAME(Sum)
571 /// with signature provided by \p funcOp. The caller is responsible
572 /// for saving/restoring the original insertion point of \p builder.
573 /// \p funcOp is expected to be empty on entry to this function.
574 /// \p rank specifies the rank of the input argument.
575 static void genRuntimeSumBody(fir::FirOpBuilder
&builder
,
576 mlir::func::FuncOp
&funcOp
, unsigned rank
,
577 mlir::Type elementType
) {
578 // function RTNAME(Sum)<T>x<rank>_simplified(arr)
579 // T, dimension(:) :: arr
582 // do iter = 0, extent(arr)
583 // sum = sum + arr[iter]
585 // RTNAME(Sum)<T>x<rank>_simplified = sum
586 // end function RTNAME(Sum)<T>x<rank>_simplified
587 auto zero
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
588 mlir::Type elementType
) {
589 if (auto ty
= elementType
.dyn_cast
<mlir::FloatType
>()) {
590 const llvm::fltSemantics
&sem
= ty
.getFloatSemantics();
591 return builder
.createRealConstant(loc
, elementType
,
592 llvm::APFloat::getZero(sem
));
594 return builder
.createIntegerConstant(loc
, elementType
, 0);
597 auto genBodyOp
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
598 mlir::Type elementType
, mlir::Value elem1
,
599 mlir::Value elem2
) -> mlir::Value
{
600 if (elementType
.isa
<mlir::FloatType
>())
601 return builder
.create
<mlir::arith::AddFOp
>(loc
, elem1
, elem2
);
602 if (elementType
.isa
<mlir::IntegerType
>())
603 return builder
.create
<mlir::arith::AddIOp
>(loc
, elem1
, elem2
);
605 llvm_unreachable("unsupported type");
609 mlir::Location loc
= mlir::UnknownLoc::get(builder
.getContext());
610 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
612 genReductionLoop
<fir::DoLoopOp
, bool, 0>(builder
, funcOp
, zero
, nopLoopCond
,
613 false, genBodyOp
, rank
, elementType
,
617 static void genRuntimeMaxvalBody(fir::FirOpBuilder
&builder
,
618 mlir::func::FuncOp
&funcOp
, unsigned rank
,
619 mlir::Type elementType
) {
620 auto init
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
621 mlir::Type elementType
) {
622 if (auto ty
= elementType
.dyn_cast
<mlir::FloatType
>()) {
623 const llvm::fltSemantics
&sem
= ty
.getFloatSemantics();
624 return builder
.createRealConstant(
625 loc
, elementType
, llvm::APFloat::getLargest(sem
, /*Negative=*/true));
627 unsigned bits
= elementType
.getIntOrFloatBitWidth();
628 int64_t minInt
= llvm::APInt::getSignedMinValue(bits
).getSExtValue();
629 return builder
.createIntegerConstant(loc
, elementType
, minInt
);
632 auto genBodyOp
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
633 mlir::Type elementType
, mlir::Value elem1
,
634 mlir::Value elem2
) -> mlir::Value
{
635 if (elementType
.isa
<mlir::FloatType
>())
636 return builder
.create
<mlir::arith::MaxFOp
>(loc
, elem1
, elem2
);
637 if (elementType
.isa
<mlir::IntegerType
>())
638 return builder
.create
<mlir::arith::MaxSIOp
>(loc
, elem1
, elem2
);
640 llvm_unreachable("unsupported type");
644 mlir::Location loc
= mlir::UnknownLoc::get(builder
.getContext());
645 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
647 genReductionLoop
<fir::DoLoopOp
, bool, 0>(builder
, funcOp
, init
, nopLoopCond
,
648 false, genBodyOp
, rank
, elementType
,
652 static void genRuntimeCountBody(fir::FirOpBuilder
&builder
,
653 mlir::func::FuncOp
&funcOp
, unsigned rank
,
654 mlir::Type elementType
) {
655 auto zero
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
656 mlir::Type elementType
) {
657 unsigned bits
= elementType
.getIntOrFloatBitWidth();
658 int64_t zeroInt
= llvm::APInt::getZero(bits
).getSExtValue();
659 return builder
.createIntegerConstant(loc
, elementType
, zeroInt
);
662 auto genBodyOp
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
663 mlir::Type elementType
, mlir::Value elem1
,
664 mlir::Value elem2
) -> mlir::Value
{
665 auto zero32
= builder
.createIntegerConstant(loc
, elementType
, 0);
666 auto zero64
= builder
.createIntegerConstant(loc
, builder
.getI64Type(), 0);
667 auto one64
= builder
.createIntegerConstant(loc
, builder
.getI64Type(), 1);
669 auto compare
= builder
.create
<mlir::arith::CmpIOp
>(
670 loc
, mlir::arith::CmpIPredicate::eq
, elem1
, zero32
);
672 builder
.create
<mlir::arith::SelectOp
>(loc
, compare
, zero64
, one64
);
673 return builder
.create
<mlir::arith::AddIOp
>(loc
, select
, elem2
);
676 // Count always gets I32 for elementType as it converts logical input to
677 // logical<4> before passing to the function.
678 mlir::Location loc
= mlir::UnknownLoc::get(builder
.getContext());
679 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
681 genReductionLoop
<fir::DoLoopOp
, bool, 0>(builder
, funcOp
, zero
, nopLoopCond
,
682 false, genBodyOp
, rank
, elementType
,
686 static void genRuntimeAnyBody(fir::FirOpBuilder
&builder
,
687 mlir::func::FuncOp
&funcOp
, unsigned rank
,
688 mlir::Type elementType
) {
689 auto zero
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
690 mlir::Type elementType
) {
691 return builder
.createIntegerConstant(loc
, elementType
, 0);
694 auto genBodyOp
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
695 mlir::Type elementType
, mlir::Value elem1
,
696 mlir::Value elem2
) -> mlir::Value
{
697 auto zero
= builder
.createIntegerConstant(loc
, elementType
, 0);
698 return builder
.create
<mlir::arith::CmpIOp
>(
699 loc
, mlir::arith::CmpIPredicate::ne
, elem1
, zero
);
702 auto continueCond
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
703 mlir::Value reductionVal
) {
704 auto one1
= builder
.createIntegerConstant(loc
, builder
.getI1Type(), 1);
705 auto eor
= builder
.create
<mlir::arith::XOrIOp
>(loc
, reductionVal
, one1
);
706 llvm::SmallVector
<mlir::Value
> results
= {eor
, reductionVal
};
710 mlir::Location loc
= mlir::UnknownLoc::get(builder
.getContext());
711 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
712 mlir::Value ok
= builder
.createBool(loc
, true);
714 genReductionLoop
<fir::IterWhileOp
, mlir::Value
, 1>(
715 builder
, funcOp
, zero
, continueCond
, ok
, genBodyOp
, rank
, elementType
,
719 static void genRuntimeAllBody(fir::FirOpBuilder
&builder
,
720 mlir::func::FuncOp
&funcOp
, unsigned rank
,
721 mlir::Type elementType
) {
722 auto one
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
723 mlir::Type elementType
) {
724 return builder
.createIntegerConstant(loc
, elementType
, 1);
727 auto genBodyOp
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
728 mlir::Type elementType
, mlir::Value elem1
,
729 mlir::Value elem2
) -> mlir::Value
{
730 auto zero
= builder
.createIntegerConstant(loc
, elementType
, 0);
731 return builder
.create
<mlir::arith::CmpIOp
>(
732 loc
, mlir::arith::CmpIPredicate::ne
, elem1
, zero
);
735 auto continueCond
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
736 mlir::Value reductionVal
) {
737 llvm::SmallVector
<mlir::Value
> results
= {reductionVal
, reductionVal
};
741 mlir::Location loc
= mlir::UnknownLoc::get(builder
.getContext());
742 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
743 mlir::Value ok
= builder
.createBool(loc
, true);
745 genReductionLoop
<fir::IterWhileOp
, mlir::Value
, 1>(
746 builder
, funcOp
, one
, continueCond
, ok
, genBodyOp
, rank
, elementType
,
750 static mlir::FunctionType
genRuntimeMinlocType(fir::FirOpBuilder
&builder
,
752 mlir::Type boxType
= fir::BoxType::get(builder
.getNoneType());
753 mlir::Type boxRefType
= builder
.getRefType(boxType
);
755 return mlir::FunctionType::get(builder
.getContext(),
756 {boxRefType
, boxType
, boxType
}, {});
759 static void genRuntimeMinlocBody(fir::FirOpBuilder
&builder
,
760 mlir::func::FuncOp
&funcOp
, unsigned rank
,
761 int maskRank
, mlir::Type elementType
,
762 mlir::Type maskElemType
,
763 mlir::Type resultElemTy
) {
764 auto init
= [](fir::FirOpBuilder builder
, mlir::Location loc
,
765 mlir::Type elementType
) {
766 if (auto ty
= elementType
.dyn_cast
<mlir::FloatType
>()) {
767 const llvm::fltSemantics
&sem
= ty
.getFloatSemantics();
768 return builder
.createRealConstant(
769 loc
, elementType
, llvm::APFloat::getLargest(sem
, /*Negative=*/false));
771 unsigned bits
= elementType
.getIntOrFloatBitWidth();
772 int64_t maxInt
= llvm::APInt::getSignedMaxValue(bits
).getSExtValue();
773 return builder
.createIntegerConstant(loc
, elementType
, maxInt
);
776 mlir::Location loc
= mlir::UnknownLoc::get(builder
.getContext());
777 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
779 mlir::Value mask
= funcOp
.front().getArgument(2);
781 // Set up result array in case of early exit / 0 length array
782 mlir::IndexType idxTy
= builder
.getIndexType();
783 mlir::Type resultTy
= fir::SequenceType::get(rank
, resultElemTy
);
784 mlir::Type resultHeapTy
= fir::HeapType::get(resultTy
);
785 mlir::Type resultBoxTy
= fir::BoxType::get(resultHeapTy
);
787 mlir::Value returnValue
= builder
.createIntegerConstant(loc
, resultElemTy
, 0);
788 mlir::Value resultArrSize
= builder
.createIntegerConstant(loc
, idxTy
, rank
);
790 mlir::Value resultArrInit
= builder
.create
<fir::AllocMemOp
>(loc
, resultTy
);
791 mlir::Value resultArrShape
= builder
.create
<fir::ShapeOp
>(loc
, resultArrSize
);
792 mlir::Value resultArr
= builder
.create
<fir::EmboxOp
>(
793 loc
, resultBoxTy
, resultArrInit
, resultArrShape
);
795 mlir::Type resultRefTy
= builder
.getRefType(resultElemTy
);
797 for (unsigned int i
= 0; i
< rank
; ++i
) {
798 mlir::Value index
= builder
.createIntegerConstant(loc
, idxTy
, i
);
799 mlir::Value resultElemAddr
=
800 builder
.create
<fir::CoordinateOp
>(loc
, resultRefTy
, resultArr
, index
);
801 builder
.create
<fir::StoreOp
>(loc
, returnValue
, resultElemAddr
);
806 fir::FirOpBuilder builder
, mlir::Location loc
, mlir::Type elementType
,
807 mlir::Value elem1
, mlir::Value elem2
,
808 llvm::SmallVector
<mlir::Value
, Fortran::common::maxRank
> indices
)
811 if (elementType
.isa
<mlir::FloatType
>()) {
812 cmp
= builder
.create
<mlir::arith::CmpFOp
>(
813 loc
, mlir::arith::CmpFPredicate::OLT
, elem1
, elem2
);
814 } else if (elementType
.isa
<mlir::IntegerType
>()) {
815 cmp
= builder
.create
<mlir::arith::CmpIOp
>(
816 loc
, mlir::arith::CmpIPredicate::slt
, elem1
, elem2
);
818 llvm_unreachable("unsupported type");
821 fir::IfOp ifOp
= builder
.create
<fir::IfOp
>(loc
, elementType
, cmp
,
822 /*withElseRegion*/ true);
824 builder
.setInsertionPointToStart(&ifOp
.getThenRegion().front());
825 mlir::Type resultElemTy
= hlfir::getFortranElementType(resultArr
.getType());
826 mlir::Type returnRefTy
= builder
.getRefType(resultElemTy
);
827 mlir::IndexType idxTy
= builder
.getIndexType();
829 mlir::Value one
= builder
.createIntegerConstant(loc
, resultElemTy
, 1);
831 for (unsigned int i
= 0; i
< rank
; ++i
) {
832 mlir::Value index
= builder
.createIntegerConstant(loc
, idxTy
, i
);
833 mlir::Value resultElemAddr
=
834 builder
.create
<fir::CoordinateOp
>(loc
, returnRefTy
, resultArr
, index
);
835 mlir::Value convert
=
836 builder
.create
<fir::ConvertOp
>(loc
, resultElemTy
, indices
[i
]);
837 mlir::Value fortranIndex
=
838 builder
.create
<mlir::arith::AddIOp
>(loc
, convert
, one
);
839 builder
.create
<fir::StoreOp
>(loc
, fortranIndex
, resultElemAddr
);
841 builder
.create
<fir::ResultOp
>(loc
, elem1
);
842 builder
.setInsertionPointToStart(&ifOp
.getElseRegion().front());
843 builder
.create
<fir::ResultOp
>(loc
, elem2
);
844 builder
.setInsertionPointAfter(ifOp
);
845 return ifOp
.getResult(0);
848 // if mask is a logical scalar, we can check its value before the main loop
849 // and either ignore the fact it is there or exit early.
851 mlir::Type logical
= builder
.getI1Type();
852 mlir::IndexType idxTy
= builder
.getIndexType();
854 fir::SequenceType::Shape
singleElement(1, 1);
855 mlir::Type arrTy
= fir::SequenceType::get(singleElement
, logical
);
856 mlir::Type boxArrTy
= fir::BoxType::get(arrTy
);
857 mlir::Value array
= builder
.create
<fir::ConvertOp
>(loc
, boxArrTy
, mask
);
859 mlir::Value indx
= builder
.createIntegerConstant(loc
, idxTy
, 0);
860 mlir::Type logicalRefTy
= builder
.getRefType(logical
);
861 mlir::Value condAddr
=
862 builder
.create
<fir::CoordinateOp
>(loc
, logicalRefTy
, array
, indx
);
863 mlir::Value cond
= builder
.create
<fir::LoadOp
>(loc
, condAddr
);
865 fir::IfOp ifOp
= builder
.create
<fir::IfOp
>(loc
, elementType
, cond
,
866 /*withElseRegion=*/true);
868 builder
.setInsertionPointToStart(&ifOp
.getElseRegion().front());
869 mlir::Value basicValue
;
870 if (elementType
.isa
<mlir::IntegerType
>()) {
871 basicValue
= builder
.createIntegerConstant(loc
, elementType
, 0);
873 basicValue
= builder
.createRealConstant(loc
, elementType
, 0);
875 builder
.create
<fir::ResultOp
>(loc
, basicValue
);
877 builder
.setInsertionPointToStart(&ifOp
.getThenRegion().front());
880 // bit of a hack - maskRank is set to -1 for absent mask arg, so don't
881 // generate high level mask or element by element mask.
882 bool hasMask
= maskRank
> 0;
884 genMinlocReductionLoop(builder
, funcOp
, init
, genBodyOp
, rank
, elementType
,
885 loc
, hasMask
, maskElemType
, resultArr
);
888 /// Generate function type for the simplified version of RTNAME(DotProduct)
889 /// operating on the given \p elementType.
890 static mlir::FunctionType
genRuntimeDotType(fir::FirOpBuilder
&builder
,
891 const mlir::Type
&elementType
) {
892 mlir::Type boxType
= fir::BoxType::get(builder
.getNoneType());
893 return mlir::FunctionType::get(builder
.getContext(), {boxType
, boxType
},
897 /// Generate function body of the simplified version of RTNAME(DotProduct)
898 /// with signature provided by \p funcOp. The caller is responsible
899 /// for saving/restoring the original insertion point of \p builder.
900 /// \p funcOp is expected to be empty on entry to this function.
901 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
902 /// of the underlying array objects - they are used to generate proper
903 /// element accesses.
904 static void genRuntimeDotBody(fir::FirOpBuilder
&builder
,
905 mlir::func::FuncOp
&funcOp
,
906 mlir::Type arg1ElementTy
,
907 mlir::Type arg2ElementTy
) {
908 // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
909 // T, dimension(:) :: arr1, arr2
912 // do iter = 0, extent(arr1)
913 // product = product + arr1[iter] * arr2[iter]
915 // RTNAME(ADotProduct)<T>_simplified = product
916 // end function RTNAME(DotProduct)<T>_simplified
917 auto loc
= mlir::UnknownLoc::get(builder
.getContext());
918 mlir::Type resultElementType
= funcOp
.getResultTypes()[0];
919 builder
.setInsertionPointToEnd(funcOp
.addEntryBlock());
921 mlir::IndexType idxTy
= builder
.getIndexType();
924 resultElementType
.isa
<mlir::FloatType
>()
925 ? builder
.createRealConstant(loc
, resultElementType
, 0.0)
926 : builder
.createIntegerConstant(loc
, resultElementType
, 0);
928 mlir::Block::BlockArgListType args
= funcOp
.front().getArguments();
929 mlir::Value arg1
= args
[0];
930 mlir::Value arg2
= args
[1];
932 mlir::Value zeroIdx
= builder
.createIntegerConstant(loc
, idxTy
, 0);
934 fir::SequenceType::Shape flatShape
= {fir::SequenceType::getUnknownExtent()};
935 mlir::Type arrTy1
= fir::SequenceType::get(flatShape
, arg1ElementTy
);
936 mlir::Type boxArrTy1
= fir::BoxType::get(arrTy1
);
937 mlir::Value array1
= builder
.create
<fir::ConvertOp
>(loc
, boxArrTy1
, arg1
);
938 mlir::Type arrTy2
= fir::SequenceType::get(flatShape
, arg2ElementTy
);
939 mlir::Type boxArrTy2
= fir::BoxType::get(arrTy2
);
940 mlir::Value array2
= builder
.create
<fir::ConvertOp
>(loc
, boxArrTy2
, arg2
);
941 // This version takes the loop trip count from the first argument.
942 // If the first argument's box has unknown (at compilation time)
943 // extent, then it may be better to take the extent from the second
944 // argument - so that after inlining the loop may be better optimized, e.g.
945 // fully unrolled. This requires generating two versions of the simplified
946 // function and some analysis at the call site to choose which version
947 // is more profitable to call.
948 // Note that we can assume that both arguments have the same extent.
950 builder
.create
<fir::BoxDimsOp
>(loc
, idxTy
, idxTy
, idxTy
, array1
, zeroIdx
);
951 mlir::Value len
= dims
.getResult(1);
952 mlir::Value one
= builder
.createIntegerConstant(loc
, idxTy
, 1);
953 mlir::Value step
= one
;
955 // We use C indexing here, so len-1 as loopcount
956 mlir::Value loopCount
= builder
.create
<mlir::arith::SubIOp
>(loc
, len
, one
);
957 auto loop
= builder
.create
<fir::DoLoopOp
>(loc
, zeroIdx
, loopCount
, step
,
959 /*finalCountValue=*/false, zero
);
960 mlir::Value sumVal
= loop
.getRegionIterArgs()[0];
963 mlir::OpBuilder::InsertPoint loopEndPt
= builder
.saveInsertionPoint();
964 builder
.setInsertionPointToStart(loop
.getBody());
966 mlir::Type eleRef1Ty
= builder
.getRefType(arg1ElementTy
);
967 mlir::Value index
= loop
.getInductionVar();
969 builder
.create
<fir::CoordinateOp
>(loc
, eleRef1Ty
, array1
, index
);
970 mlir::Value elem1
= builder
.create
<fir::LoadOp
>(loc
, addr1
);
971 // Convert to the result type.
972 elem1
= builder
.create
<fir::ConvertOp
>(loc
, resultElementType
, elem1
);
974 mlir::Type eleRef2Ty
= builder
.getRefType(arg2ElementTy
);
976 builder
.create
<fir::CoordinateOp
>(loc
, eleRef2Ty
, array2
, index
);
977 mlir::Value elem2
= builder
.create
<fir::LoadOp
>(loc
, addr2
);
978 // Convert to the result type.
979 elem2
= builder
.create
<fir::ConvertOp
>(loc
, resultElementType
, elem2
);
981 if (resultElementType
.isa
<mlir::FloatType
>())
982 sumVal
= builder
.create
<mlir::arith::AddFOp
>(
983 loc
, builder
.create
<mlir::arith::MulFOp
>(loc
, elem1
, elem2
), sumVal
);
984 else if (resultElementType
.isa
<mlir::IntegerType
>())
985 sumVal
= builder
.create
<mlir::arith::AddIOp
>(
986 loc
, builder
.create
<mlir::arith::MulIOp
>(loc
, elem1
, elem2
), sumVal
);
988 llvm_unreachable("unsupported type");
990 builder
.create
<fir::ResultOp
>(loc
, sumVal
);
992 builder
.restoreInsertionPoint(loopEndPt
);
994 mlir::Value resultVal
= loop
.getResult(0);
995 builder
.create
<mlir::func::ReturnOp
>(loc
, resultVal
);
998 mlir::func::FuncOp
SimplifyIntrinsicsPass::getOrCreateFunction(
999 fir::FirOpBuilder
&builder
, const mlir::StringRef
&baseName
,
1000 FunctionTypeGeneratorTy typeGenerator
,
1001 FunctionBodyGeneratorTy bodyGenerator
) {
1002 // WARNING: if the function generated here changes its signature
1003 // or behavior (the body code), we should probably embed some
1004 // versioning information into its name, otherwise libraries
1005 // statically linked with older versions of Flang may stop
1006 // working with object files created with newer Flang.
1007 // We can also avoid this by using internal linkage, but
1008 // this may increase the size of final executable/shared library.
1009 std::string replacementName
= mlir::Twine
{baseName
, "_simplified"}.str();
1010 mlir::ModuleOp module
= builder
.getModule();
1011 // If we already have a function, just return it.
1012 mlir::func::FuncOp newFunc
=
1013 fir::FirOpBuilder::getNamedFunction(module
, replacementName
);
1014 mlir::FunctionType fType
= typeGenerator(builder
);
1016 assert(newFunc
.getFunctionType() == fType
&&
1017 "type mismatch for simplified function");
1021 // Need to build the function!
1022 auto loc
= mlir::UnknownLoc::get(builder
.getContext());
1024 fir::FirOpBuilder::createFunction(loc
, module
, replacementName
, fType
);
1025 auto inlineLinkage
= mlir::LLVM::linkage::Linkage::LinkonceODR
;
1027 mlir::LLVM::LinkageAttr::get(builder
.getContext(), inlineLinkage
);
1028 newFunc
->setAttr("llvm.linkage", linkage
);
1030 // Save the position of the original call.
1031 mlir::OpBuilder::InsertPoint insertPt
= builder
.saveInsertionPoint();
1033 bodyGenerator(builder
, newFunc
);
1035 // Now back to where we were adding code earlier...
1036 builder
.restoreInsertionPoint(insertPt
);
1041 void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction(
1042 fir::CallOp call
, const fir::KindMapping
&kindMap
,
1043 GenReductionBodyTy genBodyFunc
) {
1044 // args[1] and args[2] are source filename and line number, ignored.
1045 mlir::Operation::operand_range args
= call
.getArgs();
1047 const mlir::Value
&dim
= args
[3];
1048 const mlir::Value
&mask
= args
[4];
1049 // dim is zero when it is absent, which is an implementation
1050 // detail in the runtime library.
1052 bool dimAndMaskAbsent
= isZero(dim
) && isOperandAbsent(mask
);
1053 unsigned rank
= getDimCount(args
[0]);
1055 // Rank is set to 0 for assumed shape arrays, don't simplify
1057 if (!(dimAndMaskAbsent
&& rank
> 0))
1060 mlir::Type resultType
= call
.getResult(0).getType();
1062 if (!resultType
.isa
<mlir::FloatType
>() &&
1063 !resultType
.isa
<mlir::IntegerType
>())
1066 auto argType
= getArgElementType(args
[0]);
1069 assert(*argType
== resultType
&&
1070 "Argument/result types mismatch in reduction");
1072 mlir::SymbolRefAttr callee
= call
.getCalleeAttr();
1074 fir::FirOpBuilder builder
{getSimplificationBuilder(call
, kindMap
)};
1075 std::string fmfString
{getFastMathFlagsString(builder
)};
1076 std::string funcName
=
1077 (mlir::Twine
{callee
.getLeafReference().getValue(), "x"} +
1079 // We must mangle the generated function name with FastMathFlags
1081 (fmfString
.empty() ? mlir::Twine
{} : mlir::Twine
{"_", fmfString
}))
1084 simplifyReductionBody(call
, kindMap
, genBodyFunc
, builder
, funcName
,
1088 void SimplifyIntrinsicsPass::simplifyLogicalDim0Reduction(
1089 fir::CallOp call
, const fir::KindMapping
&kindMap
,
1090 GenReductionBodyTy genBodyFunc
) {
1092 mlir::Operation::operand_range args
= call
.getArgs();
1093 const mlir::Value
&dim
= args
[3];
1094 unsigned rank
= getDimCount(args
[0]);
1096 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1098 if (!(isZero(dim
) && rank
> 0))
1101 mlir::Value inputBox
= findBoxDef(args
[0]);
1103 mlir::Type elementType
= hlfir::getFortranElementType(inputBox
.getType());
1104 mlir::SymbolRefAttr callee
= call
.getCalleeAttr();
1106 fir::FirOpBuilder builder
{getSimplificationBuilder(call
, kindMap
)};
1108 // Treating logicals as integers makes things a lot easier
1109 fir::LogicalType logicalType
= {elementType
.dyn_cast
<fir::LogicalType
>()};
1110 fir::KindTy kind
= logicalType
.getFKind();
1111 mlir::Type intElementType
= builder
.getIntegerType(kind
* 8);
1113 // Mangle kind into function name as it is not done by default
1114 std::string funcName
=
1115 (mlir::Twine
{callee
.getLeafReference().getValue(), "Logical"} +
1116 mlir::Twine
{kind
} + "x" + mlir::Twine
{rank
})
1119 simplifyReductionBody(call
, kindMap
, genBodyFunc
, builder
, funcName
,
1123 void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction(
1124 fir::CallOp call
, const fir::KindMapping
&kindMap
,
1125 GenReductionBodyTy genBodyFunc
) {
1127 mlir::Operation::operand_range args
= call
.getArgs();
1128 mlir::SymbolRefAttr callee
= call
.getCalleeAttr();
1129 mlir::StringRef funcNameBase
= callee
.getLeafReference().getValue();
1130 unsigned rank
= getDimCount(args
[0]);
1132 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1133 // these cases. We check for Dim at the end as some logical functions (Any,
1134 // All) set dim to 1 instead of 0 when the argument is not present.
1135 if (funcNameBase
.ends_with("Dim") || !(rank
> 0))
1138 mlir::Value inputBox
= findBoxDef(args
[0]);
1139 mlir::Type elementType
= hlfir::getFortranElementType(inputBox
.getType());
1141 fir::FirOpBuilder builder
{getSimplificationBuilder(call
, kindMap
)};
1143 // Treating logicals as integers makes things a lot easier
1144 fir::LogicalType logicalType
= {elementType
.dyn_cast
<fir::LogicalType
>()};
1145 fir::KindTy kind
= logicalType
.getFKind();
1146 mlir::Type intElementType
= builder
.getIntegerType(kind
* 8);
1148 // Mangle kind into function name as it is not done by default
1149 std::string funcName
=
1150 (mlir::Twine
{callee
.getLeafReference().getValue(), "Logical"} +
1151 mlir::Twine
{kind
} + "x" + mlir::Twine
{rank
})
1154 simplifyReductionBody(call
, kindMap
, genBodyFunc
, builder
, funcName
,
1158 void SimplifyIntrinsicsPass::simplifyMinlocReduction(
1159 fir::CallOp call
, const fir::KindMapping
&kindMap
) {
1161 mlir::Operation::operand_range args
= call
.getArgs();
1163 mlir::Value back
= args
[6];
1164 if (isTrueOrNotConstant(back
))
1167 mlir::Value mask
= args
[5];
1168 mlir::Value maskDef
= findMaskDef(mask
);
1170 // maskDef is set to NULL when the defining op is not one we accept.
1171 // This tends to be because it is a selectOp, in which case let the
1172 // runtime deal with it.
1173 if (maskDef
== NULL
)
1176 mlir::SymbolRefAttr callee
= call
.getCalleeAttr();
1177 mlir::StringRef funcNameBase
= callee
.getLeafReference().getValue();
1178 unsigned rank
= getDimCount(args
[1]);
1179 if (funcNameBase
.ends_with("Dim") || !(rank
> 0))
1182 fir::FirOpBuilder builder
{getSimplificationBuilder(call
, kindMap
)};
1183 mlir::Location loc
= call
.getLoc();
1184 auto inputBox
= findBoxDef(args
[1]);
1185 mlir::Type inputType
= hlfir::getFortranElementType(inputBox
.getType());
1187 if (inputType
.isa
<fir::CharacterType
>())
1191 fir::KindTy kind
= 0;
1192 mlir::Type logicalElemType
= builder
.getI1Type();
1193 if (isOperandAbsent(mask
)) {
1196 maskRank
= getDimCount(mask
);
1197 mlir::Type maskElemTy
= hlfir::getFortranElementType(maskDef
.getType());
1198 fir::LogicalType logicalFirType
= {maskElemTy
.dyn_cast
<fir::LogicalType
>()};
1199 kind
= logicalFirType
.getFKind();
1200 // Convert fir::LogicalType to mlir::Type
1201 logicalElemType
= logicalFirType
;
1204 mlir::Operation
*outputDef
= args
[0].getDefiningOp();
1205 mlir::Value outputAlloc
= outputDef
->getOperand(0);
1206 mlir::Type outType
= hlfir::getFortranElementType(outputAlloc
.getType());
1208 std::string fmfString
{getFastMathFlagsString(builder
)};
1209 std::string funcName
=
1210 (mlir::Twine
{callee
.getLeafReference().getValue(), "x"} +
1213 ? "_Logical" + mlir::Twine
{kind
} + "x" + mlir::Twine
{maskRank
}
1218 llvm::raw_string_ostream
nameOS(funcName
);
1219 outType
.print(nameOS
);
1220 nameOS
<< '_' << fmfString
;
1222 auto typeGenerator
= [rank
](fir::FirOpBuilder
&builder
) {
1223 return genRuntimeMinlocType(builder
, rank
);
1225 auto bodyGenerator
= [rank
, maskRank
, inputType
, logicalElemType
,
1226 outType
](fir::FirOpBuilder
&builder
,
1227 mlir::func::FuncOp
&funcOp
) {
1228 genRuntimeMinlocBody(builder
, funcOp
, rank
, maskRank
, inputType
,
1229 logicalElemType
, outType
);
1232 mlir::func::FuncOp newFunc
=
1233 getOrCreateFunction(builder
, funcName
, typeGenerator
, bodyGenerator
);
1234 builder
.create
<fir::CallOp
>(loc
, newFunc
,
1235 mlir::ValueRange
{args
[0], args
[1], args
[5]});
1236 call
->dropAllReferences();
1240 void SimplifyIntrinsicsPass::simplifyReductionBody(
1241 fir::CallOp call
, const fir::KindMapping
&kindMap
,
1242 GenReductionBodyTy genBodyFunc
, fir::FirOpBuilder
&builder
,
1243 const mlir::StringRef
&funcName
, mlir::Type elementType
) {
1245 mlir::Operation::operand_range args
= call
.getArgs();
1247 mlir::Type resultType
= call
.getResult(0).getType();
1248 unsigned rank
= getDimCount(args
[0]);
1250 mlir::Location loc
= call
.getLoc();
1252 auto typeGenerator
= [&resultType
](fir::FirOpBuilder
&builder
) {
1253 return genNoneBoxType(builder
, resultType
);
1255 auto bodyGenerator
= [&rank
, &genBodyFunc
,
1256 &elementType
](fir::FirOpBuilder
&builder
,
1257 mlir::func::FuncOp
&funcOp
) {
1258 genBodyFunc(builder
, funcOp
, rank
, elementType
);
1260 // Mangle the function name with the rank value as "x<rank>".
1261 mlir::func::FuncOp newFunc
=
1262 getOrCreateFunction(builder
, funcName
, typeGenerator
, bodyGenerator
);
1264 builder
.create
<fir::CallOp
>(loc
, newFunc
, mlir::ValueRange
{args
[0]});
1265 call
->replaceAllUsesWith(newCall
.getResults());
1266 call
->dropAllReferences();
1270 void SimplifyIntrinsicsPass::runOnOperation() {
1271 LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE
" ===\n");
1272 mlir::ModuleOp module
= getOperation();
1273 fir::KindMapping kindMap
= fir::getKindMapping(module
);
1274 module
.walk([&](mlir::Operation
*op
) {
1275 if (auto call
= mlir::dyn_cast
<fir::CallOp
>(op
)) {
1276 if (mlir::SymbolRefAttr callee
= call
.getCalleeAttr()) {
1277 mlir::StringRef funcName
= callee
.getLeafReference().getValue();
1278 // Replace call to runtime function for SUM when it has single
1279 // argument (no dim or mask argument) for 1D arrays with either
1280 // Integer4 or Real8 types. Other forms are ignored.
1281 // The new function is added to the module.
1283 // Prototype for runtime call (from sum.cpp):
1284 // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
1285 // int dim, const Descriptor *mask)
1287 if (funcName
.startswith(RTNAME_STRING(Sum
))) {
1288 simplifyIntOrFloatReduction(call
, kindMap
, genRuntimeSumBody
);
1291 if (funcName
.startswith(RTNAME_STRING(DotProduct
))) {
1292 LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName
<< "\n");
1293 LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op
->dump();
1294 llvm::dbgs() << "\n");
1295 mlir::Operation::operand_range args
= call
.getArgs();
1296 const mlir::Value
&v1
= args
[0];
1297 const mlir::Value
&v2
= args
[1];
1298 mlir::Location loc
= call
.getLoc();
1299 fir::FirOpBuilder builder
{getSimplificationBuilder(op
, kindMap
)};
1300 // Stringize the builder's FastMathFlags flags for mangling
1301 // the generated function name.
1302 std::string fmfString
{getFastMathFlagsString(builder
)};
1304 mlir::Type type
= call
.getResult(0).getType();
1305 if (!type
.isa
<mlir::FloatType
>() && !type
.isa
<mlir::IntegerType
>())
1308 // Try to find the element types of the boxed arguments.
1309 auto arg1Type
= getArgElementType(v1
);
1310 auto arg2Type
= getArgElementType(v2
);
1312 if (!arg1Type
|| !arg2Type
)
1315 // Support only floating point and integer arguments
1316 // now (e.g. logical is skipped here).
1317 if (!arg1Type
->isa
<mlir::FloatType
>() &&
1318 !arg1Type
->isa
<mlir::IntegerType
>())
1320 if (!arg2Type
->isa
<mlir::FloatType
>() &&
1321 !arg2Type
->isa
<mlir::IntegerType
>())
1324 auto typeGenerator
= [&type
](fir::FirOpBuilder
&builder
) {
1325 return genRuntimeDotType(builder
, type
);
1327 auto bodyGenerator
= [&arg1Type
,
1328 &arg2Type
](fir::FirOpBuilder
&builder
,
1329 mlir::func::FuncOp
&funcOp
) {
1330 genRuntimeDotBody(builder
, funcOp
, *arg1Type
, *arg2Type
);
1333 // Suffix the function name with the element types
1334 // of the arguments.
1335 std::string
typedFuncName(funcName
);
1336 llvm::raw_string_ostream
nameOS(typedFuncName
);
1337 // We must mangle the generated function name with FastMathFlags
1339 if (!fmfString
.empty())
1340 nameOS
<< '_' << fmfString
;
1342 arg1Type
->print(nameOS
);
1344 arg2Type
->print(nameOS
);
1346 mlir::func::FuncOp newFunc
= getOrCreateFunction(
1347 builder
, typedFuncName
, typeGenerator
, bodyGenerator
);
1348 auto newCall
= builder
.create
<fir::CallOp
>(loc
, newFunc
,
1349 mlir::ValueRange
{v1
, v2
});
1350 call
->replaceAllUsesWith(newCall
.getResults());
1351 call
->dropAllReferences();
1354 LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall
.dump();
1355 llvm::dbgs() << "\n");
1358 if (funcName
.startswith(RTNAME_STRING(Maxval
))) {
1359 simplifyIntOrFloatReduction(call
, kindMap
, genRuntimeMaxvalBody
);
1362 if (funcName
.startswith(RTNAME_STRING(Count
))) {
1363 simplifyLogicalDim0Reduction(call
, kindMap
, genRuntimeCountBody
);
1366 if (funcName
.startswith(RTNAME_STRING(Any
))) {
1367 simplifyLogicalDim1Reduction(call
, kindMap
, genRuntimeAnyBody
);
1370 if (funcName
.endswith(RTNAME_STRING(All
))) {
1371 simplifyLogicalDim1Reduction(call
, kindMap
, genRuntimeAllBody
);
1374 if (funcName
.startswith(RTNAME_STRING(Minloc
))) {
1375 simplifyMinlocReduction(call
, kindMap
);
1381 LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE
" ===\n");
1384 void SimplifyIntrinsicsPass::getDependentDialects(
1385 mlir::DialectRegistry
®istry
) const {
1386 // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
1387 registry
.insert
<mlir::LLVM::LLVMDialect
>();
1389 std::unique_ptr
<mlir::Pass
> fir::createSimplifyIntrinsicsPass() {
1390 return std::make_unique
<SimplifyIntrinsicsPass
>();