1 //===- ConstantArgumentGlobalisation.cpp ----------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Dialect/FIRType.h"
13 #include "flang/Optimizer/Transforms/Passes.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dominance.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT
22 #include "flang/Optimizer/Transforms/Passes.h.inc"
25 #define DEBUG_TYPE "flang-constant-argument-globalisation-opt"
28 unsigned uniqueLitId
= 1;
30 class CallOpRewriter
: public mlir::OpRewritePattern
<fir::CallOp
> {
32 const mlir::DominanceInfo
&di
;
35 using OpRewritePattern::OpRewritePattern
;
37 CallOpRewriter(mlir::MLIRContext
*ctx
, const mlir::DominanceInfo
&_di
)
38 : OpRewritePattern(ctx
), di(_di
) {}
41 matchAndRewrite(fir::CallOp callOp
,
42 mlir::PatternRewriter
&rewriter
) const override
{
43 LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp
<< "\n");
44 auto module
= callOp
->getParentOfType
<mlir::ModuleOp
>();
45 bool needUpdate
= false;
46 fir::FirOpBuilder
builder(rewriter
, module
);
47 llvm::SmallVector
<mlir::Value
> newOperands
;
48 llvm::SmallVector
<std::pair
<mlir::Operation
*, mlir::Operation
*>> allocas
;
49 for (const mlir::Value
&a
: callOp
.getArgs()) {
50 auto alloca
= mlir::dyn_cast_or_null
<fir::AllocaOp
>(a
.getDefiningOp());
51 // We can convert arguments that are alloca, and that has
52 // the value by reference attribute. All else is just added
53 // to the argument list.
54 if (!alloca
|| !alloca
->hasAttr(fir::getAdaptToByRefAttrName())) {
55 newOperands
.push_back(a
);
59 mlir::Type varTy
= alloca
.getInType();
60 assert(!fir::hasDynamicSize(varTy
) &&
61 "only expect statically sized scalars to be by value");
63 // Find immediate store with const argument
64 mlir::Operation
*store
= nullptr;
65 for (mlir::Operation
*s
: alloca
->getUsers()) {
66 if (mlir::isa
<fir::StoreOp
>(s
) && di
.dominates(s
, callOp
)) {
67 // We can only deal with ONE store - if already found one,
68 // set to nullptr and exit the loop.
77 // If we didn't find any store, or multiple stores, add argument as is
80 newOperands
.push_back(a
);
84 LLVM_DEBUG(llvm::dbgs() << " found store " << *store
<< "\n");
86 mlir::Operation
*definingOp
= store
->getOperand(0).getDefiningOp();
87 // If not a constant, add to operands and move on.
88 if (!mlir::isa
<mlir::arith::ConstantOp
>(definingOp
)) {
89 // Unable to remove alloca arg
90 newOperands
.push_back(a
);
94 LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp
<< "\n");
96 std::string globalName
=
97 "_global_const_." + std::to_string(uniqueLitId
++);
98 assert(!builder
.getNamedGlobal(globalName
) &&
99 "We should have a unique name here");
101 if (llvm::none_of(allocas
,
102 [alloca
](auto x
) { return x
.first
== alloca
; })) {
103 allocas
.push_back(std::make_pair(alloca
, store
));
106 auto loc
= callOp
.getLoc();
107 fir::GlobalOp global
= builder
.createGlobalConstant(
108 loc
, varTy
, globalName
,
109 [&](fir::FirOpBuilder
&builder
) {
110 mlir::Operation
*cln
= definingOp
->clone();
113 builder
.createConvert(loc
, varTy
, cln
->getResult(0));
114 builder
.create
<fir::HasValueOp
>(loc
, val
);
116 builder
.createInternalLinkage());
117 mlir::Value addr
= builder
.create
<fir::AddrOfOp
>(loc
, global
.resultType(),
119 newOperands
.push_back(addr
);
124 auto loc
= callOp
.getLoc();
125 llvm::SmallVector
<mlir::Type
> newResultTypes
;
126 newResultTypes
.append(callOp
.getResultTypes().begin(),
127 callOp
.getResultTypes().end());
128 fir::CallOp newOp
= builder
.create
<fir::CallOp
>(
130 callOp
.getCallee().has_value() ? callOp
.getCallee().value()
131 : mlir::SymbolRefAttr
{},
132 newResultTypes
, newOperands
);
133 // Copy all the attributes from the old to new op.
134 newOp
->setAttrs(callOp
->getAttrs());
135 rewriter
.replaceOp(callOp
, newOp
);
137 for (auto a
: allocas
) {
138 if (a
.first
->hasOneUse()) {
139 // If the alloca is only used for a store and the call operand, the
140 // store is no longer required.
141 rewriter
.eraseOp(a
.second
);
142 rewriter
.eraseOp(a
.first
);
145 LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp
<< " as "
147 return mlir::success();
150 // Failure here just means "we couldn't do the conversion", which is
151 // perfectly acceptable to the upper layers of this function.
152 return mlir::failure();
156 // this pass attempts to convert immediate scalar literals in function calls
157 // to global constants to allow transformations such as Dead Argument
159 class ConstantArgumentGlobalisationOpt
160 : public fir::impl::ConstantArgumentGlobalisationOptBase
<
161 ConstantArgumentGlobalisationOpt
> {
163 ConstantArgumentGlobalisationOpt() = default;
165 void runOnOperation() override
{
166 mlir::ModuleOp mod
= getOperation();
167 mlir::DominanceInfo
*di
= &getAnalysis
<mlir::DominanceInfo
>();
168 auto *context
= &getContext();
169 mlir::RewritePatternSet
patterns(context
);
170 mlir::GreedyRewriteConfig config
;
171 config
.enableRegionSimplification
=
172 mlir::GreedySimplifyRegionLevel::Disabled
;
173 config
.strictMode
= mlir::GreedyRewriteStrictness::ExistingOps
;
175 patterns
.insert
<CallOpRewriter
>(context
, *di
);
177 mlir::applyPatternsGreedily(mod
, std::move(patterns
), config
))) {
178 mlir::emitError(mod
.getLoc(),
179 "error in constant globalisation optimization\n");