1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "mlir/Transforms/Passes.h"
21 #include "llvm/ADT/TypeSwitch.h"
24 #define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
25 #define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
26 #include "flang/Optimizer/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "flang-abstract-result-opt"
36 static mlir::Type
getResultArgumentType(mlir::Type resultType
,
37 bool shouldBoxResult
) {
38 return llvm::TypeSwitch
<mlir::Type
, mlir::Type
>(resultType
)
39 .Case
<fir::SequenceType
, fir::RecordType
>(
40 [&](mlir::Type type
) -> mlir::Type
{
42 return fir::BoxType::get(type
);
43 return fir::ReferenceType::get(type
);
45 .Case
<fir::BaseBoxType
>([](mlir::Type type
) -> mlir::Type
{
46 return fir::ReferenceType::get(type
);
48 .Default([](mlir::Type
) -> mlir::Type
{
49 llvm_unreachable("bad abstract result type");
53 static mlir::FunctionType
getNewFunctionType(mlir::FunctionType funcTy
,
54 bool shouldBoxResult
) {
55 auto resultType
= funcTy
.getResult(0);
56 auto argTy
= getResultArgumentType(resultType
, shouldBoxResult
);
57 llvm::SmallVector
<mlir::Type
> newInputTypes
= {argTy
};
58 newInputTypes
.append(funcTy
.getInputs().begin(), funcTy
.getInputs().end());
59 return mlir::FunctionType::get(funcTy
.getContext(), newInputTypes
,
63 /// This is for function result types that are of type C_PTR from ISO_C_BINDING.
64 /// Follow the ABI for interoperability with C.
65 static mlir::FunctionType
getCPtrFunctionType(mlir::FunctionType funcTy
) {
66 auto resultType
= funcTy
.getResult(0);
67 assert(fir::isa_builtin_cptr_type(resultType
));
68 llvm::SmallVector
<mlir::Type
> outputTypes
;
69 auto recTy
= resultType
.dyn_cast
<fir::RecordType
>();
70 outputTypes
.emplace_back(recTy
.getTypeList()[0].second
);
71 return mlir::FunctionType::get(funcTy
.getContext(), funcTy
.getInputs(),
75 static bool mustEmboxResult(mlir::Type resultType
, bool shouldBoxResult
) {
76 return resultType
.isa
<fir::SequenceType
, fir::RecordType
>() &&
80 template <typename Op
>
81 class CallConversion
: public mlir::OpRewritePattern
<Op
> {
83 using mlir::OpRewritePattern
<Op
>::OpRewritePattern
;
85 CallConversion(mlir::MLIRContext
*context
, bool shouldBoxResult
)
86 : OpRewritePattern
<Op
>(context
, 1), shouldBoxResult
{shouldBoxResult
} {}
89 matchAndRewrite(Op op
, mlir::PatternRewriter
&rewriter
) const override
{
90 auto loc
= op
.getLoc();
91 auto result
= op
->getResult(0);
92 if (!result
.hasOneUse()) {
94 "calls with abstract result must have exactly one user");
95 return mlir::failure();
98 mlir::dyn_cast
<fir::SaveResultOp
>(result
.use_begin().getUser());
101 loc
, "calls with abstract result must be used in fir.save_result");
102 return mlir::failure();
104 auto argType
= getResultArgumentType(result
.getType(), shouldBoxResult
);
105 auto buffer
= saveResult
.getMemref();
106 mlir::Value arg
= buffer
;
107 if (mustEmboxResult(result
.getType(), shouldBoxResult
))
108 arg
= rewriter
.create
<fir::EmboxOp
>(
109 loc
, argType
, buffer
, saveResult
.getShape(), /*slice*/ mlir::Value
{},
110 saveResult
.getTypeparams());
112 llvm::SmallVector
<mlir::Type
> newResultTypes
;
113 // TODO: This should be generalized for derived types, and it is
114 // architecture and OS dependent.
115 bool isResultBuiltinCPtr
= fir::isa_builtin_cptr_type(result
.getType());
117 if (isResultBuiltinCPtr
) {
118 auto recTy
= result
.getType().template dyn_cast
<fir::RecordType
>();
119 newResultTypes
.emplace_back(recTy
.getTypeList()[0].second
);
122 // fir::CallOp specific handling.
123 if constexpr (std::is_same_v
<Op
, fir::CallOp
>) {
124 if (op
.getCallee()) {
125 llvm::SmallVector
<mlir::Value
> newOperands
;
126 if (!isResultBuiltinCPtr
)
127 newOperands
.emplace_back(arg
);
128 newOperands
.append(op
.getOperands().begin(), op
.getOperands().end());
129 newOp
= rewriter
.create
<fir::CallOp
>(loc
, *op
.getCallee(),
130 newResultTypes
, newOperands
);
133 llvm::SmallVector
<mlir::Type
> newInputTypes
;
134 if (!isResultBuiltinCPtr
)
135 newInputTypes
.emplace_back(argType
);
136 for (auto operand
: op
.getOperands().drop_front())
137 newInputTypes
.push_back(operand
.getType());
138 auto newFuncTy
= mlir::FunctionType::get(op
.getContext(), newInputTypes
,
141 llvm::SmallVector
<mlir::Value
> newOperands
;
142 newOperands
.push_back(
143 rewriter
.create
<fir::ConvertOp
>(loc
, newFuncTy
, op
.getOperand(0)));
144 if (!isResultBuiltinCPtr
)
145 newOperands
.push_back(arg
);
146 newOperands
.append(op
.getOperands().begin() + 1,
147 op
.getOperands().end());
148 newOp
= rewriter
.create
<fir::CallOp
>(loc
, mlir::SymbolRefAttr
{},
149 newResultTypes
, newOperands
);
153 // fir::DispatchOp specific handling.
154 if constexpr (std::is_same_v
<Op
, fir::DispatchOp
>) {
155 llvm::SmallVector
<mlir::Value
> newOperands
;
156 if (!isResultBuiltinCPtr
)
157 newOperands
.emplace_back(arg
);
158 unsigned passArgShift
= newOperands
.size();
159 newOperands
.append(op
.getOperands().begin() + 1, op
.getOperands().end());
161 fir::DispatchOp newDispatchOp
;
162 if (op
.getPassArgPos())
163 newOp
= rewriter
.create
<fir::DispatchOp
>(
164 loc
, newResultTypes
, rewriter
.getStringAttr(op
.getMethod()),
165 op
.getOperands()[0], newOperands
,
166 rewriter
.getI32IntegerAttr(*op
.getPassArgPos() + passArgShift
));
168 newOp
= rewriter
.create
<fir::DispatchOp
>(
169 loc
, newResultTypes
, rewriter
.getStringAttr(op
.getMethod()),
170 op
.getOperands()[0], newOperands
, nullptr);
173 if (isResultBuiltinCPtr
) {
174 mlir::Value save
= saveResult
.getMemref();
175 auto module
= op
->template getParentOfType
<mlir::ModuleOp
>();
176 fir::KindMapping kindMap
= fir::getKindMapping(module
);
177 FirOpBuilder
builder(rewriter
, kindMap
);
178 mlir::Value saveAddr
= fir::factory::genCPtrOrCFunptrAddr(
179 builder
, loc
, save
, result
.getType());
180 rewriter
.create
<fir::StoreOp
>(loc
, newOp
->getResult(0), saveAddr
);
182 op
->dropAllReferences();
183 rewriter
.eraseOp(op
);
184 return mlir::success();
188 bool shouldBoxResult
;
191 class SaveResultOpConversion
192 : public mlir::OpRewritePattern
<fir::SaveResultOp
> {
194 using OpRewritePattern::OpRewritePattern
;
195 SaveResultOpConversion(mlir::MLIRContext
*context
)
196 : OpRewritePattern(context
) {}
198 matchAndRewrite(fir::SaveResultOp op
,
199 mlir::PatternRewriter
&rewriter
) const override
{
200 rewriter
.eraseOp(op
);
201 return mlir::success();
205 class ReturnOpConversion
: public mlir::OpRewritePattern
<mlir::func::ReturnOp
> {
207 using OpRewritePattern::OpRewritePattern
;
208 ReturnOpConversion(mlir::MLIRContext
*context
, mlir::Value newArg
)
209 : OpRewritePattern(context
), newArg
{newArg
} {}
211 matchAndRewrite(mlir::func::ReturnOp ret
,
212 mlir::PatternRewriter
&rewriter
) const override
{
213 auto loc
= ret
.getLoc();
214 rewriter
.setInsertionPoint(ret
);
215 auto returnedValue
= ret
.getOperand(0);
216 bool replacedStorage
= false;
217 if (auto *op
= returnedValue
.getDefiningOp())
218 if (auto load
= mlir::dyn_cast
<fir::LoadOp
>(op
)) {
219 auto resultStorage
= load
.getMemref();
220 // TODO: This should be generalized for derived types, and it is
221 // architecture and OS dependent.
222 if (fir::isa_builtin_cptr_type(returnedValue
.getType())) {
223 rewriter
.eraseOp(load
);
224 auto module
= ret
->getParentOfType
<mlir::ModuleOp
>();
225 fir::KindMapping kindMap
= fir::getKindMapping(module
);
226 FirOpBuilder
builder(rewriter
, kindMap
);
227 mlir::Value retAddr
= fir::factory::genCPtrOrCFunptrAddr(
228 builder
, loc
, resultStorage
, returnedValue
.getType());
229 mlir::Value retValue
= rewriter
.create
<fir::LoadOp
>(
230 loc
, fir::unwrapRefType(retAddr
.getType()), retAddr
);
231 rewriter
.replaceOpWithNewOp
<mlir::func::ReturnOp
>(
232 ret
, mlir::ValueRange
{retValue
});
233 return mlir::success();
235 load
.getMemref().replaceAllUsesWith(newArg
);
236 replacedStorage
= true;
237 if (auto *alloc
= resultStorage
.getDefiningOp())
238 if (alloc
->use_empty())
239 rewriter
.eraseOp(alloc
);
241 // The result storage may have been optimized out by a memory to
242 // register pass, this is possible for fir.box results, or fir.record
243 // with no length parameters. Simply store the result in the result storage.
244 // at the return point.
245 if (!replacedStorage
)
246 rewriter
.create
<fir::StoreOp
>(loc
, returnedValue
, newArg
);
247 rewriter
.replaceOpWithNewOp
<mlir::func::ReturnOp
>(ret
);
248 return mlir::success();
255 class AddrOfOpConversion
: public mlir::OpRewritePattern
<fir::AddrOfOp
> {
257 using OpRewritePattern::OpRewritePattern
;
258 AddrOfOpConversion(mlir::MLIRContext
*context
, bool shouldBoxResult
)
259 : OpRewritePattern(context
), shouldBoxResult
{shouldBoxResult
} {}
261 matchAndRewrite(fir::AddrOfOp addrOf
,
262 mlir::PatternRewriter
&rewriter
) const override
{
263 auto oldFuncTy
= addrOf
.getType().cast
<mlir::FunctionType
>();
264 mlir::FunctionType newFuncTy
;
265 // TODO: This should be generalized for derived types, and it is
266 // architecture and OS dependent.
267 if (oldFuncTy
.getNumResults() != 0 &&
268 fir::isa_builtin_cptr_type(oldFuncTy
.getResult(0)))
269 newFuncTy
= getCPtrFunctionType(oldFuncTy
);
271 newFuncTy
= getNewFunctionType(oldFuncTy
, shouldBoxResult
);
272 auto newAddrOf
= rewriter
.create
<fir::AddrOfOp
>(addrOf
.getLoc(), newFuncTy
,
274 // Rather than converting all op a function pointer might transit through
275 // (e.g calls, stores, loads, converts...), cast new type to the abstract
276 // type. A conversion will be added when calling indirect calls of abstract
278 rewriter
.replaceOpWithNewOp
<fir::ConvertOp
>(addrOf
, oldFuncTy
, newAddrOf
);
279 return mlir::success();
283 bool shouldBoxResult
;
286 /// @brief Base CRTP class for AbstractResult pass family.
287 /// Contains common logic for abstract result conversion in a reusable fashion.
288 /// @tparam Pass target class that implements operation-specific logic.
289 /// @tparam PassBase base class template for the pass generated by TableGen.
290 /// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
291 /// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
292 /// This function should implement operation-specific functionality.
293 template <typename Pass
, template <typename
> class PassBase
>
294 class AbstractResultOptTemplate
: public PassBase
<Pass
> {
296 void runOnOperation() override
{
297 auto *context
= &this->getContext();
298 auto op
= this->getOperation();
300 mlir::RewritePatternSet
patterns(context
);
301 mlir::ConversionTarget target
= *context
;
302 const bool shouldBoxResult
= this->passResultAsBox
.getValue();
304 auto &self
= static_cast<Pass
&>(*this);
305 self
.runOnSpecificOperation(op
, shouldBoxResult
, patterns
, target
);
307 // Convert the calls and, if needed, the ReturnOp in the function body.
308 target
.addLegalDialect
<fir::FIROpsDialect
, mlir::arith::ArithDialect
,
309 mlir::func::FuncDialect
>();
310 target
.addIllegalOp
<fir::SaveResultOp
>();
311 target
.addDynamicallyLegalOp
<fir::CallOp
>([](fir::CallOp call
) {
312 return !hasAbstractResult(call
.getFunctionType());
314 target
.addDynamicallyLegalOp
<fir::AddrOfOp
>([](fir::AddrOfOp addrOf
) {
315 if (auto funTy
= addrOf
.getType().dyn_cast
<mlir::FunctionType
>())
316 return !hasAbstractResult(funTy
);
319 target
.addDynamicallyLegalOp
<fir::DispatchOp
>([](fir::DispatchOp dispatch
) {
320 return !hasAbstractResult(dispatch
.getFunctionType());
323 patterns
.insert
<CallConversion
<fir::CallOp
>>(context
, shouldBoxResult
);
324 patterns
.insert
<CallConversion
<fir::DispatchOp
>>(context
, shouldBoxResult
);
325 patterns
.insert
<SaveResultOpConversion
>(context
);
326 patterns
.insert
<AddrOfOpConversion
>(context
, shouldBoxResult
);
328 mlir::applyPartialConversion(op
, target
, std::move(patterns
)))) {
329 mlir::emitError(op
.getLoc(), "error in converting abstract results\n");
330 this->signalPassFailure();
335 class AbstractResultOnFuncOpt
336 : public AbstractResultOptTemplate
<AbstractResultOnFuncOpt
,
337 fir::impl::AbstractResultOnFuncOptBase
> {
339 void runOnSpecificOperation(mlir::func::FuncOp func
, bool shouldBoxResult
,
340 mlir::RewritePatternSet
&patterns
,
341 mlir::ConversionTarget
&target
) {
342 auto loc
= func
.getLoc();
343 auto *context
= &getContext();
344 // Convert function type itself if it has an abstract result.
345 auto funcTy
= func
.getFunctionType().cast
<mlir::FunctionType
>();
346 if (hasAbstractResult(funcTy
)) {
347 // TODO: This should be generalized for derived types, and it is
348 // architecture and OS dependent.
349 if (fir::isa_builtin_cptr_type(funcTy
.getResult(0))) {
350 func
.setType(getCPtrFunctionType(funcTy
));
351 patterns
.insert
<ReturnOpConversion
>(context
, mlir::Value
{});
352 target
.addDynamicallyLegalOp
<mlir::func::ReturnOp
>(
353 [](mlir::func::ReturnOp ret
) {
354 mlir::Type retTy
= ret
.getOperand(0).getType();
355 return !fir::isa_builtin_cptr_type(retTy
);
360 // Insert new argument.
361 mlir::OpBuilder
rewriter(context
);
362 auto resultType
= funcTy
.getResult(0);
363 auto argTy
= getResultArgumentType(resultType
, shouldBoxResult
);
364 func
.insertArgument(0u, argTy
, {}, loc
);
365 func
.eraseResult(0u);
366 mlir::Value newArg
= func
.getArgument(0u);
367 if (mustEmboxResult(resultType
, shouldBoxResult
)) {
368 auto bufferType
= fir::ReferenceType::get(resultType
);
369 rewriter
.setInsertionPointToStart(&func
.front());
370 newArg
= rewriter
.create
<fir::BoxAddrOp
>(loc
, bufferType
, newArg
);
372 patterns
.insert
<ReturnOpConversion
>(context
, newArg
);
373 target
.addDynamicallyLegalOp
<mlir::func::ReturnOp
>(
374 [](mlir::func::ReturnOp ret
) { return ret
.getOperands().empty(); });
375 assert(func
.getFunctionType() ==
376 getNewFunctionType(funcTy
, shouldBoxResult
));
378 llvm::SmallVector
<mlir::DictionaryAttr
> allArgs
;
379 func
.getAllArgAttrs(allArgs
);
380 allArgs
.insert(allArgs
.begin(),
381 mlir::DictionaryAttr::get(func
->getContext()));
382 func
.setType(getNewFunctionType(funcTy
, shouldBoxResult
));
383 func
.setAllArgAttrs(allArgs
);
389 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type
) {
390 return mlir::TypeSwitch
<mlir::Type
, bool>(type
)
391 .Case([](fir::BoxProcType boxProc
) {
392 return fir::hasAbstractResult(
393 boxProc
.getEleTy().cast
<mlir::FunctionType
>());
395 .Case([](fir::PointerType pointer
) {
396 return fir::hasAbstractResult(
397 pointer
.getEleTy().cast
<mlir::FunctionType
>());
399 .Default([](auto &&) { return false; });
402 class AbstractResultOnGlobalOpt
403 : public AbstractResultOptTemplate
<
404 AbstractResultOnGlobalOpt
, fir::impl::AbstractResultOnGlobalOptBase
> {
406 void runOnSpecificOperation(fir::GlobalOp global
, bool,
407 mlir::RewritePatternSet
&,
408 mlir::ConversionTarget
&) {
409 if (containsFunctionTypeWithAbstractResult(global
.getType())) {
410 TODO(global
->getLoc(), "support for procedure pointers");
414 } // end anonymous namespace
417 std::unique_ptr
<mlir::Pass
> fir::createAbstractResultOnFuncOptPass() {
418 return std::make_unique
<AbstractResultOnFuncOpt
>();
421 std::unique_ptr
<mlir::Pass
> fir::createAbstractResultOnGlobalOptPass() {
422 return std::make_unique
<AbstractResultOnGlobalOpt
>();