DAG: Fix assuming f16 is the only 16-bit fp type in concat vector combine (#121637)
[llvm-project.git] / flang / lib / Optimizer / Transforms / ConstantArgumentGlobalisation.cpp
blob562f3058f20f3e06a939fb69dcd240c92f1a09ee
1 //===- ConstantArgumentGlobalisation.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Dialect/FIRType.h"
13 #include "flang/Optimizer/Transforms/Passes.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dominance.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 namespace fir {
21 #define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT
22 #include "flang/Optimizer/Transforms/Passes.h.inc"
23 } // namespace fir
25 #define DEBUG_TYPE "flang-constant-argument-globalisation-opt"
27 namespace {
28 unsigned uniqueLitId = 1;
30 class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
31 protected:
32 const mlir::DominanceInfo &di;
34 public:
35 using OpRewritePattern::OpRewritePattern;
37 CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
38 : OpRewritePattern(ctx), di(_di) {}
40 llvm::LogicalResult
41 matchAndRewrite(fir::CallOp callOp,
42 mlir::PatternRewriter &rewriter) const override {
43 LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
44 auto module = callOp->getParentOfType<mlir::ModuleOp>();
45 bool needUpdate = false;
46 fir::FirOpBuilder builder(rewriter, module);
47 llvm::SmallVector<mlir::Value> newOperands;
48 llvm::SmallVector<std::pair<mlir::Operation *, mlir::Operation *>> allocas;
49 for (const mlir::Value &a : callOp.getArgs()) {
50 auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
51 // We can convert arguments that are alloca, and that has
52 // the value by reference attribute. All else is just added
53 // to the argument list.
54 if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
55 newOperands.push_back(a);
56 continue;
59 mlir::Type varTy = alloca.getInType();
60 assert(!fir::hasDynamicSize(varTy) &&
61 "only expect statically sized scalars to be by value");
63 // Find immediate store with const argument
64 mlir::Operation *store = nullptr;
65 for (mlir::Operation *s : alloca->getUsers()) {
66 if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
67 // We can only deal with ONE store - if already found one,
68 // set to nullptr and exit the loop.
69 if (store) {
70 store = nullptr;
71 break;
73 store = s;
77 // If we didn't find any store, or multiple stores, add argument as is
78 // and move on.
79 if (!store) {
80 newOperands.push_back(a);
81 continue;
84 LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");
86 mlir::Operation *definingOp = store->getOperand(0).getDefiningOp();
87 // If not a constant, add to operands and move on.
88 if (!mlir::isa<mlir::arith::ConstantOp>(definingOp)) {
89 // Unable to remove alloca arg
90 newOperands.push_back(a);
91 continue;
94 LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp << "\n");
96 std::string globalName =
97 "_global_const_." + std::to_string(uniqueLitId++);
98 assert(!builder.getNamedGlobal(globalName) &&
99 "We should have a unique name here");
101 if (llvm::none_of(allocas,
102 [alloca](auto x) { return x.first == alloca; })) {
103 allocas.push_back(std::make_pair(alloca, store));
106 auto loc = callOp.getLoc();
107 fir::GlobalOp global = builder.createGlobalConstant(
108 loc, varTy, globalName,
109 [&](fir::FirOpBuilder &builder) {
110 mlir::Operation *cln = definingOp->clone();
111 builder.insert(cln);
112 mlir::Value val =
113 builder.createConvert(loc, varTy, cln->getResult(0));
114 builder.create<fir::HasValueOp>(loc, val);
116 builder.createInternalLinkage());
117 mlir::Value addr = builder.create<fir::AddrOfOp>(loc, global.resultType(),
118 global.getSymbol());
119 newOperands.push_back(addr);
120 needUpdate = true;
123 if (needUpdate) {
124 auto loc = callOp.getLoc();
125 llvm::SmallVector<mlir::Type> newResultTypes;
126 newResultTypes.append(callOp.getResultTypes().begin(),
127 callOp.getResultTypes().end());
128 fir::CallOp newOp = builder.create<fir::CallOp>(
129 loc,
130 callOp.getCallee().has_value() ? callOp.getCallee().value()
131 : mlir::SymbolRefAttr{},
132 newResultTypes, newOperands);
133 // Copy all the attributes from the old to new op.
134 newOp->setAttrs(callOp->getAttrs());
135 rewriter.replaceOp(callOp, newOp);
137 for (auto a : allocas) {
138 if (a.first->hasOneUse()) {
139 // If the alloca is only used for a store and the call operand, the
140 // store is no longer required.
141 rewriter.eraseOp(a.second);
142 rewriter.eraseOp(a.first);
145 LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as "
146 << newOp << '\n');
147 return mlir::success();
150 // Failure here just means "we couldn't do the conversion", which is
151 // perfectly acceptable to the upper layers of this function.
152 return mlir::failure();
156 // this pass attempts to convert immediate scalar literals in function calls
157 // to global constants to allow transformations such as Dead Argument
158 // Elimination
159 class ConstantArgumentGlobalisationOpt
160 : public fir::impl::ConstantArgumentGlobalisationOptBase<
161 ConstantArgumentGlobalisationOpt> {
162 public:
163 ConstantArgumentGlobalisationOpt() = default;
165 void runOnOperation() override {
166 mlir::ModuleOp mod = getOperation();
167 mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
168 auto *context = &getContext();
169 mlir::RewritePatternSet patterns(context);
170 mlir::GreedyRewriteConfig config;
171 config.enableRegionSimplification =
172 mlir::GreedySimplifyRegionLevel::Disabled;
173 config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
175 patterns.insert<CallOpRewriter>(context, *di);
176 if (mlir::failed(
177 mlir::applyPatternsGreedily(mod, std::move(patterns), config))) {
178 mlir::emitError(mod.getLoc(),
179 "error in constant globalisation optimization\n");
180 signalPassFailure();
184 } // namespace