1 //===-- PolymorphicOpConversion.cpp ---------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Lower/BuiltinModules.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
14 #include "flang/Optimizer/Dialect/FIRType.h"
15 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
16 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
17 #include "flang/Optimizer/Support/InternalNames.h"
18 #include "flang/Optimizer/Support/TypeCode.h"
19 #include "flang/Optimizer/Support/Utils.h"
20 #include "flang/Optimizer/Transforms/Passes.h"
21 #include "flang/Runtime/derived-api.h"
22 #include "flang/Semantics/runtime-type-info.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"
24 #include "mlir/Dialect/Arith/IR/Arith.h"
25 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/Support/CommandLine.h"
34 #define GEN_PASS_DEF_POLYMORPHICOPCONVERSION
35 #include "flang/Optimizer/Transforms/Passes.h.inc"
43 /// SelectTypeOp converted to an if-then-else chain
45 /// This lowers the test conditions to calls into the runtime.
46 class SelectTypeConv
: public OpConversionPattern
<fir::SelectTypeOp
> {
48 using OpConversionPattern
<fir::SelectTypeOp
>::OpConversionPattern
;
50 SelectTypeConv(mlir::MLIRContext
*ctx
)
51 : mlir::OpConversionPattern
<fir::SelectTypeOp
>(ctx
) {}
54 matchAndRewrite(fir::SelectTypeOp selectType
, OpAdaptor adaptor
,
55 mlir::ConversionPatternRewriter
&rewriter
) const override
;
58 // Generate comparison of type descriptor addresses.
59 mlir::Value
genTypeDescCompare(mlir::Location loc
, mlir::Value selector
,
60 mlir::Type ty
, mlir::ModuleOp mod
,
61 mlir::PatternRewriter
&rewriter
) const;
63 llvm::LogicalResult
genTypeLadderStep(mlir::Location loc
,
65 mlir::Attribute attr
, mlir::Block
*dest
,
66 std::optional
<mlir::ValueRange
> destOps
,
68 mlir::PatternRewriter
&rewriter
,
69 fir::KindMapping
&kindMap
) const;
71 llvm::SmallSet
<llvm::StringRef
, 4> collectAncestors(fir::TypeInfoOp dt
,
72 mlir::ModuleOp mod
) const;
75 /// Lower `fir.dispatch` operation. A virtual call to a method in a dispatch
77 struct DispatchOpConv
: public OpConversionPattern
<fir::DispatchOp
> {
78 using OpConversionPattern
<fir::DispatchOp
>::OpConversionPattern
;
80 DispatchOpConv(mlir::MLIRContext
*ctx
, const BindingTables
&bindingTables
)
81 : mlir::OpConversionPattern
<fir::DispatchOp
>(ctx
),
82 bindingTables(bindingTables
) {}
85 matchAndRewrite(fir::DispatchOp dispatch
, OpAdaptor adaptor
,
86 mlir::ConversionPatternRewriter
&rewriter
) const override
{
87 mlir::Location loc
= dispatch
.getLoc();
89 if (bindingTables
.empty())
90 return emitError(loc
) << "no binding tables found";
92 // Get derived type information.
93 mlir::Type declaredType
=
94 fir::getDerivedType(dispatch
.getObject().getType().getEleTy());
95 assert(mlir::isa
<fir::RecordType
>(declaredType
) && "expecting fir.type");
96 auto recordType
= mlir::dyn_cast
<fir::RecordType
>(declaredType
);
98 // Lookup for the binding table.
99 auto bindingsIter
= bindingTables
.find(recordType
.getName());
100 if (bindingsIter
== bindingTables
.end())
101 return emitError(loc
)
102 << "cannot find binding table for " << recordType
.getName();
104 // Lookup for the binding.
105 const BindingTable
&bindingTable
= bindingsIter
->second
;
106 auto bindingIter
= bindingTable
.find(dispatch
.getMethod());
107 if (bindingIter
== bindingTable
.end())
108 return emitError(loc
)
109 << "cannot find binding for " << dispatch
.getMethod();
110 unsigned bindingIdx
= bindingIter
->second
;
112 mlir::Value passedObject
= dispatch
.getObject();
114 auto module
= dispatch
.getOperation()->getParentOfType
<mlir::ModuleOp
>();
116 std::string typeDescName
=
117 NameUniquer::getTypeDescriptorName(recordType
.getName());
118 if (auto global
= module
.lookupSymbol
<fir::GlobalOp
>(typeDescName
)) {
119 typeDescTy
= global
.getType();
124 // fir.dispatch "proc1"(%11 :
125 // !fir.class<!fir.heap<!fir.type<_QMpolyTp1{a:i32,b:i32}>>>)
128 // %12 = fir.box_tdesc %11 : (!fir.class<!fir.heap<!fir.type<_QMpolyTp1{a:i32,b:i32}>>>) -> !fir.tdesc<none>
129 // %13 = fir.convert %12 : (!fir.tdesc<none>) -> !fir.ref<!fir.type<_QM__fortran_type_infoTderivedtype>>
130 // %14 = fir.field_index binding, !fir.type<_QM__fortran_type_infoTderivedtype>
131 // %15 = fir.coordinate_of %13, %14 : (!fir.ref<!fir.type<_QM__fortran_type_infoTderivedtype>>, !fir.field) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.type<_QM__fortran_type_infoTbinding>>>>>
132 // %bindings = fir.load %15 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x!fir.type<_QM__fortran_type_infoTbinding>>>>>
133 // %16 = fir.box_addr %bindings : (!fir.box<!fir.ptr<!fir.array<?x!fir.type<_QM__fortran_type_infoTbinding>>>>) -> !fir.ptr<!fir.array<?x!fir.type<_QM__fortran_type_infoTbinding>>>
134 // %17 = fir.coordinate_of %16, %c0 : (!fir.ptr<!fir.array<?x!fir.type<_QM__fortran_type_infoTbinding>>>, index) -> !fir.ref<!fir.type<_QM__fortran_type_infoTbinding>>
135 // %18 = fir.field_index proc, !fir.type<_QM__fortran_type_infoTbinding>
136 // %19 = fir.coordinate_of %17, %18 : (!fir.ref<!fir.type<_QM__fortran_type_infoTbinding>>, !fir.field) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_funptr>>
137 // %20 = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_funptr>
138 // %21 = fir.coordinate_of %19, %20 : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_funptr>>, !fir.field) -> !fir.ref<i64>
139 // %22 = fir.load %21 : !fir.ref<i64>
140 // %23 = fir.convert %22 : (i64) -> (() -> ())
141 // fir.call %23() : () -> ()
144 // Load the descriptor.
145 mlir::Type fieldTy
= fir::FieldType::get(rewriter
.getContext());
146 mlir::Type tdescType
=
147 fir::TypeDescType::get(mlir::NoneType::get(rewriter
.getContext()));
148 mlir::Value boxDesc
=
149 rewriter
.create
<fir::BoxTypeDescOp
>(loc
, tdescType
, passedObject
);
150 boxDesc
= rewriter
.create
<fir::ConvertOp
>(
151 loc
, fir::ReferenceType::get(typeDescTy
), boxDesc
);
153 // Load the bindings descriptor.
154 auto bindingsCompName
= Fortran::semantics::bindingDescCompName
;
155 fir::RecordType typeDescRecTy
= mlir::cast
<fir::RecordType
>(typeDescTy
);
156 mlir::Value field
= rewriter
.create
<fir::FieldIndexOp
>(
157 loc
, fieldTy
, bindingsCompName
, typeDescRecTy
, mlir::ValueRange
{});
159 fir::ReferenceType::get(typeDescRecTy
.getType(bindingsCompName
));
160 mlir::Value bindingBoxAddr
=
161 rewriter
.create
<fir::CoordinateOp
>(loc
, coorTy
, boxDesc
, field
);
162 mlir::Value bindingBox
= rewriter
.create
<fir::LoadOp
>(loc
, bindingBoxAddr
);
164 // Load the correct binding.
165 mlir::Value bindings
= rewriter
.create
<fir::BoxAddrOp
>(loc
, bindingBox
);
166 fir::RecordType bindingTy
= fir::unwrapIfDerived(
167 mlir::cast
<fir::BaseBoxType
>(bindingBox
.getType()));
168 mlir::Type bindingAddrTy
= fir::ReferenceType::get(bindingTy
);
169 mlir::Value bindingIdxVal
= rewriter
.create
<mlir::arith::ConstantOp
>(
170 loc
, rewriter
.getIndexType(), rewriter
.getIndexAttr(bindingIdx
));
171 mlir::Value bindingAddr
= rewriter
.create
<fir::CoordinateOp
>(
172 loc
, bindingAddrTy
, bindings
, bindingIdxVal
);
174 // Get the function pointer.
175 auto procCompName
= Fortran::semantics::procCompName
;
176 mlir::Value procField
= rewriter
.create
<fir::FieldIndexOp
>(
177 loc
, fieldTy
, procCompName
, bindingTy
, mlir::ValueRange
{});
178 fir::RecordType procTy
=
179 mlir::cast
<fir::RecordType
>(bindingTy
.getType(procCompName
));
180 mlir::Type procRefTy
= fir::ReferenceType::get(procTy
);
181 mlir::Value procRef
= rewriter
.create
<fir::CoordinateOp
>(
182 loc
, procRefTy
, bindingAddr
, procField
);
184 auto addressFieldName
= Fortran::lower::builtin::cptrFieldName
;
185 mlir::Value addressField
= rewriter
.create
<fir::FieldIndexOp
>(
186 loc
, fieldTy
, addressFieldName
, procTy
, mlir::ValueRange
{});
187 mlir::Type addressTy
= procTy
.getType(addressFieldName
);
188 mlir::Type addressRefTy
= fir::ReferenceType::get(addressTy
);
189 mlir::Value addressRef
= rewriter
.create
<fir::CoordinateOp
>(
190 loc
, addressRefTy
, procRef
, addressField
);
191 mlir::Value address
= rewriter
.create
<fir::LoadOp
>(loc
, addressRef
);
193 // Get the function type.
194 llvm::SmallVector
<mlir::Type
> argTypes
;
195 for (mlir::Value operand
: dispatch
.getArgs())
196 argTypes
.push_back(operand
.getType());
197 llvm::SmallVector
<mlir::Type
> resTypes
;
198 if (!dispatch
.getResults().empty())
199 resTypes
.push_back(dispatch
.getResults()[0].getType());
202 mlir::FunctionType::get(rewriter
.getContext(), argTypes
, resTypes
);
203 mlir::Value funcPtr
= rewriter
.create
<fir::ConvertOp
>(loc
, funTy
, address
);
206 llvm::SmallVector
<mlir::Value
> args
{funcPtr
};
207 args
.append(dispatch
.getArgs().begin(), dispatch
.getArgs().end());
208 rewriter
.replaceOpWithNewOp
<fir::CallOp
>(dispatch
, resTypes
, nullptr, args
,
209 dispatch
.getProcedureAttrsAttr());
210 return mlir::success();
214 BindingTables bindingTables
;
217 /// Convert FIR structured control flow ops to CFG ops.
218 class PolymorphicOpConversion
219 : public fir::impl::PolymorphicOpConversionBase
<PolymorphicOpConversion
> {
221 llvm::LogicalResult
initialize(mlir::MLIRContext
*ctx
) override
{
222 return mlir::success();
225 void runOnOperation() override
{
226 auto *context
= &getContext();
227 mlir::ModuleOp mod
= getOperation();
228 mlir::RewritePatternSet
patterns(context
);
230 BindingTables bindingTables
;
231 buildBindingTables(bindingTables
, mod
);
233 patterns
.insert
<SelectTypeConv
>(context
);
234 patterns
.insert
<DispatchOpConv
>(context
, bindingTables
);
235 mlir::ConversionTarget
target(*context
);
236 target
.addLegalDialect
<mlir::affine::AffineDialect
,
237 mlir::cf::ControlFlowDialect
, FIROpsDialect
,
238 mlir::func::FuncDialect
>();
240 // apply the patterns
241 target
.addIllegalOp
<SelectTypeOp
>();
242 target
.addIllegalOp
<DispatchOp
>();
243 target
.markUnknownOpDynamicallyLegal([](Operation
*) { return true; });
244 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target
,
245 std::move(patterns
)))) {
246 mlir::emitError(mlir::UnknownLoc::get(context
),
247 "error in converting to CFG\n");
254 llvm::LogicalResult
SelectTypeConv::matchAndRewrite(
255 fir::SelectTypeOp selectType
, OpAdaptor adaptor
,
256 mlir::ConversionPatternRewriter
&rewriter
) const {
257 auto operands
= adaptor
.getOperands();
258 auto typeGuards
= selectType
.getCases();
259 unsigned typeGuardNum
= typeGuards
.size();
260 auto selector
= selectType
.getSelector();
261 auto loc
= selectType
.getLoc();
262 auto mod
= selectType
.getOperation()->getParentOfType
<mlir::ModuleOp
>();
263 fir::KindMapping kindMap
= fir::getKindMapping(mod
);
265 // Order type guards so the condition and branches are done to respect the
266 // Execution of SELECT TYPE construct as described in the Fortran 2018
267 // standard 11.1.11.2 point 4.
268 // 1. If a TYPE IS type guard statement matches the selector, the block
269 // following that statement is executed.
270 // 2. Otherwise, if exactly one CLASS IS type guard statement matches the
271 // selector, the block following that statement is executed.
272 // 3. Otherwise, if several CLASS IS type guard statements match the
273 // selector, one of these statements will inevitably specify a type that
274 // is an extension of all the types specified in the others; the block
275 // following that statement is executed.
276 // 4. Otherwise, if there is a CLASS DEFAULT type guard statement, the block
277 // following that statement is executed.
278 // 5. Otherwise, no block is executed.
280 llvm::SmallVector
<unsigned> orderedTypeGuards
;
281 llvm::SmallVector
<unsigned> orderedClassIsGuards
;
282 unsigned defaultGuard
= typeGuardNum
- 1;
284 // The following loop go through the type guards in the fir.select_type
285 // operation and sort them into two lists.
286 // - All the TYPE IS type guard are added in order to the orderedTypeGuards
287 // list. This list is used at the end to generate the if-then-else ladder.
288 // - CLASS IS type guard are added in a separate list. If a CLASS IS type
289 // guard type extends a type already present, the type guard is inserted
290 // before in the list to respect point 3. above. Otherwise it is just
291 // added in order at the end.
292 for (unsigned t
= 0; t
< typeGuardNum
; ++t
) {
293 if (auto a
= mlir::dyn_cast
<fir::ExactTypeAttr
>(typeGuards
[t
])) {
294 orderedTypeGuards
.push_back(t
);
298 if (auto a
= mlir::dyn_cast
<fir::SubclassAttr
>(typeGuards
[t
])) {
299 if (auto recTy
= mlir::dyn_cast
<fir::RecordType
>(a
.getType())) {
300 auto dt
= mod
.lookupSymbol
<fir::TypeInfoOp
>(recTy
.getName());
301 assert(dt
&& "dispatch table not found");
302 llvm::SmallSet
<llvm::StringRef
, 4> ancestors
=
303 collectAncestors(dt
, mod
);
304 if (!ancestors
.empty()) {
305 auto it
= orderedClassIsGuards
.begin();
306 while (it
!= orderedClassIsGuards
.end()) {
307 fir::SubclassAttr sAttr
=
308 mlir::dyn_cast
<fir::SubclassAttr
>(typeGuards
[*it
]);
309 if (auto ty
= mlir::dyn_cast
<fir::RecordType
>(sAttr
.getType())) {
310 if (ancestors
.contains(ty
.getName()))
315 if (it
!= orderedClassIsGuards
.end()) {
316 // Parent type is present so place it before.
317 orderedClassIsGuards
.insert(it
, t
);
322 orderedClassIsGuards
.push_back(t
);
325 orderedTypeGuards
.append(orderedClassIsGuards
);
326 orderedTypeGuards
.push_back(defaultGuard
);
327 assert(orderedTypeGuards
.size() == typeGuardNum
&&
328 "ordered type guard size doesn't match number of type guards");
330 for (unsigned idx
: orderedTypeGuards
) {
331 auto *dest
= selectType
.getSuccessor(idx
);
332 std::optional
<mlir::ValueRange
> destOps
=
333 selectType
.getSuccessorOperands(operands
, idx
);
334 if (mlir::dyn_cast
<mlir::UnitAttr
>(typeGuards
[idx
]))
335 rewriter
.replaceOpWithNewOp
<mlir::cf::BranchOp
>(
336 selectType
, dest
, destOps
.value_or(mlir::ValueRange
{}));
337 else if (mlir::failed(genTypeLadderStep(loc
, selector
, typeGuards
[idx
],
338 dest
, destOps
, mod
, rewriter
,
340 return mlir::failure();
342 return mlir::success();
345 llvm::LogicalResult
SelectTypeConv::genTypeLadderStep(
346 mlir::Location loc
, mlir::Value selector
, mlir::Attribute attr
,
347 mlir::Block
*dest
, std::optional
<mlir::ValueRange
> destOps
,
348 mlir::ModuleOp mod
, mlir::PatternRewriter
&rewriter
,
349 fir::KindMapping
&kindMap
) const {
351 // TYPE IS type guard comparison are all done inlined.
352 if (auto a
= mlir::dyn_cast
<fir::ExactTypeAttr
>(attr
)) {
353 if (fir::isa_trivial(a
.getType()) ||
354 mlir::isa
<fir::CharacterType
>(a
.getType())) {
355 // For type guard statement with Intrinsic type spec the type code of
356 // the descriptor is compared.
357 int code
= fir::getTypeCode(a
.getType(), kindMap
);
359 return mlir::emitError(loc
)
360 << "type code unavailable for " << a
.getType();
361 mlir::Value typeCode
= rewriter
.create
<mlir::arith::ConstantOp
>(
362 loc
, rewriter
.getI8IntegerAttr(code
));
363 mlir::Value selectorTypeCode
= rewriter
.create
<fir::BoxTypeCodeOp
>(
364 loc
, rewriter
.getI8Type(), selector
);
365 cmp
= rewriter
.create
<mlir::arith::CmpIOp
>(
366 loc
, mlir::arith::CmpIPredicate::eq
, selectorTypeCode
, typeCode
);
368 // Flang inline the kind parameter in the type descriptor so we can
369 // directly check if the type descriptor addresses are identical for
370 // the TYPE IS type guard statement.
372 genTypeDescCompare(loc
, selector
, a
.getType(), mod
, rewriter
);
374 return mlir::failure();
377 // CLASS IS type guard statement is done with a runtime call.
378 } else if (auto a
= mlir::dyn_cast
<fir::SubclassAttr
>(attr
)) {
379 // Retrieve the type descriptor from the type guard statement record type.
380 assert(mlir::isa
<fir::RecordType
>(a
.getType()) && "expect fir.record type");
381 fir::RecordType recTy
= mlir::dyn_cast
<fir::RecordType
>(a
.getType());
382 std::string typeDescName
=
383 fir::NameUniquer::getTypeDescriptorName(recTy
.getName());
384 auto typeDescGlobal
= mod
.lookupSymbol
<fir::GlobalOp
>(typeDescName
);
385 auto typeDescAddr
= rewriter
.create
<fir::AddrOfOp
>(
386 loc
, fir::ReferenceType::get(typeDescGlobal
.getType()),
387 typeDescGlobal
.getSymbol());
388 mlir::Type typeDescTy
= ReferenceType::get(rewriter
.getNoneType());
389 mlir::Value typeDesc
=
390 rewriter
.create
<ConvertOp
>(loc
, typeDescTy
, typeDescAddr
);
392 // Prepare the selector descriptor for the runtime call.
393 mlir::Type descNoneTy
= fir::BoxType::get(rewriter
.getNoneType());
394 mlir::Value descSelector
=
395 rewriter
.create
<ConvertOp
>(loc
, descNoneTy
, selector
);
397 // Generate runtime call.
398 llvm::StringRef fctName
= RTNAME_STRING(ClassIs
);
399 mlir::func::FuncOp callee
;
401 // Since conversion is done in parallel for each fir.select_type
402 // operation, the runtime function insertion must be threadsafe.
404 fir::createFuncOp(rewriter
.getUnknownLoc(), mod
, fctName
,
405 rewriter
.getFunctionType({descNoneTy
, typeDescTy
},
406 rewriter
.getI1Type()));
409 .create
<fir::CallOp
>(loc
, callee
,
410 mlir::ValueRange
{descSelector
, typeDesc
})
414 auto *thisBlock
= rewriter
.getInsertionBlock();
416 rewriter
.createBlock(dest
->getParent(), mlir::Region::iterator(dest
));
417 rewriter
.setInsertionPointToEnd(thisBlock
);
418 if (destOps
.has_value())
419 rewriter
.create
<mlir::cf::CondBranchOp
>(loc
, cmp
, dest
, destOps
.value(),
420 newBlock
, std::nullopt
);
422 rewriter
.create
<mlir::cf::CondBranchOp
>(loc
, cmp
, dest
, newBlock
);
423 rewriter
.setInsertionPointToEnd(newBlock
);
424 return mlir::success();
427 // Generate comparison of type descriptor addresses.
429 SelectTypeConv::genTypeDescCompare(mlir::Location loc
, mlir::Value selector
,
430 mlir::Type ty
, mlir::ModuleOp mod
,
431 mlir::PatternRewriter
&rewriter
) const {
432 assert(mlir::isa
<fir::RecordType
>(ty
) && "expect fir.record type");
433 fir::RecordType recTy
= mlir::dyn_cast
<fir::RecordType
>(ty
);
434 std::string typeDescName
=
435 fir::NameUniquer::getTypeDescriptorName(recTy
.getName());
436 auto typeDescGlobal
= mod
.lookupSymbol
<fir::GlobalOp
>(typeDescName
);
439 auto typeDescAddr
= rewriter
.create
<fir::AddrOfOp
>(
440 loc
, fir::ReferenceType::get(typeDescGlobal
.getType()),
441 typeDescGlobal
.getSymbol());
442 auto intPtrTy
= rewriter
.getIndexType();
443 mlir::Type tdescType
=
444 fir::TypeDescType::get(mlir::NoneType::get(rewriter
.getContext()));
445 mlir::Value selectorTdescAddr
=
446 rewriter
.create
<fir::BoxTypeDescOp
>(loc
, tdescType
, selector
);
448 rewriter
.create
<fir::ConvertOp
>(loc
, intPtrTy
, typeDescAddr
);
449 auto selectorTdescInt
=
450 rewriter
.create
<fir::ConvertOp
>(loc
, intPtrTy
, selectorTdescAddr
);
451 return rewriter
.create
<mlir::arith::CmpIOp
>(
452 loc
, mlir::arith::CmpIPredicate::eq
, typeDescInt
, selectorTdescInt
);
455 llvm::SmallSet
<llvm::StringRef
, 4>
456 SelectTypeConv::collectAncestors(fir::TypeInfoOp dt
, mlir::ModuleOp mod
) const {
457 llvm::SmallSet
<llvm::StringRef
, 4> ancestors
;
458 while (auto parentName
= dt
.getIfParentName()) {
459 ancestors
.insert(*parentName
);
460 dt
= mod
.lookupSymbol
<fir::TypeInfoOp
>(*parentName
);
461 assert(dt
&& "parent type info not generated");