[lldb] Add ability to hide the root name of a value
[llvm-project.git] / flang / lib / Optimizer / Transforms / AbstractResult.cpp
blob2c0576eaa5cc4d11363052089da1435c9f72c82d
1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Builder/FIRBuilder.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
15 #include "flang/Optimizer/Transforms/Passes.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "mlir/Transforms/Passes.h"
21 #include "llvm/ADT/TypeSwitch.h"
23 namespace fir {
24 #define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
25 #define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
26 #include "flang/Optimizer/Transforms/Passes.h.inc"
27 } // namespace fir
29 #define DEBUG_TYPE "flang-abstract-result-opt"
31 using namespace mlir;
33 namespace fir {
34 namespace {
36 static mlir::Type getResultArgumentType(mlir::Type resultType,
37 bool shouldBoxResult) {
38 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
39 .Case<fir::SequenceType, fir::RecordType>(
40 [&](mlir::Type type) -> mlir::Type {
41 if (shouldBoxResult)
42 return fir::BoxType::get(type);
43 return fir::ReferenceType::get(type);
45 .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
46 return fir::ReferenceType::get(type);
48 .Default([](mlir::Type) -> mlir::Type {
49 llvm_unreachable("bad abstract result type");
50 });
53 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
54 bool shouldBoxResult) {
55 auto resultType = funcTy.getResult(0);
56 auto argTy = getResultArgumentType(resultType, shouldBoxResult);
57 llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
58 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
59 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
60 /*resultTypes=*/{});
63 /// This is for function result types that are of type C_PTR from ISO_C_BINDING.
64 /// Follow the ABI for interoperability with C.
65 static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
66 auto resultType = funcTy.getResult(0);
67 assert(fir::isa_builtin_cptr_type(resultType));
68 llvm::SmallVector<mlir::Type> outputTypes;
69 auto recTy = resultType.dyn_cast<fir::RecordType>();
70 outputTypes.emplace_back(recTy.getTypeList()[0].second);
71 return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
72 outputTypes);
75 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
76 return resultType.isa<fir::SequenceType, fir::RecordType>() &&
77 shouldBoxResult;
80 template <typename Op>
81 class CallConversion : public mlir::OpRewritePattern<Op> {
82 public:
83 using mlir::OpRewritePattern<Op>::OpRewritePattern;
85 CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
86 : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
88 mlir::LogicalResult
89 matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
90 auto loc = op.getLoc();
91 auto result = op->getResult(0);
92 if (!result.hasOneUse()) {
93 mlir::emitError(loc,
94 "calls with abstract result must have exactly one user");
95 return mlir::failure();
97 auto saveResult =
98 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
99 if (!saveResult) {
100 mlir::emitError(
101 loc, "calls with abstract result must be used in fir.save_result");
102 return mlir::failure();
104 auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
105 auto buffer = saveResult.getMemref();
106 mlir::Value arg = buffer;
107 if (mustEmboxResult(result.getType(), shouldBoxResult))
108 arg = rewriter.create<fir::EmboxOp>(
109 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
110 saveResult.getTypeparams());
112 llvm::SmallVector<mlir::Type> newResultTypes;
113 // TODO: This should be generalized for derived types, and it is
114 // architecture and OS dependent.
115 bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
116 Op newOp;
117 if (isResultBuiltinCPtr) {
118 auto recTy = result.getType().template dyn_cast<fir::RecordType>();
119 newResultTypes.emplace_back(recTy.getTypeList()[0].second);
122 // fir::CallOp specific handling.
123 if constexpr (std::is_same_v<Op, fir::CallOp>) {
124 if (op.getCallee()) {
125 llvm::SmallVector<mlir::Value> newOperands;
126 if (!isResultBuiltinCPtr)
127 newOperands.emplace_back(arg);
128 newOperands.append(op.getOperands().begin(), op.getOperands().end());
129 newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
130 newResultTypes, newOperands);
131 } else {
132 // Indirect calls.
133 llvm::SmallVector<mlir::Type> newInputTypes;
134 if (!isResultBuiltinCPtr)
135 newInputTypes.emplace_back(argType);
136 for (auto operand : op.getOperands().drop_front())
137 newInputTypes.push_back(operand.getType());
138 auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
139 newResultTypes);
141 llvm::SmallVector<mlir::Value> newOperands;
142 newOperands.push_back(
143 rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
144 if (!isResultBuiltinCPtr)
145 newOperands.push_back(arg);
146 newOperands.append(op.getOperands().begin() + 1,
147 op.getOperands().end());
148 newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
149 newResultTypes, newOperands);
153 // fir::DispatchOp specific handling.
154 if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
155 llvm::SmallVector<mlir::Value> newOperands;
156 if (!isResultBuiltinCPtr)
157 newOperands.emplace_back(arg);
158 unsigned passArgShift = newOperands.size();
159 newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
161 fir::DispatchOp newDispatchOp;
162 if (op.getPassArgPos())
163 newOp = rewriter.create<fir::DispatchOp>(
164 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
165 op.getOperands()[0], newOperands,
166 rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift));
167 else
168 newOp = rewriter.create<fir::DispatchOp>(
169 loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
170 op.getOperands()[0], newOperands, nullptr);
173 if (isResultBuiltinCPtr) {
174 mlir::Value save = saveResult.getMemref();
175 auto module = op->template getParentOfType<mlir::ModuleOp>();
176 fir::KindMapping kindMap = fir::getKindMapping(module);
177 FirOpBuilder builder(rewriter, kindMap);
178 mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
179 builder, loc, save, result.getType());
180 rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
182 op->dropAllReferences();
183 rewriter.eraseOp(op);
184 return mlir::success();
187 private:
188 bool shouldBoxResult;
191 class SaveResultOpConversion
192 : public mlir::OpRewritePattern<fir::SaveResultOp> {
193 public:
194 using OpRewritePattern::OpRewritePattern;
195 SaveResultOpConversion(mlir::MLIRContext *context)
196 : OpRewritePattern(context) {}
197 mlir::LogicalResult
198 matchAndRewrite(fir::SaveResultOp op,
199 mlir::PatternRewriter &rewriter) const override {
200 rewriter.eraseOp(op);
201 return mlir::success();
205 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
206 public:
207 using OpRewritePattern::OpRewritePattern;
208 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
209 : OpRewritePattern(context), newArg{newArg} {}
210 mlir::LogicalResult
211 matchAndRewrite(mlir::func::ReturnOp ret,
212 mlir::PatternRewriter &rewriter) const override {
213 auto loc = ret.getLoc();
214 rewriter.setInsertionPoint(ret);
215 auto returnedValue = ret.getOperand(0);
216 bool replacedStorage = false;
217 if (auto *op = returnedValue.getDefiningOp())
218 if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
219 auto resultStorage = load.getMemref();
220 // TODO: This should be generalized for derived types, and it is
221 // architecture and OS dependent.
222 if (fir::isa_builtin_cptr_type(returnedValue.getType())) {
223 rewriter.eraseOp(load);
224 auto module = ret->getParentOfType<mlir::ModuleOp>();
225 fir::KindMapping kindMap = fir::getKindMapping(module);
226 FirOpBuilder builder(rewriter, kindMap);
227 mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr(
228 builder, loc, resultStorage, returnedValue.getType());
229 mlir::Value retValue = rewriter.create<fir::LoadOp>(
230 loc, fir::unwrapRefType(retAddr.getType()), retAddr);
231 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
232 ret, mlir::ValueRange{retValue});
233 return mlir::success();
235 load.getMemref().replaceAllUsesWith(newArg);
236 replacedStorage = true;
237 if (auto *alloc = resultStorage.getDefiningOp())
238 if (alloc->use_empty())
239 rewriter.eraseOp(alloc);
241 // The result storage may have been optimized out by a memory to
242 // register pass, this is possible for fir.box results, or fir.record
243 // with no length parameters. Simply store the result in the result storage.
244 // at the return point.
245 if (!replacedStorage)
246 rewriter.create<fir::StoreOp>(loc, returnedValue, newArg);
247 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
248 return mlir::success();
251 private:
252 mlir::Value newArg;
255 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
256 public:
257 using OpRewritePattern::OpRewritePattern;
258 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
259 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
260 mlir::LogicalResult
261 matchAndRewrite(fir::AddrOfOp addrOf,
262 mlir::PatternRewriter &rewriter) const override {
263 auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
264 mlir::FunctionType newFuncTy;
265 // TODO: This should be generalized for derived types, and it is
266 // architecture and OS dependent.
267 if (oldFuncTy.getNumResults() != 0 &&
268 fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
269 newFuncTy = getCPtrFunctionType(oldFuncTy);
270 else
271 newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
272 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
273 addrOf.getSymbol());
274 // Rather than converting all op a function pointer might transit through
275 // (e.g calls, stores, loads, converts...), cast new type to the abstract
276 // type. A conversion will be added when calling indirect calls of abstract
277 // types.
278 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
279 return mlir::success();
282 private:
283 bool shouldBoxResult;
286 /// @brief Base CRTP class for AbstractResult pass family.
287 /// Contains common logic for abstract result conversion in a reusable fashion.
288 /// @tparam Pass target class that implements operation-specific logic.
289 /// @tparam PassBase base class template for the pass generated by TableGen.
290 /// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
291 /// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
292 /// This function should implement operation-specific functionality.
293 template <typename Pass, template <typename> class PassBase>
294 class AbstractResultOptTemplate : public PassBase<Pass> {
295 public:
296 void runOnOperation() override {
297 auto *context = &this->getContext();
298 auto op = this->getOperation();
300 mlir::RewritePatternSet patterns(context);
301 mlir::ConversionTarget target = *context;
302 const bool shouldBoxResult = this->passResultAsBox.getValue();
304 auto &self = static_cast<Pass &>(*this);
305 self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
307 // Convert the calls and, if needed, the ReturnOp in the function body.
308 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
309 mlir::func::FuncDialect>();
310 target.addIllegalOp<fir::SaveResultOp>();
311 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
312 return !hasAbstractResult(call.getFunctionType());
314 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
315 if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
316 return !hasAbstractResult(funTy);
317 return true;
319 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
320 return !hasAbstractResult(dispatch.getFunctionType());
323 patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
324 patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
325 patterns.insert<SaveResultOpConversion>(context);
326 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
327 if (mlir::failed(
328 mlir::applyPartialConversion(op, target, std::move(patterns)))) {
329 mlir::emitError(op.getLoc(), "error in converting abstract results\n");
330 this->signalPassFailure();
335 class AbstractResultOnFuncOpt
336 : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
337 fir::impl::AbstractResultOnFuncOptBase> {
338 public:
339 void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
340 mlir::RewritePatternSet &patterns,
341 mlir::ConversionTarget &target) {
342 auto loc = func.getLoc();
343 auto *context = &getContext();
344 // Convert function type itself if it has an abstract result.
345 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
346 if (hasAbstractResult(funcTy)) {
347 // TODO: This should be generalized for derived types, and it is
348 // architecture and OS dependent.
349 if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
350 func.setType(getCPtrFunctionType(funcTy));
351 patterns.insert<ReturnOpConversion>(context, mlir::Value{});
352 target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
353 [](mlir::func::ReturnOp ret) {
354 mlir::Type retTy = ret.getOperand(0).getType();
355 return !fir::isa_builtin_cptr_type(retTy);
357 return;
359 if (!func.empty()) {
360 // Insert new argument.
361 mlir::OpBuilder rewriter(context);
362 auto resultType = funcTy.getResult(0);
363 auto argTy = getResultArgumentType(resultType, shouldBoxResult);
364 func.insertArgument(0u, argTy, {}, loc);
365 func.eraseResult(0u);
366 mlir::Value newArg = func.getArgument(0u);
367 if (mustEmboxResult(resultType, shouldBoxResult)) {
368 auto bufferType = fir::ReferenceType::get(resultType);
369 rewriter.setInsertionPointToStart(&func.front());
370 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
372 patterns.insert<ReturnOpConversion>(context, newArg);
373 target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
374 [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
375 assert(func.getFunctionType() ==
376 getNewFunctionType(funcTy, shouldBoxResult));
377 } else {
378 llvm::SmallVector<mlir::DictionaryAttr> allArgs;
379 func.getAllArgAttrs(allArgs);
380 allArgs.insert(allArgs.begin(),
381 mlir::DictionaryAttr::get(func->getContext()));
382 func.setType(getNewFunctionType(funcTy, shouldBoxResult));
383 func.setAllArgAttrs(allArgs);
389 inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
390 return mlir::TypeSwitch<mlir::Type, bool>(type)
391 .Case([](fir::BoxProcType boxProc) {
392 return fir::hasAbstractResult(
393 boxProc.getEleTy().cast<mlir::FunctionType>());
395 .Case([](fir::PointerType pointer) {
396 return fir::hasAbstractResult(
397 pointer.getEleTy().cast<mlir::FunctionType>());
399 .Default([](auto &&) { return false; });
402 class AbstractResultOnGlobalOpt
403 : public AbstractResultOptTemplate<
404 AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
405 public:
406 void runOnSpecificOperation(fir::GlobalOp global, bool,
407 mlir::RewritePatternSet &,
408 mlir::ConversionTarget &) {
409 if (containsFunctionTypeWithAbstractResult(global.getType())) {
410 TODO(global->getLoc(), "support for procedure pointers");
414 } // end anonymous namespace
415 } // namespace fir
417 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
418 return std::make_unique<AbstractResultOnFuncOpt>();
421 std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
422 return std::make_unique<AbstractResultOnGlobalOpt>();