[AMDGPU] Parse wwm filter flag for regalloc fast (#119347)
[llvm-project.git] / flang / lib / Optimizer / Transforms / StackArrays.cpp
blob0c474f463f09c15eb2d779cf81f44705064cb095
1 //===- StackArrays.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
11 #include "flang/Optimizer/Dialect/FIRAttr.h"
12 #include "flang/Optimizer/Dialect/FIRDialect.h"
13 #include "flang/Optimizer/Dialect/FIROps.h"
14 #include "flang/Optimizer/Dialect/FIRType.h"
15 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
16 #include "flang/Optimizer/Transforms/Passes.h"
17 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
18 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
19 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
20 #include "mlir/Analysis/DataFlowFramework.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 #include "mlir/Transforms/Passes.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/PointerUnion.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include <optional>
37 namespace fir {
38 #define GEN_PASS_DEF_STACKARRAYS
39 #include "flang/Optimizer/Transforms/Passes.h.inc"
40 } // namespace fir
42 #define DEBUG_TYPE "stack-arrays"
44 static llvm::cl::opt<std::size_t> maxAllocsPerFunc(
45 "stack-arrays-max-allocs",
46 llvm::cl::desc("The maximum number of heap allocations to consider in one "
47 "function before skipping (to save compilation time). Set "
48 "to 0 for no limit."),
49 llvm::cl::init(1000), llvm::cl::Hidden);
51 namespace {
53 /// The state of an SSA value at each program point
54 enum class AllocationState {
55 /// This means that the allocation state of a variable cannot be determined
56 /// at this program point, e.g. because one route through a conditional freed
57 /// the variable and the other route didn't.
58 /// This asserts a known-unknown: different from the unknown-unknown of having
59 /// no AllocationState stored for a particular SSA value
60 Unknown,
61 /// Means this SSA value was allocated on the heap in this function and has
62 /// now been freed
63 Freed,
64 /// Means this SSA value was allocated on the heap in this function and is a
65 /// candidate for moving to the stack
66 Allocated,
69 /// Stores where an alloca should be inserted. If the PointerUnion is an
70 /// Operation the alloca should be inserted /after/ the operation. If it is a
71 /// block, the alloca can be placed anywhere in that block.
72 class InsertionPoint {
73 llvm::PointerUnion<mlir::Operation *, mlir::Block *> location;
74 bool saveRestoreStack;
76 /// Get contained pointer type or nullptr
77 template <class T>
78 T *tryGetPtr() const {
79 if (location.is<T *>())
80 return location.get<T *>();
81 return nullptr;
84 public:
85 template <class T>
86 InsertionPoint(T *ptr, bool saveRestoreStack = false)
87 : location(ptr), saveRestoreStack{saveRestoreStack} {}
88 InsertionPoint(std::nullptr_t null)
89 : location(null), saveRestoreStack{false} {}
91 /// Get contained operation, or nullptr
92 mlir::Operation *tryGetOperation() const {
93 return tryGetPtr<mlir::Operation>();
96 /// Get contained block, or nullptr
97 mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); }
99 /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave
100 /// intrinsic should be added before the alloca, and an llvm.stackrestore
101 /// intrinsic should be added where the freemem is
102 bool shouldSaveRestoreStack() const { return saveRestoreStack; }
104 operator bool() const { return tryGetOperation() || tryGetBlock(); }
106 bool operator==(const InsertionPoint &rhs) const {
107 return (location == rhs.location) &&
108 (saveRestoreStack == rhs.saveRestoreStack);
111 bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); }
114 /// Maps SSA values to their AllocationState at a particular program point.
115 /// Also caches the insertion points for the new alloca operations
116 class LatticePoint : public mlir::dataflow::AbstractDenseLattice {
117 // Maps all values we are interested in to states
118 llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap;
120 public:
121 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint)
122 using AbstractDenseLattice::AbstractDenseLattice;
124 bool operator==(const LatticePoint &rhs) const {
125 return stateMap == rhs.stateMap;
128 /// Join the lattice accross control-flow edges
129 mlir::ChangeResult join(const AbstractDenseLattice &lattice) override;
131 void print(llvm::raw_ostream &os) const override;
133 /// Clear all modifications
134 mlir::ChangeResult reset();
136 /// Set the state of an SSA value
137 mlir::ChangeResult set(mlir::Value value, AllocationState state);
139 /// Get fir.allocmem ops which were allocated in this function and always
140 /// freed before the function returns, plus whre to insert replacement
141 /// fir.alloca ops
142 void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const;
144 std::optional<AllocationState> get(mlir::Value val) const;
147 class AllocationAnalysis
148 : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> {
149 public:
150 using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis;
152 mlir::LogicalResult visitOperation(mlir::Operation *op,
153 const LatticePoint &before,
154 LatticePoint *after) override;
156 /// At an entry point, the last modifications of all memory resources are
157 /// yet to be determined
158 void setToEntryState(LatticePoint *lattice) override;
160 protected:
161 /// Visit control flow operations and decide whether to call visitOperation
162 /// to apply the transfer function
163 mlir::LogicalResult processOperation(mlir::Operation *op) override;
166 /// Drives analysis to find candidate fir.allocmem operations which could be
167 /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis
168 class StackArraysAnalysisWrapper {
169 public:
170 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper)
172 // Maps fir.allocmem -> place to insert alloca
173 using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>;
175 StackArraysAnalysisWrapper(mlir::Operation *op) {}
177 // returns nullptr if analysis failed
178 const AllocMemMap *getCandidateOps(mlir::Operation *func);
180 private:
181 llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
183 llvm::LogicalResult analyseFunction(mlir::Operation *func);
186 /// Converts a fir.allocmem to a fir.alloca
187 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
188 public:
189 explicit AllocMemConversion(
190 mlir::MLIRContext *ctx,
191 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
192 : OpRewritePattern(ctx), candidateOps{candidateOps} {}
194 llvm::LogicalResult
195 matchAndRewrite(fir::AllocMemOp allocmem,
196 mlir::PatternRewriter &rewriter) const override;
198 /// Determine where to insert the alloca operation. The returned value should
199 /// be checked to see if it is inside a loop
200 static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc);
202 private:
203 /// Handle to the DFA (already run)
204 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;
206 /// If we failed to find an insertion point not inside a loop, see if it would
207 /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
208 static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc);
210 /// Returns the alloca if it was successfully inserted, otherwise {}
211 std::optional<fir::AllocaOp>
212 insertAlloca(fir::AllocMemOp &oldAlloc,
213 mlir::PatternRewriter &rewriter) const;
215 /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem
216 void insertStackSaveRestore(fir::AllocMemOp &oldAlloc,
217 mlir::PatternRewriter &rewriter) const;
220 class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> {
221 public:
222 StackArraysPass() = default;
223 StackArraysPass(const StackArraysPass &pass);
225 llvm::StringRef getDescription() const override;
227 void runOnOperation() override;
229 private:
230 Statistic runCount{this, "stackArraysRunCount",
231 "Number of heap allocations moved to the stack"};
234 } // namespace
236 static void print(llvm::raw_ostream &os, AllocationState state) {
237 switch (state) {
238 case AllocationState::Unknown:
239 os << "Unknown";
240 break;
241 case AllocationState::Freed:
242 os << "Freed";
243 break;
244 case AllocationState::Allocated:
245 os << "Allocated";
246 break;
250 /// Join two AllocationStates for the same value coming from different CFG
251 /// blocks
252 static AllocationState join(AllocationState lhs, AllocationState rhs) {
253 // | Allocated | Freed | Unknown
254 // ========= | ========= | ========= | =========
255 // Allocated | Allocated | Unknown | Unknown
256 // Freed | Unknown | Freed | Unknown
257 // Unknown | Unknown | Unknown | Unknown
258 if (lhs == rhs)
259 return lhs;
260 return AllocationState::Unknown;
263 mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) {
264 const auto &rhs = static_cast<const LatticePoint &>(lattice);
265 mlir::ChangeResult changed = mlir::ChangeResult::NoChange;
267 // add everything from rhs to map, handling cases where values are in both
268 for (const auto &[value, rhsState] : rhs.stateMap) {
269 auto it = stateMap.find(value);
270 if (it != stateMap.end()) {
271 // value is present in both maps
272 AllocationState myState = it->second;
273 AllocationState newState = ::join(myState, rhsState);
274 if (newState != myState) {
275 changed = mlir::ChangeResult::Change;
276 it->getSecond() = newState;
278 } else {
279 // value not present in current map: add it
280 stateMap.insert({value, rhsState});
281 changed = mlir::ChangeResult::Change;
285 return changed;
288 void LatticePoint::print(llvm::raw_ostream &os) const {
289 for (const auto &[value, state] : stateMap) {
290 os << "\n * " << value << ": ";
291 ::print(os, state);
295 mlir::ChangeResult LatticePoint::reset() {
296 if (stateMap.empty())
297 return mlir::ChangeResult::NoChange;
298 stateMap.clear();
299 return mlir::ChangeResult::Change;
302 mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) {
303 if (stateMap.count(value)) {
304 // already in map
305 AllocationState &oldState = stateMap[value];
306 if (oldState != state) {
307 stateMap[value] = state;
308 return mlir::ChangeResult::Change;
310 return mlir::ChangeResult::NoChange;
312 stateMap.insert({value, state});
313 return mlir::ChangeResult::Change;
316 /// Get values which were allocated in this function and always freed before
317 /// the function returns
318 void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const {
319 for (auto &[value, state] : stateMap) {
320 if (state == AllocationState::Freed)
321 out.insert(value);
325 std::optional<AllocationState> LatticePoint::get(mlir::Value val) const {
326 auto it = stateMap.find(val);
327 if (it == stateMap.end())
328 return {};
329 return it->second;
332 mlir::LogicalResult AllocationAnalysis::visitOperation(
333 mlir::Operation *op, const LatticePoint &before, LatticePoint *after) {
334 LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
335 << "\n");
336 LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n");
338 // propagate before -> after
339 mlir::ChangeResult changed = after->join(before);
341 if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) {
342 assert(op->getNumResults() == 1 && "fir.allocmem has one result");
343 auto attr = op->getAttrOfType<fir::MustBeHeapAttr>(
344 fir::MustBeHeapAttr::getAttrName());
345 if (attr && attr.getValue()) {
346 LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n");
347 // skip allocation marked not to be moved
348 return mlir::success();
351 auto retTy = allocmem.getAllocatedType();
352 if (!mlir::isa<fir::SequenceType>(retTy)) {
353 LLVM_DEBUG(llvm::dbgs()
354 << "--Allocation is not for an array: skipping\n");
355 return mlir::success();
358 mlir::Value result = op->getResult(0);
359 changed |= after->set(result, AllocationState::Allocated);
360 } else if (mlir::isa<fir::FreeMemOp>(op)) {
361 assert(op->getNumOperands() == 1 && "fir.freemem has one operand");
362 mlir::Value operand = op->getOperand(0);
364 // Note: StackArrays is scheduled in the pass pipeline after lowering hlfir
365 // to fir. Therefore, we only need to handle `fir::DeclareOp`s.
366 if (auto declareOp =
367 llvm::dyn_cast_if_present<fir::DeclareOp>(operand.getDefiningOp()))
368 operand = declareOp.getMemref();
370 std::optional<AllocationState> operandState = before.get(operand);
371 if (operandState && *operandState == AllocationState::Allocated) {
372 // don't tag things not allocated in this function as freed, so that we
373 // don't think they are candidates for moving to the stack
374 changed |= after->set(operand, AllocationState::Freed);
376 } else if (mlir::isa<fir::ResultOp>(op)) {
377 mlir::Operation *parent = op->getParentOp();
378 LatticePoint *parentLattice = getLattice(getProgramPointAfter(parent));
379 assert(parentLattice);
380 mlir::ChangeResult parentChanged = parentLattice->join(*after);
381 propagateIfChanged(parentLattice, parentChanged);
384 // we pass lattices straight through fir.call because called functions should
385 // not deallocate flang-generated array temporaries
387 LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n");
388 propagateIfChanged(after, changed);
389 return mlir::success();
392 void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
393 propagateIfChanged(lattice, lattice->reset());
396 /// Mostly a copy of AbstractDenseLattice::processOperation - the difference
397 /// being that call operations are passed through to the transfer function
398 mlir::LogicalResult AllocationAnalysis::processOperation(mlir::Operation *op) {
399 mlir::ProgramPoint *point = getProgramPointAfter(op);
400 // If the containing block is not executable, bail out.
401 if (op->getBlock() != nullptr &&
402 !getOrCreateFor<mlir::dataflow::Executable>(
403 point, getProgramPointBefore(op->getBlock()))
404 ->isLive())
405 return mlir::success();
407 // Get the dense lattice to update
408 mlir::dataflow::AbstractDenseLattice *after = getLattice(point);
410 // If this op implements region control-flow, then control-flow dictates its
411 // transfer function.
412 if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
413 visitRegionBranchOperation(point, branch, after);
414 return mlir::success();
417 // pass call operations through to the transfer function
419 // Get the dense state before the execution of the op.
420 const mlir::dataflow::AbstractDenseLattice *before =
421 getLatticeFor(point, getProgramPointBefore(op));
423 /// Invoke the operation transfer function
424 return visitOperationImpl(op, *before, after);
427 llvm::LogicalResult
428 StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
429 assert(mlir::isa<mlir::func::FuncOp>(func));
430 size_t nAllocs = 0;
431 func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; });
432 // don't bother with the analysis if there are no heap allocations
433 if (nAllocs == 0)
434 return mlir::success();
435 if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) {
436 LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with "
437 << nAllocs << " heap allocations");
438 return mlir::success();
441 mlir::DataFlowSolver solver;
442 // constant propagation is required for dead code analysis, dead code analysis
443 // is required to mark blocks live (required for mlir dense dfa)
444 solver.load<mlir::dataflow::SparseConstantPropagation>();
445 solver.load<mlir::dataflow::DeadCodeAnalysis>();
447 auto [it, inserted] = funcMaps.try_emplace(func);
448 AllocMemMap &candidateOps = it->second;
450 solver.load<AllocationAnalysis>();
451 if (failed(solver.initializeAndRun(func))) {
452 llvm::errs() << "DataFlowSolver failed!";
453 return mlir::failure();
456 LatticePoint point{solver.getProgramPointAfter(func)};
457 auto joinOperationLattice = [&](mlir::Operation *op) {
458 const LatticePoint *lattice =
459 solver.lookupState<LatticePoint>(solver.getProgramPointAfter(op));
460 // there will be no lattice for an unreachable block
461 if (lattice)
462 (void)point.join(*lattice);
465 func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); });
466 func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); });
467 func->walk(
468 [&](mlir::omp::TerminatorOp child) { joinOperationLattice(child); });
469 func->walk([&](mlir::omp::YieldOp child) { joinOperationLattice(child); });
471 llvm::DenseSet<mlir::Value> freedValues;
472 point.appendFreedValues(freedValues);
474 // We only replace allocations which are definately freed on all routes
475 // through the function because otherwise the allocation may have an intende
476 // lifetime longer than the current stack frame (e.g. a heap allocation which
477 // is then freed by another function).
478 for (mlir::Value freedValue : freedValues) {
479 fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>();
480 InsertionPoint insertionPoint =
481 AllocMemConversion::findAllocaInsertionPoint(allocmem);
482 if (insertionPoint)
483 candidateOps.insert({allocmem, insertionPoint});
486 LLVM_DEBUG(for (auto [allocMemOp, _]
487 : candidateOps) {
488 llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
490 return mlir::success();
493 const StackArraysAnalysisWrapper::AllocMemMap *
494 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
495 if (!funcMaps.contains(func))
496 if (mlir::failed(analyseFunction(func)))
497 return nullptr;
498 return &funcMaps[func];
501 /// Restore the old allocation type exected by existing code
502 static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter,
503 const mlir::Location &loc,
504 mlir::Value heap, mlir::Value stack) {
505 mlir::Type heapTy = heap.getType();
506 mlir::Type stackTy = stack.getType();
508 if (heapTy == stackTy)
509 return stack;
511 fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy);
512 LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy =
513 mlir::cast<fir::ReferenceType>(stackTy);
514 assert(firHeapTy.getElementType() == firRefTy.getElementType() &&
515 "Allocations must have the same type");
517 auto insertionPoint = rewriter.saveInsertionPoint();
518 rewriter.setInsertionPointAfter(stack.getDefiningOp());
519 mlir::Value conv =
520 rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult();
521 rewriter.restoreInsertionPoint(insertionPoint);
522 return conv;
525 llvm::LogicalResult
526 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
527 mlir::PatternRewriter &rewriter) const {
528 auto oldInsertionPt = rewriter.saveInsertionPoint();
529 // add alloca operation
530 std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter);
531 rewriter.restoreInsertionPoint(oldInsertionPt);
532 if (!alloca)
533 return mlir::failure();
535 // remove freemem operations
536 llvm::SmallVector<mlir::Operation *> erases;
537 for (mlir::Operation *user : allocmem.getOperation()->getUsers()) {
538 if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
539 for (mlir::Operation *user : declareOp->getUsers()) {
540 if (mlir::isa<fir::FreeMemOp>(user))
541 erases.push_back(user);
545 if (mlir::isa<fir::FreeMemOp>(user))
546 erases.push_back(user);
549 // now we are done iterating the users, it is safe to mutate them
550 for (mlir::Operation *erase : erases)
551 rewriter.eraseOp(erase);
553 // replace references to heap allocation with references to stack allocation
554 mlir::Value newValue = convertAllocationType(
555 rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult());
556 rewriter.replaceAllUsesWith(allocmem.getResult(), newValue);
558 // remove allocmem operation
559 rewriter.eraseOp(allocmem.getOperation());
561 return mlir::success();
564 static bool isInLoop(mlir::Block *block) {
565 return mlir::LoopLikeOpInterface::blockIsInLoop(block);
568 static bool isInLoop(mlir::Operation *op) {
569 return isInLoop(op->getBlock()) ||
570 op->getParentOfType<mlir::LoopLikeOpInterface>();
573 InsertionPoint
574 AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) {
575 // Ideally the alloca should be inserted at the end of the function entry
576 // block so that we do not allocate stack space in a loop. However,
577 // the operands to the alloca may not be available that early, so insert it
578 // after the last operand becomes available
579 // If the old allocmem op was in an openmp region then it should not be moved
580 // outside of that
581 LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: "
582 << oldAlloc << "\n");
584 // check that an Operation or Block we are about to return is not in a loop
585 auto checkReturn = [&](auto *point) -> InsertionPoint {
586 if (isInLoop(point)) {
587 mlir::Operation *oldAllocOp = oldAlloc.getOperation();
588 if (isInLoop(oldAllocOp)) {
589 // where we want to put it is in a loop, and even the old location is in
590 // a loop. Give up.
591 return findAllocaLoopInsertionPoint(oldAlloc);
593 return {oldAllocOp};
595 return {point};
598 auto oldOmpRegion =
599 oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
601 // Find when the last operand value becomes available
602 mlir::Block *operandsBlock = nullptr;
603 mlir::Operation *lastOperand = nullptr;
604 for (mlir::Value operand : oldAlloc.getOperands()) {
605 LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n");
606 mlir::Operation *op = operand.getDefiningOp();
607 if (!op)
608 return checkReturn(oldAlloc.getOperation());
609 if (!operandsBlock)
610 operandsBlock = op->getBlock();
611 else if (operandsBlock != op->getBlock()) {
612 LLVM_DEBUG(llvm::dbgs()
613 << "----operand declared in a different block!\n");
614 // Operation::isBeforeInBlock requires the operations to be in the same
615 // block. The best we can do is the location of the allocmem.
616 return checkReturn(oldAlloc.getOperation());
618 if (!lastOperand || lastOperand->isBeforeInBlock(op))
619 lastOperand = op;
622 if (lastOperand) {
623 // there were value operands to the allocmem so insert after the last one
624 LLVM_DEBUG(llvm::dbgs()
625 << "--Placing after last operand: " << *lastOperand << "\n");
626 // check we aren't moving out of an omp region
627 auto lastOpOmpRegion =
628 lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
629 if (lastOpOmpRegion == oldOmpRegion)
630 return checkReturn(lastOperand);
631 // Presumably this happened because the operands became ready before the
632 // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should
633 // imply that oldOmpRegion comes after lastOpOmpRegion.
634 return checkReturn(oldOmpRegion.getAllocaBlock());
637 // There were no value operands to the allocmem so we are safe to insert it
638 // as early as we want
640 // handle openmp case
641 if (oldOmpRegion)
642 return checkReturn(oldOmpRegion.getAllocaBlock());
644 // fall back to the function entry block
645 mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>();
646 assert(func && "This analysis is run on func.func");
647 mlir::Block &entryBlock = func.getBlocks().front();
648 LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n");
649 return checkReturn(&entryBlock);
652 InsertionPoint
653 AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) {
654 mlir::Operation *oldAllocOp = oldAlloc;
655 // This is only called as a last resort. We should try to insert at the
656 // location of the old allocation, which is inside of a loop, using
657 // llvm.stacksave/llvm.stackrestore
659 // find freemem ops
660 llvm::SmallVector<mlir::Operation *, 1> freeOps;
662 for (mlir::Operation *user : oldAllocOp->getUsers()) {
663 if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
664 for (mlir::Operation *user : declareOp->getUsers()) {
665 if (mlir::isa<fir::FreeMemOp>(user))
666 freeOps.push_back(user);
670 if (mlir::isa<fir::FreeMemOp>(user))
671 freeOps.push_back(user);
674 assert(freeOps.size() && "DFA should only return freed memory");
676 // Don't attempt to reason about a stacksave/stackrestore between different
677 // blocks
678 for (mlir::Operation *free : freeOps)
679 if (free->getBlock() != oldAllocOp->getBlock())
680 return {nullptr};
682 // Check that there aren't any other stack allocations in between the
683 // stack save and stack restore
684 // note: for flang generated temporaries there should only be one free op
685 for (mlir::Operation *free : freeOps) {
686 for (mlir::Operation *op = oldAlloc; op && op != free;
687 op = op->getNextNode()) {
688 if (mlir::isa<fir::AllocaOp>(op))
689 return {nullptr};
693 return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true};
696 std::optional<fir::AllocaOp>
697 AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc,
698 mlir::PatternRewriter &rewriter) const {
699 auto it = candidateOps.find(oldAlloc.getOperation());
700 if (it == candidateOps.end())
701 return {};
702 InsertionPoint insertionPoint = it->second;
703 if (!insertionPoint)
704 return {};
706 if (insertionPoint.shouldSaveRestoreStack())
707 insertStackSaveRestore(oldAlloc, rewriter);
709 mlir::Location loc = oldAlloc.getLoc();
710 mlir::Type varTy = oldAlloc.getInType();
711 if (mlir::Operation *op = insertionPoint.tryGetOperation()) {
712 rewriter.setInsertionPointAfter(op);
713 } else {
714 mlir::Block *block = insertionPoint.tryGetBlock();
715 assert(block && "There must be a valid insertion point");
716 rewriter.setInsertionPointToStart(block);
719 auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef {
720 if (opt)
721 return *opt;
722 return {};
725 llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName());
726 llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName());
727 return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName,
728 oldAlloc.getTypeparams(),
729 oldAlloc.getShape());
732 void AllocMemConversion::insertStackSaveRestore(
733 fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const {
734 auto oldPoint = rewriter.saveInsertionPoint();
735 auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>();
736 fir::FirOpBuilder builder{rewriter, mod};
738 builder.setInsertionPoint(oldAlloc);
739 mlir::Value sp = builder.genStackSave(oldAlloc.getLoc());
741 auto createStackRestoreCall = [&](mlir::Operation *user) {
742 builder.setInsertionPoint(user);
743 builder.genStackRestore(user->getLoc(), sp);
746 for (mlir::Operation *user : oldAlloc->getUsers()) {
747 if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
748 for (mlir::Operation *user : declareOp->getUsers()) {
749 if (mlir::isa<fir::FreeMemOp>(user))
750 createStackRestoreCall(user);
754 if (mlir::isa<fir::FreeMemOp>(user)) {
755 createStackRestoreCall(user);
759 rewriter.restoreInsertionPoint(oldPoint);
762 StackArraysPass::StackArraysPass(const StackArraysPass &pass)
763 : fir::impl::StackArraysBase<StackArraysPass>(pass) {}
765 llvm::StringRef StackArraysPass::getDescription() const {
766 return "Move heap allocated array temporaries to the stack";
769 void StackArraysPass::runOnOperation() {
770 mlir::func::FuncOp func = getOperation();
772 auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
773 const StackArraysAnalysisWrapper::AllocMemMap *candidateOps =
774 analysis.getCandidateOps(func);
775 if (!candidateOps) {
776 signalPassFailure();
777 return;
780 if (candidateOps->empty())
781 return;
782 runCount += candidateOps->size();
784 llvm::SmallVector<mlir::Operation *> opsToConvert;
785 opsToConvert.reserve(candidateOps->size());
786 for (auto [op, _] : *candidateOps)
787 opsToConvert.push_back(op);
789 mlir::MLIRContext &context = getContext();
790 mlir::RewritePatternSet patterns(&context);
791 mlir::GreedyRewriteConfig config;
792 // prevent the pattern driver form merging blocks
793 config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
795 patterns.insert<AllocMemConversion>(&context, *candidateOps);
796 if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
797 std::move(patterns), config))) {
798 mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
799 signalPassFailure();