1 //===- StackArrays.cpp ----------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
11 #include "flang/Optimizer/Dialect/FIRAttr.h"
12 #include "flang/Optimizer/Dialect/FIRDialect.h"
13 #include "flang/Optimizer/Dialect/FIROps.h"
14 #include "flang/Optimizer/Dialect/FIRType.h"
15 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
16 #include "flang/Optimizer/Transforms/Passes.h"
17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
20 #include "mlir/Analysis/DataFlowFramework.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Support/LogicalResult.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "mlir/Transforms/Passes.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/PointerUnion.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/raw_ostream.h"
39 #define GEN_PASS_DEF_STACKARRAYS
40 #include "flang/Optimizer/Transforms/Passes.h.inc"
43 #define DEBUG_TYPE "stack-arrays"
47 /// The state of an SSA value at each program point
48 enum class AllocationState
{
49 /// This means that the allocation state of a variable cannot be determined
50 /// at this program point, e.g. because one route through a conditional freed
51 /// the variable and the other route didn't.
52 /// This asserts a known-unknown: different from the unknown-unknown of having
53 /// no AllocationState stored for a particular SSA value
55 /// Means this SSA value was allocated on the heap in this function and has
58 /// Means this SSA value was allocated on the heap in this function and is a
59 /// candidate for moving to the stack
63 /// Stores where an alloca should be inserted. If the PointerUnion is an
64 /// Operation the alloca should be inserted /after/ the operation. If it is a
65 /// block, the alloca can be placed anywhere in that block.
66 class InsertionPoint
{
67 llvm::PointerUnion
<mlir::Operation
*, mlir::Block
*> location
;
68 bool saveRestoreStack
;
70 /// Get contained pointer type or nullptr
72 T
*tryGetPtr() const {
73 if (location
.is
<T
*>())
74 return location
.get
<T
*>();
80 InsertionPoint(T
*ptr
, bool saveRestoreStack
= false)
81 : location(ptr
), saveRestoreStack
{saveRestoreStack
} {}
82 InsertionPoint(std::nullptr_t null
)
83 : location(null
), saveRestoreStack
{false} {}
85 /// Get contained operation, or nullptr
86 mlir::Operation
*tryGetOperation() const {
87 return tryGetPtr
<mlir::Operation
>();
90 /// Get contained block, or nullptr
91 mlir::Block
*tryGetBlock() const { return tryGetPtr
<mlir::Block
>(); }
93 /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave
94 /// intrinsic should be added before the alloca, and an llvm.stackrestore
95 /// intrinsic should be added where the freemem is
96 bool shouldSaveRestoreStack() const { return saveRestoreStack
; }
98 operator bool() const { return tryGetOperation() || tryGetBlock(); }
100 bool operator==(const InsertionPoint
&rhs
) const {
101 return (location
== rhs
.location
) &&
102 (saveRestoreStack
== rhs
.saveRestoreStack
);
105 bool operator!=(const InsertionPoint
&rhs
) const { return !(*this == rhs
); }
108 /// Maps SSA values to their AllocationState at a particular program point.
109 /// Also caches the insertion points for the new alloca operations
110 class LatticePoint
: public mlir::dataflow::AbstractDenseLattice
{
111 // Maps all values we are interested in to states
112 llvm::SmallDenseMap
<mlir::Value
, AllocationState
, 1> stateMap
;
115 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint
)
116 using AbstractDenseLattice::AbstractDenseLattice
;
118 bool operator==(const LatticePoint
&rhs
) const {
119 return stateMap
== rhs
.stateMap
;
122 /// Join the lattice accross control-flow edges
123 mlir::ChangeResult
join(const AbstractDenseLattice
&lattice
) override
;
125 void print(llvm::raw_ostream
&os
) const override
;
127 /// Clear all modifications
128 mlir::ChangeResult
reset();
130 /// Set the state of an SSA value
131 mlir::ChangeResult
set(mlir::Value value
, AllocationState state
);
133 /// Get fir.allocmem ops which were allocated in this function and always
134 /// freed before the function returns, plus whre to insert replacement
136 void appendFreedValues(llvm::DenseSet
<mlir::Value
> &out
) const;
138 std::optional
<AllocationState
> get(mlir::Value val
) const;
141 class AllocationAnalysis
142 : public mlir::dataflow::DenseDataFlowAnalysis
<LatticePoint
> {
144 using DenseDataFlowAnalysis::DenseDataFlowAnalysis
;
146 void visitOperation(mlir::Operation
*op
, const LatticePoint
&before
,
147 LatticePoint
*after
) override
;
149 /// At an entry point, the last modifications of all memory resources are
150 /// yet to be determined
151 void setToEntryState(LatticePoint
*lattice
) override
;
154 /// Visit control flow operations and decide whether to call visitOperation
155 /// to apply the transfer function
156 void processOperation(mlir::Operation
*op
) override
;
159 /// Drives analysis to find candidate fir.allocmem operations which could be
160 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis
161 class StackArraysAnalysisWrapper
{
163 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper
)
165 // Maps fir.allocmem -> place to insert alloca
166 using AllocMemMap
= llvm::DenseMap
<mlir::Operation
*, InsertionPoint
>;
168 StackArraysAnalysisWrapper(mlir::Operation
*op
) {}
170 bool hasErrors() const;
172 const AllocMemMap
&getCandidateOps(mlir::Operation
*func
);
175 llvm::DenseMap
<mlir::Operation
*, AllocMemMap
> funcMaps
;
176 bool gotError
= false;
178 void analyseFunction(mlir::Operation
*func
);
181 /// Converts a fir.allocmem to a fir.alloca
182 class AllocMemConversion
: public mlir::OpRewritePattern
<fir::AllocMemOp
> {
184 using OpRewritePattern::OpRewritePattern
;
187 mlir::MLIRContext
*ctx
,
188 const llvm::DenseMap
<mlir::Operation
*, InsertionPoint
> &candidateOps
);
191 matchAndRewrite(fir::AllocMemOp allocmem
,
192 mlir::PatternRewriter
&rewriter
) const override
;
194 /// Determine where to insert the alloca operation. The returned value should
195 /// be checked to see if it is inside a loop
196 static InsertionPoint
findAllocaInsertionPoint(fir::AllocMemOp
&oldAlloc
);
199 /// allocmem operations that DFA has determined are safe to move to the stack
200 /// mapping to where to insert replacement freemem operations
201 const llvm::DenseMap
<mlir::Operation
*, InsertionPoint
> &candidateOps
;
203 /// If we failed to find an insertion point not inside a loop, see if it would
204 /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
205 static InsertionPoint
findAllocaLoopInsertionPoint(fir::AllocMemOp
&oldAlloc
);
207 /// Returns the alloca if it was successfully inserted, otherwise {}
208 std::optional
<fir::AllocaOp
>
209 insertAlloca(fir::AllocMemOp
&oldAlloc
,
210 mlir::PatternRewriter
&rewriter
) const;
212 /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem
213 void insertStackSaveRestore(fir::AllocMemOp
&oldAlloc
,
214 mlir::PatternRewriter
&rewriter
) const;
217 class StackArraysPass
: public fir::impl::StackArraysBase
<StackArraysPass
> {
219 StackArraysPass() = default;
220 StackArraysPass(const StackArraysPass
&pass
);
222 llvm::StringRef
getDescription() const override
;
224 void runOnOperation() override
;
225 void runOnFunc(mlir::Operation
*func
);
228 Statistic runCount
{this, "stackArraysRunCount",
229 "Number of heap allocations moved to the stack"};
234 static void print(llvm::raw_ostream
&os
, AllocationState state
) {
236 case AllocationState::Unknown
:
239 case AllocationState::Freed
:
242 case AllocationState::Allocated
:
248 /// Join two AllocationStates for the same value coming from different CFG
250 static AllocationState
join(AllocationState lhs
, AllocationState rhs
) {
251 // | Allocated | Freed | Unknown
252 // ========= | ========= | ========= | =========
253 // Allocated | Allocated | Unknown | Unknown
254 // Freed | Unknown | Freed | Unknown
255 // Unknown | Unknown | Unknown | Unknown
258 return AllocationState::Unknown
;
261 mlir::ChangeResult
LatticePoint::join(const AbstractDenseLattice
&lattice
) {
262 const auto &rhs
= static_cast<const LatticePoint
&>(lattice
);
263 mlir::ChangeResult changed
= mlir::ChangeResult::NoChange
;
265 // add everything from rhs to map, handling cases where values are in both
266 for (const auto &[value
, rhsState
] : rhs
.stateMap
) {
267 auto it
= stateMap
.find(value
);
268 if (it
!= stateMap
.end()) {
269 // value is present in both maps
270 AllocationState myState
= it
->second
;
271 AllocationState newState
= ::join(myState
, rhsState
);
272 if (newState
!= myState
) {
273 changed
= mlir::ChangeResult::Change
;
274 it
->getSecond() = newState
;
277 // value not present in current map: add it
278 stateMap
.insert({value
, rhsState
});
279 changed
= mlir::ChangeResult::Change
;
286 void LatticePoint::print(llvm::raw_ostream
&os
) const {
287 for (const auto &[value
, state
] : stateMap
) {
293 mlir::ChangeResult
LatticePoint::reset() {
294 if (stateMap
.empty())
295 return mlir::ChangeResult::NoChange
;
297 return mlir::ChangeResult::Change
;
300 mlir::ChangeResult
LatticePoint::set(mlir::Value value
, AllocationState state
) {
301 if (stateMap
.count(value
)) {
303 AllocationState
&oldState
= stateMap
[value
];
304 if (oldState
!= state
) {
305 stateMap
[value
] = state
;
306 return mlir::ChangeResult::Change
;
308 return mlir::ChangeResult::NoChange
;
310 stateMap
.insert({value
, state
});
311 return mlir::ChangeResult::Change
;
314 /// Get values which were allocated in this function and always freed before
315 /// the function returns
316 void LatticePoint::appendFreedValues(llvm::DenseSet
<mlir::Value
> &out
) const {
317 for (auto &[value
, state
] : stateMap
) {
318 if (state
== AllocationState::Freed
)
323 std::optional
<AllocationState
> LatticePoint::get(mlir::Value val
) const {
324 auto it
= stateMap
.find(val
);
325 if (it
== stateMap
.end())
330 void AllocationAnalysis::visitOperation(mlir::Operation
*op
,
331 const LatticePoint
&before
,
332 LatticePoint
*after
) {
333 LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
335 LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before
<< "\n");
337 // propagate before -> after
338 mlir::ChangeResult changed
= after
->join(before
);
340 if (auto allocmem
= mlir::dyn_cast
<fir::AllocMemOp
>(op
)) {
341 assert(op
->getNumResults() == 1 && "fir.allocmem has one result");
342 auto attr
= op
->getAttrOfType
<fir::MustBeHeapAttr
>(
343 fir::MustBeHeapAttr::getAttrName());
344 if (attr
&& attr
.getValue()) {
345 LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n");
346 // skip allocation marked not to be moved
350 auto retTy
= allocmem
.getAllocatedType();
351 if (!retTy
.isa
<fir::SequenceType
>()) {
352 LLVM_DEBUG(llvm::dbgs()
353 << "--Allocation is not for an array: skipping\n");
357 mlir::Value result
= op
->getResult(0);
358 changed
|= after
->set(result
, AllocationState::Allocated
);
359 } else if (mlir::isa
<fir::FreeMemOp
>(op
)) {
360 assert(op
->getNumOperands() == 1 && "fir.freemem has one operand");
361 mlir::Value operand
= op
->getOperand(0);
362 std::optional
<AllocationState
> operandState
= before
.get(operand
);
363 if (operandState
&& *operandState
== AllocationState::Allocated
) {
364 // don't tag things not allocated in this function as freed, so that we
365 // don't think they are candidates for moving to the stack
366 changed
|= after
->set(operand
, AllocationState::Freed
);
368 } else if (mlir::isa
<fir::ResultOp
>(op
)) {
369 mlir::Operation
*parent
= op
->getParentOp();
370 LatticePoint
*parentLattice
= getLattice(parent
);
371 assert(parentLattice
);
372 mlir::ChangeResult parentChanged
= parentLattice
->join(*after
);
373 propagateIfChanged(parentLattice
, parentChanged
);
376 // we pass lattices straight through fir.call because called functions should
377 // not deallocate flang-generated array temporaries
379 LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after
<< "\n");
380 propagateIfChanged(after
, changed
);
383 void AllocationAnalysis::setToEntryState(LatticePoint
*lattice
) {
384 propagateIfChanged(lattice
, lattice
->reset());
387 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference
388 /// being that call operations are passed through to the transfer function
389 void AllocationAnalysis::processOperation(mlir::Operation
*op
) {
390 // If the containing block is not executable, bail out.
391 if (!getOrCreateFor
<mlir::dataflow::Executable
>(op
, op
->getBlock())->isLive())
394 // Get the dense lattice to update
395 mlir::dataflow::AbstractDenseLattice
*after
= getLattice(op
);
397 // If this op implements region control-flow, then control-flow dictates its
398 // transfer function.
399 if (auto branch
= mlir::dyn_cast
<mlir::RegionBranchOpInterface
>(op
))
400 return visitRegionBranchOperation(op
, branch
, after
);
402 // pass call operations through to the transfer function
404 // Get the dense state before the execution of the op.
405 const mlir::dataflow::AbstractDenseLattice
*before
;
406 if (mlir::Operation
*prev
= op
->getPrevNode())
407 before
= getLatticeFor(op
, prev
);
409 before
= getLatticeFor(op
, op
->getBlock());
411 /// Invoke the operation transfer function
412 visitOperationImpl(op
, *before
, after
);
415 void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation
*func
) {
416 assert(mlir::isa
<mlir::func::FuncOp
>(func
));
417 mlir::DataFlowSolver solver
;
418 // constant propagation is required for dead code analysis, dead code analysis
419 // is required to mark blocks live (required for mlir dense dfa)
420 solver
.load
<mlir::dataflow::SparseConstantPropagation
>();
421 solver
.load
<mlir::dataflow::DeadCodeAnalysis
>();
423 auto [it
, inserted
] = funcMaps
.try_emplace(func
);
424 AllocMemMap
&candidateOps
= it
->second
;
426 solver
.load
<AllocationAnalysis
>();
427 if (failed(solver
.initializeAndRun(func
))) {
428 llvm::errs() << "DataFlowSolver failed!";
433 LatticePoint point
{func
};
434 auto joinOperationLattice
= [&](mlir::Operation
*op
) {
435 const LatticePoint
*lattice
= solver
.lookupState
<LatticePoint
>(op
);
436 // there will be no lattice for an unreachable block
438 point
.join(*lattice
);
440 func
->walk([&](mlir::func::ReturnOp child
) { joinOperationLattice(child
); });
441 func
->walk([&](fir::UnreachableOp child
) { joinOperationLattice(child
); });
442 llvm::DenseSet
<mlir::Value
> freedValues
;
443 point
.appendFreedValues(freedValues
);
445 // We only replace allocations which are definately freed on all routes
446 // through the function because otherwise the allocation may have an intende
447 // lifetime longer than the current stack frame (e.g. a heap allocation which
448 // is then freed by another function).
449 for (mlir::Value freedValue
: freedValues
) {
450 fir::AllocMemOp allocmem
= freedValue
.getDefiningOp
<fir::AllocMemOp
>();
451 InsertionPoint insertionPoint
=
452 AllocMemConversion::findAllocaInsertionPoint(allocmem
);
454 candidateOps
.insert({allocmem
, insertionPoint
});
457 LLVM_DEBUG(for (auto [allocMemOp
, _
]
459 llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp
<< '\n';
463 bool StackArraysAnalysisWrapper::hasErrors() const { return gotError
; }
465 const StackArraysAnalysisWrapper::AllocMemMap
&
466 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation
*func
) {
467 if (!funcMaps
.count(func
))
468 analyseFunction(func
);
469 return funcMaps
[func
];
472 AllocMemConversion::AllocMemConversion(
473 mlir::MLIRContext
*ctx
,
474 const llvm::DenseMap
<mlir::Operation
*, InsertionPoint
> &candidateOps
)
475 : OpRewritePattern(ctx
), candidateOps(candidateOps
) {}
478 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem
,
479 mlir::PatternRewriter
&rewriter
) const {
480 auto oldInsertionPt
= rewriter
.saveInsertionPoint();
481 // add alloca operation
482 std::optional
<fir::AllocaOp
> alloca
= insertAlloca(allocmem
, rewriter
);
483 rewriter
.restoreInsertionPoint(oldInsertionPt
);
485 return mlir::failure();
487 // remove freemem operations
488 for (mlir::Operation
*user
: allocmem
.getOperation()->getUsers())
489 if (mlir::isa
<fir::FreeMemOp
>(user
))
490 rewriter
.eraseOp(user
);
492 // replace references to heap allocation with references to stack allocation
493 rewriter
.replaceAllUsesWith(allocmem
.getResult(), alloca
->getResult());
495 // remove allocmem operation
496 rewriter
.eraseOp(allocmem
.getOperation());
498 return mlir::success();
501 static bool isInLoop(mlir::Block
*block
) {
502 return mlir::LoopLikeOpInterface::blockIsInLoop(block
);
505 static bool isInLoop(mlir::Operation
*op
) {
506 return isInLoop(op
->getBlock()) ||
507 op
->getParentOfType
<mlir::LoopLikeOpInterface
>();
511 AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp
&oldAlloc
) {
512 // Ideally the alloca should be inserted at the end of the function entry
513 // block so that we do not allocate stack space in a loop. However,
514 // the operands to the alloca may not be available that early, so insert it
515 // after the last operand becomes available
516 // If the old allocmem op was in an openmp region then it should not be moved
518 LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: "
519 << oldAlloc
<< "\n");
521 // check that an Operation or Block we are about to return is not in a loop
522 auto checkReturn
= [&](auto *point
) -> InsertionPoint
{
523 if (isInLoop(point
)) {
524 mlir::Operation
*oldAllocOp
= oldAlloc
.getOperation();
525 if (isInLoop(oldAllocOp
)) {
526 // where we want to put it is in a loop, and even the old location is in
528 return findAllocaLoopInsertionPoint(oldAlloc
);
536 oldAlloc
->getParentOfType
<mlir::omp::OutlineableOpenMPOpInterface
>();
538 // Find when the last operand value becomes available
539 mlir::Block
*operandsBlock
= nullptr;
540 mlir::Operation
*lastOperand
= nullptr;
541 for (mlir::Value operand
: oldAlloc
.getOperands()) {
542 LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand
<< "\n");
543 mlir::Operation
*op
= operand
.getDefiningOp();
545 return checkReturn(oldAlloc
.getOperation());
547 operandsBlock
= op
->getBlock();
548 else if (operandsBlock
!= op
->getBlock()) {
549 LLVM_DEBUG(llvm::dbgs()
550 << "----operand declared in a different block!\n");
551 // Operation::isBeforeInBlock requires the operations to be in the same
552 // block. The best we can do is the location of the allocmem.
553 return checkReturn(oldAlloc
.getOperation());
555 if (!lastOperand
|| lastOperand
->isBeforeInBlock(op
))
560 // there were value operands to the allocmem so insert after the last one
561 LLVM_DEBUG(llvm::dbgs()
562 << "--Placing after last operand: " << *lastOperand
<< "\n");
563 // check we aren't moving out of an omp region
564 auto lastOpOmpRegion
=
565 lastOperand
->getParentOfType
<mlir::omp::OutlineableOpenMPOpInterface
>();
566 if (lastOpOmpRegion
== oldOmpRegion
)
567 return checkReturn(lastOperand
);
568 // Presumably this happened because the operands became ready before the
569 // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should
570 // imply that oldOmpRegion comes after lastOpOmpRegion.
571 return checkReturn(oldOmpRegion
.getAllocaBlock());
574 // There were no value operands to the allocmem so we are safe to insert it
575 // as early as we want
577 // handle openmp case
579 return checkReturn(oldOmpRegion
.getAllocaBlock());
581 // fall back to the function entry block
582 mlir::func::FuncOp func
= oldAlloc
->getParentOfType
<mlir::func::FuncOp
>();
583 assert(func
&& "This analysis is run on func.func");
584 mlir::Block
&entryBlock
= func
.getBlocks().front();
585 LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n");
586 return checkReturn(&entryBlock
);
590 AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp
&oldAlloc
) {
591 mlir::Operation
*oldAllocOp
= oldAlloc
;
592 // This is only called as a last resort. We should try to insert at the
593 // location of the old allocation, which is inside of a loop, using
594 // llvm.stacksave/llvm.stackrestore
597 llvm::SmallVector
<mlir::Operation
*, 1> freeOps
;
598 for (mlir::Operation
*user
: oldAllocOp
->getUsers())
599 if (mlir::isa
<fir::FreeMemOp
>(user
))
600 freeOps
.push_back(user
);
601 assert(freeOps
.size() && "DFA should only return freed memory");
603 // Don't attempt to reason about a stacksave/stackrestore between different
605 for (mlir::Operation
*free
: freeOps
)
606 if (free
->getBlock() != oldAllocOp
->getBlock())
609 // Check that there aren't any other stack allocations in between the
610 // stack save and stack restore
611 // note: for flang generated temporaries there should only be one free op
612 for (mlir::Operation
*free
: freeOps
) {
613 for (mlir::Operation
*op
= oldAlloc
; op
&& op
!= free
;
614 op
= op
->getNextNode()) {
615 if (mlir::isa
<fir::AllocaOp
>(op
))
620 return InsertionPoint
{oldAllocOp
, /*shouldStackSaveRestore=*/true};
623 std::optional
<fir::AllocaOp
>
624 AllocMemConversion::insertAlloca(fir::AllocMemOp
&oldAlloc
,
625 mlir::PatternRewriter
&rewriter
) const {
626 auto it
= candidateOps
.find(oldAlloc
.getOperation());
627 if (it
== candidateOps
.end())
629 InsertionPoint insertionPoint
= it
->second
;
633 if (insertionPoint
.shouldSaveRestoreStack())
634 insertStackSaveRestore(oldAlloc
, rewriter
);
636 mlir::Location loc
= oldAlloc
.getLoc();
637 mlir::Type varTy
= oldAlloc
.getInType();
638 if (mlir::Operation
*op
= insertionPoint
.tryGetOperation()) {
639 rewriter
.setInsertionPointAfter(op
);
641 mlir::Block
*block
= insertionPoint
.tryGetBlock();
642 assert(block
&& "There must be a valid insertion point");
643 rewriter
.setInsertionPointToStart(block
);
646 auto unpackName
= [](std::optional
<llvm::StringRef
> opt
) -> llvm::StringRef
{
652 llvm::StringRef uniqName
= unpackName(oldAlloc
.getUniqName());
653 llvm::StringRef bindcName
= unpackName(oldAlloc
.getBindcName());
654 return rewriter
.create
<fir::AllocaOp
>(loc
, varTy
, uniqName
, bindcName
,
655 oldAlloc
.getTypeparams(),
656 oldAlloc
.getShape());
659 void AllocMemConversion::insertStackSaveRestore(
660 fir::AllocMemOp
&oldAlloc
, mlir::PatternRewriter
&rewriter
) const {
661 auto oldPoint
= rewriter
.saveInsertionPoint();
662 auto mod
= oldAlloc
->getParentOfType
<mlir::ModuleOp
>();
663 fir::KindMapping kindMap
= fir::getKindMapping(mod
);
664 fir::FirOpBuilder builder
{rewriter
, kindMap
};
666 mlir::func::FuncOp stackSaveFn
= fir::factory::getLlvmStackSave(builder
);
667 mlir::SymbolRefAttr stackSaveSym
=
668 builder
.getSymbolRefAttr(stackSaveFn
.getName());
670 builder
.setInsertionPoint(oldAlloc
);
673 .create
<fir::CallOp
>(oldAlloc
.getLoc(),
674 stackSaveFn
.getFunctionType().getResults(),
675 stackSaveSym
, mlir::ValueRange
{})
678 mlir::func::FuncOp stackRestoreFn
=
679 fir::factory::getLlvmStackRestore(builder
);
680 mlir::SymbolRefAttr stackRestoreSym
=
681 builder
.getSymbolRefAttr(stackRestoreFn
.getName());
683 for (mlir::Operation
*user
: oldAlloc
->getUsers()) {
684 if (mlir::isa
<fir::FreeMemOp
>(user
)) {
685 builder
.setInsertionPoint(user
);
686 builder
.create
<fir::CallOp
>(user
->getLoc(),
687 stackRestoreFn
.getFunctionType().getResults(),
688 stackRestoreSym
, mlir::ValueRange
{sp
});
692 rewriter
.restoreInsertionPoint(oldPoint
);
695 StackArraysPass::StackArraysPass(const StackArraysPass
&pass
)
696 : fir::impl::StackArraysBase
<StackArraysPass
>(pass
) {}
698 llvm::StringRef
StackArraysPass::getDescription() const {
699 return "Move heap allocated array temporaries to the stack";
702 void StackArraysPass::runOnOperation() {
703 mlir::ModuleOp mod
= getOperation();
705 mod
.walk([this](mlir::func::FuncOp func
) { runOnFunc(func
); });
708 void StackArraysPass::runOnFunc(mlir::Operation
*func
) {
709 assert(mlir::isa
<mlir::func::FuncOp
>(func
));
711 auto &analysis
= getAnalysis
<StackArraysAnalysisWrapper
>();
712 const auto &candidateOps
= analysis
.getCandidateOps(func
);
713 if (analysis
.hasErrors()) {
718 if (candidateOps
.empty())
720 runCount
+= candidateOps
.size();
722 mlir::MLIRContext
&context
= getContext();
723 mlir::RewritePatternSet
patterns(&context
);
724 mlir::ConversionTarget
target(context
);
726 target
.addLegalDialect
<fir::FIROpsDialect
, mlir::arith::ArithDialect
,
727 mlir::func::FuncDialect
>();
728 target
.addDynamicallyLegalOp
<fir::AllocMemOp
>([&](fir::AllocMemOp alloc
) {
729 return !candidateOps
.count(alloc
.getOperation());
732 patterns
.insert
<AllocMemConversion
>(&context
, candidateOps
);
734 mlir::applyPartialConversion(func
, target
, std::move(patterns
)))) {
735 mlir::emitError(func
->getLoc(), "error in stack arrays optimization\n");
740 std::unique_ptr
<mlir::Pass
> fir::createStackArraysPass() {
741 return std::make_unique
<StackArraysPass
>();