[lldb] Add ability to hide the root name of a value
[llvm-project.git] / flang / lib / Optimizer / Transforms / StackArrays.cpp
blob60a30d2d1ef64e83ffdb17f818a9be7d23621580
1 //===- StackArrays.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
11 #include "flang/Optimizer/Dialect/FIRAttr.h"
12 #include "flang/Optimizer/Dialect/FIRDialect.h"
13 #include "flang/Optimizer/Dialect/FIROps.h"
14 #include "flang/Optimizer/Dialect/FIRType.h"
15 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
16 #include "flang/Optimizer/Transforms/Passes.h"
17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
20 #include "mlir/Analysis/DataFlowFramework.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Support/LogicalResult.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "mlir/Transforms/Passes.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/PointerUnion.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/raw_ostream.h"
36 #include <optional>
38 namespace fir {
39 #define GEN_PASS_DEF_STACKARRAYS
40 #include "flang/Optimizer/Transforms/Passes.h.inc"
41 } // namespace fir
43 #define DEBUG_TYPE "stack-arrays"
45 namespace {
47 /// The state of an SSA value at each program point
48 enum class AllocationState {
49 /// This means that the allocation state of a variable cannot be determined
50 /// at this program point, e.g. because one route through a conditional freed
51 /// the variable and the other route didn't.
52 /// This asserts a known-unknown: different from the unknown-unknown of having
53 /// no AllocationState stored for a particular SSA value
54 Unknown,
55 /// Means this SSA value was allocated on the heap in this function and has
56 /// now been freed
57 Freed,
58 /// Means this SSA value was allocated on the heap in this function and is a
59 /// candidate for moving to the stack
60 Allocated,
63 /// Stores where an alloca should be inserted. If the PointerUnion is an
64 /// Operation the alloca should be inserted /after/ the operation. If it is a
65 /// block, the alloca can be placed anywhere in that block.
66 class InsertionPoint {
67 llvm::PointerUnion<mlir::Operation *, mlir::Block *> location;
68 bool saveRestoreStack;
70 /// Get contained pointer type or nullptr
71 template <class T>
72 T *tryGetPtr() const {
73 if (location.is<T *>())
74 return location.get<T *>();
75 return nullptr;
78 public:
79 template <class T>
80 InsertionPoint(T *ptr, bool saveRestoreStack = false)
81 : location(ptr), saveRestoreStack{saveRestoreStack} {}
82 InsertionPoint(std::nullptr_t null)
83 : location(null), saveRestoreStack{false} {}
85 /// Get contained operation, or nullptr
86 mlir::Operation *tryGetOperation() const {
87 return tryGetPtr<mlir::Operation>();
90 /// Get contained block, or nullptr
91 mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); }
93 /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave
94 /// intrinsic should be added before the alloca, and an llvm.stackrestore
95 /// intrinsic should be added where the freemem is
96 bool shouldSaveRestoreStack() const { return saveRestoreStack; }
98 operator bool() const { return tryGetOperation() || tryGetBlock(); }
100 bool operator==(const InsertionPoint &rhs) const {
101 return (location == rhs.location) &&
102 (saveRestoreStack == rhs.saveRestoreStack);
105 bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); }
108 /// Maps SSA values to their AllocationState at a particular program point.
109 /// Also caches the insertion points for the new alloca operations
110 class LatticePoint : public mlir::dataflow::AbstractDenseLattice {
111 // Maps all values we are interested in to states
112 llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap;
114 public:
115 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint)
116 using AbstractDenseLattice::AbstractDenseLattice;
118 bool operator==(const LatticePoint &rhs) const {
119 return stateMap == rhs.stateMap;
122 /// Join the lattice accross control-flow edges
123 mlir::ChangeResult join(const AbstractDenseLattice &lattice) override;
125 void print(llvm::raw_ostream &os) const override;
127 /// Clear all modifications
128 mlir::ChangeResult reset();
130 /// Set the state of an SSA value
131 mlir::ChangeResult set(mlir::Value value, AllocationState state);
133 /// Get fir.allocmem ops which were allocated in this function and always
134 /// freed before the function returns, plus whre to insert replacement
135 /// fir.alloca ops
136 void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const;
138 std::optional<AllocationState> get(mlir::Value val) const;
141 class AllocationAnalysis
142 : public mlir::dataflow::DenseDataFlowAnalysis<LatticePoint> {
143 public:
144 using DenseDataFlowAnalysis::DenseDataFlowAnalysis;
146 void visitOperation(mlir::Operation *op, const LatticePoint &before,
147 LatticePoint *after) override;
149 /// At an entry point, the last modifications of all memory resources are
150 /// yet to be determined
151 void setToEntryState(LatticePoint *lattice) override;
153 protected:
154 /// Visit control flow operations and decide whether to call visitOperation
155 /// to apply the transfer function
156 void processOperation(mlir::Operation *op) override;
159 /// Drives analysis to find candidate fir.allocmem operations which could be
160 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis
161 class StackArraysAnalysisWrapper {
162 public:
163 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper)
165 // Maps fir.allocmem -> place to insert alloca
166 using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>;
168 StackArraysAnalysisWrapper(mlir::Operation *op) {}
170 bool hasErrors() const;
172 const AllocMemMap &getCandidateOps(mlir::Operation *func);
174 private:
175 llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
176 bool gotError = false;
178 void analyseFunction(mlir::Operation *func);
181 /// Converts a fir.allocmem to a fir.alloca
182 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
183 public:
184 using OpRewritePattern::OpRewritePattern;
186 AllocMemConversion(
187 mlir::MLIRContext *ctx,
188 const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps);
190 mlir::LogicalResult
191 matchAndRewrite(fir::AllocMemOp allocmem,
192 mlir::PatternRewriter &rewriter) const override;
194 /// Determine where to insert the alloca operation. The returned value should
195 /// be checked to see if it is inside a loop
196 static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc);
198 private:
199 /// allocmem operations that DFA has determined are safe to move to the stack
200 /// mapping to where to insert replacement freemem operations
201 const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps;
203 /// If we failed to find an insertion point not inside a loop, see if it would
204 /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
205 static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc);
207 /// Returns the alloca if it was successfully inserted, otherwise {}
208 std::optional<fir::AllocaOp>
209 insertAlloca(fir::AllocMemOp &oldAlloc,
210 mlir::PatternRewriter &rewriter) const;
212 /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem
213 void insertStackSaveRestore(fir::AllocMemOp &oldAlloc,
214 mlir::PatternRewriter &rewriter) const;
217 class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> {
218 public:
219 StackArraysPass() = default;
220 StackArraysPass(const StackArraysPass &pass);
222 llvm::StringRef getDescription() const override;
224 void runOnOperation() override;
225 void runOnFunc(mlir::Operation *func);
227 private:
228 Statistic runCount{this, "stackArraysRunCount",
229 "Number of heap allocations moved to the stack"};
232 } // namespace
234 static void print(llvm::raw_ostream &os, AllocationState state) {
235 switch (state) {
236 case AllocationState::Unknown:
237 os << "Unknown";
238 break;
239 case AllocationState::Freed:
240 os << "Freed";
241 break;
242 case AllocationState::Allocated:
243 os << "Allocated";
244 break;
248 /// Join two AllocationStates for the same value coming from different CFG
249 /// blocks
250 static AllocationState join(AllocationState lhs, AllocationState rhs) {
251 // | Allocated | Freed | Unknown
252 // ========= | ========= | ========= | =========
253 // Allocated | Allocated | Unknown | Unknown
254 // Freed | Unknown | Freed | Unknown
255 // Unknown | Unknown | Unknown | Unknown
256 if (lhs == rhs)
257 return lhs;
258 return AllocationState::Unknown;
261 mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) {
262 const auto &rhs = static_cast<const LatticePoint &>(lattice);
263 mlir::ChangeResult changed = mlir::ChangeResult::NoChange;
265 // add everything from rhs to map, handling cases where values are in both
266 for (const auto &[value, rhsState] : rhs.stateMap) {
267 auto it = stateMap.find(value);
268 if (it != stateMap.end()) {
269 // value is present in both maps
270 AllocationState myState = it->second;
271 AllocationState newState = ::join(myState, rhsState);
272 if (newState != myState) {
273 changed = mlir::ChangeResult::Change;
274 it->getSecond() = newState;
276 } else {
277 // value not present in current map: add it
278 stateMap.insert({value, rhsState});
279 changed = mlir::ChangeResult::Change;
283 return changed;
286 void LatticePoint::print(llvm::raw_ostream &os) const {
287 for (const auto &[value, state] : stateMap) {
288 os << value << ": ";
289 ::print(os, state);
293 mlir::ChangeResult LatticePoint::reset() {
294 if (stateMap.empty())
295 return mlir::ChangeResult::NoChange;
296 stateMap.clear();
297 return mlir::ChangeResult::Change;
300 mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) {
301 if (stateMap.count(value)) {
302 // already in map
303 AllocationState &oldState = stateMap[value];
304 if (oldState != state) {
305 stateMap[value] = state;
306 return mlir::ChangeResult::Change;
308 return mlir::ChangeResult::NoChange;
310 stateMap.insert({value, state});
311 return mlir::ChangeResult::Change;
314 /// Get values which were allocated in this function and always freed before
315 /// the function returns
316 void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const {
317 for (auto &[value, state] : stateMap) {
318 if (state == AllocationState::Freed)
319 out.insert(value);
323 std::optional<AllocationState> LatticePoint::get(mlir::Value val) const {
324 auto it = stateMap.find(val);
325 if (it == stateMap.end())
326 return {};
327 return it->second;
330 void AllocationAnalysis::visitOperation(mlir::Operation *op,
331 const LatticePoint &before,
332 LatticePoint *after) {
333 LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
334 << "\n");
335 LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n");
337 // propagate before -> after
338 mlir::ChangeResult changed = after->join(before);
340 if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) {
341 assert(op->getNumResults() == 1 && "fir.allocmem has one result");
342 auto attr = op->getAttrOfType<fir::MustBeHeapAttr>(
343 fir::MustBeHeapAttr::getAttrName());
344 if (attr && attr.getValue()) {
345 LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n");
346 // skip allocation marked not to be moved
347 return;
350 auto retTy = allocmem.getAllocatedType();
351 if (!retTy.isa<fir::SequenceType>()) {
352 LLVM_DEBUG(llvm::dbgs()
353 << "--Allocation is not for an array: skipping\n");
354 return;
357 mlir::Value result = op->getResult(0);
358 changed |= after->set(result, AllocationState::Allocated);
359 } else if (mlir::isa<fir::FreeMemOp>(op)) {
360 assert(op->getNumOperands() == 1 && "fir.freemem has one operand");
361 mlir::Value operand = op->getOperand(0);
362 std::optional<AllocationState> operandState = before.get(operand);
363 if (operandState && *operandState == AllocationState::Allocated) {
364 // don't tag things not allocated in this function as freed, so that we
365 // don't think they are candidates for moving to the stack
366 changed |= after->set(operand, AllocationState::Freed);
368 } else if (mlir::isa<fir::ResultOp>(op)) {
369 mlir::Operation *parent = op->getParentOp();
370 LatticePoint *parentLattice = getLattice(parent);
371 assert(parentLattice);
372 mlir::ChangeResult parentChanged = parentLattice->join(*after);
373 propagateIfChanged(parentLattice, parentChanged);
376 // we pass lattices straight through fir.call because called functions should
377 // not deallocate flang-generated array temporaries
379 LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n");
380 propagateIfChanged(after, changed);
383 void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
384 propagateIfChanged(lattice, lattice->reset());
387 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference
388 /// being that call operations are passed through to the transfer function
389 void AllocationAnalysis::processOperation(mlir::Operation *op) {
390 // If the containing block is not executable, bail out.
391 if (!getOrCreateFor<mlir::dataflow::Executable>(op, op->getBlock())->isLive())
392 return;
394 // Get the dense lattice to update
395 mlir::dataflow::AbstractDenseLattice *after = getLattice(op);
397 // If this op implements region control-flow, then control-flow dictates its
398 // transfer function.
399 if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op))
400 return visitRegionBranchOperation(op, branch, after);
402 // pass call operations through to the transfer function
404 // Get the dense state before the execution of the op.
405 const mlir::dataflow::AbstractDenseLattice *before;
406 if (mlir::Operation *prev = op->getPrevNode())
407 before = getLatticeFor(op, prev);
408 else
409 before = getLatticeFor(op, op->getBlock());
411 /// Invoke the operation transfer function
412 visitOperationImpl(op, *before, after);
415 void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
416 assert(mlir::isa<mlir::func::FuncOp>(func));
417 mlir::DataFlowSolver solver;
418 // constant propagation is required for dead code analysis, dead code analysis
419 // is required to mark blocks live (required for mlir dense dfa)
420 solver.load<mlir::dataflow::SparseConstantPropagation>();
421 solver.load<mlir::dataflow::DeadCodeAnalysis>();
423 auto [it, inserted] = funcMaps.try_emplace(func);
424 AllocMemMap &candidateOps = it->second;
426 solver.load<AllocationAnalysis>();
427 if (failed(solver.initializeAndRun(func))) {
428 llvm::errs() << "DataFlowSolver failed!";
429 gotError = true;
430 return;
433 LatticePoint point{func};
434 auto joinOperationLattice = [&](mlir::Operation *op) {
435 const LatticePoint *lattice = solver.lookupState<LatticePoint>(op);
436 // there will be no lattice for an unreachable block
437 if (lattice)
438 point.join(*lattice);
440 func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); });
441 func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); });
442 llvm::DenseSet<mlir::Value> freedValues;
443 point.appendFreedValues(freedValues);
445 // We only replace allocations which are definately freed on all routes
446 // through the function because otherwise the allocation may have an intende
447 // lifetime longer than the current stack frame (e.g. a heap allocation which
448 // is then freed by another function).
449 for (mlir::Value freedValue : freedValues) {
450 fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>();
451 InsertionPoint insertionPoint =
452 AllocMemConversion::findAllocaInsertionPoint(allocmem);
453 if (insertionPoint)
454 candidateOps.insert({allocmem, insertionPoint});
457 LLVM_DEBUG(for (auto [allocMemOp, _]
458 : candidateOps) {
459 llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
463 bool StackArraysAnalysisWrapper::hasErrors() const { return gotError; }
465 const StackArraysAnalysisWrapper::AllocMemMap &
466 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
467 if (!funcMaps.count(func))
468 analyseFunction(func);
469 return funcMaps[func];
472 AllocMemConversion::AllocMemConversion(
473 mlir::MLIRContext *ctx,
474 const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps)
475 : OpRewritePattern(ctx), candidateOps(candidateOps) {}
477 mlir::LogicalResult
478 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
479 mlir::PatternRewriter &rewriter) const {
480 auto oldInsertionPt = rewriter.saveInsertionPoint();
481 // add alloca operation
482 std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter);
483 rewriter.restoreInsertionPoint(oldInsertionPt);
484 if (!alloca)
485 return mlir::failure();
487 // remove freemem operations
488 for (mlir::Operation *user : allocmem.getOperation()->getUsers())
489 if (mlir::isa<fir::FreeMemOp>(user))
490 rewriter.eraseOp(user);
492 // replace references to heap allocation with references to stack allocation
493 rewriter.replaceAllUsesWith(allocmem.getResult(), alloca->getResult());
495 // remove allocmem operation
496 rewriter.eraseOp(allocmem.getOperation());
498 return mlir::success();
501 static bool isInLoop(mlir::Block *block) {
502 return mlir::LoopLikeOpInterface::blockIsInLoop(block);
505 static bool isInLoop(mlir::Operation *op) {
506 return isInLoop(op->getBlock()) ||
507 op->getParentOfType<mlir::LoopLikeOpInterface>();
510 InsertionPoint
511 AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) {
512 // Ideally the alloca should be inserted at the end of the function entry
513 // block so that we do not allocate stack space in a loop. However,
514 // the operands to the alloca may not be available that early, so insert it
515 // after the last operand becomes available
516 // If the old allocmem op was in an openmp region then it should not be moved
517 // outside of that
518 LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: "
519 << oldAlloc << "\n");
521 // check that an Operation or Block we are about to return is not in a loop
522 auto checkReturn = [&](auto *point) -> InsertionPoint {
523 if (isInLoop(point)) {
524 mlir::Operation *oldAllocOp = oldAlloc.getOperation();
525 if (isInLoop(oldAllocOp)) {
526 // where we want to put it is in a loop, and even the old location is in
527 // a loop. Give up.
528 return findAllocaLoopInsertionPoint(oldAlloc);
530 return {oldAllocOp};
532 return {point};
535 auto oldOmpRegion =
536 oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
538 // Find when the last operand value becomes available
539 mlir::Block *operandsBlock = nullptr;
540 mlir::Operation *lastOperand = nullptr;
541 for (mlir::Value operand : oldAlloc.getOperands()) {
542 LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n");
543 mlir::Operation *op = operand.getDefiningOp();
544 if (!op)
545 return checkReturn(oldAlloc.getOperation());
546 if (!operandsBlock)
547 operandsBlock = op->getBlock();
548 else if (operandsBlock != op->getBlock()) {
549 LLVM_DEBUG(llvm::dbgs()
550 << "----operand declared in a different block!\n");
551 // Operation::isBeforeInBlock requires the operations to be in the same
552 // block. The best we can do is the location of the allocmem.
553 return checkReturn(oldAlloc.getOperation());
555 if (!lastOperand || lastOperand->isBeforeInBlock(op))
556 lastOperand = op;
559 if (lastOperand) {
560 // there were value operands to the allocmem so insert after the last one
561 LLVM_DEBUG(llvm::dbgs()
562 << "--Placing after last operand: " << *lastOperand << "\n");
563 // check we aren't moving out of an omp region
564 auto lastOpOmpRegion =
565 lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
566 if (lastOpOmpRegion == oldOmpRegion)
567 return checkReturn(lastOperand);
568 // Presumably this happened because the operands became ready before the
569 // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should
570 // imply that oldOmpRegion comes after lastOpOmpRegion.
571 return checkReturn(oldOmpRegion.getAllocaBlock());
574 // There were no value operands to the allocmem so we are safe to insert it
575 // as early as we want
577 // handle openmp case
578 if (oldOmpRegion)
579 return checkReturn(oldOmpRegion.getAllocaBlock());
581 // fall back to the function entry block
582 mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>();
583 assert(func && "This analysis is run on func.func");
584 mlir::Block &entryBlock = func.getBlocks().front();
585 LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n");
586 return checkReturn(&entryBlock);
589 InsertionPoint
590 AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) {
591 mlir::Operation *oldAllocOp = oldAlloc;
592 // This is only called as a last resort. We should try to insert at the
593 // location of the old allocation, which is inside of a loop, using
594 // llvm.stacksave/llvm.stackrestore
596 // find freemem ops
597 llvm::SmallVector<mlir::Operation *, 1> freeOps;
598 for (mlir::Operation *user : oldAllocOp->getUsers())
599 if (mlir::isa<fir::FreeMemOp>(user))
600 freeOps.push_back(user);
601 assert(freeOps.size() && "DFA should only return freed memory");
603 // Don't attempt to reason about a stacksave/stackrestore between different
604 // blocks
605 for (mlir::Operation *free : freeOps)
606 if (free->getBlock() != oldAllocOp->getBlock())
607 return {nullptr};
609 // Check that there aren't any other stack allocations in between the
610 // stack save and stack restore
611 // note: for flang generated temporaries there should only be one free op
612 for (mlir::Operation *free : freeOps) {
613 for (mlir::Operation *op = oldAlloc; op && op != free;
614 op = op->getNextNode()) {
615 if (mlir::isa<fir::AllocaOp>(op))
616 return {nullptr};
620 return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true};
623 std::optional<fir::AllocaOp>
624 AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc,
625 mlir::PatternRewriter &rewriter) const {
626 auto it = candidateOps.find(oldAlloc.getOperation());
627 if (it == candidateOps.end())
628 return {};
629 InsertionPoint insertionPoint = it->second;
630 if (!insertionPoint)
631 return {};
633 if (insertionPoint.shouldSaveRestoreStack())
634 insertStackSaveRestore(oldAlloc, rewriter);
636 mlir::Location loc = oldAlloc.getLoc();
637 mlir::Type varTy = oldAlloc.getInType();
638 if (mlir::Operation *op = insertionPoint.tryGetOperation()) {
639 rewriter.setInsertionPointAfter(op);
640 } else {
641 mlir::Block *block = insertionPoint.tryGetBlock();
642 assert(block && "There must be a valid insertion point");
643 rewriter.setInsertionPointToStart(block);
646 auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef {
647 if (opt)
648 return *opt;
649 return {};
652 llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName());
653 llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName());
654 return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName,
655 oldAlloc.getTypeparams(),
656 oldAlloc.getShape());
659 void AllocMemConversion::insertStackSaveRestore(
660 fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const {
661 auto oldPoint = rewriter.saveInsertionPoint();
662 auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>();
663 fir::KindMapping kindMap = fir::getKindMapping(mod);
664 fir::FirOpBuilder builder{rewriter, kindMap};
666 mlir::func::FuncOp stackSaveFn = fir::factory::getLlvmStackSave(builder);
667 mlir::SymbolRefAttr stackSaveSym =
668 builder.getSymbolRefAttr(stackSaveFn.getName());
670 builder.setInsertionPoint(oldAlloc);
671 mlir::Value sp =
672 builder
673 .create<fir::CallOp>(oldAlloc.getLoc(),
674 stackSaveFn.getFunctionType().getResults(),
675 stackSaveSym, mlir::ValueRange{})
676 .getResult(0);
678 mlir::func::FuncOp stackRestoreFn =
679 fir::factory::getLlvmStackRestore(builder);
680 mlir::SymbolRefAttr stackRestoreSym =
681 builder.getSymbolRefAttr(stackRestoreFn.getName());
683 for (mlir::Operation *user : oldAlloc->getUsers()) {
684 if (mlir::isa<fir::FreeMemOp>(user)) {
685 builder.setInsertionPoint(user);
686 builder.create<fir::CallOp>(user->getLoc(),
687 stackRestoreFn.getFunctionType().getResults(),
688 stackRestoreSym, mlir::ValueRange{sp});
692 rewriter.restoreInsertionPoint(oldPoint);
695 StackArraysPass::StackArraysPass(const StackArraysPass &pass)
696 : fir::impl::StackArraysBase<StackArraysPass>(pass) {}
698 llvm::StringRef StackArraysPass::getDescription() const {
699 return "Move heap allocated array temporaries to the stack";
702 void StackArraysPass::runOnOperation() {
703 mlir::ModuleOp mod = getOperation();
705 mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
708 void StackArraysPass::runOnFunc(mlir::Operation *func) {
709 assert(mlir::isa<mlir::func::FuncOp>(func));
711 auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
712 const auto &candidateOps = analysis.getCandidateOps(func);
713 if (analysis.hasErrors()) {
714 signalPassFailure();
715 return;
718 if (candidateOps.empty())
719 return;
720 runCount += candidateOps.size();
722 mlir::MLIRContext &context = getContext();
723 mlir::RewritePatternSet patterns(&context);
724 mlir::ConversionTarget target(context);
726 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
727 mlir::func::FuncDialect>();
728 target.addDynamicallyLegalOp<fir::AllocMemOp>([&](fir::AllocMemOp alloc) {
729 return !candidateOps.count(alloc.getOperation());
732 patterns.insert<AllocMemConversion>(&context, candidateOps);
733 if (mlir::failed(
734 mlir::applyPartialConversion(func, target, std::move(patterns)))) {
735 mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
736 signalPassFailure();
740 std::unique_ptr<mlir::Pass> fir::createStackArraysPass() {
741 return std::make_unique<StackArraysPass>();