1 //===-- BoxedProcedure.cpp ------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/CodeGen/CodeGen.h"
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Dialect/FIROps.h"
15 #include "flang/Optimizer/Dialect/FIRType.h"
16 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
17 #include "flang/Optimizer/Support/FatalError.h"
18 #include "flang/Optimizer/Support/InternalNames.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/DenseMap.h"
25 #define GEN_PASS_DEF_BOXEDPROCEDUREPASS
26 #include "flang/Optimizer/CodeGen/CGPasses.h.inc"
29 #define DEBUG_TYPE "flang-procedure-pointer"
34 /// Options to the procedure pointer pass.
35 struct BoxedProcedureOptions
{
36 // Lower the boxproc abstraction to function pointers and thunks where
38 bool useThunks
= true;
41 /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types.
42 class BoxprocTypeRewriter
: public mlir::TypeConverter
{
44 using mlir::TypeConverter::convertType
;
46 /// Does the type \p ty need to be converted?
47 /// Any type that is a `!fir.boxproc` in whole or in part will need to be
48 /// converted to a function type to lower the IR to function pointer form in
49 /// the default implementation performed in this pass. Other implementations
50 /// are possible, so those may convert `!fir.boxproc` to some other type or
51 /// not at all depending on the implementation target's characteristics and
53 bool needsConversion(mlir::Type ty
) {
54 if (mlir::isa
<BoxProcType
>(ty
))
56 if (auto funcTy
= mlir::dyn_cast
<mlir::FunctionType
>(ty
)) {
57 for (auto t
: funcTy
.getInputs())
58 if (needsConversion(t
))
60 for (auto t
: funcTy
.getResults())
61 if (needsConversion(t
))
65 if (auto tupleTy
= mlir::dyn_cast
<mlir::TupleType
>(ty
)) {
66 for (auto t
: tupleTy
.getTypes())
67 if (needsConversion(t
))
71 if (auto recTy
= mlir::dyn_cast
<RecordType
>(ty
)) {
72 auto [visited
, inserted
] = visitedTypes
.try_emplace(ty
, false);
74 return visited
->second
;
75 bool wasAlreadyVisitingRecordType
= needConversionIsVisitingRecordType
;
76 needConversionIsVisitingRecordType
= true;
78 for (auto t
: recTy
.getTypeList()) {
79 if (needsConversion(t
.second
)) {
84 // Only keep the result cached if the fir.type visited was a "top-level
85 // type". Nested types with a recursive reference to the "top-level type"
86 // may incorrectly have been resolved as not needed conversions because it
87 // had not been determined yet if the "top-level type" needed conversion.
88 // This is not an issue to determine the "top-level type" need of
89 // conversion, but the result should not be kept and later used in other
91 needConversionIsVisitingRecordType
= wasAlreadyVisitingRecordType
;
92 if (needConversionIsVisitingRecordType
)
93 visitedTypes
.erase(ty
);
95 visitedTypes
.find(ty
)->second
= result
;
98 if (auto boxTy
= mlir::dyn_cast
<BaseBoxType
>(ty
))
99 return needsConversion(boxTy
.getEleTy());
100 if (isa_ref_type(ty
))
101 return needsConversion(unwrapRefType(ty
));
102 if (auto t
= mlir::dyn_cast
<SequenceType
>(ty
))
103 return needsConversion(unwrapSequenceType(ty
));
104 if (auto t
= mlir::dyn_cast
<TypeDescType
>(ty
))
105 return needsConversion(t
.getOfTy());
109 BoxprocTypeRewriter(mlir::Location location
) : loc
{location
} {
110 addConversion([](mlir::Type ty
) { return ty
; });
112 [&](BoxProcType boxproc
) { return convertType(boxproc
.getEleTy()); });
113 addConversion([&](mlir::TupleType tupTy
) {
114 llvm::SmallVector
<mlir::Type
> memTys
;
115 for (auto ty
: tupTy
.getTypes())
116 memTys
.push_back(convertType(ty
));
117 return mlir::TupleType::get(tupTy
.getContext(), memTys
);
119 addConversion([&](mlir::FunctionType funcTy
) {
120 llvm::SmallVector
<mlir::Type
> inTys
;
121 llvm::SmallVector
<mlir::Type
> resTys
;
122 for (auto ty
: funcTy
.getInputs())
123 inTys
.push_back(convertType(ty
));
124 for (auto ty
: funcTy
.getResults())
125 resTys
.push_back(convertType(ty
));
126 return mlir::FunctionType::get(funcTy
.getContext(), inTys
, resTys
);
128 addConversion([&](ReferenceType ty
) {
129 return ReferenceType::get(convertType(ty
.getEleTy()));
131 addConversion([&](PointerType ty
) {
132 return PointerType::get(convertType(ty
.getEleTy()));
135 [&](HeapType ty
) { return HeapType::get(convertType(ty
.getEleTy())); });
136 addConversion([&](fir::LLVMPointerType ty
) {
137 return fir::LLVMPointerType::get(convertType(ty
.getEleTy()));
140 [&](BoxType ty
) { return BoxType::get(convertType(ty
.getEleTy())); });
141 addConversion([&](ClassType ty
) {
142 return ClassType::get(convertType(ty
.getEleTy()));
144 addConversion([&](SequenceType ty
) {
145 // TODO: add ty.getLayoutMap() as needed.
146 return SequenceType::get(ty
.getShape(), convertType(ty
.getEleTy()));
148 addConversion([&](RecordType ty
) -> mlir::Type
{
149 if (!needsConversion(ty
))
151 if (auto converted
= convertedTypes
.lookup(ty
))
153 auto rec
= RecordType::get(ty
.getContext(),
154 ty
.getName().str() + boxprocSuffix
.str());
155 if (rec
.isFinalized())
157 [[maybe_unused
]] auto it
= convertedTypes
.try_emplace(ty
, rec
);
158 assert(it
.second
&& "expected ty to not be in the map");
159 std::vector
<RecordType::TypePair
> ps
= ty
.getLenParamList();
160 std::vector
<RecordType::TypePair
> cs
;
161 for (auto t
: ty
.getTypeList()) {
162 if (needsConversion(t
.second
))
163 cs
.emplace_back(t
.first
, convertType(t
.second
));
165 cs
.emplace_back(t
.first
, t
.second
);
167 rec
.finalize(ps
, cs
);
168 rec
.pack(ty
.isPacked());
171 addConversion([&](TypeDescType ty
) {
172 return TypeDescType::get(convertType(ty
.getOfTy()));
174 addSourceMaterialization(materializeProcedure
);
175 addTargetMaterialization(materializeProcedure
);
178 static mlir::Value
materializeProcedure(mlir::OpBuilder
&builder
,
180 mlir::ValueRange inputs
,
181 mlir::Location loc
) {
182 assert(inputs
.size() == 1);
183 return builder
.create
<ConvertOp
>(loc
, unwrapRefType(type
.getEleTy()),
187 void setLocation(mlir::Location location
) { loc
= location
; }
190 // Maps to deal with recursive derived types (avoid infinite loops).
191 // Caching is also beneficial for apps with big types (dozens of
192 // components and or parent types), so the lifetime of the cache
193 // is the whole pass.
194 llvm::DenseMap
<mlir::Type
, bool> visitedTypes
;
195 bool needConversionIsVisitingRecordType
= false;
196 llvm::DenseMap
<mlir::Type
, mlir::Type
> convertedTypes
;
200 /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically,
201 /// Fortran procedures can be referenced directly through a function pointer.
202 /// However, Fortran has one-level dynamic scoping between a host procedure and
203 /// its internal procedures. This allows internal procedures to directly access
204 /// and modify the state of the host procedure's variables.
206 /// There are any number of possible implementations possible.
208 /// The implementation used here is to convert `boxproc` values to function
209 /// pointers everywhere. If a `boxproc` value includes a frame pointer to the
210 /// host procedure's data, then a thunk will be created at runtime to capture
211 /// the frame pointer during execution. In LLVM IR, the frame pointer is
212 /// designated with the `nest` attribute. The thunk's address will then be used
213 /// as the call target instead of the original function's address directly.
214 class BoxedProcedurePass
215 : public fir::impl::BoxedProcedurePassBase
<BoxedProcedurePass
> {
217 using BoxedProcedurePassBase
<BoxedProcedurePass
>::BoxedProcedurePassBase
;
219 inline mlir::ModuleOp
getModule() { return getOperation(); }
221 void runOnOperation() override final
{
222 if (options
.useThunks
) {
223 auto *context
= &getContext();
224 mlir::IRRewriter
rewriter(context
);
225 BoxprocTypeRewriter
typeConverter(mlir::UnknownLoc::get(context
));
226 getModule().walk([&](mlir::Operation
*op
) {
227 bool opIsValid
= true;
228 typeConverter
.setLocation(op
->getLoc());
229 if (auto addr
= mlir::dyn_cast
<BoxAddrOp
>(op
)) {
230 mlir::Type ty
= addr
.getVal().getType();
231 mlir::Type resTy
= addr
.getResult().getType();
232 if (llvm::isa
<mlir::FunctionType
>(ty
) ||
233 llvm::isa
<fir::BoxProcType
>(ty
)) {
234 // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
235 // or function type to be `fir.convert` ops.
236 rewriter
.setInsertionPoint(addr
);
237 rewriter
.replaceOpWithNewOp
<ConvertOp
>(
238 addr
, typeConverter
.convertType(addr
.getType()), addr
.getVal());
240 } else if (typeConverter
.needsConversion(resTy
)) {
241 rewriter
.startOpModification(op
);
242 op
->getResult(0).setType(typeConverter
.convertType(resTy
));
243 rewriter
.finalizeOpModification(op
);
245 } else if (auto func
= mlir::dyn_cast
<mlir::func::FuncOp
>(op
)) {
246 mlir::FunctionType ty
= func
.getFunctionType();
247 if (typeConverter
.needsConversion(ty
)) {
248 rewriter
.startOpModification(func
);
250 mlir::cast
<mlir::FunctionType
>(typeConverter
.convertType(ty
));
252 for (auto e
: llvm::enumerate(toTy
.getInputs())) {
253 unsigned i
= e
.index();
254 auto &block
= func
.front();
255 block
.insertArgument(i
, e
.value(), func
.getLoc());
256 block
.getArgument(i
+ 1).replaceAllUsesWith(
257 block
.getArgument(i
));
258 block
.eraseArgument(i
+ 1);
261 rewriter
.finalizeOpModification(func
);
263 } else if (auto embox
= mlir::dyn_cast
<EmboxProcOp
>(op
)) {
264 // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
266 mlir::Type toTy
= typeConverter
.convertType(
267 mlir::cast
<BoxProcType
>(embox
.getType()).getEleTy());
268 rewriter
.setInsertionPoint(embox
);
269 if (embox
.getHost()) {
271 auto module
= embox
->getParentOfType
<mlir::ModuleOp
>();
272 FirOpBuilder
builder(rewriter
, module
);
273 const auto triple
{fir::getTargetTriple(module
)};
274 auto loc
= embox
.getLoc();
275 mlir::Type i8Ty
= builder
.getI8Type();
276 mlir::Type i8Ptr
= builder
.getRefType(i8Ty
);
277 // For AArch64, PPC32 and PPC64, the thunk is populated by a call to
278 // __trampoline_setup, which is defined in
279 // compiler-rt/lib/builtins/trampoline_setup.c and requires the
280 // thunk size greater than 32 bytes. For RISCV and x86_64, the
281 // thunk setup doesn't go through __trampoline_setup and fits in 32
283 fir::SequenceType::Extent thunkSize
= triple
.getTrampolineSize();
284 mlir::Type buffTy
= SequenceType::get({thunkSize
}, i8Ty
);
285 auto buffer
= builder
.create
<AllocaOp
>(loc
, buffTy
);
286 mlir::Value closure
=
287 builder
.createConvert(loc
, i8Ptr
, embox
.getHost());
288 mlir::Value tramp
= builder
.createConvert(loc
, i8Ptr
, buffer
);
290 builder
.createConvert(loc
, i8Ptr
, embox
.getFunc());
291 builder
.create
<fir::CallOp
>(
292 loc
, factory::getLlvmInitTrampoline(builder
),
293 llvm::ArrayRef
<mlir::Value
>{tramp
, func
, closure
});
294 auto adjustCall
= builder
.create
<fir::CallOp
>(
295 loc
, factory::getLlvmAdjustTrampoline(builder
),
296 llvm::ArrayRef
<mlir::Value
>{tramp
});
297 rewriter
.replaceOpWithNewOp
<ConvertOp
>(embox
, toTy
,
298 adjustCall
.getResult(0));
301 // Just forward the function as a pointer.
302 rewriter
.replaceOpWithNewOp
<ConvertOp
>(embox
, toTy
,
306 } else if (auto global
= mlir::dyn_cast
<GlobalOp
>(op
)) {
307 auto ty
= global
.getType();
308 if (typeConverter
.needsConversion(ty
)) {
309 rewriter
.startOpModification(global
);
310 auto toTy
= typeConverter
.convertType(ty
);
311 global
.setType(toTy
);
312 rewriter
.finalizeOpModification(global
);
314 } else if (auto mem
= mlir::dyn_cast
<AllocaOp
>(op
)) {
315 auto ty
= mem
.getType();
316 if (typeConverter
.needsConversion(ty
)) {
317 rewriter
.setInsertionPoint(mem
);
318 auto toTy
= typeConverter
.convertType(unwrapRefType(ty
));
319 bool isPinned
= mem
.getPinned();
320 llvm::StringRef uniqName
=
321 mem
.getUniqName().value_or(llvm::StringRef());
322 llvm::StringRef bindcName
=
323 mem
.getBindcName().value_or(llvm::StringRef());
324 rewriter
.replaceOpWithNewOp
<AllocaOp
>(
325 mem
, toTy
, uniqName
, bindcName
, isPinned
, mem
.getTypeparams(),
329 } else if (auto mem
= mlir::dyn_cast
<AllocMemOp
>(op
)) {
330 auto ty
= mem
.getType();
331 if (typeConverter
.needsConversion(ty
)) {
332 rewriter
.setInsertionPoint(mem
);
333 auto toTy
= typeConverter
.convertType(unwrapRefType(ty
));
334 llvm::StringRef uniqName
=
335 mem
.getUniqName().value_or(llvm::StringRef());
336 llvm::StringRef bindcName
=
337 mem
.getBindcName().value_or(llvm::StringRef());
338 rewriter
.replaceOpWithNewOp
<AllocMemOp
>(
339 mem
, toTy
, uniqName
, bindcName
, mem
.getTypeparams(),
343 } else if (auto coor
= mlir::dyn_cast
<CoordinateOp
>(op
)) {
344 auto ty
= coor
.getType();
345 mlir::Type baseTy
= coor
.getBaseType();
346 if (typeConverter
.needsConversion(ty
) ||
347 typeConverter
.needsConversion(baseTy
)) {
348 rewriter
.setInsertionPoint(coor
);
349 auto toTy
= typeConverter
.convertType(ty
);
350 auto toBaseTy
= typeConverter
.convertType(baseTy
);
351 rewriter
.replaceOpWithNewOp
<CoordinateOp
>(coor
, toTy
, coor
.getRef(),
352 coor
.getCoor(), toBaseTy
);
355 } else if (auto index
= mlir::dyn_cast
<FieldIndexOp
>(op
)) {
356 auto ty
= index
.getType();
357 mlir::Type onTy
= index
.getOnType();
358 if (typeConverter
.needsConversion(ty
) ||
359 typeConverter
.needsConversion(onTy
)) {
360 rewriter
.setInsertionPoint(index
);
361 auto toTy
= typeConverter
.convertType(ty
);
362 auto toOnTy
= typeConverter
.convertType(onTy
);
363 rewriter
.replaceOpWithNewOp
<FieldIndexOp
>(
364 index
, toTy
, index
.getFieldId(), toOnTy
, index
.getTypeparams());
367 } else if (auto index
= mlir::dyn_cast
<LenParamIndexOp
>(op
)) {
368 auto ty
= index
.getType();
369 mlir::Type onTy
= index
.getOnType();
370 if (typeConverter
.needsConversion(ty
) ||
371 typeConverter
.needsConversion(onTy
)) {
372 rewriter
.setInsertionPoint(index
);
373 auto toTy
= typeConverter
.convertType(ty
);
374 auto toOnTy
= typeConverter
.convertType(onTy
);
375 rewriter
.replaceOpWithNewOp
<LenParamIndexOp
>(
376 index
, toTy
, index
.getFieldId(), toOnTy
, index
.getTypeparams());
380 rewriter
.startOpModification(op
);
381 // Convert the operands if needed
382 for (auto i
: llvm::enumerate(op
->getResultTypes()))
383 if (typeConverter
.needsConversion(i
.value())) {
384 auto toTy
= typeConverter
.convertType(i
.value());
385 op
->getResult(i
.index()).setType(toTy
);
388 // Convert the type attributes if needed
389 for (const mlir::NamedAttribute
&attr
: op
->getAttrDictionary())
390 if (auto tyAttr
= llvm::dyn_cast
<mlir::TypeAttr
>(attr
.getValue()))
391 if (typeConverter
.needsConversion(tyAttr
.getValue())) {
392 auto toTy
= typeConverter
.convertType(tyAttr
.getValue());
393 op
->setAttr(attr
.getName(), mlir::TypeAttr::get(toTy
));
395 rewriter
.finalizeOpModification(op
);
397 // Ensure block arguments are updated if needed.
398 if (opIsValid
&& op
->getNumRegions() != 0) {
399 rewriter
.startOpModification(op
);
400 for (mlir::Region
®ion
: op
->getRegions())
401 for (mlir::Block
&block
: region
.getBlocks())
402 for (mlir::BlockArgument blockArg
: block
.getArguments())
403 if (typeConverter
.needsConversion(blockArg
.getType())) {
405 typeConverter
.convertType(blockArg
.getType());
406 blockArg
.setType(toTy
);
408 rewriter
.finalizeOpModification(op
);
415 BoxedProcedureOptions options
;