[Instrumentation] Fix a warning
[llvm-project.git] / flang / lib / Optimizer / Transforms / AbstractResult.cpp
blob2eca349110f3af0a42c0b7d34b582c617c7a4f7f
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Pass/PassManager.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/TypeSwitch.h"
24 namespace fir {
25 #define GEN_PASS_DEF_ABSTRACTRESULTOPT
26 #include "flang/Optimizer/Transforms/Passes.h.inc"
27 } // namespace fir
29 #define DEBUG_TYPE "flang-abstract-result-opt"
31 using namespace mlir;
33 namespace fir {
34 namespace {
36 // Helper to only build the symbol table if needed because its build time is
37 // linear on the number of symbols in the module.
38 struct LazySymbolTable {
39 LazySymbolTable(mlir::Operation *op)
40 : module{op->getParentOfType<mlir::ModuleOp>()} {}
41 void build() {
42 if (table)
43 return;
44 table = std::make_unique<mlir::SymbolTable>(module);
47 template <typename T>
48 T lookup(llvm::StringRef name) {
49 build();
50 return table->lookup<T>(name);
53 private:
54 std::unique_ptr<mlir::SymbolTable> table;
55 mlir::ModuleOp module;
58 bool hasScalarDerivedResult(mlir::FunctionType funTy) {
59 // C_PTR/C_FUNPTR are results to void* in this pass, do not consider
60 // them as normal derived types.
61 return funTy.getNumResults() == 1 &&
62 mlir::isa<fir::RecordType>(funTy.getResult(0)) &&
63 !fir::isa_builtin_cptr_type(funTy.getResult(0));
66 static mlir::Type getResultArgumentType(mlir::Type resultType,
67 bool shouldBoxResult) {
68 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
69 .Case<fir::SequenceType, fir::RecordType>(
70 [&](mlir::Type type) -> mlir::Type {
71 if (shouldBoxResult)
72 return fir::BoxType::get(type);
73 return fir::ReferenceType::get(type);
75 .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
76 return fir::ReferenceType::get(type);
78 .Default([](mlir::Type) -> mlir::Type {
79 llvm_unreachable("bad abstract result type");
80 });
83 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
84 bool shouldBoxResult) {
85 auto resultType = funcTy.getResult(0);
86 auto argTy = getResultArgumentType(resultType, shouldBoxResult);
87 llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
88 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
89 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
90 /*resultTypes=*/{});
93 static mlir::Type getVoidPtrType(mlir::MLIRContext *context) {
94 return fir::ReferenceType::get(mlir::NoneType::get(context));
97 /// This is for function result types that are of type C_PTR from ISO_C_BINDING.
98 /// Follow the ABI for interoperability with C.
99 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
100 assert(fir::isa_builtin_cptr_type(funcTy.getResult(0)));
101 llvm::SmallVector<mlir::Type> outputTypes{
102 getVoidPtrType(funcTy.getContext())};
103 return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
104 outputTypes);
107 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
108 return mlir::isa<fir::SequenceType, fir::RecordType>(resultType) &&
109 shouldBoxResult;
112 template <typename Op>
113 class CallConversion : public mlir::OpRewritePattern<Op> {
114 public:
115 using mlir::OpRewritePattern<Op>::OpRewritePattern;
117 CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
118 : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
120 llvm::LogicalResult
121 matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
122 auto loc = op.getLoc();
123 auto result = op->getResult(0);
124 if (!result.hasOneUse()) {
125 mlir::emitError(loc,
126 "calls with abstract result must have exactly one user");
127 return mlir::failure();
129 auto saveResult =
130 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
131 if (!saveResult) {
132 mlir::emitError(
133 loc, "calls with abstract result must be used in fir.save_result");
134 return mlir::failure();
136 auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
137 auto buffer = saveResult.getMemref();
138 mlir::Value arg = buffer;
139 if (mustEmboxResult(result.getType(), shouldBoxResult))
140 arg = rewriter.create<fir::EmboxOp>(
141 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
142 saveResult.getTypeparams());
144 llvm::SmallVector<mlir::Type> newResultTypes;
145 bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
146 if (isResultBuiltinCPtr)
147 newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
149 Op newOp;
150 // fir::CallOp specific handling.
151 if constexpr (std::is_same_v<Op, fir::CallOp>) {
152 if (op.getCallee()) {
153 llvm::SmallVector<mlir::Value> newOperands;
154 if (!isResultBuiltinCPtr)
155 newOperands.emplace_back(arg);
156 newOperands.append(op.getOperands().begin(), op.getOperands().end());
157 newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
158 newResultTypes, newOperands);
159 } else {
160 // Indirect calls.
161 llvm::SmallVector<mlir::Type> newInputTypes;
162 if (!isResultBuiltinCPtr)
163 newInputTypes.emplace_back(argType);
164 for (auto operand : op.getOperands().drop_front())
165 newInputTypes.push_back(operand.getType());
166 auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
167 newResultTypes);
169 llvm::SmallVector<mlir::Value> newOperands;
170 newOperands.push_back(
171 rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
172 if (!isResultBuiltinCPtr)
173 newOperands.push_back(arg);
174 newOperands.append(op.getOperands().begin() + 1,
175 op.getOperands().end());
176 newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
177 newResultTypes, newOperands);
181 // fir::DispatchOp specific handling.
182 if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
183 llvm::SmallVector<mlir::Value> newOperands;
184 if (!isResultBuiltinCPtr)
185 newOperands.emplace_back(arg);
186 unsigned passArgShift = newOperands.size();
187 newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
188 mlir::IntegerAttr passArgPos;
189 if (op.getPassArgPos())
190 passArgPos =
191 rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift);
192 newOp = rewriter.create<fir::DispatchOp>(
193 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
194 op.getOperands()[0], newOperands, passArgPos,
195 op.getProcedureAttrsAttr());
198 if (isResultBuiltinCPtr) {
199 mlir::Value save = saveResult.getMemref();
200 auto module = op->template getParentOfType<mlir::ModuleOp>();
201 FirOpBuilder builder(rewriter, module);
202 mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
203 builder, loc, save, result.getType());
204 builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr);
206 op->dropAllReferences();
207 rewriter.eraseOp(op);
208 return mlir::success();
211 private:
212 bool shouldBoxResult;
215 class SaveResultOpConversion
216 : public mlir::OpRewritePattern<fir::SaveResultOp> {
217 public:
218 using OpRewritePattern::OpRewritePattern;
219 SaveResultOpConversion(mlir::MLIRContext *context)
220 : OpRewritePattern(context) {}
221 llvm::LogicalResult
222 matchAndRewrite(fir::SaveResultOp op,
223 mlir::PatternRewriter &rewriter) const override {
224 mlir::Operation *call = op.getValue().getDefiningOp();
225 mlir::Type type = op.getValue().getType();
226 if (mlir::isa<fir::RecordType>(type) && call && fir::hasBindcAttr(call) &&
227 !fir::isa_builtin_cptr_type(type)) {
228 rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValue(),
229 op.getMemref());
230 } else {
231 rewriter.eraseOp(op);
233 return mlir::success();
237 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
238 public:
239 using OpRewritePattern::OpRewritePattern;
240 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
241 : OpRewritePattern(context), newArg{newArg} {}
242 llvm::LogicalResult
243 matchAndRewrite(mlir::func::ReturnOp ret,
244 mlir::PatternRewriter &rewriter) const override {
245 auto loc = ret.getLoc();
246 rewriter.setInsertionPoint(ret);
247 mlir::Value resultValue = ret.getOperand(0);
248 fir::LoadOp resultLoad;
249 mlir::Value resultStorage;
250 // Identify result local storage.
251 if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
252 resultLoad = load;
253 resultStorage = load.getMemref();
254 // The result alloca may be behind a fir.declare, if any.
255 if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
256 resultStorage = declare.getMemref();
258 // Replace old local storage with new storage argument, unless
259 // the derived type is C_PTR/C_FUN_PTR, in which case the return
260 // type is updated to return void* (no new argument is passed).
261 if (fir::isa_builtin_cptr_type(resultValue.getType())) {
262 auto module = ret->getParentOfType<mlir::ModuleOp>();
263 FirOpBuilder builder(rewriter, module);
264 mlir::Value cptr = resultValue;
265 if (resultLoad) {
266 // Replace whole derived type load by component load.
267 cptr = resultLoad.getMemref();
268 rewriter.setInsertionPoint(resultLoad);
270 mlir::Value newResultValue =
271 fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
272 newResultValue = builder.createConvert(
273 loc, getVoidPtrType(ret.getContext()), newResultValue);
274 rewriter.setInsertionPoint(ret);
275 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
276 ret, mlir::ValueRange{newResultValue});
277 } else if (resultStorage) {
278 resultStorage.replaceAllUsesWith(newArg);
279 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
280 } else {
281 // The result storage may have been optimized out by a memory to
282 // register pass, this is possible for fir.box results, or fir.record
283 // with no length parameters. Simply store the result in the result
284 // storage. at the return point.
285 rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
286 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
288 // Delete result old local storage if unused.
289 if (resultStorage)
290 if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
291 if (alloc->use_empty())
292 rewriter.eraseOp(alloc);
293 return mlir::success();
296 private:
297 mlir::Value newArg;
300 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
301 public:
302 using OpRewritePattern::OpRewritePattern;
303 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
304 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
305 llvm::LogicalResult
306 matchAndRewrite(fir::AddrOfOp addrOf,
307 mlir::PatternRewriter &rewriter) const override {
308 auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType());
309 mlir::FunctionType newFuncTy;
310 if (oldFuncTy.getNumResults() != 0 &&
311 fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
312 newFuncTy = getCPtrFunctionType(oldFuncTy);
313 else
314 newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
315 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
316 addrOf.getSymbol());
317 // Rather than converting all op a function pointer might transit through
318 // (e.g calls, stores, loads, converts...), cast new type to the abstract
319 // type. A conversion will be added when calling indirect calls of abstract
320 // types.
321 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
322 return mlir::success();
325 private:
326 bool shouldBoxResult;
329 class AbstractResultOpt
330 : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
331 public:
332 using fir::impl::AbstractResultOptBase<
333 AbstractResultOpt>::AbstractResultOptBase;
335 template <typename OpTy>
336 void runOnFunctionLikeOperation(OpTy func, bool shouldBoxResult,
337 mlir::RewritePatternSet &patterns,
338 mlir::ConversionTarget &target) {
339 auto loc = func.getLoc();
340 auto *context = &getContext();
341 // Convert function type itself if it has an abstract result.
342 auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
343 // Scalar derived result of BIND(C) function must be returned according
344 // to the C struct return ABI which is target dependent and implemented in
345 // the target-rewrite pass.
346 if (hasScalarDerivedResult(funcTy) &&
347 fir::hasBindcAttr(func.getOperation()))
348 return;
349 if (hasAbstractResult(funcTy)) {
350 if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
351 func.setType(getCPtrFunctionType(funcTy));
352 patterns.insert<ReturnOpConversion>(context, mlir::Value{});
353 target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
354 [](mlir::func::ReturnOp ret) {
355 mlir::Type retTy = ret.getOperand(0).getType();
356 return !fir::isa_builtin_cptr_type(retTy);
358 return;
360 if (!func.empty()) {
361 // Insert new argument.
362 mlir::OpBuilder rewriter(context);
363 auto resultType = funcTy.getResult(0);
364 auto argTy = getResultArgumentType(resultType, shouldBoxResult);
365 func.insertArgument(0u, argTy, {}, loc);
366 func.eraseResult(0u);
367 mlir::Value newArg = func.getArgument(0u);
368 if (mustEmboxResult(resultType, shouldBoxResult)) {
369 auto bufferType = fir::ReferenceType::get(resultType);
370 rewriter.setInsertionPointToStart(&func.front());
371 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
373 patterns.insert<ReturnOpConversion>(context, newArg);
374 target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
375 [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
376 assert(func.getFunctionType() ==
377 getNewFunctionType(funcTy, shouldBoxResult));
378 } else {
379 llvm::SmallVector<mlir::DictionaryAttr> allArgs;
380 func.getAllArgAttrs(allArgs);
381 allArgs.insert(allArgs.begin(),
382 mlir::DictionaryAttr::get(func->getContext()));
383 func.setType(getNewFunctionType(funcTy, shouldBoxResult));
384 func.setAllArgAttrs(allArgs);
389 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
390 mlir::RewritePatternSet &patterns,
391 mlir::ConversionTarget &target) {
392 runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target);
395 void runOnSpecificOperation(mlir::gpu::GPUFuncOp func, bool shouldBoxResult,
396 mlir::RewritePatternSet &patterns,
397 mlir::ConversionTarget &target) {
398 runOnFunctionLikeOperation(func, shouldBoxResult, patterns, target);
401 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
402 return mlir::TypeSwitch<mlir::Type, bool>(type)
403 .Case([](fir::BoxProcType boxProc) {
404 return fir::hasAbstractResult(
405 mlir::cast<mlir::FunctionType>(boxProc.getEleTy()));
407 .Case([](fir::PointerType pointer) {
408 return fir::hasAbstractResult(
409 mlir::cast<mlir::FunctionType>(pointer.getEleTy()));
411 .Default([](auto &&) { return false; });
414 void runOnSpecificOperation(fir::GlobalOp global, bool,
415 mlir::RewritePatternSet &,
416 mlir::ConversionTarget &) {
417 if (containsFunctionTypeWithAbstractResult(global.getType())) {
418 TODO(global->getLoc(), "support for procedure pointers");
422 /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
423 void runOnModule() {
424 mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());
426 auto pass = std::make_unique<AbstractResultOpt>();
427 pass->copyOptionValuesFrom(this);
428 mlir::OpPassManager pipeline;
429 pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});
431 // Run the pass on all operations directly nested inside of the ModuleOp
432 // we can't just call runOnSpecificOperation here because the pass
433 // implementation only works when scoped to a particular func.func or
434 // fir.global
435 for (mlir::Region &region : mod->getRegions()) {
436 for (mlir::Block &block : region.getBlocks()) {
437 for (mlir::Operation &op : block.getOperations()) {
438 if (mlir::failed(runPipeline(pipeline, &op))) {
439 mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
440 signalPassFailure();
441 return;
448 void runOnOperation() override {
449 auto *context = &this->getContext();
450 mlir::Operation *op = this->getOperation();
451 if (mlir::isa<mlir::ModuleOp>(op)) {
452 runOnModule();
453 return;
456 LazySymbolTable symbolTable(op);
458 mlir::RewritePatternSet patterns(context);
459 mlir::ConversionTarget target = *context;
460 const bool shouldBoxResult = this->passResultAsBox.getValue();
462 mlir::TypeSwitch<mlir::Operation *, void>(op)
463 .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
464 runOnSpecificOperation(op, shouldBoxResult, patterns, target);
466 .Case<mlir::gpu::GPUModuleOp>([&](auto op) {
467 auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(*op);
468 for (auto funcOp : gpuMod.template getOps<mlir::func::FuncOp>())
469 runOnSpecificOperation(funcOp, shouldBoxResult, patterns, target);
470 for (auto gpuFuncOp : gpuMod.template getOps<mlir::gpu::GPUFuncOp>())
471 runOnSpecificOperation(gpuFuncOp, shouldBoxResult, patterns,
472 target);
475 // Convert the calls and, if needed, the ReturnOp in the function body.
476 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
477 mlir::func::FuncDialect>();
478 target.addIllegalOp<fir::SaveResultOp>();
479 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
480 mlir::FunctionType funTy = call.getFunctionType();
481 if (hasScalarDerivedResult(funTy) &&
482 fir::hasBindcAttr(call.getOperation()))
483 return true;
484 return !hasAbstractResult(funTy);
486 target.addDynamicallyLegalOp<fir::AddrOfOp>([&symbolTable](
487 fir::AddrOfOp addrOf) {
488 if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType())) {
489 if (hasScalarDerivedResult(funTy)) {
490 auto func = symbolTable.lookup<mlir::func::FuncOp>(
491 addrOf.getSymbol().getRootReference().getValue());
492 return func && fir::hasBindcAttr(func.getOperation());
494 return !hasAbstractResult(funTy);
496 return true;
498 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
499 mlir::FunctionType funTy = dispatch.getFunctionType();
500 if (hasScalarDerivedResult(funTy) &&
501 fir::hasBindcAttr(dispatch.getOperation()))
502 return true;
503 return !hasAbstractResult(dispatch.getFunctionType());
506 patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
507 patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
508 patterns.insert<SaveResultOpConversion>(context);
509 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
510 if (mlir::failed(
511 mlir::applyPartialConversion(op, target, std::move(patterns)))) {
512 mlir::emitError(op->getLoc(), "error in converting abstract results\n");
513 this->signalPassFailure();
518 } // end anonymous namespace
519 } // namespace fir