[lldb] Add ability to hide the root name of a value
[llvm-project.git] / flang / lib / Optimizer / Transforms / SimplifyIntrinsics.cpp
blob244a462cb72252e784f7e6dbbde6dada67a39f47
1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 //===----------------------------------------------------------------------===//
10 /// \file
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
13 ///
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 // for higher dimension arrays, of course)
18 ///
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
25 #include "flang/Common/Fortran.h"
26 #include "flang/Optimizer/Builder/BoxValue.h"
27 #include "flang/Optimizer/Builder/FIRBuilder.h"
28 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
29 #include "flang/Optimizer/Builder/Todo.h"
30 #include "flang/Optimizer/Dialect/FIROps.h"
31 #include "flang/Optimizer/Dialect/FIRType.h"
32 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
33 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
34 #include "flang/Optimizer/Transforms/Passes.h"
35 #include "flang/Runtime/entry-names.h"
36 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
37 #include "mlir/IR/Matchers.h"
38 #include "mlir/IR/Operation.h"
39 #include "mlir/Pass/Pass.h"
40 #include "mlir/Transforms/DialectConversion.h"
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
42 #include "mlir/Transforms/RegionUtils.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/Support/raw_ostream.h"
45 #include <llvm/CodeGen/SelectionDAGNodes.h>
46 #include <llvm/Support/ErrorHandling.h>
47 #include <mlir/Dialect/Arith/IR/Arith.h>
48 #include <mlir/IR/BuiltinTypes.h>
49 #include <mlir/IR/Location.h>
50 #include <mlir/IR/MLIRContext.h>
51 #include <mlir/IR/Value.h>
52 #include <mlir/Support/LLVM.h>
53 #include <optional>
55 namespace fir {
56 #define GEN_PASS_DEF_SIMPLIFYINTRINSICS
57 #include "flang/Optimizer/Transforms/Passes.h.inc"
58 } // namespace fir
60 #define DEBUG_TYPE "flang-simplify-intrinsics"
62 namespace {
64 class SimplifyIntrinsicsPass
65 : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
66 using FunctionTypeGeneratorTy =
67 llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
68 using FunctionBodyGeneratorTy =
69 llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
70 using GenReductionBodyTy = llvm::function_ref<void(
71 fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank,
72 mlir::Type elementType)>;
74 public:
75 /// Generate a new function implementing a simplified version
76 /// of a Fortran runtime function defined by \p basename name.
77 /// \p typeGenerator is a callback that generates the new function's type.
78 /// \p bodyGenerator is a callback that generates the new function's body.
79 /// The new function is created in the \p builder's Module.
80 mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
81 const mlir::StringRef &basename,
82 FunctionTypeGeneratorTy typeGenerator,
83 FunctionBodyGeneratorTy bodyGenerator);
84 void runOnOperation() override;
85 void getDependentDialects(mlir::DialectRegistry &registry) const override;
87 private:
88 /// Helper functions to replace a reduction type of call with its
89 /// simplified form. The actual function is generated using a callback
90 /// function.
91 /// \p call is the call to be replaced
92 /// \p kindMap is used to create FIROpBuilder
93 /// \p genBodyFunc is the callback that builds the replacement function
94 void simplifyIntOrFloatReduction(fir::CallOp call,
95 const fir::KindMapping &kindMap,
96 GenReductionBodyTy genBodyFunc);
97 void simplifyLogicalDim0Reduction(fir::CallOp call,
98 const fir::KindMapping &kindMap,
99 GenReductionBodyTy genBodyFunc);
100 void simplifyLogicalDim1Reduction(fir::CallOp call,
101 const fir::KindMapping &kindMap,
102 GenReductionBodyTy genBodyFunc);
103 void simplifyMinlocReduction(fir::CallOp call,
104 const fir::KindMapping &kindMap);
105 void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap,
106 GenReductionBodyTy genBodyFunc,
107 fir::FirOpBuilder &builder,
108 const mlir::StringRef &basename,
109 mlir::Type elementType);
112 } // namespace
114 /// Create FirOpBuilder with the provided \p op insertion point
115 /// and \p kindMap additionally inheriting FastMathFlags from \p op.
116 static fir::FirOpBuilder
117 getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) {
118 fir::FirOpBuilder builder{op, kindMap};
119 auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
120 if (!fmi)
121 return builder;
123 // Regardless of what default FastMathFlags are used by FirOpBuilder,
124 // override them with FastMathFlags attached to the operation.
125 builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue());
126 return builder;
129 /// Stringify FastMathFlags set for the given \p builder in a way
130 /// that the string may be used for mangling a function name.
131 /// If FastMathFlags are set to 'none', then the result is an empty
132 /// string.
133 static std::string getFastMathFlagsString(const fir::FirOpBuilder &builder) {
134 mlir::arith::FastMathFlags flags = builder.getFastMathFlags();
135 if (flags == mlir::arith::FastMathFlags::none)
136 return {};
138 std::string fmfString{mlir::arith::stringifyFastMathFlags(flags)};
139 std::replace(fmfString.begin(), fmfString.end(), ',', '_');
140 return fmfString;
143 /// Generate function type for the simplified version of RTNAME(Sum) and
144 /// similar functions with a fir.box<none> type returning \p elementType.
145 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
146 const mlir::Type &elementType) {
147 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
148 return mlir::FunctionType::get(builder.getContext(), {boxType},
149 {elementType});
152 template <typename Op>
153 Op expectOp(mlir::Value val) {
154 if (Op op = mlir::dyn_cast_or_null<Op>(val.getDefiningOp()))
155 return op;
156 LLVM_DEBUG(llvm::dbgs() << "Didn't find expected " << Op::getOperationName()
157 << '\n');
158 return nullptr;
161 template <typename Op>
162 static mlir::Value findDefSingle(fir::ConvertOp op) {
163 if (auto defOp = expectOp<Op>(op->getOperand(0))) {
164 return defOp.getResult();
166 return {};
169 template <typename... Ops>
170 static mlir::Value findDef(fir::ConvertOp op) {
171 mlir::Value defOp;
172 // Loop over the operation types given to see if any match, exiting once
173 // a match is found. Cast to void is needed to avoid compiler complaining
174 // that the result of expression is unused
175 (void)((defOp = findDefSingle<Ops>(op), (defOp)) || ...);
176 return defOp;
179 static bool isOperandAbsent(mlir::Value val) {
180 if (auto op = expectOp<fir::ConvertOp>(val)) {
181 assert(op->getOperands().size() != 0);
182 return mlir::isa_and_nonnull<fir::AbsentOp>(
183 op->getOperand(0).getDefiningOp());
185 return false;
188 static bool isTrueOrNotConstant(mlir::Value val) {
189 if (auto op = expectOp<mlir::arith::ConstantOp>(val)) {
190 return !mlir::matchPattern(val, mlir::m_Zero());
192 return true;
195 static bool isZero(mlir::Value val) {
196 if (auto op = expectOp<fir::ConvertOp>(val)) {
197 assert(op->getOperands().size() != 0);
198 if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
199 return mlir::matchPattern(defOp, mlir::m_Zero());
201 return false;
204 static mlir::Value findBoxDef(mlir::Value val) {
205 if (auto op = expectOp<fir::ConvertOp>(val)) {
206 assert(op->getOperands().size() != 0);
207 return findDef<fir::EmboxOp, fir::ReboxOp>(op);
209 return {};
212 static mlir::Value findMaskDef(mlir::Value val) {
213 if (auto op = expectOp<fir::ConvertOp>(val)) {
214 assert(op->getOperands().size() != 0);
215 return findDef<fir::EmboxOp, fir::ReboxOp, fir::AbsentOp>(op);
217 return {};
220 static unsigned getDimCount(mlir::Value val) {
221 // In order to find the dimensions count, we look for EmboxOp/ReboxOp
222 // and take the count from its *result* type. Note that in case
223 // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
224 // have different types.
225 // Actually, we can take the box type from the operand of
226 // the first ConvertOp that has non-opaque box type that we meet
227 // going through the ConvertOp chain.
228 if (mlir::Value emboxVal = findBoxDef(val))
229 if (auto boxTy = emboxVal.getType().dyn_cast<fir::BoxType>())
230 if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
231 return seqTy.getDimension();
232 return 0;
235 /// Given the call operation's box argument \p val, discover
236 /// the element type of the underlying array object.
237 /// \returns the element type or std::nullopt if the type cannot
238 /// be reliably found.
239 /// We expect that the argument is a result of fir.convert
240 /// with the destination type of !fir.box<none>.
241 static std::optional<mlir::Type> getArgElementType(mlir::Value val) {
242 mlir::Operation *defOp;
243 do {
244 defOp = val.getDefiningOp();
245 // Analyze only sequences of convert operations.
246 if (!mlir::isa<fir::ConvertOp>(defOp))
247 return std::nullopt;
248 val = defOp->getOperand(0);
249 // The convert operation is expected to convert from one
250 // box type to another box type.
251 auto boxType = val.getType().cast<fir::BoxType>();
252 auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
253 if (!elementType.isa<mlir::NoneType>())
254 return elementType;
255 } while (true);
258 using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
259 fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
260 mlir::Value)>;
261 using InitValGeneratorTy = llvm::function_ref<mlir::Value(
262 fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
263 using ContinueLoopGenTy = llvm::function_ref<llvm::SmallVector<mlir::Value>(
264 fir::FirOpBuilder &, mlir::Location, mlir::Value)>;
266 /// Generate the reduction loop into \p funcOp.
268 /// \p initVal is a function, called to get the initial value for
269 /// the reduction value
270 /// \p genBody is called to fill in the actual reduciton operation
271 /// for example add for SUM, MAX for MAXVAL, etc.
272 /// \p rank is the rank of the input argument.
273 /// \p elementType is the type of the elements in the input array,
274 /// which may be different to the return type.
275 /// \p loopCond is called to generate the condition to continue or
276 /// not for IterWhile loops
277 /// \p unorderedOrInitalLoopCond contains either a boolean or bool
278 /// mlir constant, and controls the inital value for while loops
279 /// or if DoLoop is ordered/unordered.
281 template <typename OP, typename T, int resultIndex>
282 static void
283 genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
284 InitValGeneratorTy initVal, ContinueLoopGenTy loopCond,
285 T unorderedOrInitialLoopCond, BodyOpGeneratorTy genBody,
286 unsigned rank, mlir::Type elementType, mlir::Location loc) {
288 mlir::IndexType idxTy = builder.getIndexType();
290 mlir::Block::BlockArgListType args = funcOp.front().getArguments();
291 mlir::Value arg = args[0];
293 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
295 fir::SequenceType::Shape flatShape(rank,
296 fir::SequenceType::getUnknownExtent());
297 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
298 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
299 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
300 mlir::Type resultType = funcOp.getResultTypes()[0];
301 mlir::Value init = initVal(builder, loc, resultType);
303 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
305 assert(rank > 0 && "rank cannot be zero");
306 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
308 // Compute all the upper bounds before the loop nest.
309 // It is not strictly necessary for performance, since the loop nest
310 // does not have any store operations and any LICM optimization
311 // should be able to optimize the redundancy.
312 for (unsigned i = 0; i < rank; ++i) {
313 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
314 auto dims =
315 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
316 mlir::Value len = dims.getResult(1);
317 // We use C indexing here, so len-1 as loopcount
318 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
319 bounds.push_back(loopCount);
321 // Create a loop nest consisting of OP operations.
322 // Collect the loops' induction variables into indices array,
323 // which will be used in the innermost loop to load the input
324 // array's element.
325 // The loops are generated such that the innermost loop processes
326 // the 0 dimension.
327 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
328 for (unsigned i = rank; 0 < i; --i) {
329 mlir::Value step = one;
330 mlir::Value loopCount = bounds[i - 1];
331 auto loop = builder.create<OP>(loc, zeroIdx, loopCount, step,
332 unorderedOrInitialLoopCond,
333 /*finalCountValue=*/false, init);
334 init = loop.getRegionIterArgs()[resultIndex];
335 indices.push_back(loop.getInductionVar());
336 // Set insertion point to the loop body so that the next loop
337 // is inserted inside the current one.
338 builder.setInsertionPointToStart(loop.getBody());
341 // Reverse the indices such that they are ordered as:
342 // <dim-0-idx, dim-1-idx, ...>
343 std::reverse(indices.begin(), indices.end());
344 // We are in the innermost loop: generate the reduction body.
345 mlir::Type eleRefTy = builder.getRefType(elementType);
346 mlir::Value addr =
347 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
348 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
349 mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
350 // Generate vector with condition to continue while loop at [0] and result
351 // from current loop at [1] for IterWhileOp loops, just result at [0] for
352 // DoLoopOp loops.
353 llvm::SmallVector<mlir::Value> results = loopCond(builder, loc, reductionVal);
355 // Unwind the loop nest and insert ResultOp on each level
356 // to return the updated value of the reduction to the enclosing
357 // loops.
358 for (unsigned i = 0; i < rank; ++i) {
359 auto result = builder.create<fir::ResultOp>(loc, results);
360 // Proceed to the outer loop.
361 auto loop = mlir::cast<OP>(result->getParentOp());
362 results = loop.getResults();
363 // Set insertion point after the loop operation that we have
364 // just processed.
365 builder.setInsertionPointAfter(loop.getOperation());
367 // End of loop nest. The insertion point is after the outermost loop.
368 // Return the reduction value from the function.
369 builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]);
371 using MinlocBodyOpGeneratorTy = llvm::function_ref<mlir::Value(
372 fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
373 mlir::Value, llvm::SmallVector<mlir::Value, Fortran::common::maxRank> &)>;
375 static void
376 genMinlocReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
377 InitValGeneratorTy initVal,
378 MinlocBodyOpGeneratorTy genBody, unsigned rank,
379 mlir::Type elementType, mlir::Location loc, bool hasMask,
380 mlir::Type maskElemType, mlir::Value resultArr) {
382 mlir::IndexType idxTy = builder.getIndexType();
384 mlir::Block::BlockArgListType args = funcOp.front().getArguments();
385 mlir::Value arg = args[1];
387 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
389 fir::SequenceType::Shape flatShape(rank,
390 fir::SequenceType::getUnknownExtent());
391 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
392 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
393 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
395 mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType());
396 mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
397 mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0);
398 mlir::Value flagRef = builder.createTemporary(loc, resultElemType);
399 builder.create<fir::StoreOp>(loc, zero, flagRef);
401 mlir::Value mask;
402 if (hasMask) {
403 mlir::Type maskTy = fir::SequenceType::get(flatShape, maskElemType);
404 mlir::Type boxMaskTy = fir::BoxType::get(maskTy);
405 mask = builder.create<fir::ConvertOp>(loc, boxMaskTy, args[2]);
408 mlir::Value init = initVal(builder, loc, elementType);
409 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
411 assert(rank > 0 && "rank cannot be zero");
412 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
414 // Compute all the upper bounds before the loop nest.
415 // It is not strictly necessary for performance, since the loop nest
416 // does not have any store operations and any LICM optimization
417 // should be able to optimize the redundancy.
418 for (unsigned i = 0; i < rank; ++i) {
419 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
420 auto dims =
421 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
422 mlir::Value len = dims.getResult(1);
423 // We use C indexing here, so len-1 as loopcount
424 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
425 bounds.push_back(loopCount);
427 // Create a loop nest consisting of OP operations.
428 // Collect the loops' induction variables into indices array,
429 // which will be used in the innermost loop to load the input
430 // array's element.
431 // The loops are generated such that the innermost loop processes
432 // the 0 dimension.
433 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
434 for (unsigned i = rank; 0 < i; --i) {
435 mlir::Value step = one;
436 mlir::Value loopCount = bounds[i - 1];
437 auto loop =
438 builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
439 /*finalCountValue=*/false, init);
440 init = loop.getRegionIterArgs()[0];
441 indices.push_back(loop.getInductionVar());
442 // Set insertion point to the loop body so that the next loop
443 // is inserted inside the current one.
444 builder.setInsertionPointToStart(loop.getBody());
447 // Reverse the indices such that they are ordered as:
448 // <dim-0-idx, dim-1-idx, ...>
449 std::reverse(indices.begin(), indices.end());
450 // We are in the innermost loop: generate the reduction body.
451 if (hasMask) {
452 mlir::Type logicalRef = builder.getRefType(maskElemType);
453 mlir::Value maskAddr =
454 builder.create<fir::CoordinateOp>(loc, logicalRef, mask, indices);
455 mlir::Value maskElem = builder.create<fir::LoadOp>(loc, maskAddr);
457 // fir::IfOp requires argument to be I1 - won't accept logical or any other
458 // Integer.
459 mlir::Type ifCompatType = builder.getI1Type();
460 mlir::Value ifCompatElem =
461 builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
463 llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
464 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
465 /*withElseRegion=*/true);
466 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
469 // Set flag that mask was true at some point
470 builder.create<fir::StoreOp>(loc, flagSet, flagRef);
471 mlir::Type eleRefTy = builder.getRefType(elementType);
472 mlir::Value addr =
473 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
474 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
476 mlir::Value reductionVal =
477 genBody(builder, loc, elementType, elem, init, indices);
479 if (hasMask) {
480 fir::IfOp ifOp =
481 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp());
482 builder.create<fir::ResultOp>(loc, reductionVal);
483 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
484 builder.create<fir::ResultOp>(loc, init);
485 reductionVal = ifOp.getResult(0);
486 builder.setInsertionPointAfter(ifOp);
489 // Unwind the loop nest and insert ResultOp on each level
490 // to return the updated value of the reduction to the enclosing
491 // loops.
492 for (unsigned i = 0; i < rank; ++i) {
493 auto result = builder.create<fir::ResultOp>(loc, reductionVal);
494 // Proceed to the outer loop.
495 auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
496 reductionVal = loop.getResult(0);
497 // Set insertion point after the loop operation that we have
498 // just processed.
499 builder.setInsertionPointAfter(loop.getOperation());
501 // End of loop nest. The insertion point is after the outermost loop.
502 if (fir::IfOp ifOp =
503 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) {
504 builder.create<fir::ResultOp>(loc, reductionVal);
505 builder.setInsertionPointAfter(ifOp);
506 // Redefine flagSet to escape scope of ifOp
507 flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
508 reductionVal = ifOp.getResult(0);
511 // Check for case where array was full of max values.
512 // flag will be 0 if mask was never true, 1 if mask was true as some point,
513 // this is needed to avoid catching cases where we didn't access any elements
514 // e.g. mask=.FALSE.
515 mlir::Value flagValue =
516 builder.create<fir::LoadOp>(loc, resultElemType, flagRef);
517 mlir::Value flagCmp = builder.create<mlir::arith::CmpIOp>(
518 loc, mlir::arith::CmpIPredicate::eq, flagValue, flagSet);
519 fir::IfOp ifMaskTrueOp =
520 builder.create<fir::IfOp>(loc, flagCmp, /*withElseRegion=*/false);
521 builder.setInsertionPointToStart(&ifMaskTrueOp.getThenRegion().front());
523 mlir::Value testInit = initVal(builder, loc, elementType);
524 fir::IfOp ifMinSetOp;
525 if (elementType.isa<mlir::FloatType>()) {
526 mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
527 loc, mlir::arith::CmpFPredicate::OEQ, testInit, reductionVal);
528 ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
529 /*withElseRegion*/ false);
530 } else {
531 mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
532 loc, mlir::arith::CmpIPredicate::eq, testInit, reductionVal);
533 ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
534 /*withElseRegion*/ false);
536 builder.setInsertionPointToStart(&ifMinSetOp.getThenRegion().front());
538 // Load output array with 1s instead of 0s
539 for (unsigned int i = 0; i < rank; ++i) {
540 mlir::Type resultRefTy = builder.getRefType(resultElemType);
541 // mlir::Value one = builder.createIntegerConstant(loc, resultElemType, 1);
542 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
543 mlir::Value resultElemAddr =
544 builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index);
545 builder.create<fir::StoreOp>(loc, flagSet, resultElemAddr);
547 builder.setInsertionPointAfter(ifMaskTrueOp);
548 // Store newly created output array to the reference passed in
549 fir::SequenceType::Shape resultShape(1, rank);
550 mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemType);
551 mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy);
552 mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy);
553 mlir::Type outputRefTy = builder.getRefType(outputBoxTy);
555 mlir::Value outputArrNone = args[0];
556 mlir::Value outputArr =
557 builder.create<fir::ConvertOp>(loc, outputRefTy, outputArrNone);
559 // Store nearly created array to output array
560 builder.create<fir::StoreOp>(loc, resultArr, outputArr);
561 builder.create<mlir::func::ReturnOp>(loc);
564 static llvm::SmallVector<mlir::Value> nopLoopCond(fir::FirOpBuilder &builder,
565 mlir::Location loc,
566 mlir::Value reductionVal) {
567 return {reductionVal};
570 /// Generate function body of the simplified version of RTNAME(Sum)
571 /// with signature provided by \p funcOp. The caller is responsible
572 /// for saving/restoring the original insertion point of \p builder.
573 /// \p funcOp is expected to be empty on entry to this function.
574 /// \p rank specifies the rank of the input argument.
575 static void genRuntimeSumBody(fir::FirOpBuilder &builder,
576 mlir::func::FuncOp &funcOp, unsigned rank,
577 mlir::Type elementType) {
578 // function RTNAME(Sum)<T>x<rank>_simplified(arr)
579 // T, dimension(:) :: arr
580 // T sum = 0
581 // integer iter
582 // do iter = 0, extent(arr)
583 // sum = sum + arr[iter]
584 // end do
585 // RTNAME(Sum)<T>x<rank>_simplified = sum
586 // end function RTNAME(Sum)<T>x<rank>_simplified
587 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
588 mlir::Type elementType) {
589 if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
590 const llvm::fltSemantics &sem = ty.getFloatSemantics();
591 return builder.createRealConstant(loc, elementType,
592 llvm::APFloat::getZero(sem));
594 return builder.createIntegerConstant(loc, elementType, 0);
597 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
598 mlir::Type elementType, mlir::Value elem1,
599 mlir::Value elem2) -> mlir::Value {
600 if (elementType.isa<mlir::FloatType>())
601 return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2);
602 if (elementType.isa<mlir::IntegerType>())
603 return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2);
605 llvm_unreachable("unsupported type");
606 return {};
609 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
610 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
612 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
613 false, genBodyOp, rank, elementType,
614 loc);
617 static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
618 mlir::func::FuncOp &funcOp, unsigned rank,
619 mlir::Type elementType) {
620 auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
621 mlir::Type elementType) {
622 if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
623 const llvm::fltSemantics &sem = ty.getFloatSemantics();
624 return builder.createRealConstant(
625 loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true));
627 unsigned bits = elementType.getIntOrFloatBitWidth();
628 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
629 return builder.createIntegerConstant(loc, elementType, minInt);
632 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
633 mlir::Type elementType, mlir::Value elem1,
634 mlir::Value elem2) -> mlir::Value {
635 if (elementType.isa<mlir::FloatType>())
636 return builder.create<mlir::arith::MaxFOp>(loc, elem1, elem2);
637 if (elementType.isa<mlir::IntegerType>())
638 return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2);
640 llvm_unreachable("unsupported type");
641 return {};
644 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
645 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
647 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, init, nopLoopCond,
648 false, genBodyOp, rank, elementType,
649 loc);
652 static void genRuntimeCountBody(fir::FirOpBuilder &builder,
653 mlir::func::FuncOp &funcOp, unsigned rank,
654 mlir::Type elementType) {
655 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
656 mlir::Type elementType) {
657 unsigned bits = elementType.getIntOrFloatBitWidth();
658 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
659 return builder.createIntegerConstant(loc, elementType, zeroInt);
662 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
663 mlir::Type elementType, mlir::Value elem1,
664 mlir::Value elem2) -> mlir::Value {
665 auto zero32 = builder.createIntegerConstant(loc, elementType, 0);
666 auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0);
667 auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1);
669 auto compare = builder.create<mlir::arith::CmpIOp>(
670 loc, mlir::arith::CmpIPredicate::eq, elem1, zero32);
671 auto select =
672 builder.create<mlir::arith::SelectOp>(loc, compare, zero64, one64);
673 return builder.create<mlir::arith::AddIOp>(loc, select, elem2);
676 // Count always gets I32 for elementType as it converts logical input to
677 // logical<4> before passing to the function.
678 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
679 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
681 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
682 false, genBodyOp, rank, elementType,
683 loc);
686 static void genRuntimeAnyBody(fir::FirOpBuilder &builder,
687 mlir::func::FuncOp &funcOp, unsigned rank,
688 mlir::Type elementType) {
689 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
690 mlir::Type elementType) {
691 return builder.createIntegerConstant(loc, elementType, 0);
694 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
695 mlir::Type elementType, mlir::Value elem1,
696 mlir::Value elem2) -> mlir::Value {
697 auto zero = builder.createIntegerConstant(loc, elementType, 0);
698 return builder.create<mlir::arith::CmpIOp>(
699 loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
702 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
703 mlir::Value reductionVal) {
704 auto one1 = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
705 auto eor = builder.create<mlir::arith::XOrIOp>(loc, reductionVal, one1);
706 llvm::SmallVector<mlir::Value> results = {eor, reductionVal};
707 return results;
710 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
711 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
712 mlir::Value ok = builder.createBool(loc, true);
714 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
715 builder, funcOp, zero, continueCond, ok, genBodyOp, rank, elementType,
716 loc);
719 static void genRuntimeAllBody(fir::FirOpBuilder &builder,
720 mlir::func::FuncOp &funcOp, unsigned rank,
721 mlir::Type elementType) {
722 auto one = [](fir::FirOpBuilder builder, mlir::Location loc,
723 mlir::Type elementType) {
724 return builder.createIntegerConstant(loc, elementType, 1);
727 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
728 mlir::Type elementType, mlir::Value elem1,
729 mlir::Value elem2) -> mlir::Value {
730 auto zero = builder.createIntegerConstant(loc, elementType, 0);
731 return builder.create<mlir::arith::CmpIOp>(
732 loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
735 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
736 mlir::Value reductionVal) {
737 llvm::SmallVector<mlir::Value> results = {reductionVal, reductionVal};
738 return results;
741 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
742 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
743 mlir::Value ok = builder.createBool(loc, true);
745 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
746 builder, funcOp, one, continueCond, ok, genBodyOp, rank, elementType,
747 loc);
750 static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder,
751 unsigned int rank) {
752 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
753 mlir::Type boxRefType = builder.getRefType(boxType);
755 return mlir::FunctionType::get(builder.getContext(),
756 {boxRefType, boxType, boxType}, {});
759 static void genRuntimeMinlocBody(fir::FirOpBuilder &builder,
760 mlir::func::FuncOp &funcOp, unsigned rank,
761 int maskRank, mlir::Type elementType,
762 mlir::Type maskElemType,
763 mlir::Type resultElemTy) {
764 auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
765 mlir::Type elementType) {
766 if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
767 const llvm::fltSemantics &sem = ty.getFloatSemantics();
768 return builder.createRealConstant(
769 loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/false));
771 unsigned bits = elementType.getIntOrFloatBitWidth();
772 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
773 return builder.createIntegerConstant(loc, elementType, maxInt);
776 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
777 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
779 mlir::Value mask = funcOp.front().getArgument(2);
781 // Set up result array in case of early exit / 0 length array
782 mlir::IndexType idxTy = builder.getIndexType();
783 mlir::Type resultTy = fir::SequenceType::get(rank, resultElemTy);
784 mlir::Type resultHeapTy = fir::HeapType::get(resultTy);
785 mlir::Type resultBoxTy = fir::BoxType::get(resultHeapTy);
787 mlir::Value returnValue = builder.createIntegerConstant(loc, resultElemTy, 0);
788 mlir::Value resultArrSize = builder.createIntegerConstant(loc, idxTy, rank);
790 mlir::Value resultArrInit = builder.create<fir::AllocMemOp>(loc, resultTy);
791 mlir::Value resultArrShape = builder.create<fir::ShapeOp>(loc, resultArrSize);
792 mlir::Value resultArr = builder.create<fir::EmboxOp>(
793 loc, resultBoxTy, resultArrInit, resultArrShape);
795 mlir::Type resultRefTy = builder.getRefType(resultElemTy);
797 for (unsigned int i = 0; i < rank; ++i) {
798 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
799 mlir::Value resultElemAddr =
800 builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index);
801 builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr);
804 auto genBodyOp =
805 [&rank, &resultArr](
806 fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType,
807 mlir::Value elem1, mlir::Value elem2,
808 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices)
809 -> mlir::Value {
810 mlir::Value cmp;
811 if (elementType.isa<mlir::FloatType>()) {
812 cmp = builder.create<mlir::arith::CmpFOp>(
813 loc, mlir::arith::CmpFPredicate::OLT, elem1, elem2);
814 } else if (elementType.isa<mlir::IntegerType>()) {
815 cmp = builder.create<mlir::arith::CmpIOp>(
816 loc, mlir::arith::CmpIPredicate::slt, elem1, elem2);
817 } else {
818 llvm_unreachable("unsupported type");
821 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
822 /*withElseRegion*/ true);
824 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
825 mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
826 mlir::Type returnRefTy = builder.getRefType(resultElemTy);
827 mlir::IndexType idxTy = builder.getIndexType();
829 mlir::Value one = builder.createIntegerConstant(loc, resultElemTy, 1);
831 for (unsigned int i = 0; i < rank; ++i) {
832 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
833 mlir::Value resultElemAddr =
834 builder.create<fir::CoordinateOp>(loc, returnRefTy, resultArr, index);
835 mlir::Value convert =
836 builder.create<fir::ConvertOp>(loc, resultElemTy, indices[i]);
837 mlir::Value fortranIndex =
838 builder.create<mlir::arith::AddIOp>(loc, convert, one);
839 builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr);
841 builder.create<fir::ResultOp>(loc, elem1);
842 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
843 builder.create<fir::ResultOp>(loc, elem2);
844 builder.setInsertionPointAfter(ifOp);
845 return ifOp.getResult(0);
848 // if mask is a logical scalar, we can check its value before the main loop
849 // and either ignore the fact it is there or exit early.
850 if (maskRank == 0) {
851 mlir::Type logical = builder.getI1Type();
852 mlir::IndexType idxTy = builder.getIndexType();
854 fir::SequenceType::Shape singleElement(1, 1);
855 mlir::Type arrTy = fir::SequenceType::get(singleElement, logical);
856 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
857 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, mask);
859 mlir::Value indx = builder.createIntegerConstant(loc, idxTy, 0);
860 mlir::Type logicalRefTy = builder.getRefType(logical);
861 mlir::Value condAddr =
862 builder.create<fir::CoordinateOp>(loc, logicalRefTy, array, indx);
863 mlir::Value cond = builder.create<fir::LoadOp>(loc, condAddr);
865 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cond,
866 /*withElseRegion=*/true);
868 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
869 mlir::Value basicValue;
870 if (elementType.isa<mlir::IntegerType>()) {
871 basicValue = builder.createIntegerConstant(loc, elementType, 0);
872 } else {
873 basicValue = builder.createRealConstant(loc, elementType, 0);
875 builder.create<fir::ResultOp>(loc, basicValue);
877 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
880 // bit of a hack - maskRank is set to -1 for absent mask arg, so don't
881 // generate high level mask or element by element mask.
882 bool hasMask = maskRank > 0;
884 genMinlocReductionLoop(builder, funcOp, init, genBodyOp, rank, elementType,
885 loc, hasMask, maskElemType, resultArr);
888 /// Generate function type for the simplified version of RTNAME(DotProduct)
889 /// operating on the given \p elementType.
890 static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
891 const mlir::Type &elementType) {
892 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
893 return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
894 {elementType});
897 /// Generate function body of the simplified version of RTNAME(DotProduct)
898 /// with signature provided by \p funcOp. The caller is responsible
899 /// for saving/restoring the original insertion point of \p builder.
900 /// \p funcOp is expected to be empty on entry to this function.
901 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
902 /// of the underlying array objects - they are used to generate proper
903 /// element accesses.
904 static void genRuntimeDotBody(fir::FirOpBuilder &builder,
905 mlir::func::FuncOp &funcOp,
906 mlir::Type arg1ElementTy,
907 mlir::Type arg2ElementTy) {
908 // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
909 // T, dimension(:) :: arr1, arr2
910 // T product = 0
911 // integer iter
912 // do iter = 0, extent(arr1)
913 // product = product + arr1[iter] * arr2[iter]
914 // end do
915 // RTNAME(ADotProduct)<T>_simplified = product
916 // end function RTNAME(DotProduct)<T>_simplified
917 auto loc = mlir::UnknownLoc::get(builder.getContext());
918 mlir::Type resultElementType = funcOp.getResultTypes()[0];
919 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
921 mlir::IndexType idxTy = builder.getIndexType();
923 mlir::Value zero =
924 resultElementType.isa<mlir::FloatType>()
925 ? builder.createRealConstant(loc, resultElementType, 0.0)
926 : builder.createIntegerConstant(loc, resultElementType, 0);
928 mlir::Block::BlockArgListType args = funcOp.front().getArguments();
929 mlir::Value arg1 = args[0];
930 mlir::Value arg2 = args[1];
932 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
934 fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
935 mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
936 mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
937 mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
938 mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
939 mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
940 mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
941 // This version takes the loop trip count from the first argument.
942 // If the first argument's box has unknown (at compilation time)
943 // extent, then it may be better to take the extent from the second
944 // argument - so that after inlining the loop may be better optimized, e.g.
945 // fully unrolled. This requires generating two versions of the simplified
946 // function and some analysis at the call site to choose which version
947 // is more profitable to call.
948 // Note that we can assume that both arguments have the same extent.
949 auto dims =
950 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
951 mlir::Value len = dims.getResult(1);
952 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
953 mlir::Value step = one;
955 // We use C indexing here, so len-1 as loopcount
956 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
957 auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
958 /*unordered=*/false,
959 /*finalCountValue=*/false, zero);
960 mlir::Value sumVal = loop.getRegionIterArgs()[0];
962 // Begin loop code
963 mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
964 builder.setInsertionPointToStart(loop.getBody());
966 mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
967 mlir::Value index = loop.getInductionVar();
968 mlir::Value addr1 =
969 builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
970 mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
971 // Convert to the result type.
972 elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
974 mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
975 mlir::Value addr2 =
976 builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
977 mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
978 // Convert to the result type.
979 elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
981 if (resultElementType.isa<mlir::FloatType>())
982 sumVal = builder.create<mlir::arith::AddFOp>(
983 loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
984 else if (resultElementType.isa<mlir::IntegerType>())
985 sumVal = builder.create<mlir::arith::AddIOp>(
986 loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
987 else
988 llvm_unreachable("unsupported type");
990 builder.create<fir::ResultOp>(loc, sumVal);
991 // End of loop.
992 builder.restoreInsertionPoint(loopEndPt);
994 mlir::Value resultVal = loop.getResult(0);
995 builder.create<mlir::func::ReturnOp>(loc, resultVal);
998 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
999 fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
1000 FunctionTypeGeneratorTy typeGenerator,
1001 FunctionBodyGeneratorTy bodyGenerator) {
1002 // WARNING: if the function generated here changes its signature
1003 // or behavior (the body code), we should probably embed some
1004 // versioning information into its name, otherwise libraries
1005 // statically linked with older versions of Flang may stop
1006 // working with object files created with newer Flang.
1007 // We can also avoid this by using internal linkage, but
1008 // this may increase the size of final executable/shared library.
1009 std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
1010 mlir::ModuleOp module = builder.getModule();
1011 // If we already have a function, just return it.
1012 mlir::func::FuncOp newFunc =
1013 fir::FirOpBuilder::getNamedFunction(module, replacementName);
1014 mlir::FunctionType fType = typeGenerator(builder);
1015 if (newFunc) {
1016 assert(newFunc.getFunctionType() == fType &&
1017 "type mismatch for simplified function");
1018 return newFunc;
1021 // Need to build the function!
1022 auto loc = mlir::UnknownLoc::get(builder.getContext());
1023 newFunc =
1024 fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
1025 auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
1026 auto linkage =
1027 mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
1028 newFunc->setAttr("llvm.linkage", linkage);
1030 // Save the position of the original call.
1031 mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
1033 bodyGenerator(builder, newFunc);
1035 // Now back to where we were adding code earlier...
1036 builder.restoreInsertionPoint(insertPt);
1038 return newFunc;
1041 void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction(
1042 fir::CallOp call, const fir::KindMapping &kindMap,
1043 GenReductionBodyTy genBodyFunc) {
1044 // args[1] and args[2] are source filename and line number, ignored.
1045 mlir::Operation::operand_range args = call.getArgs();
1047 const mlir::Value &dim = args[3];
1048 const mlir::Value &mask = args[4];
1049 // dim is zero when it is absent, which is an implementation
1050 // detail in the runtime library.
1052 bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
1053 unsigned rank = getDimCount(args[0]);
1055 // Rank is set to 0 for assumed shape arrays, don't simplify
1056 // in these cases
1057 if (!(dimAndMaskAbsent && rank > 0))
1058 return;
1060 mlir::Type resultType = call.getResult(0).getType();
1062 if (!resultType.isa<mlir::FloatType>() &&
1063 !resultType.isa<mlir::IntegerType>())
1064 return;
1066 auto argType = getArgElementType(args[0]);
1067 if (!argType)
1068 return;
1069 assert(*argType == resultType &&
1070 "Argument/result types mismatch in reduction");
1072 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1074 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1075 std::string fmfString{getFastMathFlagsString(builder)};
1076 std::string funcName =
1077 (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1078 mlir::Twine{rank} +
1079 // We must mangle the generated function name with FastMathFlags
1080 // value.
1081 (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString}))
1082 .str();
1084 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1085 resultType);
1088 void SimplifyIntrinsicsPass::simplifyLogicalDim0Reduction(
1089 fir::CallOp call, const fir::KindMapping &kindMap,
1090 GenReductionBodyTy genBodyFunc) {
1092 mlir::Operation::operand_range args = call.getArgs();
1093 const mlir::Value &dim = args[3];
1094 unsigned rank = getDimCount(args[0]);
1096 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1097 // these cases.
1098 if (!(isZero(dim) && rank > 0))
1099 return;
1101 mlir::Value inputBox = findBoxDef(args[0]);
1103 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1104 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1106 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1108 // Treating logicals as integers makes things a lot easier
1109 fir::LogicalType logicalType = {elementType.dyn_cast<fir::LogicalType>()};
1110 fir::KindTy kind = logicalType.getFKind();
1111 mlir::Type intElementType = builder.getIntegerType(kind * 8);
1113 // Mangle kind into function name as it is not done by default
1114 std::string funcName =
1115 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1116 mlir::Twine{kind} + "x" + mlir::Twine{rank})
1117 .str();
1119 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1120 intElementType);
1123 void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction(
1124 fir::CallOp call, const fir::KindMapping &kindMap,
1125 GenReductionBodyTy genBodyFunc) {
1127 mlir::Operation::operand_range args = call.getArgs();
1128 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1129 mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1130 unsigned rank = getDimCount(args[0]);
1132 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1133 // these cases. We check for Dim at the end as some logical functions (Any,
1134 // All) set dim to 1 instead of 0 when the argument is not present.
1135 if (funcNameBase.ends_with("Dim") || !(rank > 0))
1136 return;
1138 mlir::Value inputBox = findBoxDef(args[0]);
1139 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1141 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1143 // Treating logicals as integers makes things a lot easier
1144 fir::LogicalType logicalType = {elementType.dyn_cast<fir::LogicalType>()};
1145 fir::KindTy kind = logicalType.getFKind();
1146 mlir::Type intElementType = builder.getIntegerType(kind * 8);
1148 // Mangle kind into function name as it is not done by default
1149 std::string funcName =
1150 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1151 mlir::Twine{kind} + "x" + mlir::Twine{rank})
1152 .str();
1154 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1155 intElementType);
1158 void SimplifyIntrinsicsPass::simplifyMinlocReduction(
1159 fir::CallOp call, const fir::KindMapping &kindMap) {
1161 mlir::Operation::operand_range args = call.getArgs();
1163 mlir::Value back = args[6];
1164 if (isTrueOrNotConstant(back))
1165 return;
1167 mlir::Value mask = args[5];
1168 mlir::Value maskDef = findMaskDef(mask);
1170 // maskDef is set to NULL when the defining op is not one we accept.
1171 // This tends to be because it is a selectOp, in which case let the
1172 // runtime deal with it.
1173 if (maskDef == NULL)
1174 return;
1176 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1177 mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1178 unsigned rank = getDimCount(args[1]);
1179 if (funcNameBase.ends_with("Dim") || !(rank > 0))
1180 return;
1182 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1183 mlir::Location loc = call.getLoc();
1184 auto inputBox = findBoxDef(args[1]);
1185 mlir::Type inputType = hlfir::getFortranElementType(inputBox.getType());
1187 if (inputType.isa<fir::CharacterType>())
1188 return;
1190 int maskRank;
1191 fir::KindTy kind = 0;
1192 mlir::Type logicalElemType = builder.getI1Type();
1193 if (isOperandAbsent(mask)) {
1194 maskRank = -1;
1195 } else {
1196 maskRank = getDimCount(mask);
1197 mlir::Type maskElemTy = hlfir::getFortranElementType(maskDef.getType());
1198 fir::LogicalType logicalFirType = {maskElemTy.dyn_cast<fir::LogicalType>()};
1199 kind = logicalFirType.getFKind();
1200 // Convert fir::LogicalType to mlir::Type
1201 logicalElemType = logicalFirType;
1204 mlir::Operation *outputDef = args[0].getDefiningOp();
1205 mlir::Value outputAlloc = outputDef->getOperand(0);
1206 mlir::Type outType = hlfir::getFortranElementType(outputAlloc.getType());
1208 std::string fmfString{getFastMathFlagsString(builder)};
1209 std::string funcName =
1210 (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1211 mlir::Twine{rank} +
1212 (maskRank >= 0
1213 ? "_Logical" + mlir::Twine{kind} + "x" + mlir::Twine{maskRank}
1214 : "") +
1215 "_")
1216 .str();
1218 llvm::raw_string_ostream nameOS(funcName);
1219 outType.print(nameOS);
1220 nameOS << '_' << fmfString;
1222 auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
1223 return genRuntimeMinlocType(builder, rank);
1225 auto bodyGenerator = [rank, maskRank, inputType, logicalElemType,
1226 outType](fir::FirOpBuilder &builder,
1227 mlir::func::FuncOp &funcOp) {
1228 genRuntimeMinlocBody(builder, funcOp, rank, maskRank, inputType,
1229 logicalElemType, outType);
1232 mlir::func::FuncOp newFunc =
1233 getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
1234 builder.create<fir::CallOp>(loc, newFunc,
1235 mlir::ValueRange{args[0], args[1], args[5]});
1236 call->dropAllReferences();
1237 call->erase();
1240 void SimplifyIntrinsicsPass::simplifyReductionBody(
1241 fir::CallOp call, const fir::KindMapping &kindMap,
1242 GenReductionBodyTy genBodyFunc, fir::FirOpBuilder &builder,
1243 const mlir::StringRef &funcName, mlir::Type elementType) {
1245 mlir::Operation::operand_range args = call.getArgs();
1247 mlir::Type resultType = call.getResult(0).getType();
1248 unsigned rank = getDimCount(args[0]);
1250 mlir::Location loc = call.getLoc();
1252 auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
1253 return genNoneBoxType(builder, resultType);
1255 auto bodyGenerator = [&rank, &genBodyFunc,
1256 &elementType](fir::FirOpBuilder &builder,
1257 mlir::func::FuncOp &funcOp) {
1258 genBodyFunc(builder, funcOp, rank, elementType);
1260 // Mangle the function name with the rank value as "x<rank>".
1261 mlir::func::FuncOp newFunc =
1262 getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
1263 auto newCall =
1264 builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
1265 call->replaceAllUsesWith(newCall.getResults());
1266 call->dropAllReferences();
1267 call->erase();
1270 void SimplifyIntrinsicsPass::runOnOperation() {
1271 LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
1272 mlir::ModuleOp module = getOperation();
1273 fir::KindMapping kindMap = fir::getKindMapping(module);
1274 module.walk([&](mlir::Operation *op) {
1275 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
1276 if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
1277 mlir::StringRef funcName = callee.getLeafReference().getValue();
1278 // Replace call to runtime function for SUM when it has single
1279 // argument (no dim or mask argument) for 1D arrays with either
1280 // Integer4 or Real8 types. Other forms are ignored.
1281 // The new function is added to the module.
1283 // Prototype for runtime call (from sum.cpp):
1284 // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
1285 // int dim, const Descriptor *mask)
1287 if (funcName.startswith(RTNAME_STRING(Sum))) {
1288 simplifyIntOrFloatReduction(call, kindMap, genRuntimeSumBody);
1289 return;
1291 if (funcName.startswith(RTNAME_STRING(DotProduct))) {
1292 LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
1293 LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
1294 llvm::dbgs() << "\n");
1295 mlir::Operation::operand_range args = call.getArgs();
1296 const mlir::Value &v1 = args[0];
1297 const mlir::Value &v2 = args[1];
1298 mlir::Location loc = call.getLoc();
1299 fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)};
1300 // Stringize the builder's FastMathFlags flags for mangling
1301 // the generated function name.
1302 std::string fmfString{getFastMathFlagsString(builder)};
1304 mlir::Type type = call.getResult(0).getType();
1305 if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
1306 return;
1308 // Try to find the element types of the boxed arguments.
1309 auto arg1Type = getArgElementType(v1);
1310 auto arg2Type = getArgElementType(v2);
1312 if (!arg1Type || !arg2Type)
1313 return;
1315 // Support only floating point and integer arguments
1316 // now (e.g. logical is skipped here).
1317 if (!arg1Type->isa<mlir::FloatType>() &&
1318 !arg1Type->isa<mlir::IntegerType>())
1319 return;
1320 if (!arg2Type->isa<mlir::FloatType>() &&
1321 !arg2Type->isa<mlir::IntegerType>())
1322 return;
1324 auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
1325 return genRuntimeDotType(builder, type);
1327 auto bodyGenerator = [&arg1Type,
1328 &arg2Type](fir::FirOpBuilder &builder,
1329 mlir::func::FuncOp &funcOp) {
1330 genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type);
1333 // Suffix the function name with the element types
1334 // of the arguments.
1335 std::string typedFuncName(funcName);
1336 llvm::raw_string_ostream nameOS(typedFuncName);
1337 // We must mangle the generated function name with FastMathFlags
1338 // value.
1339 if (!fmfString.empty())
1340 nameOS << '_' << fmfString;
1341 nameOS << '_';
1342 arg1Type->print(nameOS);
1343 nameOS << '_';
1344 arg2Type->print(nameOS);
1346 mlir::func::FuncOp newFunc = getOrCreateFunction(
1347 builder, typedFuncName, typeGenerator, bodyGenerator);
1348 auto newCall = builder.create<fir::CallOp>(loc, newFunc,
1349 mlir::ValueRange{v1, v2});
1350 call->replaceAllUsesWith(newCall.getResults());
1351 call->dropAllReferences();
1352 call->erase();
1354 LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
1355 llvm::dbgs() << "\n");
1356 return;
1358 if (funcName.startswith(RTNAME_STRING(Maxval))) {
1359 simplifyIntOrFloatReduction(call, kindMap, genRuntimeMaxvalBody);
1360 return;
1362 if (funcName.startswith(RTNAME_STRING(Count))) {
1363 simplifyLogicalDim0Reduction(call, kindMap, genRuntimeCountBody);
1364 return;
1366 if (funcName.startswith(RTNAME_STRING(Any))) {
1367 simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAnyBody);
1368 return;
1370 if (funcName.endswith(RTNAME_STRING(All))) {
1371 simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAllBody);
1372 return;
1374 if (funcName.startswith(RTNAME_STRING(Minloc))) {
1375 simplifyMinlocReduction(call, kindMap);
1376 return;
1381 LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
1384 void SimplifyIntrinsicsPass::getDependentDialects(
1385 mlir::DialectRegistry &registry) const {
1386 // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
1387 registry.insert<mlir::LLVM::LLVMDialect>();
1389 std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() {
1390 return std::make_unique<SimplifyIntrinsicsPass>();