[Offload] Fix offload-info interface
[llvm-project.git] / flang / lib / Optimizer / CodeGen / BoxedProcedure.cpp
blob26f4aee21d8bda13553a0979f20e48d7961478ef
1 //===-- BoxedProcedure.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/CodeGen/CodeGen.h"
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Dialect/FIROps.h"
15 #include "flang/Optimizer/Dialect/FIRType.h"
16 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
17 #include "flang/Optimizer/Support/FatalError.h"
18 #include "flang/Optimizer/Support/InternalNames.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/DenseMap.h"
24 namespace fir {
25 #define GEN_PASS_DEF_BOXEDPROCEDUREPASS
26 #include "flang/Optimizer/CodeGen/CGPasses.h.inc"
27 } // namespace fir
29 #define DEBUG_TYPE "flang-procedure-pointer"
31 using namespace fir;
33 namespace {
34 /// Options to the procedure pointer pass.
35 struct BoxedProcedureOptions {
36 // Lower the boxproc abstraction to function pointers and thunks where
37 // required.
38 bool useThunks = true;
41 /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types.
42 class BoxprocTypeRewriter : public mlir::TypeConverter {
43 public:
44 using mlir::TypeConverter::convertType;
46 /// Does the type \p ty need to be converted?
47 /// Any type that is a `!fir.boxproc` in whole or in part will need to be
48 /// converted to a function type to lower the IR to function pointer form in
49 /// the default implementation performed in this pass. Other implementations
50 /// are possible, so those may convert `!fir.boxproc` to some other type or
51 /// not at all depending on the implementation target's characteristics and
52 /// preference.
53 bool needsConversion(mlir::Type ty) {
54 if (mlir::isa<BoxProcType>(ty))
55 return true;
56 if (auto funcTy = mlir::dyn_cast<mlir::FunctionType>(ty)) {
57 for (auto t : funcTy.getInputs())
58 if (needsConversion(t))
59 return true;
60 for (auto t : funcTy.getResults())
61 if (needsConversion(t))
62 return true;
63 return false;
65 if (auto tupleTy = mlir::dyn_cast<mlir::TupleType>(ty)) {
66 for (auto t : tupleTy.getTypes())
67 if (needsConversion(t))
68 return true;
69 return false;
71 if (auto recTy = mlir::dyn_cast<RecordType>(ty)) {
72 auto [visited, inserted] = visitedTypes.try_emplace(ty, false);
73 if (!inserted)
74 return visited->second;
75 bool wasAlreadyVisitingRecordType = needConversionIsVisitingRecordType;
76 needConversionIsVisitingRecordType = true;
77 bool result = false;
78 for (auto t : recTy.getTypeList()) {
79 if (needsConversion(t.second)) {
80 result = true;
81 break;
84 // Only keep the result cached if the fir.type visited was a "top-level
85 // type". Nested types with a recursive reference to the "top-level type"
86 // may incorrectly have been resolved as not needed conversions because it
87 // had not been determined yet if the "top-level type" needed conversion.
88 // This is not an issue to determine the "top-level type" need of
89 // conversion, but the result should not be kept and later used in other
90 // contexts.
91 needConversionIsVisitingRecordType = wasAlreadyVisitingRecordType;
92 if (needConversionIsVisitingRecordType)
93 visitedTypes.erase(ty);
94 else
95 visitedTypes.find(ty)->second = result;
96 return result;
98 if (auto boxTy = mlir::dyn_cast<BaseBoxType>(ty))
99 return needsConversion(boxTy.getEleTy());
100 if (isa_ref_type(ty))
101 return needsConversion(unwrapRefType(ty));
102 if (auto t = mlir::dyn_cast<SequenceType>(ty))
103 return needsConversion(unwrapSequenceType(ty));
104 if (auto t = mlir::dyn_cast<TypeDescType>(ty))
105 return needsConversion(t.getOfTy());
106 return false;
109 BoxprocTypeRewriter(mlir::Location location) : loc{location} {
110 addConversion([](mlir::Type ty) { return ty; });
111 addConversion(
112 [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); });
113 addConversion([&](mlir::TupleType tupTy) {
114 llvm::SmallVector<mlir::Type> memTys;
115 for (auto ty : tupTy.getTypes())
116 memTys.push_back(convertType(ty));
117 return mlir::TupleType::get(tupTy.getContext(), memTys);
119 addConversion([&](mlir::FunctionType funcTy) {
120 llvm::SmallVector<mlir::Type> inTys;
121 llvm::SmallVector<mlir::Type> resTys;
122 for (auto ty : funcTy.getInputs())
123 inTys.push_back(convertType(ty));
124 for (auto ty : funcTy.getResults())
125 resTys.push_back(convertType(ty));
126 return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys);
128 addConversion([&](ReferenceType ty) {
129 return ReferenceType::get(convertType(ty.getEleTy()));
131 addConversion([&](PointerType ty) {
132 return PointerType::get(convertType(ty.getEleTy()));
134 addConversion(
135 [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); });
136 addConversion([&](fir::LLVMPointerType ty) {
137 return fir::LLVMPointerType::get(convertType(ty.getEleTy()));
139 addConversion(
140 [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); });
141 addConversion([&](ClassType ty) {
142 return ClassType::get(convertType(ty.getEleTy()));
144 addConversion([&](SequenceType ty) {
145 // TODO: add ty.getLayoutMap() as needed.
146 return SequenceType::get(ty.getShape(), convertType(ty.getEleTy()));
148 addConversion([&](RecordType ty) -> mlir::Type {
149 if (!needsConversion(ty))
150 return ty;
151 if (auto converted = convertedTypes.lookup(ty))
152 return converted;
153 auto rec = RecordType::get(ty.getContext(),
154 ty.getName().str() + boxprocSuffix.str());
155 if (rec.isFinalized())
156 return rec;
157 [[maybe_unused]] auto it = convertedTypes.try_emplace(ty, rec);
158 assert(it.second && "expected ty to not be in the map");
159 std::vector<RecordType::TypePair> ps = ty.getLenParamList();
160 std::vector<RecordType::TypePair> cs;
161 for (auto t : ty.getTypeList()) {
162 if (needsConversion(t.second))
163 cs.emplace_back(t.first, convertType(t.second));
164 else
165 cs.emplace_back(t.first, t.second);
167 rec.finalize(ps, cs);
168 rec.pack(ty.isPacked());
169 return rec;
171 addConversion([&](TypeDescType ty) {
172 return TypeDescType::get(convertType(ty.getOfTy()));
174 addSourceMaterialization(materializeProcedure);
175 addTargetMaterialization(materializeProcedure);
178 static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
179 BoxProcType type,
180 mlir::ValueRange inputs,
181 mlir::Location loc) {
182 assert(inputs.size() == 1);
183 return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
184 inputs[0]);
187 void setLocation(mlir::Location location) { loc = location; }
189 private:
190 // Maps to deal with recursive derived types (avoid infinite loops).
191 // Caching is also beneficial for apps with big types (dozens of
192 // components and or parent types), so the lifetime of the cache
193 // is the whole pass.
194 llvm::DenseMap<mlir::Type, bool> visitedTypes;
195 bool needConversionIsVisitingRecordType = false;
196 llvm::DenseMap<mlir::Type, mlir::Type> convertedTypes;
197 mlir::Location loc;
200 /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically,
201 /// Fortran procedures can be referenced directly through a function pointer.
202 /// However, Fortran has one-level dynamic scoping between a host procedure and
203 /// its internal procedures. This allows internal procedures to directly access
204 /// and modify the state of the host procedure's variables.
206 /// There are any number of possible implementations possible.
208 /// The implementation used here is to convert `boxproc` values to function
209 /// pointers everywhere. If a `boxproc` value includes a frame pointer to the
210 /// host procedure's data, then a thunk will be created at runtime to capture
211 /// the frame pointer during execution. In LLVM IR, the frame pointer is
212 /// designated with the `nest` attribute. The thunk's address will then be used
213 /// as the call target instead of the original function's address directly.
214 class BoxedProcedurePass
215 : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> {
216 public:
217 using BoxedProcedurePassBase<BoxedProcedurePass>::BoxedProcedurePassBase;
219 inline mlir::ModuleOp getModule() { return getOperation(); }
221 void runOnOperation() override final {
222 if (options.useThunks) {
223 auto *context = &getContext();
224 mlir::IRRewriter rewriter(context);
225 BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
226 getModule().walk([&](mlir::Operation *op) {
227 bool opIsValid = true;
228 typeConverter.setLocation(op->getLoc());
229 if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
230 mlir::Type ty = addr.getVal().getType();
231 mlir::Type resTy = addr.getResult().getType();
232 if (llvm::isa<mlir::FunctionType>(ty) ||
233 llvm::isa<fir::BoxProcType>(ty)) {
234 // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
235 // or function type to be `fir.convert` ops.
236 rewriter.setInsertionPoint(addr);
237 rewriter.replaceOpWithNewOp<ConvertOp>(
238 addr, typeConverter.convertType(addr.getType()), addr.getVal());
239 opIsValid = false;
240 } else if (typeConverter.needsConversion(resTy)) {
241 rewriter.startOpModification(op);
242 op->getResult(0).setType(typeConverter.convertType(resTy));
243 rewriter.finalizeOpModification(op);
245 } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
246 mlir::FunctionType ty = func.getFunctionType();
247 if (typeConverter.needsConversion(ty)) {
248 rewriter.startOpModification(func);
249 auto toTy =
250 mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty));
251 if (!func.empty())
252 for (auto e : llvm::enumerate(toTy.getInputs())) {
253 unsigned i = e.index();
254 auto &block = func.front();
255 block.insertArgument(i, e.value(), func.getLoc());
256 block.getArgument(i + 1).replaceAllUsesWith(
257 block.getArgument(i));
258 block.eraseArgument(i + 1);
260 func.setType(toTy);
261 rewriter.finalizeOpModification(func);
263 } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
264 // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
265 // as required.
266 mlir::Type toTy = typeConverter.convertType(
267 mlir::cast<BoxProcType>(embox.getType()).getEleTy());
268 rewriter.setInsertionPoint(embox);
269 if (embox.getHost()) {
270 // Create the thunk.
271 auto module = embox->getParentOfType<mlir::ModuleOp>();
272 FirOpBuilder builder(rewriter, module);
273 const auto triple{fir::getTargetTriple(module)};
274 auto loc = embox.getLoc();
275 mlir::Type i8Ty = builder.getI8Type();
276 mlir::Type i8Ptr = builder.getRefType(i8Ty);
277 // For AArch64, PPC32 and PPC64, the thunk is populated by a call to
278 // __trampoline_setup, which is defined in
279 // compiler-rt/lib/builtins/trampoline_setup.c and requires the
280 // thunk size greater than 32 bytes. For RISCV and x86_64, the
281 // thunk setup doesn't go through __trampoline_setup and fits in 32
282 // bytes.
283 fir::SequenceType::Extent thunkSize = triple.getTrampolineSize();
284 mlir::Type buffTy = SequenceType::get({thunkSize}, i8Ty);
285 auto buffer = builder.create<AllocaOp>(loc, buffTy);
286 mlir::Value closure =
287 builder.createConvert(loc, i8Ptr, embox.getHost());
288 mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
289 mlir::Value func =
290 builder.createConvert(loc, i8Ptr, embox.getFunc());
291 builder.create<fir::CallOp>(
292 loc, factory::getLlvmInitTrampoline(builder),
293 llvm::ArrayRef<mlir::Value>{tramp, func, closure});
294 auto adjustCall = builder.create<fir::CallOp>(
295 loc, factory::getLlvmAdjustTrampoline(builder),
296 llvm::ArrayRef<mlir::Value>{tramp});
297 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
298 adjustCall.getResult(0));
299 opIsValid = false;
300 } else {
301 // Just forward the function as a pointer.
302 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
303 embox.getFunc());
304 opIsValid = false;
306 } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
307 auto ty = global.getType();
308 if (typeConverter.needsConversion(ty)) {
309 rewriter.startOpModification(global);
310 auto toTy = typeConverter.convertType(ty);
311 global.setType(toTy);
312 rewriter.finalizeOpModification(global);
314 } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
315 auto ty = mem.getType();
316 if (typeConverter.needsConversion(ty)) {
317 rewriter.setInsertionPoint(mem);
318 auto toTy = typeConverter.convertType(unwrapRefType(ty));
319 bool isPinned = mem.getPinned();
320 llvm::StringRef uniqName =
321 mem.getUniqName().value_or(llvm::StringRef());
322 llvm::StringRef bindcName =
323 mem.getBindcName().value_or(llvm::StringRef());
324 rewriter.replaceOpWithNewOp<AllocaOp>(
325 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
326 mem.getShape());
327 opIsValid = false;
329 } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
330 auto ty = mem.getType();
331 if (typeConverter.needsConversion(ty)) {
332 rewriter.setInsertionPoint(mem);
333 auto toTy = typeConverter.convertType(unwrapRefType(ty));
334 llvm::StringRef uniqName =
335 mem.getUniqName().value_or(llvm::StringRef());
336 llvm::StringRef bindcName =
337 mem.getBindcName().value_or(llvm::StringRef());
338 rewriter.replaceOpWithNewOp<AllocMemOp>(
339 mem, toTy, uniqName, bindcName, mem.getTypeparams(),
340 mem.getShape());
341 opIsValid = false;
343 } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
344 auto ty = coor.getType();
345 mlir::Type baseTy = coor.getBaseType();
346 if (typeConverter.needsConversion(ty) ||
347 typeConverter.needsConversion(baseTy)) {
348 rewriter.setInsertionPoint(coor);
349 auto toTy = typeConverter.convertType(ty);
350 auto toBaseTy = typeConverter.convertType(baseTy);
351 rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(),
352 coor.getCoor(), toBaseTy);
353 opIsValid = false;
355 } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
356 auto ty = index.getType();
357 mlir::Type onTy = index.getOnType();
358 if (typeConverter.needsConversion(ty) ||
359 typeConverter.needsConversion(onTy)) {
360 rewriter.setInsertionPoint(index);
361 auto toTy = typeConverter.convertType(ty);
362 auto toOnTy = typeConverter.convertType(onTy);
363 rewriter.replaceOpWithNewOp<FieldIndexOp>(
364 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
365 opIsValid = false;
367 } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
368 auto ty = index.getType();
369 mlir::Type onTy = index.getOnType();
370 if (typeConverter.needsConversion(ty) ||
371 typeConverter.needsConversion(onTy)) {
372 rewriter.setInsertionPoint(index);
373 auto toTy = typeConverter.convertType(ty);
374 auto toOnTy = typeConverter.convertType(onTy);
375 rewriter.replaceOpWithNewOp<LenParamIndexOp>(
376 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
377 opIsValid = false;
379 } else {
380 rewriter.startOpModification(op);
381 // Convert the operands if needed
382 for (auto i : llvm::enumerate(op->getResultTypes()))
383 if (typeConverter.needsConversion(i.value())) {
384 auto toTy = typeConverter.convertType(i.value());
385 op->getResult(i.index()).setType(toTy);
388 // Convert the type attributes if needed
389 for (const mlir::NamedAttribute &attr : op->getAttrDictionary())
390 if (auto tyAttr = llvm::dyn_cast<mlir::TypeAttr>(attr.getValue()))
391 if (typeConverter.needsConversion(tyAttr.getValue())) {
392 auto toTy = typeConverter.convertType(tyAttr.getValue());
393 op->setAttr(attr.getName(), mlir::TypeAttr::get(toTy));
395 rewriter.finalizeOpModification(op);
397 // Ensure block arguments are updated if needed.
398 if (opIsValid && op->getNumRegions() != 0) {
399 rewriter.startOpModification(op);
400 for (mlir::Region &region : op->getRegions())
401 for (mlir::Block &block : region.getBlocks())
402 for (mlir::BlockArgument blockArg : block.getArguments())
403 if (typeConverter.needsConversion(blockArg.getType())) {
404 mlir::Type toTy =
405 typeConverter.convertType(blockArg.getType());
406 blockArg.setType(toTy);
408 rewriter.finalizeOpModification(op);
414 private:
415 BoxedProcedureOptions options;
417 } // namespace