1 //===- MemRefDataFlowOpt.cpp - Memory DataFlow Optimization pass ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Dialect/FIRDialect.h"
10 #include "flang/Optimizer/Dialect/FIROps.h"
11 #include "flang/Optimizer/Dialect/FIRType.h"
12 #include "flang/Optimizer/Transforms/Passes.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/IR/Dominance.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Transforms/Passes.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
22 #define GEN_PASS_DEF_MEMREFDATAFLOWOPT
23 #include "flang/Optimizer/Transforms/Passes.h.inc"
26 #define DEBUG_TYPE "fir-memref-dataflow-opt"
32 template <typename OpT
>
33 static std::vector
<OpT
> getSpecificUsers(mlir::Value v
) {
35 for (mlir::Operation
*user
: v
.getUsers())
36 if (auto op
= dyn_cast
<OpT
>(user
))
41 /// This is based on MLIR's MemRefDataFlowOpt which is specialized on AffineRead
42 /// and AffineWrite interface
43 template <typename ReadOp
, typename WriteOp
>
44 class LoadStoreForwarding
{
46 LoadStoreForwarding(mlir::DominanceInfo
*di
) : domInfo(di
) {}
48 // FIXME: This algorithm has a bug. It ignores escaping references between a
50 std::optional
<WriteOp
> findStoreToForward(ReadOp loadOp
,
51 std::vector
<WriteOp
> &&storeOps
) {
52 llvm::SmallVector
<WriteOp
> candidateSet
;
54 for (auto storeOp
: storeOps
)
55 if (domInfo
->dominates(storeOp
, loadOp
))
56 candidateSet
.push_back(storeOp
);
58 if (candidateSet
.empty())
61 std::optional
<WriteOp
> nearestStore
;
62 for (auto candidate
: candidateSet
) {
63 auto nearerThan
= [&](WriteOp otherStore
) {
64 if (candidate
== otherStore
)
66 bool rv
= domInfo
->properlyDominates(candidate
, otherStore
);
68 LLVM_DEBUG(llvm::dbgs()
69 << "candidate " << candidate
<< " is not the nearest to "
70 << loadOp
<< " because " << otherStore
<< " is closer\n");
74 if (!llvm::any_of(candidateSet
, nearerThan
)) {
75 nearestStore
= mlir::cast
<WriteOp
>(candidate
);
82 << "load " << loadOp
<< " has " << candidateSet
.size()
83 << " store candidates, but this algorithm can't find a best.\n");
88 std::optional
<ReadOp
> findReadForWrite(WriteOp storeOp
,
89 std::vector
<ReadOp
> &&loadOps
) {
90 for (auto &loadOp
: loadOps
) {
91 if (domInfo
->dominates(storeOp
, loadOp
))
98 mlir::DominanceInfo
*domInfo
;
101 class MemDataFlowOpt
: public fir::impl::MemRefDataFlowOptBase
<MemDataFlowOpt
> {
103 void runOnOperation() override
{
104 mlir::func::FuncOp f
= getOperation();
106 auto *domInfo
= &getAnalysis
<mlir::DominanceInfo
>();
107 LoadStoreForwarding
<fir::LoadOp
, fir::StoreOp
> lsf(domInfo
);
108 f
.walk([&](fir::LoadOp loadOp
) {
109 auto maybeStore
= lsf
.findStoreToForward(
110 loadOp
, getSpecificUsers
<fir::StoreOp
>(loadOp
.getMemref()));
112 auto storeOp
= *maybeStore
;
113 LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f
.getName()
114 << " erasing load " << loadOp
115 << " with value from " << storeOp
<< '\n');
116 loadOp
.getResult().replaceAllUsesWith(storeOp
.getValue());
120 f
.walk([&](fir::AllocaOp alloca
) {
121 for (auto &storeOp
: getSpecificUsers
<fir::StoreOp
>(alloca
.getResult())) {
122 if (!lsf
.findReadForWrite(
123 storeOp
, getSpecificUsers
<fir::LoadOp
>(storeOp
.getMemref()))) {
124 LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f
.getName()
125 << " erasing store " << storeOp
<< '\n');
134 std::unique_ptr
<mlir::Pass
> fir::createMemDataFlowOptPass() {
135 return std::make_unique
<MemDataFlowOpt
>();