1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Pass/PassManager.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/TypeSwitch.h"
25 #define GEN_PASS_DEF_ABSTRACTRESULTOPT
26 #include "flang/Optimizer/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "flang-abstract-result-opt"
36 // Helper to only build the symbol table if needed because its build time is
37 // linear on the number of symbols in the module.
38 struct LazySymbolTable
{
39 LazySymbolTable(mlir::Operation
*op
)
40 : module
{op
->getParentOfType
<mlir::ModuleOp
>()} {}
44 table
= std::make_unique
<mlir::SymbolTable
>(module
);
48 T
lookup(llvm::StringRef name
) {
50 return table
->lookup
<T
>(name
);
54 std::unique_ptr
<mlir::SymbolTable
> table
;
55 mlir::ModuleOp module
;
58 bool hasScalarDerivedResult(mlir::FunctionType funTy
) {
59 // C_PTR/C_FUNPTR are results to void* in this pass, do not consider
60 // them as normal derived types.
61 return funTy
.getNumResults() == 1 &&
62 mlir::isa
<fir::RecordType
>(funTy
.getResult(0)) &&
63 !fir::isa_builtin_cptr_type(funTy
.getResult(0));
66 static mlir::Type
getResultArgumentType(mlir::Type resultType
,
67 bool shouldBoxResult
) {
68 return llvm::TypeSwitch
<mlir::Type
, mlir::Type
>(resultType
)
69 .Case
<fir::SequenceType
, fir::RecordType
>(
70 [&](mlir::Type type
) -> mlir::Type
{
72 return fir::BoxType::get(type
);
73 return fir::ReferenceType::get(type
);
75 .Case
<fir::BaseBoxType
>([](mlir::Type type
) -> mlir::Type
{
76 return fir::ReferenceType::get(type
);
78 .Default([](mlir::Type
) -> mlir::Type
{
79 llvm_unreachable("bad abstract result type");
83 static mlir::FunctionType
getNewFunctionType(mlir::FunctionType funcTy
,
84 bool shouldBoxResult
) {
85 auto resultType
= funcTy
.getResult(0);
86 auto argTy
= getResultArgumentType(resultType
, shouldBoxResult
);
87 llvm::SmallVector
<mlir::Type
> newInputTypes
= {argTy
};
88 newInputTypes
.append(funcTy
.getInputs().begin(), funcTy
.getInputs().end());
89 return mlir::FunctionType::get(funcTy
.getContext(), newInputTypes
,
93 static mlir::Type
getVoidPtrType(mlir::MLIRContext
*context
) {
94 return fir::ReferenceType::get(mlir::NoneType::get(context
));
97 /// This is for function result types that are of type C_PTR from ISO_C_BINDING.
98 /// Follow the ABI for interoperability with C.
99 static mlir::FunctionType
getCPtrFunctionType(mlir::FunctionType funcTy
) {
100 assert(fir::isa_builtin_cptr_type(funcTy
.getResult(0)));
101 llvm::SmallVector
<mlir::Type
> outputTypes
{
102 getVoidPtrType(funcTy
.getContext())};
103 return mlir::FunctionType::get(funcTy
.getContext(), funcTy
.getInputs(),
107 static bool mustEmboxResult(mlir::Type resultType
, bool shouldBoxResult
) {
108 return mlir::isa
<fir::SequenceType
, fir::RecordType
>(resultType
) &&
112 template <typename Op
>
113 class CallConversion
: public mlir::OpRewritePattern
<Op
> {
115 using mlir::OpRewritePattern
<Op
>::OpRewritePattern
;
117 CallConversion(mlir::MLIRContext
*context
, bool shouldBoxResult
)
118 : OpRewritePattern
<Op
>(context
, 1), shouldBoxResult
{shouldBoxResult
} {}
121 matchAndRewrite(Op op
, mlir::PatternRewriter
&rewriter
) const override
{
122 auto loc
= op
.getLoc();
123 auto result
= op
->getResult(0);
124 if (!result
.hasOneUse()) {
126 "calls with abstract result must have exactly one user");
127 return mlir::failure();
130 mlir::dyn_cast
<fir::SaveResultOp
>(result
.use_begin().getUser());
133 loc
, "calls with abstract result must be used in fir.save_result");
134 return mlir::failure();
136 auto argType
= getResultArgumentType(result
.getType(), shouldBoxResult
);
137 auto buffer
= saveResult
.getMemref();
138 mlir::Value arg
= buffer
;
139 if (mustEmboxResult(result
.getType(), shouldBoxResult
))
140 arg
= rewriter
.create
<fir::EmboxOp
>(
141 loc
, argType
, buffer
, saveResult
.getShape(), /*slice*/ mlir::Value
{},
142 saveResult
.getTypeparams());
144 llvm::SmallVector
<mlir::Type
> newResultTypes
;
145 bool isResultBuiltinCPtr
= fir::isa_builtin_cptr_type(result
.getType());
146 if (isResultBuiltinCPtr
)
147 newResultTypes
.emplace_back(getVoidPtrType(result
.getContext()));
150 // fir::CallOp specific handling.
151 if constexpr (std::is_same_v
<Op
, fir::CallOp
>) {
152 if (op
.getCallee()) {
153 llvm::SmallVector
<mlir::Value
> newOperands
;
154 if (!isResultBuiltinCPtr
)
155 newOperands
.emplace_back(arg
);
156 newOperands
.append(op
.getOperands().begin(), op
.getOperands().end());
157 newOp
= rewriter
.create
<fir::CallOp
>(loc
, *op
.getCallee(),
158 newResultTypes
, newOperands
);
161 llvm::SmallVector
<mlir::Type
> newInputTypes
;
162 if (!isResultBuiltinCPtr
)
163 newInputTypes
.emplace_back(argType
);
164 for (auto operand
: op
.getOperands().drop_front())
165 newInputTypes
.push_back(operand
.getType());
166 auto newFuncTy
= mlir::FunctionType::get(op
.getContext(), newInputTypes
,
169 llvm::SmallVector
<mlir::Value
> newOperands
;
170 newOperands
.push_back(
171 rewriter
.create
<fir::ConvertOp
>(loc
, newFuncTy
, op
.getOperand(0)));
172 if (!isResultBuiltinCPtr
)
173 newOperands
.push_back(arg
);
174 newOperands
.append(op
.getOperands().begin() + 1,
175 op
.getOperands().end());
176 newOp
= rewriter
.create
<fir::CallOp
>(loc
, mlir::SymbolRefAttr
{},
177 newResultTypes
, newOperands
);
181 // fir::DispatchOp specific handling.
182 if constexpr (std::is_same_v
<Op
, fir::DispatchOp
>) {
183 llvm::SmallVector
<mlir::Value
> newOperands
;
184 if (!isResultBuiltinCPtr
)
185 newOperands
.emplace_back(arg
);
186 unsigned passArgShift
= newOperands
.size();
187 newOperands
.append(op
.getOperands().begin() + 1, op
.getOperands().end());
188 mlir::IntegerAttr passArgPos
;
189 if (op
.getPassArgPos())
191 rewriter
.getI32IntegerAttr(*op
.getPassArgPos() + passArgShift
);
192 newOp
= rewriter
.create
<fir::DispatchOp
>(
193 loc
, newResultTypes
, rewriter
.getStringAttr(op
.getMethod()),
194 op
.getOperands()[0], newOperands
, passArgPos
,
195 op
.getProcedureAttrsAttr());
198 if (isResultBuiltinCPtr
) {
199 mlir::Value save
= saveResult
.getMemref();
200 auto module
= op
->template getParentOfType
<mlir::ModuleOp
>();
201 FirOpBuilder
builder(rewriter
, module
);
202 mlir::Value saveAddr
= fir::factory::genCPtrOrCFunptrAddr(
203 builder
, loc
, save
, result
.getType());
204 builder
.createStoreWithConvert(loc
, newOp
->getResult(0), saveAddr
);
206 op
->dropAllReferences();
207 rewriter
.eraseOp(op
);
208 return mlir::success();
212 bool shouldBoxResult
;
215 class SaveResultOpConversion
216 : public mlir::OpRewritePattern
<fir::SaveResultOp
> {
218 using OpRewritePattern::OpRewritePattern
;
219 SaveResultOpConversion(mlir::MLIRContext
*context
)
220 : OpRewritePattern(context
) {}
222 matchAndRewrite(fir::SaveResultOp op
,
223 mlir::PatternRewriter
&rewriter
) const override
{
224 mlir::Operation
*call
= op
.getValue().getDefiningOp();
225 mlir::Type type
= op
.getValue().getType();
226 if (mlir::isa
<fir::RecordType
>(type
) && call
&& fir::hasBindcAttr(call
) &&
227 !fir::isa_builtin_cptr_type(type
)) {
228 rewriter
.replaceOpWithNewOp
<fir::StoreOp
>(op
, op
.getValue(),
231 rewriter
.eraseOp(op
);
233 return mlir::success();
237 class ReturnOpConversion
: public mlir::OpRewritePattern
<mlir::func::ReturnOp
> {
239 using OpRewritePattern::OpRewritePattern
;
240 ReturnOpConversion(mlir::MLIRContext
*context
, mlir::Value newArg
)
241 : OpRewritePattern(context
), newArg
{newArg
} {}
243 matchAndRewrite(mlir::func::ReturnOp ret
,
244 mlir::PatternRewriter
&rewriter
) const override
{
245 auto loc
= ret
.getLoc();
246 rewriter
.setInsertionPoint(ret
);
247 mlir::Value resultValue
= ret
.getOperand(0);
248 fir::LoadOp resultLoad
;
249 mlir::Value resultStorage
;
250 // Identify result local storage.
251 if (auto load
= resultValue
.getDefiningOp
<fir::LoadOp
>()) {
253 resultStorage
= load
.getMemref();
254 // The result alloca may be behind a fir.declare, if any.
255 if (auto declare
= resultStorage
.getDefiningOp
<fir::DeclareOp
>())
256 resultStorage
= declare
.getMemref();
258 // Replace old local storage with new storage argument, unless
259 // the derived type is C_PTR/C_FUN_PTR, in which case the return
260 // type is updated to return void* (no new argument is passed).
261 if (fir::isa_builtin_cptr_type(resultValue
.getType())) {
262 auto module
= ret
->getParentOfType
<mlir::ModuleOp
>();
263 FirOpBuilder
builder(rewriter
, module
);
264 mlir::Value cptr
= resultValue
;
266 // Replace whole derived type load by component load.
267 cptr
= resultLoad
.getMemref();
268 rewriter
.setInsertionPoint(resultLoad
);
270 mlir::Value newResultValue
=
271 fir::factory::genCPtrOrCFunptrValue(builder
, loc
, cptr
);
272 newResultValue
= builder
.createConvert(
273 loc
, getVoidPtrType(ret
.getContext()), newResultValue
);
274 rewriter
.setInsertionPoint(ret
);
275 rewriter
.replaceOpWithNewOp
<mlir::func::ReturnOp
>(
276 ret
, mlir::ValueRange
{newResultValue
});
277 } else if (resultStorage
) {
278 resultStorage
.replaceAllUsesWith(newArg
);
279 rewriter
.replaceOpWithNewOp
<mlir::func::ReturnOp
>(ret
);
281 // The result storage may have been optimized out by a memory to
282 // register pass, this is possible for fir.box results, or fir.record
283 // with no length parameters. Simply store the result in the result
284 // storage. at the return point.
285 rewriter
.create
<fir::StoreOp
>(loc
, resultValue
, newArg
);
286 rewriter
.replaceOpWithNewOp
<mlir::func::ReturnOp
>(ret
);
288 // Delete result old local storage if unused.
290 if (auto alloc
= resultStorage
.getDefiningOp
<fir::AllocaOp
>())
291 if (alloc
->use_empty())
292 rewriter
.eraseOp(alloc
);
293 return mlir::success();
300 class AddrOfOpConversion
: public mlir::OpRewritePattern
<fir::AddrOfOp
> {
302 using OpRewritePattern::OpRewritePattern
;
303 AddrOfOpConversion(mlir::MLIRContext
*context
, bool shouldBoxResult
)
304 : OpRewritePattern(context
), shouldBoxResult
{shouldBoxResult
} {}
306 matchAndRewrite(fir::AddrOfOp addrOf
,
307 mlir::PatternRewriter
&rewriter
) const override
{
308 auto oldFuncTy
= mlir::cast
<mlir::FunctionType
>(addrOf
.getType());
309 mlir::FunctionType newFuncTy
;
310 if (oldFuncTy
.getNumResults() != 0 &&
311 fir::isa_builtin_cptr_type(oldFuncTy
.getResult(0)))
312 newFuncTy
= getCPtrFunctionType(oldFuncTy
);
314 newFuncTy
= getNewFunctionType(oldFuncTy
, shouldBoxResult
);
315 auto newAddrOf
= rewriter
.create
<fir::AddrOfOp
>(addrOf
.getLoc(), newFuncTy
,
317 // Rather than converting all op a function pointer might transit through
318 // (e.g calls, stores, loads, converts...), cast new type to the abstract
319 // type. A conversion will be added when calling indirect calls of abstract
321 rewriter
.replaceOpWithNewOp
<fir::ConvertOp
>(addrOf
, oldFuncTy
, newAddrOf
);
322 return mlir::success();
326 bool shouldBoxResult
;
329 class AbstractResultOpt
330 : public fir::impl::AbstractResultOptBase
<AbstractResultOpt
> {
332 using fir::impl::AbstractResultOptBase
<
333 AbstractResultOpt
>::AbstractResultOptBase
;
335 template <typename OpTy
>
336 void runOnFunctionLikeOperation(OpTy func
, bool shouldBoxResult
,
337 mlir::RewritePatternSet
&patterns
,
338 mlir::ConversionTarget
&target
) {
339 auto loc
= func
.getLoc();
340 auto *context
= &getContext();
341 // Convert function type itself if it has an abstract result.
342 auto funcTy
= mlir::cast
<mlir::FunctionType
>(func
.getFunctionType());
343 // Scalar derived result of BIND(C) function must be returned according
344 // to the C struct return ABI which is target dependent and implemented in
345 // the target-rewrite pass.
346 if (hasScalarDerivedResult(funcTy
) &&
347 fir::hasBindcAttr(func
.getOperation()))
349 if (hasAbstractResult(funcTy
)) {
350 if (fir::isa_builtin_cptr_type(funcTy
.getResult(0))) {
351 func
.setType(getCPtrFunctionType(funcTy
));
352 patterns
.insert
<ReturnOpConversion
>(context
, mlir::Value
{});
353 target
.addDynamicallyLegalOp
<mlir::func::ReturnOp
>(
354 [](mlir::func::ReturnOp ret
) {
355 mlir::Type retTy
= ret
.getOperand(0).getType();
356 return !fir::isa_builtin_cptr_type(retTy
);
361 // Insert new argument.
362 mlir::OpBuilder
rewriter(context
);
363 auto resultType
= funcTy
.getResult(0);
364 auto argTy
= getResultArgumentType(resultType
, shouldBoxResult
);
365 func
.insertArgument(0u, argTy
, {}, loc
);
366 func
.eraseResult(0u);
367 mlir::Value newArg
= func
.getArgument(0u);
368 if (mustEmboxResult(resultType
, shouldBoxResult
)) {
369 auto bufferType
= fir::ReferenceType::get(resultType
);
370 rewriter
.setInsertionPointToStart(&func
.front());
371 newArg
= rewriter
.create
<fir::BoxAddrOp
>(loc
, bufferType
, newArg
);
373 patterns
.insert
<ReturnOpConversion
>(context
, newArg
);
374 target
.addDynamicallyLegalOp
<mlir::func::ReturnOp
>(
375 [](mlir::func::ReturnOp ret
) { return ret
.getOperands().empty(); });
376 assert(func
.getFunctionType() ==
377 getNewFunctionType(funcTy
, shouldBoxResult
));
379 llvm::SmallVector
<mlir::DictionaryAttr
> allArgs
;
380 func
.getAllArgAttrs(allArgs
);
381 allArgs
.insert(allArgs
.begin(),
382 mlir::DictionaryAttr::get(func
->getContext()));
383 func
.setType(getNewFunctionType(funcTy
, shouldBoxResult
));
384 func
.setAllArgAttrs(allArgs
);
389 void runOnSpecificOperation(mlir::func::FuncOp func
, bool shouldBoxResult
,
390 mlir::RewritePatternSet
&patterns
,
391 mlir::ConversionTarget
&target
) {
392 runOnFunctionLikeOperation(func
, shouldBoxResult
, patterns
, target
);
395 void runOnSpecificOperation(mlir::gpu::GPUFuncOp func
, bool shouldBoxResult
,
396 mlir::RewritePatternSet
&patterns
,
397 mlir::ConversionTarget
&target
) {
398 runOnFunctionLikeOperation(func
, shouldBoxResult
, patterns
, target
);
401 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type
) {
402 return mlir::TypeSwitch
<mlir::Type
, bool>(type
)
403 .Case([](fir::BoxProcType boxProc
) {
404 return fir::hasAbstractResult(
405 mlir::cast
<mlir::FunctionType
>(boxProc
.getEleTy()));
407 .Case([](fir::PointerType pointer
) {
408 return fir::hasAbstractResult(
409 mlir::cast
<mlir::FunctionType
>(pointer
.getEleTy()));
411 .Default([](auto &&) { return false; });
414 void runOnSpecificOperation(fir::GlobalOp global
, bool,
415 mlir::RewritePatternSet
&,
416 mlir::ConversionTarget
&) {
417 if (containsFunctionTypeWithAbstractResult(global
.getType())) {
418 TODO(global
->getLoc(), "support for procedure pointers");
422 /// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
424 mlir::ModuleOp mod
= mlir::cast
<mlir::ModuleOp
>(getOperation());
426 auto pass
= std::make_unique
<AbstractResultOpt
>();
427 pass
->copyOptionValuesFrom(this);
428 mlir::OpPassManager pipeline
;
429 pipeline
.addPass(std::unique_ptr
<mlir::Pass
>{pass
.release()});
431 // Run the pass on all operations directly nested inside of the ModuleOp
432 // we can't just call runOnSpecificOperation here because the pass
433 // implementation only works when scoped to a particular func.func or
435 for (mlir::Region
®ion
: mod
->getRegions()) {
436 for (mlir::Block
&block
: region
.getBlocks()) {
437 for (mlir::Operation
&op
: block
.getOperations()) {
438 if (mlir::failed(runPipeline(pipeline
, &op
))) {
439 mlir::emitError(op
.getLoc(), "Failed to run abstract result pass");
448 void runOnOperation() override
{
449 auto *context
= &this->getContext();
450 mlir::Operation
*op
= this->getOperation();
451 if (mlir::isa
<mlir::ModuleOp
>(op
)) {
456 LazySymbolTable
symbolTable(op
);
458 mlir::RewritePatternSet
patterns(context
);
459 mlir::ConversionTarget target
= *context
;
460 const bool shouldBoxResult
= this->passResultAsBox
.getValue();
462 mlir::TypeSwitch
<mlir::Operation
*, void>(op
)
463 .Case
<mlir::func::FuncOp
, fir::GlobalOp
>([&](auto op
) {
464 runOnSpecificOperation(op
, shouldBoxResult
, patterns
, target
);
466 .Case
<mlir::gpu::GPUModuleOp
>([&](auto op
) {
467 auto gpuMod
= mlir::dyn_cast
<mlir::gpu::GPUModuleOp
>(*op
);
468 for (auto funcOp
: gpuMod
.template getOps
<mlir::func::FuncOp
>())
469 runOnSpecificOperation(funcOp
, shouldBoxResult
, patterns
, target
);
470 for (auto gpuFuncOp
: gpuMod
.template getOps
<mlir::gpu::GPUFuncOp
>())
471 runOnSpecificOperation(gpuFuncOp
, shouldBoxResult
, patterns
,
475 // Convert the calls and, if needed, the ReturnOp in the function body.
476 target
.addLegalDialect
<fir::FIROpsDialect
, mlir::arith::ArithDialect
,
477 mlir::func::FuncDialect
>();
478 target
.addIllegalOp
<fir::SaveResultOp
>();
479 target
.addDynamicallyLegalOp
<fir::CallOp
>([](fir::CallOp call
) {
480 mlir::FunctionType funTy
= call
.getFunctionType();
481 if (hasScalarDerivedResult(funTy
) &&
482 fir::hasBindcAttr(call
.getOperation()))
484 return !hasAbstractResult(funTy
);
486 target
.addDynamicallyLegalOp
<fir::AddrOfOp
>([&symbolTable
](
487 fir::AddrOfOp addrOf
) {
488 if (auto funTy
= mlir::dyn_cast
<mlir::FunctionType
>(addrOf
.getType())) {
489 if (hasScalarDerivedResult(funTy
)) {
490 auto func
= symbolTable
.lookup
<mlir::func::FuncOp
>(
491 addrOf
.getSymbol().getRootReference().getValue());
492 return func
&& fir::hasBindcAttr(func
.getOperation());
494 return !hasAbstractResult(funTy
);
498 target
.addDynamicallyLegalOp
<fir::DispatchOp
>([](fir::DispatchOp dispatch
) {
499 mlir::FunctionType funTy
= dispatch
.getFunctionType();
500 if (hasScalarDerivedResult(funTy
) &&
501 fir::hasBindcAttr(dispatch
.getOperation()))
503 return !hasAbstractResult(dispatch
.getFunctionType());
506 patterns
.insert
<CallConversion
<fir::CallOp
>>(context
, shouldBoxResult
);
507 patterns
.insert
<CallConversion
<fir::DispatchOp
>>(context
, shouldBoxResult
);
508 patterns
.insert
<SaveResultOpConversion
>(context
);
509 patterns
.insert
<AddrOfOpConversion
>(context
, shouldBoxResult
);
511 mlir::applyPartialConversion(op
, target
, std::move(patterns
)))) {
512 mlir::emitError(op
->getLoc(), "error in converting abstract results\n");
513 this->signalPassFailure();
518 } // end anonymous namespace